From 63728e9121e738e9b31ceecafed37c02e9d8433f Mon Sep 17 00:00:00 2001 From: Ryan Meyers Date: Tue, 17 Feb 2026 11:40:32 -0600 Subject: [PATCH 1/4] Add treemap support in lieu of pie charts --- README.md | 2 +- pyproject.toml | 2 + src/textual_plot/__init__.py | 3 +- src/textual_plot/demo.py | 49 ++++ src/textual_plot/plot_widget.py | 464 +++++++++++++++++++++++++++++++- tests/test_treemap.py | 46 ++++ tests/treemap.py | 27 ++ uv.lock | 26 ++ 8 files changed, 607 insertions(+), 12 deletions(-) create mode 100644 tests/test_treemap.py create mode 100644 tests/treemap.py diff --git a/README.md b/README.md index da36145..fd9a6a6 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # A native plotting widget for Textual apps -[Textual](https://www.textualize.io/) is an excellent Python framework for building applications in the terminal, or on the web. This library provides a plot widget which your app can use to plot all kinds of quantitative data. So, no pie charts, sorry. The widget support scatter plots and line plots, and can also draw using _high-resolution_ characters like unicode half blocks, quadrants and 8-dot Braille characters. It may still be apparent that these are drawn using characters that take up a full block in the terminal, especially when plot series overlap. However, the use of these characters can reduce the line thickness and improve the resolution tremendously. +[Textual](https://www.textualize.io/) is an excellent Python framework for building applications in the terminal, or on the web. This library provides a plot widget which your app can use to plot all kinds of quantitative data. So, no pie charts, sorry, but we do have a treemap! The widget support scatter plots and line plots, and can also draw using _high-resolution_ characters like unicode half blocks, quadrants and 8-dot Braille characters. It may still be apparent that these are drawn using characters that take up a full block in the terminal, especially when plot series overlap. However, the use of these characters can reduce the line thickness and improve the resolution tremendously. ## Screenshots diff --git a/pyproject.toml b/pyproject.toml index b0ecfdc..8704c73 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,9 @@ license = "MIT" license-files = ["LICENSE"] requires-python = ">=3.10" dependencies = [ + "distinctipy>=1.3.0", "numpy>=2.2.1", + "squarify>=0.4.3", "textual>=1.0.0", "textual-hires-canvas>=0.14.0", ] diff --git a/src/textual_plot/__init__.py b/src/textual_plot/__init__.py index 5bfca09..e1e3239 100644 --- a/src/textual_plot/__init__.py +++ b/src/textual_plot/__init__.py @@ -3,7 +3,7 @@ DurationFormatter, NumericAxisFormatter, ) -from textual_plot.plot_widget import HiResMode, LegendLocation, PlotWidget +from textual_plot.plot_widget import HiResMode, LegendLocation, PlotWidget, ValueDisplay __all__ = [ "AxisFormatter", @@ -12,4 +12,5 @@ "LegendLocation", "NumericAxisFormatter", "PlotWidget", + "ValueDisplay", ] diff --git a/src/textual_plot/demo.py b/src/textual_plot/demo.py index 3998b50..3c2f2ef 100644 --- a/src/textual_plot/demo.py +++ b/src/textual_plot/demo.py @@ -326,6 +326,53 @@ def action_cycle_hires_mode(self) -> None: self.plot() +class TreemapPlot(Container): + BINDINGS = [("h", "cycle_hires_mode", "HiRes")] + + _hires_mode = itertools.cycle([None, HiResMode.BRAILLE]) + hires_mode = next(_hires_mode) + + def compose(self) -> ComposeResult: + yield PlotWidget() + + def on_mount(self) -> None: + self.plot() + + def plot(self) -> None: + plot = self.query_one(PlotWidget) + plot.clear() + + values = [500, 433, 280, 195, 165, 120, 95, 78, 62, 48, 35, 28, 25, 18, 12] + labels = [ + "Electronics", + "Clothing", + "Home", + "Sports", + "Books", + "Toys", + "Garden", + "Auto", + "Health", + "Beauty", + "Office", + "Food", + "Pets", + "Music", + "Art", + ] + plot.treemap( + values, + labels=labels, + padding=1, + hires_mode=self.hires_mode, + ) + plot.show_legend() + + def action_cycle_hires_mode(self) -> None: + self.hires_mode = next(self._hires_mode) + self.plot() + + class DemoApp(App[None]): AUTO_FOCUS = "SinePlot > PlotWidget" @@ -350,6 +397,8 @@ def compose(self) -> ComposeResult: yield ErrorBarPlot() with TabPane("Bar plot", id="barplot"): yield BarPlot() + with TabPane("Treemap", id="treemap"): + yield TreemapPlot() def on_mount(self) -> None: self.theme = "tokyo-night" diff --git a/src/textual_plot/plot_widget.py b/src/textual_plot/plot_widget.py index 59abf8f..90b7f00 100644 --- a/src/textual_plot/plot_widget.py +++ b/src/textual_plot/plot_widget.py @@ -9,6 +9,7 @@ from __future__ import annotations import enum +import re import sys from dataclasses import dataclass from math import ceil, floor @@ -22,7 +23,9 @@ from typing import Self else: from typing_extensions import Self +import distinctipy import numpy as np +import squarify # type: ignore[import-untyped] from numpy.typing import ArrayLike, NDArray from textual import on from textual._box_drawing import BOX_CHARACTERS, combine_quads @@ -49,7 +52,7 @@ NumericAxisFormatter, ) -__all__ = ["HiResMode", "LegendLocation", "PlotWidget"] +__all__ = ["HiResMode", "LegendLocation", "PlotWidget", "ValueDisplay"] FloatScalar: TypeAlias = float | np.floating FloatArray: TypeAlias = NDArray[np.floating] @@ -76,6 +79,7 @@ class LegendLocation(enum.Enum): TOPRIGHT = enum.auto() BOTTOMLEFT = enum.auto() BOTTOMRIGHT = enum.auto() + LEFT = enum.auto() # Y-axis label area (left margin); used for treemap default @dataclass @@ -143,6 +147,41 @@ class BarPlot(DataSet): bar_style: str | list[str] +class ValueDisplay(enum.Enum): + """What to show for treemap rectangle values in labels.""" + + VALUE = "value" # Raw value + PERCENT = "percent" # Percent of total + BOTH = "both" # Value and percent + CURRENCY = "currency" # Currency format (2 decimals, configurable symbol) + NONE = "none" # No value/percent line + + +@dataclass +class TreemapPlot: + """A dataset for rendering as a treemap. + + Attributes: + values: Array of numeric values (sizes) for each rectangle. + labels: Optional labels for each rectangle (for legend). + styles: Rich style string or list of styles for each rectangle. + padding: Pixel padding between rectangles and from edges. + hires_mode: HiResMode for high-resolution rendering or None. + aspect_preference: Bias toward wider (>1) or taller (<1) rectangles. + value_display: What to show on the second line of labels. + currency_symbol: Symbol for CURRENCY display mode (e.g. "$", "€"). + """ + + values: FloatArray + labels: list[str] | None + styles: str | list[str] + padding: int + hires_mode: HiResMode | None + aspect_preference: float + value_display: ValueDisplay + currency_symbol: str + + @dataclass class VLinePlot: """A vertical line to be drawn on the plot. @@ -200,7 +239,7 @@ class ScaleChanged(Message): DEFAULT_CSS = """ PlotWidget { - layers: plot legend; + layers: plot legend info; &:focus > .plot--axis { color: $primary; @@ -229,6 +268,14 @@ class ScaleChanged(Message): } } + #info { + layer: info; + width: auto; + border-top: solid $secondary; + padding: 0 2; + display: none; + } + #legend { layer: legend; width: auto; @@ -276,7 +323,7 @@ class ScaleChanged(Message): KEYBOARD_ZOOM_FACTOR: float = 0.15 KEYBOARD_PAN_FACTOR: float = 2.0 - _datasets: list[DataSet] + _datasets: list[DataSet | TreemapPlot] _labels: list[str | None] _user_x_min: float | None = None @@ -341,6 +388,8 @@ def __init__( self._labels = [] self._v_lines: list[VLinePlot] = [] self._v_lines_labels: list[str | None] = [] + self._treemap_hover_rects: list[dict] = [] + self._treemap_selected_rect: dict | None = None self._allow_pan_and_zoom = allow_pan_and_zoom self.invert_mouse_wheel = invert_mouse_wheel self._x_formatter = NumericAxisFormatter() @@ -358,6 +407,7 @@ def compose(self) -> ComposeResult: yield Canvas(1, 1, id="plot") yield Canvas(1, 1, id="margin-bottom") yield Legend(id="legend") + yield Static("", id="info") def on_mount(self) -> None: """Initialize the plot widget when mounted to the DOM.""" @@ -394,6 +444,9 @@ def clear(self) -> None: self._labels = [] self._v_lines = [] self._v_lines_labels = [] + self._treemap_hover_rects = [] + self._treemap_selected_rect = None + self._update_info(None) self._update_legend() self._rerender() @@ -598,6 +651,96 @@ def bar( self._update_legend() self._rerender() + def treemap( + self, + values: ArrayLike, + labels: Sequence[str] | None = None, + styles: str | Sequence[str] | None = None, + padding: int = 1, + hires_mode: HiResMode | None = None, + label: str | None = None, + aspect_preference: float = 1.5, + value_display: ValueDisplay | str = ValueDisplay.BOTH, + currency_symbol: str = "$", + ) -> None: + """Graph data as a treemap. + + Treemaps display hierarchical or flat data as nested rectangles, with + each rectangle's area proportional to its value. Uses the squarify + algorithm for layout. + + Args: + values: An ArrayLike with the numeric values (sizes) for each + rectangle. Supports flat lists for single-level treemaps. + labels: Optional sequence of labels for each rectangle (for legend). + Defaults to None. + styles: A style string for all rectangles or a sequence of styles + for each. Defaults to a color cycle if None. + padding: Pixel padding between rectangles and from edges. + Defaults to 1. + hires_mode: A HiResMode enum or None for standard rendering. + Defaults to None. + label: A string with the label for the dataset in the legend. + Defaults to None. + aspect_preference: Bias toward wider (>1) or taller (<1) rectangles. + Values >1 prefer wider rectangles (better for text labels). + Defaults to 1.5. + value_display: What to show on the second line of labels: "value", + "percent", "both", "currency", or "none". Defaults to "both". + currency_symbol: Symbol for currency display mode (e.g. "$", "€"). + Defaults to "$". + """ + vd = ( + ValueDisplay(value_display) + if isinstance(value_display, str) + else value_display + ) + values_array = np.array(values, dtype=float) + values_array = values_array[~np.isnan(values_array) & ~np.isinf(values_array)] + if len(values_array) == 0: + return + + # Default colors: use distinctipy for perceptually distinct CIELAB-inspired palette + def _rgb_style(rgb: tuple[float, float, float]) -> str: + r, g, b = rgb + return f"rgb({int(r * 255)},{int(g * 255)},{int(b * 255)})" + + n = len(values_array) + if styles is None: + distinct_colors = distinctipy.get_colors( + n, pastel_factor=0.2, rng=42 + ) # deterministic, slightly pastel for terminal contrast + styles_list = [_rgb_style(c) for c in distinct_colors] + elif isinstance(styles, str): + styles_list = [styles] * n + else: + styles_list = list(styles) + if len(styles_list) < n: + extra = distinctipy.get_colors( + n - len(styles_list), pastel_factor=0.2, rng=42 + ) + styles_list.extend(_rgb_style(c) for c in extra) + + labels_list: list[str] | None = None + if labels is not None: + labels_list = list(labels)[: len(values_array)] + + self._datasets.append( + TreemapPlot( + values=values_array, + labels=labels_list, + styles=styles_list, + padding=padding, + hires_mode=hires_mode, + aspect_preference=aspect_preference, + value_display=vd, + currency_symbol=currency_symbol, + ) + ) + self._labels.append(label) + self._update_legend() + self._rerender() + def add_v_line( self, x: float, line_style: str = "white", label: str | None = None ) -> None: @@ -720,10 +863,52 @@ def show_legend( raise TypeError( f"Expected LegendLocation, got {type(location).__name__} instead." ) + elif ( + is_visible + and self._datasets + and all(isinstance(d, TreemapPlot) for d in self._datasets) + ): + self._legend_location = LegendLocation.LEFT + self._legend_relative_offset = Offset(0, 0) self.query_one("#legend", Static).display = is_visible if is_visible: self._update_legend() + def _update_info(self, rect_info: dict | None) -> None: + """Update the plot info box with rectangle data or hide it.""" + try: + info_box = self.query_one("#info", Static) + except NoMatches: + return + if rect_info is None: + info_box.display = False + return + total = rect_info["total"] + value = rect_info["value"] + pct = 100 * value / total if total else 0 + value_display = rect_info.get("value_display", ValueDisplay.BOTH) + currency_symbol = rect_info.get("currency_symbol", "$") + if value_display == ValueDisplay.CURRENCY: + value_str = f"{currency_symbol}{value:,.2f}" + else: + value_str = f"{value:,.0f}" if value == int(value) else f"{value:,.1f}" + style = rect_info.get("style", "") + title = Text("███") + title.stylize(style) + title.append( + f" {rect_info['label']} · Value: {value_str} · Percent: {pct:.1f}%" + ) + info_box.update(title) + info_box.display = True + # Position in margin-bottom (x-axis label) area - unused for treemap + canvas = self.query_one("#plot", Canvas) + if canvas.size: + info_box.offset = Offset( + self.margin_left + 1, + self.margin_top + canvas.size.height, + ) + info_box.refresh(layout=True) + def _update_legend(self) -> None: """Update the content and position of the plot legend.""" legend = self.query_one("#legend", Static) @@ -732,6 +917,20 @@ def _update_legend(self) -> None: legend_lines = [] for label, dataset in zip(self._labels, self._datasets): + # Treemap with per-rectangle labels: show them even if dataset label is None + if isinstance(dataset, TreemapPlot) and dataset.labels: + rect_styles = ( + dataset.styles + if isinstance(dataset.styles, list) + else [dataset.styles] * len(dataset.labels) + ) + for rect_label, rect_style in zip(dataset.labels, rect_styles): + text = Text("███") + text.stylize(rect_style) + text.append(f" {rect_label}") + legend_lines.append(text.markup) + continue + if label is not None: if isinstance(dataset, LinePlot): marker = LEGEND_LINE[dataset.hires_mode] @@ -758,6 +957,13 @@ def _update_legend(self) -> None: else LEGEND_MARKER[dataset.hires_mode] ).center(3) style = dataset.marker_style + elif isinstance(dataset, TreemapPlot): + marker = "███" + style = ( + dataset.styles[0] + if isinstance(dataset.styles, list) + else dataset.styles + ) else: # unsupported dataset type continue @@ -830,11 +1036,17 @@ def _get_legend_origin_coordinates(self, location: LegendLocation) -> Offset: all_labels.extend( [label for label in self._v_lines_labels if label is not None] ) + for dataset in self._datasets: + if isinstance(dataset, TreemapPlot) and dataset.labels: + all_labels.extend(dataset.labels) # markers and lines in the legend are 3 characters wide, plus a space, so 4 max_length = 4 + max((len(s) for s in all_labels), default=0) - if location in (LegendLocation.TOPLEFT, LegendLocation.BOTTOMLEFT): + if location == LegendLocation.LEFT: + x0 = 1 + y0 = self.margin_top + 1 + elif location in (LegendLocation.TOPLEFT, LegendLocation.BOTTOMLEFT): x0 = self.margin_left + 1 else: # LegendLocation is TOPRIGHT or BOTTOMRIGHT @@ -842,7 +1054,9 @@ def _get_legend_origin_coordinates(self, location: LegendLocation) -> Offset: # leave room for the border x0 -= legend.styles.border.spacing.left + legend.styles.border.spacing.right - if location in (LegendLocation.TOPLEFT, LegendLocation.TOPRIGHT): + if location == LegendLocation.LEFT: + pass # y0 already set + elif location in (LegendLocation.TOPLEFT, LegendLocation.TOPRIGHT): y0 = self.margin_top + 1 else: # LegendLocation is BOTTOMLEFT or BOTTOMRIGHT @@ -918,13 +1132,19 @@ def _render_plot(self) -> None: # clear canvas canvas.reset() - # determine axis limits - if self._datasets or self._v_lines: + # Clear treemap hover rects; treemap render will repopulate if applicable + self._treemap_hover_rects = [] + self._treemap_selected_rect = None + self._update_info(None) + + # determine axis limits (skip TreemapPlot - it uses pixel coordinates) + coord_datasets = [d for d in self._datasets if not isinstance(d, TreemapPlot)] + if coord_datasets or self._v_lines: xs = [] ys = [] # Collect x and y values, accounting for bar widths - for dataset in self._datasets: + for dataset in coord_datasets: if isinstance(dataset, BarPlot): # For bar plots, include the left and right edges x_left = dataset.x - dataset.width / 2 @@ -975,6 +1195,8 @@ def _render_plot(self) -> None: self._render_bar_plot(dataset) elif isinstance(dataset, ScatterPlot): self._render_scatter_plot(dataset) + elif isinstance(dataset, TreemapPlot): + self._render_treemap_plot(dataset) # render axis, ticks and labels canvas.draw_rectangle_box( @@ -1209,6 +1431,158 @@ def _render_bar_plot(self, dataset: BarPlot) -> None: x1, y1 = self.get_pixel_from_coordinate(x_right, y_bottom) canvas.draw_filled_rectangle(x0, y0, x1, y1, style=style) + def _render_treemap_plot(self, dataset: TreemapPlot) -> None: + """Render a treemap dataset on the canvas. + + Uses squarify for layout. Rectangles are drawn in canvas coordinates + within the scale rectangle, with optional padding. + + Args: + dataset: The treemap dataset to render. + """ + canvas = self.query_one("#plot", Canvas) + sr = self._scale_rectangle + if sr.width <= 0 or sr.height <= 0: + return + + pad = max(0, dataset.padding) + effective_width = max(1, sr.width - 2 * pad) + effective_height = max(1, sr.height - 2 * pad) + + # Squarify has no ratio param; bias aspect by laying out in modified space + # then scaling. aspect_preference > 1 = prefer wider rects (better for labels) + aspect = max(0.25, min(4.0, dataset.aspect_preference)) + layout_width = effective_width / aspect + layout_height = effective_height + + normalized = squarify.normalize_sizes( + dataset.values.tolist(), layout_width, layout_height + ) + rects = squarify.squarify(normalized, 0, 0, layout_width, layout_height) + + # Scale rects from layout space to actual space (stretch x when aspect > 1) + ox = sr.x + pad + oy = sr.y + pad + + styles = dataset.styles + is_style_list = isinstance(styles, list) + total_value = float(np.sum(dataset.values)) + + # Store rects for hover detection (replace when multiple treemap datasets) + hover_rects: list[dict] = [] + + for i, rect in enumerate(rects): + style = styles[i] if is_style_list else styles + assert isinstance(style, str) + + # Scale from layout space to actual space + rx = rect["x"] * aspect + ry = rect["y"] + rdx = rect["dx"] * aspect + rdy = rect["dy"] + + # Convert to integers - squarify returns floats but canvas expects int pixels + x0 = int(ox + rx) + y0 = int(oy + ry) + x1 = int(x0 + rdx) + y1 = int(y0 + rdy) + + # Ensure at least 1px size for visibility + if x1 <= x0: + x1 = x0 + 1 + if y1 <= y0: + y1 = y0 + 1 + + # Store for hover info box + label_str = ( + dataset.labels[i] + if dataset.labels and i < len(dataset.labels) + else f"Item {i + 1}" + ) + hover_rects.append( + { + "x0": x0, + "y0": y0, + "x1": x1, + "y1": y1, + "label": label_str, + "value": float(dataset.values[i]), + "total": total_value, + "style": style, + "value_display": dataset.value_display, + "currency_symbol": dataset.currency_symbol, + } + ) + + if dataset.hires_mode: + canvas.draw_filled_hires_rectangle( + float(x0), + float(y0), + float(x1), + float(y1), + hires_mode=dataset.hires_mode, + style=style, + ) + else: + canvas.draw_filled_rectangle(x0, y0, x1, y1, style=style) + + # Draw label on rectangle if provided and rect is large enough + if dataset.labels and i < len(dataset.labels): + rect_w = x1 - x0 + rect_h = y1 - y0 + value = dataset.values[i] + pct = 100 * value / total_value if total_value else 0 + value_str = f"{value:,.0f}" if value == int(value) else f"{value:,.1f}" + currency_str = f"{dataset.currency_symbol}{value:,.2f}" + pct_str = f"{pct:.1f}%" + # Build second line based on value_display + if dataset.value_display == ValueDisplay.VALUE: + line2 = value_str + elif dataset.value_display == ValueDisplay.PERCENT: + line2 = pct_str + elif dataset.value_display == ValueDisplay.BOTH: + line2 = f"{value_str} ({pct_str})" + elif dataset.value_display == ValueDisplay.CURRENCY: + line2 = currency_str + else: + line2 = None + needs_two_lines = line2 is not None and rect_h >= 2 + if rect_w >= 3 and rect_h >= (2 if needs_two_lines else 1): + label_line1 = dataset.labels[i] + max_len = max(1, rect_w - 2) + if len(label_line1) > max_len: + label_line1 = label_line1[: max_len - 1] + "…" + if needs_two_lines and len(line2) > max_len: + line2 = line2[: max_len - 1] + "…" + cx = (x0 + x1) // 2 + cy = (y0 + y1) // 2 + bg_rgb = _parse_style_to_rgb(style) + if bg_rgb is not None: + text_color = distinctipy.get_text_color(bg_rgb) + fg = "black" if text_color == (0, 0, 0) else "white" + else: + fg = "white" + style_str = f"bold {fg} on {style}" + if needs_two_lines: + canvas.write_text( + cx, + cy - 1, + f"[{style_str}]{label_line1}", + align=TextAlign.CENTER, + ) + canvas.write_text( + cx, cy, f"[{style_str}]{line2}", align=TextAlign.CENTER + ) + else: + canvas.write_text( + cx, + cy, + f"[{style_str}]{label_line1}", + align=TextAlign.CENTER, + ) + + self._treemap_hover_rects = hover_rects + def _render_v_line_plot(self, vline: VLinePlot) -> None: """Render a vertical line on the canvas. @@ -1232,8 +1606,13 @@ def _render_v_line_plot(self, vline: VLinePlot) -> None: def _render_x_ticks(self) -> None: """Render tick marks and labels for the x-axis.""" - canvas = self.query_one("#plot", Canvas) bottom_margin = self.query_one("#margin-bottom", Canvas) + # Hide ticks when plot contains only treemaps (no numeric axes) + if self._datasets and all(isinstance(d, TreemapPlot) for d in self._datasets): + bottom_margin.reset() + return + + canvas = self.query_one("#plot", Canvas) bottom_margin.reset() x_ticks: Sequence[float] @@ -1277,8 +1656,13 @@ def _render_x_ticks(self) -> None: def _render_y_ticks(self) -> None: """Render tick marks and labels for the y-axis.""" - canvas = self.query_one("#plot", Canvas) left_margin = self.query_one("#margin-left", Canvas) + # Hide ticks when plot contains only treemaps (no numeric axes) + if self._datasets and all(isinstance(d, TreemapPlot) for d in self._datasets): + left_margin.reset() + return + + canvas = self.query_one("#plot", Canvas) left_margin.reset() y_ticks: Sequence[float] @@ -1320,6 +1704,8 @@ def _render_y_ticks(self) -> None: def _render_x_label(self) -> None: """Render the x-axis label.""" + if self._datasets and all(isinstance(d, TreemapPlot) for d in self._datasets): + return canvas = self.query_one("#plot", Canvas) margin = self.query_one("#margin-bottom", Canvas) margin.write_text( @@ -1331,6 +1717,8 @@ def _render_x_label(self) -> None: def _render_y_label(self) -> None: """Render the y-axis label.""" + if self._datasets and all(isinstance(d, TreemapPlot) for d in self._datasets): + return margin = self.query_one("#margin-top", Canvas) margin.write_text( self.margin_left - 2, @@ -1586,6 +1974,54 @@ def stop_dragging_legend(self, event: MouseUp) -> None: self.query_one("#legend").remove_class("dragged") event.stop() + def _get_treemap_rect_at(self, offset: Offset) -> dict | None: + """Return the treemap rect at the given content offset, or None.""" + if not self._treemap_hover_rects: + return None + try: + canvas = self.query_one("#plot", Canvas) + except NoMatches: + return None + cx = offset.x - self.margin_left + cy = offset.y - self.margin_top + if ( + not canvas.size + or cx < 0 + or cy < 0 + or cx >= canvas.size.width + or cy >= canvas.size.height + ): + return None + for rect_info in reversed(self._treemap_hover_rects): + r = rect_info + if r["x0"] <= cx < r["x1"] and r["y0"] <= cy < r["y1"]: + return rect_info + return None + + @on(MouseDown) + def _handle_treemap_click(self, event: MouseDown) -> None: + """Select treemap rectangle on click to pin info box.""" + if event.button != 1: + return + if (offset := event.get_content_offset(self)) is None: + return + rect = self._get_treemap_rect_at(offset) + self._treemap_selected_rect = rect + self._update_info(rect) + + @on(MouseMove) + def _handle_treemap_hover(self, event: MouseMove) -> None: + """Update treemap info box when hovering over rectangles.""" + if not self._treemap_hover_rects: + self._update_info(None) + return + if (offset := event.get_content_offset(self)) is None: + self._update_info(self._treemap_selected_rect) + return + rect = self._get_treemap_rect_at(offset) + # Show hovered rect if any, else show selected rect + self._update_info(rect if rect is not None else self._treemap_selected_rect) + @on(MouseMove) def drag_with_mouse(self, event: MouseMove) -> None: """Handle mouse drag operations for panning the plot or the legend. @@ -1777,6 +2213,14 @@ def linear_mapper( return a_prime + (x - a) * (b_prime - a_prime) / (b - a) +def _parse_style_to_rgb(style: str) -> tuple[float, float, float] | None: + """Parse rgb(r,g,b) style string to (r,g,b) tuple in 0-1 range for distinctipy.""" + m = re.match(r"rgb\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", style, re.I) + if m: + return (int(m.group(1)) / 255, int(m.group(2)) / 255, int(m.group(3)) / 255) + return None + + def drop_nans_and_infs(x: FloatArray, y: FloatArray) -> tuple[FloatArray, FloatArray]: """Drop NaNs and Infs from x and y arrays. diff --git a/tests/test_treemap.py b/tests/test_treemap.py new file mode 100644 index 0000000..18af960 --- /dev/null +++ b/tests/test_treemap.py @@ -0,0 +1,46 @@ +"""Tests for treemap plot functionality.""" + +import numpy as np + +from textual_plot.plot_widget import TreemapPlot, ValueDisplay + + +class TestTreemapPlot: + """Test TreemapPlot dataclass and squarify integration.""" + + def test_treemap_empty_after_filter_returns_early(self) -> None: + """treemap() returns early when all values are NaN/Inf.""" + from textual_plot import PlotWidget + + plot = PlotWidget() + plot.treemap([np.nan, np.inf]) + assert len(plot._datasets) == 0 + + def test_squarify_layout_integration(self) -> None: + """squarify produces valid rectangles for treemap layout.""" + import squarify + + values = [500, 433, 78, 25, 25, 7] + width, height = 100, 50 + normalized = squarify.normalize_sizes(values, width, height) + rects = squarify.squarify(normalized, 0, 0, width, height) + assert len(rects) == len(values) + total_area = sum(r["dx"] * r["dy"] for r in rects) + assert abs(total_area - width * height) < 0.01 # Allow small float error + + def test_treemap_plot_structure(self) -> None: + """TreemapPlot dataclass has expected structure for rendering.""" + values = np.array([10.0, 20.0, 30.0]) + dataset = TreemapPlot( + values=values, + labels=["A", "B", "C"], + styles=["red", "blue", "green"], + padding=1, + hires_mode=None, + aspect_preference=1.5, + value_display=ValueDisplay.BOTH, + currency_symbol="$", + ) + assert len(dataset.values) == 3 + assert dataset.labels == ["A", "B", "C"] + assert dataset.padding == 1 diff --git a/tests/treemap.py b/tests/treemap.py new file mode 100644 index 0000000..1d8360e --- /dev/null +++ b/tests/treemap.py @@ -0,0 +1,27 @@ +"""Demo of treemap plot in textual-plot.""" + +from textual.app import App, ComposeResult + +from textual_plot import PlotWidget + + +class TreemapApp(App[None]): + def compose(self) -> ComposeResult: + yield PlotWidget() + + def on_mount(self) -> None: + plot = self.query_one(PlotWidget) + values = [500, 433, 78, 25, 25, 7] + labels = ["A", "B", "C", "D", "E", "F"] + styles = ["red", "blue", "green", "yellow", "cyan", "magenta"] + plot.treemap( + values, + labels=labels, + styles=styles, + padding=1, + label="Categories", + ) + plot.show_legend() + + +TreemapApp().run() diff --git a/uv.lock b/uv.lock index 39a827f..476d716 100644 --- a/uv.lock +++ b/uv.lock @@ -321,6 +321,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, ] +[[package]] +name = "distinctipy" +version = "1.3.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/8c/3c/e0b90a5bc396e2abaf207d9a41ea8aeab1f41760425262474903bade6a7b/distinctipy-1.3.4.tar.gz", hash = "sha256:fed97afff1afb73ecaa87c85461021f0ba89fae63067c0125b9673526510aac4", size = 29711, upload-time = "2024-01-10T21:32:24.032Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0a/75/fa882538bdb0c8fc4459f1595a761591b827691936d57c08c492676f19bc/distinctipy-1.3.4-py3-none-any.whl", hash = "sha256:2bf57d9d20dbc5c2fd462298573cc963c037f493d04ec61e94cb8d0bf5023c74", size = 26743, upload-time = "2024-01-10T21:32:22.351Z" }, +] + [[package]] name = "exceptiongroup" version = "1.3.1" @@ -1598,6 +1611,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" }, ] +[[package]] +name = "squarify" +version = "0.4.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0d/01/1753243870dff9fa786c9712fdc8dffb56f356c46c77d7468cb12f6d8398/squarify-0.4.4.tar.gz", hash = "sha256:b8a110c8dc5f1cd1402ca12d79764a081e90bfc445346cfa166df929753ecb46", size = 5514, upload-time = "2024-07-19T18:57:41.418Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/3c/eedbe9fb07cc20fd9a8423da14b03bc270d0570b3ba9174a4497156a2152/squarify-0.4.4-py3-none-any.whl", hash = "sha256:d7597724e29d48aa14fd2f551060d6b09e1f0a67e4cd3ea329fe03b4c9a56f11", size = 4082, upload-time = "2024-07-19T18:57:40.338Z" }, +] + [[package]] name = "textual" version = "7.0.2" @@ -1651,8 +1673,10 @@ name = "textual-plot" version = "0.10.1" source = { editable = "." } dependencies = [ + { name = "distinctipy" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "squarify" }, { name = "textual" }, { name = "textual-hires-canvas" }, ] @@ -1669,7 +1693,9 @@ dev = [ [package.metadata] requires-dist = [ + { name = "distinctipy", specifier = ">=1.3.0" }, { name = "numpy", specifier = ">=2.2.1" }, + { name = "squarify", specifier = ">=0.4.3" }, { name = "textual", specifier = ">=1.0.0" }, { name = "textual-hires-canvas", specifier = ">=0.14.0" }, ] From 99ebc0029009c386e33e9bbf7691e8fddc202fe0 Mon Sep 17 00:00:00 2001 From: Ryan Meyers Date: Tue, 17 Feb 2026 11:52:31 -0600 Subject: [PATCH 2/4] Correct hover offset --- src/textual_plot/plot_widget.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/src/textual_plot/plot_widget.py b/src/textual_plot/plot_widget.py index 90b7f00..7e58117 100644 --- a/src/textual_plot/plot_widget.py +++ b/src/textual_plot/plot_widget.py @@ -1974,16 +1974,19 @@ def stop_dragging_legend(self, event: MouseUp) -> None: self.query_one("#legend").remove_class("dragged") event.stop() - def _get_treemap_rect_at(self, offset: Offset) -> dict | None: - """Return the treemap rect at the given content offset, or None.""" + def _get_treemap_rect_at(self, event: MouseMove | MouseDown) -> dict | None: + """Return the treemap rect at the mouse position, or None.""" if not self._treemap_hover_rects: return None try: canvas = self.query_one("#plot", Canvas) except NoMatches: return None - cx = offset.x - self.margin_left - cy = offset.y - self.margin_top + # Use screen offset relative to canvas for accurate hit-testing (avoids + # coordinate system mismatch with get_content_offset on nested widgets) + canvas_offset = event.screen_offset - self.screen.get_offset(canvas) + cx = canvas_offset.x + cy = canvas_offset.y if ( not canvas.size or cx < 0 @@ -2003,9 +2006,7 @@ def _handle_treemap_click(self, event: MouseDown) -> None: """Select treemap rectangle on click to pin info box.""" if event.button != 1: return - if (offset := event.get_content_offset(self)) is None: - return - rect = self._get_treemap_rect_at(offset) + rect = self._get_treemap_rect_at(event) self._treemap_selected_rect = rect self._update_info(rect) @@ -2015,10 +2016,7 @@ def _handle_treemap_hover(self, event: MouseMove) -> None: if not self._treemap_hover_rects: self._update_info(None) return - if (offset := event.get_content_offset(self)) is None: - self._update_info(self._treemap_selected_rect) - return - rect = self._get_treemap_rect_at(offset) + rect = self._get_treemap_rect_at(event) # Show hovered rect if any, else show selected rect self._update_info(rect if rect is not None else self._treemap_selected_rect) From 9bb6fc67c9a2a63f8670f0676aca58c8fb8e084e Mon Sep 17 00:00:00 2001 From: Ryan Meyers Date: Wed, 18 Feb 2026 09:45:41 -0600 Subject: [PATCH 3/4] Nested treeplots --- src/textual_plot/color_utils.py | 46 +++ src/textual_plot/demo.py | 66 ++++ src/textual_plot/plot_widget.py | 518 ++++++++++++++++++++++++++---- src/textual_plot/treemap_utils.py | 331 +++++++++++++++++++ tests/test_treemap.py | 35 ++ 5 files changed, 938 insertions(+), 58 deletions(-) create mode 100644 src/textual_plot/color_utils.py create mode 100644 src/textual_plot/treemap_utils.py diff --git a/src/textual_plot/color_utils.py b/src/textual_plot/color_utils.py new file mode 100644 index 0000000..8f12458 --- /dev/null +++ b/src/textual_plot/color_utils.py @@ -0,0 +1,46 @@ +"""Color utilities for treemap and plot styling.""" + +from __future__ import annotations + +import colorsys +import re + +import distinctipy + + +def parse_style_to_rgb(style: str) -> tuple[float, float, float] | None: + """Parse rgb(r,g,b) style string to (r,g,b) tuple in 0-1 range.""" + m = re.match(r"rgb\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", style, re.I) + if m: + return (int(m.group(1)) / 255, int(m.group(2)) / 255, int(m.group(3)) / 255) + return None + + +def adjust_luminance( + rgb: tuple[float, float, float], factor: float +) -> tuple[float, float, float]: + """Adjust luminance of RGB (0-1) by factor. factor>1 lightens, factor<1 darkens.""" + h, lightness, s = colorsys.rgb_to_hls(rgb[0], rgb[1], rgb[2]) + lightness = max(0, min(1, lightness * factor)) + return colorsys.hls_to_rgb(h, lightness, s) + + +def tint_with_hue( + child_rgb: tuple[float, float, float], + parent_rgb: tuple[float, float, float], +) -> tuple[float, float, float]: + """Apply parent's hue to child color; keep child's saturation and value for distinction.""" + ph, ps, pv = colorsys.rgb_to_hsv(parent_rgb[0], parent_rgb[1], parent_rgb[2]) + ch, cs, cv = colorsys.rgb_to_hsv(child_rgb[0], child_rgb[1], child_rgb[2]) + r, g, b = colorsys.hsv_to_rgb(ph, cs, cv) + return (r, g, b) + + +def rgb_too_close(a: tuple[float, float, float], b: tuple[float, float, float]) -> bool: + """True if a and b are too similar (perceptual distance).""" + return distinctipy.color_distance(a, b) < 0.02 + + +def rgb_style(rgb: tuple[float, float, float]) -> str: + """Convert (r,g,b) 0-1 to rgb(r,g,b) style string.""" + return f"rgb({int(rgb[0] * 255)},{int(rgb[1] * 255)},{int(rgb[2] * 255)})" diff --git a/src/textual_plot/demo.py b/src/textual_plot/demo.py index 3c2f2ef..52b384a 100644 --- a/src/textual_plot/demo.py +++ b/src/textual_plot/demo.py @@ -373,6 +373,70 @@ def action_cycle_hires_mode(self) -> None: self.plot() +class NestedTreemapPlot(Container): + """Nested treemap: click to zoom in, Escape to zoom out, arrows to select.""" + + BINDINGS = [("h", "cycle_hires_mode", "HiRes")] + + _hires_mode = itertools.cycle([None, HiResMode.BRAILLE]) + hires_mode = next(_hires_mode) + + def compose(self) -> ComposeResult: + yield PlotWidget() + + def on_mount(self) -> None: + self.plot() + + def plot(self) -> None: + plot = self.query_one(PlotWidget) + plot.clear() + # Nested: show_nested=True draws full hierarchy with luminance variance + plot.treemap( + [ + { + "label": "Electronics", + "children": [ + { + "label": "Phones", + "children": [ + {"label": "iPhone", "value": 55}, + {"label": "Android", "value": 35}, + {"label": "Other", "value": 10}, + ], + }, + {"label": "Laptops", "value": 50}, + {"label": "Tablets", "value": 25}, + {"label": "Accessories", "value": 25}, + ], + }, + { + "label": "Clothing", + "children": [ + {"label": "Shirts", "value": 80}, + {"label": "Pants", "value": 60}, + {"label": "Shoes", "value": 40}, + ], + }, + { + "label": "Home", + "children": [ + {"label": "Furniture", "value": 90}, + {"label": "Decor", "value": 45}, + {"label": "Kitchen", "value": 30}, + ], + }, + ], + padding=1, + hires_mode=self.hires_mode, + show_nested=True, + ) + plot.show_legend() + + def action_cycle_hires_mode(self) -> None: + self.hires_mode = next(self._hires_mode) + self.plot() + + class DemoApp(App[None]): AUTO_FOCUS = "SinePlot > PlotWidget" @@ -399,6 +463,8 @@ def compose(self) -> ComposeResult: yield BarPlot() with TabPane("Treemap", id="treemap"): yield TreemapPlot() + with TabPane("Nested Treemap", id="nested_treemap"): + yield NestedTreemapPlot() def on_mount(self) -> None: self.theme = "tokyo-night" diff --git a/src/textual_plot/plot_widget.py b/src/textual_plot/plot_widget.py index 7e58117..17c7904 100644 --- a/src/textual_plot/plot_widget.py +++ b/src/textual_plot/plot_widget.py @@ -9,13 +9,14 @@ from __future__ import annotations import enum -import re import sys from dataclasses import dataclass from math import ceil, floor from statistics import mean -from typing import Sequence, TypeAlias +from typing import Any, Sequence, TypeAlias +import distinctipy +import numpy as np from rich.text import Text from textual.binding import Binding @@ -23,8 +24,7 @@ from typing import Self else: from typing_extensions import Self -import distinctipy -import numpy as np + import squarify # type: ignore[import-untyped] from numpy.typing import ArrayLike, NDArray from textual import on @@ -33,6 +33,7 @@ from textual.containers import Grid from textual.css.query import NoMatches from textual.events import ( + Key, MouseDown, MouseMove, MouseScrollDown, @@ -51,12 +52,25 @@ CategoricalAxisFormatter, NumericAxisFormatter, ) +from textual_plot.color_utils import parse_style_to_rgb, rgb_style +from textual_plot.treemap_utils import ( + TREEMAP_PARENT_LABEL_TOP, + format_treemap_nested_path, + get_path_styles, + get_treemap_level, + normalize_treemap_tree, + squarify_recursive, + treemap_max_depth, +) __all__ = ["HiResMode", "LegendLocation", "PlotWidget", "ValueDisplay"] FloatScalar: TypeAlias = float | np.floating FloatArray: TypeAlias = NDArray[np.floating] +# Nested treemap: each node is {"value": float, "label": str, "children": list | None} +TreemapNode: TypeAlias = dict[str, Any] + LEGEND_LINE = { None: "███", @@ -170,6 +184,9 @@ class TreemapPlot: aspect_preference: Bias toward wider (>1) or taller (<1) rectangles. value_display: What to show on the second line of labels. currency_symbol: Symbol for CURRENCY display mode (e.g. "$", "€"). + tree: Normalized tree structure for nested treemaps (when is_nested). + show_nested: When True and tree is nested, draw full hierarchy with + luminance variance (children darker) instead of zoom-in/zoom-out. """ values: FloatArray @@ -180,6 +197,8 @@ class TreemapPlot: aspect_preference: float value_display: ValueDisplay currency_symbol: str + tree: list[TreemapNode] | None = None + show_nested: bool = False @dataclass @@ -313,6 +332,7 @@ class ScaleChanged(Message): Binding("up", "pan_up", "Pan up", group=PAN_GROUP), Binding("down", "pan_down", "Pan down", group=PAN_GROUP), ("r", "reset_scales", "Reset scales"), + Binding("escape", "treemap_zoom_out", "Treemap zoom out", show=False), ] margin_top = reactive(2) @@ -390,6 +410,8 @@ def __init__( self._v_lines_labels: list[str | None] = [] self._treemap_hover_rects: list[dict] = [] self._treemap_selected_rect: dict | None = None + self._treemap_path: list[int] = [] + self._treemap_selected_index: int | None = None self._allow_pan_and_zoom = allow_pan_and_zoom self.invert_mouse_wheel = invert_mouse_wheel self._x_formatter = NumericAxisFormatter() @@ -446,6 +468,8 @@ def clear(self) -> None: self._v_lines_labels = [] self._treemap_hover_rects = [] self._treemap_selected_rect = None + self._treemap_path = [] + self._treemap_selected_index = None self._update_info(None) self._update_legend() self._rerender() @@ -662,6 +686,7 @@ def treemap( aspect_preference: float = 1.5, value_display: ValueDisplay | str = ValueDisplay.BOTH, currency_symbol: str = "$", + show_nested: bool = False, ) -> None: """Graph data as a treemap. @@ -689,41 +714,44 @@ def treemap( "percent", "both", "currency", or "none". Defaults to "both". currency_symbol: Symbol for currency display mode (e.g. "$", "€"). Defaults to "$". + show_nested: When True and data is nested, draw the full hierarchy + at once with luminance variance (children darker) instead of + zoom-in/zoom-out. Defaults to False. """ vd = ( ValueDisplay(value_display) if isinstance(value_display, str) else value_display ) - values_array = np.array(values, dtype=float) - values_array = values_array[~np.isnan(values_array) & ~np.isinf(values_array)] - if len(values_array) == 0: + tree_nodes, is_nested = normalize_treemap_tree(values, labels) + if not tree_nodes: return - # Default colors: use distinctipy for perceptually distinct CIELAB-inspired palette - def _rgb_style(rgb: tuple[float, float, float]) -> str: - r, g, b = rgb - return f"rgb({int(r * 255)},{int(g * 255)},{int(b * 255)})" + values_array = np.array([n["value"] for n in tree_nodes], dtype=float) + labels_list = [n["label"] for n in tree_nodes] + # Default colors: use distinctipy for perceptually distinct CIELAB-inspired palette n = len(values_array) if styles is None: distinct_colors = distinctipy.get_colors( - n, pastel_factor=0.2, rng=42 + n, + pastel_factor=0.2, + rng=42, + # colorblind_type="Deuteranomaly", ) # deterministic, slightly pastel for terminal contrast - styles_list = [_rgb_style(c) for c in distinct_colors] + styles_list = [rgb_style(c) for c in distinct_colors] elif isinstance(styles, str): styles_list = [styles] * n else: styles_list = list(styles) if len(styles_list) < n: extra = distinctipy.get_colors( - n - len(styles_list), pastel_factor=0.2, rng=42 + n - len(styles_list), + pastel_factor=0.2, + rng=42, + # colorblind_type="Deuteranomaly", ) - styles_list.extend(_rgb_style(c) for c in extra) - - labels_list: list[str] | None = None - if labels is not None: - labels_list = list(labels)[: len(values_array)] + styles_list.extend(rgb_style(c) for c in extra) self._datasets.append( TreemapPlot( @@ -735,6 +763,8 @@ def _rgb_style(rgb: tuple[float, float, float]) -> str: aspect_preference=aspect_preference, value_display=vd, currency_symbol=currency_symbol, + tree=tree_nodes if is_nested else None, + show_nested=show_nested and is_nested, ) ) self._labels.append(label) @@ -893,11 +923,28 @@ def _update_info(self, rect_info: dict | None) -> None: else: value_str = f"{value:,.0f}" if value == int(value) else f"{value:,.1f}" style = rect_info.get("style", "") - title = Text("███") - title.stylize(style) - title.append( - f" {rect_info['label']} · Value: {value_str} · Percent: {pct:.1f}%" - ) + path = rect_info.get("path") + tree = rect_info.get("tree") + path_to_style = rect_info.get("path_to_style") + if path is not None and tree is not None and len(path) > 0 and path_to_style: + segments, _ = format_treemap_nested_path( + tree, path, total, value_display, currency_symbol + ) + path_styles = get_path_styles(path, path_to_style) + title = Text() + for i, (seg_text, _) in enumerate(segments): + seg_style = path_styles[i] if i < len(path_styles) else style + title.append(" ███ ", style=seg_style) + title.append(f" {seg_text}") + if i < len(segments) - 1: + title.append(" ▶ ") + else: + title = Text("███") + title.stylize(style) + content = ( + f"{rect_info['label']} · Value: {value_str} · Percent: {pct:.1f}%" + ) + title.append(f" {content}") info_box.update(title) info_box.display = True # Position in margin-bottom (x-axis label) area - unused for treemap @@ -1135,6 +1182,7 @@ def _render_plot(self) -> None: # Clear treemap hover rects; treemap render will repopulate if applicable self._treemap_hover_rects = [] self._treemap_selected_rect = None + # Preserve _treemap_path and _treemap_selected_index for nested zoom state self._update_info(None) # determine axis limits (skip TreemapPlot - it uses pixel coordinates) @@ -1435,44 +1483,91 @@ def _render_treemap_plot(self, dataset: TreemapPlot) -> None: """Render a treemap dataset on the canvas. Uses squarify for layout. Rectangles are drawn in canvas coordinates - within the scale rectangle, with optional padding. - - Args: - dataset: The treemap dataset to render. + within the scale rectangle, with optional padding. For nested data, + renders the current level based on _treemap_path, or the full hierarchy + with luminance variance when show_nested=True. """ canvas = self.query_one("#plot", Canvas) sr = self._scale_rectangle if sr.width <= 0 or sr.height <= 0: return + # Nested full-hierarchy mode: draw all levels with luminance variance + if getattr(dataset, "show_nested", False) and dataset.tree is not None: + self._render_treemap_nested(canvas, dataset, sr) + return + + # Get current level: from tree if nested, else from dataset + if dataset.tree is not None: + level_nodes = get_treemap_level(dataset.tree, self._treemap_path) + if not level_nodes: + return + values_at_level = np.array([n["value"] for n in level_nodes]) + labels_at_level = [n["label"] for n in level_nodes] + has_children_list = [n.get("children") is not None for n in level_nodes] + else: + level_nodes = None + values_at_level = dataset.values + labels_at_level = dataset.labels or [ + f"Item {i + 1}" for i in range(len(dataset.values)) + ] + has_children_list = [False] * len(values_at_level) + + n = len(values_at_level) + if n == 0: + return + pad = max(0, dataset.padding) effective_width = max(1, sr.width - 2 * pad) effective_height = max(1, sr.height - 2 * pad) - # Squarify has no ratio param; bias aspect by laying out in modified space - # then scaling. aspect_preference > 1 = prefer wider rects (better for labels) aspect = max(0.25, min(4.0, dataset.aspect_preference)) layout_width = effective_width / aspect layout_height = effective_height normalized = squarify.normalize_sizes( - dataset.values.tolist(), layout_width, layout_height + values_at_level.tolist(), layout_width, layout_height ) rects = squarify.squarify(normalized, 0, 0, layout_width, layout_height) - # Scale rects from layout space to actual space (stretch x when aspect > 1) ox = sr.x + pad oy = sr.y + pad - styles = dataset.styles - is_style_list = isinstance(styles, list) - total_value = float(np.sum(dataset.values)) + # Styles for current level (may differ from dataset.styles when zoomed) + if dataset.tree is not None and n != len(dataset.styles): + distinct_colors = distinctipy.get_colors( + n, + pastel_factor=0.2, + rng=42, + # colorblind_type="Deuteranomaly", + ) + styles_at_level = [rgb_style(c) for c in distinct_colors] + else: + styles_at_level = ( + dataset.styles + if isinstance(dataset.styles, list) + else [dataset.styles] * n + ) + if len(styles_at_level) < n: + extra = distinctipy.get_colors( + n - len(styles_at_level), + pastel_factor=0.2, + rng=42, + # colorblind_type="Deuteranomaly", + ) + styles_at_level.extend(rgb_style(c) for c in extra) - # Store rects for hover detection (replace when multiple treemap datasets) + total_value = float(np.sum(values_at_level)) hover_rects: list[dict] = [] + selected_idx = self._treemap_selected_index + if selected_idx is not None and (selected_idx < 0 or selected_idx >= n): + selected_idx = None + for i, rect in enumerate(rects): - style = styles[i] if is_style_list else styles + style = ( + styles_at_level[i] if i < len(styles_at_level) else styles_at_level[0] + ) assert isinstance(style, str) # Scale from layout space to actual space @@ -1493,11 +1588,8 @@ def _render_treemap_plot(self, dataset: TreemapPlot) -> None: if y1 <= y0: y1 = y0 + 1 - # Store for hover info box label_str = ( - dataset.labels[i] - if dataset.labels and i < len(dataset.labels) - else f"Item {i + 1}" + labels_at_level[i] if i < len(labels_at_level) else f"Item {i + 1}" ) hover_rects.append( { @@ -1506,15 +1598,32 @@ def _render_treemap_plot(self, dataset: TreemapPlot) -> None: "x1": x1, "y1": y1, "label": label_str, - "value": float(dataset.values[i]), + "value": float(values_at_level[i]), "total": total_value, "style": style, "value_display": dataset.value_display, "currency_symbol": dataset.currency_symbol, + "has_children": has_children_list[i] + if i < len(has_children_list) + else False, + "node_index": i, } ) - if dataset.hires_mode: + # Selection fill: hires mode uses solid for selected; otherwise braille + if selected_idx == i: + if dataset.hires_mode: + canvas.draw_filled_rectangle(x0, y0, x1, y1, style=style) + else: + canvas.draw_filled_hires_rectangle( + float(x0), + float(y0), + float(x1), + float(y1), + hires_mode=HiResMode.BRAILLE, + style=style, + ) + elif dataset.hires_mode: canvas.draw_filled_hires_rectangle( float(x0), float(y0), @@ -1527,10 +1636,10 @@ def _render_treemap_plot(self, dataset: TreemapPlot) -> None: canvas.draw_filled_rectangle(x0, y0, x1, y1, style=style) # Draw label on rectangle if provided and rect is large enough - if dataset.labels and i < len(dataset.labels): + if i < len(labels_at_level): rect_w = x1 - x0 rect_h = y1 - y0 - value = dataset.values[i] + value = values_at_level[i] pct = 100 * value / total_value if total_value else 0 value_str = f"{value:,.0f}" if value == int(value) else f"{value:,.1f}" currency_str = f"{dataset.currency_symbol}{value:,.2f}" @@ -1548,7 +1657,7 @@ def _render_treemap_plot(self, dataset: TreemapPlot) -> None: line2 = None needs_two_lines = line2 is not None and rect_h >= 2 if rect_w >= 3 and rect_h >= (2 if needs_two_lines else 1): - label_line1 = dataset.labels[i] + label_line1 = labels_at_level[i] max_len = max(1, rect_w - 2) if len(label_line1) > max_len: label_line1 = label_line1[: max_len - 1] + "…" @@ -1556,13 +1665,17 @@ def _render_treemap_plot(self, dataset: TreemapPlot) -> None: line2 = line2[: max_len - 1] + "…" cx = (x0 + x1) // 2 cy = (y0 + y1) // 2 - bg_rgb = _parse_style_to_rgb(style) + bg_rgb = parse_style_to_rgb(style) if bg_rgb is not None: text_color = distinctipy.get_text_color(bg_rgb) fg = "black" if text_color == (0, 0, 0) else "white" else: fg = "white" - style_str = f"bold {fg} on {style}" + style_str = ( + f"bold {fg} on {style}" + if selected_idx == i + else f"{fg} on {style}" + ) if needs_two_lines: canvas.write_text( cx, @@ -1582,6 +1695,190 @@ def _render_treemap_plot(self, dataset: TreemapPlot) -> None: ) self._treemap_hover_rects = hover_rects + # Restore selected rect from new hover_rects for info box + if selected_idx is not None and selected_idx < len(hover_rects): + self._treemap_selected_rect = hover_rects[selected_idx] + self._update_info(self._treemap_selected_rect) + + def _render_treemap_nested( + self, canvas: Canvas, dataset: TreemapPlot, sr: Region + ) -> None: + """Render full nested hierarchy with luminance variance (children darker).""" + pad = max(0, dataset.padding) + effective_width = max(1, sr.width - 2 * pad) + effective_height = max(1, sr.height - 2 * pad) + aspect = max(0.25, min(4.0, dataset.aspect_preference)) + ox = sr.x + pad + oy = sr.y + pad + + n_top = len(dataset.tree) + if isinstance(dataset.styles, list) and len(dataset.styles) >= n_top: + base_styles = list(dataset.styles[:n_top]) + else: + distinct = distinctipy.get_colors( + n_top, + pastel_factor=0.2, + rng=42, + colorblind_type="Deuteranomaly", + ) + base_styles = [rgb_style(c) for c in distinct] + + layout_w = effective_width / aspect + # Reserve vertical space for parent label rows (3 extra per nesting level: 2 top + 1 bottom) + extra_rows = 3 * max(0, treemap_max_depth(dataset.tree) - 1) + layout_h = max(1, effective_height - extra_rows) + rect_infos = squarify_recursive( + dataset.tree, 0, 0, layout_w, layout_h, aspect, [], base_styles, [] + ) + path_to_style = {tuple(r["path"]): r["style"] for r in rect_infos} + + total_value = sum(r["value"] for r in rect_infos) + hover_rects: list[dict] = [] + selected_idx = self._treemap_selected_index + if selected_idx is not None and ( + selected_idx < 0 or selected_idx >= len(rect_infos) + ): + selected_idx = None + if selected_idx is None and rect_infos: + selected_idx = 0 + self._treemap_selected_index = 0 + + for i, info in enumerate(rect_infos): + rx, ry = info["x"], info["y"] + rdx, rdy = info["dx"], info["dy"] + x0 = int(ox + rx * aspect) + y0 = int(oy + ry) + x1 = int(x0 + rdx * aspect) + y1 = int(y0 + rdy) + if x1 <= x0: + x1 = x0 + 1 + if y1 <= y0: + y1 = y0 + 1 + + style = info["style"] + rect_info = { + "x0": x0, + "y0": y0, + "x1": x1, + "y1": y1, + "label": info["label"], + "value": info["value"], + "total": total_value, + "style": style, + "value_display": dataset.value_display, + "currency_symbol": dataset.currency_symbol, + "has_children": info.get("has_children", False), + "node_index": i, + "path": info["path"], + "tree": dataset.tree, + "path_to_style": path_to_style, + } + hover_rects.append(rect_info) + + draw_style = style + # Selected rect: hires mode uses solid fill; otherwise use braille for distinction + if selected_idx == i: + if dataset.hires_mode: + canvas.draw_filled_rectangle(x0, y0, x1, y1, style=draw_style) + else: + canvas.draw_filled_hires_rectangle( + float(x0), + float(y0), + float(x1), + float(y1), + hires_mode=HiResMode.BRAILLE, + style=draw_style, + ) + elif dataset.hires_mode: + canvas.draw_filled_hires_rectangle( + float(x0), + float(y0), + float(x1), + float(y1), + hires_mode=dataset.hires_mode, + style=draw_style, + ) + else: + canvas.draw_filled_rectangle(x0, y0, x1, y1, style=draw_style) + + rect_w = x1 - x0 + rect_h = y1 - y0 + value = info["value"] + pct = 100 * value / total_value if total_value else 0 + value_str = f"{value:,.0f}" if value == int(value) else f"{value:,.1f}" + currency_str = f"{dataset.currency_symbol}{value:,.2f}" + pct_str = f"{pct:.1f}%" + if dataset.value_display == ValueDisplay.VALUE: + line2 = value_str + elif dataset.value_display == ValueDisplay.PERCENT: + line2 = pct_str + elif dataset.value_display == ValueDisplay.BOTH: + line2 = f"{value_str} ({pct_str})" + elif dataset.value_display == ValueDisplay.CURRENCY: + line2 = currency_str + else: + line2 = None + needs_two_lines = line2 is not None and rect_h >= 2 + is_parent = info.get("has_children", False) + if rect_w >= 3: + cx = (x0 + x1) // 2 + bg_rgb = parse_style_to_rgb(draw_style) + if bg_rgb is not None: + text_color = distinctipy.get_text_color(bg_rgb) + fg = "black" if text_color == (0, 0, 0) else "white" + else: + fg = "white" + style_str = ( + f"bold {fg} on {draw_style}" + if selected_idx == i + else f"{fg} on {draw_style}" + ) + max_len = max(1, rect_w - 2) + if is_parent: + # Parent: label only, no values, in center row of top 3 + label_line1 = info["label"] + if len(label_line1) > max_len: + label_line1 = label_line1[: max_len - 1] + "…" + cy = y0 + 1 # center of top 3 rows + if rect_h >= TREEMAP_PARENT_LABEL_TOP: + canvas.write_text( + cx, + cy, + f"[{style_str}]{label_line1}", + align=TextAlign.CENTER, + ) + elif rect_h >= (2 if needs_two_lines else 1): + # Leaf: label and value + label_line1 = info["label"] + if len(label_line1) > max_len: + label_line1 = label_line1[: max_len - 1] + "…" + if needs_two_lines and line2 and len(line2) > max_len: + line2 = line2[: max_len - 1] + "…" + cy = (y0 + y1) // 2 + if needs_two_lines and line2: + canvas.write_text( + cx, + cy - 1, + f"[{style_str}]{label_line1}", + align=TextAlign.CENTER, + ) + canvas.write_text( + cx, cy, f"[{style_str}]{line2}", align=TextAlign.CENTER + ) + else: + canvas.write_text( + cx, + cy, + f"[{style_str}]{label_line1}", + align=TextAlign.CENTER, + ) + + self._treemap_hover_rects = hover_rects + if selected_idx is not None and selected_idx < len(hover_rects): + self._treemap_selected_rect = hover_rects[selected_idx] + else: + self._treemap_selected_rect = None + self._update_info(self._treemap_selected_rect) def _render_v_line_plot(self, vline: VLinePlot) -> None: """Render a vertical line on the canvas. @@ -1949,6 +2246,108 @@ def action_pan_down(self) -> None: """Pan the plot downward.""" self._pan(0, -self.KEYBOARD_PAN_FACTOR) + def _is_treemap_only_nested(self) -> bool: + """True if plot has only treemap datasets with nested (zoomable) data.""" + if not self._datasets or not all( + isinstance(d, TreemapPlot) for d in self._datasets + ): + return False + return any( + d.tree is not None for d in self._datasets if isinstance(d, TreemapPlot) + ) + + def _is_treemap_show_nested(self) -> bool: + """True if plot has a treemap with show_nested=True (full hierarchy view).""" + return any( + getattr(d, "show_nested", False) + for d in self._datasets + if isinstance(d, TreemapPlot) + ) + + def _handle_treemap_key(self, key: str) -> bool: + """Handle treemap-specific keys. Returns True if key was handled.""" + if not self._is_treemap_only_nested(): + return False + dataset = next( + (d for d in self._datasets if isinstance(d, TreemapPlot) and d.tree), None + ) + if dataset is None: + return False + + # In show_nested mode, use hover_rects count (all leaves); else use level count + if self._is_treemap_show_nested() and self._treemap_hover_rects: + n = len(self._treemap_hover_rects) + else: + n = len(get_treemap_level(dataset.tree, self._treemap_path)) + if n == 0: + return False + + if key == "escape": + if self._treemap_path and not self._is_treemap_show_nested(): + self._treemap_path.pop() + self._treemap_selected_index = min( + self._treemap_selected_index or 0, + len(get_treemap_level(dataset.tree, self._treemap_path)) - 1, + ) + self._rerender() + return True + + if key in ("plus", "equal", "+"): + if not self._is_treemap_show_nested(): + idx = ( + self._treemap_selected_index + if self._treemap_selected_index is not None + else 0 + ) + if self._treemap_hover_rects and idx < len(self._treemap_hover_rects): + rect = self._treemap_hover_rects[idx] + if rect.get("has_children"): + self._treemap_path.append(rect["node_index"]) + self._treemap_selected_index = 0 + self._rerender() + return True + + if key in ("minus", "-"): + if self._treemap_path and not self._is_treemap_show_nested(): + self._treemap_path.pop() + self._treemap_selected_index = min( + self._treemap_selected_index or 0, + len(get_treemap_level(dataset.tree, self._treemap_path)) - 1, + ) + self._rerender() + return True + + if key in ("left", "right", "up", "down"): + idx = self._treemap_selected_index + if idx is None: + idx = 0 + if key == "left": + idx = max(0, idx - 1) + elif key == "right": + idx = min(n - 1, idx + 1) + elif key == "up": + idx = max(0, idx - 1) + elif key == "down": + idx = min(n - 1, idx + 1) + self._treemap_selected_index = idx + if self._treemap_hover_rects and idx < len(self._treemap_hover_rects): + self._treemap_selected_rect = self._treemap_hover_rects[idx] + self._update_info(self._treemap_selected_rect) + self._rerender() + return True + + return False + + @on(Key) + def _on_key(self, event: Key) -> None: + """Handle keys; treemap-specific handling takes precedence when applicable.""" + if self._handle_treemap_key(event.key): + event.stop() + + def action_treemap_zoom_out(self) -> None: + """Zoom out of nested treemap (bound to Escape).""" + self._handle_treemap_key("escape") + @on(MouseDown) def start_dragging_legend(self, event: MouseDown) -> None: """Start dragging the legend when clicked with left mouse button. @@ -2003,12 +2402,20 @@ def _get_treemap_rect_at(self, event: MouseMove | MouseDown) -> dict | None: @on(MouseDown) def _handle_treemap_click(self, event: MouseDown) -> None: - """Select treemap rectangle on click to pin info box.""" + """Select treemap rectangle on click; zoom in if it has children (non-show_nested only).""" if event.button != 1: return rect = self._get_treemap_rect_at(event) + if rect is None: + return self._treemap_selected_rect = rect + self._treemap_selected_index = rect.get("node_index", 0) + # Zoom in if rect has children (only when not in show_nested full-hierarchy view) + if rect.get("has_children") and not self._is_treemap_show_nested(): + self._treemap_path.append(rect["node_index"]) + self._treemap_selected_index = 0 self._update_info(rect) + self._rerender() @on(MouseMove) def _handle_treemap_hover(self, event: MouseMove) -> None: @@ -2016,8 +2423,11 @@ def _handle_treemap_hover(self, event: MouseMove) -> None: if not self._treemap_hover_rects: self._update_info(None) return + # In show_nested mode: only show selected, not hover + if self._is_treemap_show_nested(): + self._update_info(self._treemap_selected_rect) + return rect = self._get_treemap_rect_at(event) - # Show hovered rect if any, else show selected rect self._update_info(rect if rect is not None else self._treemap_selected_rect) @on(MouseMove) @@ -2211,14 +2621,6 @@ def linear_mapper( return a_prime + (x - a) * (b_prime - a_prime) / (b - a) -def _parse_style_to_rgb(style: str) -> tuple[float, float, float] | None: - """Parse rgb(r,g,b) style string to (r,g,b) tuple in 0-1 range for distinctipy.""" - m = re.match(r"rgb\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", style, re.I) - if m: - return (int(m.group(1)) / 255, int(m.group(2)) / 255, int(m.group(3)) / 255) - return None - - def drop_nans_and_infs(x: FloatArray, y: FloatArray) -> tuple[FloatArray, FloatArray]: """Drop NaNs and Infs from x and y arrays. diff --git a/src/textual_plot/treemap_utils.py b/src/textual_plot/treemap_utils.py new file mode 100644 index 0000000..9c7aaa4 --- /dev/null +++ b/src/textual_plot/treemap_utils.py @@ -0,0 +1,331 @@ +"""Treemap layout and tree utilities.""" + +from __future__ import annotations + +from typing import Any, Sequence + +import distinctipy +import numpy as np +import squarify # type: ignore[import-untyped] +from numpy.typing import ArrayLike + +from textual_plot.color_utils import ( + parse_style_to_rgb, + rgb_style, + rgb_too_close, + tint_with_hue, +) + +# 1-character border around children so parent is visible and selectable +TREEMAP_PARENT_PAD = 1 +# Top inset for parent labels (3 rows); sides/bottom stay 1 +TREEMAP_PARENT_LABEL_TOP = 3 + +TreemapNode: type = dict[str, Any] + + +def treemap_max_depth(nodes: list[TreemapNode]) -> int: + """Max nesting depth of tree (1 = no nesting, 2 = one level of children, etc).""" + if not nodes: + return 0 + return 1 + max(treemap_max_depth(n.get("children") or []) for n in nodes) + + +def normalize_treemap_tree( + values: ArrayLike | list[Any] | list[list[float]], + labels: Sequence[str] | Sequence[Sequence[str]] | None = None, +) -> tuple[list[TreemapNode], bool]: + """Normalize treemap input to a tree of nodes. Returns (nodes, is_nested).""" + if not values: + return ([], False) + + try: + vals_list = list(values) + except TypeError: + vals_list = np.array(values).tolist() + first = vals_list[0] if len(vals_list) > 0 else None + + # Flat: list of numbers + if isinstance(first, (int, float, np.floating, np.integer)): + arr = np.array(values, dtype=float) + arr = arr[~np.isnan(arr) & ~np.isinf(arr)] + n = len(arr) + labels_list = list(labels)[:n] if labels else None + nodes = [ + { + "value": float(arr[i]), + "label": ( + labels_list[i] + if labels_list and i < len(labels_list) + else f"Item {i + 1}" + ), + "children": None, + } + for i in range(n) + ] + return (nodes, False) + + # Nested: list of dicts with "label" and "children" + if isinstance(first, dict): + nodes = [] + for i, item in enumerate(vals_list): + if not isinstance(item, dict): + continue + children_raw = item.get("children") + label = item.get("label", f"Item {i + 1}") + if children_raw is not None: + child_labels = None + if labels and i < len(labels): + lab = labels[i] + if isinstance(lab, (list, tuple)): + child_labels = lab + sub_nodes, _ = normalize_treemap_tree(children_raw, child_labels) + value = sum(n["value"] for n in sub_nodes) + nodes.append({"value": value, "label": label, "children": sub_nodes}) + else: + value = float(item.get("value", 0)) + nodes.append({"value": value, "label": label, "children": None}) + return (nodes, any(n.get("children") for n in nodes)) + + # Nested: list of lists [[100, 50], [30, 20]] + if isinstance(first, (list, tuple)): + labels_nested = list(labels) if labels else [] + top_labels = ( + labels_nested[0] + if labels_nested and isinstance(labels_nested[0], (list, tuple)) + else labels_nested + ) + child_labels_list = ( + labels_nested[1] + if len(labels_nested) > 1 and isinstance(labels_nested[1], (list, tuple)) + else None + ) + nodes = [] + for i, group in enumerate(vals_list): + if not isinstance(group, (list, tuple)): + continue + sub_labels = ( + child_labels_list[i] + if child_labels_list and i < len(child_labels_list) + else None + ) + sub_nodes, _ = normalize_treemap_tree(group, sub_labels) + value = sum(n["value"] for n in sub_nodes) + group_label = f"Group {i + 1}" + if top_labels and i < len(top_labels) and isinstance(top_labels[i], str): + group_label = top_labels[i] + nodes.append({"value": value, "label": group_label, "children": sub_nodes}) + return (nodes, True) + + return ([], False) + + +def get_treemap_level(nodes: list[TreemapNode], path: list[int]) -> list[TreemapNode]: + """Get the node list at the given path. path=[] returns nodes.""" + current: list[TreemapNode] = nodes + for idx in path: + if idx < 0 or idx >= len(current): + return [] + current = current[idx].get("children") or [] + return current + + +def get_treemap_node_at_path( + tree: list[TreemapNode], path: list[int] +) -> TreemapNode | None: + """Get node at path. path=[0,1] -> tree[0].children[1].""" + if not path: + return None + cur: list[TreemapNode] = tree + for i in path[:-1]: + if i < 0 or i >= len(cur): + return None + cur = cur[i].get("children") or [] + if path[-1] < 0 or path[-1] >= len(cur): + return None + return cur[path[-1]] + + +def get_path_styles( + path: list[int], + path_to_style: dict[tuple[int, ...], str], +) -> list[str]: + """Get style for each node in path from path_to_style mapping.""" + if not path or not path_to_style: + return [] + return [ + path_to_style[tuple(path[:i])] + for i in range(1, len(path) + 1) + if tuple(path[:i]) in path_to_style + ] + + +def format_treemap_nested_path( + tree: list[TreemapNode], + path: list[int], + total: float, + value_display: "ValueDisplay", + currency_symbol: str, +) -> tuple[list[tuple[str, str | None]], str]: + """Build breadcrumb for nested treemap. Returns ([(label, style), ...], plain_text).""" + if not tree or not path or total <= 0: + return ([], "") + from textual_plot.plot_widget import ValueDisplay + + segments: list[tuple[str, str | None]] = [] + plain_parts: list[str] = [] + current_nodes: list[TreemapNode] = tree + parent_value = total + for depth, idx in enumerate(path): + if idx < 0 or idx >= len(current_nodes): + break + node = current_nodes[idx] + value = node["value"] + label = node.get("label", "?") + pct_parent = 100 * value / parent_value if parent_value else 0 + pct_all = 100 * value / total if total else 0 + if value_display == ValueDisplay.CURRENCY: + value_str = f"{currency_symbol}{value:,.2f}" + else: + value_str = f"{value:,.0f}" if value == int(value) else f"{value:,.1f}" + if depth == 0: + segments.append((f"{label} ▪ {value_str} ▪ {pct_all:.1f}%", None)) + plain_parts.append(f"{label} ▪ {value_str} ▪ {pct_all:.1f}%") + else: + segments.append( + ( + f"{label} ▪ {value_str} ▪ {pct_parent:.1f}% ({pct_all:.1f}% All)", + None, + ) + ) + plain_parts.append( + f"{label} ▪ {value_str} ▪ {pct_parent:.1f}% ({pct_all:.1f}% All)" + ) + current_nodes = node.get("children") or [] + parent_value = value + return (segments, " ▶ ".join(plain_parts)) + + +def squarify_recursive( + nodes: list[TreemapNode], + x: float, + y: float, + dx: float, + dy: float, + aspect: float, + path: list[int], + base_styles: list[str], + exclude_colors: list[tuple[float, float, float]], +) -> list[dict]: + """Recursively run squarify on a tree. Returns rect info for parents and leaves. + + Parents get distinctipy colors; children get distinctipy (excluding parent) tinted + with parent's hue. At every nesting level: parents output first (draw underneath); + children inset by TREEMAP_PARENT_PAD so parent shows as 1-char border. + """ + if not nodes or dx <= 0 or dy <= 0: + return [] + out: list[dict] = [] + values = [n["value"] for n in nodes] + normalized = squarify.normalize_sizes(values, dx, dy) + rects = squarify.squarify(normalized, 0, 0, dx, dy) + pad = TREEMAP_PARENT_PAD + # Horizontal: same as bottom (2 chars) so parent border visible on left and right + inset_x = float(pad) * 2 / max(0.01, aspect) + inset_y_top = float(TREEMAP_PARENT_LABEL_TOP) + inset_y_bottom = float(pad) * 2 # 2 rows at bottom so parent border is visible + for i, rect in enumerate(rects): + if i >= len(nodes): + break + node = nodes[i] + rx, ry = rect["x"], rect["y"] + rdx, rdy = rect["dx"], rect["dy"] + child_path = path + [i] + style = base_styles[i] if i < len(base_styles) else base_styles[0] + if node.get("children"): + parent_rgb = parse_style_to_rgb(style) + n_children = len(node["children"]) + child_exclude = ( + [parent_rgb] + exclude_colors if parent_rgb else list(exclude_colors) + ) + child_exclude.extend( + [(1.0, 1.0, 1.0), (0.0, 0.0, 0.0)] + ) # always exclude white/black + child_colors = distinctipy.get_colors( + n_children, + exclude_colors=child_exclude, + pastel_factor=0.2, + rng=42, + colorblind_type="Deuteranomaly", + ) + child_styles = [] + current_exclude = list(child_exclude) + for c in child_colors: + tinted = tint_with_hue(c, parent_rgb) if parent_rgb else c + retries = 0 + while parent_rgb and rgb_too_close(tinted, parent_rgb) and retries < 50: + current_exclude.append(c) + current_exclude.append(tinted) + c = distinctipy.distinct_color( + current_exclude, + pastel_factor=0.2, + rng=42, + colorblind_type="Deuteranomaly", + ) + tinted = tint_with_hue(c, parent_rgb) + retries += 1 + child_styles.append(rgb_style(tinted)) + current_exclude.append(c) + current_exclude.append(tinted) + parent_value = sum(n["value"] for n in node["children"]) + out.append( + { + "x": x + rx, + "y": y + ry, + "dx": rdx, + "dy": rdy, + "node": node, + "path": child_path, + "style": style, + "selection_base": style, + "value": parent_value, + "label": node["label"], + "has_children": True, + } + ) + child_dx = max(0.1, rdx - 2 * inset_x) + child_dy = max(0.1, rdy - inset_y_top - inset_y_bottom) + child_x = x + rx + inset_x + child_y = y + ry + inset_y_top + child_exclude_list = ( + [parent_rgb] + exclude_colors if parent_rgb else exclude_colors + ) + sub = squarify_recursive( + node["children"], + child_x, + child_y, + child_dx, + child_dy, + aspect, + child_path, + child_styles, + child_exclude_list, + ) + out.extend(sub) + else: + out.append( + { + "x": x + rx, + "y": y + ry, + "dx": rdx, + "dy": rdy, + "node": node, + "path": child_path, + "style": style, + "selection_base": style, + "value": node["value"], + "label": node["label"], + "has_children": False, + } + ) + return out diff --git a/tests/test_treemap.py b/tests/test_treemap.py index 18af960..bd11514 100644 --- a/tests/test_treemap.py +++ b/tests/test_treemap.py @@ -40,7 +40,42 @@ def test_treemap_plot_structure(self) -> None: aspect_preference=1.5, value_display=ValueDisplay.BOTH, currency_symbol="$", + tree=None, + show_nested=False, ) assert len(dataset.values) == 3 assert dataset.labels == ["A", "B", "C"] assert dataset.padding == 1 + + def test_treemap_show_nested_dataset_structure(self) -> None: + """TreemapPlot with show_nested=True has expected structure for nested rendering.""" + from textual_plot.treemap_utils import normalize_treemap_tree + + tree_nodes, is_nested = normalize_treemap_tree( + [ + { + "label": "A", + "children": [ + {"label": "A1", "value": 10}, + {"label": "A2", "value": 20}, + ], + }, + {"label": "B", "value": 30}, + ] + ) + assert is_nested + assert len(tree_nodes) == 2 + dataset = TreemapPlot( + values=np.array([30.0, 30.0]), + labels=["A", "B"], + styles=["red", "blue"], + padding=1, + hires_mode=None, + aspect_preference=1.5, + value_display=ValueDisplay.BOTH, + currency_symbol="$", + tree=tree_nodes, + show_nested=True, + ) + assert dataset.tree is not None + assert dataset.show_nested is True From 02333dc5e19b30f156813ad3970526460ab2de3d Mon Sep 17 00:00:00 2001 From: Ryan Meyers Date: Wed, 18 Feb 2026 11:34:55 -0600 Subject: [PATCH 4/4] Better zooming --- src/textual_plot/plot_widget.py | 245 +++++++++++++++++++++++++----- src/textual_plot/treemap_utils.py | 85 +++++++---- 2 files changed, 259 insertions(+), 71 deletions(-) diff --git a/src/textual_plot/plot_widget.py b/src/textual_plot/plot_widget.py index 17c7904..3a2ef29 100644 --- a/src/textual_plot/plot_widget.py +++ b/src/textual_plot/plot_widget.py @@ -33,6 +33,7 @@ from textual.containers import Grid from textual.css.query import NoMatches from textual.events import ( + Click, Key, MouseDown, MouseMove, @@ -58,6 +59,7 @@ format_treemap_nested_path, get_path_styles, get_treemap_level, + get_treemap_node_at_path, normalize_treemap_tree, squarify_recursive, treemap_max_depth, @@ -333,6 +335,7 @@ class ScaleChanged(Message): Binding("down", "pan_down", "Pan down", group=PAN_GROUP), ("r", "reset_scales", "Reset scales"), Binding("escape", "treemap_zoom_out", "Treemap zoom out", show=False), + ("l", "toggle_legend", "Toggle legend"), ] margin_top = reactive(2) @@ -373,6 +376,7 @@ class ScaleChanged(Message): _allow_pan_and_zoom: bool = True _is_dragging_legend: bool = False + _legend_visible: bool = False _needs_rerender: bool = False _needs_canvas_resize: bool = False @@ -900,10 +904,17 @@ def show_legend( ): self._legend_location = LegendLocation.LEFT self._legend_relative_offset = Offset(0, 0) + self._legend_visible = is_visible self.query_one("#legend", Static).display = is_visible if is_visible: self._update_legend() + def action_toggle_legend(self) -> None: + """Toggle legend visibility (collapse/expand).""" + self._legend_visible = not self._legend_visible + self.show_legend(is_visible=self._legend_visible) + self._rerender() + def _update_info(self, rect_info: dict | None) -> None: """Update the plot info box with rectangle data or hide it.""" try: @@ -1703,7 +1714,11 @@ def _render_treemap_plot(self, dataset: TreemapPlot) -> None: def _render_treemap_nested( self, canvas: Canvas, dataset: TreemapPlot, sr: Region ) -> None: - """Render full nested hierarchy with luminance variance (children darker).""" + """Render full nested hierarchy with luminance variance (children darker). + + When _treemap_path is non-empty, renders the subtree at that path as the root + (zoom-in view). Double-click a parent to zoom in; Escape to zoom out. + """ pad = max(0, dataset.padding) effective_width = max(1, sr.width - 2 * pad) effective_height = max(1, sr.height - 2 * pad) @@ -1711,26 +1726,96 @@ def _render_treemap_nested( ox = sr.x + pad oy = sr.y + pad - n_top = len(dataset.tree) - if isinstance(dataset.styles, list) and len(dataset.styles) >= n_top: - base_styles = list(dataset.styles[:n_top]) + # When zoomed in, use subtree at _treemap_path as root + base_path = list(self._treemap_path) + if base_path: + node = get_treemap_node_at_path(dataset.tree, base_path) + if node is None: + self._treemap_path.clear() + base_path = [] + tree_to_render = dataset.tree + else: + tree_to_render = [node] + else: + tree_to_render = dataset.tree + + # Pre-compute full tree's path_to_style so zoomed view keeps colors consistent + n_top_full = len(dataset.tree) + if isinstance(dataset.styles, list) and len(dataset.styles) >= n_top_full: + base_styles_full = list(dataset.styles[:n_top_full]) else: - distinct = distinctipy.get_colors( - n_top, + distinct_full = distinctipy.get_colors( + n_top_full, pastel_factor=0.2, rng=42, colorblind_type="Deuteranomaly", ) - base_styles = [rgb_style(c) for c in distinct] + base_styles_full = [rgb_style(c) for c in distinct_full] + layout_w_full = effective_width / aspect + extra_rows_full = 3 * max(0, treemap_max_depth(dataset.tree) - 1) + layout_h_full = max(1, effective_height - extra_rows_full) + full_rect_infos = squarify_recursive( + dataset.tree, + 0, + 0, + layout_w_full, + layout_h_full, + aspect, + [], + base_styles_full, + [], + None, + ) + full_path_to_style = {tuple(r["path"]): r["style"] for r in full_rect_infos} + + n_top = len(tree_to_render) + if isinstance(dataset.styles, list) and len(dataset.styles) >= n_top: + base_styles = list(dataset.styles[:n_top]) + else: + base_styles = [ + full_path_to_style.get(tuple(base_path + [i]), base_styles_full[0]) + for i in range(n_top) + ] + if not base_styles: + distinct = distinctipy.get_colors( + n_top, + pastel_factor=0.2, + rng=42, + colorblind_type="Deuteranomaly", + ) + base_styles = [rgb_style(c) for c in distinct] layout_w = effective_width / aspect - # Reserve vertical space for parent label rows (3 extra per nesting level: 2 top + 1 bottom) - extra_rows = 3 * max(0, treemap_max_depth(dataset.tree) - 1) + extra_rows = 3 * max(0, treemap_max_depth(tree_to_render) - 1) layout_h = max(1, effective_height - extra_rows) - rect_infos = squarify_recursive( - dataset.tree, 0, 0, layout_w, layout_h, aspect, [], base_styles, [] - ) - path_to_style = {tuple(r["path"]): r["style"] for r in rect_infos} + if not base_path: + rect_infos = full_rect_infos + path_to_style = full_path_to_style + else: + rect_infos = squarify_recursive( + tree_to_render, + 0, + 0, + layout_w, + layout_h, + aspect, + [], + base_styles, + [], + None, + ) + + # Map zoom-relative paths [0], [0,0], [0,1] to full paths base_path, base_path+[0], ... + def _full_path(p: list[int]) -> tuple[int, ...]: + if p and p[0] == 0: + return tuple(base_path + p[1:]) + return tuple(base_path + p) + + for r in rect_infos: + fp = _full_path(r["path"]) + if fp in full_path_to_style: + r["style"] = full_path_to_style[fp] + path_to_style = {_full_path(r["path"]): r["style"] for r in rect_infos} total_value = sum(r["value"] for r in rect_infos) hover_rects: list[dict] = [] @@ -1756,6 +1841,10 @@ def _render_treemap_nested( y1 = y0 + 1 style = info["style"] + if base_path and info["path"] and info["path"][0] == 0: + full_path = base_path + info["path"][1:] + else: + full_path = base_path + info["path"] rect_info = { "x0": x0, "y0": y0, @@ -1769,7 +1858,7 @@ def _render_treemap_nested( "currency_symbol": dataset.currency_symbol, "has_children": info.get("has_children", False), "node_index": i, - "path": info["path"], + "path": full_path, "tree": dataset.tree, "path_to_style": path_to_style, } @@ -2188,30 +2277,38 @@ def _zoom( ) self._rerender() - @on(MouseScrollDown) - def zoom_in(self, event: MouseScrollDown) -> None: - """Zoom into the plot when scrolling down. + @on(MouseScrollUp) + def zoom_in(self, event: MouseScrollUp) -> None: + """Zoom into the plot when scrolling up. Args: - event: The mouse scroll down event. + event: The mouse scroll up event. """ event.stop() + if self._is_treemap_only_nested() and self._handle_treemap_key("plus"): + return self._zoom_with_mouse(event, self.MOUSE_ZOOM_FACTOR) - @on(MouseScrollUp) - def zoom_out(self, event: MouseScrollUp) -> None: - """Zoom out of the plot when scrolling up. + @on(MouseScrollDown) + def zoom_out(self, event: MouseScrollDown) -> None: + """Zoom out of the plot when scrolling down. Args: - event: The mouse scroll up event. + event: The mouse scroll down event. """ event.stop() + if self._is_treemap_only_nested() and self._handle_treemap_key("minus"): + return self._zoom_with_mouse(event, -self.MOUSE_ZOOM_FACTOR) def action_zoom_in(self) -> None: + if self._is_treemap_only_nested() and self._handle_treemap_key("plus"): + return self._zoom_with_keyboard(self.KEYBOARD_ZOOM_FACTOR) def action_zoom_out(self) -> None: + if self._is_treemap_only_nested() and self._handle_treemap_key("minus"): + return self._zoom_with_keyboard(-self.KEYBOARD_ZOOM_FACTOR) def action_zoom_x_in(self) -> None: @@ -2283,36 +2380,49 @@ def _handle_treemap_key(self, key: str) -> bool: return False if key == "escape": - if self._treemap_path and not self._is_treemap_show_nested(): + if self._treemap_path: self._treemap_path.pop() + level = get_treemap_level(dataset.tree, self._treemap_path) self._treemap_selected_index = min( self._treemap_selected_index or 0, - len(get_treemap_level(dataset.tree, self._treemap_path)) - 1, + len(level) - 1, ) self._rerender() return True if key in ("plus", "equal", "+"): - if not self._is_treemap_show_nested(): - idx = ( - self._treemap_selected_index - if self._treemap_selected_index is not None - else 0 - ) - if self._treemap_hover_rects and idx < len(self._treemap_hover_rects): - rect = self._treemap_hover_rects[idx] - if rect.get("has_children"): + idx = ( + self._treemap_selected_index + if self._treemap_selected_index is not None + else 0 + ) + if self._treemap_hover_rects and idx < len(self._treemap_hover_rects): + rect = self._treemap_hover_rects[idx] + if rect.get("has_children"): + if self._is_treemap_show_nested(): + path = rect.get("path", []) + seg = ( + path[len(self._treemap_path)] + if len(path) > len(self._treemap_path) + else None + ) + if seg is not None: + self._treemap_path.append(seg) + self._treemap_selected_index = 0 + self._rerender() + else: self._treemap_path.append(rect["node_index"]) self._treemap_selected_index = 0 self._rerender() return True if key in ("minus", "-"): - if self._treemap_path and not self._is_treemap_show_nested(): + if self._treemap_path: self._treemap_path.pop() + level = get_treemap_level(dataset.tree, self._treemap_path) self._treemap_selected_index = min( self._treemap_selected_index or 0, - len(get_treemap_level(dataset.tree, self._treemap_path)) - 1, + len(level) - 1, ) self._rerender() return True @@ -2348,6 +2458,35 @@ def action_treemap_zoom_out(self) -> None: """Zoom out of nested treemap (bound to Escape).""" self._handle_treemap_key("escape") + def treemap_zoom_in(self, path: list[int] | int) -> None: + """Zoom into a nested treemap node. + + Args: + path: Index or list of indices into the tree. E.g. 0 zooms to first + top-level node; [0, 1] zooms to second child of first top-level. + """ + if not self._is_treemap_only_nested(): + return + path_list = [path] if isinstance(path, int) else list(path) + if not path_list: + return + dataset = next( + (d for d in self._datasets if isinstance(d, TreemapPlot) and d.tree), + None, + ) + if dataset is None: + return + node = get_treemap_node_at_path(dataset.tree, path_list) + if node is None or not node.get("children"): + return + self._treemap_path = path_list + self._treemap_selected_index = 0 + self._rerender() + + def treemap_zoom_out(self) -> None: + """Zoom out one level in nested treemap.""" + self._handle_treemap_key("escape") + @on(MouseDown) def start_dragging_legend(self, event: MouseDown) -> None: """Start dragging the legend when clicked with left mouse button. @@ -2373,7 +2512,7 @@ def stop_dragging_legend(self, event: MouseUp) -> None: self.query_one("#legend").remove_class("dragged") event.stop() - def _get_treemap_rect_at(self, event: MouseMove | MouseDown) -> dict | None: + def _get_treemap_rect_at(self, event: MouseMove | MouseDown | Click) -> dict | None: """Return the treemap rect at the mouse position, or None.""" if not self._treemap_hover_rects: return None @@ -2400,9 +2539,9 @@ def _get_treemap_rect_at(self, event: MouseMove | MouseDown) -> dict | None: return rect_info return None - @on(MouseDown) - def _handle_treemap_click(self, event: MouseDown) -> None: - """Select treemap rectangle on click; zoom in if it has children (non-show_nested only).""" + @on(Click) + def _handle_treemap_click(self, event: Click) -> None: + """Select treemap rectangle on click; zoom in on double-click when rect has children.""" if event.button != 1: return rect = self._get_treemap_rect_at(event) @@ -2410,10 +2549,34 @@ def _handle_treemap_click(self, event: MouseDown) -> None: return self._treemap_selected_rect = rect self._treemap_selected_index = rect.get("node_index", 0) - # Zoom in if rect has children (only when not in show_nested full-hierarchy view) - if rect.get("has_children") and not self._is_treemap_show_nested(): + + # Double-click: zoom in when rect has children + if event.chain == 2 and rect.get("has_children"): + if self._is_treemap_show_nested(): + # Append path segment to zoom into this parent + path = rect.get("path", []) + seg = ( + path[len(self._treemap_path)] + if len(path) > len(self._treemap_path) + else None + ) + if seg is not None: + self._treemap_path.append(seg) + self._treemap_selected_index = 0 + else: + # Non-nested: zoom in by index + self._treemap_path.append(rect["node_index"]) + self._treemap_selected_index = 0 + + # Single click in non-show_nested: zoom in on parent (legacy behavior) + elif ( + event.chain == 1 + and rect.get("has_children") + and not self._is_treemap_show_nested() + ): self._treemap_path.append(rect["node_index"]) self._treemap_selected_index = 0 + self._update_info(rect) self._rerender() diff --git a/src/textual_plot/treemap_utils.py b/src/textual_plot/treemap_utils.py index 9c7aaa4..79222a8 100644 --- a/src/textual_plot/treemap_utils.py +++ b/src/textual_plot/treemap_utils.py @@ -216,12 +216,16 @@ def squarify_recursive( path: list[int], base_styles: list[str], exclude_colors: list[tuple[float, float, float]], + path_to_style: dict[tuple[int, ...], str] | None = None, ) -> list[dict]: """Recursively run squarify on a tree. Returns rect info for parents and leaves. Parents get distinctipy colors; children get distinctipy (excluding parent) tinted with parent's hue. At every nesting level: parents output first (draw underneath); children inset by TREEMAP_PARENT_PAD so parent shows as 1-char border. + + When path_to_style is provided, uses those styles for nodes instead of generating + new ones (keeps colors consistent when zooming). """ if not nodes or dx <= 0 or dy <= 0: return [] @@ -241,42 +245,62 @@ def squarify_recursive( rx, ry = rect["x"], rect["y"] rdx, rdy = rect["dx"], rect["dy"] child_path = path + [i] - style = base_styles[i] if i < len(base_styles) else base_styles[0] + path_key = tuple(child_path) + if path_to_style is not None and path_key in path_to_style: + style = path_to_style[path_key] + else: + style = base_styles[i] if i < len(base_styles) else base_styles[0] if node.get("children"): parent_rgb = parse_style_to_rgb(style) n_children = len(node["children"]) - child_exclude = ( - [parent_rgb] + exclude_colors if parent_rgb else list(exclude_colors) - ) - child_exclude.extend( - [(1.0, 1.0, 1.0), (0.0, 0.0, 0.0)] - ) # always exclude white/black - child_colors = distinctipy.get_colors( - n_children, - exclude_colors=child_exclude, - pastel_factor=0.2, - rng=42, - colorblind_type="Deuteranomaly", - ) child_styles = [] - current_exclude = list(child_exclude) - for c in child_colors: - tinted = tint_with_hue(c, parent_rgb) if parent_rgb else c - retries = 0 - while parent_rgb and rgb_too_close(tinted, parent_rgb) and retries < 50: + if path_to_style is not None: + for j in range(n_children): + child_path_key = tuple(child_path + [j]) + if child_path_key in path_to_style: + child_styles.append(path_to_style[child_path_key]) + else: + child_styles.append( + base_styles[0] if base_styles else "rgb(128,128,128)" + ) + else: + child_exclude = ( + [parent_rgb] + exclude_colors + if parent_rgb + else list(exclude_colors) + ) + child_exclude.extend( + [(1.0, 1.0, 1.0), (0.0, 0.0, 0.0)] + ) # always exclude white/black + child_colors = distinctipy.get_colors( + n_children, + exclude_colors=child_exclude, + pastel_factor=0.2, + rng=42, + colorblind_type="Deuteranomaly", + ) + current_exclude = list(child_exclude) + for c in child_colors: + tinted = tint_with_hue(c, parent_rgb) if parent_rgb else c + retries = 0 + while ( + parent_rgb + and rgb_too_close(tinted, parent_rgb) + and retries < 50 + ): + current_exclude.append(c) + current_exclude.append(tinted) + c = distinctipy.distinct_color( + current_exclude, + pastel_factor=0.2, + rng=42, + colorblind_type="Deuteranomaly", + ) + tinted = tint_with_hue(c, parent_rgb) + retries += 1 + child_styles.append(rgb_style(tinted)) current_exclude.append(c) current_exclude.append(tinted) - c = distinctipy.distinct_color( - current_exclude, - pastel_factor=0.2, - rng=42, - colorblind_type="Deuteranomaly", - ) - tinted = tint_with_hue(c, parent_rgb) - retries += 1 - child_styles.append(rgb_style(tinted)) - current_exclude.append(c) - current_exclude.append(tinted) parent_value = sum(n["value"] for n in node["children"]) out.append( { @@ -310,6 +334,7 @@ def squarify_recursive( child_path, child_styles, child_exclude_list, + path_to_style, ) out.extend(sub) else: