diff --git a/README.md b/README.md index da36145..fd9a6a6 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # A native plotting widget for Textual apps -[Textual](https://www.textualize.io/) is an excellent Python framework for building applications in the terminal, or on the web. This library provides a plot widget which your app can use to plot all kinds of quantitative data. So, no pie charts, sorry. The widget support scatter plots and line plots, and can also draw using _high-resolution_ characters like unicode half blocks, quadrants and 8-dot Braille characters. It may still be apparent that these are drawn using characters that take up a full block in the terminal, especially when plot series overlap. However, the use of these characters can reduce the line thickness and improve the resolution tremendously. +[Textual](https://www.textualize.io/) is an excellent Python framework for building applications in the terminal, or on the web. This library provides a plot widget which your app can use to plot all kinds of quantitative data. So, no pie charts, sorry, but we do have a treemap! The widget support scatter plots and line plots, and can also draw using _high-resolution_ characters like unicode half blocks, quadrants and 8-dot Braille characters. It may still be apparent that these are drawn using characters that take up a full block in the terminal, especially when plot series overlap. However, the use of these characters can reduce the line thickness and improve the resolution tremendously. ## Screenshots diff --git a/pyproject.toml b/pyproject.toml index b0ecfdc..8704c73 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,9 @@ license = "MIT" license-files = ["LICENSE"] requires-python = ">=3.10" dependencies = [ + "distinctipy>=1.3.0", "numpy>=2.2.1", + "squarify>=0.4.3", "textual>=1.0.0", "textual-hires-canvas>=0.14.0", ] diff --git a/src/textual_plot/__init__.py b/src/textual_plot/__init__.py index 5bfca09..e1e3239 100644 --- a/src/textual_plot/__init__.py +++ b/src/textual_plot/__init__.py @@ -3,7 +3,7 @@ DurationFormatter, NumericAxisFormatter, ) -from textual_plot.plot_widget import HiResMode, LegendLocation, PlotWidget +from textual_plot.plot_widget import HiResMode, LegendLocation, PlotWidget, ValueDisplay __all__ = [ "AxisFormatter", @@ -12,4 +12,5 @@ "LegendLocation", "NumericAxisFormatter", "PlotWidget", + "ValueDisplay", ] diff --git a/src/textual_plot/color_utils.py b/src/textual_plot/color_utils.py new file mode 100644 index 0000000..8f12458 --- /dev/null +++ b/src/textual_plot/color_utils.py @@ -0,0 +1,46 @@ +"""Color utilities for treemap and plot styling.""" + +from __future__ import annotations + +import colorsys +import re + +import distinctipy + + +def parse_style_to_rgb(style: str) -> tuple[float, float, float] | None: + """Parse rgb(r,g,b) style string to (r,g,b) tuple in 0-1 range.""" + m = re.match(r"rgb\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", style, re.I) + if m: + return (int(m.group(1)) / 255, int(m.group(2)) / 255, int(m.group(3)) / 255) + return None + + +def adjust_luminance( + rgb: tuple[float, float, float], factor: float +) -> tuple[float, float, float]: + """Adjust luminance of RGB (0-1) by factor. factor>1 lightens, factor<1 darkens.""" + h, lightness, s = colorsys.rgb_to_hls(rgb[0], rgb[1], rgb[2]) + lightness = max(0, min(1, lightness * factor)) + return colorsys.hls_to_rgb(h, lightness, s) + + +def tint_with_hue( + child_rgb: tuple[float, float, float], + parent_rgb: tuple[float, float, float], +) -> tuple[float, float, float]: + """Apply parent's hue to child color; keep child's saturation and value for distinction.""" + ph, ps, pv = colorsys.rgb_to_hsv(parent_rgb[0], parent_rgb[1], parent_rgb[2]) + ch, cs, cv = colorsys.rgb_to_hsv(child_rgb[0], child_rgb[1], child_rgb[2]) + r, g, b = colorsys.hsv_to_rgb(ph, cs, cv) + return (r, g, b) + + +def rgb_too_close(a: tuple[float, float, float], b: tuple[float, float, float]) -> bool: + """True if a and b are too similar (perceptual distance).""" + return distinctipy.color_distance(a, b) < 0.02 + + +def rgb_style(rgb: tuple[float, float, float]) -> str: + """Convert (r,g,b) 0-1 to rgb(r,g,b) style string.""" + return f"rgb({int(rgb[0] * 255)},{int(rgb[1] * 255)},{int(rgb[2] * 255)})" diff --git a/src/textual_plot/demo.py b/src/textual_plot/demo.py index 3998b50..52b384a 100644 --- a/src/textual_plot/demo.py +++ b/src/textual_plot/demo.py @@ -326,6 +326,117 @@ def action_cycle_hires_mode(self) -> None: self.plot() +class TreemapPlot(Container): + BINDINGS = [("h", "cycle_hires_mode", "HiRes")] + + _hires_mode = itertools.cycle([None, HiResMode.BRAILLE]) + hires_mode = next(_hires_mode) + + def compose(self) -> ComposeResult: + yield PlotWidget() + + def on_mount(self) -> None: + self.plot() + + def plot(self) -> None: + plot = self.query_one(PlotWidget) + plot.clear() + + values = [500, 433, 280, 195, 165, 120, 95, 78, 62, 48, 35, 28, 25, 18, 12] + labels = [ + "Electronics", + "Clothing", + "Home", + "Sports", + "Books", + "Toys", + "Garden", + "Auto", + "Health", + "Beauty", + "Office", + "Food", + "Pets", + "Music", + "Art", + ] + plot.treemap( + values, + labels=labels, + padding=1, + hires_mode=self.hires_mode, + ) + plot.show_legend() + + def action_cycle_hires_mode(self) -> None: + self.hires_mode = next(self._hires_mode) + self.plot() + + +class NestedTreemapPlot(Container): + """Nested treemap: click to zoom in, Escape to zoom out, arrows to select.""" + + BINDINGS = [("h", "cycle_hires_mode", "HiRes")] + + _hires_mode = itertools.cycle([None, HiResMode.BRAILLE]) + hires_mode = next(_hires_mode) + + def compose(self) -> ComposeResult: + yield PlotWidget() + + def on_mount(self) -> None: + self.plot() + + def plot(self) -> None: + plot = self.query_one(PlotWidget) + plot.clear() + # Nested: show_nested=True draws full hierarchy with luminance variance + plot.treemap( + [ + { + "label": "Electronics", + "children": [ + { + "label": "Phones", + "children": [ + {"label": "iPhone", "value": 55}, + {"label": "Android", "value": 35}, + {"label": "Other", "value": 10}, + ], + }, + {"label": "Laptops", "value": 50}, + {"label": "Tablets", "value": 25}, + {"label": "Accessories", "value": 25}, + ], + }, + { + "label": "Clothing", + "children": [ + {"label": "Shirts", "value": 80}, + {"label": "Pants", "value": 60}, + {"label": "Shoes", "value": 40}, + ], + }, + { + "label": "Home", + "children": [ + {"label": "Furniture", "value": 90}, + {"label": "Decor", "value": 45}, + {"label": "Kitchen", "value": 30}, + ], + }, + ], + padding=1, + hires_mode=self.hires_mode, + show_nested=True, + ) + plot.show_legend() + + def action_cycle_hires_mode(self) -> None: + self.hires_mode = next(self._hires_mode) + self.plot() + + class DemoApp(App[None]): AUTO_FOCUS = "SinePlot > PlotWidget" @@ -350,6 +461,10 @@ def compose(self) -> ComposeResult: yield ErrorBarPlot() with TabPane("Bar plot", id="barplot"): yield BarPlot() + with TabPane("Treemap", id="treemap"): + yield TreemapPlot() + with TabPane("Nested Treemap", id="nested_treemap"): + yield NestedTreemapPlot() def on_mount(self) -> None: self.theme = "tokyo-night" diff --git a/src/textual_plot/plot_widget.py b/src/textual_plot/plot_widget.py index 59abf8f..3a2ef29 100644 --- a/src/textual_plot/plot_widget.py +++ b/src/textual_plot/plot_widget.py @@ -13,8 +13,10 @@ from dataclasses import dataclass from math import ceil, floor from statistics import mean -from typing import Sequence, TypeAlias +from typing import Any, Sequence, TypeAlias +import distinctipy +import numpy as np from rich.text import Text from textual.binding import Binding @@ -22,7 +24,8 @@ from typing import Self else: from typing_extensions import Self -import numpy as np + +import squarify # type: ignore[import-untyped] from numpy.typing import ArrayLike, NDArray from textual import on from textual._box_drawing import BOX_CHARACTERS, combine_quads @@ -30,6 +33,8 @@ from textual.containers import Grid from textual.css.query import NoMatches from textual.events import ( + Click, + Key, MouseDown, MouseMove, MouseScrollDown, @@ -48,12 +53,26 @@ CategoricalAxisFormatter, NumericAxisFormatter, ) +from textual_plot.color_utils import parse_style_to_rgb, rgb_style +from textual_plot.treemap_utils import ( + TREEMAP_PARENT_LABEL_TOP, + format_treemap_nested_path, + get_path_styles, + get_treemap_level, + get_treemap_node_at_path, + normalize_treemap_tree, + squarify_recursive, + treemap_max_depth, +) -__all__ = ["HiResMode", "LegendLocation", "PlotWidget"] +__all__ = ["HiResMode", "LegendLocation", "PlotWidget", "ValueDisplay"] FloatScalar: TypeAlias = float | np.floating FloatArray: TypeAlias = NDArray[np.floating] +# Nested treemap: each node is {"value": float, "label": str, "children": list | None} +TreemapNode: TypeAlias = dict[str, Any] + LEGEND_LINE = { None: "███", @@ -76,6 +95,7 @@ class LegendLocation(enum.Enum): TOPRIGHT = enum.auto() BOTTOMLEFT = enum.auto() BOTTOMRIGHT = enum.auto() + LEFT = enum.auto() # Y-axis label area (left margin); used for treemap default @dataclass @@ -143,6 +163,46 @@ class BarPlot(DataSet): bar_style: str | list[str] +class ValueDisplay(enum.Enum): + """What to show for treemap rectangle values in labels.""" + + VALUE = "value" # Raw value + PERCENT = "percent" # Percent of total + BOTH = "both" # Value and percent + CURRENCY = "currency" # Currency format (2 decimals, configurable symbol) + NONE = "none" # No value/percent line + + +@dataclass +class TreemapPlot: + """A dataset for rendering as a treemap. + + Attributes: + values: Array of numeric values (sizes) for each rectangle. + labels: Optional labels for each rectangle (for legend). + styles: Rich style string or list of styles for each rectangle. + padding: Pixel padding between rectangles and from edges. + hires_mode: HiResMode for high-resolution rendering or None. + aspect_preference: Bias toward wider (>1) or taller (<1) rectangles. + value_display: What to show on the second line of labels. + currency_symbol: Symbol for CURRENCY display mode (e.g. "$", "€"). + tree: Normalized tree structure for nested treemaps (when is_nested). + show_nested: When True and tree is nested, draw full hierarchy with + luminance variance (children darker) instead of zoom-in/zoom-out. + """ + + values: FloatArray + labels: list[str] | None + styles: str | list[str] + padding: int + hires_mode: HiResMode | None + aspect_preference: float + value_display: ValueDisplay + currency_symbol: str + tree: list[TreemapNode] | None = None + show_nested: bool = False + + @dataclass class VLinePlot: """A vertical line to be drawn on the plot. @@ -200,7 +260,7 @@ class ScaleChanged(Message): DEFAULT_CSS = """ PlotWidget { - layers: plot legend; + layers: plot legend info; &:focus > .plot--axis { color: $primary; @@ -229,6 +289,14 @@ class ScaleChanged(Message): } } + #info { + layer: info; + width: auto; + border-top: solid $secondary; + padding: 0 2; + display: none; + } + #legend { layer: legend; width: auto; @@ -266,6 +334,8 @@ class ScaleChanged(Message): Binding("up", "pan_up", "Pan up", group=PAN_GROUP), Binding("down", "pan_down", "Pan down", group=PAN_GROUP), ("r", "reset_scales", "Reset scales"), + Binding("escape", "treemap_zoom_out", "Treemap zoom out", show=False), + ("l", "toggle_legend", "Toggle legend"), ] margin_top = reactive(2) @@ -276,7 +346,7 @@ class ScaleChanged(Message): KEYBOARD_ZOOM_FACTOR: float = 0.15 KEYBOARD_PAN_FACTOR: float = 2.0 - _datasets: list[DataSet] + _datasets: list[DataSet | TreemapPlot] _labels: list[str | None] _user_x_min: float | None = None @@ -306,6 +376,7 @@ class ScaleChanged(Message): _allow_pan_and_zoom: bool = True _is_dragging_legend: bool = False + _legend_visible: bool = False _needs_rerender: bool = False _needs_canvas_resize: bool = False @@ -341,6 +412,10 @@ def __init__( self._labels = [] self._v_lines: list[VLinePlot] = [] self._v_lines_labels: list[str | None] = [] + self._treemap_hover_rects: list[dict] = [] + self._treemap_selected_rect: dict | None = None + self._treemap_path: list[int] = [] + self._treemap_selected_index: int | None = None self._allow_pan_and_zoom = allow_pan_and_zoom self.invert_mouse_wheel = invert_mouse_wheel self._x_formatter = NumericAxisFormatter() @@ -358,6 +433,7 @@ def compose(self) -> ComposeResult: yield Canvas(1, 1, id="plot") yield Canvas(1, 1, id="margin-bottom") yield Legend(id="legend") + yield Static("", id="info") def on_mount(self) -> None: """Initialize the plot widget when mounted to the DOM.""" @@ -394,6 +470,11 @@ def clear(self) -> None: self._labels = [] self._v_lines = [] self._v_lines_labels = [] + self._treemap_hover_rects = [] + self._treemap_selected_rect = None + self._treemap_path = [] + self._treemap_selected_index = None + self._update_info(None) self._update_legend() self._rerender() @@ -598,6 +679,102 @@ def bar( self._update_legend() self._rerender() + def treemap( + self, + values: ArrayLike, + labels: Sequence[str] | None = None, + styles: str | Sequence[str] | None = None, + padding: int = 1, + hires_mode: HiResMode | None = None, + label: str | None = None, + aspect_preference: float = 1.5, + value_display: ValueDisplay | str = ValueDisplay.BOTH, + currency_symbol: str = "$", + show_nested: bool = False, + ) -> None: + """Graph data as a treemap. + + Treemaps display hierarchical or flat data as nested rectangles, with + each rectangle's area proportional to its value. Uses the squarify + algorithm for layout. + + Args: + values: An ArrayLike with the numeric values (sizes) for each + rectangle. Supports flat lists for single-level treemaps. + labels: Optional sequence of labels for each rectangle (for legend). + Defaults to None. + styles: A style string for all rectangles or a sequence of styles + for each. Defaults to a color cycle if None. + padding: Pixel padding between rectangles and from edges. + Defaults to 1. + hires_mode: A HiResMode enum or None for standard rendering. + Defaults to None. + label: A string with the label for the dataset in the legend. + Defaults to None. + aspect_preference: Bias toward wider (>1) or taller (<1) rectangles. + Values >1 prefer wider rectangles (better for text labels). + Defaults to 1.5. + value_display: What to show on the second line of labels: "value", + "percent", "both", "currency", or "none". Defaults to "both". + currency_symbol: Symbol for currency display mode (e.g. "$", "€"). + Defaults to "$". + show_nested: When True and data is nested, draw the full hierarchy + at once with luminance variance (children darker) instead of + zoom-in/zoom-out. Defaults to False. + """ + vd = ( + ValueDisplay(value_display) + if isinstance(value_display, str) + else value_display + ) + tree_nodes, is_nested = normalize_treemap_tree(values, labels) + if not tree_nodes: + return + + values_array = np.array([n["value"] for n in tree_nodes], dtype=float) + labels_list = [n["label"] for n in tree_nodes] + + # Default colors: use distinctipy for perceptually distinct CIELAB-inspired palette + n = len(values_array) + if styles is None: + distinct_colors = distinctipy.get_colors( + n, + pastel_factor=0.2, + rng=42, + # colorblind_type="Deuteranomaly", + ) # deterministic, slightly pastel for terminal contrast + styles_list = [rgb_style(c) for c in distinct_colors] + elif isinstance(styles, str): + styles_list = [styles] * n + else: + styles_list = list(styles) + if len(styles_list) < n: + extra = distinctipy.get_colors( + n - len(styles_list), + pastel_factor=0.2, + rng=42, + # colorblind_type="Deuteranomaly", + ) + styles_list.extend(rgb_style(c) for c in extra) + + self._datasets.append( + TreemapPlot( + values=values_array, + labels=labels_list, + styles=styles_list, + padding=padding, + hires_mode=hires_mode, + aspect_preference=aspect_preference, + value_display=vd, + currency_symbol=currency_symbol, + tree=tree_nodes if is_nested else None, + show_nested=show_nested and is_nested, + ) + ) + self._labels.append(label) + self._update_legend() + self._rerender() + def add_v_line( self, x: float, line_style: str = "white", label: str | None = None ) -> None: @@ -720,10 +897,76 @@ def show_legend( raise TypeError( f"Expected LegendLocation, got {type(location).__name__} instead." ) + elif ( + is_visible + and self._datasets + and all(isinstance(d, TreemapPlot) for d in self._datasets) + ): + self._legend_location = LegendLocation.LEFT + self._legend_relative_offset = Offset(0, 0) + self._legend_visible = is_visible self.query_one("#legend", Static).display = is_visible if is_visible: self._update_legend() + def action_toggle_legend(self) -> None: + """Toggle legend visibility (collapse/expand).""" + self._legend_visible = not self._legend_visible + self.show_legend(is_visible=self._legend_visible) + self._rerender() + + def _update_info(self, rect_info: dict | None) -> None: + """Update the plot info box with rectangle data or hide it.""" + try: + info_box = self.query_one("#info", Static) + except NoMatches: + return + if rect_info is None: + info_box.display = False + return + total = rect_info["total"] + value = rect_info["value"] + pct = 100 * value / total if total else 0 + value_display = rect_info.get("value_display", ValueDisplay.BOTH) + currency_symbol = rect_info.get("currency_symbol", "$") + if value_display == ValueDisplay.CURRENCY: + value_str = f"{currency_symbol}{value:,.2f}" + else: + value_str = f"{value:,.0f}" if value == int(value) else f"{value:,.1f}" + style = rect_info.get("style", "") + path = rect_info.get("path") + tree = rect_info.get("tree") + path_to_style = rect_info.get("path_to_style") + if path is not None and tree is not None and len(path) > 0 and path_to_style: + segments, _ = format_treemap_nested_path( + tree, path, total, value_display, currency_symbol + ) + path_styles = get_path_styles(path, path_to_style) + title = Text() + for i, (seg_text, _) in enumerate(segments): + seg_style = path_styles[i] if i < len(path_styles) else style + title.append(" ███ ", style=seg_style) + title.append(f" {seg_text}") + if i < len(segments) - 1: + title.append(" ▶ ") + else: + title = Text("███") + title.stylize(style) + content = ( + f"{rect_info['label']} · Value: {value_str} · Percent: {pct:.1f}%" + ) + title.append(f" {content}") + info_box.update(title) + info_box.display = True + # Position in margin-bottom (x-axis label) area - unused for treemap + canvas = self.query_one("#plot", Canvas) + if canvas.size: + info_box.offset = Offset( + self.margin_left + 1, + self.margin_top + canvas.size.height, + ) + info_box.refresh(layout=True) + def _update_legend(self) -> None: """Update the content and position of the plot legend.""" legend = self.query_one("#legend", Static) @@ -732,6 +975,20 @@ def _update_legend(self) -> None: legend_lines = [] for label, dataset in zip(self._labels, self._datasets): + # Treemap with per-rectangle labels: show them even if dataset label is None + if isinstance(dataset, TreemapPlot) and dataset.labels: + rect_styles = ( + dataset.styles + if isinstance(dataset.styles, list) + else [dataset.styles] * len(dataset.labels) + ) + for rect_label, rect_style in zip(dataset.labels, rect_styles): + text = Text("███") + text.stylize(rect_style) + text.append(f" {rect_label}") + legend_lines.append(text.markup) + continue + if label is not None: if isinstance(dataset, LinePlot): marker = LEGEND_LINE[dataset.hires_mode] @@ -758,6 +1015,13 @@ def _update_legend(self) -> None: else LEGEND_MARKER[dataset.hires_mode] ).center(3) style = dataset.marker_style + elif isinstance(dataset, TreemapPlot): + marker = "███" + style = ( + dataset.styles[0] + if isinstance(dataset.styles, list) + else dataset.styles + ) else: # unsupported dataset type continue @@ -830,11 +1094,17 @@ def _get_legend_origin_coordinates(self, location: LegendLocation) -> Offset: all_labels.extend( [label for label in self._v_lines_labels if label is not None] ) + for dataset in self._datasets: + if isinstance(dataset, TreemapPlot) and dataset.labels: + all_labels.extend(dataset.labels) # markers and lines in the legend are 3 characters wide, plus a space, so 4 max_length = 4 + max((len(s) for s in all_labels), default=0) - if location in (LegendLocation.TOPLEFT, LegendLocation.BOTTOMLEFT): + if location == LegendLocation.LEFT: + x0 = 1 + y0 = self.margin_top + 1 + elif location in (LegendLocation.TOPLEFT, LegendLocation.BOTTOMLEFT): x0 = self.margin_left + 1 else: # LegendLocation is TOPRIGHT or BOTTOMRIGHT @@ -842,7 +1112,9 @@ def _get_legend_origin_coordinates(self, location: LegendLocation) -> Offset: # leave room for the border x0 -= legend.styles.border.spacing.left + legend.styles.border.spacing.right - if location in (LegendLocation.TOPLEFT, LegendLocation.TOPRIGHT): + if location == LegendLocation.LEFT: + pass # y0 already set + elif location in (LegendLocation.TOPLEFT, LegendLocation.TOPRIGHT): y0 = self.margin_top + 1 else: # LegendLocation is BOTTOMLEFT or BOTTOMRIGHT @@ -918,13 +1190,20 @@ def _render_plot(self) -> None: # clear canvas canvas.reset() - # determine axis limits - if self._datasets or self._v_lines: + # Clear treemap hover rects; treemap render will repopulate if applicable + self._treemap_hover_rects = [] + self._treemap_selected_rect = None + # Preserve _treemap_path and _treemap_selected_index for nested zoom state + self._update_info(None) + + # determine axis limits (skip TreemapPlot - it uses pixel coordinates) + coord_datasets = [d for d in self._datasets if not isinstance(d, TreemapPlot)] + if coord_datasets or self._v_lines: xs = [] ys = [] # Collect x and y values, accounting for bar widths - for dataset in self._datasets: + for dataset in coord_datasets: if isinstance(dataset, BarPlot): # For bar plots, include the left and right edges x_left = dataset.x - dataset.width / 2 @@ -975,6 +1254,8 @@ def _render_plot(self) -> None: self._render_bar_plot(dataset) elif isinstance(dataset, ScatterPlot): self._render_scatter_plot(dataset) + elif isinstance(dataset, TreemapPlot): + self._render_treemap_plot(dataset) # render axis, ticks and labels canvas.draw_rectangle_box( @@ -1209,6 +1490,485 @@ def _render_bar_plot(self, dataset: BarPlot) -> None: x1, y1 = self.get_pixel_from_coordinate(x_right, y_bottom) canvas.draw_filled_rectangle(x0, y0, x1, y1, style=style) + def _render_treemap_plot(self, dataset: TreemapPlot) -> None: + """Render a treemap dataset on the canvas. + + Uses squarify for layout. Rectangles are drawn in canvas coordinates + within the scale rectangle, with optional padding. For nested data, + renders the current level based on _treemap_path, or the full hierarchy + with luminance variance when show_nested=True. + """ + canvas = self.query_one("#plot", Canvas) + sr = self._scale_rectangle + if sr.width <= 0 or sr.height <= 0: + return + + # Nested full-hierarchy mode: draw all levels with luminance variance + if getattr(dataset, "show_nested", False) and dataset.tree is not None: + self._render_treemap_nested(canvas, dataset, sr) + return + + # Get current level: from tree if nested, else from dataset + if dataset.tree is not None: + level_nodes = get_treemap_level(dataset.tree, self._treemap_path) + if not level_nodes: + return + values_at_level = np.array([n["value"] for n in level_nodes]) + labels_at_level = [n["label"] for n in level_nodes] + has_children_list = [n.get("children") is not None for n in level_nodes] + else: + level_nodes = None + values_at_level = dataset.values + labels_at_level = dataset.labels or [ + f"Item {i + 1}" for i in range(len(dataset.values)) + ] + has_children_list = [False] * len(values_at_level) + + n = len(values_at_level) + if n == 0: + return + + pad = max(0, dataset.padding) + effective_width = max(1, sr.width - 2 * pad) + effective_height = max(1, sr.height - 2 * pad) + + aspect = max(0.25, min(4.0, dataset.aspect_preference)) + layout_width = effective_width / aspect + layout_height = effective_height + + normalized = squarify.normalize_sizes( + values_at_level.tolist(), layout_width, layout_height + ) + rects = squarify.squarify(normalized, 0, 0, layout_width, layout_height) + + ox = sr.x + pad + oy = sr.y + pad + + # Styles for current level (may differ from dataset.styles when zoomed) + if dataset.tree is not None and n != len(dataset.styles): + distinct_colors = distinctipy.get_colors( + n, + pastel_factor=0.2, + rng=42, + # colorblind_type="Deuteranomaly", + ) + styles_at_level = [rgb_style(c) for c in distinct_colors] + else: + styles_at_level = ( + dataset.styles + if isinstance(dataset.styles, list) + else [dataset.styles] * n + ) + if len(styles_at_level) < n: + extra = distinctipy.get_colors( + n - len(styles_at_level), + pastel_factor=0.2, + rng=42, + # colorblind_type="Deuteranomaly", + ) + styles_at_level.extend(rgb_style(c) for c in extra) + + total_value = float(np.sum(values_at_level)) + hover_rects: list[dict] = [] + + selected_idx = self._treemap_selected_index + if selected_idx is not None and (selected_idx < 0 or selected_idx >= n): + selected_idx = None + + for i, rect in enumerate(rects): + style = ( + styles_at_level[i] if i < len(styles_at_level) else styles_at_level[0] + ) + assert isinstance(style, str) + + # Scale from layout space to actual space + rx = rect["x"] * aspect + ry = rect["y"] + rdx = rect["dx"] * aspect + rdy = rect["dy"] + + # Convert to integers - squarify returns floats but canvas expects int pixels + x0 = int(ox + rx) + y0 = int(oy + ry) + x1 = int(x0 + rdx) + y1 = int(y0 + rdy) + + # Ensure at least 1px size for visibility + if x1 <= x0: + x1 = x0 + 1 + if y1 <= y0: + y1 = y0 + 1 + + label_str = ( + labels_at_level[i] if i < len(labels_at_level) else f"Item {i + 1}" + ) + hover_rects.append( + { + "x0": x0, + "y0": y0, + "x1": x1, + "y1": y1, + "label": label_str, + "value": float(values_at_level[i]), + "total": total_value, + "style": style, + "value_display": dataset.value_display, + "currency_symbol": dataset.currency_symbol, + "has_children": has_children_list[i] + if i < len(has_children_list) + else False, + "node_index": i, + } + ) + + # Selection fill: hires mode uses solid for selected; otherwise braille + if selected_idx == i: + if dataset.hires_mode: + canvas.draw_filled_rectangle(x0, y0, x1, y1, style=style) + else: + canvas.draw_filled_hires_rectangle( + float(x0), + float(y0), + float(x1), + float(y1), + hires_mode=HiResMode.BRAILLE, + style=style, + ) + elif dataset.hires_mode: + canvas.draw_filled_hires_rectangle( + float(x0), + float(y0), + float(x1), + float(y1), + hires_mode=dataset.hires_mode, + style=style, + ) + else: + canvas.draw_filled_rectangle(x0, y0, x1, y1, style=style) + + # Draw label on rectangle if provided and rect is large enough + if i < len(labels_at_level): + rect_w = x1 - x0 + rect_h = y1 - y0 + value = values_at_level[i] + pct = 100 * value / total_value if total_value else 0 + value_str = f"{value:,.0f}" if value == int(value) else f"{value:,.1f}" + currency_str = f"{dataset.currency_symbol}{value:,.2f}" + pct_str = f"{pct:.1f}%" + # Build second line based on value_display + if dataset.value_display == ValueDisplay.VALUE: + line2 = value_str + elif dataset.value_display == ValueDisplay.PERCENT: + line2 = pct_str + elif dataset.value_display == ValueDisplay.BOTH: + line2 = f"{value_str} ({pct_str})" + elif dataset.value_display == ValueDisplay.CURRENCY: + line2 = currency_str + else: + line2 = None + needs_two_lines = line2 is not None and rect_h >= 2 + if rect_w >= 3 and rect_h >= (2 if needs_two_lines else 1): + label_line1 = labels_at_level[i] + max_len = max(1, rect_w - 2) + if len(label_line1) > max_len: + label_line1 = label_line1[: max_len - 1] + "…" + if needs_two_lines and len(line2) > max_len: + line2 = line2[: max_len - 1] + "…" + cx = (x0 + x1) // 2 + cy = (y0 + y1) // 2 + bg_rgb = parse_style_to_rgb(style) + if bg_rgb is not None: + text_color = distinctipy.get_text_color(bg_rgb) + fg = "black" if text_color == (0, 0, 0) else "white" + else: + fg = "white" + style_str = ( + f"bold {fg} on {style}" + if selected_idx == i + else f"{fg} on {style}" + ) + if needs_two_lines: + canvas.write_text( + cx, + cy - 1, + f"[{style_str}]{label_line1}", + align=TextAlign.CENTER, + ) + canvas.write_text( + cx, cy, f"[{style_str}]{line2}", align=TextAlign.CENTER + ) + else: + canvas.write_text( + cx, + cy, + f"[{style_str}]{label_line1}", + align=TextAlign.CENTER, + ) + + self._treemap_hover_rects = hover_rects + # Restore selected rect from new hover_rects for info box + if selected_idx is not None and selected_idx < len(hover_rects): + self._treemap_selected_rect = hover_rects[selected_idx] + self._update_info(self._treemap_selected_rect) + + def _render_treemap_nested( + self, canvas: Canvas, dataset: TreemapPlot, sr: Region + ) -> None: + """Render full nested hierarchy with luminance variance (children darker). + + When _treemap_path is non-empty, renders the subtree at that path as the root + (zoom-in view). Double-click a parent to zoom in; Escape to zoom out. + """ + pad = max(0, dataset.padding) + effective_width = max(1, sr.width - 2 * pad) + effective_height = max(1, sr.height - 2 * pad) + aspect = max(0.25, min(4.0, dataset.aspect_preference)) + ox = sr.x + pad + oy = sr.y + pad + + # When zoomed in, use subtree at _treemap_path as root + base_path = list(self._treemap_path) + if base_path: + node = get_treemap_node_at_path(dataset.tree, base_path) + if node is None: + self._treemap_path.clear() + base_path = [] + tree_to_render = dataset.tree + else: + tree_to_render = [node] + else: + tree_to_render = dataset.tree + + # Pre-compute full tree's path_to_style so zoomed view keeps colors consistent + n_top_full = len(dataset.tree) + if isinstance(dataset.styles, list) and len(dataset.styles) >= n_top_full: + base_styles_full = list(dataset.styles[:n_top_full]) + else: + distinct_full = distinctipy.get_colors( + n_top_full, + pastel_factor=0.2, + rng=42, + colorblind_type="Deuteranomaly", + ) + base_styles_full = [rgb_style(c) for c in distinct_full] + layout_w_full = effective_width / aspect + extra_rows_full = 3 * max(0, treemap_max_depth(dataset.tree) - 1) + layout_h_full = max(1, effective_height - extra_rows_full) + full_rect_infos = squarify_recursive( + dataset.tree, + 0, + 0, + layout_w_full, + layout_h_full, + aspect, + [], + base_styles_full, + [], + None, + ) + full_path_to_style = {tuple(r["path"]): r["style"] for r in full_rect_infos} + + n_top = len(tree_to_render) + if isinstance(dataset.styles, list) and len(dataset.styles) >= n_top: + base_styles = list(dataset.styles[:n_top]) + else: + base_styles = [ + full_path_to_style.get(tuple(base_path + [i]), base_styles_full[0]) + for i in range(n_top) + ] + if not base_styles: + distinct = distinctipy.get_colors( + n_top, + pastel_factor=0.2, + rng=42, + colorblind_type="Deuteranomaly", + ) + base_styles = [rgb_style(c) for c in distinct] + + layout_w = effective_width / aspect + extra_rows = 3 * max(0, treemap_max_depth(tree_to_render) - 1) + layout_h = max(1, effective_height - extra_rows) + if not base_path: + rect_infos = full_rect_infos + path_to_style = full_path_to_style + else: + rect_infos = squarify_recursive( + tree_to_render, + 0, + 0, + layout_w, + layout_h, + aspect, + [], + base_styles, + [], + None, + ) + + # Map zoom-relative paths [0], [0,0], [0,1] to full paths base_path, base_path+[0], ... + def _full_path(p: list[int]) -> tuple[int, ...]: + if p and p[0] == 0: + return tuple(base_path + p[1:]) + return tuple(base_path + p) + + for r in rect_infos: + fp = _full_path(r["path"]) + if fp in full_path_to_style: + r["style"] = full_path_to_style[fp] + path_to_style = {_full_path(r["path"]): r["style"] for r in rect_infos} + + total_value = sum(r["value"] for r in rect_infos) + hover_rects: list[dict] = [] + selected_idx = self._treemap_selected_index + if selected_idx is not None and ( + selected_idx < 0 or selected_idx >= len(rect_infos) + ): + selected_idx = None + if selected_idx is None and rect_infos: + selected_idx = 0 + self._treemap_selected_index = 0 + + for i, info in enumerate(rect_infos): + rx, ry = info["x"], info["y"] + rdx, rdy = info["dx"], info["dy"] + x0 = int(ox + rx * aspect) + y0 = int(oy + ry) + x1 = int(x0 + rdx * aspect) + y1 = int(y0 + rdy) + if x1 <= x0: + x1 = x0 + 1 + if y1 <= y0: + y1 = y0 + 1 + + style = info["style"] + if base_path and info["path"] and info["path"][0] == 0: + full_path = base_path + info["path"][1:] + else: + full_path = base_path + info["path"] + rect_info = { + "x0": x0, + "y0": y0, + "x1": x1, + "y1": y1, + "label": info["label"], + "value": info["value"], + "total": total_value, + "style": style, + "value_display": dataset.value_display, + "currency_symbol": dataset.currency_symbol, + "has_children": info.get("has_children", False), + "node_index": i, + "path": full_path, + "tree": dataset.tree, + "path_to_style": path_to_style, + } + hover_rects.append(rect_info) + + draw_style = style + # Selected rect: hires mode uses solid fill; otherwise use braille for distinction + if selected_idx == i: + if dataset.hires_mode: + canvas.draw_filled_rectangle(x0, y0, x1, y1, style=draw_style) + else: + canvas.draw_filled_hires_rectangle( + float(x0), + float(y0), + float(x1), + float(y1), + hires_mode=HiResMode.BRAILLE, + style=draw_style, + ) + elif dataset.hires_mode: + canvas.draw_filled_hires_rectangle( + float(x0), + float(y0), + float(x1), + float(y1), + hires_mode=dataset.hires_mode, + style=draw_style, + ) + else: + canvas.draw_filled_rectangle(x0, y0, x1, y1, style=draw_style) + + rect_w = x1 - x0 + rect_h = y1 - y0 + value = info["value"] + pct = 100 * value / total_value if total_value else 0 + value_str = f"{value:,.0f}" if value == int(value) else f"{value:,.1f}" + currency_str = f"{dataset.currency_symbol}{value:,.2f}" + pct_str = f"{pct:.1f}%" + if dataset.value_display == ValueDisplay.VALUE: + line2 = value_str + elif dataset.value_display == ValueDisplay.PERCENT: + line2 = pct_str + elif dataset.value_display == ValueDisplay.BOTH: + line2 = f"{value_str} ({pct_str})" + elif dataset.value_display == ValueDisplay.CURRENCY: + line2 = currency_str + else: + line2 = None + needs_two_lines = line2 is not None and rect_h >= 2 + is_parent = info.get("has_children", False) + if rect_w >= 3: + cx = (x0 + x1) // 2 + bg_rgb = parse_style_to_rgb(draw_style) + if bg_rgb is not None: + text_color = distinctipy.get_text_color(bg_rgb) + fg = "black" if text_color == (0, 0, 0) else "white" + else: + fg = "white" + style_str = ( + f"bold {fg} on {draw_style}" + if selected_idx == i + else f"{fg} on {draw_style}" + ) + max_len = max(1, rect_w - 2) + if is_parent: + # Parent: label only, no values, in center row of top 3 + label_line1 = info["label"] + if len(label_line1) > max_len: + label_line1 = label_line1[: max_len - 1] + "…" + cy = y0 + 1 # center of top 3 rows + if rect_h >= TREEMAP_PARENT_LABEL_TOP: + canvas.write_text( + cx, + cy, + f"[{style_str}]{label_line1}", + align=TextAlign.CENTER, + ) + elif rect_h >= (2 if needs_two_lines else 1): + # Leaf: label and value + label_line1 = info["label"] + if len(label_line1) > max_len: + label_line1 = label_line1[: max_len - 1] + "…" + if needs_two_lines and line2 and len(line2) > max_len: + line2 = line2[: max_len - 1] + "…" + cy = (y0 + y1) // 2 + if needs_two_lines and line2: + canvas.write_text( + cx, + cy - 1, + f"[{style_str}]{label_line1}", + align=TextAlign.CENTER, + ) + canvas.write_text( + cx, cy, f"[{style_str}]{line2}", align=TextAlign.CENTER + ) + else: + canvas.write_text( + cx, + cy, + f"[{style_str}]{label_line1}", + align=TextAlign.CENTER, + ) + + self._treemap_hover_rects = hover_rects + if selected_idx is not None and selected_idx < len(hover_rects): + self._treemap_selected_rect = hover_rects[selected_idx] + else: + self._treemap_selected_rect = None + self._update_info(self._treemap_selected_rect) + def _render_v_line_plot(self, vline: VLinePlot) -> None: """Render a vertical line on the canvas. @@ -1232,8 +1992,13 @@ def _render_v_line_plot(self, vline: VLinePlot) -> None: def _render_x_ticks(self) -> None: """Render tick marks and labels for the x-axis.""" - canvas = self.query_one("#plot", Canvas) bottom_margin = self.query_one("#margin-bottom", Canvas) + # Hide ticks when plot contains only treemaps (no numeric axes) + if self._datasets and all(isinstance(d, TreemapPlot) for d in self._datasets): + bottom_margin.reset() + return + + canvas = self.query_one("#plot", Canvas) bottom_margin.reset() x_ticks: Sequence[float] @@ -1277,8 +2042,13 @@ def _render_x_ticks(self) -> None: def _render_y_ticks(self) -> None: """Render tick marks and labels for the y-axis.""" - canvas = self.query_one("#plot", Canvas) left_margin = self.query_one("#margin-left", Canvas) + # Hide ticks when plot contains only treemaps (no numeric axes) + if self._datasets and all(isinstance(d, TreemapPlot) for d in self._datasets): + left_margin.reset() + return + + canvas = self.query_one("#plot", Canvas) left_margin.reset() y_ticks: Sequence[float] @@ -1320,6 +2090,8 @@ def _render_y_ticks(self) -> None: def _render_x_label(self) -> None: """Render the x-axis label.""" + if self._datasets and all(isinstance(d, TreemapPlot) for d in self._datasets): + return canvas = self.query_one("#plot", Canvas) margin = self.query_one("#margin-bottom", Canvas) margin.write_text( @@ -1331,6 +2103,8 @@ def _render_x_label(self) -> None: def _render_y_label(self) -> None: """Render the y-axis label.""" + if self._datasets and all(isinstance(d, TreemapPlot) for d in self._datasets): + return margin = self.query_one("#margin-top", Canvas) margin.write_text( self.margin_left - 2, @@ -1503,30 +2277,38 @@ def _zoom( ) self._rerender() - @on(MouseScrollDown) - def zoom_in(self, event: MouseScrollDown) -> None: - """Zoom into the plot when scrolling down. + @on(MouseScrollUp) + def zoom_in(self, event: MouseScrollUp) -> None: + """Zoom into the plot when scrolling up. Args: - event: The mouse scroll down event. + event: The mouse scroll up event. """ event.stop() + if self._is_treemap_only_nested() and self._handle_treemap_key("plus"): + return self._zoom_with_mouse(event, self.MOUSE_ZOOM_FACTOR) - @on(MouseScrollUp) - def zoom_out(self, event: MouseScrollUp) -> None: - """Zoom out of the plot when scrolling up. + @on(MouseScrollDown) + def zoom_out(self, event: MouseScrollDown) -> None: + """Zoom out of the plot when scrolling down. Args: - event: The mouse scroll up event. + event: The mouse scroll down event. """ event.stop() + if self._is_treemap_only_nested() and self._handle_treemap_key("minus"): + return self._zoom_with_mouse(event, -self.MOUSE_ZOOM_FACTOR) def action_zoom_in(self) -> None: + if self._is_treemap_only_nested() and self._handle_treemap_key("plus"): + return self._zoom_with_keyboard(self.KEYBOARD_ZOOM_FACTOR) def action_zoom_out(self) -> None: + if self._is_treemap_only_nested() and self._handle_treemap_key("minus"): + return self._zoom_with_keyboard(-self.KEYBOARD_ZOOM_FACTOR) def action_zoom_x_in(self) -> None: @@ -1561,6 +2343,150 @@ def action_pan_down(self) -> None: """Pan the plot downward.""" self._pan(0, -self.KEYBOARD_PAN_FACTOR) + def _is_treemap_only_nested(self) -> bool: + """True if plot has only treemap datasets with nested (zoomable) data.""" + if not self._datasets or not all( + isinstance(d, TreemapPlot) for d in self._datasets + ): + return False + return any( + d.tree is not None for d in self._datasets if isinstance(d, TreemapPlot) + ) + + def _is_treemap_show_nested(self) -> bool: + """True if plot has a treemap with show_nested=True (full hierarchy view).""" + return any( + getattr(d, "show_nested", False) + for d in self._datasets + if isinstance(d, TreemapPlot) + ) + + def _handle_treemap_key(self, key: str) -> bool: + """Handle treemap-specific keys. Returns True if key was handled.""" + if not self._is_treemap_only_nested(): + return False + dataset = next( + (d for d in self._datasets if isinstance(d, TreemapPlot) and d.tree), None + ) + if dataset is None: + return False + + # In show_nested mode, use hover_rects count (all leaves); else use level count + if self._is_treemap_show_nested() and self._treemap_hover_rects: + n = len(self._treemap_hover_rects) + else: + n = len(get_treemap_level(dataset.tree, self._treemap_path)) + if n == 0: + return False + + if key == "escape": + if self._treemap_path: + self._treemap_path.pop() + level = get_treemap_level(dataset.tree, self._treemap_path) + self._treemap_selected_index = min( + self._treemap_selected_index or 0, + len(level) - 1, + ) + self._rerender() + return True + + if key in ("plus", "equal", "+"): + idx = ( + self._treemap_selected_index + if self._treemap_selected_index is not None + else 0 + ) + if self._treemap_hover_rects and idx < len(self._treemap_hover_rects): + rect = self._treemap_hover_rects[idx] + if rect.get("has_children"): + if self._is_treemap_show_nested(): + path = rect.get("path", []) + seg = ( + path[len(self._treemap_path)] + if len(path) > len(self._treemap_path) + else None + ) + if seg is not None: + self._treemap_path.append(seg) + self._treemap_selected_index = 0 + self._rerender() + else: + self._treemap_path.append(rect["node_index"]) + self._treemap_selected_index = 0 + self._rerender() + return True + + if key in ("minus", "-"): + if self._treemap_path: + self._treemap_path.pop() + level = get_treemap_level(dataset.tree, self._treemap_path) + self._treemap_selected_index = min( + self._treemap_selected_index or 0, + len(level) - 1, + ) + self._rerender() + return True + + if key in ("left", "right", "up", "down"): + idx = self._treemap_selected_index + if idx is None: + idx = 0 + if key == "left": + idx = max(0, idx - 1) + elif key == "right": + idx = min(n - 1, idx + 1) + elif key == "up": + idx = max(0, idx - 1) + elif key == "down": + idx = min(n - 1, idx + 1) + self._treemap_selected_index = idx + if self._treemap_hover_rects and idx < len(self._treemap_hover_rects): + self._treemap_selected_rect = self._treemap_hover_rects[idx] + self._update_info(self._treemap_selected_rect) + self._rerender() + return True + + return False + + @on(Key) + def _on_key(self, event: Key) -> None: + """Handle keys; treemap-specific handling takes precedence when applicable.""" + if self._handle_treemap_key(event.key): + event.stop() + + def action_treemap_zoom_out(self) -> None: + """Zoom out of nested treemap (bound to Escape).""" + self._handle_treemap_key("escape") + + def treemap_zoom_in(self, path: list[int] | int) -> None: + """Zoom into a nested treemap node. + + Args: + path: Index or list of indices into the tree. E.g. 0 zooms to first + top-level node; [0, 1] zooms to second child of first top-level. + """ + if not self._is_treemap_only_nested(): + return + path_list = [path] if isinstance(path, int) else list(path) + if not path_list: + return + dataset = next( + (d for d in self._datasets if isinstance(d, TreemapPlot) and d.tree), + None, + ) + if dataset is None: + return + node = get_treemap_node_at_path(dataset.tree, path_list) + if node is None or not node.get("children"): + return + self._treemap_path = path_list + self._treemap_selected_index = 0 + self._rerender() + + def treemap_zoom_out(self) -> None: + """Zoom out one level in nested treemap.""" + self._handle_treemap_key("escape") + @on(MouseDown) def start_dragging_legend(self, event: MouseDown) -> None: """Start dragging the legend when clicked with left mouse button. @@ -1586,6 +2512,87 @@ def stop_dragging_legend(self, event: MouseUp) -> None: self.query_one("#legend").remove_class("dragged") event.stop() + def _get_treemap_rect_at(self, event: MouseMove | MouseDown | Click) -> dict | None: + """Return the treemap rect at the mouse position, or None.""" + if not self._treemap_hover_rects: + return None + try: + canvas = self.query_one("#plot", Canvas) + except NoMatches: + return None + # Use screen offset relative to canvas for accurate hit-testing (avoids + # coordinate system mismatch with get_content_offset on nested widgets) + canvas_offset = event.screen_offset - self.screen.get_offset(canvas) + cx = canvas_offset.x + cy = canvas_offset.y + if ( + not canvas.size + or cx < 0 + or cy < 0 + or cx >= canvas.size.width + or cy >= canvas.size.height + ): + return None + for rect_info in reversed(self._treemap_hover_rects): + r = rect_info + if r["x0"] <= cx < r["x1"] and r["y0"] <= cy < r["y1"]: + return rect_info + return None + + @on(Click) + def _handle_treemap_click(self, event: Click) -> None: + """Select treemap rectangle on click; zoom in on double-click when rect has children.""" + if event.button != 1: + return + rect = self._get_treemap_rect_at(event) + if rect is None: + return + self._treemap_selected_rect = rect + self._treemap_selected_index = rect.get("node_index", 0) + + # Double-click: zoom in when rect has children + if event.chain == 2 and rect.get("has_children"): + if self._is_treemap_show_nested(): + # Append path segment to zoom into this parent + path = rect.get("path", []) + seg = ( + path[len(self._treemap_path)] + if len(path) > len(self._treemap_path) + else None + ) + if seg is not None: + self._treemap_path.append(seg) + self._treemap_selected_index = 0 + else: + # Non-nested: zoom in by index + self._treemap_path.append(rect["node_index"]) + self._treemap_selected_index = 0 + + # Single click in non-show_nested: zoom in on parent (legacy behavior) + elif ( + event.chain == 1 + and rect.get("has_children") + and not self._is_treemap_show_nested() + ): + self._treemap_path.append(rect["node_index"]) + self._treemap_selected_index = 0 + + self._update_info(rect) + self._rerender() + + @on(MouseMove) + def _handle_treemap_hover(self, event: MouseMove) -> None: + """Update treemap info box when hovering over rectangles.""" + if not self._treemap_hover_rects: + self._update_info(None) + return + # In show_nested mode: only show selected, not hover + if self._is_treemap_show_nested(): + self._update_info(self._treemap_selected_rect) + return + rect = self._get_treemap_rect_at(event) + self._update_info(rect if rect is not None else self._treemap_selected_rect) + @on(MouseMove) def drag_with_mouse(self, event: MouseMove) -> None: """Handle mouse drag operations for panning the plot or the legend. diff --git a/src/textual_plot/treemap_utils.py b/src/textual_plot/treemap_utils.py new file mode 100644 index 0000000..79222a8 --- /dev/null +++ b/src/textual_plot/treemap_utils.py @@ -0,0 +1,356 @@ +"""Treemap layout and tree utilities.""" + +from __future__ import annotations + +from typing import Any, Sequence + +import distinctipy +import numpy as np +import squarify # type: ignore[import-untyped] +from numpy.typing import ArrayLike + +from textual_plot.color_utils import ( + parse_style_to_rgb, + rgb_style, + rgb_too_close, + tint_with_hue, +) + +# 1-character border around children so parent is visible and selectable +TREEMAP_PARENT_PAD = 1 +# Top inset for parent labels (3 rows); sides/bottom stay 1 +TREEMAP_PARENT_LABEL_TOP = 3 + +TreemapNode: type = dict[str, Any] + + +def treemap_max_depth(nodes: list[TreemapNode]) -> int: + """Max nesting depth of tree (1 = no nesting, 2 = one level of children, etc).""" + if not nodes: + return 0 + return 1 + max(treemap_max_depth(n.get("children") or []) for n in nodes) + + +def normalize_treemap_tree( + values: ArrayLike | list[Any] | list[list[float]], + labels: Sequence[str] | Sequence[Sequence[str]] | None = None, +) -> tuple[list[TreemapNode], bool]: + """Normalize treemap input to a tree of nodes. Returns (nodes, is_nested).""" + if not values: + return ([], False) + + try: + vals_list = list(values) + except TypeError: + vals_list = np.array(values).tolist() + first = vals_list[0] if len(vals_list) > 0 else None + + # Flat: list of numbers + if isinstance(first, (int, float, np.floating, np.integer)): + arr = np.array(values, dtype=float) + arr = arr[~np.isnan(arr) & ~np.isinf(arr)] + n = len(arr) + labels_list = list(labels)[:n] if labels else None + nodes = [ + { + "value": float(arr[i]), + "label": ( + labels_list[i] + if labels_list and i < len(labels_list) + else f"Item {i + 1}" + ), + "children": None, + } + for i in range(n) + ] + return (nodes, False) + + # Nested: list of dicts with "label" and "children" + if isinstance(first, dict): + nodes = [] + for i, item in enumerate(vals_list): + if not isinstance(item, dict): + continue + children_raw = item.get("children") + label = item.get("label", f"Item {i + 1}") + if children_raw is not None: + child_labels = None + if labels and i < len(labels): + lab = labels[i] + if isinstance(lab, (list, tuple)): + child_labels = lab + sub_nodes, _ = normalize_treemap_tree(children_raw, child_labels) + value = sum(n["value"] for n in sub_nodes) + nodes.append({"value": value, "label": label, "children": sub_nodes}) + else: + value = float(item.get("value", 0)) + nodes.append({"value": value, "label": label, "children": None}) + return (nodes, any(n.get("children") for n in nodes)) + + # Nested: list of lists [[100, 50], [30, 20]] + if isinstance(first, (list, tuple)): + labels_nested = list(labels) if labels else [] + top_labels = ( + labels_nested[0] + if labels_nested and isinstance(labels_nested[0], (list, tuple)) + else labels_nested + ) + child_labels_list = ( + labels_nested[1] + if len(labels_nested) > 1 and isinstance(labels_nested[1], (list, tuple)) + else None + ) + nodes = [] + for i, group in enumerate(vals_list): + if not isinstance(group, (list, tuple)): + continue + sub_labels = ( + child_labels_list[i] + if child_labels_list and i < len(child_labels_list) + else None + ) + sub_nodes, _ = normalize_treemap_tree(group, sub_labels) + value = sum(n["value"] for n in sub_nodes) + group_label = f"Group {i + 1}" + if top_labels and i < len(top_labels) and isinstance(top_labels[i], str): + group_label = top_labels[i] + nodes.append({"value": value, "label": group_label, "children": sub_nodes}) + return (nodes, True) + + return ([], False) + + +def get_treemap_level(nodes: list[TreemapNode], path: list[int]) -> list[TreemapNode]: + """Get the node list at the given path. path=[] returns nodes.""" + current: list[TreemapNode] = nodes + for idx in path: + if idx < 0 or idx >= len(current): + return [] + current = current[idx].get("children") or [] + return current + + +def get_treemap_node_at_path( + tree: list[TreemapNode], path: list[int] +) -> TreemapNode | None: + """Get node at path. path=[0,1] -> tree[0].children[1].""" + if not path: + return None + cur: list[TreemapNode] = tree + for i in path[:-1]: + if i < 0 or i >= len(cur): + return None + cur = cur[i].get("children") or [] + if path[-1] < 0 or path[-1] >= len(cur): + return None + return cur[path[-1]] + + +def get_path_styles( + path: list[int], + path_to_style: dict[tuple[int, ...], str], +) -> list[str]: + """Get style for each node in path from path_to_style mapping.""" + if not path or not path_to_style: + return [] + return [ + path_to_style[tuple(path[:i])] + for i in range(1, len(path) + 1) + if tuple(path[:i]) in path_to_style + ] + + +def format_treemap_nested_path( + tree: list[TreemapNode], + path: list[int], + total: float, + value_display: "ValueDisplay", + currency_symbol: str, +) -> tuple[list[tuple[str, str | None]], str]: + """Build breadcrumb for nested treemap. Returns ([(label, style), ...], plain_text).""" + if not tree or not path or total <= 0: + return ([], "") + from textual_plot.plot_widget import ValueDisplay + + segments: list[tuple[str, str | None]] = [] + plain_parts: list[str] = [] + current_nodes: list[TreemapNode] = tree + parent_value = total + for depth, idx in enumerate(path): + if idx < 0 or idx >= len(current_nodes): + break + node = current_nodes[idx] + value = node["value"] + label = node.get("label", "?") + pct_parent = 100 * value / parent_value if parent_value else 0 + pct_all = 100 * value / total if total else 0 + if value_display == ValueDisplay.CURRENCY: + value_str = f"{currency_symbol}{value:,.2f}" + else: + value_str = f"{value:,.0f}" if value == int(value) else f"{value:,.1f}" + if depth == 0: + segments.append((f"{label} ▪ {value_str} ▪ {pct_all:.1f}%", None)) + plain_parts.append(f"{label} ▪ {value_str} ▪ {pct_all:.1f}%") + else: + segments.append( + ( + f"{label} ▪ {value_str} ▪ {pct_parent:.1f}% ({pct_all:.1f}% All)", + None, + ) + ) + plain_parts.append( + f"{label} ▪ {value_str} ▪ {pct_parent:.1f}% ({pct_all:.1f}% All)" + ) + current_nodes = node.get("children") or [] + parent_value = value + return (segments, " ▶ ".join(plain_parts)) + + +def squarify_recursive( + nodes: list[TreemapNode], + x: float, + y: float, + dx: float, + dy: float, + aspect: float, + path: list[int], + base_styles: list[str], + exclude_colors: list[tuple[float, float, float]], + path_to_style: dict[tuple[int, ...], str] | None = None, +) -> list[dict]: + """Recursively run squarify on a tree. Returns rect info for parents and leaves. + + Parents get distinctipy colors; children get distinctipy (excluding parent) tinted + with parent's hue. At every nesting level: parents output first (draw underneath); + children inset by TREEMAP_PARENT_PAD so parent shows as 1-char border. + + When path_to_style is provided, uses those styles for nodes instead of generating + new ones (keeps colors consistent when zooming). + """ + if not nodes or dx <= 0 or dy <= 0: + return [] + out: list[dict] = [] + values = [n["value"] for n in nodes] + normalized = squarify.normalize_sizes(values, dx, dy) + rects = squarify.squarify(normalized, 0, 0, dx, dy) + pad = TREEMAP_PARENT_PAD + # Horizontal: same as bottom (2 chars) so parent border visible on left and right + inset_x = float(pad) * 2 / max(0.01, aspect) + inset_y_top = float(TREEMAP_PARENT_LABEL_TOP) + inset_y_bottom = float(pad) * 2 # 2 rows at bottom so parent border is visible + for i, rect in enumerate(rects): + if i >= len(nodes): + break + node = nodes[i] + rx, ry = rect["x"], rect["y"] + rdx, rdy = rect["dx"], rect["dy"] + child_path = path + [i] + path_key = tuple(child_path) + if path_to_style is not None and path_key in path_to_style: + style = path_to_style[path_key] + else: + style = base_styles[i] if i < len(base_styles) else base_styles[0] + if node.get("children"): + parent_rgb = parse_style_to_rgb(style) + n_children = len(node["children"]) + child_styles = [] + if path_to_style is not None: + for j in range(n_children): + child_path_key = tuple(child_path + [j]) + if child_path_key in path_to_style: + child_styles.append(path_to_style[child_path_key]) + else: + child_styles.append( + base_styles[0] if base_styles else "rgb(128,128,128)" + ) + else: + child_exclude = ( + [parent_rgb] + exclude_colors + if parent_rgb + else list(exclude_colors) + ) + child_exclude.extend( + [(1.0, 1.0, 1.0), (0.0, 0.0, 0.0)] + ) # always exclude white/black + child_colors = distinctipy.get_colors( + n_children, + exclude_colors=child_exclude, + pastel_factor=0.2, + rng=42, + colorblind_type="Deuteranomaly", + ) + current_exclude = list(child_exclude) + for c in child_colors: + tinted = tint_with_hue(c, parent_rgb) if parent_rgb else c + retries = 0 + while ( + parent_rgb + and rgb_too_close(tinted, parent_rgb) + and retries < 50 + ): + current_exclude.append(c) + current_exclude.append(tinted) + c = distinctipy.distinct_color( + current_exclude, + pastel_factor=0.2, + rng=42, + colorblind_type="Deuteranomaly", + ) + tinted = tint_with_hue(c, parent_rgb) + retries += 1 + child_styles.append(rgb_style(tinted)) + current_exclude.append(c) + current_exclude.append(tinted) + parent_value = sum(n["value"] for n in node["children"]) + out.append( + { + "x": x + rx, + "y": y + ry, + "dx": rdx, + "dy": rdy, + "node": node, + "path": child_path, + "style": style, + "selection_base": style, + "value": parent_value, + "label": node["label"], + "has_children": True, + } + ) + child_dx = max(0.1, rdx - 2 * inset_x) + child_dy = max(0.1, rdy - inset_y_top - inset_y_bottom) + child_x = x + rx + inset_x + child_y = y + ry + inset_y_top + child_exclude_list = ( + [parent_rgb] + exclude_colors if parent_rgb else exclude_colors + ) + sub = squarify_recursive( + node["children"], + child_x, + child_y, + child_dx, + child_dy, + aspect, + child_path, + child_styles, + child_exclude_list, + path_to_style, + ) + out.extend(sub) + else: + out.append( + { + "x": x + rx, + "y": y + ry, + "dx": rdx, + "dy": rdy, + "node": node, + "path": child_path, + "style": style, + "selection_base": style, + "value": node["value"], + "label": node["label"], + "has_children": False, + } + ) + return out diff --git a/tests/test_treemap.py b/tests/test_treemap.py new file mode 100644 index 0000000..bd11514 --- /dev/null +++ b/tests/test_treemap.py @@ -0,0 +1,81 @@ +"""Tests for treemap plot functionality.""" + +import numpy as np + +from textual_plot.plot_widget import TreemapPlot, ValueDisplay + + +class TestTreemapPlot: + """Test TreemapPlot dataclass and squarify integration.""" + + def test_treemap_empty_after_filter_returns_early(self) -> None: + """treemap() returns early when all values are NaN/Inf.""" + from textual_plot import PlotWidget + + plot = PlotWidget() + plot.treemap([np.nan, np.inf]) + assert len(plot._datasets) == 0 + + def test_squarify_layout_integration(self) -> None: + """squarify produces valid rectangles for treemap layout.""" + import squarify + + values = [500, 433, 78, 25, 25, 7] + width, height = 100, 50 + normalized = squarify.normalize_sizes(values, width, height) + rects = squarify.squarify(normalized, 0, 0, width, height) + assert len(rects) == len(values) + total_area = sum(r["dx"] * r["dy"] for r in rects) + assert abs(total_area - width * height) < 0.01 # Allow small float error + + def test_treemap_plot_structure(self) -> None: + """TreemapPlot dataclass has expected structure for rendering.""" + values = np.array([10.0, 20.0, 30.0]) + dataset = TreemapPlot( + values=values, + labels=["A", "B", "C"], + styles=["red", "blue", "green"], + padding=1, + hires_mode=None, + aspect_preference=1.5, + value_display=ValueDisplay.BOTH, + currency_symbol="$", + tree=None, + show_nested=False, + ) + assert len(dataset.values) == 3 + assert dataset.labels == ["A", "B", "C"] + assert dataset.padding == 1 + + def test_treemap_show_nested_dataset_structure(self) -> None: + """TreemapPlot with show_nested=True has expected structure for nested rendering.""" + from textual_plot.treemap_utils import normalize_treemap_tree + + tree_nodes, is_nested = normalize_treemap_tree( + [ + { + "label": "A", + "children": [ + {"label": "A1", "value": 10}, + {"label": "A2", "value": 20}, + ], + }, + {"label": "B", "value": 30}, + ] + ) + assert is_nested + assert len(tree_nodes) == 2 + dataset = TreemapPlot( + values=np.array([30.0, 30.0]), + labels=["A", "B"], + styles=["red", "blue"], + padding=1, + hires_mode=None, + aspect_preference=1.5, + value_display=ValueDisplay.BOTH, + currency_symbol="$", + tree=tree_nodes, + show_nested=True, + ) + assert dataset.tree is not None + assert dataset.show_nested is True diff --git a/tests/treemap.py b/tests/treemap.py new file mode 100644 index 0000000..1d8360e --- /dev/null +++ b/tests/treemap.py @@ -0,0 +1,27 @@ +"""Demo of treemap plot in textual-plot.""" + +from textual.app import App, ComposeResult + +from textual_plot import PlotWidget + + +class TreemapApp(App[None]): + def compose(self) -> ComposeResult: + yield PlotWidget() + + def on_mount(self) -> None: + plot = self.query_one(PlotWidget) + values = [500, 433, 78, 25, 25, 7] + labels = ["A", "B", "C", "D", "E", "F"] + styles = ["red", "blue", "green", "yellow", "cyan", "magenta"] + plot.treemap( + values, + labels=labels, + styles=styles, + padding=1, + label="Categories", + ) + plot.show_legend() + + +TreemapApp().run() diff --git a/uv.lock b/uv.lock index 39a827f..476d716 100644 --- a/uv.lock +++ b/uv.lock @@ -321,6 +321,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, ] +[[package]] +name = "distinctipy" +version = "1.3.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/8c/3c/e0b90a5bc396e2abaf207d9a41ea8aeab1f41760425262474903bade6a7b/distinctipy-1.3.4.tar.gz", hash = "sha256:fed97afff1afb73ecaa87c85461021f0ba89fae63067c0125b9673526510aac4", size = 29711, upload-time = "2024-01-10T21:32:24.032Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0a/75/fa882538bdb0c8fc4459f1595a761591b827691936d57c08c492676f19bc/distinctipy-1.3.4-py3-none-any.whl", hash = "sha256:2bf57d9d20dbc5c2fd462298573cc963c037f493d04ec61e94cb8d0bf5023c74", size = 26743, upload-time = "2024-01-10T21:32:22.351Z" }, +] + [[package]] name = "exceptiongroup" version = "1.3.1" @@ -1598,6 +1611,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" }, ] +[[package]] +name = "squarify" +version = "0.4.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0d/01/1753243870dff9fa786c9712fdc8dffb56f356c46c77d7468cb12f6d8398/squarify-0.4.4.tar.gz", hash = "sha256:b8a110c8dc5f1cd1402ca12d79764a081e90bfc445346cfa166df929753ecb46", size = 5514, upload-time = "2024-07-19T18:57:41.418Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/3c/eedbe9fb07cc20fd9a8423da14b03bc270d0570b3ba9174a4497156a2152/squarify-0.4.4-py3-none-any.whl", hash = "sha256:d7597724e29d48aa14fd2f551060d6b09e1f0a67e4cd3ea329fe03b4c9a56f11", size = 4082, upload-time = "2024-07-19T18:57:40.338Z" }, +] + [[package]] name = "textual" version = "7.0.2" @@ -1651,8 +1673,10 @@ name = "textual-plot" version = "0.10.1" source = { editable = "." } dependencies = [ + { name = "distinctipy" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "squarify" }, { name = "textual" }, { name = "textual-hires-canvas" }, ] @@ -1669,7 +1693,9 @@ dev = [ [package.metadata] requires-dist = [ + { name = "distinctipy", specifier = ">=1.3.0" }, { name = "numpy", specifier = ">=2.2.1" }, + { name = "squarify", specifier = ">=0.4.3" }, { name = "textual", specifier = ">=1.0.0" }, { name = "textual-hires-canvas", specifier = ">=0.14.0" }, ]