diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000..1c52f9fcaf --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +src/memray/_vendor/** linguist-generated=true diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index edb7c8701b..a8d1f4398b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -22,8 +22,14 @@ jobs: run: | sudo apt-get update sudo apt-get install -qy npm git + - name: Install Python dependencies + run: | + python3 -m pip install -r requirements-extra.txt - name: Check if files are up to date - run: make build-js + run: | + make build-js + make vendor-update + python3 tools/check_vendor_versions.py - name: Check for changes run: | git add -u diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cafc5b7878..80b7e9afc8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,4 +1,4 @@ -exclude: "^(src/memray/reporters/templates/assets|src/vendor|benchmarks|docs/_static/flamegraphs)/" +exclude: "^(src/memray/reporters/templates/assets|src/vendor|src/memray/_vendor|benchmarks|docs/_static/flamegraphs)/" repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.5.0 @@ -11,7 +11,7 @@ repos: - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace - exclude: "^([.]bumpversion[.]cfg|.*/__snapshots__/)" + exclude: "^([.]bumpversion[.]cfg|.*/__snapshots__/|tools/vendoring/patches/.*\\.patch$)" - repo: https://github.com/pre-commit/pygrep-hooks rev: v1.10.0 @@ -64,3 +64,12 @@ repos: files: ^news/ types: [rst] additional_dependencies: ["sphinx"] + + - repo: local + hooks: + - id: no-bare-textual-imports + name: Forbid bare textual imports + language: pygrep + entry: '^\s*(from textual\b|import textual\b)' + types: [python] + exclude: "^(src/memray/_vendor/|tests/conftest\\.py)" diff --git a/MANIFEST.in b/MANIFEST.in index 184f439c81..d7ccc604dd 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -36,6 +36,7 @@ recursive-include src/vendor * recursive-include src/memray *.py recursive-include src/memray *.pyi recursive-include src/memray *.html *.js *.css +recursive-include src/memray/_vendor/textual * recursive-include src/memray *.pyx *.pxd recursive-include src/memray *.gdb *.lldb recursive-include src/memray/_memray * diff --git a/Makefile b/Makefile index 1c1478d0f7..b0c184ab0d 100644 --- a/Makefile +++ b/Makefile @@ -34,6 +34,10 @@ cpp_files := $(shell find src/memray/_memray -name \*.cpp -o -name \*.h) # Use this to inject arbitrary commands before the make targets (e.g. docker) ENV := +.PHONY: vendor-update +vendor-update: ## Update vendored dependencies (Textual) + $(PYTHON) -m vendoring sync . + .PHONY: build build: build-js build-vendor build-ext ## (default) Build package extensions, JS assets, and vendor assets in-place diff --git a/README.md b/README.md index ed3cd1536e..1db2e728ec 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,7 @@ It really makes a difference! # Installation -Memray requires Python 3.7+ and can be easily installed using most common Python +Memray requires Python 3.9+ and can be easily installed using most common Python packaging tools. We recommend installing the latest stable release from [PyPI](https://pypi.org/project/memray/) with pip: @@ -119,6 +119,28 @@ pre-commit install This will ensure that your contribution passes our linting checks. +# Vendoring Textual + +Memray vendors Textual under `src/memray/_vendor/textual`. + +To bump Textual: + +```shell +# 1) Update pins +# - vendor.txt: textual== +# - tools/vendoring/patches/textual-version.patch: __version__ = "" +# - setup.py test_requires: textual== + +# 2) Regenerate vendored tree +make vendor-update + +# 3) Verify version consistency checks +python3 tools/check_vendor_versions.py +``` + +CI (`.github/workflows/build.yml`, `check_generated_files`) reruns vendoring +and `tools/check_vendor_versions.py`, then fails if generated files drift. + # Documentation You can find the latest documentation available [here](https://bloomberg.github.io/memray/). diff --git a/docs/supported_environments.rst b/docs/supported_environments.rst index 28b66cafaf..37d82b3d3b 100644 --- a/docs/supported_environments.rst +++ b/docs/supported_environments.rst @@ -11,7 +11,7 @@ Supported Python versions Every Python version that hasn't reached end of life is supported. -Currently that's Python 3.8 through 3.14. +Currently that's Python 3.9 through 3.14. Supported operating systems --------------------------- @@ -38,7 +38,7 @@ are available on PyPI. For macOS, we test on ``x86-64`` and ``arm64`` - so, both Intel and Apple Silicon Macs. Pre-built wheels are available for both architectures, though -only for Python 3.8 and newer. +only for Python 3.9 and newer. Supported runtime environments ------------------------------ diff --git a/pyproject.toml b/pyproject.toml index 5078f5d77d..052e6be90c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ build-backend = 'setuptools.build_meta' line-length = 95 select = ["C4", "E", "F", "I001", "PERF", "W"] fix = true +exclude = ["src/memray/_vendor"] [tool.ruff.isort] force-single-line = true @@ -59,10 +60,10 @@ ignore = [ ] [tool.mypy] -exclude="tests/integration/(native_extension|multithreaded_extension)/" +exclude = "tests/integration/(native_extension|multithreaded_extension)/|_vendor/" [tool.cibuildwheel] -build = ["cp38-*", "cp39-*", "cp310-*", "cp311-*"] +build = ["cp39-*", "cp310-*", "cp311-*"] skip = "*musllinux*{i686,aarch64}*" manylinux-x86_64-image = "manylinux2014" manylinux-i686-image = "manylinux2014" @@ -116,6 +117,24 @@ before-test = [ "codesign --remove-signature /Library/Frameworks/Python.framework/Versions/*/Resources/Python.app/Contents/MacOS/Python || true", ] +[tool.vendoring] +destination = "src/memray/_vendor/" +requirements = "vendor.txt" +namespace = "memray._vendor" +protected-files = ["__init__.py"] +patches-dir = "tools/vendoring/patches" + +[tool.vendoring.transformations] +drop = [ + "tests/", + "docs/", + "examples/", + "*.md", + "*.rst", + "*.txt", + "CHANGELOG*", +] + [tool.coverage.run] plugins = [ "Cython.Coverage", @@ -128,6 +147,7 @@ branch = true parallel = true omit = [ "*__init__.py", + "src/memray/_vendor/*", ] [tool.coverage.report] diff --git a/requirements-extra.txt b/requirements-extra.txt index 19921e61b5..a82662ed32 100644 --- a/requirements-extra.txt +++ b/requirements-extra.txt @@ -2,4 +2,5 @@ mypy bump2version towncrier pre-commit +vendoring -r requirements-docs.txt diff --git a/requirements-test.txt b/requirements-test.txt index 3ff4f2fd8e..60ba668bfb 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -5,7 +5,7 @@ pytest pytest-cov ipython setuptools +textual==8.2.1 pkgconfig pytest-textual-snapshot -textual >= 0.43, != 0.65.2, != 0.66 packaging diff --git a/setup.py b/setup.py index c57166c910..3a751bac5d 100644 --- a/setup.py +++ b/setup.py @@ -90,9 +90,9 @@ def build_js_files(self): install_requires = [ "jinja2 >= 2.9", - "typing_extensions; python_version < '3.8.0'", - "rich >= 11.2.0", - "textual >= 0.41.0", + "rich >= 14.2.0", + "markdown-it-py", + "platformdirs", ] docs_requires = [ "IPython", @@ -109,6 +109,7 @@ def build_js_files(self): "isort", "mypy", "check-manifest", + "vendoring", ] test_requires = [ @@ -119,7 +120,7 @@ def build_js_files(self): "ipython", "setuptools", "pytest-textual-snapshot", - "textual >= 0.43, != 0.65.2, != 0.66", + "textual==8.2.1", "packaging", ] @@ -305,7 +306,7 @@ def build_js_files(self): setup( name="memray", version=about["__version__"], - python_requires=">=3.7.0", + python_requires=">=3.9.0", description="A memory profiler for Python applications", long_description=LONG_DESCRIPTION, long_description_content_type="text/markdown", @@ -316,8 +317,6 @@ def build_js_files(self): "License :: OSI Approved :: Apache Software License", "Operating System :: POSIX :: Linux", "Operating System :: MacOS", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", diff --git a/src/memray/_vendor/__init__.py b/src/memray/_vendor/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/src/memray/_vendor/__init__.py @@ -0,0 +1 @@ + diff --git a/src/memray/_vendor/textual/LICENSE b/src/memray/_vendor/textual/LICENSE new file mode 100644 index 0000000000..3a4399759e --- /dev/null +++ b/src/memray/_vendor/textual/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 Will McGugan + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/src/memray/_vendor/textual/__init__.py b/src/memray/_vendor/textual/__init__.py new file mode 100644 index 0000000000..ca39047418 --- /dev/null +++ b/src/memray/_vendor/textual/__init__.py @@ -0,0 +1,194 @@ +""" +The root Textual module. + +Exposes some commonly used symbols. + +""" + +from __future__ import annotations + +import inspect +import weakref +from typing import TYPE_CHECKING, Callable + +import rich.repr + +from memray._vendor.textual import constants +from memray._vendor.textual._context import active_app +from memray._vendor.textual._log import LogGroup, LogVerbosity +from memray._vendor.textual._on import on +from memray._vendor.textual._work_decorator import work + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + +__all__ = [ + "__version__", # type: ignore + "log", + "on", + "work", +] + + +LogCallable: TypeAlias = "Callable" + + +__version__ = "8.2.1" + +if TYPE_CHECKING: + from memray._vendor.textual.app import App as _App + + +class LoggerError(Exception): + """Raised when the logger failed.""" + + +@rich.repr.auto +class Logger: + """A [logger class](/guide/devtools/#logging-handler) that logs to the Textual [console](/guide/devtools#console).""" + + def __init__( + self, + log_callable: LogCallable | None, + group: LogGroup = LogGroup.INFO, + verbosity: LogVerbosity = LogVerbosity.NORMAL, + app: _App | None = None, + ) -> None: + self._log = log_callable + self._group = group + self._verbosity = verbosity + self._app = None if app is None else weakref.ref(app) + + @property + def app(self) -> _App | None: + """The associated application, or `None` if there isn't one.""" + return None if self._app is None else self._app() + + def __rich_repr__(self) -> rich.repr.Result: + yield self._group, LogGroup.INFO + yield self._verbosity, LogVerbosity.NORMAL + + def __call__(self, *args: object, **kwargs) -> None: + if constants.LOG_FILE: + output = " ".join(str(arg) for arg in args) + if kwargs: + key_values = " ".join( + f"{key}={value!r}" for key, value in kwargs.items() + ) + output = f"{output} {key_values}" if output else key_values + + with open(constants.LOG_FILE, "a", encoding="utf-8") as log_file: + print(output, file=log_file) + + app = self.app + if app is None: + try: + app = active_app.get() + except LookupError: + if constants.DEBUG: + print_args = ( + *args, + *[f"{key}={value!r}" for key, value in kwargs.items()], + ) + print(*print_args) + return + if not app._is_devtools_connected: + return + + current_frame = inspect.currentframe() + assert current_frame is not None + previous_frame = current_frame.f_back + assert previous_frame is not None + caller = inspect.getframeinfo(previous_frame) + + _log = self._log or app._log + try: + _log( + self._group, + self._verbosity, + caller, + *args, + **kwargs, + ) + except LoggerError: + # If there is not active app, try printing + if constants.DEBUG: + print_args = ( + *args, + *[f"{key}={value!r}" for key, value in kwargs.items()], + ) + print(*print_args) + + def verbosity(self, verbose: bool) -> Logger: + """Get a new logger with selective verbosity. + + Args: + verbose: True to use HIGH verbosity, otherwise NORMAL. + + Returns: + New logger. + """ + verbosity = LogVerbosity.HIGH if verbose else LogVerbosity.NORMAL + return Logger(self._log, self._group, verbosity, app=self.app) + + @property + def verbose(self) -> Logger: + """A verbose logger.""" + return Logger(self._log, self._group, LogVerbosity.HIGH, app=self.app) + + @property + def event(self) -> Logger: + """Logs events.""" + return Logger(self._log, LogGroup.EVENT, app=self.app) + + @property + def debug(self) -> Logger: + """Logs debug messages.""" + return Logger(self._log, LogGroup.DEBUG, app=self.app) + + @property + def info(self) -> Logger: + """Logs information.""" + return Logger(self._log, LogGroup.INFO, app=self.app) + + @property + def warning(self) -> Logger: + """Logs warnings.""" + return Logger(self._log, LogGroup.WARNING, app=self.app) + + @property + def error(self) -> Logger: + """Logs errors.""" + return Logger(self._log, LogGroup.ERROR, app=self.app) + + @property + def system(self) -> Logger: + """Logs system information.""" + return Logger(self._log, LogGroup.SYSTEM, app=self.app) + + @property + def logging(self) -> Logger: + """Logs from stdlib logging module.""" + return Logger(self._log, LogGroup.LOGGING, app=self.app) + + @property + def worker(self) -> Logger: + """Logs worker information.""" + return Logger(self._log, LogGroup.WORKER, app=self.app) + + +log = Logger(None) +"""Global logger that logs to the currently active app. + +Example: + ```python + from memray._vendor.textual import log + log(locals()) + ``` + +!!! note + This logger will only work if there is an active app in the current thread. + Use `app.log` to write logs from a thread without an active app. + + +""" diff --git a/src/memray/_vendor/textual/__main__.py b/src/memray/_vendor/textual/__main__.py new file mode 100644 index 0000000000..4cc0f56638 --- /dev/null +++ b/src/memray/_vendor/textual/__main__.py @@ -0,0 +1,19 @@ +from rich import print +from rich.panel import Panel + +from memray._vendor.textual.demo.demo_app import DemoApp + +if __name__ == "__main__": + app = DemoApp() + app.run() + print( + Panel.fit( + "[b magenta]Hope you liked the demo![/]\n\n" + "Please consider sponsoring me if you get value from my work.\n\n" + "Even the price of a ☕ can brighten my day!\n\n" + "https://github.com/sponsors/willmcgugan\n\n" + "- Will McGugan", + border_style="red", + title="Consider sponsoring", + ) + ) diff --git a/src/memray/_vendor/textual/_animator.py b/src/memray/_vendor/textual/_animator.py new file mode 100644 index 0000000000..5d815f2f9b --- /dev/null +++ b/src/memray/_vendor/textual/_animator.py @@ -0,0 +1,590 @@ +from __future__ import annotations + +import asyncio +from abc import ABC, abstractmethod +from dataclasses import dataclass +from functools import partial +from typing import TYPE_CHECKING, Any, Callable, TypeVar + +from typing_extensions import Protocol, runtime_checkable + +from memray._vendor.textual import _time +from memray._vendor.textual._callback import invoke +from memray._vendor.textual._compat import cached_property +from memray._vendor.textual._easing import DEFAULT_EASING, EASING +from memray._vendor.textual._types import AnimationLevel, CallbackType +from memray._vendor.textual.timer import Timer + +if TYPE_CHECKING: + from memray._vendor.textual.app import App + + AnimationKey = tuple[int, str] + """Animation keys are the id of the object and the attribute being animated.""" + +EasingFunction = Callable[[float], float] +"""Signature for a function that parametrizes animation speed. + +An easing function must map the interval [0, 1] into the interval [0, 1]. +""" + + +class AnimationError(Exception): + """An issue prevented animation from starting.""" + + +ReturnType = TypeVar("ReturnType") + + +@runtime_checkable +class Animatable(Protocol): + """Protocol for objects that can have their intrinsic values animated. + + For example, the transition between two colors can be animated + because the class [`Color`][textual.color.Color.blend] satisfies this protocol. + """ + + def blend( + self: ReturnType, destination: ReturnType, factor: float + ) -> ReturnType: # pragma: no cover + ... + + +class Animation(ABC): + on_complete: CallbackType | None = None + """Callback to run after animation completes""" + + @abstractmethod + def __call__( + self, + time: float, + app_animation_level: AnimationLevel = "full", + ) -> bool: # pragma: no cover + """Call the animation, return a boolean indicating whether animation is in-progress or complete. + + Args: + time: The current timestamp + + Returns: + True if the animation has finished, otherwise False. + """ + raise NotImplementedError("") + + async def invoke_callback(self) -> None: + """Calls the [`on_complete`][Animation.on_complete] callback if one is provided.""" + if self.on_complete is not None: + await invoke(self.on_complete) + + @abstractmethod + async def stop(self, complete: bool = True) -> None: + """Stop the animation. + + Args: + complete: Flag to say if the animation should be taken to completion. + """ + raise NotImplementedError + + def __eq__(self, other: object) -> bool: + return False + + +@dataclass +class SimpleAnimation(Animation): + obj: object + attribute: str + start_time: float + duration: float + start_value: float | Animatable + end_value: float | Animatable + final_value: object + easing: EasingFunction + on_complete: CallbackType | None = None + level: AnimationLevel = "full" + """Minimum level required for the animation to take place (inclusive).""" + + def __call__( + self, time: float, app_animation_level: AnimationLevel = "full" + ) -> bool: + if ( + self.duration == 0 + or app_animation_level == "none" + or app_animation_level == "basic" + and self.level == "full" + ): + setattr(self.obj, self.attribute, self.final_value) + return True + + factor = min(1.0, (time - self.start_time) / self.duration) + eased_factor = self.easing(factor) + + if factor == 1.0: + value = self.final_value + elif isinstance(self.start_value, Animatable): + assert isinstance( + self.end_value, Animatable + ), "end_value must be animatable" + value = self.start_value.blend(self.end_value, eased_factor) + else: + assert isinstance( + self.start_value, (int, float) + ), f"`start_value` must be float, not {self.start_value!r}" + assert isinstance( + self.end_value, (int, float) + ), f"`end_value` must be float, not {self.end_value!r}" + + if self.end_value > self.start_value: + eased_factor = self.easing(factor) + value = ( + self.start_value + + (self.end_value - self.start_value) * eased_factor + ) + else: + eased_factor = 1 - self.easing(factor) + value = ( + self.end_value + (self.start_value - self.end_value) * eased_factor + ) + setattr(self.obj, self.attribute, value) + return factor >= 1 + + async def stop(self, complete: bool = True) -> None: + """Stop the animation. + + Args: + complete: Flag to say if the animation should be taken to completion. + + Note: + [`on_complete`][Animation.on_complete] will be called regardless + of the value provided for `complete`. + """ + if complete: + setattr(self.obj, self.attribute, self.end_value) + await self.invoke_callback() + + def __eq__(self, other: object) -> bool: + if isinstance(other, SimpleAnimation): + return ( + self.final_value == other.final_value + and self.duration == other.duration + ) + return False + + +class BoundAnimator: + def __init__(self, animator: Animator, obj: object) -> None: + self._animator = animator + self._obj = obj + + def __call__( + self, + attribute: str, + value: str | float | Animatable, + *, + final_value: object = ..., + duration: float | None = None, + speed: float | None = None, + delay: float = 0.0, + easing: EasingFunction | str = DEFAULT_EASING, + on_complete: CallbackType | None = None, + level: AnimationLevel = "full", + ) -> None: + """Animate an attribute. + + Args: + attribute: Name of the attribute to animate. + value: The value to animate to. + final_value: The final value of the animation. Defaults to `value` if not set. + duration: The duration (in seconds) of the animation. + speed: The speed of the animation. + delay: A delay (in seconds) before the animation starts. + easing: An easing method. + on_complete: A callable to invoke when the animation is finished. + level: Minimum level required for the animation to take place (inclusive). + """ + start_value = getattr(self._obj, attribute) + if isinstance(value, str) and hasattr(start_value, "parse"): + # Color and Scalar have a parse method + # I'm exploiting a coincidence here, but I think this should be a first-class concept + # TODO: add a `Parsable` protocol + value = start_value.parse(value) + easing_function = EASING[easing] if isinstance(easing, str) else easing + return self._animator.animate( + self._obj, + attribute=attribute, + value=value, + final_value=final_value, + duration=duration, + speed=speed, + delay=delay, + easing=easing_function, + on_complete=on_complete, + level=level, + ) + + +class Animator: + """An object to manage updates to a given attribute over a period of time.""" + + def __init__(self, app: App, frames_per_second: int = 60) -> None: + """Initialise the animator object. + + Args: + app: The application that owns the animator. + frames_per_second: The number of frames/second to run the animation at. + """ + self._animations: dict[AnimationKey, Animation] = {} + """Dictionary that maps animation keys to the corresponding animation instances.""" + self._scheduled: dict[AnimationKey, Timer] = {} + """Dictionary of scheduled animations, comprising of their keys and the timer objects.""" + self.app = app + """The app that owns the animator object.""" + self._timer = Timer( + app, + 1 / frames_per_second, + name="Animator", + callback=self, + pause=True, + ) + + @cached_property + def _idle_event(self) -> asyncio.Event: + """The timer that runs the animator.""" + return asyncio.Event() + + @cached_property + def _complete_event(self) -> asyncio.Event: + """Flag if no animations are currently taking place.""" + return asyncio.Event() + + async def start(self) -> None: + """Start the animator task.""" + self._idle_event.set() + self._complete_event.set() + self._timer._start() + + async def stop(self) -> None: + """Stop the animator task.""" + try: + self._timer.stop() + except asyncio.CancelledError: + pass + finally: + self._idle_event.set() + self._complete_event.set() + + def bind(self, obj: object) -> BoundAnimator: + """Bind the animator to a given object. + + Args: + obj: The object to bind to. + + Returns: + The bound animator. + """ + return BoundAnimator(self, obj) + + def is_being_animated(self, obj: object, attribute: str) -> bool: + """Does the object/attribute pair have an ongoing or scheduled animation? + + Args: + obj: An object to check for. + attribute: The attribute on the object to test for. + + Returns: + `True` if that attribute is being animated for that object, `False` if not. + """ + key = (id(obj), attribute) + return key in self._animations or key in self._scheduled + + def animate( + self, + obj: object, + attribute: str, + value: Any, + *, + final_value: object = ..., + duration: float | None = None, + speed: float | None = None, + easing: EasingFunction | str = DEFAULT_EASING, + delay: float = 0.0, + on_complete: CallbackType | None = None, + level: AnimationLevel = "full", + ) -> None: + """Animate an attribute to a new value. + + Args: + obj: The object containing the attribute. + attribute: The name of the attribute. + value: The destination value of the attribute. + final_value: The final value, or ellipsis if it is the same as ``value``. + duration: The duration of the animation, or ``None`` to use speed. + speed: The speed of the animation. + easing: An easing function. + delay: Number of seconds to delay the start of the animation by. + on_complete: Callback to run after the animation completes. + level: Minimum level required for the animation to take place (inclusive). + """ + self._record_animation(attribute) + animate_callback = partial( + self._animate, + obj, + attribute, + value, + final_value=final_value, + duration=duration, + speed=speed, + easing=easing, + on_complete=on_complete, + level=level, + ) + if delay: + self._complete_event.clear() + self._scheduled[(id(obj), attribute)] = self.app.set_timer( + delay, animate_callback + ) + else: + animate_callback() + + def _record_animation(self, attribute: str) -> None: + """Called when an attribute is to be animated. + + Args: + attribute: Attribute being animated. + """ + + def _animate( + self, + obj: object, + attribute: str, + value: Any, + *, + final_value: object = ..., + duration: float | None = None, + speed: float | None = None, + easing: EasingFunction | str = DEFAULT_EASING, + on_complete: CallbackType | None = None, + level: AnimationLevel = "full", + ) -> None: + """Animate an attribute to a new value. + + Args: + obj: The object containing the attribute. + attribute: The name of the attribute. + value: The destination value of the attribute. + final_value: The final value, or ellipsis if it is the same as ``value``. + duration: The duration of the animation, or ``None`` to use speed. + speed: The speed of the animation. + easing: An easing function. + on_complete: Callback to run after the animation completes. + level: Minimum level required for the animation to take place (inclusive). + """ + if not hasattr(obj, attribute): + raise AttributeError( + f"Can't animate attribute {attribute!r} on {obj!r}; attribute does not exist" + ) + assert (duration is not None and speed is None) or ( + duration is None and speed is not None + ), "An Animation should have a duration OR a speed" + + # If an animation is already scheduled for this attribute, unschedule it. + animation_key = (id(obj), attribute) + try: + del self._scheduled[animation_key] + except KeyError: + pass + + if final_value is ...: + final_value = value + + start_time = self._get_time() + easing_function = EASING[easing] if isinstance(easing, str) else easing + animation: Animation | None = None + + if hasattr(obj, "__textual_animation__"): + animation = getattr(obj, "__textual_animation__")( + attribute, + getattr(obj, attribute), + value, + start_time, + duration=duration, + speed=speed, + easing=easing_function, + on_complete=on_complete, + level=level, + ) + + if animation is None: + if not isinstance(value, (int, float)) and not isinstance( + value, Animatable + ): + raise AnimationError( + f"Don't know how to animate {value!r}; " + "Can only animate , , or objects with a blend method" + ) + + start_value = getattr(obj, attribute) + if start_value == value: + self._animations.pop(animation_key, None) + if on_complete is not None: + self.app.call_later(on_complete) + return + + if duration is not None: + animation_duration = duration + else: + if hasattr(value, "get_distance_to"): + animation_duration = value.get_distance_to(start_value) / ( + speed or 50 + ) + else: + animation_duration = abs(value - start_value) / (speed or 50) + + animation = SimpleAnimation( + obj, + attribute=attribute, + start_time=start_time, + duration=animation_duration, + start_value=start_value, + end_value=value, + final_value=final_value, + easing=easing_function, + on_complete=( + partial(self.app.call_later, on_complete) + if on_complete is not None + else None + ), + level=level, + ) + + assert animation is not None, "animation expected to be non-None" + + if (current_animation := self._animations.get(animation_key)) is not None: + if (on_complete := current_animation.on_complete) is not None: + self.app.call_later(on_complete) + + self._animations[animation_key] = animation + self._timer.resume() + self._idle_event.clear() + self._complete_event.clear() + + async def _stop_scheduled_animation( + self, key: AnimationKey, complete: bool + ) -> None: + """Stop a scheduled animation. + + Args: + key: The key for the animation to stop. + complete: Should the animation be moved to its completed state? + """ + # First off, pull the timer out of the schedule and stop it; it + # won't be needed. + try: + schedule = self._scheduled.pop(key) + except KeyError: + return + schedule.stop() + # If we've been asked to complete (there's no point in making the + # animation only to then do nothing with it), and if there was a + # callback (there will be, but this just keeps type checkers happy + # really)... + if complete and schedule._callback is not None: + # ...invoke it to get the animator created and in the running + # animations. Yes, this does mean that a stopped scheduled + # animation will start running early... + await invoke(schedule._callback) + # ...but only so we can call on it to run right to the very end + # right away. + await self._stop_running_animation(key, complete) + + async def _stop_running_animation(self, key: AnimationKey, complete: bool) -> None: + """Stop a running animation. + + Args: + key: The key for the animation to stop. + complete: Should the animation be moved to its completed state? + """ + try: + animation = self._animations.pop(key) + except KeyError: + return + await animation.stop(complete) + + async def stop_animation( + self, obj: object, attribute: str, complete: bool = True + ) -> None: + """Stop an animation on an attribute. + + Args: + obj: The object containing the attribute. + attribute: The name of the attribute. + complete: Should the animation be set to its final value? + + Note: + If there is no animation scheduled or running, this is a no-op. + """ + key = (id(obj), attribute) + if key in self._scheduled: + await self._stop_scheduled_animation(key, complete) + elif key in self._animations: + await self._stop_running_animation(key, complete) + + def force_stop_animation(self, obj: object, attribute: str) -> None: + """Force stop an animation on an attribute. This will immediately stop the animation, + without running any associated callbacks, setting the attribute to its final value. + + Args: + obj: The object containing the attribute. + attribute: The name of the attribute. + + Note: + If there is no animation scheduled or running, this is a no-op. + """ + from memray._vendor.textual.css.scalar_animation import ScalarAnimation + + animation_key = (id(obj), attribute) + try: + animation = self._animations.pop(animation_key) + except KeyError: + return + + if isinstance(animation, SimpleAnimation): + setattr(obj, attribute, animation.end_value) + elif isinstance(animation, ScalarAnimation): + setattr(obj, attribute, animation.final_value) + + if animation.on_complete is not None: + animation.on_complete() + + def __call__(self) -> None: + if not self._animations: + self._timer.pause() + self._idle_event.set() + if not self._scheduled: + self._complete_event.set() + else: + app_animation_level = self.app.animation_level + animation_time = self._get_time() + animation_keys = list(self._animations.keys()) + for animation_key in animation_keys: + animation = self._animations[animation_key] + animation_complete = animation(animation_time, app_animation_level) + if animation_complete: + del self._animations[animation_key] + if animation.on_complete is not None: + animation.on_complete() + + def _get_time(self) -> float: + """Get the current wall clock time, via the internal Timer. + + Returns: + The wall clock time. + """ + # N.B. We could remove this method and always call `self._timer.get_time()` internally, + # but it's handy to have in mocking situations. + return _time.get_time() + + async def wait_for_idle(self) -> None: + """Wait for any animations to complete.""" + await self._idle_event.wait() + + async def wait_until_complete(self) -> None: + """Wait for any current and scheduled animations to complete.""" + await self._complete_event.wait() diff --git a/src/memray/_vendor/textual/_ansi_sequences.py b/src/memray/_vendor/textual/_ansi_sequences.py new file mode 100644 index 0000000000..2db80a59c5 --- /dev/null +++ b/src/memray/_vendor/textual/_ansi_sequences.py @@ -0,0 +1,454 @@ +from __future__ import annotations + +from typing import Mapping, Tuple + +from typing_extensions import Final + +from memray._vendor.textual.keys import Keys + + +class IgnoredSequence: + """Class used to mark that a sequence should be ignored.""" + + +IGNORE_SEQUENCE: Final[IgnoredSequence] = IgnoredSequence() +"""Constant to indicate that a sequence should be ignored.""" + + +# Mapping of vt100 escape codes to Keys. +ANSI_SEQUENCES_KEYS: Mapping[str, Tuple[Keys, ...] | str | IgnoredSequence] = { + # Control keys. + " ": (Keys.Space,), + "\r": (Keys.Enter,), + "\x00": (Keys.ControlAt,), # Control-At (Also for Ctrl-Space) + "\x01": (Keys.ControlA,), # Control-A (home) + "\x02": (Keys.ControlB,), # Control-B (emacs cursor left) + "\x03": (Keys.ControlC,), # Control-C (interrupt) + "\x04": (Keys.ControlD,), # Control-D (exit) + "\x05": (Keys.ControlE,), # Control-E (end) + "\x06": (Keys.ControlF,), # Control-F (cursor forward) + "\x07": (Keys.ControlG,), # Control-G + "\x08": (Keys.Backspace,), # Control-H (8) (Identical to '\b') + "\x09": (Keys.Tab,), # Control-I (9) (Identical to '\t') + "\x0a": (Keys.ControlJ,), # Control-J (10) (Identical to '\n') + "\x0b": (Keys.ControlK,), # Control-K (delete until end of line; vertical tab) + "\x0c": (Keys.ControlL,), # Control-L (clear; form feed) + # "\x0d": (Keys.ControlM,), # Control-M (13) (Identical to '\r') + "\x0e": (Keys.ControlN,), # Control-N (14) (history forward) + "\x0f": (Keys.ControlO,), # Control-O (15) + "\x10": (Keys.ControlP,), # Control-P (16) (history back) + "\x11": (Keys.ControlQ,), # Control-Q + "\x12": (Keys.ControlR,), # Control-R (18) (reverse search) + "\x13": (Keys.ControlS,), # Control-S (19) (forward search) + "\x14": (Keys.ControlT,), # Control-T + "\x15": (Keys.ControlU,), # Control-U + "\x16": (Keys.ControlV,), # Control-V + "\x17": (Keys.ControlW,), # Control-W + "\x18": (Keys.ControlX,), # Control-X + "\x19": (Keys.ControlY,), # Control-Y (25) + "\x1a": (Keys.ControlZ,), # Control-Z + "\x1b": (Keys.Escape,), # Also Control-[ + "\x1b\x1b": ( + Keys.Escape, + ), # Windows issues esc esc for a single press of escape key + "\x9b": (Keys.ShiftEscape,), + "\x1c": (Keys.ControlBackslash,), # Both Control-\ (also Ctrl-| ) + "\x1d": (Keys.ControlSquareClose,), # Control-] + "\x1e": (Keys.ControlCircumflex,), # Control-^ + "\x1f": (Keys.ControlUnderscore,), # Control-underscore (Also for Ctrl-hyphen.) + # ASCII Delete (0x7f) + # Vt220 (and Linux terminal) send this when pressing backspace. We map this + # to ControlH, because that will make it easier to create key bindings that + # work everywhere, with the trade-off that it's no longer possible to + # handle backspace and control-h individually for the few terminals that + # support it. (Most terminals send ControlH when backspace is pressed.) + # See: http://www.ibb.net/~anne/keyboard.html + "\x7f": (Keys.Backspace,), + "\x1b\x7f": (Keys.ControlW,), + # Various + "\x1b[1~": (Keys.Home,), # tmux + "\x1b[2~": (Keys.Insert,), + "\x1b[3~": (Keys.Delete,), + "\x1b[4~": (Keys.End,), # tmux + "\x1b[5~": (Keys.PageUp,), + "\x1b[6~": (Keys.PageDown,), + "\x1b[7~": (Keys.Home,), # xrvt + "\x1b[8~": (Keys.End,), # xrvt + "\x1b[Z": (Keys.BackTab,), # shift + tab + "\x1b\x09": (Keys.BackTab,), # Linux console + "\x1b[~": (Keys.BackTab,), # Windows console + # -- + # Function keys. + "\x1bOP": (Keys.F1,), + "\x1bOQ": (Keys.F2,), + "\x1bOR": (Keys.F3,), + "\x1bOS": (Keys.F4,), + "\x1b[[A": (Keys.F1,), # Linux console. + "\x1b[[B": (Keys.F2,), # Linux console. + "\x1b[[C": (Keys.F3,), # Linux console. + "\x1b[[D": (Keys.F4,), # Linux console. + "\x1b[[E": (Keys.F5,), # Linux console. + "\x1b[11~": (Keys.F1,), # rxvt-unicode + "\x1b[12~": (Keys.F2,), # rxvt-unicode + "\x1b[13~": (Keys.F3,), # rxvt-unicode + "\x1b[14~": (Keys.F4,), # rxvt-unicode + "\x1b[15~": (Keys.F5,), + "\x1b[17~": (Keys.F6,), + "\x1b[18~": (Keys.F7,), + "\x1b[19~": (Keys.F8,), + "\x1b[20~": (Keys.F9,), + "\x1b[21~": (Keys.F10,), + "\x1b[23~": (Keys.F11,), + "\x1b[24~": (Keys.F12,), + "\x1b[25~": (Keys.F13,), + "\x1b[26~": (Keys.F14,), + "\x1b[28~": (Keys.F15,), + "\x1b[29~": (Keys.F16,), + "\x1b[31~": (Keys.F17,), + "\x1b[32~": (Keys.F18,), + "\x1b[33~": (Keys.F19,), + "\x1b[34~": (Keys.F20,), + # Xterm + "\x1b[1;2P": (Keys.F13,), + "\x1b[1;2Q": (Keys.F14,), + "\x1b[1;2R": ( + Keys.F15, + ), # Conflicts with CPR response; enabled after https://github.com/Textualize/textual/issues/3440. + "\x1b[1;2S": (Keys.F16,), + "\x1b[15;2~": (Keys.F17,), + "\x1b[17;2~": (Keys.F18,), + "\x1b[18;2~": (Keys.F19,), + "\x1b[19;2~": (Keys.F20,), + "\x1b[20;2~": (Keys.F21,), + "\x1b[21;2~": (Keys.F22,), + "\x1b[23;2~": (Keys.F23,), + "\x1b[24;2~": (Keys.F24,), + "\x1b[23$": (Keys.F23,), # rxvt + "\x1b[24$": (Keys.F24,), # rxvt + # -- + # Control + function keys. + "\x1b[1;5P": (Keys.ControlF1,), + "\x1b[1;5Q": (Keys.ControlF2,), + "\x1b[1;5R": ( + Keys.ControlF3, + ), # Conflicts with CPR response; enabled after https://github.com/Textualize/textual/issues/3440. + "\x1b[1;5S": (Keys.ControlF4,), + "\x1b[15;5~": (Keys.ControlF5,), + "\x1b[17;5~": (Keys.ControlF6,), + "\x1b[18;5~": (Keys.ControlF7,), + "\x1b[19;5~": (Keys.ControlF8,), + "\x1b[20;5~": (Keys.ControlF9,), + "\x1b[21;5~": (Keys.ControlF10,), + "\x1b[23;5~": (Keys.ControlF11,), + "\x1b[24;5~": (Keys.ControlF12,), + "\x1b[1;6P": (Keys.ControlF13,), + "\x1b[1;6Q": (Keys.ControlF14,), + "\x1b[1;6R": ( + Keys.ControlF15, + ), # Conflicts with CPR response; enabled after https://github.com/Textualize/textual/issues/3440. + "\x1b[1;6S": (Keys.ControlF16,), + "\x1b[15;6~": (Keys.ControlF17,), + "\x1b[17;6~": (Keys.ControlF18,), + "\x1b[18;6~": (Keys.ControlF19,), + "\x1b[19;6~": (Keys.ControlF20,), + "\x1b[20;6~": (Keys.ControlF21,), + "\x1b[21;6~": (Keys.ControlF22,), + "\x1b[23;6~": (Keys.ControlF23,), + "\x1b[24;6~": (Keys.ControlF24,), + # rxvt-unicode control function keys: + "\x1b[11^": (Keys.ControlF1,), + "\x1b[12^": (Keys.ControlF2,), + "\x1b[13^": (Keys.ControlF3,), + "\x1b[14^": (Keys.ControlF4,), + "\x1b[15^": (Keys.ControlF5,), + "\x1b[17^": (Keys.ControlF6,), + "\x1b[18^": (Keys.ControlF7,), + "\x1b[19^": (Keys.ControlF8,), + "\x1b[20^": (Keys.ControlF9,), + "\x1b[21^": (Keys.ControlF10,), + "\x1b[23^": (Keys.ControlF11,), + "\x1b[24^": (Keys.ControlF12,), + # rxvt-unicode control+shift function keys: + "\x1b[25^": (Keys.ControlF13,), + "\x1b[26^": (Keys.ControlF14,), + "\x1b[28^": (Keys.ControlF15,), + "\x1b[29^": (Keys.ControlF16,), + "\x1b[31^": (Keys.ControlF17,), + "\x1b[32^": (Keys.ControlF18,), + "\x1b[33^": (Keys.ControlF19,), + "\x1b[34^": (Keys.ControlF20,), + "\x1b[23@": (Keys.ControlF21,), + "\x1b[24@": (Keys.ControlF22,), + # -- + # Tmux (Win32 subsystem) sends the following scroll events. + "\x1b[62~": (Keys.ScrollUp,), + "\x1b[63~": (Keys.ScrollDown,), + # Meta/control/escape + pageup/pagedown/insert/delete. + "\x1b[3;2~": (Keys.ShiftDelete,), # xterm, gnome-terminal. + "\x1b[3$": (Keys.ShiftDelete,), # rxvt + "\x1b[5;2~": (Keys.ShiftPageUp,), + "\x1b[6;2~": (Keys.ShiftPageDown,), + "\x1b[2;3~": (Keys.Escape, Keys.Insert), + "\x1b[3;3~": (Keys.Escape, Keys.Delete), + "\x1b[5;3~": (Keys.Escape, Keys.PageUp), + "\x1b[6;3~": (Keys.Escape, Keys.PageDown), + "\x1b[2;4~": (Keys.Escape, Keys.ShiftInsert), + "\x1b[3;4~": (Keys.Escape, Keys.ShiftDelete), + "\x1b[5;4~": (Keys.Escape, Keys.ShiftPageUp), + "\x1b[6;4~": (Keys.Escape, Keys.ShiftPageDown), + "\x1b[3;5~": (Keys.ControlDelete,), # xterm, gnome-terminal. + "\x1b[3^": (Keys.ControlDelete,), # rxvt + "\x1b[5;5~": (Keys.ControlPageUp,), + "\x1b[6;5~": (Keys.ControlPageDown,), + "\x1b[5^": (Keys.ControlPageUp,), # rxvt + "\x1b[6^": (Keys.ControlPageDown,), # rxvt + "\x1b[3;6~": (Keys.ControlShiftDelete,), + "\x1b[5;6~": (Keys.ControlShiftPageUp,), + "\x1b[6;6~": (Keys.ControlShiftPageDown,), + "\x1b[2;7~": (Keys.Escape, Keys.ControlInsert), + "\x1b[5;7~": (Keys.Escape, Keys.ControlPageDown), + "\x1b[6;7~": (Keys.Escape, Keys.ControlPageDown), + "\x1b[2;8~": (Keys.Escape, Keys.ControlShiftInsert), + "\x1b[5;8~": (Keys.Escape, Keys.ControlShiftPageDown), + "\x1b[6;8~": (Keys.Escape, Keys.ControlShiftPageDown), + # -- + # Arrows. + # (Normal cursor mode). + "\x1b[A": (Keys.Up,), + "\x1b[B": (Keys.Down,), + "\x1b[C": (Keys.Right,), + "\x1b[D": (Keys.Left,), + "\x1b[H": (Keys.Home,), + "\x1b[F": (Keys.End,), + # Tmux sends following keystrokes when control+arrow is pressed, but for + # Emacs ansi-term sends the same sequences for normal arrow keys. Consider + # it a normal arrow press, because that's more important. + # (Application cursor mode). + "\x1bOA": (Keys.Up,), + "\x1bOB": (Keys.Down,), + "\x1bOC": (Keys.Right,), + "\x1bOD": (Keys.Left,), + "\x1bOF": (Keys.End,), + "\x1bOH": (Keys.Home,), + # Shift + arrows. + "\x1b[1;2A": (Keys.ShiftUp,), + "\x1b[1;2B": (Keys.ShiftDown,), + "\x1b[1;2C": (Keys.ShiftRight,), + "\x1b[1;2D": (Keys.ShiftLeft,), + "\x1b[1;2F": (Keys.ShiftEnd,), + "\x1b[1;2H": (Keys.ShiftHome,), + # Shift+navigation in rxvt + "\x1b[a": (Keys.ShiftUp,), + "\x1b[b": (Keys.ShiftDown,), + "\x1b[c": (Keys.ShiftRight,), + "\x1b[d": (Keys.ShiftLeft,), + "\x1b[7$": (Keys.ShiftHome,), + "\x1b[8$": (Keys.ShiftEnd,), + # Meta + arrow keys. Several terminals handle this differently. + # The following sequences are for xterm and gnome-terminal. + # (Iterm sends ESC followed by the normal arrow_up/down/left/right + # sequences, and the OSX Terminal sends ESCb and ESCf for "alt + # arrow_left" and "alt arrow_right." We don't handle these + # explicitly, in here, because would could not distinguish between + # pressing ESC (to go to Vi navigation mode), followed by just the + # 'b' or 'f' key. These combinations are handled in + # the input processor.) + "\x1b[1;3A": (Keys.Escape, Keys.Up), + "\x1b[1;3B": (Keys.Escape, Keys.Down), + "\x1b[1;3C": (Keys.Escape, Keys.Right), + "\x1b[1;3D": (Keys.Escape, Keys.Left), + "\x1b[1;3F": (Keys.Escape, Keys.End), + "\x1b[1;3H": (Keys.Escape, Keys.Home), + # Alt+shift+number. + "\x1b[1;4A": (Keys.Escape, Keys.ShiftUp), + "\x1b[1;4B": (Keys.Escape, Keys.ShiftDown), + "\x1b[1;4C": (Keys.Escape, Keys.ShiftRight), + "\x1b[1;4D": (Keys.Escape, Keys.ShiftLeft), + "\x1b[1;4F": (Keys.Escape, Keys.ShiftEnd), + "\x1b[1;4H": (Keys.Escape, Keys.ShiftHome), + # Control + arrows. + "\x1b[1;5A": (Keys.ControlUp,), # Cursor Mode + "\x1b[1;5B": (Keys.ControlDown,), # Cursor Mode + "\x1b[1;5C": (Keys.ControlRight,), # Cursor Mode + "\x1b[1;5D": (Keys.ControlLeft,), # Cursor Mode + "\x1bf": (Keys.ControlRight,), # iTerm natural editing keys + "\x1bb": (Keys.ControlLeft,), # iTerm natural editing keys + "\x1b[1;5F": (Keys.ControlEnd,), + "\x1b[1;5H": (Keys.ControlHome,), + # rxvt + "\x1b[7^": (Keys.ControlEnd,), + "\x1b[8^": (Keys.ControlHome,), + # Tmux sends following keystrokes when control+arrow is pressed, but for + # Emacs ansi-term sends the same sequences for normal arrow keys. Consider + # it a normal arrow press, because that's more important. + "\x1b[5A": (Keys.ControlUp,), + "\x1b[5B": (Keys.ControlDown,), + "\x1b[5C": (Keys.ControlRight,), + "\x1b[5D": (Keys.ControlLeft,), + # Control arrow keys in rxvt + "\x1bOa": (Keys.ControlUp,), + "\x1bOb": (Keys.ControlUp,), + "\x1bOc": (Keys.ControlRight,), + "\x1bOd": (Keys.ControlLeft,), + # Control + shift + arrows. + "\x1b[1;6A": (Keys.ControlShiftUp,), + "\x1b[1;6B": (Keys.ControlShiftDown,), + "\x1b[1;6C": (Keys.ControlShiftRight,), + "\x1b[1;6D": (Keys.ControlShiftLeft,), + "\x1b[1;6F": (Keys.ControlShiftEnd,), + "\x1b[1;6H": (Keys.ControlShiftHome,), + # Control + Meta + arrows. + "\x1b[1;7A": (Keys.Escape, Keys.ControlUp), + "\x1b[1;7B": (Keys.Escape, Keys.ControlDown), + "\x1b[1;7C": (Keys.Escape, Keys.ControlRight), + "\x1b[1;7D": (Keys.Escape, Keys.ControlLeft), + "\x1b[1;7F": (Keys.Escape, Keys.ControlEnd), + "\x1b[1;7H": (Keys.Escape, Keys.ControlHome), + # Meta + Shift + arrows. + "\x1b[1;8A": (Keys.Escape, Keys.ControlShiftUp), + "\x1b[1;8B": (Keys.Escape, Keys.ControlShiftDown), + "\x1b[1;8C": (Keys.Escape, Keys.ControlShiftRight), + "\x1b[1;8D": (Keys.Escape, Keys.ControlShiftLeft), + "\x1b[1;8F": (Keys.Escape, Keys.ControlShiftEnd), + "\x1b[1;8H": (Keys.Escape, Keys.ControlShiftHome), + # Meta + arrow on (some?) Macs when using iTerm defaults (see issue #483). + "\x1b[1;9A": (Keys.Escape, Keys.Up), + "\x1b[1;9B": (Keys.Escape, Keys.Down), + "\x1b[1;9C": (Keys.Escape, Keys.Right), + "\x1b[1;9D": (Keys.Escape, Keys.Left), + # -- + # Control/shift/meta + number in mintty. + # (c-2 will actually send c-@ and c-6 will send c-^.) + "\x1b[1;5p": (Keys.Control0,), + "\x1b[1;5q": (Keys.Control1,), + "\x1b[1;5r": (Keys.Control2,), + "\x1b[1;5s": (Keys.Control3,), + "\x1b[1;5t": (Keys.Control4,), + "\x1b[1;5u": (Keys.Control5,), + "\x1b[1;5v": (Keys.Control6,), + "\x1b[1;5w": (Keys.Control7,), + "\x1b[1;5x": (Keys.Control8,), + "\x1b[1;5y": (Keys.Control9,), + "\x1b[1;6p": (Keys.ControlShift0,), + "\x1b[1;6q": (Keys.ControlShift1,), + "\x1b[1;6r": (Keys.ControlShift2,), + "\x1b[1;6s": (Keys.ControlShift3,), + "\x1b[1;6t": (Keys.ControlShift4,), + "\x1b[1;6u": (Keys.ControlShift5,), + "\x1b[1;6v": (Keys.ControlShift6,), + "\x1b[1;6w": (Keys.ControlShift7,), + "\x1b[1;6x": (Keys.ControlShift8,), + "\x1b[1;6y": (Keys.ControlShift9,), + "\x1b[1;7p": (Keys.Escape, Keys.Control0), + "\x1b[1;7q": (Keys.Escape, Keys.Control1), + "\x1b[1;7r": (Keys.Escape, Keys.Control2), + "\x1b[1;7s": (Keys.Escape, Keys.Control3), + "\x1b[1;7t": (Keys.Escape, Keys.Control4), + "\x1b[1;7u": (Keys.Escape, Keys.Control5), + "\x1b[1;7v": (Keys.Escape, Keys.Control6), + "\x1b[1;7w": (Keys.Escape, Keys.Control7), + "\x1b[1;7x": (Keys.Escape, Keys.Control8), + "\x1b[1;7y": (Keys.Escape, Keys.Control9), + "\x1b[1;8p": (Keys.Escape, Keys.ControlShift0), + "\x1b[1;8q": (Keys.Escape, Keys.ControlShift1), + "\x1b[1;8r": (Keys.Escape, Keys.ControlShift2), + "\x1b[1;8s": (Keys.Escape, Keys.ControlShift3), + "\x1b[1;8t": (Keys.Escape, Keys.ControlShift4), + "\x1b[1;8u": (Keys.Escape, Keys.ControlShift5), + "\x1b[1;8v": (Keys.Escape, Keys.ControlShift6), + "\x1b[1;8w": (Keys.Escape, Keys.ControlShift7), + "\x1b[1;8x": (Keys.Escape, Keys.ControlShift8), + "\x1b[1;8y": (Keys.Escape, Keys.ControlShift9), + # Simplify some sequences that appear to be unique to rxvt; see + # https://github.com/Textualize/textual/issues/3741 for context. + "\x1bOj": "*", + "\x1bOk": "+", + "\x1bOm": "-", + "\x1bOn": ".", + "\x1bOo": "/", + "\x1bOp": "0", + "\x1bOq": "1", + "\x1bOr": "2", + "\x1bOs": "3", + "\x1bOt": "4", + "\x1bOu": "5", + "\x1bOv": "6", + "\x1bOw": "7", + "\x1bOx": "8", + "\x1bOy": "9", + "\x1bOM": (Keys.Enter,), + # WezTerm on macOS emits sequences for Opt and keys on the top numeric + # row; whereas other terminals provide various characters. The following + # swallow up those sequences and turns them into characters the same as + # the other terminals. + "\x1b§": "§", + "\x1b1": "¡", + "\x1b2": "™", + "\x1b3": "£", + "\x1b4": "¢", + "\x1b5": "∞", + "\x1b6": "§", + "\x1b7": "¶", + "\x1b8": "•", + "\x1b9": "ª", + "\x1b0": "º", + "\x1b-": "–", + "\x1b=": "≠", + # Ctrl+§ on kitty is different from most other terminals on macOS. + "\x1b[167;5u": "0", + ############################################################################ + # The ignore section. Only add sequences here if they are going to be + # ignored. Also, when adding a sequence here, please include a note as + # to why it is being ignored; ideally citing sources if possible. + ############################################################################ + # The following 2 are inherited from prompt toolkit. They relate to a + # press of 5 on the numeric keypad, when *not* in number mode. + "\x1b[E": IGNORE_SEQUENCE, # Xterm. + "\x1b[G": IGNORE_SEQUENCE, # Linux console. + # Various ctrl+cmd+ keys under Kitty on macOS. + "\x1b[3;13~": IGNORE_SEQUENCE, # ctrl-cmd-del + "\x1b[1;13H": IGNORE_SEQUENCE, # ctrl-cmd-home + "\x1b[1;13F": IGNORE_SEQUENCE, # ctrl-cmd-end + "\x1b[5;13~": IGNORE_SEQUENCE, # ctrl-cmd-pgup + "\x1b[6;13~": IGNORE_SEQUENCE, # ctrl-cmd-pgdn + "\x1b[49;13u": IGNORE_SEQUENCE, # ctrl-cmd-1 + "\x1b[50;13u": IGNORE_SEQUENCE, # ctrl-cmd-2 + "\x1b[51;13u": IGNORE_SEQUENCE, # ctrl-cmd-3 + "\x1b[52;13u": IGNORE_SEQUENCE, # ctrl-cmd-4 + "\x1b[53;13u": IGNORE_SEQUENCE, # ctrl-cmd-5 + "\x1b[54;13u": IGNORE_SEQUENCE, # ctrl-cmd-6 + "\x1b[55;13u": IGNORE_SEQUENCE, # ctrl-cmd-7 + "\x1b[56;13u": IGNORE_SEQUENCE, # ctrl-cmd-8 + "\x1b[57;13u": IGNORE_SEQUENCE, # ctrl-cmd-9 + "\x1b[48;13u": IGNORE_SEQUENCE, # ctrl-cmd-0 + "\x1b[45;13u": IGNORE_SEQUENCE, # ctrl-cmd-- + "\x1b[61;13u": IGNORE_SEQUENCE, # ctrl-cmd-+ + "\x1b[91;13u": IGNORE_SEQUENCE, # ctrl-cmd-[ + "\x1b[93;13u": IGNORE_SEQUENCE, # ctrl-cmd-] + "\x1b[92;13u": IGNORE_SEQUENCE, # ctrl-cmd-\ + "\x1b[39;13u": IGNORE_SEQUENCE, # ctrl-cmd-' + "\x1b[59;13u": IGNORE_SEQUENCE, # ctrl-cmd-; + "\x1b[47;13u": IGNORE_SEQUENCE, # ctrl-cmd-/ + "\x1b[46;13u": IGNORE_SEQUENCE, # ctrl-cmd-. +} + +# https://gist.github.com/christianparpart/d8a62cc1ab659194337d73e399004036 +SYNC_START = "\x1b[?2026h" +SYNC_END = "\x1b[?2026l" + + +def set_pointer_shape(shape: str) -> str: + """Generate escape sequence to set pointer (cursor) shape using Kitty protocol. + + Args: + shape: The pointer shape name (e.g., "default", "pointer", "text", "crosshair", etc.) + + Returns: + The escape sequence to set the pointer shape. + + See: https://sw.kovidgoyal.net/kitty/pointer-shapes/ + """ + # Kitty pointer shape protocol: ESC ] 22 ; ST + # where ST is ESC \ or BEL (\x07) + # Using BEL as terminator for better compatibility + return f"\x1b]22;{shape}\x07" diff --git a/src/memray/_vendor/textual/_ansi_theme.py b/src/memray/_vendor/textual/_ansi_theme.py new file mode 100644 index 0000000000..c3418d2227 --- /dev/null +++ b/src/memray/_vendor/textual/_ansi_theme.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from rich.terminal_theme import TerminalTheme + + +def rgb(red: int, green: int, blue: int) -> tuple[int, int, int]: + """Define an RGB color. + + This exists mainly so that a VSCode extension can render the colors inline. + + Args: + red: Red component. + green: Green component. + blue: Blue component. + + Returns: + Color triplet. + """ + return red, green, blue + + +MONOKAI = TerminalTheme( + rgb(12, 12, 12), + rgb(217, 217, 217), + [ + rgb(26, 26, 26), + rgb(244, 0, 95), + rgb(152, 224, 36), + rgb(253, 151, 31), + rgb(157, 101, 255), + rgb(244, 0, 95), + rgb(88, 209, 235), + rgb(196, 197, 181), + rgb(98, 94, 76), + ], + [ + rgb(244, 0, 95), + rgb(152, 224, 36), + rgb(224, 213, 97), + rgb(157, 101, 255), + rgb(244, 0, 95), + rgb(88, 209, 235), + rgb(246, 246, 239), + ], +) + +ALABASTER = TerminalTheme( + rgb(247, 247, 247), + rgb(0, 0, 0), + [ + rgb(0, 0, 0), + rgb(170, 55, 49), + rgb(68, 140, 39), + rgb(203, 144, 0), + rgb(50, 92, 192), + rgb(122, 62, 157), + rgb(0, 131, 178), + rgb(247, 247, 247), + rgb(119, 119, 119), + ], + [ + rgb(240, 80, 80), + rgb(96, 203, 0), + rgb(255, 188, 93), + rgb(0, 122, 204), + rgb(230, 76, 230), + rgb(0, 170, 203), + rgb(247, 247, 247), + ], +) + +DEFAULT_TERMINAL_THEME = MONOKAI diff --git a/src/memray/_vendor/textual/_arrange.py b/src/memray/_vendor/textual/_arrange.py new file mode 100644 index 0000000000..e7d9d7a650 --- /dev/null +++ b/src/memray/_vendor/textual/_arrange.py @@ -0,0 +1,253 @@ +from __future__ import annotations + +from collections import defaultdict +from fractions import Fraction +from operator import attrgetter +from typing import TYPE_CHECKING, Iterable, Mapping, Sequence + +from memray._vendor.textual._partition import partition +from memray._vendor.textual.geometry import NULL_OFFSET, NULL_SPACING, Region, Size, Spacing +from memray._vendor.textual.layout import DockArrangeResult, WidgetPlacement + +if TYPE_CHECKING: + from memray._vendor.textual.widget import Widget + +# TODO: This is a bit of a fudge, need to ensure it is impossible for layouts to generate this value +TOP_Z = 2**31 - 1 + + +def _build_layers(widgets: Iterable[Widget]) -> Mapping[str, Sequence[Widget]]: + """Organize widgets into layers. + + Args: + widgets: The widgets. + + Returns: + A mapping of layer name onto the widgets within the layer. + """ + layers: defaultdict[str, list[Widget]] = defaultdict(list) + for widget in widgets: + layers[widget.layer].append(widget) + return layers + + +_get_dock = attrgetter("styles.is_docked") +_get_split = attrgetter("styles.is_split") +_get_display = attrgetter("display") + + +def arrange( + widget: Widget, + children: Sequence[Widget], + size: Size, + viewport: Size, + optimal: bool = False, +) -> DockArrangeResult: + """Arrange widgets by applying docks and calling layouts + + Args: + widget: The parent (container) widget. + size: The size of the available area. + viewport: The size of the viewport (terminal). + + Returns: + Widget arrangement information. + """ + placements: list[WidgetPlacement] = [] + scroll_spacing = NULL_SPACING + styles = widget.styles + + # Widgets which will be displayed + display_widgets = list(filter(_get_display, children)) + # Widgets organized into layers + layers = _build_layers(display_widgets) + + for widgets in layers.values(): + # Partition widgets into split widgets and non-split widgets + non_split_widgets, split_widgets = partition(_get_split, widgets) + if split_widgets: + _split_placements, dock_region = _arrange_split_widgets( + split_widgets, size, viewport + ) + placements.extend(_split_placements) + else: + dock_region = size.region + + split_spacing = size.region.get_spacing_between(dock_region) + + # Partition widgets into "layout" widgets (those that appears in the normal 'flow' of the + # document), and "dock" widgets which are positioned relative to an edge + layout_widgets, dock_widgets = partition(_get_dock, non_split_widgets) + + # Arrange docked widgets + if dock_widgets: + _dock_placements, dock_spacing = _arrange_dock_widgets( + dock_widgets, dock_region, viewport, greedy=not optimal + ) + placements.extend(_dock_placements) + dock_region = dock_region.shrink(dock_spacing) + else: + dock_spacing = Spacing() + + dock_spacing += split_spacing + + if layout_widgets: + # Arrange layout widgets (i.e. not docked) + layout_placements = widget.process_layout( + widget.layout.arrange( + widget, layout_widgets, dock_region.size, greedy=not optimal + ) + ) + scroll_spacing = scroll_spacing.grow_maximum(dock_spacing) + placement_offset = dock_region.offset + # Perform any alignment of the widgets. + if styles.align_horizontal != "left" or styles.align_vertical != "top": + bounding_region = WidgetPlacement.get_bounds(layout_placements) + container_width, container_height = dock_region.size + placement_offset += styles._align_size( + bounding_region.size, + widget._extrema.apply_dimensions( + 0 if styles.is_auto_width else container_width, + 0 if styles.is_auto_height else container_height, + ), + ).clamped + + if placement_offset: + # Translate placements if required. + layout_placements = WidgetPlacement.translate( + layout_placements, placement_offset + ) + + WidgetPlacement.apply_absolute(layout_placements) + placements.extend(layout_placements) + + return DockArrangeResult(placements, set(display_widgets), scroll_spacing) + + +def _arrange_dock_widgets( + dock_widgets: Sequence[Widget], region: Region, viewport: Size, greedy: bool = True +) -> tuple[list[WidgetPlacement], Spacing]: + """Arrange widgets which are *docked*. + + Args: + dock_widgets: Widgets with a non-empty dock. + region: Region to dock within. + viewport: Size of the viewport. + + Returns: + A tuple of widget placements, and additional spacing around them. + """ + _WidgetPlacement = WidgetPlacement + top_z = TOP_Z + region_offset = region.offset + size = region.size + width, height = size + null_spacing = NULL_SPACING + + top = right = bottom = left = 0 + + placements: list[WidgetPlacement] = [] + append_placement = placements.append + + for dock_widget in dock_widgets: + edge = dock_widget.styles.dock + + box_model = dock_widget._get_box_model( + size, viewport, Fraction(size.width), Fraction(size.height), greedy=greedy + ) + widget_width_fraction, widget_height_fraction, margin = box_model + widget_width = int(widget_width_fraction) + margin.width + widget_height = int(widget_height_fraction) + margin.height + + if edge == "bottom": + dock_region = Region(0, height - widget_height, widget_width, widget_height) + bottom = max(bottom, widget_height) + elif edge == "top": + dock_region = Region(0, 0, widget_width, widget_height) + top = max(top, widget_height) + elif edge == "left": + dock_region = Region(0, 0, widget_width, widget_height) + left = max(left, widget_width) + elif edge == "right": + dock_region = Region(width - widget_width, 0, widget_width, widget_height) + right = max(right, widget_width) + else: + # Should not occur, mainly to keep Mypy happy + raise AssertionError("invalid value for dock edge") # pragma: no-cover + + dock_region = dock_region.shrink(margin) + styles = dock_widget.styles + offset = ( + styles.offset.resolve( + size, + viewport, + ) + if styles.has_rule("offset") + else NULL_OFFSET + ) + append_placement( + _WidgetPlacement( + dock_region.translate(region_offset), + offset, + null_spacing, + dock_widget, + top_z, + True, + False, + ) + ) + + dock_spacing = Spacing(top, right, bottom, left) + return (placements, dock_spacing) + + +def _arrange_split_widgets( + split_widgets: Sequence[Widget], size: Size, viewport: Size +) -> tuple[list[WidgetPlacement], Region]: + """Arrange split widgets. + + Split widgets are "docked" but also reduce the area available for regular widgets. + + Args: + split_widgets: Widgets to arrange. + size: Available area to arrange. + viewport: Viewport (size of terminal). + + Returns: + A tuple of widget placements, and the remaining view area. + """ + _WidgetPlacement = WidgetPlacement + placements: list[WidgetPlacement] = [] + append_placement = placements.append + view_region = size.region + null_spacing = NULL_SPACING + null_offset = NULL_OFFSET + + for split_widget in split_widgets: + split = split_widget.styles.split + box_model = split_widget._get_box_model( + size, viewport, Fraction(size.width), Fraction(size.height) + ) + widget_width_fraction, widget_height_fraction, margin = box_model + if split == "bottom": + widget_height = int(widget_height_fraction) + margin.height + view_region, split_region = view_region.split_horizontal(-widget_height) + elif split == "top": + widget_height = int(widget_height_fraction) + margin.height + split_region, view_region = view_region.split_horizontal(widget_height) + elif split == "left": + widget_width = int(widget_width_fraction) + margin.width + split_region, view_region = view_region.split_vertical(widget_width) + elif split == "right": + widget_width = int(widget_width_fraction) + margin.width + view_region, split_region = view_region.split_vertical(-widget_width) + else: + raise AssertionError("invalid value for split edge") # pragma: no-cover + + append_placement( + _WidgetPlacement( + split_region, null_offset, null_spacing, split_widget, 1, True, False + ) + ) + + return placements, view_region diff --git a/src/memray/_vendor/textual/_auto_scroll.py b/src/memray/_vendor/textual/_auto_scroll.py new file mode 100644 index 0000000000..7d724f62d6 --- /dev/null +++ b/src/memray/_vendor/textual/_auto_scroll.py @@ -0,0 +1,30 @@ +from memray._vendor.textual.geometry import Region + + +def get_auto_scroll_regions( + widget_region: Region, auto_scroll_lines: int +) -> tuple[Region, Region]: + """Get non-overlapping regions which should auto scroll when selecting. + + Args: + widget_region: The region occupied by the widget. + auto_scroll_lines: Number of lines in auto scroll regions. + + Returns: + A pair of regions. The first for the region to scroll up, the second for the region to scroll down. + """ + x, y, width, height = widget_region + + # Divide the region in to two, non overlapping regions + top_half, bottom_half = widget_region.split_horizontal(height // 2) + + # Get a region at the top with the desired dimensions + up_region = Region(x, y, width, auto_scroll_lines) + # Ensure it is no larger than the top half + up_region = top_half.intersection(up_region) + + # Repeat for the bottom half + down_region = Region(x, y + height - auto_scroll_lines, width, auto_scroll_lines) + down_region = bottom_half.intersection(down_region) + + return up_region, down_region diff --git a/src/memray/_vendor/textual/_binary_encode.py b/src/memray/_vendor/textual/_binary_encode.py new file mode 100644 index 0000000000..36cb54096c --- /dev/null +++ b/src/memray/_vendor/textual/_binary_encode.py @@ -0,0 +1,325 @@ +""" +An encoding / decoding format suitable for serializing data structures to binary. + +This is based on https://en.wikipedia.org/wiki/Bencode with some extensions. + +The following data types may be encoded: + +- None +- int +- bool +- bytes +- str +- list +- tuple +- dict + +""" + +from __future__ import annotations + +from typing import Any, Callable + + +class DecodeError(Exception): + """A problem decoding data.""" + + +def dump(data: object) -> bytes: + """Encodes a data structure into bytes. + + Args: + data: Data structure + + Returns: + A byte string encoding the data. + """ + + def encode_none(_datum: None) -> bytes: + """ + Encodes a None value. + + Args: + datum: Always None. + + Returns: + None encoded. + """ + return b"N" + + def encode_bool(datum: bool) -> bytes: + """ + Encode a boolean value. + + Args: + datum: The boolean value to encode. + + Returns: + The encoded bytes. + """ + return b"T" if datum else b"F" + + def encode_int(datum: int) -> bytes: + """ + Encode an integer value. + + Args: + datum: The integer value to encode. + + Returns: + The encoded bytes. + """ + return b"i%ie" % datum + + def encode_bytes(datum: bytes) -> bytes: + """ + Encode a bytes value. + + Args: + datum: The bytes value to encode. + + Returns: + The encoded bytes. + """ + return b"%i:%s" % (len(datum), datum) + + def encode_string(datum: str) -> bytes: + """ + Encode a string value. + + Args: + datum: The string value to encode. + + Returns: + The encoded bytes. + """ + encoded_data = datum.encode("utf-8") + return b"s%i:%s" % (len(encoded_data), encoded_data) + + def encode_list(datum: list) -> bytes: + """ + Encode a list value. + + Args: + datum: The list value to encode. + + Returns: + The encoded bytes. + """ + return b"l%se" % b"".join(encode(element) for element in datum) + + def encode_tuple(datum: tuple) -> bytes: + """ + Encode a tuple value. + + Args: + datum: The tuple value to encode. + + Returns: + The encoded bytes. + """ + return b"t%se" % b"".join(encode(element) for element in datum) + + def encode_dict(datum: dict) -> bytes: + """ + Encode a dictionary value. + + Args: + datum: The dictionary value to encode. + + Returns: + The encoded bytes. + """ + return b"d%se" % b"".join( + b"%s%s" % (encode(key), encode(value)) for key, value in datum.items() + ) + + ENCODERS: dict[type, Callable[[Any], Any]] = { + type(None): encode_none, + bool: encode_bool, + int: encode_int, + bytes: encode_bytes, + str: encode_string, + list: encode_list, + tuple: encode_tuple, + dict: encode_dict, + } + + def encode(datum: object) -> bytes: + """Recursively encode data. + + Args: + datum: Data suitable for encoding. + + Raises: + TypeError: If `datum` is not one of the supported types. + + Returns: + Encoded data bytes. + """ + try: + decoder = ENCODERS[type(datum)] + except KeyError: + raise TypeError("Can't encode {datum!r}") from None + return decoder(datum) + + return encode(data) + + +def load(encoded: bytes) -> object: + """Load an encoded data structure from bytes. + + Args: + encoded: Encoded data in bytes. + + Raises: + DecodeError: If an error was encountered decoding the string. + + Returns: + Decoded data. + """ + if not isinstance(encoded, bytes): + raise TypeError("must be bytes") + max_position = len(encoded) + position = 0 + + def get_byte() -> bytes: + """Get an encoded byte and advance position. + + Raises: + DecodeError: If the end of the data was reached + + Returns: + A bytes object with a single byte. + """ + nonlocal position + if position >= max_position: + raise DecodeError("More data expected") + character = encoded[position : position + 1] + position += 1 + return character + + def peek_byte() -> bytes: + """Get the byte at the current position, but don't advance position. + + Returns: + A bytes object with a single byte. + """ + return encoded[position : position + 1] + + def get_bytes(size: int) -> bytes: + """Get a number of bytes of encode data. + + Args: + size: Number of bytes to retrieve. + + Raises: + DecodeError: If there aren't enough bytes. + + Returns: + A bytes object. + """ + nonlocal position + bytes_data = encoded[position : position + size] + if len(bytes_data) != size: + raise DecodeError(b"Missing bytes in {bytes_data!r}") + position += size + return bytes_data + + def decode_int() -> int: + """Decode an int from the encoded data. + + Returns: + An integer. + """ + int_bytes = b"" + while (byte := get_byte()) != b"e": + int_bytes += byte + return int(int_bytes) + + def decode_bytes(size_bytes: bytes) -> bytes: + """Decode a bytes string from the encoded data. + + Returns: + A bytes object. + """ + while (byte := get_byte()) != b":": + size_bytes += byte + bytes_string = get_bytes(int(size_bytes)) + return bytes_string + + def decode_string() -> str: + """Decode a (utf-8 encoded) string from the encoded data. + + Returns: + A string. + """ + size_bytes = b"" + while (byte := get_byte()) != b":": + size_bytes += byte + bytes_string = get_bytes(int(size_bytes)) + decoded_string = bytes_string.decode("utf-8", errors="replace") + return decoded_string + + def decode_list() -> list[object]: + """Decode a list. + + Returns: + A list of data. + """ + elements: list[object] = [] + add_element = elements.append + while peek_byte() != b"e": + add_element(decode()) + get_byte() + return elements + + def decode_tuple() -> tuple[object, ...]: + """Decode a tuple. + + Returns: + A tuple of decoded data. + """ + elements: list[object] = [] + add_element = elements.append + while peek_byte() != b"e": + add_element(decode()) + get_byte() + return tuple(elements) + + def decode_dict() -> dict[object, object]: + """Decode a dict. + + Returns: + A dict of decoded data. + """ + elements: dict[object, object] = {} + add_element = elements.__setitem__ + while peek_byte() != b"e": + add_element(decode(), decode()) + get_byte() + return elements + + DECODERS = { + b"i": decode_int, + b"s": decode_string, + b"l": decode_list, + b"t": decode_tuple, + b"d": decode_dict, + b"T": lambda: True, + b"F": lambda: False, + b"N": lambda: None, + } + + def decode() -> object: + """Recursively decode data. + + Returns: + Decoded data. + """ + decoder = DECODERS.get(initial := get_byte(), None) + if decoder is None: + return decode_bytes(initial) + return decoder() + + return decode() diff --git a/src/memray/_vendor/textual/_border.py b/src/memray/_vendor/textual/_border.py new file mode 100644 index 0000000000..80898999f6 --- /dev/null +++ b/src/memray/_vendor/textual/_border.py @@ -0,0 +1,479 @@ +from __future__ import annotations + +from functools import lru_cache +from typing import TYPE_CHECKING, Iterable, Tuple, cast + +from rich.segment import Segment + +from memray._vendor.textual.color import Color +from memray._vendor.textual.css.types import AlignHorizontal, EdgeStyle, EdgeType +from memray._vendor.textual.style import Style + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + from memray._vendor.textual.content import Content + +INNER = 1 +OUTER = 2 + +BORDER_CHARS: dict[ + EdgeType, tuple[tuple[str, str, str], tuple[str, str, str], tuple[str, str, str]] +] = { + # Three tuples for the top, middle, and bottom rows. + # The sub-tuples are the characters for the left, center, and right borders. + "": ( + (" ", " ", " "), + (" ", " ", " "), + (" ", " ", " "), + ), + "ascii": ( + ("+", "-", "+"), + ("|", " ", "|"), + ("+", "-", "+"), + ), + "none": ( + (" ", " ", " "), + (" ", " ", " "), + (" ", " ", " "), + ), + "hidden": ( + (" ", " ", " "), + (" ", " ", " "), + (" ", " ", " "), + ), + "blank": ( + (" ", " ", " "), + (" ", " ", " "), + (" ", " ", " "), + ), + "round": ( + ("╭", "─", "╮"), + ("│", " ", "│"), + ("╰", "─", "╯"), + ), + "solid": ( + ("┌", "─", "┐"), + ("│", " ", "│"), + ("└", "─", "┘"), + ), + "double": ( + ("╔", "═", "╗"), + ("║", " ", "║"), + ("╚", "═", "╝"), + ), + "dashed": ( + ("┏", "╍", "┓"), + ("╏", " ", "╏"), + ("┗", "╍", "┛"), + ), + "heavy": ( + ("┏", "━", "┓"), + ("┃", " ", "┃"), + ("┗", "━", "┛"), + ), + "inner": ( + ("▗", "▄", "▖"), + ("▐", " ", "▌"), + ("▝", "▀", "▘"), + ), + "outer": ( + ("▛", "▀", "▜"), + ("▌", " ", "▐"), + ("▙", "▄", "▟"), + ), + "thick": ( + ("█", "▀", "█"), + ("█", " ", "█"), + ("█", "▄", "█"), + ), + "block": ( + ("▄", "▄", "▄"), + ("█", " ", "█"), + ("▀", "▀", "▀"), + ), + "hkey": ( + ("▔", "▔", "▔"), + (" ", " ", " "), + ("▁", "▁", "▁"), + ), + "vkey": ( + ("▏", " ", "▕"), + ("▏", " ", "▕"), + ("▏", " ", "▕"), + ), + "tall": ( + ("▊", "▔", "▎"), + ("▊", " ", "▎"), + ("▊", "▁", "▎"), + ), + "panel": ( + ("▊", "█", "▎"), + ("▊", " ", "▎"), + ("▊", "▁", "▎"), + ), + "tab": ( + ("▁", "▁", "▁"), + ("▎", " ", "▊"), + ("▔", "▔", "▔"), + ), + "wide": ( + ("▁", "▁", "▁"), + ("▎", " ", "▊"), + ("▔", "▔", "▔"), + ), +} + +# Some of the borders are on the widget background and some are on the background of the parent +# This table selects which for each character, 0 indicates the widget, 1 selects the parent. +# 2 and 3 reverse a cross-combination of the background and foreground colors of 0 and 1. +BORDER_LOCATIONS: dict[ + EdgeType, tuple[tuple[int, int, int], tuple[int, int, int], tuple[int, int, int]] +] = { + "": ( + (0, 0, 0), + (0, 0, 0), + (0, 0, 0), + ), + "ascii": ( + (0, 0, 0), + (0, 0, 0), + (0, 0, 0), + ), + "none": ( + (0, 0, 0), + (0, 0, 0), + (0, 0, 0), + ), + "hidden": ( + (0, 0, 0), + (0, 0, 0), + (0, 0, 0), + ), + "blank": ( + (0, 0, 0), + (0, 0, 0), + (0, 0, 0), + ), + "round": ( + (0, 0, 0), + (0, 0, 0), + (0, 0, 0), + ), + "solid": ( + (0, 0, 0), + (0, 0, 0), + (0, 0, 0), + ), + "double": ( + (0, 0, 0), + (0, 0, 0), + (0, 0, 0), + ), + "dashed": ( + (0, 0, 0), + (0, 0, 0), + (0, 0, 0), + ), + "heavy": ( + (0, 0, 0), + (0, 0, 0), + (0, 0, 0), + ), + "inner": ( + (1, 1, 1), + (1, 1, 1), + (1, 1, 1), + ), + "outer": ( + (0, 0, 0), + (0, 0, 0), + (0, 0, 0), + ), + "thick": ( + (0, 0, 0), + (0, 0, 0), + (0, 0, 0), + ), + "block": ( + (1, 1, 1), + (0, 0, 0), + (1, 1, 1), + ), + "hkey": ( + (0, 0, 0), + (0, 0, 0), + (0, 0, 0), + ), + "vkey": ( + (0, 0, 0), + (0, 0, 0), + (0, 0, 0), + ), + "tall": ( + (2, 0, 1), + (2, 0, 1), + (2, 0, 1), + ), + "panel": ( + (2, 0, 1), + (2, 0, 1), + (2, 0, 1), + ), + "tab": ( + (1, 1, 1), + (0, 1, 3), + (1, 1, 1), + ), + "wide": ( + (1, 1, 1), + (0, 1, 3), + (1, 1, 1), + ), +} + +# Some borders (such as panel) require that the title (and subtitle) be draw in reverse. +# This is a mapping of the border type on to a tuple for the top and bottom borders, to indicate +# reverse colors is required. +BORDER_TITLE_FLIP: dict[str, tuple[bool, bool]] = { + "panel": (True, False), + "tab": (True, True), +} + +# In a similar fashion, we extract the border _label_ locations for easier access when +# rendering a border label. +# The values are a pair with (title location, subtitle location). +BORDER_LABEL_LOCATIONS: dict[EdgeType, tuple[int, int]] = { + edge_type: (locations[0][1], locations[2][1]) + for edge_type, locations in BORDER_LOCATIONS.items() +} + +INVISIBLE_EDGE_TYPES = cast("frozenset[EdgeType]", frozenset(("", "none", "hidden"))) + +BorderValue: TypeAlias = Tuple[EdgeType, Color] + +BoxSegments: TypeAlias = Tuple[ + Tuple[Segment, Segment, Segment], + Tuple[Segment, Segment, Segment], + Tuple[Segment, Segment, Segment], +] + +Borders: TypeAlias = Tuple[EdgeStyle, EdgeStyle, EdgeStyle, EdgeStyle] + +REVERSE_STYLE = Style(reverse=True) + + +@lru_cache(maxsize=1024) +def get_box( + name: EdgeType, + inner_style: Style, + outer_style: Style, + style: Style, +) -> BoxSegments: + """Get segments used to render a box. + + Args: + name: Name of the box type. + inner_style: The inner style (widget background). + outer_style: The outer style (parent background). + style: Widget style. + + Returns: + A tuple of 3 Segment triplets. + """ + _Segment = Segment + ( + (top1, top2, top3), + (mid1, mid2, mid3), + (bottom1, bottom2, bottom3), + ) = BORDER_CHARS[name] + + ( + (ltop1, ltop2, ltop3), + (lmid1, lmid2, lmid3), + (lbottom1, lbottom2, lbottom3), + ) = BORDER_LOCATIONS[name] + + inner = inner_style + style + outer = outer_style + style + + styles = ( + inner.rich_style, + outer.rich_style, + Style(outer.background, inner.foreground, reverse=True).rich_style, + Style(inner.background, outer.foreground, reverse=True).rich_style, + ) + + return ( + ( + _Segment(top1, styles[ltop1]), + _Segment(top2, styles[ltop2]), + _Segment(top3, styles[ltop3]), + ), + ( + _Segment(mid1, styles[lmid1]), + _Segment(mid2, styles[lmid2]), + _Segment(mid3, styles[lmid3]), + ), + ( + _Segment(bottom1, styles[lbottom1]), + _Segment(bottom2, styles[lbottom2]), + _Segment(bottom3, styles[lbottom3]), + ), + ) + + +def render_border_label( + label: tuple[Content, Style], + is_title: bool, + name: EdgeType, + width: int, + inner_style: Style, + outer_style: Style, + style: Style, + has_left_corner: bool, + has_right_corner: bool, +) -> Iterable[Segment]: + """Render a border label (the title or subtitle) with optional markup. + + The styling that may be embedded in the label will be reapplied after taking into + account the inner, outer, and border-specific, styles. + + Args: + label: Tuple of label and style to render in the border. + is_title: Whether we are rendering the title (`True`) or the subtitle (`False`). + name: Name of the box type. + width: The width, in cells, of the space available for the whole edge. + This is the total space that may also be needed for the border corners and + the whitespace padding around the (sub)title. Thus, the effective space + available for the border label is: + - `width` if no corner is needed; + - `width - 2` if one corner is needed; and + - `width - 4` if both corners are needed. + inner_style: The inner style (widget background). + outer_style: The outer style (parent background). + style: Widget style. + console: The console that will render the markup in the label. + has_left_corner: Whether the border edge will have to render a left corner. + has_right_corner: Whether the border edge will have to render a right corner. + + Returns: + A list of segments that represent the full label and surrounding padding. + """ + # How many cells do we need to reserve for surrounding blanks and corners? + corners_needed = has_left_corner + has_right_corner + cells_reserved = 2 * corners_needed + + text_label, label_style = label + + if not text_label.cell_length or width <= cells_reserved: + return + + text_label = text_label.truncate(width - cells_reserved, ellipsis=True) + if has_left_corner: + text_label = text_label.pad_left(1) + if has_right_corner: + text_label = text_label.pad_right(1) + text_label = text_label.stylize_before(label_style) + + label_style_location = BORDER_LABEL_LOCATIONS[name][0 if is_title else 1] + flip_top, flip_bottom = BORDER_TITLE_FLIP.get(name, (False, False)) + + inner = inner_style + style + outer = outer_style + style + + base_style: Style + if label_style_location == 0: + base_style = inner + elif label_style_location == 1: + base_style = outer + elif label_style_location == 2: + base_style = Style(outer.background, inner.foreground, reverse=True) + elif label_style_location == 3: + base_style = Style(inner.background, outer.foreground, reverse=True) + else: + assert False + + if (flip_top and is_title) or (flip_bottom and not is_title): + base_style = base_style.without_color + Style( + background=base_style.foreground, + foreground=base_style.background, + ) + + segments = text_label.render_segments(base_style) + yield from segments + + +def render_row( + box_row: tuple[Segment, Segment, Segment], + width: int, + left: bool, + right: bool, + label_segments: Iterable[Segment], + label_alignment: AlignHorizontal = "left", +) -> Iterable[Segment]: + """Compose a box row with its padded label. + + This is the function that actually does the work that `render_row` is intended + to do, but we have many lists of segments flowing around, so it becomes easier + to yield the segments bit by bit, and the aggregate everything into a list later. + + Args: + box_row: Corners and side segments. + width: Total width of resulting line. + left: Render left corner. + right: Render right corner. + label_segments: The segments that make up the label. + label_alignment: Where to horizontally align the label. + + Returns: + An iterable of segments. + """ + box1, box2, box3 = box_row + + corners_needed = left + right + label_segments_list = list(label_segments) + + label_length = sum((segment.cell_length for segment in label_segments_list), 0) + space_available = max(0, width - corners_needed - label_length) + + if left: + yield box1 + + if not space_available: + yield from label_segments_list + elif not label_length: + yield Segment(box2.text * space_available, box2.style) + elif label_alignment == "left" or label_alignment == "right": + edge = Segment(box2.text * (space_available - 1), box2.style) + if label_alignment == "left": + yield Segment(box2.text, box2.style) + yield from label_segments_list + yield edge + else: + yield edge + yield from label_segments_list + yield Segment(box2.text, box2.style) + elif label_alignment == "center": + length_on_left = space_available // 2 + length_on_right = space_available - length_on_left + yield Segment(box2.text * length_on_left, box2.style) + yield from label_segments_list + yield Segment(box2.text * length_on_right, box2.style) + else: + assert False + + if right: + yield box3 + + +_edge_type_normalization_table: dict[EdgeType, EdgeType] = { + # i.e. we normalize "border: none;" to "border: ;". + # As a result our layout-related calculations that include borders are simpler (and have better performance) + "none": "", + "hidden": "", +} + + +def normalize_border_value(value: BorderValue) -> BorderValue: + return _edge_type_normalization_table.get(value[0], value[0]), value[1] diff --git a/src/memray/_vendor/textual/_box_drawing.py b/src/memray/_vendor/textual/_box_drawing.py new file mode 100644 index 0000000000..e30d73b4c6 --- /dev/null +++ b/src/memray/_vendor/textual/_box_drawing.py @@ -0,0 +1,366 @@ +""" +Box drawing utilities for Canvas. + +The box drawing characters have zero to four lines radiating from the center of the glyph. +There are three line types: thin, heavy, and double. These are indicated by 1, 2, and 3 respectively (0 for no line). + +This code represents the characters as a tuple of 4 integers, (, , , ). This format +makes it possible to logically combine characters together, as there is no mathematical relationship in the unicode db. + +Note that not all combinations are possible. Characters can have a maximum of 2 border types in a single glyph. +There are also fewer characters for the "double" line type. + +""" + +from __future__ import annotations + +from functools import lru_cache + +from typing_extensions import TypeAlias + +Quad: TypeAlias = "tuple[int, int, int, int]" +"""Four values indicating the composition of the box character.""" + +# Yes, I typed this out by hand. - WM +BOX_CHARACTERS: dict[Quad, str] = { + (0, 0, 0, 0): " ", + (0, 0, 0, 1): "╴", + (0, 0, 0, 2): "╸", + (0, 0, 0, 3): "╸", + # + (0, 0, 1, 0): "╷", + (0, 0, 1, 1): "┐", + (0, 0, 1, 2): "┑", + (0, 0, 1, 3): "╕", + # + (0, 0, 2, 0): "╻", + (0, 0, 2, 1): "┒", + (0, 0, 2, 2): "┓", + (0, 0, 2, 3): "╕", + # + (0, 0, 3, 0): "╻", + (0, 0, 3, 1): "╖", + (0, 0, 3, 2): "╖", + (0, 0, 3, 3): "╗", + # + (0, 1, 0, 0): "╶", + (0, 1, 0, 1): "─", + (0, 1, 0, 2): "╾", + (0, 1, 0, 3): "╼", + # + (0, 1, 1, 0): "┌", + (0, 1, 1, 1): "┬", + (0, 1, 1, 2): "┭", + (0, 1, 1, 3): "╤", + # + (0, 1, 2, 0): "┎", + (0, 1, 2, 1): "┰", + (0, 1, 2, 2): "┱", + (0, 1, 2, 3): "┱", + # + (0, 1, 3, 0): "╓", + (0, 1, 3, 1): "╥", + (0, 1, 3, 2): "╥", + (0, 1, 3, 3): "╥", + # + (0, 2, 0, 0): "╺", + (0, 2, 0, 1): "╼", + (0, 2, 0, 2): "━", + (0, 2, 0, 3): "━", + # + (0, 2, 1, 0): "┍", + (0, 2, 1, 1): "┮", + (0, 2, 1, 2): "┯", + (0, 2, 1, 3): "┯", + # + (0, 2, 2, 0): "┏", + (0, 2, 2, 1): "┲", + (0, 2, 2, 2): "┳", + (0, 2, 2, 3): "╦", + # + (0, 2, 3, 0): "╒", + (0, 2, 3, 1): "╥", + (0, 2, 3, 2): "╥", + (0, 2, 3, 3): "╦", + # + (0, 3, 0, 0): "╺", + (0, 3, 0, 1): "╾", + (0, 3, 0, 2): "╾", + (0, 3, 0, 3): "═", + # + (0, 3, 1, 0): "╒", + (0, 3, 1, 1): "╤", + (0, 3, 1, 2): "╤", + (0, 3, 1, 3): "╤", + # + (0, 3, 2, 0): "╒", + (0, 3, 2, 1): "╤", + (0, 3, 2, 2): "╤", + (0, 3, 2, 3): "╤", + # + (0, 3, 3, 0): "╔", + (0, 3, 3, 1): "╦", + (0, 3, 3, 2): "╦", + (0, 3, 3, 3): "╦", + # + (1, 0, 0, 0): "╵", + (1, 0, 0, 1): "┘", + (1, 0, 0, 2): "┙", + (1, 0, 0, 3): "╛", + # + (1, 0, 1, 0): "│", + (1, 0, 1, 1): "┤", + (1, 0, 1, 2): "┥", + (1, 0, 1, 3): "╡", + # + (1, 0, 2, 0): "╽", + (1, 0, 2, 1): "┧", + (1, 0, 2, 2): "┪", + (1, 0, 2, 3): "┪", + # + (1, 0, 3, 0): "╽", + (1, 0, 3, 1): "┧", + (1, 0, 3, 2): "┪", + (1, 0, 3, 3): "┪", + # + (1, 1, 0, 0): "└", + (1, 1, 0, 1): "┴", + (1, 1, 0, 2): "┵", + (1, 1, 0, 3): "┵", + # + (1, 1, 1, 0): "├", + (1, 1, 1, 1): "┼", + (1, 1, 1, 2): "┽", + (1, 1, 1, 3): "┽", + # + (1, 1, 2, 0): "┟", + (1, 1, 2, 1): "╁", + (1, 1, 2, 2): "╅", + (1, 1, 2, 3): "╅", + # + (1, 1, 3, 0): "┟", + (1, 1, 3, 1): "╁", + (1, 1, 3, 2): "╅", + (1, 1, 3, 3): "╅", + # + (1, 2, 0, 0): "┕", + (1, 2, 0, 1): "┶", + (1, 2, 0, 2): "┷", + (1, 2, 0, 3): "╧", + # + (1, 2, 1, 0): "┝", + (1, 2, 1, 1): "┾", + (1, 2, 1, 2): "┿", + (1, 2, 1, 3): "┿", + # + (1, 2, 2, 0): "┢", + (1, 2, 2, 1): "╆", + (1, 2, 2, 2): "╈", + (1, 2, 2, 3): "╈", + # + (1, 2, 3, 0): "┢", + (1, 2, 3, 1): "╆", + (1, 2, 3, 2): "╈", + (1, 2, 3, 3): "╈", + # + (1, 3, 0, 0): "╘", + (1, 3, 0, 1): "╧", + (1, 3, 0, 2): "╧", + (1, 3, 0, 3): "╧", + # + (1, 3, 1, 0): "╞", + (1, 3, 1, 1): "╬", + (1, 3, 1, 2): "╪", + (1, 3, 1, 3): "╪", + # + (1, 3, 2, 0): "╟", + (1, 3, 2, 1): "┾", + (1, 3, 2, 2): "┾", + (1, 3, 2, 3): "╪", + # + (1, 3, 3, 0): "╞", + (1, 3, 3, 1): "╆", + (1, 3, 3, 2): "╆", + (1, 3, 3, 3): "╈", + # + (2, 0, 0, 0): "╹", + (2, 0, 0, 1): "┚", + (2, 0, 0, 2): "┛", + (2, 0, 0, 3): "╛", + # + (2, 0, 1, 0): "╿", + (2, 0, 1, 1): "┦", + (2, 0, 1, 2): "┩", + (2, 0, 1, 3): "┩", + # + (2, 0, 2, 0): "┃", + (2, 0, 2, 1): "┨", + (2, 0, 2, 2): "┫", + (2, 0, 2, 3): "╢", + # + (2, 0, 3, 0): "║", + (2, 0, 3, 1): "╢", + (2, 0, 3, 2): "╢", + (2, 0, 3, 3): "╢", + # + (2, 1, 0, 0): "┖", + (2, 1, 0, 1): "┸", + (2, 1, 0, 2): "┹", + (2, 1, 0, 3): "┹", + # + (2, 1, 1, 0): "┞", + (2, 1, 1, 1): "╀", + (2, 1, 1, 2): "╃", + (2, 1, 1, 3): "╃", + # + (2, 1, 2, 0): "┠", + (2, 1, 2, 1): "╂", + (2, 1, 2, 2): "╉", + (2, 1, 2, 3): "╉", + # + (2, 1, 3, 0): "╟", + (2, 1, 3, 1): "╫", + (2, 1, 3, 2): "╫", + (2, 1, 3, 3): "╫", + # + (2, 2, 0, 0): "┗", + (2, 2, 0, 1): "┺", + (2, 2, 0, 2): "┻", + (2, 2, 0, 3): "┻", + # + (2, 2, 1, 0): "┡", + (2, 2, 1, 1): "╄", + (2, 2, 1, 2): "╇", + (2, 2, 1, 3): "╇", + # + (2, 2, 2, 0): "┣", + (2, 2, 2, 1): "╊", + (2, 2, 2, 2): "╋", + (2, 2, 2, 3): "╬", + # + (2, 2, 3, 0): "╠", + (2, 2, 3, 1): "╬", + (2, 2, 3, 2): "╬", + (2, 2, 3, 3): "╬", + # + (2, 3, 0, 0): "╚", + (2, 3, 0, 1): "╩", + (2, 3, 0, 2): "╩", + (2, 3, 0, 3): "╩", + # + (2, 3, 1, 0): "╞", + (2, 3, 1, 1): "╬", + (2, 3, 1, 2): "╬", + (2, 3, 1, 3): "╬", + # + (2, 3, 2, 0): "╞", + (2, 3, 2, 1): "╬", + (2, 3, 2, 2): "╬", + (2, 3, 2, 3): "╬", + # + (2, 3, 3, 0): "╠", + (2, 3, 3, 1): "╬", + (2, 3, 3, 2): "╬", + (2, 3, 3, 3): "╬", + # + (3, 0, 0, 0): "╹", + (3, 0, 0, 1): "╜", + (3, 0, 0, 2): "╜", + (3, 0, 0, 3): "╝", + # + (3, 0, 1, 0): "╿", + (3, 0, 1, 1): "┦", + (3, 0, 1, 2): "┦", + (3, 0, 1, 3): "┩", + # + (3, 0, 2, 0): "║", + (3, 0, 2, 1): "╢", + (3, 0, 2, 2): "╢", + (3, 0, 2, 3): "╣", + # + (3, 0, 3, 0): "║", + (3, 0, 3, 1): "╢", + (3, 0, 3, 2): "╢", + (3, 0, 3, 3): "╣", + # + (3, 1, 0, 0): "╙", + (3, 1, 0, 1): "╨", + (3, 1, 0, 2): "╨", + (3, 1, 0, 3): "╩", + # + (3, 1, 1, 0): "╟", + (3, 1, 1, 1): "╬", + (3, 1, 1, 2): "╬", + (3, 1, 1, 3): "╬", + # + (3, 1, 2, 0): "╟", + (3, 1, 2, 1): "╬", + (3, 1, 2, 2): "╬", + (3, 1, 2, 3): "╬", + # + (3, 1, 3, 0): "╟", + (3, 1, 3, 1): "╫", + (3, 1, 3, 2): "╫", + (3, 1, 3, 3): "╉", + # + (3, 2, 0, 0): "╙", + (3, 2, 0, 1): "╨", + (3, 2, 0, 2): "╨", + (3, 2, 0, 3): "╩", + # + (3, 2, 1, 0): "╟", + (3, 2, 1, 1): "╬", + (3, 2, 1, 2): "╬", + (3, 2, 1, 3): "╬", + # + (3, 2, 2, 0): "╟", + (3, 2, 2, 1): "╬", + (3, 2, 2, 2): "╬", + (3, 2, 2, 3): "╬", + # + (3, 2, 3, 0): "╟", + (3, 2, 3, 1): "╫", + (3, 2, 3, 2): "╫", + (3, 2, 3, 3): "╬", + # + (3, 3, 0, 0): "╚", + (3, 3, 0, 1): "╩", + (3, 3, 0, 2): "╩", + (3, 3, 0, 3): "╩", + # + (3, 3, 1, 0): "╠", + (3, 3, 1, 1): "╄", + (3, 3, 1, 2): "╄", + (3, 3, 1, 3): "╇", + # + (3, 3, 2, 0): "╠", + (3, 3, 2, 1): "╬", + (3, 3, 2, 2): "╬", + (3, 3, 2, 3): "╬", + # + (3, 3, 3, 0): "╠", + (3, 3, 3, 1): "╊", + (3, 3, 3, 2): "╬", + (3, 3, 3, 3): "╬", +} + + +@lru_cache(1024) +def combine_quads(box1: Quad, box2: Quad) -> Quad: + """Combine two box drawing quads. + + Args: + box1: Existing box quad. + box2: New box quad. + + Returns: + A new box quad. + """ + top1, right1, bottom1, left1 = box1 + top2, right2, bottom2, left2 = box2 + return ( + top2 or top1, + right2 or right1, + bottom2 or bottom1, + left2 or left1, + ) diff --git a/src/memray/_vendor/textual/_callback.py b/src/memray/_vendor/textual/_callback.py new file mode 100644 index 0000000000..8e6bf54e48 --- /dev/null +++ b/src/memray/_vendor/textual/_callback.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +import asyncio +from functools import partial +from inspect import isawaitable, signature +from typing import TYPE_CHECKING, Any, Callable + +from memray._vendor.textual import active_app + +if TYPE_CHECKING: + from memray._vendor.textual.app import App + +# Maximum seconds before warning about a slow callback +INVOKE_TIMEOUT_WARNING = 3 + + +def count_parameters(func: Callable) -> int: + """Count the number of parameters in a callable""" + try: + return func._param_count + except AttributeError: + pass + if isinstance(func, partial): + param_count = _count_parameters(func.func) - ( + len(func.args) + len(func.keywords) + ) + elif hasattr(func, "__self__"): + # Bound method + func = func.__func__ # type: ignore + param_count = _count_parameters(func) - 1 + else: + param_count = _count_parameters(func) + try: + func._param_count = param_count + except TypeError: + pass + return param_count + + +def _count_parameters(func: Callable) -> int: + """Count the number of parameters in a callable""" + return len(signature(func).parameters) + + +async def _invoke(callback: Callable, *params: object) -> Any: + """Invoke a callback with an arbitrary number of parameters. + + Args: + callback: The callable to be invoked. + + Returns: + The return value of the invoked callable. + """ + _rich_traceback_guard = True + parameter_count = count_parameters(callback) + result = callback(*params[:parameter_count]) + if isawaitable(result): + result = await result + return result + + +async def invoke(callback: Callable[..., Any], *params: object) -> Any: + """Invoke a callback with an arbitrary number of parameters. + + Args: + callback: The callable to be invoked. + + Returns: + The return value of the invoked callable. + """ + + app: App | None + try: + app = active_app.get() + except LookupError: + # May occur if this method is called outside of an app context (i.e. in a unit test) + app = None + + if app is not None and "debug" in app.features: + # In debug mode we will warn about callbacks that may be stuck + def log_slow() -> None: + """Log a message regarding a slow callback.""" + assert app is not None + app.log.warning( + f"Callback {callback} is still pending after {INVOKE_TIMEOUT_WARNING} seconds" + ) + + call_later_handle = asyncio.get_running_loop().call_later( + INVOKE_TIMEOUT_WARNING, log_slow + ) + try: + return await _invoke(callback, *params) + finally: + call_later_handle.cancel() + else: + return await _invoke(callback, *params) diff --git a/src/memray/_vendor/textual/_cells.py b/src/memray/_vendor/textual/_cells.py new file mode 100644 index 0000000000..56c135b139 --- /dev/null +++ b/src/memray/_vendor/textual/_cells.py @@ -0,0 +1,44 @@ +from typing import Callable + +from memray._vendor.textual.expand_tabs import get_tab_widths + +__all__ = ["cell_len", "cell_width_to_column_index"] + + +cell_len: Callable[[str], int] +try: + from rich.cells import cached_cell_len as cell_len +except ImportError: + from rich.cells import cell_len + + +def cell_width_to_column_index(line: str, cell_width: int, tab_width: int) -> int: + """Retrieve the column index corresponding to the given cell width. + + Args: + line: The line of text to search within. + cell_width: The cell width to convert to column index. + tab_width: The tab stop width to expand tabs contained within the line. + + Returns: + The column corresponding to the cell width. + """ + column_index = 0 + total_cell_offset = 0 + for part, expanded_tab_width in get_tab_widths(line, tab_width): + # Check if the click landed on a character within this part. + for character in part: + total_cell_offset += cell_len(character) + if total_cell_offset > cell_width: + return column_index + column_index += 1 + + # Account for the appearance of the tab character for this part + total_cell_offset += expanded_tab_width + # Check if the click falls within the boundary of the expanded tab. + if total_cell_offset > cell_width: + return column_index + + column_index += 1 + + return len(line) diff --git a/src/memray/_vendor/textual/_color_constants.py b/src/memray/_vendor/textual/_color_constants.py new file mode 100644 index 0000000000..05759e4a57 --- /dev/null +++ b/src/memray/_vendor/textual/_color_constants.py @@ -0,0 +1,192 @@ +from __future__ import annotations + +ANSI_COLORS = [ + "black", + "red", + "green", + "yellow", + "blue", + "magenta", + "cyan", + "white", + "bright_black", + "bright_red", + "bright_green", + "bright_yellow", + "bright_blue", + "bright_magenta", + "bright_cyan", + "bright_white", +] +"""The names of ANSI colors (prefixed with ansi_ in CSS).""" + +COLOR_NAME_TO_RGB: dict[str, tuple[int, int, int] | tuple[int, int, int, int]] = { + # Let's start with a specific pseudo-color:: + "transparent": (0, 0, 0, 0), + # Then, the 16 common ANSI colors: + "ansi_black": (0, 0, 0), + "ansi_red": (128, 0, 0), + "ansi_green": (0, 128, 0), + "ansi_yellow": (128, 128, 0), + "ansi_blue": (0, 0, 128), + "ansi_magenta": (128, 0, 128), + "ansi_cyan": (0, 128, 128), + "ansi_white": (192, 192, 192), + "ansi_bright_black": (128, 128, 128), + "ansi_bright_red": (255, 0, 0), + "ansi_bright_green": (0, 255, 0), + "ansi_bright_yellow": (255, 255, 0), + "ansi_bright_blue": (0, 0, 255), + "ansi_bright_magenta": (255, 0, 255), + "ansi_bright_cyan": (0, 255, 255), + "ansi_bright_white": (255, 255, 255), + # And then, Web color keywords: (up to CSS Color Module Level 4) + "black": (0, 0, 0), + "silver": (192, 192, 192), + "gray": (128, 128, 128), + "white": (255, 255, 255), + "maroon": (128, 0, 0), + "red": (255, 0, 0), + "purple": (128, 0, 128), + "fuchsia": (255, 0, 255), + "green": (0, 128, 0), + "lime": (0, 255, 0), + "olive": (128, 128, 0), + "yellow": (255, 255, 0), + "navy": (0, 0, 128), + "blue": (0, 0, 255), + "teal": (0, 128, 128), + "aqua": (0, 255, 255), + "orange": (255, 165, 0), + "aliceblue": (240, 248, 255), + "antiquewhite": (250, 235, 215), + "aquamarine": (127, 255, 212), + "azure": (240, 255, 255), + "beige": (245, 245, 220), + "bisque": (255, 228, 196), + "blanchedalmond": (255, 235, 205), + "blueviolet": (138, 43, 226), + "brown": (165, 42, 42), + "burlywood": (222, 184, 135), + "cadetblue": (95, 158, 160), + "chartreuse": (127, 255, 0), + "chocolate": (210, 105, 30), + "coral": (255, 127, 80), + "cornflowerblue": (100, 149, 237), + "cornsilk": (255, 248, 220), + "crimson": (220, 20, 60), + "cyan": (0, 255, 255), + "darkblue": (0, 0, 139), + "darkcyan": (0, 139, 139), + "darkgoldenrod": (184, 134, 11), + "darkgray": (169, 169, 169), + "darkgreen": (0, 100, 0), + "darkgrey": (169, 169, 169), + "darkkhaki": (189, 183, 107), + "darkmagenta": (139, 0, 139), + "darkolivegreen": (85, 107, 47), + "darkorange": (255, 140, 0), + "darkorchid": (153, 50, 204), + "darkred": (139, 0, 0), + "darksalmon": (233, 150, 122), + "darkseagreen": (143, 188, 143), + "darkslateblue": (72, 61, 139), + "darkslategray": (47, 79, 79), + "darkslategrey": (47, 79, 79), + "darkturquoise": (0, 206, 209), + "darkviolet": (148, 0, 211), + "deeppink": (255, 20, 147), + "deepskyblue": (0, 191, 255), + "dimgray": (105, 105, 105), + "dimgrey": (105, 105, 105), + "dodgerblue": (30, 144, 255), + "firebrick": (178, 34, 34), + "floralwhite": (255, 250, 240), + "forestgreen": (34, 139, 34), + "gainsboro": (220, 220, 220), + "ghostwhite": (248, 248, 255), + "gold": (255, 215, 0), + "goldenrod": (218, 165, 32), + "greenyellow": (173, 255, 47), + "grey": (128, 128, 128), + "honeydew": (240, 255, 240), + "hotpink": (255, 105, 180), + "indianred": (205, 92, 92), + "indigo": (75, 0, 130), + "ivory": (255, 255, 240), + "khaki": (240, 230, 140), + "lavender": (230, 230, 250), + "lavenderblush": (255, 240, 245), + "lawngreen": (124, 252, 0), + "lemonchiffon": (255, 250, 205), + "lightblue": (173, 216, 230), + "lightcoral": (240, 128, 128), + "lightcyan": (224, 255, 255), + "lightgoldenrodyellow": (250, 250, 210), + "lightgray": (211, 211, 211), + "lightgreen": (144, 238, 144), + "lightgrey": (211, 211, 211), + "lightpink": (255, 182, 193), + "lightsalmon": (255, 160, 122), + "lightseagreen": (32, 178, 170), + "lightskyblue": (135, 206, 250), + "lightslategray": (119, 136, 153), + "lightslategrey": (119, 136, 153), + "lightsteelblue": (176, 196, 222), + "lightyellow": (255, 255, 224), + "limegreen": (50, 205, 50), + "linen": (250, 240, 230), + "magenta": (255, 0, 255), + "mediumaquamarine": (102, 205, 170), + "mediumblue": (0, 0, 205), + "mediumorchid": (186, 85, 211), + "mediumpurple": (147, 112, 219), + "mediumseagreen": (60, 179, 113), + "mediumslateblue": (123, 104, 238), + "mediumspringgreen": (0, 250, 154), + "mediumturquoise": (72, 209, 204), + "mediumvioletred": (199, 21, 133), + "midnightblue": (25, 25, 112), + "mintcream": (245, 255, 250), + "mistyrose": (255, 228, 225), + "moccasin": (255, 228, 181), + "navajowhite": (255, 222, 173), + "oldlace": (253, 245, 230), + "olivedrab": (107, 142, 35), + "orangered": (255, 69, 0), + "orchid": (218, 112, 214), + "palegoldenrod": (238, 232, 170), + "palegreen": (152, 251, 152), + "paleturquoise": (175, 238, 238), + "palevioletred": (219, 112, 147), + "papayawhip": (255, 239, 213), + "peachpuff": (255, 218, 185), + "peru": (205, 133, 63), + "pink": (255, 192, 203), + "plum": (221, 160, 221), + "powderblue": (176, 224, 230), + "rosybrown": (188, 143, 143), + "royalblue": (65, 105, 225), + "saddlebrown": (139, 69, 19), + "salmon": (250, 128, 114), + "sandybrown": (244, 164, 96), + "seagreen": (46, 139, 87), + "seashell": (255, 245, 238), + "sienna": (160, 82, 45), + "skyblue": (135, 206, 235), + "slateblue": (106, 90, 205), + "slategray": (112, 128, 144), + "slategrey": (112, 128, 144), + "snow": (255, 250, 250), + "springgreen": (0, 255, 127), + "steelblue": (70, 130, 180), + "tan": (210, 180, 140), + "thistle": (216, 191, 216), + "tomato": (255, 99, 71), + "turquoise": (64, 224, 208), + "violet": (238, 130, 238), + "wheat": (245, 222, 179), + "whitesmoke": (245, 245, 245), + "yellowgreen": (154, 205, 50), + "rebeccapurple": (102, 51, 153), +} diff --git a/src/memray/_vendor/textual/_compat.py b/src/memray/_vendor/textual/_compat.py new file mode 100644 index 0000000000..32d7d7d136 --- /dev/null +++ b/src/memray/_vendor/textual/_compat.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +import sys +from typing import Any, Generic, TypeVar, overload + +if sys.version_info >= (3, 12): + from functools import cached_property +else: + # based on the code from Python 3.14: + # https://github.com/python/cpython/blob/ + # 5507eff19c757a908a2ff29dfe423e35595fda00/Lib/functools.py#L1089-L1138 + # Copyright (C) 2006 Python Software Foundation. + # vendored under the PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 because + # prior to Python 3.12 cached_property used a threading.Lock, which makes + # it very slow. + _T_co = TypeVar("_T_co", covariant=True) + _NOT_FOUND = object() + + class cached_property(Generic[_T_co]): + def __init__(self, func: Callable[[Any, _T_co]]) -> None: + self.func = func + self.attrname = None + self.__doc__ = func.__doc__ + self.__module__ = func.__module__ + + def __set_name__(self, owner: type[any], name: str) -> None: + if self.attrname is None: + self.attrname = name + elif name != self.attrname: + raise TypeError( + "Cannot assign the same cached_property to two different names " + f"({self.attrname!r} and {name!r})." + ) + + @overload + def __get__(self, instance: None, owner: type[Any] | None = None) -> Self: ... + + @overload + def __get__( + self, instance: object, owner: type[Any] | None = None + ) -> _T_co: ... + + def __get__( + self, instance: object, owner: type[Any] | None = None + ) -> _T_co | Self: + if instance is None: + return self + if self.attrname is None: + raise TypeError( + "Cannot use cached_property instance without calling __set_name__ on it." + ) + try: + cache = instance.__dict__ + except ( + AttributeError + ): # not all objects have __dict__ (e.g. class defines slots) + msg = ( + f"No '__dict__' attribute on {type(instance).__name__!r} " + f"instance to cache {self.attrname!r} property." + ) + raise TypeError(msg) from None + val = cache.get(self.attrname, _NOT_FOUND) + if val is _NOT_FOUND: + val = self.func(instance) + try: + cache[self.attrname] = val + except TypeError: + msg = ( + f"The '__dict__' attribute on {type(instance).__name__!r} instance " + f"does not support item assignment for caching {self.attrname!r} property." + ) + raise TypeError(msg) from None + return val diff --git a/src/memray/_vendor/textual/_compositor.py b/src/memray/_vendor/textual/_compositor.py new file mode 100644 index 0000000000..24dbfbadca --- /dev/null +++ b/src/memray/_vendor/textual/_compositor.py @@ -0,0 +1,1269 @@ +""" + +The compositor handles combining widgets into a single screen (i.e. compositing). + +It also stores the results of that process, so that Textual knows the widgets on +the screen and their locations. The compositor uses this information to answer +queries regarding the widget under an offset, or the style under an offset. + +Additionally, the compositor can render portions of the screen which may have updated, +without having to render the entire screen. +""" + +from __future__ import annotations + +from operator import itemgetter +from typing import ( + TYPE_CHECKING, + Callable, + Iterable, + Mapping, + NamedTuple, + Sequence, + cast, +) + +import rich.repr +from rich.console import Console, ConsoleOptions, RenderableType, RenderResult +from rich.control import Control +from rich.segment import Segment +from rich.style import Style + +from memray._vendor.textual import errors +from memray._vendor.textual._cells import cell_len +from memray._vendor.textual._context import visible_screen_stack +from memray._vendor.textual._loop import loop_last +from memray._vendor.textual.geometry import NULL_SPACING, Offset, Region, Size, Spacing +from memray._vendor.textual.map_geometry import MapGeometry +from memray._vendor.textual.strip import Strip, StripRenderable +from memray._vendor.textual.widget import Widget + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + from memray._vendor.textual.screen import Screen + + +class ReflowResult(NamedTuple): + """The result of a reflow operation. Describes the chances to widgets.""" + + hidden: set[Widget] # Widgets that are hidden + shown: set[Widget] # Widgets that are shown + resized: set[Widget] # Widgets that have been resized + + +# Maps a widget on to its geometry (information that describes its position in the composition) +CompositorMap: TypeAlias = "dict[Widget, MapGeometry]" + + +class CompositorUpdate: + """An update generated by the compositor, which also doubles as console renderables.""" + + def render_segments(self, console: Console) -> str: + """Render the update to raw data, suitable for writing to terminal. + + Args: + console: Console instance. + + Returns: + Raw data with escape sequences. + """ + return "" + + +@rich.repr.auto(angular=True) +class LayoutUpdate(CompositorUpdate): + """A renderable containing the result of a render for a given region.""" + + def __init__(self, strips: list[Iterable[Strip]], region: Region) -> None: + self.strips = strips + self.region = region + + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + x = self.region.x + new_line = Segment.line() + move_to = Control.move_to + for last, (y, line) in loop_last(enumerate(self.strips, self.region.y)): + yield move_to(x, y).segment + for strip in line: + yield from strip + if not last: + yield new_line + + def render_segments(self, console: Console) -> str: + """Render the update to raw data, suitable for writing to terminal. + + Args: + console: Console instance. + + Returns: + Raw data with escape sequences. + """ + sequences: list[str] = [] + append = sequences.append + extend = sequences.extend + x = self.region.x + move_to = Control.move_to + for last, (y, line) in loop_last(enumerate(self.strips, self.region.y)): + append(move_to(x, y).segment.text) + extend([strip.render(console) for strip in line]) + if not last: + append("\n") + return "".join(sequences) + + def __rich_repr__(self) -> rich.repr.Result: + yield self.region + + +@rich.repr.auto(angular=True) +class InlineUpdate(CompositorUpdate): + """A renderable to write an inline update.""" + + def __init__(self, strips: list[Strip], clear: bool = False) -> None: + self.strips = strips + self.clear = clear + + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + new_line = Segment.line() + for last, line in loop_last(self.strips): + yield from line + if not last: + yield new_line + + def render_segments(self, console: Console) -> str: + """Render the update to raw data, suitable for writing to terminal. + + Args: + console: Console instance. + + Returns: + Raw data with escape sequences. + """ + sequences: list[str] = [] + append = sequences.append + for last, strip in loop_last(self.strips): + append(strip.render(console)) + if not last: + append("\n") + if self.clear: + if len(self.strips) > 1: + append("\n") + append("\x1b[J") # Clear down + if len(self.strips) > 1: + back_lines = len(self.strips) if self.clear else len(self.strips) - 1 + append(f"\x1b[{back_lines}A\r") # Move cursor back to original position + else: + append("\r") + append("\x1b[6n") # Query new cursor position + return "".join(sequences) + + +@rich.repr.auto(angular=True) +class ChopsUpdate(CompositorUpdate): + """A renderable that applies updated spans to the screen.""" + + def __init__( + self, + chops: Sequence[Mapping[int, Strip | None]], + spans: list[tuple[int, int, int]], + chop_ends: list[list[int]], + ) -> None: + """A renderable which updates chops (fragments of lines). + + Args: + chops: A mapping of offsets to list of segments, per line. + crop: Region to restrict update to. + chop_ends: A list of the end offsets for each line + """ + self.chops = chops + self.spans = spans + self.chop_ends = chop_ends + + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + move_to = Control.move_to + new_line = Segment.line() + chops = self.chops + chop_ends = self.chop_ends + last_y = self.spans[-1][0] + + _cell_len = cell_len + for y, x1, x2 in self.spans: + line = chops[y] + ends = chop_ends[y] + for end, (x, strip) in zip(ends, line.items()): + # TODO: crop to x extents + if strip is None: + continue + + if x > x2 or end <= x1: + continue + + if x2 > x >= x1 and end <= x2: + yield move_to(x, y).segment + yield from strip + continue + + iter_segments = iter(strip) + if x < x1: + for segment in iter_segments: + next_x = x + _cell_len(segment.text) + if next_x > x1: + yield move_to(x, y).segment + yield segment + break + x = next_x + else: + yield move_to(x, y).segment + if end <= x2: + yield from iter_segments + else: + for segment in iter_segments: + if x >= x2: + break + yield segment + x += _cell_len(segment.text) + + if y != last_y: + yield new_line + + def render_segments(self, console: Console) -> str: + """Render the update to raw data, suitable for writing to terminal. + + Args: + console: Console instance. + + Returns: + Raw data with escape sequences. + """ + sequences: list[str] = [] + append = sequences.append + + move_to = Control.move_to + chops = self.chops + chop_ends = self.chop_ends + last_y = self.spans[-1][0] + + for y, x1, x2 in self.spans: + line = chops[y] + ends = chop_ends[y] + for end, (x, strip) in zip(ends, line.items()): + if strip is None: + continue + + if x > x2 or end <= x1: + continue + + if x2 > x >= x1 and end <= x2: + append(move_to(x, y).segment.text) + append(strip.render(console)) + continue + + strip = strip.crop(0, min(end, x2) - x) + append(move_to(x, y).segment.text) + append(strip.render(console)) + + if y != last_y: + append("\n") + + terminal_sequences = "".join(sequences) + return terminal_sequences + + def __rich_repr__(self) -> rich.repr.Result: + yield from () + + +@rich.repr.auto(angular=True) +class Compositor: + """Responsible for storing information regarding the relative positions of Widgets and rendering them.""" + + def __init__(self) -> None: + # A mapping of Widget on to its "render location" (absolute position / depth) + self._full_map: CompositorMap = {} + self._full_map_invalidated = True + self._visible_map: CompositorMap | None = None + self._layers: list[tuple[Widget, MapGeometry]] | None = None + + # All widgets considered in the arrangement + # Note this may be a superset of self.full_map.keys() as some widgets may be invisible for various reasons + self.widgets: set[Widget] = set() + + # Mapping of visible widgets on to their region, and clip region + self._visible_widgets: dict[Widget, tuple[Region, Region]] | None = None + + # The top level widget + self.root: Widget | None = None + + # Dimensions of the arrangement + self.size = Size(0, 0) + + # The points in each line where the line bisects the left and right edges of the widget + self._cuts: list[list[int]] | None = None + + # Regions that require an update + self._dirty_regions: set[Region] = set() + + # Mapping of line numbers on to lists of widget and regions + self._layers_visible: list[list[tuple[Widget, Region, Region]]] | None = None + + def clear(self) -> None: + """Remove all references to widgets (used when the screen closes).""" + self._full_map.clear() + self._visible_map = None + self._layers = None + self.widgets.clear() + self._visible_widgets = None + self._layers_visible = None + + @classmethod + def _regions_to_spans( + cls, regions: Iterable[Region] + ) -> Iterable[tuple[int, int, int]]: + """Converts the regions to horizontal spans. Spans will be combined if they overlap + or are contiguous to produce optimal non-overlapping spans. + + Args: + regions: An iterable of Regions. + + Returns: + Yields tuples of (Y, X1, X2). + """ + inline_ranges: dict[int, list[tuple[int, int]]] = {} + setdefault = inline_ranges.setdefault + for region_x, region_y, width, height in regions: + span = (region_x, region_x + width) + for y in range(region_y, region_y + height): + setdefault(y, []).append(span) + + slice_remaining = slice(1, None) + for y, ranges in sorted(inline_ranges.items()): + if len(ranges) == 1: + # Special case of 1 span + yield (y, *ranges[0]) + else: + ranges.sort() + x1, x2 = ranges[0] + for next_x1, next_x2 in ranges[slice_remaining]: + if next_x1 <= x2: + if next_x2 > x2: + x2 = next_x2 + else: + yield (y, x1, x2) + x1 = next_x1 + x2 = next_x2 + yield (y, x1, x2) + + def __rich_repr__(self) -> rich.repr.Result: + yield "size", self.size + yield "widgets", self.widgets + + def reflow(self, parent: Widget, size: Size) -> ReflowResult: + """Reflow (layout) widget and its children. + + Args: + parent: The root widget. + size: Size of the area to be filled. + + Returns: + Hidden, shown, and resized widgets. + """ + self._cuts = None + self._layers = None + self._layers_visible = None + self._visible_widgets = None + self._visible_map = None + self.root = parent + self.size = size + + # Keep a copy of the old map because we're going to compare it with the update + old_map = self._full_map + old_widgets = old_map.keys() + + map, widgets = self._arrange_root(parent, size, visible_only=False) + + new_widgets = map.keys() + + # Newly visible widgets + shown_widgets = new_widgets - old_widgets + + # Newly hidden widgets + hidden_widgets = self.widgets - widgets + + # Replace map and widgets + self._full_map = map + self.widgets = widgets + + # Contains widgets + geometry for every widget that changed (added, removed, or updated) + changes = map.items() ^ old_map.items() + + # Widgets in both new and old + common_widgets = old_widgets & new_widgets + + # Mark dirty regions. + screen_region = size.region + if screen_region not in self._dirty_regions: + regions = { + region + for region in ( + map_geometry.clip.intersection(map_geometry.region) + for _, map_geometry in changes + ) + if region + } + self._dirty_regions.update(regions) + + resized_widgets = { + widget + for widget, (region, *_) in changes + if (widget in common_widgets and old_map[widget].region.size != region.size) + } + return ReflowResult( + hidden=hidden_widgets, + shown=shown_widgets, + resized=resized_widgets, + ) + + def reflow_visible(self, parent: Widget, size: Size) -> set[Widget]: + """Reflow only the visible children. + + This is a fast-path for scrolling. + + Args: + parent: The root widget. + size: Size of the area to be filled. + + Returns: + Set of widgets that were exposed by the scroll. + """ + self._cuts = None + self._layers = None + self._layers_visible = None + self._visible_widgets = None + self._full_map_invalidated = True + self.root = parent + self.size = size + + # Keep a copy of the old map because we're going to compare it with the update + old_map = self._visible_map or {} + map, widgets = self._arrange_root(parent, size, visible_only=True) + + # Replace map and widgets + self._visible_map = map + self.widgets = widgets + + exposed_widgets = map.keys() - old_map.keys() + + # Contains widgets + geometry for every widget that changed (added, removed, or updated) + changes = map.items() ^ old_map.items() + + # Mark dirty regions. + screen_region = size.region + if screen_region not in self._dirty_regions: + regions = { + region + for region in ( + map_geometry.clip.intersection(map_geometry.region) + for _, map_geometry in changes + ) + if region + } + self._dirty_regions.update(regions) + + return exposed_widgets + + @property + def full_map(self) -> CompositorMap: + """Lazily built compositor map that covers all widgets.""" + + if self.root is None: + return {} + if self._full_map_invalidated: + self._full_map_invalidated = False + map, _widgets = self._arrange_root(self.root, self.size, visible_only=False) + # Update any widgets which became visible in the interim + self._full_map = map + self._visible_widgets = None + self._visible_map = None + + return self._full_map + + @property + def visible_widgets(self) -> dict[Widget, tuple[Region, Region]]: + """Get a mapping of widgets on to region and clip. + + Returns: + Visible widget mapping. + """ + + if self._visible_widgets is None: + map = ( + self._visible_map + if self._visible_map is not None + else (self._full_map or {}) + ) + screen = self.size.region + in_screen = screen.overlaps + overlaps = Region.overlaps + + # Widgets and regions in render order + visible_widgets = [ + (order, widget, region, clip) + for widget, (region, order, clip, _, _, _, _) in map.items() + if in_screen(region) and overlaps(clip, region) + ] + visible_widgets.sort(key=itemgetter(0), reverse=True) + self._visible_widgets = { + widget: (region, clip) for _, widget, region, clip in visible_widgets + } + return self._visible_widgets + + def _arrange_root( + self, root: Widget, size: Size, visible_only: bool = True + ) -> tuple[CompositorMap, set[Widget]]: + """Arrange a widget's children based on its layout attribute. + + Args: + root: Top level widget. + size: Size of visible area (screen). + visible_only: Only update visible widgets (used in scrolling). + + Returns: + Compositor map and set of widgets. + """ + + map: CompositorMap = {} + widgets: set[Widget] = set() + add_new_widget = widgets.add + invisible_widgets: set[Widget] = set() + add_new_invisible_widget = invisible_widgets.add + layer_order: int = 0 + + no_clip = size.region + + def add_widget( + widget: Widget, + virtual_region: Region, + region: Region, + order: tuple[tuple[int, int, int], ...], + layer_order: int, + clip: Region, + visible: bool, + dock_gutter: Spacing, + _MapGeometry: type[MapGeometry] = MapGeometry, + ) -> None: + """Called recursively to place a widget and its children in the map. + + Args: + widget: The widget to add. + virtual_region: The Widget region relative to its container. + region: The region the widget will occupy. + order: Painting order information. + layer_order: The order of the widget in its layer. + clip: The clipping region (i.e. the viewport which contains it). + visible: Whether the widget should be visible by default. + This may be overridden by the CSS rule `visibility`. + """ + if not widget._is_mounted: + return + styles = widget.styles + + if (visibility := styles.get_rule("visibility")) is not None: + visible = visibility == "visible" + + if visible: + add_new_widget(widget) + else: + add_new_invisible_widget(widget) + + # Container region is minus border + container_region = region.shrink(styles.gutter) + container_size = container_region.size + + # Widgets with scrollbars (containers or scroll view) require additional processing + if widget.is_scrollable: + # The region that contains the content (container region minus scrollbars) + child_region = ( + container_region + if widget.loading + else widget._get_scrollable_region(container_region) + ) + + # The region covered by children relative to parent widget + total_region = child_region.reset_offset + + if widget.is_container: + # Arrange the layout + arrange_result = widget.arrange(child_region.size) + + arranged_widgets = arrange_result.widgets + widgets.update(arranged_widgets) + + # Get the region that will be updated + sub_clip = clip.intersection(child_region) + + if widget._anchored and not widget._anchor_released: + new_scroll_y = ( + arrange_result.spatial_map.total_region.bottom + - ( + widget.container_size.height + - widget.scrollbar_size_horizontal + ) + ) + widget.set_reactive(Widget.scroll_y, new_scroll_y) + widget.set_reactive(Widget.scroll_target_y, new_scroll_y) + widget.vertical_scrollbar._reactive_position = new_scroll_y + + if visible_only: + placements = arrange_result.get_visible_placements( + sub_clip - child_region.offset + widget.scroll_offset + ) + else: + placements = arrange_result.placements + total_region = total_region.union(arrange_result.total_region) + + # An offset added to all placements + placement_offset = container_region.offset + placement_scroll_offset = placement_offset - widget.scroll_offset + + placements = [ + placement.process_offset(size.region, placement_scroll_offset) + for placement in placements + ] + + layers_to_index = { + layer_name: index + for index, layer_name in enumerate(widget.layers) + } + + get_layer_index = layers_to_index.get + + if widget._cover_widget is not None: + map[widget._cover_widget] = _MapGeometry( + region.shrink(widget.styles.gutter), + order, + clip, + region.size, + container_size, + virtual_region, + dock_gutter, + ) + + # Add all the widgets + for ( + sub_region, + sub_region_offset, + _, + sub_widget, + z, + fixed, + overlay, + absolute, + ) in reversed(placements): + layer_index = get_layer_index(sub_widget.layer, 0) + # Combine regions with children to calculate the "virtual size" + if fixed: + widget_region = ( + sub_region + sub_region_offset + placement_offset + ) + else: + widget_region = ( + sub_region + sub_region_offset + placement_scroll_offset + ) + + widget_order = order + ((layer_index, z, layer_order),) + + if widget._cover_widget is None: + add_widget( + sub_widget, + sub_region, + widget_region, + ((1, 0, 0),) if overlay else widget_order, + layer_order, + no_clip if overlay else sub_clip, + visible, + arrange_result.scroll_spacing, + ) + layer_order -= 1 + else: + if widget._anchored and not widget._anchor_released: + new_scroll_y = widget.virtual_size.height - ( + widget.container_size.height + - widget.scrollbar_size_horizontal + ) + widget.scroll_y = new_scroll_y + widget.scroll_target_y = new_scroll_y + widget.vertical_scrollbar.position = new_scroll_y + + if visible: + # Add any scrollbars + if ( + widget.show_vertical_scrollbar + or widget.show_horizontal_scrollbar + ) and styles.scrollbar_visibility == "visible": + for chrome_widget, chrome_region in widget._arrange_scrollbars( + container_region + ): + map[chrome_widget] = _MapGeometry( + chrome_region, + order, + clip, + container_size, + container_size, + chrome_region, + dock_gutter, + ) + + map[widget._render_widget] = _MapGeometry( + region, + order, + clip, + total_region.size, + container_size, + virtual_region, + dock_gutter, + ) + + elif visible: + # Add the widget to the map + map[widget._render_widget] = _MapGeometry( + region, + order, + clip, + region.size, + container_size, + virtual_region, + dock_gutter, + ) + + # Add top level (root) widget + add_widget( + root, + size.region, + size.region, + ((0, 0, 0),), + layer_order, + size.region, + True, + NULL_SPACING, + ) + widgets -= invisible_widgets + return map, widgets + + @property + def layers(self) -> list[tuple[Widget, MapGeometry]]: + """Get widgets and geometry in layer order.""" + map = self._visible_map if self._visible_map is not None else self._full_map + if self._layers is None: + self._layers = sorted( + map.items(), key=lambda item: item[1].order, reverse=True + ) + return self._layers + + @property + def layers_visible(self) -> list[list[tuple[Widget, Region, Region]]]: + """Visible widgets and regions in layers order. + + Returns: + Lists visible widgets per layer. Widgets are give as a tuple of + (WIDGET, CROPPED_REGION, REGION). CROPPED_REGION is clipped by + the container. + + """ + + if self._layers_visible is None: + layers_visible: list[list[tuple[Widget, Region, Region]]] + layers_visible = [[] for y in range(self.size.height)] + layers_visible_appends = [layer.append for layer in layers_visible] + intersection = Region.intersection + _range = range + for widget, (region, clip) in self.visible_widgets.items(): + cropped_region = intersection(region, clip) + _x, region_y, _width, region_height = cropped_region + if region_height: + widget_location = (widget, cropped_region, region) + for y in _range(region_y, region_y + region_height): + layers_visible_appends[y](widget_location) + self._layers_visible = layers_visible + return self._layers_visible + + def __contains__(self, widget: Widget) -> bool: + """Check if the widget was included in the last update. + + Args: + widget: A widget. + + Returns: + `True` if the widget was in the last refresh, or `False` if it wasn't. + """ + # Try to avoid a recalculation of full_map if possible. + return ( + widget in self.widgets + or (self._visible_map is not None and widget in self._visible_map) + or widget in self.full_map + ) + + def get_offset(self, widget: Widget) -> Offset: + """Get the offset of a widget. + + Args: + widget: Widget to query. + + Returns: + Offset of widget. + """ + try: + if self._visible_map is not None: + try: + return self._visible_map[widget].region.offset + except KeyError: + pass + return self.full_map[widget].region.offset + except KeyError: + raise errors.NoWidget("Widget is not in layout") + + def get_widget_at(self, x: int, y: int) -> tuple[Widget, Region]: + """Get the widget under a given coordinate. + + Args: + x: X Coordinate. + y: Y Coordinate. + + Raises: + errors.NoWidget: If there is not widget underneath (x, y). + + Returns: + A tuple of the widget and its region. + """ + + contains = Region.contains + if len(self.layers_visible) > y >= 0: + for widget, cropped_region, region in self.layers_visible[int(y)]: + if contains(cropped_region, x, y) and widget.visible: + return widget, region + raise errors.NoWidget(f"No widget under screen coordinate ({x}, {y})") + + def get_widgets_at(self, x: int, y: int) -> Iterable[tuple[Widget, Region]]: + """Get all widgets under a given coordinate. + + Args: + x: X coordinate. + y: Y coordinate. + + Returns: + Sequence of (WIDGET, REGION) tuples. + """ + contains = Region.contains + if len(self.layers_visible) > y >= 0: + for widget, cropped_region, region in self.layers_visible[y]: + if contains(cropped_region, x, y) and widget.visible: + yield widget, region + + def get_style_at(self, x: int, y: int) -> Style: + """Get the Style at the given cell or Style.null() + + Args: + x: X position within the Layout. + y: Y position within the Layout. + + Returns: + The Style at the cell (x, y) within the Layout. + """ + try: + widget, region = self.get_widget_at(x, y) + except errors.NoWidget: + return Style.null() + if widget not in self.visible_widgets: + return Style.null() + + x -= region.x + y -= region.y + + visible_screen_stack.set(widget.app._background_screens) + lines = widget.render_lines(Region(0, y, region.width, 1)) + + if not lines: + return Style.null() + end = 0 + + for segment in lines[0]: + end += segment.cell_length + if x < end: + return segment.style or Style.null() + + return Style.null() + + def get_widget_and_offset_at( + self, x: int, y: int + ) -> tuple[Widget | None, Offset | None]: + """Get the Style at the given cell, the offset within the content. + + Args: + x: X position within the Layout. + y: Y position within the Layout. + + Returns: + A tuple of the widget at (x, y) and the offset within the widget. + """ + try: + widget, region = self.get_widget_at(x, y) + except errors.NoWidget: + return None, None + if widget not in self.visible_widgets: + return None, None + + if y >= widget.content_region.bottom: + x, y = widget.content_region.bottom_right_inclusive + + gutter_left, gutter_right = widget.gutter.top_left + x -= region.x + gutter_left + y -= region.y + gutter_right + + if y < 0: + return None, None + + visible_screen_stack.set(widget.app._background_screens) + line = widget.render_line(y) + + end = 0 + start = 0 + offset_y: int | None = None + offset_x = 0 + offset_x2 = 0 + + from rich.cells import get_character_cell_size + + for segment in line: + end += segment.cell_length + style = segment.style + if style is not None and style._meta is not None: + meta = style.meta + if "offset" in meta: + offset_x, offset_y = style.meta["offset"] + offset_x2 = offset_x + len(segment.text) + + if x < end and x >= start: + segment_cell_length = 0 + cell_cut = x - start + segment_offset = 0 + for character in segment.text: + if segment_cell_length >= cell_cut: + break + segment_cell_length += get_character_cell_size(character) + segment_offset += 1 + return widget, ( + None + if offset_y is None + else Offset(offset_x + segment_offset, offset_y) + ) + start = end + + return widget, (None if offset_y is None else Offset(offset_x2, offset_y)) + + def find_widget(self, widget: Widget) -> MapGeometry: + """Get information regarding the relative position of a widget in the Compositor. + + Args: + widget: The Widget in this layout you wish to know the Region of. + + Raises: + NoWidget: If the Widget is not contained in this Layout. + + Returns: + Widget's composition information. + """ + if self.root is None: + raise errors.NoWidget("Widget is not in layout") + try: + if not self._full_map_invalidated: + try: + return self._full_map[widget] + except KeyError: + pass + if self._visible_map is not None: + try: + return self._visible_map[widget] + except KeyError: + pass + region = self.full_map[widget] + except KeyError: + raise errors.NoWidget("Widget is not in layout") + else: + return region + + @property + def cuts(self) -> list[list[int]]: + """Get vertical cuts. + + A cut is every point on a line where a widget starts or ends. + + Returns: + A list of cuts for every line. + """ + if self._cuts is not None: + return self._cuts + + width, height = self.size + cuts = [[0, width] for _ in range(height)] + + intersection = Region.intersection + extend = list.extend + + for region, clip in self.visible_widgets.values(): + x, y, region_width, region_height = intersection(region, clip) + if region_width and region_height: + region_cuts = (x, x + region_width) + for cut in cuts[y : y + region_height]: + extend(cut, region_cuts) + + # Sort the cuts for each line + self._cuts = [sorted(set(line_cuts)) for line_cuts in cuts] + + return self._cuts + + def _get_renders( + self, crop: Region | None = None + ) -> Iterable[tuple[Region, Region, list[Strip]]]: + """Get rendered widgets (lists of segments) in the composition. + + Args: + crop: Region to crop to, or `None` for entire screen. + + Returns: + An iterable of , , and + """ + # If a renderable throws an error while rendering, the user likely doesn't care about the traceback + # up to this point. + _rich_traceback_guard = True + + _Region = Region + + visible_widgets = self.visible_widgets + + if crop: + crop_overlaps = crop.overlaps + widget_regions = [ + (widget, region, clip) + for widget, (region, clip) in visible_widgets.items() + if crop_overlaps(clip) + ] + else: + widget_regions = [ + (widget, region, clip) + for widget, (region, clip) in visible_widgets.items() + ] + + intersection = _Region.intersection + contains_region = _Region.contains_region + + for widget, region, clip in widget_regions: + if contains_region(clip, region): + yield ( + region, + clip, + widget.render_lines( + _Region( + 0, + 0, + region.width, + region.height, + ) + ), + ) + else: + new_x, new_y, new_width, new_height = intersection(region, clip) + if new_width and new_height: + yield ( + region, + clip, + widget.render_lines( + _Region( + new_x - region.x, + new_y - region.y, + new_width, + new_height, + ) + ), + ) + + def render_update( + self, + full: bool = False, + screen_stack: list[Screen] | None = None, + simplify: bool = False, + ) -> RenderableType | None: + """Render an update renderable. + + Args: + full: Perform a full update if `True`, otherwise a partial update. + screen_stack: Screen stack list. Defaults to None. + simplify: Simplify segments. + + Returns: + A renderable for the update, or `None` if no update was required. + """ + + visible_screen_stack.set([] if screen_stack is None else screen_stack) + screen_region = self.size.region + if full or screen_region in self._dirty_regions: + return self.render_full_update(simplify=simplify) + else: + return self.render_partial_update() + + def render_inline( + self, + size: Size, + screen_stack: list[Screen] | None = None, + clear: bool = False, + ) -> RenderableType: + """Render an inline update. + + Args: + size: Inline size. + screen_stack: Screen stack list. Defaults to None. + clear: Also clear below the inline update (set when size decreases). + + Returns: + A renderable. + """ + visible_screen_stack.set([] if screen_stack is None else screen_stack) + strips = self.render_strips(size) + return InlineUpdate(strips, clear=clear) + + def render_full_update(self, simplify: bool = False) -> LayoutUpdate: + """Render a full update. + + Args: + simplify: Simplify the segments (combine contiguous segments). + + Returns: + A LayoutUpdate renderable. + """ + screen_region = self.size.region + self._dirty_regions.clear() + crop = screen_region + chops = self._render_chops(crop, lambda y: True) + render_strips: list[Iterable[Strip]] + if simplify: + # Simplify is done when exporting to SVG + # It doesn't make things faster + render_strips = [ + [Strip.join(chop.values()).simplify().discard_meta()] for chop in chops + ] + else: + render_strips = [chop.values() for chop in chops] + + return LayoutUpdate(render_strips, screen_region) + + def render_partial_update(self) -> ChopsUpdate | None: + """Render a partial update. + + Returns: + A ChopsUpdate if there is anything to update, otherwise `None`. + """ + screen_region = self.size.region + update_regions = self._dirty_regions.copy() + self._dirty_regions.clear() + if update_regions: + # Create a crop region that surrounds all updates. + crop = Region.from_union(update_regions).intersection(screen_region) + spans = list(self._regions_to_spans(update_regions)) + is_rendered_line = {y for y, _, _ in spans}.__contains__ + else: + return None + chops = self._render_chops(crop, is_rendered_line) + chop_ends = [cut_set[1:] for cut_set in self.cuts] + return ChopsUpdate(chops, spans, chop_ends) + + def render_strips(self, size: Size | None = None) -> list[Strip]: + """Render to a list of strips. + + Args: + size: Size of render. + + Returns: + A list of strips with the screen content. + """ + if size is None: + size = self.size + chops = self._render_chops(size.region, lambda y: True) + render_strips = [Strip.join(chop.values()) for chop in chops[: size.height]] + return render_strips + + def _render_chops( + self, + crop: Region, + is_rendered_line: Callable[[int], bool], + ) -> Sequence[Mapping[int, Strip]]: + """Render update 'chops'. + + Args: + crop: Region to crop to. + is_rendered_line: Callable to check if line should be rendered. + + Returns: + Chops structure. + """ + cuts = self.cuts + fromkeys = cast("Callable[[list[int]], dict[int, Strip | None]]", dict.fromkeys) + chops: list[dict[int, Strip | None]] + chops = [fromkeys(cut_set[:-1]) for cut_set in cuts] + + cut_strips: Iterable[Strip] + + # Go through all the renders in reverse order and fill buckets with no render + renders = self._get_renders(crop) + intersection = Region.intersection + + for region, clip, strips in renders: + render_region = intersection(region, clip) + render_x = render_region.x + first_cut, last_cut = render_region.column_span + + for y, strip in zip(render_region.line_range, strips): + if not is_rendered_line(y): + continue + + chops_line = chops[y] + final_cuts = [cut for cut in cuts[y] if (last_cut >= cut >= first_cut)] + cut_strips = strip.divide([cut - render_x for cut in final_cuts[1:]]) + + # Since we are painting front to back, the first segments for a cut "wins" + get_chops_line = chops_line.get + for cut, strip in zip(final_cuts, cut_strips): + if get_chops_line(cut) is None: + chops_line[cut] = strip + return cast("Sequence[Mapping[int, Strip]]", chops) + + def __rich__(self) -> StripRenderable: + return StripRenderable(self.render_strips()) + + def update_widgets(self, widgets: set[Widget]) -> None: + """Update the given widgets in the composition. + + Args: + widgets: Set of Widgets to update. + """ + + # If there are any *new* widgets we need to invalidate the full map + if not self._full_map_invalidated and not widgets.issubset( + self.visible_widgets.keys() + ): + self._full_map_invalidated = True + + regions: list[Region] = [] + add_region = regions.append + get_widget = self.visible_widgets.__getitem__ + for widget in self.visible_widgets.keys() & widgets: + region, clip = get_widget(widget) + offset = region.offset + intersection = clip.intersection + for dirty_region in widget._exchange_repaint_regions(): + if update_region := intersection(dirty_region.translate(offset)): + add_region(update_region) + + self._dirty_regions.update(regions) diff --git a/src/memray/_vendor/textual/_context.py b/src/memray/_vendor/textual/_context.py new file mode 100644 index 0000000000..8f57a1793a --- /dev/null +++ b/src/memray/_vendor/textual/_context.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from contextvars import ContextVar +from typing import TYPE_CHECKING, Any, Callable + +if TYPE_CHECKING: + from memray._vendor.textual.app import App + from memray._vendor.textual.message import Message + from memray._vendor.textual.message_pump import MessagePump + from memray._vendor.textual.screen import Screen + + +class NoActiveAppError(RuntimeError): + """Runtime error raised if we try to retrieve the active app when there is none.""" + + +active_app: ContextVar["App[Any]"] = ContextVar("active_app") +active_message_pump: ContextVar["MessagePump"] = ContextVar("active_message_pump") + +prevent_message_types_stack: ContextVar[list[set[type[Message]]]] = ContextVar( + "prevent_message_types_stack" +) +visible_screen_stack: ContextVar[list[Screen[object]]] = ContextVar( + "visible_screen_stack" +) +"""A stack of visible screens (with background alpha < 1), used in the screen render process.""" +message_hook: ContextVar[Callable[[Message], None]] = ContextVar("message_hook") +"""A callable that accepts a message. Used by App.run_test.""" diff --git a/src/memray/_vendor/textual/_debug.py b/src/memray/_vendor/textual/_debug.py new file mode 100644 index 0000000000..d16e0be28a --- /dev/null +++ b/src/memray/_vendor/textual/_debug.py @@ -0,0 +1,26 @@ +""" +Functions related to debugging. +""" + +from __future__ import annotations + +from memray._vendor.textual import constants + + +def get_caller_file_and_line() -> str | None: + """Get the caller filename and line, if in debug mode, otherwise return `None`: + + Returns: + Path and file if `constants.DEBUG==True` + """ + + if not constants.DEBUG: + return None + import inspect + + try: + current_frame = inspect.currentframe() + caller_frame = inspect.getframeinfo(current_frame.f_back.f_back) + return f"{caller_frame.filename}:{caller_frame.lineno}" + except Exception: + return None diff --git a/src/memray/_vendor/textual/_dispatch_key.py b/src/memray/_vendor/textual/_dispatch_key.py new file mode 100644 index 0000000000..616488f835 --- /dev/null +++ b/src/memray/_vendor/textual/_dispatch_key.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from typing import Callable + +from memray._vendor.textual import events +from memray._vendor.textual._callback import invoke +from memray._vendor.textual.dom import DOMNode +from memray._vendor.textual.errors import DuplicateKeyHandlers +from memray._vendor.textual.message_pump import MessagePump + + +async def dispatch_key(node: DOMNode, event: events.Key) -> bool: + """Dispatch a key event to method. + + This function will call the method named 'key_' on a node if it exists. + Some keys have aliases. The first alias found will be invoked if it exists. + If multiple handlers exist that match the key, an exception is raised. + + Args: + event: A key event. + + Returns: + True if key was handled, otherwise False. + + Raises: + DuplicateKeyHandlers: When there's more than 1 handler that could handle this key. + """ + + def get_key_handler(pump: MessagePump, key: str) -> Callable | None: + """Look for the public and private handler methods by name on self.""" + return getattr(pump, f"key_{key}", None) or getattr(pump, f"_key_{key}", None) + + handled = False + invoked_method = None + key_name = event.name + if not key_name: + return False + + def _raise_duplicate_key_handlers_error( + key_name: str, first_handler: str, second_handler: str + ) -> None: + """Raise exception for case where user presses a key and there are multiple candidate key handler methods for it.""" + raise DuplicateKeyHandlers( + f"Multiple handlers for key press {key_name!r}.\n" + f"We found both {first_handler!r} and {second_handler!r}, " + f"and didn't know which to call.\n" + f"Consider combining them into a single handler.", + ) + + try: + screen = node.screen + except Exception: + screen = None + for key_method_name in event.name_aliases: + if (key_method := get_key_handler(node, key_method_name)) is not None: + if invoked_method: + _raise_duplicate_key_handlers_error( + key_name, invoked_method.__name__, key_method.__name__ + ) + # If key handlers return False, then they are not considered handled + # This allows key handlers to do some conditional logic + + if screen is not None and not screen.is_active: + break + handled = (await invoke(key_method, event)) is not False + invoked_method = key_method + + return handled diff --git a/src/memray/_vendor/textual/_doc.py b/src/memray/_vendor/textual/_doc.py new file mode 100644 index 0000000000..904b344ad3 --- /dev/null +++ b/src/memray/_vendor/textual/_doc.py @@ -0,0 +1,191 @@ +from __future__ import annotations + +import hashlib +import inspect +import os +import shlex +from pathlib import Path +from typing import Awaitable, Callable, Iterable, cast + +from memray._vendor.textual._import_app import import_app +from memray._vendor.textual.app import App +from memray._vendor.textual.pilot import Pilot + +SCREENSHOT_CACHE = ".screenshot_cache" + + +# This module defines our "Custom Fences", powered by SuperFences +# @link https://facelessuser.github.io/pymdown-extensions/extensions/superfences/#custom-fences +def format_svg(source, language, css_class, options, md, attrs, **kwargs) -> str: + """A superfences formatter to insert an SVG screenshot.""" + + try: + cmd: list[str] = shlex.split(attrs["path"]) + path = cmd[0] + + _press = attrs.get("press", None) + _type = attrs.get("type", None) + press = [*_press.split(",")] if _press else [] + if _type is not None: + press.extend(_type.replace("\\t", "\t")) + title = attrs.get("title") + + print(f"screenshotting {path!r}") + + cwd = os.getcwd() + try: + rows = int(attrs.get("lines", 24)) + columns = int(attrs.get("columns", 80)) + hover = attrs.get("hover", "") + svg = take_svg_screenshot( + None, + path, + press, + hover=hover, + title=title, + terminal_size=(columns, rows), + wait_for_animation=False, + simplify=False, + ) + finally: + os.chdir(cwd) + + assert svg is not None + return svg + + except Exception as error: + import traceback + + traceback.print_exception(error) + return "" + + +def take_svg_screenshot( + app: App | None = None, + app_path: str | None = None, + press: Iterable[str] = (), + hover: str = "", + title: str | None = None, + terminal_size: tuple[int, int] = (80, 24), + run_before: Callable[[Pilot], Awaitable[None] | None] | None = None, + wait_for_animation: bool = True, + simplify=True, +) -> str: + """ + + Args: + app: An app instance. Must be supplied if app_path is not. + app_path: A path to an app. Must be supplied if app is not. + press: Key presses to run before taking screenshot. "_" is a short pause. + hover: Hover over the given widget. + title: The terminal title in the output image. + terminal_size: A pair of integers (rows, columns), representing terminal size. + run_before: An arbitrary callable that runs arbitrary code before taking the + screenshot. Use this to simulate complex user interactions with the app + that cannot be simulated by key presses. + wait_for_animation: Wait for animation to complete before taking screenshot. + simplify: Simplify the segments by combining contiguous segments with the same style. + + Returns: + An SVG string, showing the content of the terminal window at the time + the screenshot was taken. + """ + + if app is None: + assert app_path is not None + app = import_app(app_path) + + assert app is not None + + if title is None: + title = app.title + + def get_cache_key(app: App) -> str: + hash = hashlib.md5() + file_paths = [app_path] + app.css_path + for path in file_paths: + assert path is not None + with open(path, "rb") as source_file: + hash.update(source_file.read()) + hash.update(f"{press}-{hover}-{title}-{terminal_size}".encode("utf-8")) + cache_key = f"{hash.hexdigest()}.svg" + return cache_key + + if app_path is not None and run_before is None: + screenshot_cache = Path(SCREENSHOT_CACHE) + screenshot_cache.mkdir(exist_ok=True) + + screenshot_path = screenshot_cache / get_cache_key(app) + if screenshot_path.exists(): + return screenshot_path.read_text() + + async def auto_pilot(pilot: Pilot) -> None: + app = pilot.app + if run_before is not None: + result = run_before(pilot) + if inspect.isawaitable(result): + await result + await pilot.pause() + await pilot.press(*press) + if hover: + await pilot.hover(hover) + await pilot.pause(0.5) + if wait_for_animation: + await pilot.wait_for_scheduled_animations() + await pilot.pause() + await pilot.pause() + await pilot.wait_for_scheduled_animations() + svg = app.export_screenshot(title=title, simplify=simplify) + + app.exit(svg) + + svg = cast( + str, + app.run( + headless=True, + auto_pilot=auto_pilot, + size=terminal_size, + ), + ) + + if app_path is not None and run_before is None: + screenshot_path.write_text(svg) + + assert svg is not None + + return svg + + +def rich(source, language, css_class, options, md, attrs, **kwargs) -> str: + """A superfences formatter to insert an SVG screenshot.""" + + import io + + from rich.console import Console + + title = attrs.get("title", "Rich") + + rows = int(attrs.get("lines", 24)) + columns = int(attrs.get("columns", 80)) + + console = Console( + file=io.StringIO(), + record=True, + force_terminal=True, + color_system="truecolor", + width=columns, + height=rows, + ) + error_console = Console(stderr=True) + + globals: dict = {} + try: + exec(source, globals) + except Exception: + error_console.print_exception() + # console.bell() + + if "output" in globals: + console.print(globals["output"]) + output_svg = console.export_svg(title=title) + return output_svg diff --git a/src/memray/_vendor/textual/_duration.py b/src/memray/_vendor/textual/_duration.py new file mode 100644 index 0000000000..8d6145605a --- /dev/null +++ b/src/memray/_vendor/textual/_duration.py @@ -0,0 +1,44 @@ +import re + +_match_duration = re.compile(r"^(-?\d+\.?\d*)(s|ms)$").match + + +class DurationError(Exception): + """ + Exception indicating a general issue with a CSS duration. + """ + + +class DurationParseError(DurationError): + """ + Indicates a malformed duration string that could not be parsed. + """ + + +def _duration_as_seconds(duration: str) -> float: + """ + Args: + duration: A string of the form `"2s"` or `"300ms"`, representing 2 seconds and + 300 milliseconds respectively. If no unit is supplied, e.g. `"2"`, then the duration is + assumed to be in seconds. + Raises: + DurationParseError: If the argument `duration` is not a valid duration string. + Returns: + The duration in seconds. + """ + match = _match_duration(duration) + + if match: + value, unit_name = match.groups() + value = float(value) + if unit_name == "ms": + duration_secs = value / 1000 + else: + duration_secs = value + else: + try: + duration_secs = float(duration) + except ValueError: + raise DurationParseError(f"{duration!r} is not a valid duration.") from None + + return duration_secs diff --git a/src/memray/_vendor/textual/_easing.py b/src/memray/_vendor/textual/_easing.py new file mode 100644 index 0000000000..13b274ddd8 --- /dev/null +++ b/src/memray/_vendor/textual/_easing.py @@ -0,0 +1,131 @@ +""" +Define a series of easing functions for more natural-looking animations. +Taken from https://easings.net/ and translated from JavaScript. +""" + +from math import cos, pi, sin, sqrt + + +def _in_out_expo(x: float) -> float: + """https://easings.net/#easeInOutExpo""" + if 0 < x < 0.5: + return pow(2, 20 * x - 10) / 2 + elif 0.5 <= x < 1: + return (2 - pow(2, -20 * x + 10)) / 2 + else: + return x # x in (0, 1) + + +def _in_out_circ(x: float) -> float: + """https://easings.net/#easeInOutCirc""" + if x < 0.5: + return (1 - sqrt(1 - pow(2 * x, 2))) / 2 + else: + return (sqrt(1 - pow(-2 * x + 2, 2)) + 1) / 2 + + +def _in_out_back(x: float) -> float: + """https://easings.net/#easeInOutBack""" + c = 1.70158 * 1.525 + if x < 0.5: + return (pow(2 * x, 2) * ((c + 1) * 2 * x - c)) / 2 + else: + return (pow(2 * x - 2, 2) * ((c + 1) * (x * 2 - 2) + c) + 2) / 2 + + +def _in_elastic(x: float) -> float: + """https://easings.net/#easeInElastic""" + c = 2 * pi / 3 + if 0 < x < 1: + return -pow(2, 10 * x - 10) * sin((x * 10 - 10.75) * c) + else: + return x # x in (0, 1) + + +def _in_out_elastic(x: float) -> float: + """https://easings.net/#easeInOutElastic""" + c = 2 * pi / 4.5 + if 0 < x < 0.5: + return -(pow(2, 20 * x - 10) * sin((20 * x - 11.125) * c)) / 2 + elif 0.5 <= x < 1: + return (pow(2, -20 * x + 10) * sin((20 * x - 11.125) * c)) / 2 + 1 + else: + return x # x in (0, 1) + + +def _out_elastic(x: float) -> float: + """https://easings.net/#easeInOutElastic""" + c = 2 * pi / 3 + if 0 < x < 1: + return pow(2, -10 * x) * sin((x * 10 - 0.75) * c) + 1 + else: + return x # x in (0, 1) + + +def _out_bounce(x: float) -> float: + """https://easings.net/#easeOutBounce""" + n, d = 7.5625, 2.75 + if x < 1 / d: + return n * x * x + elif x < 2 / d: + x_ = x - 1.5 / d + return n * x_ * x_ + 0.75 + elif x < 2.5 / d: + x_ = x - 2.25 / d + return n * x_ * x_ + 0.9375 + else: + x_ = x - 2.625 / d + return n * x_ * x_ + 0.984375 + + +def _in_bounce(x: float) -> float: + """https://easings.net/#easeInBounce""" + return 1 - _out_bounce(1 - x) + + +def _in_out_bounce(x: float) -> float: + """https://easings.net/#easeInOutBounce""" + if x < 0.5: + return (1 - _out_bounce(1 - 2 * x)) / 2 + else: + return (1 + _out_bounce(2 * x - 1)) / 2 + + +EASING = { + "none": lambda x: 1.0, + "round": lambda x: 0.0 if x < 0.5 else 1.0, + "linear": lambda x: x, + "in_sine": lambda x: 1 - cos((x * pi) / 2), + "in_out_sine": lambda x: -(cos(x * pi) - 1) / 2, + "out_sine": lambda x: sin((x * pi) / 2), + "in_quad": lambda x: x * x, + "in_out_quad": lambda x: 2 * x * x if x < 0.5 else 1 - pow(-2 * x + 2, 2) / 2, + "out_quad": lambda x: 1 - pow(1 - x, 2), + "in_cubic": lambda x: x * x * x, + "in_out_cubic": lambda x: 4 * x * x * x if x < 0.5 else 1 - pow(-2 * x + 2, 3) / 2, + "out_cubic": lambda x: 1 - pow(1 - x, 3), + "in_quart": lambda x: pow(x, 4), + "in_out_quart": lambda x: 8 * pow(x, 4) if x < 0.5 else 1 - pow(-2 * x + 2, 4) / 2, + "out_quart": lambda x: 1 - pow(1 - x, 4), + "in_quint": lambda x: pow(x, 5), + "in_out_quint": lambda x: 16 * pow(x, 5) if x < 0.5 else 1 - pow(-2 * x + 2, 5) / 2, + "out_quint": lambda x: 1 - pow(1 - x, 5), + "in_expo": lambda x: pow(2, 10 * x - 10) if x else 0, + "in_out_expo": _in_out_expo, + "out_expo": lambda x: 1 - pow(2, -10 * x) if x != 1 else 1, + "in_circ": lambda x: 1 - sqrt(1 - pow(x, 2)), + "in_out_circ": _in_out_circ, + "out_circ": lambda x: sqrt(1 - pow(x - 1, 2)), + "in_back": lambda x: 2.70158 * pow(x, 3) - 1.70158 * pow(x, 2), + "in_out_back": _in_out_back, + "out_back": lambda x: 1 + 2.70158 * pow(x - 1, 3) + 1.70158 * pow(x - 1, 2), + "in_elastic": _in_elastic, + "in_out_elastic": _in_out_elastic, + "out_elastic": _out_elastic, + "in_bounce": _in_bounce, + "in_out_bounce": _in_out_bounce, + "out_bounce": _out_bounce, +} + +DEFAULT_EASING = "in_out_cubic" +DEFAULT_SCROLL_EASING = "out_cubic" diff --git a/src/memray/_vendor/textual/_event_broker.py b/src/memray/_vendor/textual/_event_broker.py new file mode 100644 index 0000000000..fe6727e105 --- /dev/null +++ b/src/memray/_vendor/textual/_event_broker.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from typing import Any, NamedTuple + + +class NoHandler(Exception): + """Raised when handler isn't found in the meta.""" + + +class HandlerArguments(NamedTuple): + """Information for event handler.""" + + modifiers: set[str] + action: Any + + +def extract_handler_actions(event_name: str, meta: dict[str, Any]) -> HandlerArguments: + """Extract action from meta dict. + + Args: + event_name: Event to check from. + meta: Meta information (stored in Rich Style) + + Raises: + NoHandler: If no handler is found. + + Returns: + Action information. + """ + event_path = event_name.split(".") + for key, value in meta.items(): + if key.startswith("@"): + name_args = key[1:].split(".") + if name_args[: len(event_path)] == event_path: + modifiers = name_args[len(event_path) :] + return HandlerArguments(set(modifiers), value) + raise NoHandler(f"No handler for {event_name!r}") diff --git a/src/memray/_vendor/textual/_extrema.py b/src/memray/_vendor/textual/_extrema.py new file mode 100644 index 0000000000..69ee65ac39 --- /dev/null +++ b/src/memray/_vendor/textual/_extrema.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from fractions import Fraction +from typing import NamedTuple + +from memray._vendor.textual.geometry import Size + + +class Extrema(NamedTuple): + """Specifies minimum and maximum dimensions.""" + + min_width: Fraction | None = None + max_width: Fraction | None = None + min_height: Fraction | None = None + max_height: Fraction | None = None + + def apply_width(self, width: Fraction) -> Fraction: + """Apply width extrema. + + Args: + width: Width value. + + Returns: + Width, clamped between minimum and maximum. + + """ + min_width, max_width = self[:2] + if min_width is not None: + width = max(width, min_width) + if max_width is not None: + width = min(width, max_width) + return width + + def apply_height(self, height: Fraction) -> Fraction: + """Apply height extrema. + + Args: + height: Height value. + + Returns: + Height, clamped between minimum and maximum. + + """ + min_height, max_height = self[2:] + if min_height is not None: + height = max(height, min_height) + if max_height is not None: + height = min(height, max_height) + return height + + def apply_dimensions(self, width: int, height: int) -> Size: + """Apply extrema to integer dimensions. + + Args: + width: Integer width. + height: Integer height. + + Returns: + Size with extrema applied. + """ + return Size( + int(self.apply_width(Fraction(width))), + int(self.apply_height(Fraction(height))), + ) diff --git a/src/memray/_vendor/textual/_files.py b/src/memray/_vendor/textual/_files.py new file mode 100644 index 0000000000..37fd76adfb --- /dev/null +++ b/src/memray/_vendor/textual/_files.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from datetime import datetime + + +def generate_datetime_filename( + prefix: str, suffix: str, datetime_format: str | None = None +) -> str: + """Generate a filename which includes the current date and time. + + Useful for ensuring a degree of uniqueness when saving files. + + Args: + prefix: Prefix to attach to the start of the filename, before the timestamp string. + suffix: Suffix to attach to the end of the filename, after the timestamp string. + This should include the file extension. + datetime_format: The format of the datetime to include in the filename. + If None, the ISO format will be used. + """ + if datetime_format is None: + dt = datetime.now().isoformat() + else: + dt = datetime.now().strftime(datetime_format) + + file_name_stem = f"{prefix} {dt}" + for reserved in ' <>:"/\\|?*.': + file_name_stem = file_name_stem.replace(reserved, "_") + return file_name_stem + suffix diff --git a/src/memray/_vendor/textual/_immutable_sequence_view.py b/src/memray/_vendor/textual/_immutable_sequence_view.py new file mode 100644 index 0000000000..823c73e6a4 --- /dev/null +++ b/src/memray/_vendor/textual/_immutable_sequence_view.py @@ -0,0 +1,69 @@ +"""Provides an immutable sequence view class.""" + +from __future__ import annotations + +from sys import maxsize +from typing import TYPE_CHECKING, Generic, Iterator, Sequence, TypeVar, overload + +T = TypeVar("T") + + +class ImmutableSequenceView(Generic[T]): + """Class to wrap a sequence of some sort, but not allow modification.""" + + def __init__(self, wrap: Sequence[T]) -> None: + """Initialise the immutable sequence. + + Args: + wrap: The sequence being wrapped. + """ + self._wrap = wrap + + if TYPE_CHECKING: + + @overload + def __getitem__(self, index: int) -> T: ... + + @overload + def __getitem__(self, index: slice) -> ImmutableSequenceView[T]: ... + + def __getitem__(self, index: int | slice) -> T | ImmutableSequenceView[T]: + return ( + self._wrap[index] + if isinstance(index, int) + else ImmutableSequenceView[T](self._wrap[index]) + ) + + def __iter__(self) -> Iterator[T]: + return iter(self._wrap) + + def __len__(self) -> int: + return len(self._wrap) + + def __length_hint__(self) -> int: + return len(self) + + def __bool__(self) -> bool: + return bool(self._wrap) + + def __contains__(self, item: T) -> bool: + return item in self._wrap + + def index(self, item: T, start: int = 0, stop: int = maxsize) -> int: + """Return the index of the given item. + + Args: + item: The item to find in the sequence. + start: Optional start location. + stop: Optional stop location. + + Returns: + The index of the item in the sequence. + + Raises: + ValueError: If the item is not in the sequence. + """ + return self._wrap.index(item, start, stop) + + def __reversed__(self) -> Iterator[T]: + return reversed(self._wrap) diff --git a/src/memray/_vendor/textual/_import_app.py b/src/memray/_vendor/textual/_import_app.py new file mode 100644 index 0000000000..42820d15da --- /dev/null +++ b/src/memray/_vendor/textual/_import_app.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +import os +import runpy +import shlex +import sys +from pathlib import Path +from typing import TYPE_CHECKING, cast + +if TYPE_CHECKING: + from memray._vendor.textual.app import App + + +class AppFail(Exception): + pass + + +def shebang_python(candidate: Path) -> bool: + """Does the given file look like it's run with Python? + + Args: + candidate: The candidate file to check. + + Returns: + ``True`` if it looks to #! python, ``False`` if not. + """ + try: + with candidate.open("rb") as source: + first_line = source.readline() + except IOError: + return False + return first_line.startswith(b"#!") and b"python" in first_line + + +def import_app(import_name: str) -> App: + """Import an app from a path or import name. + + Args: + import_name: A name to import, such as `foo.bar`, or a path ending with .py. + + Raises: + AppFail: If the app could not be found for any reason. + + Returns: + A Textual application + """ + + import importlib + import inspect + + from memray._vendor.textual.app import WINDOWS, App + + import_name, *argv = shlex.split(import_name, posix=not WINDOWS) + drive, import_name = os.path.splitdrive(import_name) + + lib, _colon, name = import_name.partition(":") + + if drive: + lib = os.path.join(drive, os.sep, lib) + + if lib.endswith(".py") or shebang_python(Path(lib)): + path = os.path.abspath(lib) + sys.path.append(str(Path(path).parent)) + try: + global_vars = runpy.run_path(path, {}) + except Exception as error: + raise AppFail(str(error)) + + sys.argv[:] = [path, *argv] + + if name: + # User has given a name, use that + try: + app = global_vars[name] + except KeyError: + raise AppFail(f"App {name!r} not found in {lib!r}") + else: + # User has not given a name + if "app" in global_vars: + # App exists, lets use that + try: + app = global_vars["app"] + except KeyError: + raise AppFail(f"App {name!r} not found in {lib!r}") + else: + # Find an App class or instance that is *not* the base class + apps = [ + value + for value in global_vars.values() + if ( + isinstance(value, App) + or (inspect.isclass(value) and issubclass(value, App)) + and value is not App + ) + ] + if not apps: + raise AppFail( + f'Unable to find app in {lib!r}, try specifying app with "foo.py:app"' + ) + if len(apps) > 1: + raise AppFail( + f'Multiple apps found {lib!r}, try specifying app with "foo.py:app"' + ) + app = apps[0] + app._BASE_PATH = path + + else: + # Assuming the user wants to import the file + sys.path.append("") + try: + module = importlib.import_module(lib) + except ImportError as error: + raise AppFail(str(error)) + + find_app = name or "app" + try: + app = getattr(module, find_app or "app") + except AttributeError: + raise AppFail(f"Unable to find {find_app!r} in {module!r}") + + sys.argv[:] = [import_name, *argv] + + if inspect.isclass(app) and issubclass(app, App): + app = app() + + return cast(App, app) diff --git a/src/memray/_vendor/textual/_keyboard_protocol.py b/src/memray/_vendor/textual/_keyboard_protocol.py new file mode 100644 index 0000000000..9b76541102 --- /dev/null +++ b/src/memray/_vendor/textual/_keyboard_protocol.py @@ -0,0 +1,123 @@ +# https://sw.kovidgoyal.net/kitty/keyboard-protocol/#functional-key-definitions +FUNCTIONAL_KEYS = { + "27u": "escape", + "13u": "enter", + "9u": "tab", + "127u": "backspace", + "2~": "insert", + "3~": "delete", + "1D": "left", + "1C": "right", + "1A": "up", + "1B": "down", + "5~": "pageup", + "6~": "pagedown", + "1H": "home", + "1~": "home", + "7~": "home", + "1F": "end", + "4~": "end", + "8~": "end", + "57358u": "caps_lock", + "57359u": "scroll_lock", + "57360u": "num_lock", + "57361u": "print_screen", + "57362u": "pause", + "57363u": "menu", + "1P": "f1", + "11~": "f1", + "1Q": "f2", + "12~": "f2", + "13~": "f3", + "1R": "f3", + "1S": "f4", + "14~": "f4", + "15~": "f5", + "17~": "f6", + "18~": "f7", + "19~": "f8", + "20~": "f9", + "21~": "f10", + "23~": "f11", + "24~": "f12", + "57376u": "f13", + "57377u": "f14", + "57378u": "f15", + "57379u": "f16", + "57380u": "f17", + "57381u": "f18", + "57382u": "f19", + "57383u": "f20", + "57384u": "f21", + "57385u": "f22", + "57386u": "f23", + "57387u": "f24", + "57388u": "f25", + "57389u": "f26", + "57390u": "f27", + "57391u": "f28", + "57392u": "f29", + "57393u": "f30", + "57394u": "f31", + "57395u": "f32", + "57396u": "f33", + "57397u": "f34", + "57398u": "f35", + "57399u": "0", + "57400u": "1", + "57401u": "2", + "57402u": "3", + "57403u": "4", + "57404u": "5", + "57405u": "6", + "57406u": "7", + "57407u": "8", + "57408u": "9", + "57409u": "decimal", + "57410u": "divide", + "57411u": "multiply", + "57412u": "subtract", + "57413u": "add", + "57414u": "enter", + "57415u": "equal", + "57416u": "separator", + "57417u": "left", + "57418u": "right", + "57419u": "up", + "57420u": "down", + "57421u": "pageup", + "57422u": "pagedown", + "57423u": "home", + "57424u": "end", + "57425u": "insert", + "57426u": "delete", + "1E": "kp_begin", + "57427~": "kp_begin", + "57428u": "media_play", + "57429u": "media_pause", + "57430u": "media_play_pause", + "57431u": "media_reverse", + "57432u": "media_stop", + "57433u": "media_fast_forward", + "57434u": "media_rewind", + "57435u": "media_track_next", + "57436u": "media_track_previous", + "57437u": "media_record", + "57438u": "lower_volume", + "57439u": "raise_volume", + "57440u": "mute_volume", + "57441u": "left_shift", + "57442u": "left_control", + "57443u": "left_alt", + "57444u": "left_super", + "57445u": "left_hyper", + "57446u": "left_meta", + "57447u": "right_shift", + "57448u": "right_control", + "57449u": "right_alt", + "57450u": "right_super", + "57451u": "right_hyper", + "57452u": "right_meta", + "57453u": "iso_level3_shift", + "57454u": "iso_level5_shift", +} diff --git a/src/memray/_vendor/textual/_layout_resolve.py b/src/memray/_vendor/textual/_layout_resolve.py new file mode 100644 index 0000000000..902e8476cb --- /dev/null +++ b/src/memray/_vendor/textual/_layout_resolve.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +from fractions import Fraction +from typing import Sequence, cast + +from typing_extensions import Protocol + + +class EdgeProtocol(Protocol): + """Any object that defines an edge (such as Layout).""" + + # Size of edge in cells, or None for no fixed size + size: int | None + # Portion of flexible space to use if size is None + fraction: int + # Minimum size for edge, in cells + min_size: int + + +def layout_resolve(total: int, edges: Sequence[EdgeProtocol]) -> list[int]: + """Divide total space to satisfy size, fraction, and min_size, constraints. + + The returned list of integers should add up to total in most cases, unless it is + impossible to satisfy all the constraints. For instance, if there are two edges + with a minimum size of 20 each and `total` is 30 then the returned list will be + greater than total. In practice, this would mean that a Layout object would + clip the rows that would overflow the screen height. + + Args: + total: Total number of characters. + edges: Edges within total space. + + Returns: + Number of characters for each edge. + """ + # Size of edge or None for yet to be determined + sizes = [(edge.size or None) for edge in edges] + + if None not in sizes: + # No flexible edges + return cast("list[int]", sizes) + + # Get flexible edges and index to map these back on to sizes list + flexible_edges = [ + (index, edge) + for index, (size, edge) in enumerate(zip(sizes, edges)) + if size is None + ] + # Remaining space in total + remaining = total - sum([size or 0 for size in sizes]) + if remaining <= 0: + # No room for flexible edges + return [ + ((edge.min_size or 1) if size is None else size) + for size, edge in zip(sizes, edges) + ] + + # Get the total fraction value for all flexible edges + total_flexible = sum([(edge.fraction or 1) for _, edge in flexible_edges]) + while flexible_edges: + # Calculate number of characters in a ratio portion + portion = Fraction(remaining, total_flexible) + + # If any edges will be less than their minimum, replace size with the minimum + for flexible_index, (index, edge) in enumerate(flexible_edges): + if portion * edge.fraction < edge.min_size: + # This flexible edge will be smaller than its minimum size + # We need to fix the size and redistribute the outstanding space + sizes[index] = edge.min_size + remaining -= edge.min_size + total_flexible -= edge.fraction or 1 + del flexible_edges[flexible_index] + # New fixed size will invalidate calculations, so we need to repeat the process + break + else: + # Distribute flexible space and compensate for rounding error + # Since edge sizes can only be integers we need to add the remainder + # to the following line + remainder = Fraction(0) + for index, edge in flexible_edges: + sizes[index], remainder = divmod(portion * edge.fraction + remainder, 1) + break + + # Sizes now contains integers only + return cast("list[int]", sizes) diff --git a/src/memray/_vendor/textual/_line_split.py b/src/memray/_vendor/textual/_line_split.py new file mode 100644 index 0000000000..0f5058741a --- /dev/null +++ b/src/memray/_vendor/textual/_line_split.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +import re + +# Pre-compile the regular expression and store it in a global constant +LINE_AND_ENDING_PATTERN = re.compile(r"(.*?)(\r\n|\r|\n|$)", re.S) + + +def line_split(input_string: str) -> list[tuple[str, str]]: + r""" + Splits an arbitrary string into a list of tuples, where each tuple contains a line of text and its line ending. + + Args: + input_string (str): The string to split. + + Returns: + list[tuple[str, str]]: A list of tuples, where each tuple contains a line of text and its line ending. + + Example: + split_string_to_lines_and_endings("Hello\r\nWorld\nThis is a test\rLast line") + >>> [('Hello', '\r\n'), ('World', '\n'), ('This is a test', '\r'), ('Last line', '')] + """ + return LINE_AND_ENDING_PATTERN.findall(input_string)[:-1] if input_string else [] diff --git a/src/memray/_vendor/textual/_log.py b/src/memray/_vendor/textual/_log.py new file mode 100644 index 0000000000..d42811089b --- /dev/null +++ b/src/memray/_vendor/textual/_log.py @@ -0,0 +1,23 @@ +from enum import Enum + + +class LogGroup(Enum): + """A log group is a classification of the log message (*not* a level).""" + + UNDEFINED = 0 # Mainly for testing + EVENT = 1 + DEBUG = 2 + INFO = 3 + WARNING = 4 + ERROR = 5 + PRINT = 6 + SYSTEM = 7 + LOGGING = 8 + WORKER = 9 + + +class LogVerbosity(Enum): + """Tags log messages as being verbose and potentially excluded from output.""" + + NORMAL = 0 + HIGH = 1 diff --git a/src/memray/_vendor/textual/_loop.py b/src/memray/_vendor/textual/_loop.py new file mode 100644 index 0000000000..67546999c0 --- /dev/null +++ b/src/memray/_vendor/textual/_loop.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +from typing import Iterable, Literal, Sequence, TypeVar + +T = TypeVar("T") + + +def loop_first(values: Iterable[T]) -> Iterable[tuple[bool, T]]: + """Iterate and generate a tuple with a flag for first value.""" + iter_values = iter(values) + try: + value = next(iter_values) + except StopIteration: + return + yield True, value + for value in iter_values: + yield False, value + + +def loop_last(values: Iterable[T]) -> Iterable[tuple[bool, T]]: + """Iterate and generate a tuple with a flag for last value.""" + iter_values = iter(values) + try: + previous_value = next(iter_values) + except StopIteration: + return + for value in iter_values: + yield False, previous_value + previous_value = value + yield True, previous_value + + +def loop_first_last(values: Iterable[T]) -> Iterable[tuple[bool, bool, T]]: + """Iterate and generate a tuple with a flag for first and last value.""" + iter_values = iter(values) + try: + previous_value = next(iter_values) + except StopIteration: + return + first = True + for value in iter_values: + yield first, False, previous_value + first = False + previous_value = value + yield first, True, previous_value + + +def loop_from_index( + values: Sequence[T], + index: int, + direction: Literal[-1, +1] = +1, + wrap: bool = True, +) -> Iterable[tuple[int, T]]: + """Iterate over values in a sequence from a given starting index, potentially wrapping the index + if it would go out of bounds. + + Note that the first value to be yielded is a step from `index`, and `index` will be yielded *last*. + + + Args: + values: A sequence of values. + index: Starting index. + direction: Direction to move index (+1 for forward, -1 for backward). + bool: Should the index wrap when out of bounds? + + Yields: + A tuple of index and value from the sequence. + """ + # Sanity check for devs who miss the typing errors + assert direction in (-1, +1), "direction must be -1 or +1" + count = len(values) + if wrap: + for _ in range(count): + index = (index + direction) % count + yield (index, values[index]) + else: + if direction == +1: + for _ in range(count): + if (index := index + 1) >= count: + break + yield (index, values[index]) + else: + for _ in range(count): + if (index := index - 1) < 0: + break + yield (index, values[index]) diff --git a/src/memray/_vendor/textual/_markup_playground.py b/src/memray/_vendor/textual/_markup_playground.py new file mode 100644 index 0000000000..7a1d455ac2 --- /dev/null +++ b/src/memray/_vendor/textual/_markup_playground.py @@ -0,0 +1,150 @@ +import json + +from memray._vendor.textual import containers, events, on +from memray._vendor.textual.app import App, ComposeResult +from memray._vendor.textual.content import Content +from memray._vendor.textual.reactive import reactive +from memray._vendor.textual.widgets import Footer, Pretty, Static, TextArea + + +class MarkupPlayground(App): + + TITLE = "Markup Playground" + CSS = """ + Screen { + layout: vertical; + #editor { + width: 1fr; + height: 1fr; + border: tab $foreground 50%; + padding: 1; + margin: 1 0 0 0; + &:focus { + border: tab $primary; + } + + } + #variables { + width: 1fr; + height: 1fr; + border: tab $foreground 50%; + padding: 1; + margin: 1 0 0 1; + &:focus { + border: tab $primary; + } + } + #variables.-bad-json { + border: tab $error; + } + #results-container { + border: tab $success; + &.-error { + border: tab $error; + } + overflow-y: auto; + } + #results { + padding: 1 1; + width: 1fr; + } + #spans-container { + border: tab $success; + overflow-y: auto; + margin: 0 0 0 1; + } + #spans { + padding: 1 1; + width: 1fr; + } + HorizontalGroup { + height: 1fr; + } + } + """ + AUTO_FOCUS = "#editor" + + BINDINGS = [ + ("f1", "toggle('show_variables')", "Variables"), + ("f2", "toggle('show_spans')", "Spans"), + ] + variables: reactive[dict[str, object]] = reactive({}) + + show_variables = reactive(True) + show_spans = reactive(False) + + def compose(self) -> ComposeResult: + with containers.HorizontalGroup(): + yield (editor := TextArea(id="editor", soft_wrap=False)) + yield (variables := TextArea("", id="variables", language="json")) + editor.border_title = "Markup" + variables.border_title = "Variables (JSON)" + + with containers.HorizontalGroup(): + with containers.VerticalScroll(id="results-container") as container: + yield Static(id="results") + container.border_title = "Output" + with containers.VerticalScroll(id="spans-container") as container: + yield Pretty([], id="spans") + container.border_title = "Spans" + + yield Footer() + + def watch_show_variables(self, show_variables: bool) -> None: + self.query_one("#variables").display = show_variables + + def watch_show_spans(self, show_spans: bool) -> None: + self.query_one("#spans-container").display = show_spans + + @on(TextArea.Changed, "#editor") + def on_markup_changed(self, event: TextArea.Changed) -> None: + self.update_markup() + + def update_markup(self) -> None: + results = self.query_one("#results", Static) + editor = self.query_one("#editor", TextArea) + spans = self.query_one("#spans", Pretty) + try: + content = Content.from_markup(editor.text, **self.variables) + results.update(content) + spans.update(content.spans) + except Exception: + from rich.traceback import Traceback + + results.update(Traceback()) + spans.update([]) + + self.query_one("#results-container").add_class("-error").scroll_end( + animate=False + ) + else: + self.query_one("#results-container").remove_class("-error") + + def watch_variables(self, variables: dict[str, object]) -> None: + self.update_markup() + + @on(TextArea.Changed, "#variables") + def on_variables_change(self, event: TextArea.Changed) -> None: + variables_text_area = self.query_one("#variables", TextArea) + try: + variables = json.loads(variables_text_area.text) + except Exception as error: + variables_text_area.add_class("-bad-json") + self.variables = {} + else: + variables_text_area.remove_class("-bad-json") + self.variables = variables + + @on(events.DescendantBlur, "#variables") + def on_variables_blur(self) -> None: + variables_text_area = self.query_one("#variables", TextArea) + try: + variables = json.loads(variables_text_area.text) + except Exception as error: + if not variables_text_area.has_class("-bad-json"): + self.notify(f"Bad JSON: ${error}", title="Variables", severity="error") + variables_text_area.add_class("-bad-json") + else: + variables_text_area.remove_class("-bad-json") + variables_text_area.text = json.dumps(variables, indent=4) + self.variables = variables diff --git a/src/memray/_vendor/textual/_node_list.py b/src/memray/_vendor/textual/_node_list.py new file mode 100644 index 0000000000..a1d66f530d --- /dev/null +++ b/src/memray/_vendor/textual/_node_list.py @@ -0,0 +1,236 @@ +from __future__ import annotations + +import sys +import weakref +from operator import attrgetter +from typing import TYPE_CHECKING, Any, Callable, Iterator, Sequence, overload + +import rich.repr + +if TYPE_CHECKING: + from _typeshed import SupportsRichComparison + + from memray._vendor.textual.dom import DOMNode + from memray._vendor.textual.widget import Widget + + +_display_getter = attrgetter("display") +_visible_getter = attrgetter("visible") + + +class DuplicateIds(Exception): + """Raised when attempting to add a widget with an id that already exists.""" + + +class ReadOnlyError(AttributeError): + """Raise if you try to mutate the list.""" + + +@rich.repr.auto(angular=True) +class NodeList(Sequence["Widget"]): + """ + A container for widgets that forms one level of hierarchy. + + Although named a list, widgets may appear only once, making them more like a set. + """ + + def __init__(self, parent: DOMNode | None = None) -> None: + """Initialize a node list. + + Args: + parent: The parent node which holds a reference to this object, or `None` if + there is no parent. + """ + self._parent = None if parent is None else weakref.ref(parent) + # The nodes in the list + self._nodes: list[Widget] = [] + self._nodes_set: set[Widget] = set() + self._displayed_nodes: tuple[int, list[Widget]] = (-1, []) + self._displayed_visible_nodes: tuple[int, list[Widget]] = (-1, []) + + # We cache widgets by their IDs too for a quick lookup + # Note that only widgets with IDs are cached like this, so + # this cache will likely hold fewer values than self._nodes. + self._nodes_by_id: dict[str, Widget] = {} + + # Increments when list is updated (used for caching) + self._updates = 0 + + def __bool__(self) -> bool: + return bool(self._nodes) + + def __length_hint__(self) -> int: + return len(self._nodes) + + def __rich_repr__(self) -> rich.repr.Result: + yield self._nodes + + def __len__(self) -> int: + return len(self._nodes) + + def __contains__(self, widget: object) -> bool: + return widget in self._nodes + + def updated(self) -> None: + """Mark the nodes as having been updated.""" + self._updates += 1 + node = None if self._parent is None else self._parent() + while node is not None and (node := node._parent) is not None: + node._nodes._updates += 1 + + def _sort( + self, + *, + key: Callable[[Widget], SupportsRichComparison] | None = None, + reverse: bool = False, + ): + """Sort nodes. + + Args: + key: A key function which accepts a widget, or `None` for no key function. + reverse: Sort in descending order. + """ + if key is None: + self._nodes.sort(key=attrgetter("sort_order"), reverse=reverse) + else: + self._nodes.sort(key=key, reverse=reverse) + + self.updated() + + def index(self, widget: Any, start: int = 0, stop: int = sys.maxsize) -> int: + """Return the index of the given widget. + + Args: + widget: The widget to find in the node list. + + Returns: + The index of the widget in the node list. + + Raises: + ValueError: If the widget is not in the node list. + """ + return self._nodes.index(widget, start, stop) + + def _get_by_id(self, widget_id: str) -> Widget | None: + """Get the widget for the given widget_id, or None if there's no matches in this list""" + return self._nodes_by_id.get(widget_id) + + def _append(self, widget: Widget) -> None: + """Append a Widget. + + Args: + widget: A widget. + """ + if widget not in self._nodes_set: + self._nodes.append(widget) + self._nodes_set.add(widget) + widget_id = widget.id + if widget_id is not None: + self._ensure_unique_id(widget_id) + self._nodes_by_id[widget_id] = widget + self.updated() + + def _insert(self, index: int, widget: Widget) -> None: + """Insert a Widget. + + Args: + widget: A widget. + """ + if widget not in self._nodes_set: + self._nodes.insert(index, widget) + self._nodes_set.add(widget) + widget_id = widget.id + if widget_id is not None: + self._ensure_unique_id(widget_id) + self._nodes_by_id[widget_id] = widget + self.updated() + + def _ensure_unique_id(self, widget_id: str) -> None: + """Ensure a new widget ID would be unique. + + Args: + widget_id: New widget ID. + + Raises: + DuplicateIds: If the given ID is not unique. + """ + if widget_id in self._nodes_by_id: + raise DuplicateIds( + f"Tried to insert a widget with ID {widget_id!r}, but a widget already exists with that ID ({self._nodes_by_id[widget_id]!r}); " + "ensure all child widgets have a unique ID." + ) + + def _remove(self, widget: Widget) -> None: + """Remove a widget from the list. + + Removing a widget not in the list is a null-op. + + Args: + widget: A Widget in the list. + """ + if widget in self._nodes_set: + del self._nodes[self._nodes.index(widget)] + self._nodes_set.remove(widget) + widget_id = widget.id + if widget_id in self._nodes_by_id: + del self._nodes_by_id[widget_id] + self.updated() + + def _clear(self) -> None: + """Clear the node list.""" + if self._nodes: + self._nodes.clear() + self._nodes_set.clear() + self._nodes_by_id.clear() + self.updated() + + def __iter__(self) -> Iterator[Widget]: + return iter(self._nodes) + + def __reversed__(self) -> Iterator[Widget]: + return reversed(self._nodes) + + @property + def displayed(self) -> Sequence[Widget]: + """Just the nodes where `display==True`.""" + if self._displayed_nodes[0] != self._updates: + self._displayed_nodes = ( + self._updates, + list(filter(_display_getter, self._nodes)), + ) + return self._displayed_nodes[1] + + @property + def displayed_and_visible(self) -> Sequence[Widget]: + """Nodes with both `display==True` and `visible==True`.""" + if self._displayed_visible_nodes[0] != self._updates: + self._displayed_nodes = ( + self._updates, + list(filter(_visible_getter, self.displayed)), + ) + return self._displayed_nodes[1] + + @property + def displayed_reverse(self) -> Iterator[Widget]: + """Just the nodes where `display==True`, in reverse order.""" + return filter(_display_getter, reversed(self._nodes)) + + if TYPE_CHECKING: + + @overload + def __getitem__(self, index: int) -> Widget: ... + + @overload + def __getitem__(self, index: slice) -> list[Widget]: ... + + def __getitem__(self, index: int | slice) -> Widget | list[Widget]: + return self._nodes[index] + + if not TYPE_CHECKING: + # This confused the type checker for some reason + def __getattr__(self, key: str) -> object: + if key in {"clear", "append", "pop", "insert", "remove", "extend"}: + raise ReadOnlyError( + "Widget.children is read-only: use Widget.mount(...) or Widget.remove(...) to add or remove widgets" + ) + raise AttributeError(key) diff --git a/src/memray/_vendor/textual/_on.py b/src/memray/_vendor/textual/_on.py new file mode 100644 index 0000000000..2f81c66eaa --- /dev/null +++ b/src/memray/_vendor/textual/_on.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +from typing import Callable, TypeVar + +from memray._vendor.textual.css.model import SelectorSet +from memray._vendor.textual.css.parse import parse_selectors +from memray._vendor.textual.css.tokenizer import TokenError +from memray._vendor.textual.message import Message + +DecoratedType = TypeVar("DecoratedType") + + +class OnDecoratorError(Exception): + """Errors related to the `on` decorator. + + Typically raised at import time as an early warning system. + """ + + +class OnNoWidget(Exception): + """A selector was applied to an attribute that isn't a widget.""" + + +def on( + message_type: type[Message], selector: str | None = None, **kwargs: str +) -> Callable[[DecoratedType], DecoratedType]: + """Decorator to declare that the method is a message handler. + + The decorator accepts an optional CSS selector that will be matched against a widget exposed by + a `control` property on the message. + + Example: + ```python + # Handle the press of buttons with ID "#quit". + @on(Button.Pressed, "#quit") + def quit_button(self) -> None: + self.app.quit() + ``` + + Keyword arguments can be used to match additional selectors for attributes + listed in [`ALLOW_SELECTOR_MATCH`][textual.message.Message.ALLOW_SELECTOR_MATCH]. + + Example: + ```python + # Handle the activation of the tab "#home" within the `TabbedContent` "#tabs". + @on(TabbedContent.TabActivated, "#tabs", pane="#home") + def switch_to_home(self) -> None: + self.log("Switching back to the home tab.") + ... + ``` + + Args: + message_type: The message type (i.e. the class). + selector: An optional [selector](/guide/CSS#selectors). If supplied, the handler will only be called if `selector` + matches the widget from the `control` attribute of the message. + **kwargs: Additional selectors for other attributes of the message. + """ + + selectors: dict[str, str] = {} + if selector is not None: + selectors["control"] = selector + if kwargs: + selectors.update(kwargs) + + parsed_selectors: dict[str, tuple[SelectorSet, ...]] = {} + for attribute, css_selector in selectors.items(): + if attribute == "control": + if message_type.control == Message.control: + raise OnDecoratorError( + "The message class must have a 'control' to match with the on decorator" + ) + elif attribute not in message_type.ALLOW_SELECTOR_MATCH: + raise OnDecoratorError( + f"The attribute {attribute!r} can't be matched; have you added it to " + + f"{message_type.__name__}.ALLOW_SELECTOR_MATCH?" + ) + try: + parsed_selectors[attribute] = parse_selectors(css_selector) + except TokenError: + raise OnDecoratorError( + f"Unable to parse selector {css_selector!r} for {attribute}; check for syntax errors" + ) from None + + def decorator(method: DecoratedType) -> DecoratedType: + """Store message and selector in function attribute, return callable unaltered.""" + + if not hasattr(method, "_textual_on"): + setattr(method, "_textual_on", []) + getattr(method, "_textual_on").append((message_type, parsed_selectors)) + + return method + + return decorator diff --git a/src/memray/_vendor/textual/_opacity.py b/src/memray/_vendor/textual/_opacity.py new file mode 100644 index 0000000000..b0bcf06169 --- /dev/null +++ b/src/memray/_vendor/textual/_opacity.py @@ -0,0 +1,42 @@ +from typing import Iterable, cast + +from rich.segment import Segment +from rich.style import Style + +from memray._vendor.textual.color import Color + + +def _apply_opacity( + segments: Iterable[Segment], + base_background: Color, + opacity: float, +) -> Iterable[Segment]: + """Takes an iterable of foreground Segments and blends them into the supplied + background color, yielding copies of the Segments with blended foreground and + background colors applied. + + Args: + segments: The segments in the foreground. + base_background: The background color to blend foreground into. + opacity: The blending factor. A value of 1.0 means output segments will + have identical foreground and background colors to input segments. + """ + _Segment = Segment + from_rich_color = Color.from_rich_color + from_color = Style.from_color + blend = base_background.blend + styled_segments = cast("Iterable[tuple[str, Style, object]]", segments) + for text, style, _ in styled_segments: + blended_style = style + + if style.color is not None: + color = from_rich_color(style.color) + blended_foreground = blend(color, opacity) + blended_style += from_color(color=blended_foreground.rich_color) + + if style.bgcolor is not None: + bgcolor = from_rich_color(style.bgcolor) + blended_background = blend(bgcolor, opacity) + blended_style += from_color(bgcolor=blended_background.rich_color) + + yield _Segment(text, blended_style) diff --git a/src/memray/_vendor/textual/_parser.py b/src/memray/_vendor/textual/_parser.py new file mode 100644 index 0000000000..9dea24dbff --- /dev/null +++ b/src/memray/_vendor/textual/_parser.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +from collections import deque +from typing import Callable, Deque, Generator, Generic, Iterable, NamedTuple, TypeVar + +from memray._vendor.textual._time import get_time + + +class ParseError(Exception): + """Base class for parse related errors.""" + + +class ParseEOF(ParseError): + """End of Stream.""" + + +class ParseTimeout(ParseError): + """Read has timed out.""" + + +class Read1(NamedTuple): + """Reads a single character.""" + + timeout: float | None = None + """Optional timeout in seconds.""" + + +class Peek1(NamedTuple): + """Reads a single character, but does not advance the parser position.""" + + timeout: float | None = None + """Optional timeout in seconds.""" + + +T = TypeVar("T") +TokenCallback = Callable[[T], None] + + +class Parser(Generic[T]): + """Base class for a simple parser.""" + + read1 = Read1 + peek1 = Peek1 + + def __init__(self) -> None: + self._eof = False + self._tokens: Deque[T] = deque() + self._gen = self.parse(self._tokens.append) + self._awaiting: Read1 | Peek1 = next(self._gen) + self._timeout_time: float | None = None + + @property + def is_eof(self) -> bool: + """Is the parser at the end of the file (i.e. exhausted)?""" + return self._eof + + def tick(self) -> Iterable[T]: + """Call at regular intervals to check for timeouts.""" + if self._timeout_time is not None and get_time() >= self._timeout_time: + self._timeout_time = None + self._awaiting = self._gen.throw(ParseTimeout()) + while self._tokens: + yield self._tokens.popleft() + + def feed(self, data: str) -> Iterable[T]: + """Feed data to be parsed. + + Args: + data: Data to parser. + + Raises: + ParseError: If the data could not be parsed. + + Yields: + T: A generic data type. + """ + if self._eof: + raise ParseError("end of file reached") from None + + tokens = self._tokens + popleft = tokens.popleft + + if not data: + self._eof = True + try: + self._gen.throw(ParseEOF()) + except StopIteration: + pass + while tokens: + yield popleft() + return + + pos = 0 + data_size = len(data) + + while tokens: + yield popleft() + + while pos < data_size: + _awaiting = self._awaiting + if isinstance(_awaiting, Read1): + self._timeout_time = None + self._awaiting = self._gen.send(data[pos]) + pos += 1 + elif isinstance(_awaiting, Peek1): + self._timeout_time = None + self._awaiting = self._gen.send(data[pos]) + + if self._awaiting.timeout is not None: + self._timeout_time = get_time() + self._awaiting.timeout + + while tokens: + yield popleft() + + def parse( + self, token_callback: TokenCallback + ) -> Generator[Read1 | Peek1, str, None]: + """Implement to parse a stream of text. + + Args: + token_callback: Callable to report a successful parsed data type. + + Yields: + ParseAwaitable: One of `self.read1` or `self.peek1` + """ + yield from () diff --git a/src/memray/_vendor/textual/_partition.py b/src/memray/_vendor/textual/_partition.py new file mode 100644 index 0000000000..c28a925fc1 --- /dev/null +++ b/src/memray/_vendor/textual/_partition.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from typing import Callable, Iterable, TypeVar + +T = TypeVar("T") + + +def partition( + predicate: Callable[[T], object], iterable: Iterable[T] +) -> tuple[list[T], list[T]]: + """Partition a sequence into two list from a given predicate. The first list will contain + the values where the predicate is False, the second list will contain the remaining values. + + Args: + predicate: A callable that returns True or False for a given value. + iterable: In Iterable of values. + + Returns: + A list of values where the predicate is False, and a list + where the predicate is True. + """ + + result: tuple[list[T], list[T]] = ([], []) + appends = (result[1].append, result[0].append) + for value in iterable: + appends[not predicate(value)](value) + return result diff --git a/src/memray/_vendor/textual/_path.py b/src/memray/_vendor/textual/_path.py new file mode 100644 index 0000000000..e83f01ca9d --- /dev/null +++ b/src/memray/_vendor/textual/_path.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +import inspect +from pathlib import Path, PurePath +from typing import List, Union + +from typing_extensions import TypeAlias + +CSSPathType: TypeAlias = Union[ + str, + PurePath, + List[Union[str, PurePath]], +] +"""Valid ways of specifying paths to CSS files.""" + + +class CSSPathError(Exception): + """Raised when supplied CSS path(s) are invalid.""" + + +def _css_path_type_as_list(css_path: CSSPathType) -> list[PurePath]: + """Normalize the supplied CSSPathType into a list of paths. + + Args: + css_path: Value to be normalized. + + Raises: + CSSPathError: If the argument has the wrong format. + + Returns: + A list of paths. + """ + + paths: list[PurePath] = [] + if isinstance(css_path, str): + paths = [Path(css_path)] + elif isinstance(css_path, PurePath): + paths = [css_path] + elif isinstance(css_path, list): + paths = [Path(path) for path in css_path] + else: + raise CSSPathError("Expected a str, Path or list[str | Path] for the CSS_PATH.") + + return paths + + +def _make_path_object_relative(path: str | PurePath, obj: object) -> Path: + """Convert the supplied path to a Path object that is relative to a given Python object. + If the supplied path is absolute, it will simply be converted to a Path object. + Used, for example, to return the path of a CSS file relative to a Textual App instance. + + Args: + path: A path. + obj: A Python object to resolve the path relative to. + + Returns: + A resolved Path object, relative to obj + """ + path = Path(path) + + # If the path supplied by the user is absolute, we can use it directly + if path.is_absolute(): + return path + + # Otherwise (relative path), resolve it relative to obj... + base_path = getattr(obj, "_BASE_PATH", None) + if base_path is not None: + subclass_path = Path(base_path) + else: + subclass_path = Path(inspect.getfile(obj.__class__)) + resolved_path = (subclass_path.parent / path).resolve() + return resolved_path diff --git a/src/memray/_vendor/textual/_profile.py b/src/memray/_vendor/textual/_profile.py new file mode 100644 index 0000000000..2b35cd9543 --- /dev/null +++ b/src/memray/_vendor/textual/_profile.py @@ -0,0 +1,26 @@ +""" +Timer context manager, only used in debug. +""" + +import contextlib +from time import perf_counter +from typing import Generator + +from memray._vendor.textual import log + + +@contextlib.contextmanager +def timer(subject: str = "time", threshold: float = 0) -> Generator[None, None, None]: + """print the elapsed time. (only used in debugging). + + Args: + subject: Text shown in log. + threshold: Time in second after which the log is written. + + """ + start = perf_counter() + yield + elapsed = perf_counter() - start + if elapsed >= threshold: + elapsed_ms = elapsed * 1000 + log(f"{subject} elapsed {elapsed_ms:.4f}ms") diff --git a/src/memray/_vendor/textual/_queue.py b/src/memray/_vendor/textual/_queue.py new file mode 100644 index 0000000000..d01e96adb7 --- /dev/null +++ b/src/memray/_vendor/textual/_queue.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import asyncio +from asyncio import Event +from collections import deque +from typing import Generic, TypeVar + +QueueType = TypeVar("QueueType") + + +class Queue(Generic[QueueType]): + """A cut-down version of asyncio.Queue + + This has just enough functionality to run the message pumps. + + """ + + def __init__(self) -> None: + self.values: deque[QueueType] = deque() + self.ready_event = Event() + + def put_nowait(self, value: QueueType) -> None: + self.values.append(value) + self.ready_event.set() + + def qsize(self) -> int: + return len(self.values) + + def empty(self) -> bool: + return not self.values + + def task_done(self) -> None: + pass + + async def get(self) -> QueueType: + if not self.ready_event.is_set(): + await self.ready_event.wait() + value = self.values.popleft() + if not self.values: + self.ready_event.clear() + return value + + def get_nowait(self) -> QueueType: + if not self.values: + raise asyncio.QueueEmpty() + value = self.values.popleft() + if not self.values: + self.ready_event.clear() + return value diff --git a/src/memray/_vendor/textual/_resolve.py b/src/memray/_vendor/textual/_resolve.py new file mode 100644 index 0000000000..cdea629608 --- /dev/null +++ b/src/memray/_vendor/textual/_resolve.py @@ -0,0 +1,347 @@ +from __future__ import annotations + +from fractions import Fraction +from itertools import accumulate +from typing import TYPE_CHECKING, Iterable, Sequence, cast + +from typing_extensions import Literal + +from memray._vendor.textual.box_model import BoxModel +from memray._vendor.textual.css.scalar import Scalar +from memray._vendor.textual.css.styles import RenderStyles +from memray._vendor.textual.geometry import Size + +if TYPE_CHECKING: + from memray._vendor.textual.widget import Widget + + +def resolve( + dimensions: Sequence[Scalar], + total: int, + gutter: int, + size: Size, + viewport: Size, + *, + expand: bool = False, + shrink: bool = False, + minimums: list[int] | None = None, +) -> list[tuple[int, int]]: + """Resolve a list of dimensions. + + Args: + dimensions: Scalars for column / row sizes. + total: Total space to divide. + gutter: Gutter between rows / columns. + size: Size of container. + viewport: Size of viewport. + + Returns: + List of (, ) + """ + resolved: list[tuple[Scalar, Fraction | None]] = [ + ( + (scalar, None) + if scalar.is_fraction + else (scalar, scalar.resolve(size, viewport)) + ) + for scalar in dimensions + ] + + from_float = Fraction.from_float + total_fraction = from_float( + sum([scalar.value for scalar, fraction in resolved if fraction is None]) + ) + + total_gutter = gutter * (len(dimensions) - 1) + if total_fraction: + consumed = sum([fraction for _, fraction in resolved if fraction is not None]) + remaining = max(Fraction(0), Fraction(total - total_gutter) - consumed) + fraction_unit = Fraction(remaining, total_fraction) + resolved_fractions = [ + from_float(scalar.value) * fraction_unit if fraction is None else fraction + for scalar, fraction in resolved + ] + else: + resolved_fractions = cast( + "list[Fraction]", [fraction for _, fraction in resolved] + ) + + fraction_gutter = Fraction(gutter) + + if expand or shrink: + total_space = total - total_gutter + used_space = sum(resolved_fractions) + if expand: + remaining_space = total_space - used_space + if remaining_space > 0: + resolved_fractions = [ + width + Fraction(width, used_space) * remaining_space + for width in resolved_fractions + ] + if shrink: + one = Fraction(1) + excess_space = used_space - total_space + if minimums is not None and excess_space > 0: + for index, (minimum_width, width) in enumerate( + zip(map(Fraction, minimums), resolved_fractions) + ): + remove_space = max(Fraction(width, used_space), one) * excess_space + updated_width = max(minimum_width, width - remove_space) + resolved_fractions[index] = updated_width + used_space = used_space - width + updated_width + excess_space = used_space - total_space + if excess_space <= 0: + break + + used_space = sum(resolved_fractions) + excess_space = used_space - total_space + + if excess_space > 0: + resolved_fractions = [ + width - Fraction(width, used_space) * excess_space + for width in resolved_fractions + ] + + offsets = [0] + [ + fraction.__floor__() + for fraction in accumulate( + value + for fraction in resolved_fractions + for value in (fraction, fraction_gutter) + ) + ] + results = [ + (offset1, offset2 - offset1) + for offset1, offset2 in zip(offsets[::2], offsets[1::2]) + ] + + return results + + +def resolve_fraction_unit( + widget_styles: Iterable[RenderStyles], + size: Size, + viewport_size: Size, + remaining_space: Fraction, + resolve_dimension: Literal["width", "height"] = "width", +) -> Fraction: + """Calculate the fraction. + + Args: + widget_styles: Styles for widgets with fraction units. + size: Container size. + viewport_size: Viewport size. + remaining_space: Remaining space for fr units. + resolve_dimension: Which dimension to resolve. + + Returns: + The value of 1fr. + """ + _Fraction = Fraction + if not remaining_space or not widget_styles: + return _Fraction(1) + + initial_space = remaining_space + + def resolve_scalar( + scalar: Scalar | None, fraction_unit: Fraction = Fraction(1) + ) -> Fraction | None: + """Resolve a scalar if it is not None. + + Args: + scalar: Optional scalar to resolve. + fraction_unit: Size of 1fr. + + Returns: + Fraction if resolved, otherwise None. + """ + return ( + None + if scalar is None + else scalar.resolve(size, viewport_size, fraction_unit) + ) + + resolve: list[tuple[Scalar, Fraction | None, Fraction | None]] = [] + + if resolve_dimension == "width": + resolve = [ + ( + cast(Scalar, styles.width), + resolve_scalar(styles.min_width), + resolve_scalar(styles.max_width), + ) + for styles in widget_styles + if styles.overlay != "screen" + ] + else: + resolve = [ + ( + cast(Scalar, styles.height), + resolve_scalar(styles.min_height), + resolve_scalar(styles.max_height), + ) + for styles in widget_styles + if styles.overlay != "screen" + ] + + resolved: list[Fraction | None] = [None] * len(resolve) + remaining_fraction = Fraction(sum(scalar.value for scalar, _, _ in resolve)) + + while remaining_fraction > 0: + remaining_space_changed = False + resolve_fraction = _Fraction(remaining_space, remaining_fraction) + for index, (scalar, min_value, max_value) in enumerate(resolve): + value = resolved[index] + if value is None: + resolved_scalar = scalar.resolve(size, viewport_size, resolve_fraction) + if min_value is not None and resolved_scalar < min_value: + remaining_space -= min_value + remaining_fraction -= _Fraction(scalar.value) + resolved[index] = min_value + remaining_space_changed = True + elif max_value is not None and resolved_scalar > max_value: + remaining_space -= max_value + remaining_fraction -= _Fraction(scalar.value) + resolved[index] = max_value + remaining_space_changed = True + + if not remaining_space_changed: + break + + return ( + Fraction(remaining_space, remaining_fraction) + if remaining_fraction > 0 + else initial_space + ) + + +def resolve_box_models( + dimensions: list[Scalar | None], + widgets: list[Widget], + size: Size, + viewport_size: Size, + margin: Size, + resolve_dimension: Literal["width", "height"] = "width", + greedy: bool = True, +) -> list[BoxModel]: + """Resolve box models for a list of dimensions + + Args: + dimensions: A list of Scalars or Nones for each dimension. + widgets: Widgets in resolve. + size: Size of container. + viewport_size: Viewport size. + margin: Total space occupied by margin + resolve_dimension: Which dimension to resolve. + + Returns: + List of resolved box models. + """ + + margin_width, margin_height = margin + fraction_width = Fraction(size.width) + fraction_height = Fraction(size.height) + fraction_zero = Fraction(0) + margin_size = size - margin + + margins = [widget.styles.margin.totals for widget in widgets] + + # Fixed box models + box_models: list[BoxModel | None] = [ + ( + None + if _dimension is not None and _dimension.is_fraction + else widget._get_box_model( + size, + viewport_size, + ( + fraction_zero + if (_width := fraction_width - margin_width) < 0 + else _width + ), + ( + fraction_zero + if (_height := fraction_height - margin_height) < 0 + else _height + ), + greedy=greedy, + ) + ) + for (_dimension, widget, (margin_width, margin_height)) in zip( + dimensions, widgets, margins + ) + ] + + if None not in box_models: + # No fr units, so we're done + return cast("list[BoxModel]", box_models) + + # If all box models have been calculated + widget_styles = [widget.styles for widget in widgets] + if resolve_dimension == "width": + total_remaining = int( + sum( + [ + box_model.width + for widget, box_model in zip(widgets, box_models) + if (box_model is not None and widget.styles.overlay != "screen") + ] + ) + ) + + remaining_space = int(max(0, size.width - total_remaining - margin_width)) + fraction_unit = resolve_fraction_unit( + [ + styles + for styles in widget_styles + if styles.width is not None + and styles.width.is_fraction + and styles.overlay != "screen" + ], + size, + viewport_size, + Fraction(remaining_space), + resolve_dimension, + ) + width_fraction = fraction_unit + height_fraction = Fraction(margin_size.height) + else: + total_remaining = int( + sum( + [ + box_model.height + for widget, box_model in zip(widgets, box_models) + if (box_model is not None and widget.styles.overlay != "screen") + ] + ) + ) + + remaining_space = int(max(0, size.height - total_remaining - margin_height)) + + fraction_unit = resolve_fraction_unit( + [ + styles + for styles in widget_styles + if ( + styles.height is not None + and styles.height.is_fraction + and styles.overlay != "screen" + ) + ], + size, + viewport_size, + Fraction(remaining_space), + resolve_dimension, + ) + width_fraction = Fraction(margin_size.width) + height_fraction = fraction_unit + + box_models = [ + box_model + or widget._get_box_model( + size, viewport_size, width_fraction, height_fraction, greedy=greedy + ) + for widget, box_model in zip(widgets, box_models) + ] + + return cast("list[BoxModel]", box_models) diff --git a/src/memray/_vendor/textual/_segment_tools.py b/src/memray/_vendor/textual/_segment_tools.py new file mode 100644 index 0000000000..725951bc01 --- /dev/null +++ b/src/memray/_vendor/textual/_segment_tools.py @@ -0,0 +1,307 @@ +""" +Tools for processing Segments, or lists of Segments. +""" + +from __future__ import annotations + +import re +from functools import lru_cache +from typing import Iterable + +from rich.segment import Segment +from rich.style import Style + +from memray._vendor.textual._cells import cell_len +from memray._vendor.textual.css.types import AlignHorizontal, AlignVertical +from memray._vendor.textual.geometry import Size + + +@lru_cache(1024 * 8) +def make_blank(width, style: Style) -> Segment: + """Make a blank segment. + + Args: + width: Width of blank. + style: Style of blank. + + Returns: + A single segment + """ + return Segment(" " * width, style) + + +class NoCellPositionForIndex(Exception): + pass + + +def index_to_cell_position(segments: Iterable[Segment], index: int) -> int: + """Given a character index, return the cell position of that character within + an Iterable of Segments. This is the sum of the cell lengths of all the characters + *before* the character at `index`. + + Args: + segments: The segments to find the cell position within. + index: The index to convert into a cell position. + + Returns: + The cell position of the character at `index`. + + Raises: + NoCellPositionForIndex: If the supplied index doesn't fall within the given segments. + """ + if not segments: + raise NoCellPositionForIndex + + if index == 0: + return 0 + + cell_position_end = 0 + segment_length = 0 + segment_end_index = 0 + segment_cell_length = 0 + text = "" + iter_segments = iter(segments) + try: + while segment_end_index < index: + segment = next(iter_segments) + text = segment.text + segment_length = len(text) + segment_cell_length = cell_len(text) + cell_position_end += segment_cell_length + segment_end_index += segment_length + except StopIteration: + raise NoCellPositionForIndex + + # Check how far into this segment the target index is + segment_index_start = segment_end_index - segment_length + index_within_segment = index - segment_index_start + segment_cell_start = cell_position_end - segment_cell_length + + return segment_cell_start + cell_len(text[:index_within_segment]) + + +def line_crop( + segments: list[Segment], start: int, end: int, total: int +) -> list[Segment]: + """Crops a list of segments between two cell offsets. + + Args: + segments: A list of Segments for a line. + start: Start offset (cells) + end: End offset (cells, exclusive) + total: Total cell length of segments. + Returns: + A new shorter list of segments + """ + # This is essentially a specialized version of Segment.divide + # The following line has equivalent functionality (but a little slower) + # return list(Segment.divide(segments, [start, end]))[1] + + _cell_len = cell_len + pos = 0 + output_segments: list[Segment] = [] + add_segment = output_segments.append + iter_segments = iter(segments) + segment: Segment | None = None + for segment in iter_segments: + end_pos = pos + _cell_len(segment.text) + if end_pos > start: + segment = segment.split_cells(start - pos)[1] + break + pos = end_pos + else: + return [] + + if end >= total: + # The end crop is the end of the segments, so we can collect all remaining segments + if segment: + add_segment(segment) + output_segments.extend(iter_segments) + return output_segments + + pos = start + while segment is not None: + end_pos = pos + _cell_len(segment.text) + if end_pos < end: + add_segment(segment) + else: + add_segment(segment.split_cells(end - pos)[0]) + break + pos = end_pos + segment = next(iter_segments, None) + + return output_segments + + +def line_trim(segments: list[Segment], start: bool, end: bool) -> list[Segment]: + """Optionally remove a cell from the start and / or end of a list of segments. + + Args: + segments: A line (list of Segments) + start: Remove cell from start. + end: Remove cell from end. + + Returns: + A new list of segments. + """ + segments = segments.copy() + if segments and start: + _, first_segment = segments[0].split_cells(1) + if first_segment.text: + segments[0] = first_segment + else: + segments.pop(0) + if segments and end: + last_segment = segments[-1] + last_segment, _ = last_segment.split_cells(len(last_segment.text) - 1) + if last_segment.text: + segments[-1] = last_segment + else: + segments.pop() + return segments + + +def line_pad( + segments: Iterable[Segment], pad_left: int, pad_right: int, style: Style +) -> list[Segment]: + """Adds padding to the left and / or right of a list of segments. + + Args: + segments: A line of segments. + pad_left: Cells to pad on the left. + pad_right: Cells to pad on the right. + style: Style of padded cells. + + Returns: + A new line with padding. + """ + if pad_left and pad_right: + return [ + make_blank(pad_left, style), + *segments, + make_blank(pad_right, style), + ] + elif pad_left: + return [ + make_blank(pad_left, style), + *segments, + ] + elif pad_right: + return [ + *segments, + make_blank(pad_right, style), + ] + return list(segments) + + +def align_lines( + lines: list[list[Segment]], + style: Style, + size: Size, + horizontal: AlignHorizontal, + vertical: AlignVertical, +) -> Iterable[list[Segment]]: + """Align lines. + + Args: + lines: A list of lines. + style: Background style. + size: Size of container. + horizontal: Horizontal alignment. + vertical: Vertical alignment. + + Returns: + Aligned lines. + """ + if not lines: + return + width, height = size + get_line_length = Segment.get_line_length + line_lengths = [get_line_length(line) for line in lines] + shape_width = max(line_lengths) + shape_height = len(line_lengths) + + def blank_lines(count: int) -> list[list[Segment]]: + """Create blank lines. + + Args: + count: Desired number of blank lines. + + Returns: + A list of blank lines. + """ + return [[make_blank(width, style)]] * count + + top_blank_lines = bottom_blank_lines = 0 + vertical_excess_space = max(0, height - shape_height) + + if vertical == "top": + bottom_blank_lines = vertical_excess_space + elif vertical == "middle": + top_blank_lines = vertical_excess_space // 2 + bottom_blank_lines = vertical_excess_space - top_blank_lines + elif vertical == "bottom": + top_blank_lines = vertical_excess_space + + if top_blank_lines: + yield from blank_lines(top_blank_lines) + + if horizontal == "left": + for cell_length, line in zip(line_lengths, lines): + if cell_length == width: + yield line + else: + yield line_pad(line, 0, width - cell_length, style) + + elif horizontal == "center": + left_space = max(0, width - shape_width) // 2 + for cell_length, line in zip(line_lengths, lines): + if cell_length == width: + yield line + else: + yield line_pad( + line, left_space, width - cell_length - left_space, style + ) + + elif horizontal == "right": + for cell_length, line in zip(line_lengths, lines): + if width == cell_length: + yield line + else: + yield line_pad(line, width - cell_length, 0, style) + + if bottom_blank_lines: + yield from blank_lines(bottom_blank_lines) + + +_re_spaces = re.compile(r"(\s+|\S+)") + + +def apply_hatch( + segments: Iterable[Segment], + character: str, + hatch_style: Style, + _split=_re_spaces.split, +) -> Iterable[Segment]: + """Replace run of spaces with another character + style. + + Args: + segments: Segments to process. + character: Character to replace spaces. + hatch_style: Style of replacement characters. + + Yields: + Segments. + """ + _Segment = Segment + for segment in segments: + if " " not in segment.text: + yield segment + else: + text, style, _ = segment + for token in _split(text): + if token: + if token.isspace(): + yield _Segment(character * len(token), hatch_style) + else: + yield _Segment(token, style) diff --git a/src/memray/_vendor/textual/_sleep.py b/src/memray/_vendor/textual/_sleep.py new file mode 100644 index 0000000000..bdc7010dc1 --- /dev/null +++ b/src/memray/_vendor/textual/_sleep.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from asyncio import Future, get_running_loop +from threading import Event, Thread +from time import perf_counter, sleep + + +class Sleeper(Thread): + def __init__( + self, + ) -> None: + self._exit = False + self._sleep_time = 0.0 + self._event = Event() + self.future: Future | None = None + self._loop = get_running_loop() + super().__init__(daemon=True) + + def run(self): + while True: + self._event.wait() + if self._exit: + break + sleep(self._sleep_time) + self._event.clear() + # self.future.set_result(None) + assert self.future is not None + self._loop.call_soon_threadsafe(self.future.set_result, None) + + async def sleep(self, sleep_time: float) -> None: + future = self.future = self._loop.create_future() + self._sleep_time = sleep_time + self._event.set() + await future + + +async def check_sleeps() -> None: + sleeper = Sleeper() + sleeper.start() + + async def profile_sleep(sleep_for: float) -> float: + start = perf_counter() + + while perf_counter() - start < sleep_for: + sleep(0) + elapsed = perf_counter() - start + return elapsed + + for t in range(15, 120, 5): + sleep_time = 1 / t + elapsed = await profile_sleep(sleep_time) + difference = (elapsed / sleep_time * 100) - 100 + print( + f"sleep={sleep_time*1000:.01f}ms clock={elapsed*1000:.01f}ms diff={difference:.02f}%" + ) + + +from asyncio import run + +run(check_sleeps()) diff --git a/src/memray/_vendor/textual/_slug.py b/src/memray/_vendor/textual/_slug.py new file mode 100644 index 0000000000..686ae06d5d --- /dev/null +++ b/src/memray/_vendor/textual/_slug.py @@ -0,0 +1,140 @@ +"""Provides a utility function and class for creating Markdown-friendly slugs. + +The approach to creating slugs is designed to be as close to +GitHub-flavoured Markdown as possible. However, because there doesn't appear +to be any actual documentation for this 'standard', the code here involves +some guesswork and also some pragmatic shortcuts. + +Expect this to grow over time. + +The main rules used in here at the moment are: + +1. Strip all leading and trailing whitespace. +2. Remove all non-lingual characters (emoji, etc). +3. Remove all punctuation and whitespace apart from dash and underscore. +""" + +from __future__ import annotations + +from collections import defaultdict +from re import compile +from string import punctuation +from typing import Pattern +from urllib.parse import quote + +from typing_extensions import Final + +WHITESPACE_REPLACEMENT: Final[str] = "-" +"""The character to replace undesirable characters with.""" + +REMOVABLE: Final[str] = punctuation.replace(WHITESPACE_REPLACEMENT, "").replace("_", "") +"""The collection of characters that should be removed altogether.""" + +NONLINGUAL: Final[str] = ( + r"\U000024C2-\U0001F251" + r"\U00002702-\U000027B0" + r"\U0001F1E0-\U0001F1FF" + r"\U0001F300-\U0001F5FF" # Miscellaneous Symbols And Pictographs + r"\U0001F600-\U0001F64F" # Emoticons + r"\U0001F680-\U0001F6FF" # Transport and Map Symbols + r"\U0001F900-\U0001F9FF" # Supplemental Symbols and Pictographs + r"\u200D" + r"\u2640-\u2642" +) +"""A string that can be used in a regular expression to remove most non-lingual characters.""" + +STRIP_RE: Final[Pattern] = compile(f"[{REMOVABLE}{NONLINGUAL}]+") +"""A regular expression for finding all the characters that should be removed.""" + +WHITESPACE_RE: Final[Pattern] = compile(r"\s") +"""A regular expression for finding all the whitespace and turning it into `REPLACEMENT`.""" + + +def slug(text: str) -> str: + """Create a Markdown-friendly slug from the given text. + + Args: + text: The text to generate a slug from. + + Returns: + A slug for the given text. + + The rules used in generating the slug are based on observations of how + GitHub-flavoured Markdown works. + """ + result = text.strip().lower() + for rule, replacement in ( + (STRIP_RE, ""), + (WHITESPACE_RE, WHITESPACE_REPLACEMENT), + ): + result = rule.sub(replacement, result) + return quote(result) + + +class TrackedSlugs: + """Provides a class for generating tracked slugs. + + While [`slug`][textual._slug.slug] will generate a slug for a given + string, it does not guarantee that it is unique for a given context. If + you want to ensure that the same string generates unique slugs (perhaps + heading slugs within a Markdown document, as an example), use an + instance of this class to generate them. + + Example: + ```python + >>> slug("hello world") + 'hello-world' + >>> slug("hello world") + 'hello-world' + >>> unique = TrackedSlugs() + >>> unique.slug("hello world") + 'hello-world' + >>> unique.slug("hello world") + 'hello-world-1' + ``` + """ + + def __init__(self) -> None: + """Initialise the tracked slug object.""" + self._used: defaultdict[str, int] = defaultdict(int) + """Keeps track of how many times a particular slug has been used.""" + + def slug(self, text: str) -> str: + """Create a Markdown-friendly unique slug from the given text. + + Args: + text: The text to generate a slug from. + + Returns: + A slug for the given text. + """ + slugged = slug(text) + used = self._used[slugged] + self._used[slugged] += 1 + if used: + slugged = f"{slugged}-{used}" + return slugged + + +VALID_ID_CHARACTERS = frozenset("abcdefghijklmnopqrstuvwxyz0123456789-") + + +def slug_for_tcss_id(text: str) -> str: + """Produce a slug usable as a TCSS id from the given text. + + Args: + text: Text. + + Returns: + A slugified version of text suitable for use as a TCSS id. + """ + is_valid = VALID_ID_CHARACTERS.__contains__ + slug = "".join( + (character if is_valid(character) else "{:x}".format(ord(character))) + for character in text.casefold().replace(" ", "-") + ) + if not slug: + return "_" + if slug[0].isdecimal(): + return f"_{slug}" + return slug diff --git a/src/memray/_vendor/textual/_spatial_map.py b/src/memray/_vendor/textual/_spatial_map.py new file mode 100644 index 0000000000..1d7c587547 --- /dev/null +++ b/src/memray/_vendor/textual/_spatial_map.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from collections import defaultdict +from itertools import product +from typing import Generic, Iterable, TypeVar + +from typing_extensions import TypeAlias + +from memray._vendor.textual.geometry import Offset, Region + +ValueType = TypeVar("ValueType") +GridCoordinate: TypeAlias = "tuple[int, int]" + + +class SpatialMap(Generic[ValueType]): + """A spatial map allows for data to be associated with rectangular regions + in Euclidean space, and efficiently queried. + + When the SpatialMap is populated, a reference to each value is placed into one or + more buckets associated with a regular grid that covers 2D space. + + The SpatialMap is able to quickly retrieve the values under a given "window" region + by combining the values in the grid squares under the visible area. + """ + + def __init__(self, grid_width: int = 100, grid_height: int = 20) -> None: + """Create a spatial map with the given grid size. + + Args: + grid_width: Width of a grid square. + grid_height: Height of a grid square. + """ + self._grid_size = (grid_width, grid_height) + self.total_region = Region() + self._map: defaultdict[GridCoordinate, list[ValueType]] = defaultdict(list) + self._fixed: list[ValueType] = [] + + def _region_to_grid_coordinates(self, region: Region) -> Iterable[GridCoordinate]: + """Get the grid squares under a region. + + Args: + region: A region. + + Returns: + Iterable of grid coordinates (tuple of 2 values). + """ + # (x1, y1) is the coordinate of the top left cell + # (x2, y2) is the coordinate of the bottom right cell + x1, y1, width, height = region + x2 = x1 + width - 1 + y2 = y1 + height - 1 + grid_width, grid_height = self._grid_size + + return product( + range(x1 // grid_width, x2 // grid_width + 1), + range(y1 // grid_height, y2 // grid_height + 1), + ) + + def insert( + self, regions_and_values: Iterable[tuple[Region, Offset, bool, bool, ValueType]] + ) -> None: + """Insert values into the Spatial map. + + Values are associated with their region in Euclidean space, and a boolean that + indicates fixed regions. Fixed regions don't scroll and are always visible. + + Args: + regions_and_values: An iterable of (REGION, OFFSET, FIXED, OVERLAY, VALUE). + """ + append_fixed = self._fixed.append + get_grid_list = self._map.__getitem__ + _region_to_grid = self._region_to_grid_coordinates + total_region = self.total_region + for region, offset, fixed, overlay, value in regions_and_values: + if fixed: + append_fixed(value) + else: + if not overlay: + total_region = total_region.union(region) + for grid in _region_to_grid(region + offset): + get_grid_list(grid).append(value) + self.total_region = total_region + + def get_values_in_region(self, region: Region) -> list[ValueType]: + """Get a superset of all the values that intersect with a given region. + + Note that this may return false positives. + + Args: + region: A region. + + Returns: + Values under the region. + """ + results: list[ValueType] = self._fixed.copy() + add_results = results.extend + get_grid_values = self._map.get + for grid_coordinate in self._region_to_grid_coordinates(region): + grid_values = get_grid_values(grid_coordinate) + if grid_values is not None: + add_results(grid_values) + unique_values = list(dict.fromkeys(results)) + return unique_values diff --git a/src/memray/_vendor/textual/_styles_cache.py b/src/memray/_vendor/textual/_styles_cache.py new file mode 100644 index 0000000000..b6f56c377f --- /dev/null +++ b/src/memray/_vendor/textual/_styles_cache.py @@ -0,0 +1,513 @@ +from __future__ import annotations + +from functools import lru_cache +from typing import TYPE_CHECKING, Callable, Iterable, Sequence + +import rich.repr +from rich.segment import Segment +from rich.style import Style as RichStyle +from rich.terminal_theme import TerminalTheme + +from memray._vendor.textual import log +from memray._vendor.textual._ansi_theme import DEFAULT_TERMINAL_THEME +from memray._vendor.textual._border import get_box, render_border_label, render_row +from memray._vendor.textual._context import active_app +from memray._vendor.textual._opacity import _apply_opacity +from memray._vendor.textual._segment_tools import apply_hatch, line_pad, line_trim, make_blank +from memray._vendor.textual.color import TRANSPARENT, Color +from memray._vendor.textual.constants import DEBUG +from memray._vendor.textual.content import Content +from memray._vendor.textual.filter import LineFilter +from memray._vendor.textual.geometry import Region, Size, Spacing +from memray._vendor.textual.renderables.text_opacity import TextOpacity +from memray._vendor.textual.renderables.tint import Tint +from memray._vendor.textual.strip import Strip +from memray._vendor.textual.style import Style + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + from memray._vendor.textual.css.styles import StylesBase + from memray._vendor.textual.widget import Widget + +RenderLineCallback: TypeAlias = Callable[[int], Strip] + + +@rich.repr.auto(angular=True) +class StylesCache: + """Responsible for rendering CSS Styles and keeping a cache of rendered lines. + + The render method applies border, outline, and padding set in the Styles object to widget content. + + The diagram below shows content (possibly from a Rich renderable) with padding and border. The + labels A. B. and C. indicate the code path (see comments in render_line below) chosen to render + the indicated lines. + + ``` + ┏━━━━━━━━━━━━━━━━━━━━━━┓◀── A. border + ┃ ┃◀┐ + ┃ ┃ └─ B. border + padding + + ┃ Lorem ipsum dolor ┃◀┐ border + ┃ sit amet, ┃ │ + ┃ consectetur ┃ └─ C. border + padding + + ┃ adipiscing elit, ┃ content + padding + + ┃ sed do eiusmod ┃ border + ┃ tempor incididunt ┃ + ┃ ┃ + ┃ ┃ + ┗━━━━━━━━━━━━━━━━━━━━━━┛ + ``` + """ + + def __init__(self) -> None: + self._cache: dict[int, Strip] = {} + self._dirty_lines: set[int] = set() + self._width = 1 + self._simple_strip: Strip | None = None + """A simple strip consisting of left border + background + right border, which may be reused in a render.""" + + def __rich_repr__(self) -> rich.repr.Result: + if self._dirty_lines: + yield "dirty", self._dirty_lines + yield "width", self._width, 1 + + def set_dirty(self, *regions: Region) -> None: + """Add a dirty regions.""" + if regions: + for region in regions: + self._dirty_lines.update(region.line_range) + else: + self.clear() + + def is_dirty(self, y: int) -> bool: + """Check if a given line is dirty (needs to be rendered again). + + Args: + y: Y coordinate of line. + + Returns: + True if line requires a render, False if can be cached. + """ + return y in self._dirty_lines + + def clear(self) -> None: + """Clear the styles cache (will cause the content to re-render).""" + + self._cache.clear() + self._dirty_lines.clear() + + def render_widget(self, widget: Widget, crop: Region) -> list[Strip]: + """Render the content for a widget. + + Args: + widget: A widget. + region: A region of the widget to render. + + Returns: + Rendered lines. + """ + border_title = widget._border_title + border_subtitle = widget._border_subtitle + self._simple_strip = None + + base_background, background = widget.background_colors + styles = widget.styles + strips = self.render( + styles, + widget.region.size, + base_background, + background, + widget.render_line, + widget.get_line_filters(), + ( + None + if border_title is None + else ( + border_title, + *widget._get_title_style_information(base_background), + ) + ), + ( + None + if border_subtitle is None + else ( + border_subtitle, + *widget._get_subtitle_style_information(base_background), + ) + ), + content_size=widget.content_region.size, + padding=styles.padding, + crop=crop, + opacity=widget.opacity, + ansi_theme=widget.app.ansi_theme, + ) + + if widget.auto_links: + hover_style = widget.hover_style + if ( + hover_style._link_id + and hover_style._meta + and "@click" in hover_style.meta + ): + link_style_hover = widget.link_style_hover + if link_style_hover: + strips = [ + strip.style_links(hover_style.link_id, link_style_hover) + for strip in strips + ] + + return strips + + def render( + self, + styles: StylesBase, + size: Size, + base_background: Color, + background: Color, + render_content_line: RenderLineCallback, + filters: Sequence[LineFilter], + border_title: tuple[Content, Color, Color, Style] | None, + border_subtitle: tuple[Content, Color, Color, Style] | None, + content_size: Size | None = None, + padding: Spacing | None = None, + crop: Region | None = None, + opacity: float = 1.0, + ansi_theme: TerminalTheme = DEFAULT_TERMINAL_THEME, + ) -> list[Strip]: + """Render a widget content plus CSS styles. + + Args: + styles: CSS Styles object. + size: Size of widget. + base_background: Background color beneath widget. + background: Background color of widget. + render_content_line: Callback to render content line. + console: The console in use by the app. + border_title: Optional tuple of (title, color, background, style). + border_subtitle: Optional tuple of (subtitle, color, background, style). + content_size: Size of content or None to assume full size. + padding: Override padding from Styles, or None to use styles.padding. + crop: Region to crop to. + filters: Additional post-processing for the segments. + opacity: Widget opacity. + ansi_theme: Theme for ANSI colors. + + Returns: + Rendered lines. + """ + if content_size is None: + content_size = size + if padding is None: + padding = styles.padding + if crop is None: + crop = size.region + + width, _height = size + if width != self._width: + self.clear() + self._width = width + strips: list[Strip] = [] + add_strip = strips.append + + is_dirty = self._dirty_lines.__contains__ + render_line = self.render_line + + for y in crop.line_range: + if is_dirty(y) or y not in self._cache: + strip = render_line( + styles, + y, + size, + content_size, + padding, + base_background, + background, + render_content_line, + border_title, + border_subtitle, + opacity, + ansi_theme, + ) + self._cache[y] = strip + else: + strip = self._cache[y] + + for filter in filters: + strip = strip.apply_filter(filter, background) + + if DEBUG: + if any([not (segment.control or segment.text) for segment in strip]): + log.warning(f"Strip contains invalid empty Segments: {strip!r}.") + + add_strip(strip) + + self._dirty_lines.difference_update(crop.line_range) + + if crop.column_span != (0, width): + x1, x2 = crop.column_span + strips = [strip.crop(x1, x2) for strip in strips] + + return strips + + @lru_cache(1024) + def get_inner_outer( + cls, base_background: Color, background: Color + ) -> tuple[Style, Style]: + """Get inner and outer background colors.""" + return ( + Style(background=base_background + background), + Style(background=base_background), + ) + + def render_line( + self, + styles: StylesBase, + y: int, + size: Size, + content_size: Size, + padding: Spacing, + base_background: Color, + background: Color, + render_content_line: Callable[[int], Strip], + border_title: tuple[Content, Color, Color, Style] | None, + border_subtitle: tuple[Content, Color, Color, Style] | None, + opacity: float, + ansi_theme: TerminalTheme, + ) -> Strip: + """Render a styled line. + + Args: + styles: Styles object. + y: The y coordinate of the line (relative to widget screen offset). + size: Size of the widget. + content_size: Size of the content area. + padding: Padding. + base_background: Background color of widget beneath this line. + background: Background color of widget. + render_content_line: Callback to render a line of content. + console: The console in use by the app. + border_title: Optional tuple of (title, color, background, style). + border_subtitle: Optional tuple of (subtitle, color, background, style). + opacity: Opacity of line. + + Returns: + A line of segments. + """ + + gutter = styles.gutter + width, height = size + content_width, content_height = content_size + + pad_top, pad_right, pad_bottom, pad_left = padding + + ( + (border_top, border_top_color), + (border_right, border_right_color), + (border_bottom, border_bottom_color), + (border_left, border_left_color), + ) = styles.border + + ( + (outline_top, outline_top_color), + (outline_right, outline_right_color), + (outline_bottom, outline_bottom_color), + (outline_left, outline_left_color), + ) = styles.outline + + from_color = RichStyle.from_color + inner, outer = self.get_inner_outer(base_background, background) + + def line_post(segments: Iterable[Segment]) -> Iterable[Segment]: + """Apply effects to segments inside the border.""" + if styles.has_rule("hatch") and styles.hatch != "none": + character, color = styles.hatch + if character != " " and color.a > 0: + hatch_style = from_color( + (background + color).rich_color, background.rich_color + ) + return apply_hatch(segments, character, hatch_style) + return segments + + def post(segments: Iterable[Segment]) -> Iterable[Segment]: + """Post process segments to apply opacity and tint. + + Args: + segments: Iterable of segments. + + Returns: + New list of segments + """ + try: + app = active_app.get() + ansi_theme = app.ansi_theme + except LookupError: + ansi_theme = DEFAULT_TERMINAL_THEME + + if styles.tint.a: + segments = Tint.process_segments( + segments, styles.tint, ansi_theme, background + ) + if opacity != 1.0: + segments = _apply_opacity(segments, base_background, opacity) + return segments + + cache_simple_strip: bool = False + line: Iterable[Segment] + # Draw top or bottom borders (A) + if (border_top and y == 0) or (border_bottom and y == height - 1): + is_top = y == 0 + border_color = base_background + ( + border_top_color if is_top else border_bottom_color + ).multiply_alpha(opacity) + border_color_as_style = Style(foreground=border_color) + border_edge_type = border_top if is_top else border_bottom + has_left = border_left != "" + has_right = border_right != "" + border_label = border_title if is_top else border_subtitle + if border_label is None: + render_label = None + else: + label, label_color, label_background, style = border_label + base_label_background = base_background + background + style += Style( + ( + (base_label_background + label_background) + if label_background.a + else TRANSPARENT + ), + ( + (base_label_background + label_color) + if label_color.a + else TRANSPARENT + ), + ) + render_label = (label, style) + + # Try to save time with expensive call to `render_border_label`: + if render_label: + label_segments = render_border_label( + render_label, + is_top, + border_edge_type, + width - 2, + inner, + outer, + border_color_as_style, + has_left, + has_right, + ) + else: + label_segments = [] + box_segments = get_box( + border_edge_type, + inner, + outer, + border_color_as_style, + ) + label_alignment = ( + styles.border_title_align if is_top else styles.border_subtitle_align + ) + line = render_row( + box_segments[0 if is_top else 2], + width, + has_left, + has_right, + label_segments, + label_alignment, # type: ignore + ) + # Draw padding (B) + elif (pad_top and y < gutter.top) or ( + pad_bottom and y >= height - gutter.bottom + ): + if self._simple_strip is not None: + return self._simple_strip + cache_simple_strip = True + background_rich_style = inner.rich_style + left_style = Style( + foreground=base_background + border_left_color.multiply_alpha(opacity) + ) + left = get_box(border_left, inner, outer, left_style)[1][0] + right_style = Style( + foreground=base_background + border_right_color.multiply_alpha(opacity) + ) + right = get_box(border_right, inner, outer, right_style)[1][2] + if border_left and border_right: + line = [left, make_blank(width - 2, background_rich_style), right] + elif border_left: + line = [left, make_blank(width - 1, background_rich_style)] + elif border_right: + line = [make_blank(width - 1, background_rich_style), right] + else: + line = [make_blank(width, background_rich_style)] + line = line_post(line) + else: + # Content with border and padding (C) + content_y = y - gutter.top + if content_y < content_height: + line = render_content_line(y - gutter.top) + line = line.adjust_cell_length(content_width, inner.rich_style) + else: + line = Strip.blank(content_width, inner.rich_style) + + if (text_opacity := styles.text_opacity) != 1.0: + line = TextOpacity.process_segments(line, text_opacity, ansi_theme) + if pad_left or pad_right: + line = line_post(line_pad(line, pad_left, pad_right, inner.rich_style)) + else: + line = line_post(line) + + if border_left or border_right: + # Add left / right border + left_style = Style( + foreground=base_background + + border_left_color.multiply_alpha(opacity) + ) + left = get_box(border_left, inner, outer, left_style)[1][0] + right_style = Style( + foreground=base_background + + border_right_color.multiply_alpha(opacity) + ) + right = get_box(border_right, inner, outer, right_style)[1][2] + + if border_left and border_right: + line = [left, *line, right] + elif border_left: + line = [left, *line] + else: + line = [*line, right] + + # Draw any outline + if (outline_top and y == 0) or (outline_bottom and y == height - 1): + # Top or bottom outlines + outline_color = outline_top_color if y == 0 else outline_bottom_color + box_segments = get_box( + outline_top if y == 0 else outline_bottom, + inner, + outer, + Style(foreground=base_background + outline_color), + ) + line = render_row( + box_segments[0 if y == 0 else 2], + width, + outline_left != "", + outline_right != "", + (), + ) + + elif outline_left or outline_right: + # Lines in side outline + left_style = Style(foreground=(base_background + outline_left_color)) + left = get_box(outline_left, inner, outer, left_style)[1][0] + right_style = Style(foreground=(base_background + outline_right_color)) + right = get_box(outline_right, inner, outer, right_style)[1][2] + line = line_trim(list(line), outline_left != "", outline_right != "") + if outline_left and outline_right: + line = [left, *line, right] + elif outline_left: + line = [left, *line] + else: + line = [*line, right] + strip = Strip(post(line), width) + if cache_simple_strip: + self._simple_strip = strip + return strip diff --git a/src/memray/_vendor/textual/_text_area_theme.py b/src/memray/_vendor/textual/_text_area_theme.py new file mode 100644 index 0000000000..c4321fd1d5 --- /dev/null +++ b/src/memray/_vendor/textual/_text_area_theme.py @@ -0,0 +1,434 @@ +from __future__ import annotations + +from dataclasses import dataclass, field, fields +from typing import TYPE_CHECKING + +from rich.style import Style + +from memray._vendor.textual.color import Color + +if TYPE_CHECKING: + from memray._vendor.textual.widgets import TextArea + + +@dataclass +class TextAreaTheme: + """A theme for the `TextArea` widget. + + Allows theming the general widget (gutter, selections, cursor, and so on) and + mapping of tree-sitter tokens to Rich styles. + + For example, consider the following snippet from the `markdown.scm` highlight + query file. We've assigned the `heading_content` token type to the name `heading`. + + ``` + (heading_content) @heading + ``` + + Now, we can map this `heading` name to a Rich style, and it will be styled as + such in the `TextArea`, assuming a parser which returns a `heading_content` + node is used (as will be the case when language="markdown"). + + ``` + TextAreaTheme('my_theme', syntax_styles={'heading': Style(color='cyan', bold=True)}) + ``` + + We can register this theme with our `TextArea` using the [`TextArea.register_theme`][textual.widgets._text_area.TextArea.register_theme] method, + and headings in our markdown files will be styled bold cyan. + """ + + name: str + """The name of the theme.""" + + base_style: Style | None = None + """The background style of the text area. If `None` the parent style will be used.""" + + gutter_style: Style | None = None + """The style of the gutter. If `None`, a legible Style will be generated.""" + + cursor_style: Style | None = None + """The style of the cursor. If `None`, a legible Style will be generated.""" + + cursor_line_style: Style | None = None + """The style to apply to the line the cursor is on.""" + + cursor_line_gutter_style: Style | None = None + """The style to apply to the gutter of the line the cursor is on. If `None`, a legible Style will be + generated.""" + + bracket_matching_style: Style | None = None + """The style to apply to matching brackets. If `None`, a legible Style will be generated.""" + + selection_style: Style | None = None + """The style of the selection. If `None` a default selection Style will be generated.""" + + syntax_styles: dict[str, Style] = field(default_factory=dict) + """The mapping of tree-sitter names from the `highlight_query` to Rich styles.""" + + _theme_configured_attributes: set[str] = field(init=False, default_factory=set) + """Records which attributes were set via the theme object (as opposed to CSS components).""" + + def __post_init__(self) -> None: + theme_fields = fields(self) + for field in theme_fields: + if getattr(self, field.name) is not None: + self._theme_configured_attributes.add(field.name) + + def apply_css(self, text_area: TextArea) -> None: + """Apply CSS rules from a TextArea to be used for fallback styling. + + If any attributes in the theme aren't supplied, they'll be filled with the appropriate + base CSS (e.g. color, background, etc.) and component CSS (e.g. text-area--cursor) from + the supplied TextArea. + + Args: + text_area: The TextArea instance to retrieve fallback styling from. + """ + self.base_style = text_area.rich_style or Style() + get_style = text_area.get_component_rich_style + + if self.base_style.color is None: + self.base_style = Style(color="#f3f3f3", bgcolor=self.base_style.bgcolor) + + app_theme = text_area.app.current_theme + + if self.base_style.bgcolor is None: + self.base_style = Style( + color=self.base_style.color, bgcolor=app_theme.surface + ) + + configured = self._theme_configured_attributes.__contains__ + + assert self.base_style is not None + assert self.base_style.color is not None + assert self.base_style.bgcolor is not None + + if not configured("gutter_style"): + gutter_style = get_style("text-area--gutter") + if gutter_style: + self.gutter_style = gutter_style + else: + self.gutter_style = self.base_style.copy() + + background_color = Color.from_rich_color(self.base_style.bgcolor) + if not configured("cursor_style"): + # If the theme doesn't contain a cursor style, fallback to component styles. + cursor_style = get_style("text-area--cursor") + if cursor_style: + self.cursor_style = cursor_style + else: + # There's no component style either, fallback to a default. + self.cursor_style = Style.from_color( + color=background_color.rich_color, + bgcolor=background_color.inverse.rich_color, + ) + + # Apply fallbacks for the styles of the active line and active line gutter. + if not configured("cursor_line_style"): + self.cursor_line_style = get_style("text-area--cursor-line") + + if not configured("cursor_line_gutter_style"): + self.cursor_line_gutter_style = get_style("text-area--cursor-gutter") + + if not configured("bracket_matching_style"): + matching_bracket_style = get_style("text-area--matching-bracket") + if matching_bracket_style: + self.bracket_matching_style = matching_bracket_style + else: + bracket_matching_background = background_color.blend( + background_color.inverse, factor=0.05 + ) + self.bracket_matching_style = Style( + bgcolor=bracket_matching_background.rich_color + ) + + if not configured("selection_style"): + selection_style = get_style("text-area--selection") + if selection_style: + self.selection_style = selection_style + else: + selection_background_color = background_color.blend( + app_theme.primary, factor=0.5 + ) + self.selection_style = Style.from_color( + bgcolor=selection_background_color.rich_color + ) + + @classmethod + def get_builtin_theme(cls, theme_name: str) -> TextAreaTheme | None: + """Get a `TextAreaTheme` by name. + + Given a `theme_name`, return the corresponding `TextAreaTheme` object. + + Args: + theme_name: The name of the theme. + + Returns: + The `TextAreaTheme` corresponding to the name or `None` if the theme isn't + found. + """ + return _BUILTIN_THEMES.get(theme_name) + + def get_highlight(self, name: str) -> Style | None: + """Return the Rich style corresponding to the name defined in the tree-sitter + highlight query for the current theme. + + Args: + name: The name of the highlight. + + Returns: + The `Style` to use for this highlight, or `None` if no style. + """ + return self.syntax_styles.get(name) + + @classmethod + def builtin_themes(cls) -> list[TextAreaTheme]: + """Get a list of all builtin TextAreaThemes. + + Returns: + A list of all builtin TextAreaThemes. + """ + return list(_BUILTIN_THEMES.values()) + + +_MONOKAI = TextAreaTheme( + name="monokai", + base_style=Style(color="#f8f8f2", bgcolor="#272822"), + gutter_style=Style(color="#90908a", bgcolor="#272822"), + cursor_style=Style(color="#272822", bgcolor="#f8f8f0"), + cursor_line_style=Style(bgcolor="#3e3d32"), + cursor_line_gutter_style=Style(color="#c2c2bf", bgcolor="#3e3d32"), + bracket_matching_style=Style(bgcolor="#838889", bold=True), + selection_style=Style(bgcolor="#65686a"), + syntax_styles={ + "string": Style(color="#E6DB74"), + "string.documentation": Style(color="#E6DB74"), + "comment": Style(color="#75715E"), + "heading.marker": Style(color="#90908a"), + "keyword": Style(color="#F92672"), + "operator": Style(color="#f8f8f2"), + "repeat": Style(color="#F92672"), + "exception": Style(color="#F92672"), + "include": Style(color="#F92672"), + "keyword.function": Style(color="#F92672"), + "keyword.return": Style(color="#F92672"), + "keyword.operator": Style(color="#F92672"), + "conditional": Style(color="#F92672"), + "number": Style(color="#AE81FF"), + "float": Style(color="#AE81FF"), + "class": Style(color="#A6E22E"), + "type": Style(color="#A6E22E"), + "type.class": Style(color="#A6E22E"), + "type.builtin": Style(color="#F92672"), + "variable.builtin": Style(color="#f8f8f2"), + "function": Style(color="#A6E22E"), + "function.call": Style(color="#A6E22E"), + "method": Style(color="#A6E22E"), + "method.call": Style(color="#A6E22E"), + "boolean": Style(color="#66D9EF", italic=True), + "constant.builtin": Style(color="#66D9EF", italic=True), + "json.null": Style(color="#66D9EF", italic=True), + "regex.punctuation.bracket": Style(color="#F92672"), + "regex.operator": Style(color="#F92672"), + "html.end_tag_error": Style(color="red", underline=True), + "tag": Style(color="#F92672"), + "yaml.field": Style(color="#F92672", bold=True), + "json.label": Style(color="#F92672", bold=True), + "toml.type": Style(color="#F92672"), + "toml.datetime": Style(color="#AE81FF"), + "css.property": Style(color="#AE81FF"), + "heading": Style(color="#F92672", bold=True), + "bold": Style(bold=True), + "italic": Style(italic=True), + "strikethrough": Style(strike=True), + "link.label": Style(color="#F92672"), + "link.uri": Style(color="#66D9EF", underline=True), + "list.marker": Style(color="#90908a"), + "inline_code": Style(color="#E6DB74"), + "punctuation.bracket": Style(color="#f8f8f2"), + "punctuation.delimiter": Style(color="#f8f8f2"), + "punctuation.special": Style(color="#f8f8f2"), + }, +) + +_DRACULA = TextAreaTheme( + name="dracula", + base_style=Style(color="#f8f8f2", bgcolor="#1E1F35"), + gutter_style=Style(color="#6272a4"), + cursor_style=Style(color="#282a36", bgcolor="#f8f8f0"), + cursor_line_style=Style(bgcolor="#282b45"), + cursor_line_gutter_style=Style(color="#c2c2bf", bgcolor="#282b45", bold=True), + bracket_matching_style=Style(bgcolor="#99999d", bold=True, underline=True), + selection_style=Style(bgcolor="#44475A"), + syntax_styles={ + "string": Style(color="#f1fa8c"), + "string.documentation": Style(color="#f1fa8c"), + "comment": Style(color="#6272a4"), + "heading.marker": Style(color="#6272a4"), + "keyword": Style(color="#ff79c6"), + "operator": Style(color="#f8f8f2"), + "repeat": Style(color="#ff79c6"), + "exception": Style(color="#ff79c6"), + "include": Style(color="#ff79c6"), + "keyword.function": Style(color="#ff79c6"), + "keyword.return": Style(color="#ff79c6"), + "keyword.operator": Style(color="#ff79c6"), + "conditional": Style(color="#ff79c6"), + "number": Style(color="#bd93f9"), + "float": Style(color="#bd93f9"), + "class": Style(color="#50fa7b"), + "type": Style(color="#ff79c6"), + "type.class": Style(color="#50fa7b"), + "type.builtin": Style(color="#bd93f9"), + "variable.builtin": Style(color="#f8f8f2"), + "function": Style(color="#50fa7b"), + "function.call": Style(color="#50fa7b"), + "method": Style(color="#50fa7b"), + "method.call": Style(color="#50fa7b"), + "boolean": Style(color="#50fa7b"), + "constant.builtin": Style(color="#bd93f9"), + "json.null": Style(color="#bd93f9"), + "regex.punctuation.bracket": Style(color="#ff79c6"), + "regex.operator": Style(color="#ff79c6"), + "html.end_tag_error": Style(color="#F83333", underline=True), + "tag": Style(color="#ff79c6"), + "yaml.field": Style(color="#ff79c6", bold=True), + "json.label": Style(color="#ff79c6", bold=True), + "toml.type": Style(color="#ff79c6"), + "toml.datetime": Style(color="#bd93f9"), + "css.property": Style(color="#bd93f9"), + "heading": Style(color="#ff79c6", bold=True), + "bold": Style(bold=True), + "italic": Style(italic=True), + "strikethrough": Style(strike=True), + "link.label": Style(color="#ff79c6"), + "link.uri": Style(color="#bd93f9", underline=True), + "list.marker": Style(color="#6272a4"), + "inline_code": Style(color="#f1fa8c"), + "punctuation.bracket": Style(color="#f8f8f2"), + "punctuation.delimiter": Style(color="#f8f8f2"), + "punctuation.special": Style(color="#f8f8f2"), + }, +) + +_DARK_VS = TextAreaTheme( + name="vscode_dark", + base_style=Style(color="#CCCCCC", bgcolor="#1F1F1F"), + gutter_style=Style(color="#6E7681", bgcolor="#1F1F1F"), + cursor_style=Style(color="#1e1e1e", bgcolor="#f0f0f0"), + cursor_line_style=Style(bgcolor="#2b2b2b"), + bracket_matching_style=Style(bgcolor="#3a3a3a", bold=True), + cursor_line_gutter_style=Style(color="#CCCCCC", bgcolor="#2b2b2b"), + selection_style=Style(bgcolor="#264F78"), + syntax_styles={ + "string": Style(color="#ce9178"), + "string.documentation": Style(color="#ce9178"), + "comment": Style(color="#6A9955"), + "heading.marker": Style(color="#6E7681"), + "keyword": Style(color="#C586C0"), + "operator": Style(color="#CCCCCC"), + "conditional": Style(color="#569cd6"), + "keyword.function": Style(color="#569cd6"), + "keyword.return": Style(color="#569cd6"), + "keyword.operator": Style(color="#569cd6"), + "repeat": Style(color="#569cd6"), + "exception": Style(color="#569cd6"), + "include": Style(color="#569cd6"), + "number": Style(color="#b5cea8"), + "float": Style(color="#b5cea8"), + "class": Style(color="#4EC9B0"), + "type": Style(color="#EFCB43"), + "type.class": Style(color="#4EC9B0"), + "type.builtin": Style(color="#9CDCFE"), + "function": Style(color="#DCDCAA"), + "function.call": Style(color="#DCDCAA"), + "method": Style(color="#4EC9B0"), + "method.call": Style(color="#4EC9B0"), + "constructor": Style(color="#4EC9B0"), + "boolean": Style(color="#7DAF9C"), + "constant.builtin": Style(color="#7DAF9C"), + "json.null": Style(color="#7DAF9C"), + "tag": Style(color="#EFCB43"), + "yaml.field": Style(color="#569cd6", bold=True), + "json.label": Style(color="#569cd6", bold=True), + "toml.type": Style(color="#569cd6"), + "toml.datetime": Style(color="#C586C0", italic=True), + "css.property": Style(color="#569cd6"), + "heading": Style(color="#569cd6", bold=True), + "bold": Style(bold=True), + "italic": Style(italic=True), + "strikethrough": Style(strike=True), + "link.uri": Style(color="#40A6FF", underline=True), + "link.label": Style(color="#569cd6"), + "list.marker": Style(color="#6E7681"), + "inline_code": Style(color="#ce9178"), + "info_string": Style(color="#ce9178", bold=True, italic=True), + "punctuation.bracket": Style(color="#CCCCCC"), + "punctuation.delimiter": Style(color="#CCCCCC"), + "punctuation.special": Style(color="#CCCCCC"), + }, +) + +_GITHUB_LIGHT = TextAreaTheme( + name="github_light", + base_style=Style(color="#24292e", bgcolor="#f0f0f0"), + gutter_style=Style(color="#BBBBBB", bgcolor="#f0f0f0"), + cursor_style=Style(color="#fafbfc", bgcolor="#24292e"), + cursor_line_style=Style(bgcolor="#ebebeb"), + bracket_matching_style=Style(color="#24292e", underline=True), + cursor_line_gutter_style=Style(color="#A4A4A4", bgcolor="#ebebeb"), + selection_style=Style(bgcolor="#c8c8fa"), + syntax_styles={ + "string": Style(color="#093069"), + "string.documentation": Style(color="#093069"), + "comment": Style(color="#6a737d"), + "heading.marker": Style(color="#A4A4A4"), + "type": Style(color="#A4A4A4"), + "type.class": Style(color="#A4A4A4"), + "type.builtin": Style(color="#7DAF9C"), + "keyword": Style(color="#d73a49"), + "operator": Style(color="#0450AE"), + "conditional": Style(color="#CF222E"), + "keyword.function": Style(color="#CF222E"), + "keyword.return": Style(color="#CF222E"), + "keyword.operator": Style(color="#CF222E"), + "repeat": Style(color="#CF222E"), + "exception": Style(color="#CF222E"), + "include": Style(color="#CF222E"), + "number": Style(color="#d73a49"), + "float": Style(color="#d73a49"), + "parameter": Style(color="#24292e"), + "class": Style(color="#963800"), + "variable": Style(color="#e36209"), + "function": Style(color="#6639BB"), + "method": Style(color="#6639BB"), + "boolean": Style(color="#7DAF9C"), + "constant.builtin": Style(color="#7DAF9C"), + "tag": Style(color="#6639BB"), + "yaml.field": Style(color="#6639BB"), + "json.label": Style(color="#6639BB"), + "toml.type": Style(color="#6639BB"), + "css.property": Style(color="#6639BB"), + "heading": Style(color="#24292e", bold=True), + "bold": Style(bold=True), + "italic": Style(italic=True), + "strikethrough": Style(strike=True), + "link.uri": Style(color="#40A6FF", underline=True), + "link.label": Style(color="#6639BB"), + "list.marker": Style(color="#A4A4A4"), + "inline_code": Style(color="#093069"), + "punctuation.bracket": Style(color="#24292e"), + "punctuation.delimiter": Style(color="#24292e"), + "punctuation.special": Style(color="#24292e"), + }, +) + +_CSS_THEME = TextAreaTheme(name="css", syntax_styles=_DARK_VS.syntax_styles) + +_BUILTIN_THEMES = { + "css": _CSS_THEME, + "monokai": _MONOKAI, + "dracula": _DRACULA, + "vscode_dark": _DARK_VS, + "github_light": _GITHUB_LIGHT, +} diff --git a/src/memray/_vendor/textual/_time.py b/src/memray/_vendor/textual/_time.py new file mode 100644 index 0000000000..ec09d2fb64 --- /dev/null +++ b/src/memray/_vendor/textual/_time.py @@ -0,0 +1,52 @@ +import asyncio +import sys +from asyncio import sleep as asyncio_sleep +from time import monotonic, perf_counter + +WINDOWS = sys.platform == "win32" + + +if WINDOWS: + time = perf_counter +else: + time = monotonic + + +if WINDOWS: + # sleep on windows as a resolution of 15ms + # Python3.11 is somewhat better, but this home-grown version beats it + # Deduced from practical experiments + + from memray._vendor.textual._win_sleep import sleep as win_sleep + + async def sleep(secs: float) -> None: + """Sleep for a given number of seconds. + + Args: + secs: Number of seconds to sleep for. + """ + await asyncio.create_task(win_sleep(secs)) + +else: + + async def sleep(secs: float) -> None: + """Sleep for a given number of seconds. + + Args: + secs: Number of seconds to sleep for. + """ + # From practical experiments, asyncio.sleep sleeps for at least half a millisecond too much + # Presumably there is overhead asyncio itself which accounts for this + # We will reduce the sleep to compensate, and also don't sleep at all for less than half a millisecond + sleep_for = secs - 0.0005 + if sleep_for > 0: + await asyncio_sleep(sleep_for) + + +get_time = time +"""Get the current wall clock (monotonic) time. + +Returns: + The value (in fractional seconds) of a monotonic clock, + i.e. a clock that cannot go backwards. +""" diff --git a/src/memray/_vendor/textual/_tree_sitter.py b/src/memray/_vendor/textual/_tree_sitter.py new file mode 100644 index 0000000000..8f62b6edbe --- /dev/null +++ b/src/memray/_vendor/textual/_tree_sitter.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from importlib import import_module + +from memray._vendor.textual import log + +try: + from tree_sitter import Language + + _LANGUAGE_CACHE: dict[str, Language] = {} + + _tree_sitter = True + + def get_language(language_name: str) -> Language | None: + if language_name in _LANGUAGE_CACHE: + return _LANGUAGE_CACHE[language_name] + + try: + module = import_module(f"tree_sitter_{language_name}") + except ImportError: + return None + else: + try: + if language_name == "xml": + # xml uses language_xml() instead of language() + # it's the only outlier amongst the languages in the `textual[syntax]` extra + language = Language(module.language_xml()) + else: + language = Language(module.language()) + except (OSError, AttributeError): + log.warning(f"Could not load language {language_name!r}.") + return None + else: + _LANGUAGE_CACHE[language_name] = language + return language + +except ImportError: + _tree_sitter = False + + def get_language(language_name: str) -> Language | None: + return None + + +TREE_SITTER = _tree_sitter diff --git a/src/memray/_vendor/textual/_two_way_dict.py b/src/memray/_vendor/textual/_two_way_dict.py new file mode 100644 index 0000000000..ac0ffe16ef --- /dev/null +++ b/src/memray/_vendor/textual/_two_way_dict.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from typing import Generic, TypeVar + +Key = TypeVar("Key") +Value = TypeVar("Value") + + +class TwoWayDict(Generic[Key, Value]): + """ + A two-way mapping offering O(1) access in both directions. + + Wraps two dictionaries and uses them to provide efficient access to + both values (given keys) and keys (given values). + """ + + def __init__(self, initial: dict[Key, Value]) -> None: + self._forward: dict[Key, Value] = initial + self._reverse: dict[Value, Key] = {value: key for key, value in initial.items()} + + def __setitem__(self, key: Key, value: Value) -> None: + # TODO: Duplicate values need to be managed to ensure consistency, + # decide on best approach. + self._forward.__setitem__(key, value) + self._reverse.__setitem__(value, key) + + def __delitem__(self, key: Key) -> None: + value = self._forward[key] + self._forward.__delitem__(key) + self._reverse.__delitem__(value) + + def __iter__(self): + return iter(self._forward) + + def get(self, key: Key) -> Value | None: + """Given a key, efficiently lookup and return the associated value. + + Args: + key: The key + + Returns: + The value + """ + return self._forward.get(key) + + def get_key(self, value: Value) -> Key | None: + """Given a value, efficiently lookup and return the associated key. + + Args: + value: The value + + Returns: + The key + """ + return self._reverse.get(value) + + def contains_value(self, value: Value) -> bool: + """Check if `value` is a value within this TwoWayDict. + + Args: + value: The value to check. + + Returns: + True if the value is within the values of this dict. + """ + return value in self._reverse + + def __len__(self): + return len(self._forward) + + def __contains__(self, item: Key) -> bool: + return item in self._forward diff --git a/src/memray/_vendor/textual/_types.py b/src/memray/_vendor/textual/_types.py new file mode 100644 index 0000000000..427649cb48 --- /dev/null +++ b/src/memray/_vendor/textual/_types.py @@ -0,0 +1,57 @@ +from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Literal, Union + +from typing_extensions import Protocol + +if TYPE_CHECKING: + from rich.segment import Segment + + from memray._vendor.textual.message import Message + + +class MessageTarget(Protocol): + """Protocol that must be followed by objects that can receive messages.""" + + async def _post_message(self, message: "Message") -> bool: ... + + def post_message(self, message: "Message") -> bool: ... + + +class EventTarget(Protocol): + async def _post_message(self, message: "Message") -> bool: ... + + def post_message(self, message: "Message") -> bool: ... + + +class UnusedParameter: + """Helper type for a parameter that isn't specified in a method call.""" + + +SegmentLines = List[List["Segment"]] +CallbackType = Union[Callable[[], Awaitable[None]], Callable[[], None]] +"""Type used for arbitrary callables used in callbacks.""" +IgnoreReturnCallbackType = Union[Callable[[], Awaitable[Any]], Callable[[], Any]] +"""A callback which ignores the return type.""" +WatchCallbackBothValuesType = Union[ + Callable[[Any, Any], Awaitable[None]], + Callable[[Any, Any], None], +] +"""Type for watch methods that accept the old and new values of reactive objects.""" +WatchCallbackNewValueType = Union[ + Callable[[Any], Awaitable[None]], + Callable[[Any], None], +] +"""Type for watch methods that accept only the new value of reactive objects.""" +WatchCallbackNoArgsType = Union[ + Callable[[], Awaitable[None]], + Callable[[], None], +] +"""Type for watch methods that do not require the explicit value of the reactive.""" +WatchCallbackType = Union[ + WatchCallbackBothValuesType, + WatchCallbackNewValueType, + WatchCallbackNoArgsType, +] +"""Type used for callbacks passed to the `watch` method of widgets.""" + +AnimationLevel = Literal["none", "basic", "full"] +"""The levels that the [`TEXTUAL_ANIMATIONS`][textual.constants.TEXTUAL_ANIMATIONS] env var can be set to.""" diff --git a/src/memray/_vendor/textual/_wait.py b/src/memray/_vendor/textual/_wait.py new file mode 100644 index 0000000000..9774ef3c04 --- /dev/null +++ b/src/memray/_vendor/textual/_wait.py @@ -0,0 +1,41 @@ +from asyncio import sleep +from time import monotonic, process_time + +SLEEP_GRANULARITY: float = 1 / 50 +SLEEP_IDLE: float = SLEEP_GRANULARITY / 20.0 + + +async def wait_for_idle( + min_sleep: float = SLEEP_GRANULARITY, max_sleep: float = 1 +) -> None: + """Wait until the process isn't working very hard. + + This will compare wall clock time with process time. If the process time + is not advancing at the same rate as wall clock time it means the process is + idle (i.e. sleeping or waiting for input). + + When the process is idle it suggests that input has been processed and the state + is predictable enough to test. + + Args: + min_sleep: Minimum time to wait. + max_sleep: Maximum time to wait. + """ + start_time = monotonic() + + while True: + cpu_time = process_time() + # Sleep for a predetermined amount of time + await sleep(SLEEP_GRANULARITY) + # Calculate the wall clock elapsed time and the process elapsed time + cpu_elapsed = process_time() - cpu_time + elapsed_time = monotonic() - start_time + + # If we have slept the maximum, we can break + if elapsed_time >= max_sleep: + break + + # If we have slept at least the minimum and the cpu elapsed is significantly less + # than wall clock, then we can assume the process has finished working for now + if elapsed_time > min_sleep and cpu_elapsed < SLEEP_IDLE: + break diff --git a/src/memray/_vendor/textual/_widget_navigation.py b/src/memray/_vendor/textual/_widget_navigation.py new file mode 100644 index 0000000000..5ee430cd76 --- /dev/null +++ b/src/memray/_vendor/textual/_widget_navigation.py @@ -0,0 +1,182 @@ +""" +Utilities to move index-based selections backward/forward. + +These utilities concern themselves with selections where not all options are available, +otherwise it would be enough to increment/decrement the index and use the operator `%` +to implement wrapping. +""" + +from __future__ import annotations + +from itertools import count +from typing import Literal, Protocol, Sequence + +from typing_extensions import TypeAlias + +from memray._vendor.textual._loop import loop_from_index + + +class Disableable(Protocol): + """Non-widgets that have an enabled/disabled status.""" + + disabled: bool + + +Direction: TypeAlias = Literal[-1, 1] +"""Valid values to determine navigation direction. + +In a vertical setting, 1 points down and -1 points up. +In a horizontal setting, 1 points right and -1 points left. +""" + + +def get_directed_distance( + index: int, start: int, direction: Direction, wrap_at: int +) -> int: + """Computes the distance going from `start` to `index` in the given direction. + + Starting at `start`, this is the number of steps you need to take in the given + `direction` to reach `index`, assuming there is wrapping at 0 and `wrap_at`. + This is also the smallest non-negative integer solution `d` to + `(start + d * direction) % wrap_at == index`. + + The diagram below illustrates the computation of `d1 = distance(2, 8, 1, 10)` and + `d2 = distance(2, 8, -1, 10)`: + + ``` + start ────────────────────┐ + index ────────┐ │ + indices 0 1 2 3 4 5 6 7 8 9 + d1 2 3 4 0 1 + > > > > > (direction == 1) + d2 6 5 4 3 2 1 0 + < < < < < < < (direction == -1) + ``` + + Args: + index: The index that we want to reach. + start: The starting point to consider when computing the distance. + direction: The direction in which we want to compute the distance. + wrap_at: Controls at what point wrapping around takes place. + + Returns: + The computed distance. + """ + return direction * (index - start) % wrap_at + + +def find_first_enabled( + candidates: Sequence[Disableable], +) -> int | None: + """Find the first enabled candidate in a sequence of possibly-disabled objects. + + Args: + candidates: The sequence of candidates to consider. + + Returns: + The first enabled candidate or `None` if none were available. + """ + return next( + (index for index, candidate in enumerate(candidates) if not candidate.disabled), + None, + ) + + +def find_last_enabled(candidates: Sequence[Disableable]) -> int | None: + """Find the last enabled candidate in a sequence of possibly-disabled objects. + + Args: + candidates: The sequence of candidates to consider. + + Returns: + The last enabled candidate or `None` if none were available. + """ + total_candidates = len(candidates) + return next( + ( + total_candidates - offset_from_end + for offset_from_end, candidate in enumerate(reversed(candidates), start=1) + if not candidate.disabled + ), + None, + ) + + +def find_next_enabled( + candidates: Sequence[Disableable], + anchor: int | None, + direction: Direction, +) -> int | None: + """Find the next enabled object if we're currently at the given anchor. + + The definition of "next" depends on the given direction and this function will wrap + around the ends of the sequence of object candidates. + + Args: + candidates: The sequence of object candidates to consider. + anchor: The point of the sequence from which we'll start looking for the next + enabled object. + direction: The direction in which to traverse the candidates when looking for + the next enabled candidate. + + Returns: + The next enabled object. If none are available, return the anchor. + """ + + if anchor is None: + if candidates: + return ( + find_first_enabled(candidates) + if direction == 1 + else find_last_enabled(candidates) + ) + return None + + for index, candidate in loop_from_index(candidates, anchor, direction, wrap=True): + if not candidate.disabled: + return index + return anchor + + +def find_next_enabled_no_wrap( + candidates: Sequence[Disableable], + anchor: int | None, + direction: Direction, + with_anchor: bool = False, +) -> int | None: + """Find the next enabled object starting from the given anchor (without wrapping). + + The meaning of "next" and "past" depend on the direction specified. + + Args: + candidates: The sequence of object candidates to consider. + anchor: The point of the sequence from which we'll start looking for the next + enabled object. + direction: The direction in which to traverse the candidates when looking for + the next enabled candidate. + with_anchor: Whether to consider the anchor or not. + + Returns: + The next enabled object. If none are available, return None. + """ + + if anchor is None: + if candidates: + return ( + find_first_enabled(candidates) + if direction == 1 + else find_last_enabled(candidates) + ) + return None + + start = anchor if with_anchor else anchor + direction + counter = count(start, direction) + valid_candidates = ( + candidates[start:] if direction == 1 else reversed(candidates[: start + 1]) + ) + + for idx, candidate in zip(counter, valid_candidates): + if candidate.disabled: + continue + return idx + return None diff --git a/src/memray/_vendor/textual/_win_sleep.py b/src/memray/_vendor/textual/_win_sleep.py new file mode 100644 index 0000000000..c8f938c1fc --- /dev/null +++ b/src/memray/_vendor/textual/_win_sleep.py @@ -0,0 +1,121 @@ +""" +A version of `time.sleep` that is more accurate than the standard library (even on Python 3.11). + +This should only be imported on Windows. +""" + +from __future__ import annotations + +import asyncio +from time import sleep as time_sleep +from typing import Coroutine + +__all__ = ["sleep"] + + +INFINITE = 0xFFFFFFFF +WAIT_FAILED = 0xFFFFFFFF +CREATE_WAITABLE_TIMER_HIGH_RESOLUTION = 0x00000002 +TIMER_ALL_ACCESS = 0x1F0003 + + +async def time_sleep_coro(secs: float): + """Coroutine wrapper around `time.sleep`.""" + await asyncio.sleep(secs) + + +try: + import ctypes + from ctypes.wintypes import HANDLE, LARGE_INTEGER + + kernel32 = ctypes.windll.kernel32 # type: ignore[attr-defined] +except Exception: + + def sleep(secs: float) -> Coroutine[None, None, None]: + """Wrapper around `time.sleep` to match the signature of the main case below.""" + return time_sleep_coro(secs) + +else: + + async def no_sleep_coro(): + """Creates a coroutine that does nothing for when no sleep is needed.""" + pass + + def sleep(secs: float) -> Coroutine[None, None, None]: + """A replacement sleep for Windows. + + Note that unlike `time.sleep` this *may* sleep for slightly less than the + specified time. This is generally not an issue for Textual's use case. + + In order to create a timer that _can_ be cancelled on Windows, we need to + create a timer and a separate event, and then we wait for either of the two + things. When Textual wants to quit, we set the cancel event. + + Args: + secs: Seconds to sleep for. + """ + + # Subtract a millisecond to account for overhead + sleep_for = max(0, secs - 0.001) + if sleep_for < 0.0005: + # Less than 0.5ms and its not worth doing the sleep + return no_sleep_coro() + + timer = kernel32.CreateWaitableTimerExW( + None, + None, + CREATE_WAITABLE_TIMER_HIGH_RESOLUTION, + TIMER_ALL_ACCESS, + ) + if not timer: + return time_sleep_coro(sleep_for) + + if not kernel32.SetWaitableTimer( + timer, + ctypes.byref(LARGE_INTEGER(int(sleep_for * -10_000_000))), + 0, + None, + None, + 0, + ): + kernel32.CloseHandle(timer) + return time_sleep_coro(sleep_for) + + cancel_event = kernel32.CreateEventExW(None, None, 0, TIMER_ALL_ACCESS) + if not cancel_event: + kernel32.CloseHandle(timer) + return time_sleep_coro(sleep_for) + + def cancel_inner(): + """Sets the cancel event so we know we can stop waiting for the timer.""" + kernel32.SetEvent(cancel_event) + + async def cancel(): + """Cancels the timer by setting the cancel event.""" + await asyncio.get_running_loop().run_in_executor(None, cancel_inner) + + def wait_inner(): + """Function responsible for waiting for the timer or the cancel event.""" + if ( + kernel32.WaitForMultipleObjects( + 2, + ctypes.pointer((HANDLE * 2)(cancel_event, timer)), + False, + INFINITE, + ) + == WAIT_FAILED + ): + time_sleep(sleep_for) + + async def wait(): + """Wraps the actual sleeping so we can detect if the thread was cancelled.""" + try: + await asyncio.get_running_loop().run_in_executor(None, wait_inner) + except asyncio.CancelledError: + await cancel() + raise + finally: + kernel32.CloseHandle(timer) + kernel32.CloseHandle(cancel_event) + + return wait() diff --git a/src/memray/_vendor/textual/_work_decorator.py b/src/memray/_vendor/textual/_work_decorator.py new file mode 100644 index 0000000000..c50f496d84 --- /dev/null +++ b/src/memray/_vendor/textual/_work_decorator.py @@ -0,0 +1,158 @@ +""" +A decorator used to create [workers](/guide/workers). +""" + +from __future__ import annotations + +from functools import partial, wraps +from inspect import iscoroutinefunction +from typing import TYPE_CHECKING, Callable, Coroutine, TypeVar, Union, cast, overload + +from typing_extensions import ParamSpec, TypeAlias + +if TYPE_CHECKING: + from memray._vendor.textual.worker import Worker + + +FactoryParamSpec = ParamSpec("FactoryParamSpec") +DecoratorParamSpec = ParamSpec("DecoratorParamSpec") +ReturnType = TypeVar("ReturnType") + +Decorator: TypeAlias = Callable[ + [ + Union[ + Callable[DecoratorParamSpec, ReturnType], + Callable[DecoratorParamSpec, Coroutine[None, None, ReturnType]], + ] + ], + Callable[DecoratorParamSpec, "Worker[ReturnType]"], +] + + +class WorkerDeclarationError(Exception): + """An error in the declaration of a worker method.""" + + +if TYPE_CHECKING: + + @overload + def work( + method: Callable[FactoryParamSpec, Coroutine[None, None, ReturnType]], + *, + name: str = "", + group: str = "default", + exit_on_error: bool = True, + exclusive: bool = False, + description: str | None = None, + thread: bool = False, + ) -> Callable[FactoryParamSpec, "Worker[ReturnType]"]: ... + + @overload + def work( + method: Callable[FactoryParamSpec, ReturnType], + *, + name: str = "", + group: str = "default", + exit_on_error: bool = True, + exclusive: bool = False, + description: str | None = None, + thread: bool = False, + ) -> Callable[FactoryParamSpec, "Worker[ReturnType]"]: ... + + @overload + def work( + *, + name: str = "", + group: str = "default", + exit_on_error: bool = True, + exclusive: bool = False, + description: str | None = None, + thread: bool = False, + ) -> Decorator[..., ReturnType]: ... + + +def work( + method: ( + Callable[FactoryParamSpec, ReturnType] + | Callable[FactoryParamSpec, Coroutine[None, None, ReturnType]] + | None + ) = None, + *, + name: str = "", + group: str = "default", + exit_on_error: bool = True, + exclusive: bool = False, + description: str | None = None, + thread: bool = False, +) -> Callable[FactoryParamSpec, Worker[ReturnType]] | Decorator: + """A decorator used to create [workers](/guide/workers). + + Args: + method: A function or coroutine. + name: A short string to identify the worker (in logs and debugging). + group: A short string to identify a group of workers. + exit_on_error: Exit the app if the worker raises an error. Set to `False` to suppress exceptions. + exclusive: Cancel all workers in the same group. + description: Readable description of the worker for debugging purposes. + By default, it uses a string representation of the decorated method + and its arguments. + thread: Mark the method as a thread worker. + """ + + def decorator( + method: ( + Callable[DecoratorParamSpec, ReturnType] + | Callable[DecoratorParamSpec, Coroutine[None, None, ReturnType]] + ), + ) -> Callable[DecoratorParamSpec, Worker[ReturnType]]: + """The decorator.""" + + # Methods that aren't async *must* be marked as being a thread + # worker. + if not iscoroutinefunction(method) and not thread: + raise WorkerDeclarationError( + "Can not create a worker from a non-async function unless `thread=True` is set on the work decorator." + ) + + @wraps(method) + def decorated( + *args: DecoratorParamSpec.args, **kwargs: DecoratorParamSpec.kwargs + ) -> Worker[ReturnType]: + """The replaced callable.""" + from memray._vendor.textual.dom import DOMNode + + self = args[0] + assert isinstance(self, DOMNode) + + if description is not None: + debug_description = description + else: + try: + positional_arguments = ", ".join(repr(arg) for arg in args[1:]) + keyword_arguments = ", ".join( + f"{name}={value!r}" for name, value in kwargs.items() + ) + tokens = [positional_arguments, keyword_arguments] + debug_description = f"{method.__name__}({', '.join(token for token in tokens if token)})" + except Exception: + debug_description = "" + worker = cast( + "Worker[ReturnType]", + self.run_worker( + partial(method, *args, **kwargs), + name=name or method.__name__, + group=group, + description=debug_description, + exclusive=exclusive, + exit_on_error=exit_on_error, + thread=thread, + ), + ) + return worker + + return decorated + + if method is None: + return decorator + else: + return decorator(method) diff --git a/src/memray/_vendor/textual/_wrap.py b/src/memray/_vendor/textual/_wrap.py new file mode 100644 index 0000000000..4d9544d5df --- /dev/null +++ b/src/memray/_vendor/textual/_wrap.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import re +from typing import Iterable + +from rich.cells import get_character_cell_size + +from memray._vendor.textual._cells import cell_len +from memray._vendor.textual._loop import loop_last +from memray._vendor.textual.expand_tabs import get_tab_widths + +re_chunk = re.compile(r"\S+\s*|\s+") + + +def chunks(text: str) -> Iterable[tuple[int, int, str]]: + """Yields each "chunk" from the text as a tuple containing (start_index, end_index, chunk_content). + A "chunk" in this context refers to a word and any whitespace around it. + + Args: + text: The text to split into chunks. + + Returns: + Yields tuples containing the start, end and content for each chunk. + """ + end = 0 + while (chunk_match := re_chunk.match(text, end)) is not None: + start, end = chunk_match.span() + chunk = chunk_match.group(0) + yield start, end, chunk + + +def compute_wrap_offsets( + text: str, + width: int, + tab_size: int, + fold: bool = True, + precomputed_tab_sections: list[tuple[str, int]] | None = None, +) -> list[int]: + """Given a string of text, and a width (measured in cells), return a list + of codepoint indices which the string should be split at in order for it to fit + within the given width. + + Args: + text: The text to examine. + width: The available cell width. + tab_size: The tab stop width. + fold: If True, words longer than `width` will be folded onto a new line. + precomputed_tab_sections: The output of `get_tab_widths` can be passed here directly, + to prevent us from having to recompute the value. + + Returns: + A list of indices to break the line at. + """ + tab_size = min(tab_size, width) + if precomputed_tab_sections: + tab_sections = precomputed_tab_sections + else: + tab_sections = get_tab_widths(text, tab_size) + + break_positions: list[int] = [] # offsets to insert the breaks at + append = break_positions.append + cell_offset = 0 + _cell_len = cell_len + + tab_section_index = 0 + cumulative_width = 0 + cumulative_widths: list[int] = [] # prefix sum of tab widths for each codepoint + record_widths = cumulative_widths.extend + + for last, (tab_section, tab_width) in loop_last(tab_sections): + # add 1 since the \t character is stripped by get_tab_widths + section_codepoint_length = len(tab_section) + int(bool(tab_width)) + widths = [cumulative_width] * section_codepoint_length + record_widths(widths) + cumulative_width += tab_width + if last: + cumulative_widths.append(cumulative_width) + + for start, end, chunk in chunks(text): + chunk_width = _cell_len(chunk) # this cell len excludes tabs completely + tab_width_before_start = cumulative_widths[start] + tab_width_before_end = cumulative_widths[end] + chunk_tab_width = tab_width_before_end - tab_width_before_start + chunk_width += chunk_tab_width + remaining_space = width - cell_offset + chunk_fits = remaining_space >= chunk_width + + if chunk_fits: + # Simplest case - the word fits within the remaining width for this line. + cell_offset += chunk_width + else: + # Not enough space remaining for this word on the current line. + if chunk_width > width: + # The word doesn't fit on any line, so we must fold it + if fold: + _get_character_cell_size = get_character_cell_size + lines: list[list[str]] = [[]] + + append_new_line = lines.append + append_to_last_line = lines[-1].append + + total_width = 0 + for character in chunk: + if character == "\t": + # Tab characters have dynamic width, so look it up + cell_width = tab_sections[tab_section_index][1] + tab_section_index += 1 + else: + cell_width = _get_character_cell_size(character) + + if total_width + cell_width > width: + append_new_line([character]) + append_to_last_line = lines[-1].append + total_width = cell_width + else: + append_to_last_line(character) + total_width += cell_width + + folded_word = ["".join(line) for line in lines] + for last, line in loop_last(folded_word): + if start: + append(start) + if last: + # Since cell_len ignores tabs, we need to check the width + # of the tabs in this line. The width of tabs within the + # line is computed by taking the difference between the + # cumulative width of tabs up to the end of the line and the + # cumulative width of tabs up to the start of the line. + line_tab_widths = ( + cumulative_widths[start + len(line)] + - cumulative_widths[start] + ) + cell_offset = _cell_len(line) + line_tab_widths + else: + start += len(line) + else: + # Folding isn't allowed, so crop the word. + if start: + append(start) + cell_offset = chunk_width + elif cell_offset and start: + # The word doesn't fit within the remaining space on the current + # line, but it *can* fit on to the next (empty) line. + append(start) + cell_offset = chunk_width + + return break_positions diff --git a/src/memray/_vendor/textual/_xterm_parser.py b/src/memray/_vendor/textual/_xterm_parser.py new file mode 100644 index 0000000000..328689daa5 --- /dev/null +++ b/src/memray/_vendor/textual/_xterm_parser.py @@ -0,0 +1,400 @@ +from __future__ import annotations + +import os +import re +from typing import Any, Generator, Iterable + +from typing_extensions import Final + +from memray._vendor.textual import constants, events, messages +from memray._vendor.textual._ansi_sequences import ANSI_SEQUENCES_KEYS, IGNORE_SEQUENCE +from memray._vendor.textual._keyboard_protocol import FUNCTIONAL_KEYS +from memray._vendor.textual._parser import ParseEOF, Parser, ParseTimeout, Peek1, Read1, TokenCallback +from memray._vendor.textual.keys import KEY_NAME_REPLACEMENTS, Keys, _character_to_key +from memray._vendor.textual.message import Message + +# When trying to determine whether the current sequence is a supported/valid +# escape sequence, at which length should we give up and consider our search +# to be unsuccessful? +_MAX_SEQUENCE_SEARCH_THRESHOLD = 32 + +_re_mouse_event = re.compile("^" + re.escape("\x1b[") + r"(\d+);(?P\d)\$y" +) + +_re_cursor_position = re.compile(r"\x1b\[(?P\d+);(?P\d+)R") + +BRACKETED_PASTE_START: Final[str] = "\x1b[200~" +"""Sequence received when a bracketed paste event starts.""" +BRACKETED_PASTE_END: Final[str] = "\x1b[201~" +"""Sequence received when a bracketed paste event ends.""" +FOCUSIN: Final[str] = "\x1b[I" +"""Sequence received when the terminal receives focus.""" +FOCUSOUT: Final[str] = "\x1b[O" +"""Sequence received when focus is lost from the terminal.""" + +SPECIAL_SEQUENCES = {BRACKETED_PASTE_START, BRACKETED_PASTE_END, FOCUSIN, FOCUSOUT} +"""Set of special sequences.""" + +_re_extended_key: Final = re.compile(r"\x1b\[(?:(\d+)(?:;(\d+))?)?([u~ABCDEFHPQRS])") +_re_in_band_window_resize: Final = re.compile( + r"\x1b\[48;(\d+(?:\:.*?)?);(\d+(?:\:.*?)?);(\d+(?:\:.*?)?);(\d+(?:\:.*?)?)t" +) + + +IS_ITERM = ( + os.environ.get("LC_TERMINAL", "") == "iTerm2" + or os.environ.get("TERM_PROGRAM", "") == "iTerm.app" +) + + +class XTermParser(Parser[Message]): + _re_sgr_mouse = re.compile(r"\x1b\[<(\d+);(-?\d+);(-?\d+)([Mm])") + + def __init__(self, debug: bool = False) -> None: + self.last_x = 0.0 + self.last_y = 0.0 + self.mouse_pixels = False + self.terminal_size: tuple[int, int] | None = None + self.terminal_pixel_size: tuple[int, int] | None = None + self._debug_log_file = open("keys.log", "at") if debug else None + super().__init__() + self.debug_log("---") + + def debug_log(self, *args: Any) -> None: # pragma: no cover + if self._debug_log_file is not None: + self._debug_log_file.write(" ".join(args) + "\n") + self._debug_log_file.flush() + + def feed(self, data: str) -> Iterable[Message]: + self.debug_log(f"FEED {data!r}") + return super().feed(data) + + def parse_mouse_code(self, code: str) -> Message | None: + sgr_match = self._re_sgr_mouse.match(code) + if sgr_match: + _buttons, _x, _y, state = sgr_match.groups() + buttons = int(_buttons) + x = float(int(_x) - 1) + y = float(int(_y) - 1) + if x < 0 or y < 0: + # TODO: Workaround for Ghostty erroneous negative coordinate bug + return None + if ( + self.mouse_pixels + and self.terminal_pixel_size is not None + and self.terminal_size is not None + ): + pixel_width, pixel_height = self.terminal_pixel_size + width, height = self.terminal_size + x_ratio = pixel_width / width + y_ratio = pixel_height / height + x /= x_ratio + y /= y_ratio + + delta_x = int(x) - int(self.last_x) + delta_y = int(y) - int(self.last_y) + self.last_x = x + self.last_y = y + event_class: type[events.MouseEvent] + + if buttons & 64: + event_class = [ + events.MouseScrollUp, + events.MouseScrollDown, + events.MouseScrollLeft, + events.MouseScrollRight, + ][buttons & 3] + button = 0 + else: + button = (buttons + 1) & 3 + # XTerm events for mouse movement can look like mouse button down events. But if there is no key pressed, + # it's a mouse move event. + if buttons & 32 or button == 0: + event_class = events.MouseMove + else: + event_class = events.MouseDown if state == "M" else events.MouseUp + + event = event_class( + None, + x, + y, + delta_x, + delta_y, + button, + bool(buttons & 4), + bool(buttons & 8), + bool(buttons & 16), + screen_x=x, + screen_y=y, + ) + return event + return None + + def parse( + self, token_callback: TokenCallback + ) -> Generator[Read1 | Peek1, str, None]: + ESC = "\x1b" + read1 = self.read1 + sequence_to_key_events = self._sequence_to_key_events + paste_buffer: list[str] = [] + bracketed_paste = False + + def on_token(token: Message) -> None: + """Hook to log events.""" + self.debug_log(str(token)) + if isinstance(token, events.Resize): + self.terminal_size = token.size + self.terminal_pixel_size = token.pixel_size + token_callback(token) + + def on_key_token(event: events.Key) -> None: + """Token callback wrapper for handling keys. + + Args: + event: The key event to send to the callback. + + This wrapper looks for keys that should be ignored, and filters + them out, logging the ignored sequence when it does. + """ + if event.key == Keys.Ignore: + self.debug_log(f"ignored={event.character!r}") + else: + on_token(event) + + def reissue_sequence_as_keys( + reissue_sequence: str, process_alt: bool = False + ) -> None: + """Called when an escape sequence hasn't been understood. + + Args: + reissue_sequence: Key sequence to report to the app. + """ + + alt = False + + if reissue_sequence: + self.debug_log("REISSUE", repr(reissue_sequence)) + for character in reissue_sequence: + if process_alt and character == ESC: + alt = True + continue + key_events = sequence_to_key_events(character, alt=alt) + for event in key_events: + if event.key == "escape" and not process_alt: + event = events.Key("circumflex_accent", "^") + on_token(event) + alt = False + + while not self.is_eof: + if not bracketed_paste and paste_buffer: + # We're at the end of the bracketed paste. + # The paste buffer has content, but the bracketed paste has finished, + # so we flush the paste buffer. We have to remove the final character + # since if bracketed paste has come to an end, we'll have added the + # ESC from the closing bracket, since at that point we didn't know what + # the full escape code was. + pasted_text = "".join(paste_buffer[:-1]) + # Note the removal of NUL characters: https://github.com/Textualize/textual/issues/1661 + on_token(events.Paste(pasted_text.replace("\x00", ""))) + paste_buffer.clear() + + try: + character = yield read1() + except ParseEOF: + return + + if bracketed_paste: + paste_buffer.append(character) + + self.debug_log(f"character={character!r}") + if character != ESC: + if not bracketed_paste: + for event in sequence_to_key_events(character): + on_key_token(event) + if not character: + return + continue + + # # Could be the escape key was pressed OR the start of an escape sequence + sequence: str = ESC + + def send_sequence(process_alt: bool = True) -> None: + """Send escape key and reissue sequence.""" + if sequence == ESC: + on_token(events.Key("escape", "\x1b")) + else: + reissue_sequence_as_keys(sequence, process_alt=process_alt) + + while True: + try: + new_character = yield read1(constants.ESCAPE_DELAY) + except ParseTimeout: + send_sequence() + break + except ParseEOF: + send_sequence() + return + + if new_character == ESC: + send_sequence(process_alt=False) + sequence = character + continue + else: + sequence += new_character + if len(sequence) > _MAX_SEQUENCE_SEARCH_THRESHOLD: + reissue_sequence_as_keys(sequence) + break + + self.debug_log(f"sequence={sequence!r}") + if sequence in SPECIAL_SEQUENCES: + if sequence == FOCUSIN: + on_token(events.AppFocus()) + elif sequence == FOCUSOUT: + on_token(events.AppBlur()) + elif sequence == BRACKETED_PASTE_START: + bracketed_paste = True + elif sequence == BRACKETED_PASTE_END: + bracketed_paste = False + break + if match := _re_in_band_window_resize.fullmatch(sequence): + height, width, pixel_height, pixel_width = [ + group.partition(":")[0] for group in match.groups() + ] + resize_event = events.Resize.from_dimensions( + (int(width), int(height)), + (int(pixel_width), int(pixel_height)), + ) + + self.terminal_size = resize_event.size + self.terminal_pixel_size = resize_event.pixel_size + self.mouse_pixels = True + on_token(resize_event) + break + + if not bracketed_paste: + # Check cursor position report + cursor_position_match = _re_cursor_position.match(sequence) + if cursor_position_match is not None: + row, column = map(int, cursor_position_match.groups()) + x = int(column) - 1 + y = int(row) - 1 + on_token(events.CursorPosition(x, y)) + break + + # Was it a pressed key event that we received? + key_events = list(sequence_to_key_events(sequence)) + for key_event in key_events: + on_key_token(key_event) + if key_events: + break + # Or a mouse event? + mouse_match = _re_mouse_event.match(sequence) + if mouse_match is not None: + mouse_code = mouse_match.group(0) + mouse_event = self.parse_mouse_code(mouse_code) + if mouse_event is not None: + on_token(mouse_event) + break + + # Or a mode report? + # (i.e. the terminal saying it supports a mode we requested) + mode_report_match = _re_terminal_mode_response.match(sequence) + if mode_report_match is not None: + mode_id = mode_report_match["mode_id"] + setting_parameter = int(mode_report_match["setting_parameter"]) + if mode_id == "2026" and setting_parameter > 0: + on_token(messages.TerminalSupportsSynchronizedOutput()) + elif ( + mode_id == "2048" + and constants.SMOOTH_SCROLL + and not IS_ITERM + ): + # TODO: iTerm is buggy in one or more of the protocols required here + in_band_event = ( + messages.InBandWindowResize.from_setting_parameter( + setting_parameter + ) + ) + on_token(in_band_event) + break + + if self._debug_log_file is not None: + self._debug_log_file.close() + self._debug_log_file = None + + def _sequence_to_key_events( + self, sequence: str, alt: bool = False + ) -> Iterable[events.Key]: + """Map a sequence of code points on to a sequence of keys. + + Args: + sequence: Sequence of code points. + + Returns: + Keys + """ + + if (match := _re_extended_key.fullmatch(sequence)) is not None: + number, modifiers, end = match.groups() + number = number or 1 + if not (key := FUNCTIONAL_KEYS.get(f"{number}{end}", "")): + try: + key = _character_to_key(chr(int(number))) + except Exception: + key = chr(int(number)) + key_tokens: list[str] = [] + if modifiers: + modifier_bits = int(modifiers) - 1 + # Not convinced of the utility in reporting caps_lock and num_lock + MODIFIERS = ("shift", "alt", "ctrl", "super", "hyper", "meta") + # Ignore caps_lock and num_lock modifiers + for bit, modifier in enumerate(MODIFIERS): + if modifier_bits & (1 << bit): + key_tokens.append(modifier) + + key_tokens.sort() + key_tokens.append(key.lower()) + yield events.Key( + "+".join(key_tokens), sequence if len(sequence) == 1 else None + ) + return + + keys = ANSI_SEQUENCES_KEYS.get(sequence) + # If we're being asked to ignore the key... + if keys is IGNORE_SEQUENCE: + # ...build a special ignore key event, which has the ignore + # name as the key (that is, the key this sequence is bound + # to is the ignore key) and the sequence that was ignored as + # the character. + yield events.Key(Keys.Ignore, sequence) + return + if isinstance(keys, tuple): + # If the sequence mapped to a tuple, then it's values from the + # `Keys` enum. Raise key events from what we find in the tuple. + for key in keys: + yield events.Key(key.value, sequence if len(sequence) == 1 else None) + return + # If keys is a string, the intention is that it's a mapping to a + # character, which should really be treated as the sequence for the + # purposes of the next step... + if isinstance(keys, str): + sequence = keys + # If the sequence is a single character, attempt to process it as a + # key. + if len(sequence) == 1: + try: + if not sequence.isalnum(): + name = _character_to_key(sequence) + else: + name = sequence + + name = KEY_NAME_REPLACEMENTS.get(name, name) + if len(name) == 1 and alt: + if name.isupper(): + name = f"shift+{name.lower()}" + name = f"alt+{name}" + yield events.Key(name, sequence) + except Exception: + yield events.Key(sequence, sequence) diff --git a/src/memray/_vendor/textual/actions.py b/src/memray/_vendor/textual/actions.py new file mode 100644 index 0000000000..a7ff7bbde2 --- /dev/null +++ b/src/memray/_vendor/textual/actions.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import ast +import re +from functools import lru_cache +from typing import Any + +from typing_extensions import TypeAlias + +ActionParseResult: TypeAlias = "tuple[str, str, tuple[object, ...]]" +"""An action is its name and the arbitrary tuple of its arguments.""" + + +class SkipAction(Exception): + """Raise in an action to skip the action (and allow any parent bindings to run).""" + + +class ActionError(Exception): + pass + + +re_action_args = re.compile(r"([\w\.]+)\((.*)\)") + + +@lru_cache(maxsize=1024) +def parse(action: str) -> ActionParseResult: + """Parses an action string. + + Args: + action: String containing action. + + Raises: + ActionError: If the action has invalid syntax. + + Returns: + Action name and arguments. + """ + args_match = re_action_args.match(action) + if args_match is not None: + action_name, action_args_str = args_match.groups() + if action_args_str: + try: + # We wrap `action_args_str` to be able to disambiguate the cases where + # the list of arguments is a comma-separated list of values from the + # case where the argument is a single tuple. + action_args: tuple[Any, ...] = ast.literal_eval(f"({action_args_str},)") + except Exception: + raise ActionError( + f"unable to parse {action_args_str!r} in action {action!r}" + ) + else: + action_args = () + else: + action_name = action + action_args = () + + namespace, _, action_name = action_name.rpartition(".") + + return namespace, action_name, action_args diff --git a/src/memray/_vendor/textual/app.py b/src/memray/_vendor/textual/app.py new file mode 100644 index 0000000000..5f6740a0eb --- /dev/null +++ b/src/memray/_vendor/textual/app.py @@ -0,0 +1,4985 @@ +""" + +Here you will find the [App][textual.app.App] class, which is the base class for Textual apps. + +See [app basics](/guide/app) for how to build Textual apps. +""" + +from __future__ import annotations + +import asyncio +import importlib +import inspect +import io +import mimetypes +import os +import signal +import sys +import threading +import uuid +import warnings +from asyncio import AbstractEventLoop, Task, create_task +from concurrent.futures import Future +from contextlib import ( + asynccontextmanager, + contextmanager, + redirect_stderr, + redirect_stdout, +) +from functools import partial +from pathlib import Path +from time import perf_counter +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + Awaitable, + BinaryIO, + Callable, + ClassVar, + Generator, + Generic, + Iterable, + Iterator, + Mapping, + NamedTuple, + Sequence, + TextIO, + Type, + TypeVar, + overload, +) +from weakref import WeakKeyDictionary, WeakSet + +import rich +import rich.repr +from platformdirs import user_downloads_path +from rich.console import Console, ConsoleDimensions, ConsoleOptions, RenderableType +from rich.control import Control +from rich.protocol import is_renderable +from rich.segment import Segment, Segments +from rich.terminal_theme import TerminalTheme + +from memray._vendor.textual import ( + Logger, + LogGroup, + LogVerbosity, + actions, + constants, + events, + log, + messages, + on, +) +from memray._vendor.textual._animator import DEFAULT_EASING, Animatable, Animator, EasingFunction +from memray._vendor.textual._ansi_sequences import SYNC_END, SYNC_START +from memray._vendor.textual._ansi_theme import ALABASTER, MONOKAI +from memray._vendor.textual._callback import invoke +from memray._vendor.textual._compat import cached_property +from memray._vendor.textual._compositor import CompositorUpdate +from memray._vendor.textual._context import active_app, active_message_pump +from memray._vendor.textual._context import message_hook as message_hook_context_var +from memray._vendor.textual._dispatch_key import dispatch_key +from memray._vendor.textual._event_broker import NoHandler, extract_handler_actions +from memray._vendor.textual._files import generate_datetime_filename +from memray._vendor.textual._path import ( + CSSPathType, + _css_path_type_as_list, + _make_path_object_relative, +) +from memray._vendor.textual._types import AnimationLevel +from memray._vendor.textual._wait import wait_for_idle +from memray._vendor.textual.actions import ActionParseResult, SkipAction +from memray._vendor.textual.await_complete import AwaitComplete +from memray._vendor.textual.await_remove import AwaitRemove +from memray._vendor.textual.binding import Binding, BindingsMap, BindingType, Keymap +from memray._vendor.textual.command import CommandListItem, CommandPalette, Provider, SimpleProvider +from memray._vendor.textual.compose import compose +from memray._vendor.textual.content import Content +from memray._vendor.textual.css.errors import StylesheetError +from memray._vendor.textual.css.query import NoMatches +from memray._vendor.textual.css.stylesheet import RulesMap, Stylesheet +from memray._vendor.textual.dom import DOMNode, NoScreen +from memray._vendor.textual.driver import Driver +from memray._vendor.textual.errors import NoWidget +from memray._vendor.textual.features import FeatureFlag, parse_features +from memray._vendor.textual.file_monitor import FileMonitor +from memray._vendor.textual.filter import ANSIToTruecolor, DimFilter, Monochrome, NoColor +from memray._vendor.textual.geometry import Offset, Region, Size +from memray._vendor.textual.keys import ( + REPLACED_KEYS, + _character_to_key, + _get_unicode_name_from_key, + _normalize_key_list, + format_key, +) +from memray._vendor.textual.messages import CallbackType, Prune +from memray._vendor.textual.notifications import Notification, Notifications, Notify, SeverityLevel +from memray._vendor.textual.reactive import Reactive +from memray._vendor.textual.renderables.blank import Blank +from memray._vendor.textual.screen import ( + ActiveBinding, + Screen, + ScreenResultCallbackType, + ScreenResultType, + SystemModalScreen, +) +from memray._vendor.textual.signal import Signal +from memray._vendor.textual.theme import BUILTIN_THEMES, Theme, ThemeProvider +from memray._vendor.textual.timer import Timer +from memray._vendor.textual.visual import SupportsVisual, Visual +from memray._vendor.textual.widget import AwaitMount, Widget +from memray._vendor.textual.widgets._toast import ToastRack +from memray._vendor.textual.worker import NoActiveWorker, get_current_worker +from memray._vendor.textual.worker_manager import WorkerManager + +if TYPE_CHECKING: + from textual_dev.client import DevtoolsClient + from typing_extensions import Coroutine, Literal, Self, TypeAlias + + from memray._vendor.textual._types import MessageTarget + + # Unused & ignored imports are needed for the docs to link to these objects: + from memray._vendor.textual.css.query import WrongType # type: ignore # noqa: F401 + from memray._vendor.textual.filter import LineFilter + from memray._vendor.textual.message import Message + from memray._vendor.textual.pilot import Pilot + from memray._vendor.textual.system_commands import SystemCommandsProvider + from memray._vendor.textual.widget import MountError # type: ignore # noqa: F401 + +WINDOWS = sys.platform == "win32" + +# asyncio will warn against resources not being cleared +if constants.DEBUG: + warnings.simplefilter("always", ResourceWarning) + +# `asyncio.get_event_loop()` is deprecated since Python 3.10: +_ASYNCIO_GET_EVENT_LOOP_IS_DEPRECATED = sys.version_info >= (3, 10, 0) + +ComposeResult = Iterable[Widget] +RenderResult: TypeAlias = "RenderableType | Visual | SupportsVisual" +"""Result of Widget.render()""" + +AutopilotCallbackType: TypeAlias = ( + "Callable[[Pilot[object]], Coroutine[Any, Any, None]]" +) +"""Signature for valid callbacks that can be used to control apps.""" + +CommandCallback: TypeAlias = "Callable[[], Awaitable[Any]] | Callable[[], Any]" +"""Signature for callbacks used in [`get_system_commands`][textual.app.App.get_system_commands]""" + +ScreenType = TypeVar("ScreenType", bound=Screen) +"""Type var for a Screen, used in [`get_screen`][textual.app.App.get_screen].""" + + +class SystemCommand(NamedTuple): + """Defines a system command used in the command palette (yielded from [`get_system_commands`][textual.app.App.get_system_commands]).""" + + title: str + """The title of the command (used in search).""" + help: str + """Additional help text, shown under the title.""" + callback: CommandCallback + """A callback to invoke when the command is selected.""" + discover: bool = True + """Should the command show when the search is empty?""" + + +def get_system_commands_provider() -> type[SystemCommandsProvider]: + """Callable to lazy load the system commands. + + Returns: + System commands class. + """ + from memray._vendor.textual.system_commands import SystemCommandsProvider + + return SystemCommandsProvider + + +class AppError(Exception): + """Base class for general App related exceptions.""" + + +class ActionError(Exception): + """Base class for exceptions relating to actions.""" + + +class ScreenError(Exception): + """Base class for exceptions that relate to screens.""" + + +class ScreenStackError(ScreenError): + """Raised when trying to manipulate the screen stack incorrectly.""" + + +class ModeError(Exception): + """Base class for exceptions related to modes.""" + + +class InvalidModeError(ModeError): + """Raised if there is an issue with a mode name.""" + + +class UnknownModeError(ModeError): + """Raised when attempting to use a mode that is not known.""" + + +class ActiveModeError(ModeError): + """Raised when attempting to remove the currently active mode.""" + + +class SuspendNotSupported(Exception): + """Raised if suspending the application is not supported. + + This exception is raised if [`App.suspend`][textual.app.App.suspend] is called while + the application is running in an environment where this isn't supported. + """ + + +class InvalidThemeError(Exception): + """Raised when an invalid theme is set.""" + + +ReturnType = TypeVar("ReturnType") +CallThreadReturnType = TypeVar("CallThreadReturnType") + + +class _NullFile: + """A file-like where writes go nowhere.""" + + def write(self, text: str) -> None: + pass + + def flush(self) -> None: + pass + + def isatty(self) -> bool: + return True + + +class _PrintCapture: + """A file-like which captures output.""" + + def __init__(self, app: App, stderr: bool = False) -> None: + """ + + Args: + app: App instance. + stderr: Write from stderr. + """ + self.app = app + self.stderr = stderr + + def write(self, text: str) -> None: + """Called when writing to stdout or stderr. + + Args: + text: Text that was "printed". + """ + self.app._print(text, stderr=self.stderr) + + def flush(self) -> None: + """Called when stdout or stderr was flushed.""" + self.app._flush(stderr=self.stderr) + + def isatty(self) -> bool: + """Pretend we're a terminal.""" + # TODO: should this be configurable? + return True + + def fileno(self) -> int: + """Return invalid fileno.""" + return -1 + + +@rich.repr.auto +class App(Generic[ReturnType], DOMNode): + """The base class for Textual Applications.""" + + CSS: ClassVar[str] = "" + """Inline CSS, useful for quick scripts. This is loaded after CSS_PATH, + and therefore takes priority in the event of a specificity clash.""" + + # Default (the lowest priority) CSS + DEFAULT_CSS: ClassVar[str] + DEFAULT_CSS = """ + App { + background: $background; + color: $foreground; + + &:ansi { + background: ansi_default; + color: ansi_default; + + .-ansi-scrollbar { + scrollbar-background: ansi_default; + scrollbar-background-hover: ansi_default; + scrollbar-background-active: ansi_default; + scrollbar-color: ansi_blue; + scrollbar-color-active: ansi_bright_blue; + scrollbar-color-hover: ansi_bright_blue; + scrollbar-corner-color: ansi_default; + } + + .bindings-table--key { + color: ansi_magenta; + } + .bindings-table--description { + color: ansi_default; + } + + .bindings-table--header { + color: ansi_default; + } + + .bindings-table--divider { + color: transparent; + text-style: dim; + } + } + + /* When a widget is maximized */ + Screen.-maximized-view { + layout: vertical !important; + hatch: right $panel; + overflow-y: auto !important; + align: center middle; + .-maximized { + dock: initial !important; + } + } + /* Fade the header title when app is blurred */ + &:blur HeaderTitle { + text-opacity: 50%; + } + } + *:disabled:can-focus { + opacity: 0.7; + } + """ + + MODES: ClassVar[dict[str, str | Callable[[], Screen]]] = {} + """Modes associated with the app and their base screens. + + The base screen is the screen at the bottom of the mode stack. You can think of + it as the default screen for that stack. + The base screens can be names of screens listed in [SCREENS][textual.app.App.SCREENS], + [`Screen`][textual.screen.Screen] instances, or callables that return screens. + + Example: + ```py + class HelpScreen(Screen[None]): + ... + + class MainAppScreen(Screen[None]): + ... + + class MyApp(App[None]): + MODES = { + "default": "main", + "help": HelpScreen, + } + + SCREENS = { + "main": MainAppScreen, + } + + ... + ``` + """ + DEFAULT_MODE: ClassVar[str] = "_default" + """Name of the default mode.""" + + SCREENS: ClassVar[dict[str, Callable[[], Screen[Any]]]] = {} + """Screens associated with the app for the lifetime of the app.""" + + AUTO_FOCUS: ClassVar[str | None] = "*" + """A selector to determine what to focus automatically when a screen is activated. + + The widget focused is the first that matches the given [CSS selector](/guide/queries/#query-selectors). + Setting to `None` or `""` disables auto focus. + """ + + ALLOW_SELECT: ClassVar[bool] = True + """A switch to toggle arbitrary text selection for the app. + + Note that this doesn't apply to Input and TextArea which have builtin support for selection. + """ + + _BASE_PATH: str | None = None + CSS_PATH: ClassVar[CSSPathType | None] = None + """File paths to load CSS from.""" + + TITLE: str | None = None + """A class variable to set the *default* title for the application. + + To update the title while the app is running, you can set the [title][textual.app.App.title] attribute. + See also [the `Screen.TITLE` attribute][textual.screen.Screen.TITLE]. + """ + + SUB_TITLE: str | None = None + """A class variable to set the default sub-title for the application. + + To update the sub-title while the app is running, you can set the [sub_title][textual.app.App.sub_title] attribute. + See also [the `Screen.SUB_TITLE` attribute][textual.screen.Screen.SUB_TITLE]. + """ + + ENABLE_COMMAND_PALETTE: ClassVar[bool] = True + """Should the [command palette][textual.command.CommandPalette] be enabled for the application?""" + + NOTIFICATION_TIMEOUT: ClassVar[float] = 5 + """Default number of seconds to show notifications before removing them.""" + + COMMANDS: ClassVar[set[type[Provider] | Callable[[], type[Provider]]]] = { + get_system_commands_provider + } + """Command providers used by the [command palette](/guide/command_palette). + + Should be a set of [command.Provider][textual.command.Provider] classes. + """ + + COMMAND_PALETTE_BINDING: ClassVar[str] = "ctrl+p" + """The key that launches the command palette (if enabled by [`App.ENABLE_COMMAND_PALETTE`][textual.app.App.ENABLE_COMMAND_PALETTE]).""" + + COMMAND_PALETTE_DISPLAY: ClassVar[str | None] = None + """How the command palette key should be displayed in the footer (or `None` for default).""" + + ALLOW_IN_MAXIMIZED_VIEW: ClassVar[str] = "Footer" + """The default value of [Screen.ALLOW_IN_MAXIMIZED_VIEW][textual.screen.Screen.ALLOW_IN_MAXIMIZED_VIEW].""" + + CLICK_CHAIN_TIME_THRESHOLD: ClassVar[float] = 0.5 + """The maximum number of seconds between clicks to upgrade a single click to a double click, + a double click to a triple click, etc.""" + + BINDINGS: ClassVar[list[BindingType]] = [ + Binding( + "ctrl+q", + "quit", + "Quit", + tooltip="Quit the app and return to the command prompt.", + show=False, + priority=True, + ), + Binding("ctrl+c", "help_quit", show=False, system=True), + ] + """The default key bindings.""" + + CLOSE_TIMEOUT: float | None = 5.0 + """Timeout waiting for widget's to close, or `None` for no timeout.""" + + TOOLTIP_DELAY: float = 0.5 + """The time in seconds after which a tooltip gets displayed.""" + + BINDING_GROUP_TITLE: str | None = None + """Set to text to show in the key panel.""" + + ESCAPE_TO_MINIMIZE: ClassVar[bool] = True + """Use escape key to minimize widgets (potentially overriding bindings). + + This is the default value, used if the active screen's `ESCAPE_TO_MINIMIZE` is not changed from `None`. + """ + + INLINE_PADDING: ClassVar[int] = 1 + """Number of blank lines above an inline app.""" + + SUSPENDED_SCREEN_CLASS: ClassVar[str] = "" + """Class to apply to suspended screens, or empty string for no class.""" + + HORIZONTAL_BREAKPOINTS: ClassVar[list[tuple[int, str]]] | None = [] + """List of horizontal breakpoints for responsive classes. + + This allows for styles to be responsive to the dimensions of the terminal. + For instance, you might want to show less information, or fewer columns on a narrow displays -- or more information when the terminal is sized wider than usual. + + A breakpoint consists of a tuple containing the minimum width where the class should applied, and the name of the class to set. + + Note that only one class name is set, and you should avoid having more than one breakpoint set for the same size. + + Example: + ```python + # Up to 80 cells wide, the app has the class "-normal" + # 80 - 119 cells wide, the app has the class "-wide" + # 120 cells or wider, the app has the class "-very-wide" + HORIZONTAL_BREAKPOINTS = [(0, "-normal"), (80, "-wide"), (120, "-very-wide")] + ``` + + """ + VERTICAL_BREAKPOINTS: ClassVar[list[tuple[int, str]]] | None = [] + """List of vertical breakpoints for responsive classes. + + Contents are the same as [`HORIZONTAL_BREAKPOINTS`][textual.app.App.HORIZONTAL_BREAKPOINTS], but the integer is compared to the height, rather than the width. + """ + + # TODO: Enable by default after suitable testing period + PAUSE_GC_ON_SCROLL: ClassVar[bool] = False + """Pause Python GC (Garbage Collection) when scrolling, for potentially smoother scrolling with many widgets (experimental).""" + + ENABLE_SELECT_AUTO_SCROLL: ClassVar[bool] = True + """Enable automatic scrolling if selecting and the mouse is at the top or bottom of the widget?""" + + SELECT_AUTO_SCROLL_LINES: ClassVar[int] = 3 + """Number of lines in auto-scrolling regions at the top and bottom of a widget.""" + + SELECT_AUTO_SCROLL_SPEED: ClassVar[float] = 60.0 + """Maximum speed of select auto-scroll in lines per second.""" + + _PSEUDO_CLASSES: ClassVar[dict[str, Callable[[App[Any]], bool]]] = { + "focus": lambda app: app.app_focus, + "blur": lambda app: not app.app_focus, + "dark": lambda app: app.current_theme.dark, + "light": lambda app: not app.current_theme.dark, + "inline": lambda app: app.is_inline, + "ansi": lambda app: app.ansi_color, + "nocolor": lambda app: app.no_color, + } + + title: Reactive[str] = Reactive("", compute=False) + """The title of the app, displayed in the header.""" + sub_title: Reactive[str] = Reactive("", compute=False) + """The app's sub-title, combined with [`title`][textual.app.App.title] in the header.""" + + app_focus = Reactive(True, compute=False) + """Indicates if the app has focus. + + When run in the terminal, the app always has focus. When run in the web, the app will + get focus when the terminal widget has focus. + """ + + theme: Reactive[str] = Reactive(constants.DEFAULT_THEME) + """The name of the currently active theme.""" + + ansi_theme_dark = Reactive(MONOKAI, init=False) + """Maps ANSI colors to hex colors using a Rich TerminalTheme object while using a dark theme.""" + + ansi_theme_light = Reactive(ALABASTER, init=False) + """Maps ANSI colors to hex colors using a Rich TerminalTheme object while using a light theme.""" + + ansi_color = Reactive(False) + """Allow ANSI colors in UI?""" + + def __init__( + self, + driver_class: Type[Driver] | None = None, + css_path: CSSPathType | None = None, + watch_css: bool = False, + ansi_color: bool = False, + ): + """Create an instance of an app. + + Args: + driver_class: Driver class or `None` to auto-detect. + This will be used by some Textual tools. + css_path: Path to CSS or `None` to use the `CSS_PATH` class variable. + To load multiple CSS files, pass a list of strings or paths which + will be loaded in order. + watch_css: Reload CSS if the files changed. This is set automatically if + you are using `textual run` with the `dev` switch. + ansi_color: Allow ANSI colors if `True`, or convert ANSI colors to RGB if `False`. + + Raises: + CssPathError: When the supplied CSS path(s) are an unexpected type. + """ + self._start_time = perf_counter() + super().__init__(classes=self.DEFAULT_CLASSES) + self.features: frozenset[FeatureFlag] = parse_features(os.getenv("TEXTUAL", "")) + + self._registered_themes: dict[str, Theme] = {} + """Themes that have been registered with the App using `App.register_theme`. + + This excludes the built-in themes.""" + + for theme in BUILTIN_THEMES.values(): + self.register_theme(theme) + + ansi_theme = ( + self.ansi_theme_dark if self.current_theme.dark else self.ansi_theme_light + ) + self.set_reactive(App.ansi_color, ansi_color) + self._filters: list[LineFilter] = [ + ANSIToTruecolor(ansi_theme, enabled=not ansi_color) + ] + environ = dict(os.environ) + self.no_color = environ.pop("NO_COLOR", None) is not None + if self.no_color: + self._filters.append(NoColor() if self.ansi_color else Monochrome()) + + for filter_name in constants.FILTERS.split(","): + filter = filter_name.lower().strip() + if filter == "dim": + self._filters.append(DimFilter()) + + self.console = Console( + color_system=constants.COLOR_SYSTEM, + file=_NullFile(), + markup=True, + highlight=False, + emoji=False, + legacy_windows=False, + _environ=environ, + force_terminal=True, + safe_box=False, + soft_wrap=False, + ) + self._workers = WorkerManager(self) + self.error_console = Console(markup=False, highlight=False, stderr=True) + self.driver_class = driver_class or self.get_driver_class() + self._screen_stacks: dict[str, list[Screen[Any]]] = {self.DEFAULT_MODE: []} + """A stack of screens per mode.""" + self._current_mode: str = self.DEFAULT_MODE + """The current mode the app is in.""" + self._sync_available = False + + self.mouse_over: Widget | None = None + """The widget directly under the mouse.""" + self.hover_over: Widget | None = None + """The first widget with a hover style under the mouse.""" + self.mouse_captured: Widget | None = None + self._driver: Driver | None = None + self._exit_renderables: list[RenderableType] = [] + + self._action_targets = {"app", "screen", "focused"} + self._animator = Animator(self) + self._animate = self._animator.bind(self) + + self.mouse_position = Offset(0, 0) + """The current screen-space mouse position.""" + + self.mouse_position_high_resolution: tuple[float, float] = (0.0, 0.0) + """A high resolution (floating point) mouse position. If supported by the terminal, this may be more granular than `mouse_position`""" + + self._mouse_down_widget: Widget | None = None + """The widget that was most recently mouse downed (used to create click events).""" + + self._click_chain_last_offset: Offset | None = None + """The last offset at which a Click occurred, in screen-space.""" + + self._click_chain_last_time: float | None = None + """The last time at which a Click occurred.""" + + self._chained_clicks: int = 1 + """Counter which tracks the number of clicks received in a row.""" + + self._previous_cursor_position = Offset(0, 0) + """The previous cursor position""" + + self.cursor_position = Offset(0, 0) + """The position of the terminal cursor in screen-space. + + This can be set by widgets and is useful for controlling the + positioning of OS IME and emoji popup menus.""" + + self._exception: Exception | None = None + """The unhandled exception which is leading to the app shutting down, + or None if the app is still running with no unhandled exceptions.""" + + self.title = ( + self.TITLE if self.TITLE is not None else f"{self.__class__.__name__}" + ) + """The title for the application. + + The initial value for `title` will be set to the `TITLE` class variable if it exists, or + the name of the app if it doesn't. + + Assign a new value to this attribute to change the title. + The new value is always converted to string. + """ + + self.sub_title = self.SUB_TITLE if self.SUB_TITLE is not None else "" + """The sub-title for the application. + + The initial value for `sub_title` will be set to the `SUB_TITLE` class variable if it exists, or + an empty string if it doesn't. + + Sub-titles are typically used to show the high-level state of the app, such as the current mode, or path to + the file being worked on. + + Assign a new value to this attribute to change the sub-title. + The new value is always converted to string. + """ + + self.use_command_palette: bool = self.ENABLE_COMMAND_PALETTE + """A flag to say if the application should use the command palette. + + If set to `False` any call to + [`action_command_palette`][textual.app.App.action_command_palette] + will be ignored. + """ + + self._logger = Logger(self._log, app=self) + + self._css_has_errors = False + + self.theme_variables: dict[str, str] = {} + """Variables generated from the current theme.""" + + # Note that the theme must be set *before* self.get_css_variables() is called + # to ensure that the variables are retrieved from the currently active theme. + self.stylesheet = Stylesheet(variables=self.get_css_variables()) + + css_path = css_path or self.CSS_PATH + css_paths = [ + _make_path_object_relative(css_path, self) + for css_path in ( + _css_path_type_as_list(css_path) if css_path is not None else [] + ) + ] + self.css_path = css_paths + + self._registry: WeakSet[DOMNode] = WeakSet() + + self._keymap: Keymap = {} + + # Sensitivity on X is double the sensitivity on Y to account for + # cells being twice as tall as wide + self.scroll_sensitivity_x: float = 4.0 + """Number of columns to scroll in the X direction with wheel or trackpad.""" + self.scroll_sensitivity_y: float = 2.0 + """Number of lines to scroll in the Y direction with wheel or trackpad.""" + + self._installed_screens: dict[str, Screen | Callable[[], Screen]] = {} + self._installed_screens.update(**self.SCREENS) + self._modes: dict[str, str | Callable[[], Screen]] = self.MODES.copy() + """Contains the working-copy of the `MODES` for each instance.""" + + self._compose_stacks: list[list[Widget]] = [] + self._composed: list[list[Widget]] = [] + self._recompose_required = False + + self.devtools: DevtoolsClient | None = None + self._devtools_redirector: StdoutRedirector | None = None + if "devtools" in self.features: + try: + from textual_dev.client import DevtoolsClient + from textual_dev.redirect_output import StdoutRedirector + except ImportError: + # Dev dependencies not installed + pass + else: + self.devtools = DevtoolsClient(constants.DEVTOOLS_HOST) + self._devtools_redirector = StdoutRedirector(self.devtools) + + self._loop: asyncio.AbstractEventLoop | None = None + self._return_value: ReturnType | None = None + """Internal attribute used to set the return value for the app.""" + self._return_code: int | None = None + """Internal attribute used to set the return code for the app.""" + self._exit = False + self._disable_tooltips = False + self._disable_notifications = False + + self.css_monitor = ( + FileMonitor(self.css_path, self._on_css_change) + if watch_css or self.debug + else None + ) + self._screenshot: str | None = None + self._dom_ready = False + self._batch_count = 0 + self._notifications = Notifications() + + self._capture_print: WeakKeyDictionary[MessageTarget, tuple[bool, bool]] = ( + WeakKeyDictionary() + ) + """Registry of the MessageTargets which are capturing output at any given time.""" + self._capture_stdout = _PrintCapture(self, stderr=False) + """File-like object capturing data written to stdout.""" + self._capture_stderr = _PrintCapture(self, stderr=True) + """File-like object capturing data written to stderr.""" + self._original_stdout = sys.__stdout__ + """The original stdout stream (before redirection etc).""" + self._original_stderr = sys.__stderr__ + """The original stderr stream (before redirection etc).""" + + self.theme_changed_signal: Signal[Theme] = Signal(self, "theme-changed") + """Signal that is published when the App's theme is changed. + + Subscribers will receive the new theme object as an argument to the callback. + """ + + self.app_suspend_signal: Signal[App] = Signal(self, "app-suspend") + """The signal that is published when the app is suspended. + + When [`App.suspend`][textual.app.App.suspend] is called this signal + will be [published][textual.signal.Signal.publish]; + [subscribe][textual.signal.Signal.subscribe] to this signal to + perform work before the suspension takes place. + """ + self.app_resume_signal: Signal[App] = Signal(self, "app-resume") + """The signal that is published when the app is resumed after a suspend. + + When the app is resumed after a + [`App.suspend`][textual.app.App.suspend] call this signal will be + [published][textual.signal.Signal.publish]; + [subscribe][textual.signal.Signal.subscribe] to this signal to + perform work after the app has resumed. + """ + + self.mode_change_signal: Signal[str] = Signal(self, "mode-change") + """A signal published when the current screen mode changes.""" + + self.screen_change_signal: Signal[Screen] = Signal(self, "screen-change") + """A signal published when the current screen changes.""" + + self.set_class(self.current_theme.dark, "-dark-mode", update=False) + self.set_class(not self.current_theme.dark, "-light-mode", update=False) + + self.animation_level: AnimationLevel = constants.TEXTUAL_ANIMATIONS + """Determines what type of animations the app will display. + + See [`textual.constants.TEXTUAL_ANIMATIONS`][textual.constants.TEXTUAL_ANIMATIONS]. + """ + + self._last_focused_on_app_blur: Widget | None = None + """The widget that had focus when the last `AppBlur` happened. + + This will be used to restore correct focus when an `AppFocus` + happens. + """ + + self._previous_inline_height: int | None = None + """Size of previous inline update.""" + + self._resize_event: events.Resize | None = None + """A pending resize event, sent on idle.""" + + self._size: Size | None = None + + self._css_update_count: int = 0 + """Incremented when CSS is invalidated.""" + + self._clipboard: str = "" + """Contents of local clipboard.""" + + self.supports_smooth_scrolling: bool = False + """Does the terminal support smooth scrolling?""" + + self._compose_screen: Screen | None = None + """The screen composed by App.compose.""" + + self._realtime_animation_count = 0 + """Number of current realtime animations, such as scrolling.""" + + if self.ENABLE_COMMAND_PALETTE: + for _key, binding in self._bindings: + if binding.action in {"command_palette", "app.command_palette"}: + break + else: + self._bindings._add_binding( + Binding( + self.COMMAND_PALETTE_BINDING, + "command_palette", + "palette", + show=False, + key_display=self.COMMAND_PALETTE_DISPLAY, + priority=True, + tooltip="Open the command palette", + ) + ) + + def get_line_filters(self) -> Sequence[LineFilter]: + """Get currently enabled line filters. + + Returns: + A list of [LineFilter][textual.filters.LineFilter] instances. + """ + return [filter for filter in self._filters if filter.enabled] + + @property + def _is_devtools_connected(self) -> bool: + """Is the app connected to the devtools?""" + return self.devtools is not None and self.devtools.is_connected + + @cached_property + def _exception_event(self) -> asyncio.Event: + """An event that will be set when the first exception is encountered.""" + return asyncio.Event() + + def __init_subclass__(cls, *args, **kwargs) -> None: + for variable_name, screen_collection in ( + ("SCREENS", cls.SCREENS), + ("MODES", cls.MODES), + ): + for screen_name, screen_object in screen_collection.items(): + if not (isinstance(screen_object, str) or callable(screen_object)): + if isinstance(screen_object, Screen): + raise ValueError( + f"{variable_name} should contain a Screen type or callable, not an instance" + f" (got instance of {type(screen_object).__name__} for {screen_name!r})" + ) + raise TypeError( + f"expected a callable or string, got {screen_object!r}" + ) + + return super().__init_subclass__(*args, **kwargs) + + def _thread_init(self): + """Initialize threading primitives for the current thread. + + https://github.com/Textualize/textual/issues/5845 + + """ + self._message_queue + self._mounted_event + self._exception_event + self._thread_id = threading.get_ident() + + def _get_dom_base(self) -> DOMNode: + """When querying from the app, we want to query the default screen.""" + return self.default_screen + + def validate_title(self, title: Any) -> str: + """Make sure the title is set to a string.""" + return str(title) + + def validate_sub_title(self, sub_title: Any) -> str: + """Make sure the subtitle is set to a string.""" + return str(sub_title) + + @property + def default_screen(self) -> Screen: + """The default screen instance.""" + return self.screen if self._compose_screen is None else self._compose_screen + + @property + def workers(self) -> WorkerManager: + """The [worker](/guide/workers/) manager. + + Returns: + An object to manage workers. + """ + return self._workers + + @property + def return_value(self) -> ReturnType | None: + """The return value of the app, or `None` if it has not yet been set. + + The return value is set when calling [exit][textual.app.App.exit]. + """ + return self._return_value + + @property + def return_code(self) -> int | None: + """The return code with which the app exited. + + Non-zero codes indicate errors. + A value of 1 means the app exited with a fatal error. + If the app hasn't exited yet, this will be `None`. + + Example: + The return code can be used to exit the process via `sys.exit`. + ```py + my_app.run() + sys.exit(my_app.return_code) + ``` + """ + return self._return_code + + @property + def children(self) -> Sequence["Widget"]: + """A view onto the app's immediate children. + + This attribute exists on all widgets. + In the case of the App, it will only ever contain a single child, which will + be the currently active screen. + + Returns: + A sequence of widgets. + """ + try: + return ( + next( + screen + for screen in reversed(self._screen_stack) + if not isinstance(screen, SystemModalScreen) + ), + ) + except StopIteration: + return () + + @property + def clipboard(self) -> str: + """The value of the local clipboard. + + Note, that this only contains text copied in the app, and not + text copied from elsewhere in the OS. + """ + return self._clipboard + + def _realtime_animation_begin(self) -> None: + """A scroll or other animation that must be smooth has begun.""" + if self.PAUSE_GC_ON_SCROLL: + import gc + + gc.disable() + self._realtime_animation_count += 1 + + def _realtime_animation_complete(self) -> None: + """A scroll or other animation that must be smooth has completed.""" + self._realtime_animation_count -= 1 + if self._realtime_animation_count == 0 and self.PAUSE_GC_ON_SCROLL: + import gc + + gc.enable() + + def format_title(self, title: str, sub_title: str) -> Content: + """Format the title for display. + + Args: + title: The title. + sub_title: The sub title. + + Returns: + Content instance with title and subtitle. + """ + title_content = Content(title) + sub_title_content = Content(sub_title) + if sub_title_content: + return Content.assemble( + title_content, + (" — ", "dim"), + sub_title_content.stylize("dim"), + ) + else: + return title_content + + @contextmanager + def batch_update(self) -> Generator[None, None, None]: + """A context manager to suspend all repaints until the end of the batch.""" + self._begin_batch() + try: + yield + finally: + self._end_batch() + + def _begin_batch(self) -> None: + """Begin a batch update.""" + self._batch_count += 1 + + def _end_batch(self) -> None: + """End a batch update.""" + self._batch_count -= 1 + assert self._batch_count >= 0, "This won't happen if you use `batch_update`" + if not self._batch_count: + self.check_idle() + + def delay_update(self, delay: float = 0.05) -> None: + """Delay updates for a short period of time. + + May be used to mask a brief transition. + Consider this method only if you aren't able to use `App.batch_update`. + + Args: + delay: Delay before updating. + """ + self._begin_batch() + + def end_batch() -> None: + """Re-enable updates, and refresh screen.""" + self._end_batch() + if not self._batch_count: + self.screen.refresh() + + self.set_timer(delay, end_batch, name="delay_update") + + @contextmanager + def _context(self) -> Generator[None, None, None]: + """Context manager to set ContextVars.""" + app_reset_token = active_app.set(self) + message_pump_reset_token = active_message_pump.set(self) + try: + yield + finally: + active_message_pump.reset(message_pump_reset_token) + active_app.reset(app_reset_token) + + def _watch_ansi_color(self, ansi_color: bool) -> None: + """Enable or disable the truecolor filter when the reactive changes""" + for filter in self._filters: + if isinstance(filter, ANSIToTruecolor): + filter.enabled = not ansi_color + + def animate( + self, + attribute: str, + value: float | Animatable, + *, + final_value: object = ..., + duration: float | None = None, + speed: float | None = None, + delay: float = 0.0, + easing: EasingFunction | str = DEFAULT_EASING, + on_complete: CallbackType | None = None, + level: AnimationLevel = "full", + ) -> None: + """Animate an attribute. + + See the guide for how to use the [animation](/guide/animation) system. + + Args: + attribute: Name of the attribute to animate. + value: The value to animate to. + final_value: The final value of the animation. + duration: The duration (in seconds) of the animation. + speed: The speed of the animation. + delay: A delay (in seconds) before the animation starts. + easing: An easing method. + on_complete: A callable to invoke when the animation is finished. + level: Minimum level required for the animation to take place (inclusive). + """ + self._animate( + attribute, + value, + final_value=final_value, + duration=duration, + speed=speed, + delay=delay, + easing=easing, + on_complete=on_complete, + level=level, + ) + + async def stop_animation(self, attribute: str, complete: bool = True) -> None: + """Stop an animation on an attribute. + + Args: + attribute: Name of the attribute whose animation should be stopped. + complete: Should the animation be set to its final value? + + Note: + If there is no animation scheduled or running, this is a no-op. + """ + await self._animator.stop_animation(self, attribute, complete) + + @property + def is_dom_root(self) -> bool: + """Is this a root node (i.e. the App)?""" + return True + + @property + def is_attached(self) -> bool: + """Is this node linked to the app through the DOM?""" + return True + + @property + def debug(self) -> bool: + """Is debug mode enabled?""" + return "debug" in self.features or constants.DEBUG + + @property + def is_headless(self) -> bool: + """Is the app running in 'headless' mode? + + Headless mode is used when running tests with [run_test][textual.app.App.run_test]. + """ + return False if self._driver is None else self._driver.is_headless + + @property + def is_inline(self) -> bool: + """Is the app running in 'inline' mode?""" + return False if self._driver is None else self._driver.is_inline + + @property + def is_web(self) -> bool: + """Is the app running in 'web' mode via a browser?""" + return False if self._driver is None else self._driver.is_web + + @property + def screen_stack(self) -> list[Screen[Any]]: + """A snapshot of the current screen stack. + + Returns: + A snapshot of the current state of the screen stack. + """ + return self._screen_stacks[self._current_mode].copy() + + @property + def _screen_stack(self) -> list[Screen[Any]]: + """A reference to the current screen stack. + + Note: + Consider using [`screen_stack`][textual.app.App.screen_stack] instead. + + Returns: + A reference to the current screen stack. + """ + return self._screen_stacks[self._current_mode] + + @property + def current_mode(self) -> str: + """The name of the currently active mode.""" + return self._current_mode + + @property + def console_options(self) -> ConsoleOptions: + """Get options for the Rich console. + + Returns: + Console options (same object returned from `console.options`). + """ + size = ConsoleDimensions(*self.size) + console = self.console + return ConsoleOptions( + max_height=size.height, + size=size, + legacy_windows=console.legacy_windows, + min_width=1, + max_width=size.width, + encoding=console.encoding, + is_terminal=console.is_terminal, + ) + + def get_screen_stack(self, mode: str | None = None) -> list[Screen]: + """Get the screen stack for the given mode, or the current mode if no mode is specified. + + Args: + mode: Name of a model + + Raises: + KeyError: If there is no mode. + + Returns: + A list of screens. Note that this is a copy, and modifying the list will not impact the app's screen stack. + """ + if mode is None: + mode = self._current_mode + try: + stack = self._screen_stacks[mode] + except KeyError: + raise KeyError(f"No mode called {mode!r}") from None + return stack.copy() + + def exit( + self, + result: ReturnType | None = None, + return_code: int = 0, + message: RenderableType | None = None, + ) -> None: + """Exit the app, and return the supplied result. + + Args: + result: Return value. + return_code: The return code. Use non-zero values for error codes. + message: Optional message to display on exit. + """ + self._exit = True + self._return_value = result + self._return_code = return_code + self.post_message(messages.ExitApp()) + if message: + self._exit_renderables.append(message) + + @property + def focused(self) -> Widget | None: + """The widget that is focused on the currently active screen, or `None`. + + Focused widgets receive keyboard input. + + Returns: + The currently focused widget, or `None` if nothing is focused. + """ + focused = self.screen.focused + if focused is not None and focused.loading: + return None + return focused + + @property + def active_bindings(self) -> dict[str, ActiveBinding]: + """Get currently active bindings. + + If no widget is focused, then app-level bindings are returned. + If a widget is focused, then any bindings present in the active screen and app are merged and returned. + + This property may be used to inspect current bindings. + + Returns: + A dict that maps keys on to binding information. + """ + return self.screen.active_bindings + + def get_system_commands(self, screen: Screen) -> Iterable[SystemCommand]: + """A generator of system commands used in the command palette. + + Args: + screen: The screen where the command palette was invoked from. + + Implement this method in your App subclass if you want to add custom commands. + Here is an example: + + ```python + def get_system_commands(self, screen: Screen) -> Iterable[SystemCommand]: + yield from super().get_system_commands(screen) + yield SystemCommand("Bell", "Ring the bell", self.bell) + ``` + + !!! note + Requires that [`SystemCommandsProvider`][textual.system_commands.SystemCommandsProvider] is in `App.COMMANDS` class variable. + + Yields: + [SystemCommand][textual.app.SystemCommand] instances. + """ + if not self.ansi_color: + yield SystemCommand( + "Theme", + "Change the current theme", + self.action_change_theme, + ) + yield SystemCommand( + "Quit", + "Quit the application as soon as possible", + self.action_quit, + ) + + if screen.query("HelpPanel"): + yield SystemCommand( + "Keys", + "Hide the keys and widget help panel", + self.action_hide_help_panel, + ) + else: + yield SystemCommand( + "Keys", + "Show help for the focused widget and a summary of available keys", + self.action_show_help_panel, + ) + + if screen.maximized is not None: + yield SystemCommand( + "Minimize", + "Minimize the widget and restore to normal size", + screen.action_minimize, + ) + elif screen.focused is not None and screen.focused.allow_maximize: + yield SystemCommand( + "Maximize", "Maximize the focused widget", screen.action_maximize + ) + + yield SystemCommand( + "Screenshot", + "Save an SVG 'screenshot' of the current screen", + lambda: self.set_timer(0.1, self.deliver_screenshot), + ) + + def get_default_screen(self) -> Screen: + """Get the default screen. + + This is called when the App is first composed. The returned screen instance + will be the first screen on the stack. + + Implement this method if you would like to use a custom Screen as the default screen. + + Returns: + A screen instance. + """ + return Screen(id="_default") + + def compose(self) -> ComposeResult: + """Yield child widgets for a container. + + This method should be implemented in a subclass. + """ + yield from () + + def get_theme_variable_defaults(self) -> dict[str, str]: + """Get the default values for the `variables` used in a theme. + + If the currently specified theme doesn't define a value for a variable, + the value specified here will be used as a fallback. + + If a variable is referenced in CSS but does not appear either here + or in the theme, the CSS will fail to parse on startup. + + This method allows applications to define their own variables, beyond + those offered by Textual, which can then be overridden by a Theme. + + Returns: + A mapping of variable name (e.g. "my-button-background-color") to value. + Values can be any valid CSS value, e.g. "red 50%", "auto 90%", + "#ff0000", "rgb(255, 0, 0)", etc. + """ + return {} + + def get_css_variables(self) -> dict[str, str]: + """Get a mapping of variables used to pre-populate CSS. + + May be implemented in a subclass to add new CSS variables. + + Returns: + A mapping of variable name to value. + """ + theme = self.current_theme + # Build the Textual color system from the theme. + # This will contain $secondary, $primary, $background, etc. + variables = theme.to_color_system().generate() + # Apply the additional variables from the theme + variables = {**variables, **(theme.variables)} + theme_variables = self.get_theme_variable_defaults() + + combined_variables = {**theme_variables, **variables} + self.theme_variables = combined_variables + return combined_variables + + def get_theme(self, theme_name: str) -> Theme | None: + """Get a theme by name. + + Args: + theme_name: The name of the theme to get. May also be a comma + separated list of names, to pick the first available theme. + + Returns: + A Theme instance and None if the theme doesn't exist. + """ + theme_names = [token.strip() for token in theme_name.split(",")] + for theme_name in theme_names: + if theme_name in self.available_themes: + return self.available_themes[theme_name] + return None + + def register_theme(self, theme: Theme) -> None: + """Register a theme with the app. + + If the theme already exists, it will be overridden. + + After registering a theme, you can activate it by setting the + `App.theme` attribute. To retrieve a registered theme, use the + `App.get_theme` method. + + Args: + theme: The theme to register. + """ + self._registered_themes[theme.name] = theme + + def unregister_theme(self, theme_name: str) -> None: + """Unregister a theme with the app. + + Args: + theme_name: The name of the theme to unregister. + """ + if theme_name in self._registered_themes: + del self._registered_themes[theme_name] + + @property + def available_themes(self) -> dict[str, Theme]: + """All available themes (all built-in themes plus any that have been registered). + + A dictionary mapping theme names to Theme instances. + """ + return {**self._registered_themes} + + @property + def current_theme(self) -> Theme: + theme = self.get_theme(self.theme) + if theme is None: + theme = self.get_theme("textual-dark") + assert theme is not None # validated by _validate_theme + return theme + + def _validate_theme(self, theme_name: str) -> str: + if theme_name not in self.available_themes: + message = ( + f"Theme {theme_name!r} has not been registered. " + "Call 'App.register_theme' before setting the 'App.theme' attribute." + ) + raise InvalidThemeError(message) + return theme_name + + def _watch_theme(self, theme_name: str) -> None: + """Apply a theme to the application. + + This method is called when the theme reactive attribute is set. + """ + theme = self.current_theme + dark = theme.dark + self.ansi_color = theme_name == "textual-ansi" + self.set_class(dark, "-dark-mode", update=False) + self.set_class(not dark, "-light-mode", update=False) + self._refresh_truecolor_filter(self.ansi_theme) + self._invalidate_css() + self.call_next(partial(self.refresh_css, animate=False)) + self.call_next(self.theme_changed_signal.publish, theme) + + def _invalidate_css(self) -> None: + """Invalidate CSS, so it will be refreshed.""" + self._css_update_count += 1 + + def watch_ansi_theme_dark(self, theme: TerminalTheme) -> None: + if self.current_theme.dark: + self._refresh_truecolor_filter(theme) + self._invalidate_css() + self.call_next(self.refresh_css) + + def watch_ansi_theme_light(self, theme: TerminalTheme) -> None: + if not self.current_theme.dark: + self._refresh_truecolor_filter(theme) + self._invalidate_css() + self.call_next(self.refresh_css) + + @property + def ansi_theme(self) -> TerminalTheme: + """The ANSI TerminalTheme currently being used. + + Defines how colors defined as ANSI (e.g. `magenta`) inside Rich renderables + are mapped to hex codes. + """ + return ( + self.ansi_theme_dark if self.current_theme.dark else self.ansi_theme_light + ) + + def _refresh_truecolor_filter(self, theme: TerminalTheme) -> None: + """Update the ANSI to Truecolor filter, if available, with a new theme mapping. + + Args: + theme: The new terminal theme to use for mapping ANSI to truecolor. + """ + filters = self._filters + for index, filter in enumerate(filters): + if isinstance(filter, ANSIToTruecolor): + filters[index] = ANSIToTruecolor(theme, enabled=not self.ansi_color) + return + + def get_driver_class(self) -> Type[Driver]: + """Get a driver class for this platform. + + This method is called by the constructor, and unlikely to be required when + building a Textual app. + + Returns: + A Driver class which manages input and display. + """ + + driver_class: Type[Driver] + + driver_import = constants.DRIVER + if driver_import is not None: + # The driver class is set from the environment + # Syntax should be foo.bar.baz:MyDriver + module_import, _, driver_symbol = driver_import.partition(":") + driver_module = importlib.import_module(module_import) + driver_class = getattr(driver_module, driver_symbol) + if not inspect.isclass(driver_class) or not issubclass( + driver_class, Driver + ): + raise RuntimeError( + f"Unable to import {driver_import!r}; {driver_class!r} is not a Driver class " + ) + return driver_class + + if WINDOWS: + from memray._vendor.textual.drivers.windows_driver import WindowsDriver + + driver_class = WindowsDriver + else: + from memray._vendor.textual.drivers.linux_driver import LinuxDriver + + driver_class = LinuxDriver + return driver_class + + def __rich_repr__(self) -> rich.repr.Result: + yield "title", self.title + yield "id", self.id, None + if self.name: + yield "name", self.name + if self.classes: + yield "classes", set(self.classes) + pseudo_classes = self.pseudo_classes + if pseudo_classes: + yield "pseudo_classes", set(pseudo_classes) + + @property + def animator(self) -> Animator: + """The animator object.""" + return self._animator + + @property + def screen(self) -> Screen[object]: + """The current active screen. + + Returns: + The currently active (visible) screen. + + Raises: + ScreenStackError: If there are no screens on the stack. + """ + try: + return self._screen_stack[-1] + except KeyError: + raise UnknownModeError(f"No known mode {self._current_mode!r}") from None + except IndexError: + raise ScreenStackError("No screens on stack") from None + + @property + def _background_screens(self) -> list[Screen]: + """A list of screens that may be visible due to background opacity (top-most first, not including current screen).""" + screens: list[Screen] = [] + for screen in reversed(self._screen_stack[:-1]): + screens.append(screen) + if screen.styles.background.a == 1: + break + background_screens = screens[::-1] + return background_screens + + @property + def size(self) -> Size: + """The size of the terminal. + + Returns: + Size of the terminal. + """ + if self._size is not None: + return self._size + if self._driver is not None and self._driver._size is not None: + width, height = self._driver._size + else: + width, height = self.console.size + return Size(width, height) + + @property + def viewport_size(self) -> Size: + """Get the viewport size (size of the screen).""" + try: + return self.screen.size + except (ScreenStackError, NoScreen): + return self.size + + def _get_inline_height(self) -> int: + """Get the inline height (height when in inline mode). + + Returns: + Height in lines. + """ + size = self.size + return max(screen._get_inline_height(size) for screen in self._screen_stack) + + @property + def log(self) -> Logger: + """The textual logger. + + Example: + ```python + self.log("Hello, World!") + self.log(self.tree) + ``` + + Returns: + A Textual logger. + """ + return self._logger + + def _log( + self, + group: LogGroup, + verbosity: LogVerbosity, + _textual_calling_frame: inspect.Traceback, + *objects: Any, + **kwargs, + ) -> None: + """Write to logs or devtools. + + Positional args will be logged. Keyword args will be prefixed with the key. + + Example: + ```python + data = [1,2,3] + self.log("Hello, World", state=data) + self.log(self.tree) + self.log(locals()) + ``` + + Args: + verbosity: Verbosity level 0-3. + """ + + devtools = self.devtools + if devtools is None or not devtools.is_connected: + return + + if verbosity.value > LogVerbosity.NORMAL.value and not devtools.verbose: + return + + try: + from textual_dev.client import DevtoolsLog + + if len(objects) == 1 and not kwargs: + devtools.log( + DevtoolsLog(objects, caller=_textual_calling_frame), + group, + verbosity, + ) + else: + output = " ".join(str(arg) for arg in objects) + if kwargs: + key_values = " ".join( + f"{key}={value!r}" for key, value in kwargs.items() + ) + output = f"{output} {key_values}" if output else key_values + devtools.log( + DevtoolsLog(output, caller=_textual_calling_frame), + group, + verbosity, + ) + except Exception as error: + self._handle_exception(error) + + def get_loading_widget(self) -> Widget: + """Get a widget to be used as a loading indicator. + + Extend this method if you want to display the loading state a little differently. + + Returns: + A widget to display a loading state. + """ + from memray._vendor.textual.widgets import LoadingIndicator + + return LoadingIndicator() + + def copy_to_clipboard(self, text: str) -> None: + """Copy text to the clipboard. + + !!! note + + This does not work on macOS Terminal, but will work on most other terminals. + + Args: + text: Text you wish to copy to the clipboard. + """ + self._clipboard = text + if self._driver is None: + return + import base64 + + base64_text = base64.b64encode(text.encode("utf-8")).decode("utf-8") + self._driver.write(f"\x1b]52;c;{base64_text}\a") + + def call_from_thread( + self, + callback: Callable[..., CallThreadReturnType | Awaitable[CallThreadReturnType]], + *args: Any, + **kwargs: Any, + ) -> CallThreadReturnType: + """Run a callable from another thread, and return the result. + + Like asyncio apps in general, Textual apps are not thread-safe. If you call methods + or set attributes on Textual objects from a thread, you may get unpredictable results. + + This method will ensure that your code runs within the correct context. + + !!! tip + + Consider using [post_message][textual.message_pump.MessagePump.post_message] which is also thread-safe. + + Args: + callback: A callable to run. + *args: Arguments to the callback. + **kwargs: Keyword arguments for the callback. + + Raises: + RuntimeError: If the app isn't running or if this method is called from the same + thread where the app is running. + + Returns: + The result of the callback. + """ + + if self._loop is None: + raise RuntimeError("App is not running") + + if self._thread_id == threading.get_ident(): + raise RuntimeError( + "The `call_from_thread` method must run in a different thread from the app" + ) + + callback_with_args = partial(callback, *args, **kwargs) + + async def run_callback() -> CallThreadReturnType: + """Run the callback, set the result or error on the future.""" + with self._context(): + return await invoke(callback_with_args) + + # Post the message to the main loop + future: Future[CallThreadReturnType] = asyncio.run_coroutine_threadsafe( + run_callback(), loop=self._loop + ) + result = future.result() + return result + + def action_change_theme(self) -> None: + """An [action](/guide/actions) to change the current theme.""" + self.search_themes() + + def action_screenshot( + self, filename: str | None = None, path: str | None = None + ) -> None: + """This [action](/guide/actions) will save an SVG file containing the current contents of the screen. + + Args: + filename: Filename of screenshot, or None to auto-generate. + path: Path to directory. Defaults to the user's Downloads directory. + """ + self.deliver_screenshot(filename, path) + + def export_screenshot( + self, + *, + title: str | None = None, + simplify: bool = False, + ) -> str: + """Export an SVG screenshot of the current screen. + + See also [save_screenshot][textual.app.App.save_screenshot] which writes the screenshot to a file. + + Args: + title: The title of the exported screenshot or None + to use app title. + simplify: Simplify the segments by combining contiguous segments with the same style. + """ + assert self._driver is not None, "App must be running" + width, height = self.size + + console = Console( + width=width, + height=height, + file=io.StringIO(), + force_terminal=True, + color_system="truecolor", + record=True, + legacy_windows=False, + safe_box=False, + ) + screen_render = self.screen._compositor.render_update( + full=True, screen_stack=self.app._background_screens, simplify=simplify + ) + console.print(screen_render) + return console.export_svg(title=title or self.title) + + def save_screenshot( + self, + filename: str | None = None, + path: str | None = None, + time_format: str | None = None, + ) -> str: + """Save an SVG screenshot of the current screen. + + Args: + filename: Filename of SVG screenshot, or None to auto-generate + a filename with the date and time. + path: Path to directory for output. Defaults to current working directory. + time_format: Date and time format to use if filename is None. + Defaults to a format like ISO 8601 with some reserved characters replaced with underscores. + + Returns: + Filename of screenshot. + """ + path = path or "./" + if not filename: + svg_filename = generate_datetime_filename(self.title, ".svg", time_format) + else: + svg_filename = filename + svg_path = os.path.expanduser(os.path.join(path, svg_filename)) + screenshot_svg = self.export_screenshot() + with open(svg_path, "w", encoding="utf-8") as svg_file: + svg_file.write(screenshot_svg) + return svg_path + + def deliver_screenshot( + self, + filename: str | None = None, + path: str | None = None, + time_format: str | None = None, + ) -> str | None: + """Deliver a screenshot of the app. + + This will save the screenshot when running locally, or serve it when the app + is running in a web browser. + + Args: + filename: Filename of SVG screenshot, or None to auto-generate + a filename with the date and time. + path: Path to directory for output when saving locally (not used when app is running in the browser). + Defaults to current working directory. + time_format: Date and time format to use if filename is None. + Defaults to a format like ISO 8601 with some reserved characters replaced with underscores. + + Returns: + The delivery key that uniquely identifies the file delivery. + """ + if not filename: + svg_filename = generate_datetime_filename(self.title, ".svg", time_format) + else: + svg_filename = filename + screenshot_svg = self.export_screenshot() + return self.deliver_text( + io.StringIO(screenshot_svg), + save_directory=path, + save_filename=svg_filename, + open_method="browser", + mime_type="image/svg+xml", + name="screenshot", + ) + + def search_commands( + self, + commands: Sequence[CommandListItem], + placeholder: str = "Search for commands…", + ) -> AwaitMount: + """Show a list of commands in the app. + + Args: + commands: A list of SimpleCommand instances. + placeholder: Placeholder text for the search field. + + Returns: + AwaitMount: An awaitable that resolves when the commands are shown. + """ + return self.push_screen( + CommandPalette( + providers=[SimpleProvider(self.screen, commands)], + placeholder=placeholder, + ) + ) + + def search_themes(self) -> None: + """Show a fuzzy search command palette containing all registered themes. + + Selecting a theme in the list will change the app's theme. + """ + self.push_screen( + CommandPalette( + providers=[ThemeProvider], + placeholder="Search for themes…", + ), + ) + + def bind( + self, + keys: str, + action: str, + *, + description: str = "", + show: bool = True, + key_display: str | None = None, + ) -> None: + """Bind a key to an action. + + !!! warning + This method may be private or removed in a future version of Textual. + See [dynamic actions](/guide/actions#dynamic-actions) for a more flexible alternative to updating bindings. + + Args: + keys: A comma separated list of keys, i.e. + action: Action to bind to. + description: Short description of action. + show: Show key in UI. + key_display: Replacement text for key, or None to use default. + """ + self._bindings.bind( + keys, action, description, show=show, key_display=key_display + ) + + def get_key_display(self, binding: Binding) -> str: + """Format a bound key for display in footer / key panel etc. + + !!! note + You can implement this in a subclass if you want to change how keys are displayed in your app. + + Args: + binding: A Binding. + + Returns: + A string used to represent the key. + """ + # Dev has overridden the key display, so use that + if binding.key_display: + return binding.key_display + + # Extract modifiers + modifiers, key = binding.parse_key() + + # Format the key (replace unicode names with character) + key = format_key(key) + + # Convert ctrl modifier to caret + if "ctrl" in modifiers: + modifiers.pop(modifiers.index("ctrl")) + key = f"^{key}" + # Join everything with + + key_tokens = modifiers + [key] + return "+".join(key_tokens) + + async def _press_keys(self, keys: Iterable[str]) -> None: + """A task to send key events.""" + import unicodedata + + app = self + driver = app._driver + assert driver is not None + for key in keys: + if key.startswith("wait:"): + _, wait_ms = key.split(":") + await asyncio.sleep(float(wait_ms) / 1000) + await app._animator.wait_until_complete() + else: + if len(key) == 1 and not key.isalnum(): + key = _character_to_key(key) + original_key = REPLACED_KEYS.get(key, key) + char: str | None + try: + char = unicodedata.lookup(_get_unicode_name_from_key(original_key)) + except KeyError: + char = key if len(key) == 1 else None + key_event = events.Key(key, char) + key_event.set_sender(app) + driver.send_message(key_event) + await wait_for_idle(0) + await app._animator.wait_until_complete() + await wait_for_idle(0) + + def _flush(self, stderr: bool = False) -> None: + """Called when stdout or stderr is flushed. + + Args: + stderr: True if the print was to stderr, or False for stdout. + + """ + if self._devtools_redirector is not None: + self._devtools_redirector.flush() + + def _print(self, text: str, stderr: bool = False) -> None: + """Called with captured print. + + Dispatches printed content to appropriate destinations: devtools, + widgets currently capturing output, stdout/stderr. + + Args: + text: Text that has been printed. + stderr: True if the print was to stderr, or False for stdout. + """ + if self._devtools_redirector is not None: + current_frame = inspect.currentframe() + self._devtools_redirector.write( + text, current_frame.f_back if current_frame is not None else None + ) + + # If we're in headless mode, we want printed text to still reach stdout/stderr. + if self.is_headless: + target_stream = self._original_stderr if stderr else self._original_stdout + target_stream.write(text) + + # Send Print events to all widgets that are currently capturing output. + for target, (_stdout, _stderr) in self._capture_print.items(): + if (_stderr and stderr) or (_stdout and not stderr): + target.post_message(events.Print(text, stderr=stderr)) + + def begin_capture_print( + self, target: MessageTarget, stdout: bool = True, stderr: bool = True + ) -> None: + """Capture content that is printed (or written to stdout / stderr). + + If printing is captured, the `target` will be sent an [events.Print][textual.events.Print] message. + + Args: + target: The widget where print content will be sent. + stdout: Capture stdout. + stderr: Capture stderr. + """ + if not stdout and not stderr: + self.end_capture_print(target) + else: + self._capture_print[target] = (stdout, stderr) + + def end_capture_print(self, target: MessageTarget) -> None: + """End capturing of prints. + + Args: + target: The widget that was capturing prints. + """ + self._capture_print.pop(target) + + @asynccontextmanager + async def run_test( + self, + *, + headless: bool = True, + size: tuple[int, int] | None = (80, 24), + tooltips: bool = False, + notifications: bool = False, + message_hook: Callable[[Message], None] | None = None, + ) -> AsyncGenerator[Pilot[ReturnType], None]: + """An asynchronous context manager for testing apps. + + !!! tip + + See the guide for [testing](/guide/testing) Textual apps. + + Use this to run your app in "headless" mode (no output) and drive the app via a [Pilot][textual.pilot.Pilot] object. + + Example: + + ```python + async with app.run_test() as pilot: + await pilot.click("#Button.ok") + assert ... + ``` + + Args: + headless: Run in headless mode (no output or input). + size: Force terminal size to `(WIDTH, HEIGHT)`, + or None to auto-detect. + tooltips: Enable tooltips when testing. + notifications: Enable notifications when testing. + message_hook: An optional callback that will be called each time any message arrives at any + message pump in the app. + """ + from memray._vendor.textual.pilot import Pilot + + app = self + app._disable_tooltips = not tooltips + app._disable_notifications = not notifications + app_ready_event = asyncio.Event() + + def on_app_ready() -> None: + """Called when app is ready to process events.""" + app_ready_event.set() + + async def run_app(app: App[ReturnType]) -> None: + """Run the apps message loop. + + Args: + app: App to run. + """ + + with app._context(): + try: + if message_hook is not None: + message_hook_context_var.set(message_hook) + app._loop = asyncio.get_running_loop() + app._thread_id = threading.get_ident() + await app._process_messages( + ready_callback=on_app_ready, + headless=headless, + terminal_size=size, + ) + finally: + app_ready_event.set() + + # Launch the app in the "background" + + self._task = app_task = create_task(run_app(app), name=f"run_test {app}") + + # Wait until the app has performed all startup routines. + await app_ready_event.wait() + with app._context(): + # Context manager returns pilot object to manipulate the app + try: + pilot = Pilot(app) + await pilot._wait_for_screen() + yield pilot + finally: + await asyncio.sleep(0) + # Shutdown the app cleanly + await app._shutdown() + await app_task + # Re-raise the exception which caused panic so test frameworks are aware + if self._exception: + raise self._exception + + async def run_async( + self, + *, + headless: bool = False, + inline: bool = False, + inline_no_clear: bool = False, + mouse: bool = True, + size: tuple[int, int] | None = None, + auto_pilot: AutopilotCallbackType | None = None, + ) -> ReturnType | None: + """Run the app asynchronously. + + Args: + headless: Run in headless mode (no output). + inline: Run the app inline (under the prompt). + inline_no_clear: Don't clear the app output when exiting an inline app. + mouse: Enable mouse support. + size: Force terminal size to `(WIDTH, HEIGHT)`, + or None to auto-detect. + auto_pilot: An autopilot coroutine. + + Returns: + App return value. + """ + from memray._vendor.textual.pilot import Pilot + + app = self + auto_pilot_task: Task | None = None + + if auto_pilot is None and constants.PRESS: + keys = constants.PRESS.split(",") + + async def press_keys(pilot: Pilot[ReturnType]) -> None: + """Auto press keys.""" + await pilot.press(*keys) + + auto_pilot = press_keys + + async def app_ready() -> None: + """Called by the message loop when the app is ready.""" + nonlocal auto_pilot_task + + if auto_pilot is not None: + + async def run_auto_pilot( + auto_pilot: AutopilotCallbackType, pilot: Pilot + ) -> None: + with self._context(): + try: + await auto_pilot(pilot) + except Exception: + app.exit() + raise + + pilot = Pilot(app) + auto_pilot_task = create_task( + run_auto_pilot(auto_pilot, pilot), name=repr(pilot) + ) + + self._thread_init() + + loop = app._loop = asyncio.get_running_loop() + if hasattr(asyncio, "eager_task_factory"): + loop.set_task_factory(asyncio.eager_task_factory) + with app._context(): + try: + await app._process_messages( + ready_callback=None if auto_pilot is None else app_ready, + headless=headless, + inline=inline, + inline_no_clear=inline_no_clear, + mouse=mouse, + terminal_size=size, + ) + finally: + try: + if auto_pilot_task is not None: + await auto_pilot_task + finally: + try: + await asyncio.shield(app._shutdown()) + except asyncio.CancelledError: + pass + app._loop = None + app._thread_id = 0 + + return app.return_value + + def run( + self, + *, + headless: bool = False, + inline: bool = False, + inline_no_clear: bool = False, + mouse: bool = True, + size: tuple[int, int] | None = None, + auto_pilot: AutopilotCallbackType | None = None, + loop: AbstractEventLoop | None = None, + ) -> ReturnType | None: + """Run the app. + + Args: + headless: Run in headless mode (no output). + inline: Run the app inline (under the prompt). + inline_no_clear: Don't clear the app output when exiting an inline app. + mouse: Enable mouse support. + size: Force terminal size to `(WIDTH, HEIGHT)`, + or None to auto-detect. + auto_pilot: An auto pilot coroutine. + loop: Asyncio loop instance, or `None` to use default. + Returns: + App return value. + """ + + async def run_app() -> ReturnType | None: + """Run the app.""" + return await self.run_async( + headless=headless, + inline=inline, + inline_no_clear=inline_no_clear, + mouse=mouse, + size=size, + auto_pilot=auto_pilot, + ) + + if loop is None: + if _ASYNCIO_GET_EVENT_LOOP_IS_DEPRECATED: + # N.B. This does work with Python<3.10, but global Locks, Events, etc + # eagerly bind the event loop, and result in Future bound to wrong + # loop errors. + return asyncio.run(run_app()) + try: + global_loop = asyncio.get_event_loop() + except RuntimeError: + # the global event loop may have been destroyed by someone running + # asyncio.run(), or asyncio.set_event_loop(None), in which case + # we need to use asyncio.run() also. (We run this outside the + # context of an exception handler) + pass + else: + return global_loop.run_until_complete(run_app()) + return asyncio.run(run_app()) + return loop.run_until_complete(run_app()) + + async def _on_css_change(self) -> None: + """Callback for the file monitor, called when CSS files change.""" + css_paths = ( + self.css_monitor._paths if self.css_monitor is not None else self.css_path + ) + if css_paths: + try: + time = perf_counter() + stylesheet = self.stylesheet.copy() + try: + stylesheet.read_all(css_paths) + except StylesheetError as error: + # If one of the CSS paths is no longer available (or perhaps temporarily unavailable), + # we'll end up with partial CSS, which is probably confusing more than anything. We opt to do + # nothing here, knowing that we'll retry again very soon, on the next file monitor invocation. + # Related issue: https://github.com/Textualize/textual/issues/3996 + self.log.warning(str(error)) + return + stylesheet.parse() + elapsed = (perf_counter() - time) * 1000 + if self._css_has_errors: + from rich.panel import Panel + + self.log.system( + Panel( + "CSS files successfully loaded after previous error:\n\n- " + + "\n- ".join(str(path) for path in css_paths), + style="green", + border_style="green", + ) + ) + self.log.system( + f" loaded {len(css_paths)} CSS files in {elapsed:.0f} ms" + ) + except Exception as error: + # TODO: Catch specific exceptions + self._css_has_errors = True + self.log.error(error) + self.bell() + else: + self._css_has_errors = False + self.stylesheet = stylesheet + self.stylesheet.update(self) + for screen in self.screen_stack: + self.stylesheet.update(screen) + + def render(self) -> RenderResult: + """Render method, inherited from widget, to render the screen's background. + + May be overridden to customize background visuals. + + """ + return Blank(self.styles.background) + + ExpectType = TypeVar("ExpectType", bound=Widget) + + if TYPE_CHECKING: + + @overload + def get_child_by_id(self, id: str) -> Widget: ... + + @overload + def get_child_by_id( + self, id: str, expect_type: type[ExpectType] + ) -> ExpectType: ... + + def get_child_by_id( + self, id: str, expect_type: type[ExpectType] | None = None + ) -> ExpectType | Widget: + """Get the first child (immediate descendant) of this DOMNode with the given ID. + + Args: + id: The ID of the node to search for. + expect_type: Require the object be of the supplied type, + or use `None` to apply no type restriction. + + Returns: + The first child of this node with the specified ID. + + Raises: + NoMatches: If no children could be found for this ID. + WrongType: If the wrong type was found. + """ + return ( + self.screen.get_child_by_id(id) + if expect_type is None + else self.screen.get_child_by_id(id, expect_type) + ) + + if TYPE_CHECKING: + + @overload + def get_widget_by_id(self, id: str) -> Widget: ... + + @overload + def get_widget_by_id( + self, id: str, expect_type: type[ExpectType] + ) -> ExpectType: ... + + def get_widget_by_id( + self, id: str, expect_type: type[ExpectType] | None = None + ) -> ExpectType | Widget: + """Get the first descendant widget with the given ID. + + Performs a breadth-first search rooted at the current screen. + It will not return the Screen if that matches the ID. + To get the screen, use `self.screen`. + + Args: + id: The ID to search for in the subtree + expect_type: Require the object be of the supplied type, or None for any type. + Defaults to None. + + Returns: + The first descendant encountered with this ID. + + Raises: + NoMatches: if no children could be found for this ID + WrongType: if the wrong type was found. + """ + return ( + self.screen.get_widget_by_id(id) + if expect_type is None + else self.screen.get_widget_by_id(id, expect_type) + ) + + def get_child_by_type(self, expect_type: type[ExpectType]) -> ExpectType: + """Get a child of a give type. + + Args: + expect_type: The type of the expected child. + + Raises: + NoMatches: If no valid child is found. + + Returns: + A widget. + """ + return self.screen.get_child_by_type(expect_type) + + def update_styles(self, node: DOMNode, animate: bool = True) -> None: + """Immediately update the styles of this node and all descendant nodes. + + Called by Textual whenever CSS classes / pseudo classes change. + For example, when you hover over a button, the :hover pseudo class + will be added, and this method is called to apply the corresponding + :hover styles. + + Args: + node: Node to update. + animate: Enable animation? + """ + if isinstance(node, App): + for screen in reversed(self.screen_stack): + screen.update_node_styles(animate=animate) + if not (screen.is_modal and screen.styles.background.a < 1): + break + else: + descendants = node.walk_children(with_self=True) + self.stylesheet.update_nodes(descendants, animate=animate) + + def mount( + self, + *widgets: Widget, + before: int | str | Widget | None = None, + after: int | str | Widget | None = None, + ) -> AwaitMount: + """Mount the given widgets relative to the app's screen. + + Args: + *widgets: The widget(s) to mount. + before: Optional location to mount before. An `int` is the index + of the child to mount before, a `str` is a `query_one` query to + find the widget to mount before. + after: Optional location to mount after. An `int` is the index + of the child to mount after, a `str` is a `query_one` query to + find the widget to mount after. + + Returns: + An awaitable object that waits for widgets to be mounted. + + Raises: + MountError: If there is a problem with the mount request. + + Note: + Only one of `before` or `after` can be provided. If both are + provided a `MountError` will be raised. + """ + return self.screen.mount(*widgets, before=before, after=after) + + def mount_all( + self, + widgets: Iterable[Widget], + *, + before: int | str | Widget | None = None, + after: int | str | Widget | None = None, + ) -> AwaitMount: + """Mount widgets from an iterable. + + Args: + widgets: An iterable of widgets. + before: Optional location to mount before. An `int` is the index + of the child to mount before, a `str` is a `query_one` query to + find the widget to mount before. + after: Optional location to mount after. An `int` is the index + of the child to mount after, a `str` is a `query_one` query to + find the widget to mount after. + + Returns: + An awaitable object that waits for widgets to be mounted. + + Raises: + MountError: If there is a problem with the mount request. + + Note: + Only one of `before` or `after` can be provided. If both are + provided a `MountError` will be raised. + """ + return self.mount(*widgets, before=before, after=after) + + def _init_mode(self, mode: str) -> AwaitMount: + """Do internal initialization of a new screen stack mode. + + Args: + mode: Name of the mode. + + Returns: + An optionally awaitable object which can be awaited until the screen + associated with the mode has been mounted. + """ + + stack = self._screen_stacks.get(mode, []) + if stack: + # Mode already exists + # Return an dummy await + return AwaitMount(stack[0], []) + + if mode in self._modes: + # Mode is defined in MODES + _screen = self._modes[mode] + if isinstance(_screen, Screen): + raise TypeError( + "MODES cannot contain instances, use a type instead " + f"(got instance of {type(_screen).__name__} for {mode!r})" + ) + new_screen: Screen | str = _screen() if callable(_screen) else _screen + screen, await_mount = self._get_screen(new_screen) + stack.append(screen) + self._load_screen_css(screen) + if screen._css_update_count != self._css_update_count: + self.refresh_css() + + screen.post_message(events.ScreenResume()) + else: + # Mode is not defined + screen = self.get_default_screen() + stack.append(screen) + self._register(self, screen) + screen.post_message(events.ScreenResume()) + await_mount = AwaitMount(stack[0], []) + + screen._screen_resized(self.size) + + self._screen_stacks[mode] = stack + return await_mount + + def switch_mode(self, mode: str) -> AwaitMount: + """Switch to a given mode. + + Args: + mode: The mode to switch to. + + Returns: + An optionally awaitable object which waits for the screen associated + with the mode to be mounted. + + Raises: + UnknownModeError: If trying to switch to an unknown mode. + + """ + + if mode == self._current_mode: + return AwaitMount(self.screen, []) + + if mode not in self._modes: + raise UnknownModeError(f"No known mode {mode!r}") + + self.delay_update() + + self.screen.post_message(events.ScreenSuspend()) + self.screen.refresh() + + if mode not in self._screen_stacks: + await_mount = self._init_mode(mode) + else: + await_mount = AwaitMount(self.screen, []) + + self._current_mode = mode + if self.screen._css_update_count != self._css_update_count: + self.refresh_css() + + self.mode_change_signal.publish(mode) + self.screen_change_signal.publish(self.screen) + + self.screen._screen_resized(self.size) + + self.screen.post_message(events.ScreenResume()) + + self.log.system(f"{self._current_mode!r} is the current mode") + self.log.system(f"{self.screen} is active") + + return await_mount + + def add_mode(self, mode: str, base_screen: str | Callable[[], Screen]) -> None: + """Adds a mode and its corresponding base screen to the app. + + Args: + mode: The new mode. + base_screen: The base screen associated with the given mode. + + Raises: + InvalidModeError: If the name of the mode is not valid/duplicated. + """ + if mode == "_default": + raise InvalidModeError("Cannot use '_default' as a custom mode.") + elif mode in self._modes: + raise InvalidModeError(f"Duplicated mode name {mode!r}.") + + if isinstance(base_screen, Screen): + raise TypeError( + "add_mode() must be called with a Screen type, not an instance" + f" (got instance of {type(base_screen).__name__})" + ) + self._modes[mode] = base_screen + + def remove_mode(self, mode: str) -> AwaitComplete: + """Removes a mode from the app. + + Screens that are running in the stack of that mode are scheduled for pruning. + + Args: + mode: The mode to remove. It can't be the active mode. + + Raises: + ActiveModeError: If trying to remove the active mode. + UnknownModeError: If trying to remove an unknown mode. + """ + if mode == self._current_mode: + raise ActiveModeError(f"Can't remove active mode {mode!r}") + elif mode not in self._modes: + raise UnknownModeError(f"Unknown mode {mode!r}") + else: + del self._modes[mode] + + if mode not in self._screen_stacks: + return AwaitComplete.nothing() + + stack = self._screen_stacks[mode] + del self._screen_stacks[mode] + + async def remove_screens() -> None: + """Remove screens.""" + for screen in reversed(stack): + await self._replace_screen(screen) + + return AwaitComplete(remove_screens()).call_next(self) + + def is_screen_installed(self, screen: Screen | str) -> bool: + """Check if a given screen has been installed. + + Args: + screen: Either a Screen object or screen name (the `name` argument when installed). + + Returns: + True if the screen is currently installed, + """ + if isinstance(screen, str): + return screen in self._installed_screens + else: + return screen in self._installed_screens.values() + + @overload + def get_screen(self, screen: ScreenType) -> ScreenType: ... + + @overload + def get_screen(self, screen: str) -> Screen: ... + + @overload + def get_screen( + self, screen: str, screen_class: Type[ScreenType] | None = None + ) -> ScreenType: ... + + @overload + def get_screen( + self, screen: ScreenType, screen_class: Type[ScreenType] | None = None + ) -> ScreenType: ... + + def get_screen( + self, screen: Screen | str, screen_class: Type[Screen] | None = None + ) -> Screen: + """Get an installed screen. + + Example: + ```python + my_screen = self.get_screen("settings", MyScreen) + ``` + + Args: + screen: Either a Screen object or screen name (the `name` argument when installed). + screen_class: Class of expected screen, or `None` for any screen class. + + Raises: + KeyError: If the named screen doesn't exist. + + Returns: + A screen instance. + """ + if isinstance(screen, str): + try: + next_screen = self._installed_screens[screen] + except KeyError: + raise KeyError(f"No screen called {screen!r} installed") from None + if callable(next_screen): + next_screen = next_screen() + self._installed_screens[screen] = next_screen + else: + next_screen = screen + if screen_class is not None and not isinstance(next_screen, screen_class): + raise TypeError( + f"Expected a screen of type {screen_class}, got {type(next_screen)}" + ) + return next_screen + + def _get_screen(self, screen: Screen | str) -> tuple[Screen, AwaitMount]: + """Get an installed screen and an AwaitMount object. + + If the screen isn't running, it will be registered before it is run. + + Args: + screen: Either a Screen object or screen name (the `name` argument when installed). + + Raises: + KeyError: If the named screen doesn't exist. + + Returns: + A screen instance and an awaitable that awaits the children mounting. + """ + _screen = self.get_screen(screen) + if not _screen.is_running: + widgets = self._register(self, _screen) + await_mount = AwaitMount(_screen, widgets) + self.call_next(await_mount) + return (_screen, await_mount) + else: + await_mount = AwaitMount(_screen, []) + self.call_next(await_mount) + return (_screen, await_mount) + + def _load_screen_css(self, screen: Screen): + """Loads the CSS associated with a screen.""" + + if self.css_monitor is not None: + self.css_monitor.add_paths(screen.css_path) + + update = False + for path in screen.css_path: + if not self.stylesheet.has_source(str(path), ""): + self.stylesheet.read(path) + update = True + if screen.CSS: + try: + screen_path = inspect.getfile(screen.__class__) + except (TypeError, OSError): + screen_path = "" + screen_class_var = f"{screen.__class__.__name__}.CSS" + read_from = (screen_path, screen_class_var) + if not self.stylesheet.has_source(screen_path, screen_class_var): + self.stylesheet.add_source( + screen.CSS, + read_from=read_from, + is_default_css=False, + scope=screen._css_type_name if screen.SCOPED_CSS else "", + ) + update = True + if update: + self.stylesheet.reparse() + self.stylesheet.update(self) + + async def _replace_screen(self, screen: Screen) -> Screen: + """Handle the replaced screen. + + Args: + screen: A screen object. + + Returns: + The screen that was replaced. + """ + if self._screen_stack: + self.screen.refresh() + screen.post_message(events.ScreenSuspend()) + self.log.system(f"{screen} SUSPENDED") + if not self.is_screen_installed(screen) and all( + screen not in stack for stack in self._screen_stacks.values() + ): + self.capture_mouse(None) + await screen.remove() + self.log.system(f"{screen} REMOVED") + return screen + + if TYPE_CHECKING: + + @overload + def push_screen( + self, + screen: Screen[ScreenResultType] | str, + callback: ScreenResultCallbackType[ScreenResultType] | None = None, + wait_for_dismiss: Literal[False] = False, + *, + mode: str | None = None, + ) -> AwaitMount: ... + + @overload + def push_screen( + self, + screen: Screen[ScreenResultType] | str, + callback: ScreenResultCallbackType[ScreenResultType] | None = None, + wait_for_dismiss: Literal[True] = True, + *, + mode: str | None = None, + ) -> asyncio.Future[ScreenResultType]: ... + + def push_screen( + self, + screen: Screen[ScreenResultType] | str, + callback: ScreenResultCallbackType[ScreenResultType] | None = None, + wait_for_dismiss: bool = False, + *, + mode: str | None = None, + ) -> AwaitMount | asyncio.Future[ScreenResultType]: + """Push a new [screen](/guide/screens) on the screen stack, making it the current screen. + + Args: + screen: A Screen instance or the name of an installed screen. + callback: An optional callback function that will be called if the screen is [dismissed][textual.screen.Screen.dismiss] with a result. + wait_for_dismiss: If `True`, awaiting this method will return the dismiss value from the screen. When set to `False`, awaiting + this method will wait for the screen to be mounted. Note that `wait_for_dismiss` should only be set to `True` when running in a worker. + mode: The mode to push the screen to, or `None` to the currently active mode. If pushing to a non-current mode, the screen will not be immediately visible. + + Raises: + NoActiveWorker: If using `wait_for_dismiss` outside of a worker. + + Returns: + An optional awaitable that awaits the mounting of the screen and its children, or an asyncio Future + to await the result of the screen. + """ + if not isinstance(screen, (Screen, str)): + raise TypeError( + f"push_screen requires a Screen instance or str; not {screen!r}" + ) + + try: + loop = asyncio.get_running_loop() + except RuntimeError: + # Mainly for testing, when push_screen isn't called in an async context + future: asyncio.Future[ScreenResultType] = asyncio.Future() + else: + future = loop.create_future() + + if mode is None: + mode = self._current_mode + + try: + screen_stack = self._screen_stacks[mode] + except KeyError: + raise UnknownModeError(f"No such mode {mode!r}") + + if screen_stack and screen_stack[-1].is_active: + self.app.capture_mouse(None) + mode_screen = screen_stack[-1] + mode_screen.post_message(events.ScreenSuspend()) + mode_screen.refresh() + next_screen, await_mount = self._get_screen(screen) + try: + message_pump = active_message_pump.get() + except LookupError: + message_pump = self.app + + next_screen._push_result_callback(message_pump, callback, future) + self._load_screen_css(next_screen) + next_screen._update_auto_focus() + screen_stack.append(next_screen) + if next_screen.is_active: + next_screen.post_message(events.ScreenResume()) + self.screen_change_signal.publish(next_screen) + if wait_for_dismiss: + try: + get_current_worker() + except NoActiveWorker: + raise NoActiveWorker( + "push_screen must be run from a worker when `wait_for_dismiss` is True" + ) from None + return future + else: + return await_mount + + if TYPE_CHECKING: + + @overload + async def push_screen_wait( + self, screen: Screen[ScreenResultType], *, mode: str | None = None + ) -> ScreenResultType: ... + + @overload + async def push_screen_wait( + self, screen: str, *, mode: str | None = None + ) -> Any: ... + + async def push_screen_wait( + self, screen: Screen[ScreenResultType] | str, *, mode: str | None = None + ) -> ScreenResultType | Any: + """Push a screen and wait for the result (received from [`Screen.dismiss`][textual.screen.Screen.dismiss]). + + Note that this method may only be called when running in a worker. + + Args: + screen: A screen or the name of an installed screen. + mode: The mode to push the screen to, or `None` to the currently active mode. If pushing to a non-current mode, the screen will not be immediately visible. + + Returns: + The screen's result. + """ + await self._flush_next_callbacks() + # The shield prevents the cancellation of the current task from canceling the push_screen awaitable + return await asyncio.shield( + self.push_screen(screen, wait_for_dismiss=True, mode=mode) + ) + + def switch_screen(self, screen: Screen | str) -> AwaitComplete: + """Switch to another [screen](/guide/screens) by replacing the top of the screen stack with a new screen. + + Args: + screen: Either a Screen object or screen name (the `name` argument when installed). + """ + if not isinstance(screen, (Screen, str)): + raise TypeError( + f"switch_screen requires a Screen instance or str; not {screen!r}" + ) + + next_screen, await_mount = self._get_screen(screen) + if screen is self.screen or next_screen is self.screen: + self.log.system(f"Screen {screen} is already current.") + return AwaitComplete.nothing() + + self.app.capture_mouse(None) + top_screen = self._screen_stack.pop() + + top_screen._pop_result_callback() + self._load_screen_css(next_screen) + self._screen_stack.append(next_screen) + self.screen.post_message(events.ScreenResume()) + self.screen._push_result_callback(self.screen, None) + self.screen_change_signal.publish(self.screen) + self.log.system(f"{self.screen} is current (SWITCHED)") + + async def do_switch() -> None: + """Task to perform switch.""" + + await await_mount() + await self._replace_screen(top_screen) + + return AwaitComplete(do_switch()).call_next(self) + + def install_screen(self, screen: Screen, name: str) -> None: + """Install a screen. + + Installing a screen prevents Textual from destroying it when it is no longer on the screen stack. + Note that you don't need to install a screen to use it. See [push_screen][textual.app.App.push_screen] + or [switch_screen][textual.app.App.switch_screen] to make a new screen current. + + Args: + screen: Screen to install. + name: Unique name to identify the screen. + + Raises: + ScreenError: If the screen can't be installed. + + Returns: + An awaitable that awaits the mounting of the screen and its children. + """ + if name in self._installed_screens: + raise ScreenError(f"Can't install screen; {name!r} is already installed") + if screen in self._installed_screens.values(): + raise ScreenError( + f"Can't install screen; {screen!r} has already been installed" + ) + self._installed_screens[name] = screen + self.log.system(f"{screen} INSTALLED name={name!r}") + + def uninstall_screen(self, screen: Screen | str) -> str | None: + """Uninstall a screen. + + If the screen was not previously installed, then this method is a null-op. + Uninstalling a screen allows Textual to delete it when it is popped or switched. + Note that uninstalling a screen is only required if you have previously installed it + with [install_screen][textual.app.App.install_screen]. + Textual will also uninstall screens automatically on exit. + + Args: + screen: The screen to uninstall or the name of an installed screen. + + Returns: + The name of the screen that was uninstalled, or None if no screen was uninstalled. + """ + if isinstance(screen, str): + if screen not in self._installed_screens: + return None + uninstall_screen = self._installed_screens[screen] + if any(uninstall_screen in stack for stack in self._screen_stacks.values()): + raise ScreenStackError("Can't uninstall screen in screen stack") + del self._installed_screens[screen] + self.log.system(f"{uninstall_screen} UNINSTALLED name={screen!r}") + return screen + else: + if any(screen in stack for stack in self._screen_stacks.values()): + raise ScreenStackError("Can't uninstall screen in screen stack") + for name, installed_screen in self._installed_screens.items(): + if installed_screen is screen: + self._installed_screens.pop(name) + self.log.system(f"{screen} UNINSTALLED name={name!r}") + return name + return None + + def pop_screen(self) -> AwaitComplete: + """Pop the current [screen](/guide/screens) from the stack, and switch to the previous screen. + + Returns: + The screen that was replaced. + """ + + screen_stack = self._screen_stack + if len(screen_stack) <= 1: + raise ScreenStackError( + "Can't pop screen; there must be at least one screen on the stack" + ) + + previous_screen = screen_stack.pop() + previous_screen._pop_result_callback() + self.screen.post_message( + events.ScreenResume(refresh_styles=previous_screen.styles.background.a < 0) + ) + self.screen_change_signal.publish(self.screen) + self.log.system(f"{self.screen} is active") + + async def do_pop() -> None: + """Task to pop the screen.""" + await self._replace_screen(previous_screen) + + return AwaitComplete(do_pop()).call_next(self) + + def _pop_to_screen(self, screen: Screen) -> None: + """Pop screens until the given screen is active. + + Args: + screen: desired active screen + + Raises: + ScreenError: If the screen doesn't exist in the stack. + """ + screens_to_pop: list[Screen] = [] + for pop_screen in reversed(self.screen_stack): + if pop_screen is not screen: + screens_to_pop.append(pop_screen) + else: + break + else: + raise ScreenError(f"Screen {screen!r} not in screen stack") + + async def pop_screens() -> None: + """Pop any screens in `screens_to_pop`.""" + with self.batch_update(): + for screen in screens_to_pop: + await screen.dismiss() + + if screens_to_pop: + self.call_later(pop_screens) + + def set_focus(self, widget: Widget | None, scroll_visible: bool = True) -> None: + """Focus (or unfocus) a widget. A focused widget will receive key events first. + + Args: + widget: Widget to focus. + scroll_visible: Scroll widget into view. + """ + self.screen.set_focus(widget, scroll_visible) + + def _set_mouse_over( + self, widget: Widget | None, hover_widget: Widget | None + ) -> None: + """Called when the mouse is over another widget. + + Args: + widget: Widget under mouse, or None for no widgets. + """ + if widget is None: + if self.mouse_over is not None: + try: + self.mouse_over.post_message(events.Leave(self.mouse_over)) + finally: + self.mouse_over = None + else: + if self.mouse_over is not widget: + try: + if self.mouse_over is not None: + self.mouse_over.post_message(events.Leave(self.mouse_over)) + if widget is not None: + widget.post_message(events.Enter(widget)) + finally: + self.mouse_over = widget + + current_hover_over = self.hover_over + if current_hover_over is not None: + current_hover_over.mouse_hover = False + + if hover_widget is not None: + hover_widget.mouse_hover = True + if hover_widget._has_hover_style: + hover_widget.update_node_styles() + if current_hover_over is not None and current_hover_over._has_hover_style: + current_hover_over.update_node_styles() + self.hover_over = hover_widget + + def _update_mouse_over(self, screen: Screen) -> None: + """Updates the mouse over after the next refresh. + + This method is called whenever a widget is added or removed, which may change + the widget under the mouse. + + """ + + if self.mouse_over is None or not screen.is_active: + return + + async def check_mouse() -> None: + """Check if the mouse over widget has changed.""" + try: + hover_widgets = screen.get_hover_widgets_at(*self.mouse_position) + except NoWidget: + pass + else: + mouse_over, hover_over = hover_widgets.widgets + if ( + mouse_over is not self.mouse_over + or hover_over is not self.hover_over + ): + self._set_mouse_over(mouse_over, hover_over) + + self.call_after_refresh(check_mouse) + + def capture_mouse(self, widget: Widget | None) -> None: + """Send all mouse events to the given widget or disable mouse capture. + + Normally mouse events are sent to the widget directly under the pointer. + Capturing the mouse allows a widget to receive mouse events even when the pointer is over another widget. + + Args: + widget: Widget to capture mouse events, or `None` to end mouse capture. + """ + if widget == self.mouse_captured: + return + if self.mouse_captured is not None: + self.mouse_captured.post_message(events.MouseRelease(self.mouse_position)) + self.mouse_captured = widget + if widget is not None: + widget.post_message(events.MouseCapture(self.mouse_position)) + self.screen.update_pointer_shape() + + def panic(self, *renderables: RenderableType) -> None: + """Exits the app and display error message(s). + + Used in response to unexpected errors. + For a more graceful exit, see the [exit][textual.app.App.exit] method. + + Args: + *renderables: Text or Rich renderable(s) to display on exit. + """ + assert all( + is_renderable(renderable) for renderable in renderables + ), "Can only call panic with strings or Rich renderables" + + def render(renderable: RenderableType) -> list[Segment]: + """Render a panic renderables.""" + segments = list(self.console.render(renderable, self.console.options)) + return segments + + pre_rendered = [Segments(render(renderable)) for renderable in renderables] + self._exit_renderables.extend(pre_rendered) + + self._close_messages_no_wait() + + def _handle_exception(self, error: Exception) -> None: + """Called with an unhandled exception. + + Always results in the app exiting. + + Args: + error: An exception instance. + """ + self._return_code = 1 + # If we're running via pilot and this is the first exception encountered, + # take note of it so that we can re-raise for test frameworks later. + if self._exception is None: + self._exception = error + self._exception_event.set() + + if hasattr(error, "__rich__"): + # Exception has a rich method, so we can defer to that for the rendering + self.panic(error) + else: + # Use default exception rendering + self._fatal_error() + + def _fatal_error(self) -> None: + """Exits the app after an unhandled exception.""" + from rich.traceback import Traceback + + self.bell() + traceback = Traceback( + show_locals=True, width=None, locals_max_length=5, suppress=[rich] + ) + self._exit_renderables.append( + Segments(self.console.render(traceback, self.console.options)) + ) + self._close_messages_no_wait() + + def _print_error_renderables(self) -> None: + """Print and clear exit renderables.""" + error_count = len(self._exit_renderables) + if "debug" in self.features: + for renderable in self._exit_renderables: + self.error_console.print(renderable) + if error_count > 1: + self.error_console.print( + f"\n[b]NOTE:[/b] {error_count} errors shown above.", markup=True + ) + elif self._exit_renderables: + self.error_console.print(self._exit_renderables[0]) + if error_count > 1: + self.error_console.print( + f"\n[b]NOTE:[/b] 1 of {error_count} errors shown. Run with [b]textual run --dev[/] to see all errors.", + markup=True, + ) + + self._exit_renderables.clear() + + def _build_driver( + self, headless: bool, inline: bool, mouse: bool, size: tuple[int, int] | None + ) -> Driver: + """Construct a driver instance. + + Args: + headless: Request headless driver. + inline: Request inline driver. + mouse: Request mouse support. + size: Initial size. + + Returns: + Driver instance. + """ + driver: Driver + driver_class: type[Driver] + if headless: + from memray._vendor.textual.drivers.headless_driver import HeadlessDriver + + driver_class = HeadlessDriver + elif inline and not WINDOWS: + from memray._vendor.textual.drivers.linux_inline_driver import LinuxInlineDriver + + driver_class = LinuxInlineDriver + else: + driver_class = self.driver_class + + driver = self._driver = driver_class( + self, + debug=constants.DEBUG, + mouse=mouse, + size=size, + ) + return driver + + async def _init_devtools(self): + """Initialize developer tools.""" + if self.devtools is not None: + from textual_dev.client import DevtoolsConnectionError + + try: + await self.devtools.connect() + self.log.system(f"Connected to devtools ( {self.devtools.url} )") + except DevtoolsConnectionError: + self.log.system(f"Couldn't connect to devtools ( {self.devtools.url} )") + + async def _process_messages( + self, + ready_callback: CallbackType | None = None, + headless: bool = False, + inline: bool = False, + inline_no_clear: bool = False, + mouse: bool = True, + terminal_size: tuple[int, int] | None = None, + message_hook: Callable[[Message], None] | None = None, + ) -> None: + self._thread_init() + + async def app_prelude() -> bool: + """Work required before running the app. + + Returns: + `True` if the app should continue, or `False` if there was a problem starting. + """ + await self._init_devtools() + self.log.system("---") + self.log.system(loop=asyncio.get_running_loop()) + self.log.system(features=self.features) + if constants.LOG_FILE is not None: + _log_path = os.path.abspath(constants.LOG_FILE) + self.log.system(f"Writing logs to {_log_path!r}") + + try: + if self.css_path: + self.stylesheet.read_all(self.css_path) + for read_from, css, tie_breaker, scope in self._get_default_css(): + self.stylesheet.add_source( + css, + read_from=read_from, + is_default_css=True, + tie_breaker=tie_breaker, + scope=scope, + ) + if self.CSS: + try: + app_path = inspect.getfile(self.__class__) + except (TypeError, OSError): + app_path = "" + read_from = (app_path, f"{self.__class__.__name__}.CSS") + self.stylesheet.add_source( + self.CSS, read_from=read_from, is_default_css=False + ) + except Exception as error: + self._handle_exception(error) + self._print_error_renderables() + return False + + if self.css_monitor: + self.set_interval(0.25, self.css_monitor, name="css monitor") + self.log.system("STARTED", self.css_monitor) + return True + + async def run_process_messages(): + """The main message loop, invoke below.""" + + async def invoke_ready_callback() -> None: + if ready_callback is not None: + ready_result = ready_callback() + if inspect.isawaitable(ready_result): + await ready_result + + with self.batch_update(): + try: + try: + await self._dispatch_message(events.Compose()) + await self._dispatch_message( + events.Resize.from_dimensions(self.size, None) + ) + default_screen = self.screen + self.stylesheet.apply(self) + await self._dispatch_message(events.Mount()) + self.check_idle() + finally: + self._mounted_event.set() + self._is_mounted = True + + Reactive._initialize_object(self) + + if self.screen is not default_screen: + self.stylesheet.apply(default_screen) + + await self.animator.start() + + except Exception: + await self.animator.stop() + raise + + finally: + self._running = True + await self._ready() + await invoke_ready_callback() + + try: + await self._process_messages_loop() + except asyncio.CancelledError: + pass + finally: + self.workers.cancel_all() + self._running = False + try: + await self.animator.stop() + finally: + await Timer._stop_all(self._timers) + + with self._context(): + if not await app_prelude(): + return + self._running = True + try: + load_event = events.Load() + await self._dispatch_message(load_event) + + driver = self._driver = self._build_driver( + headless=headless, + inline=inline, + mouse=mouse, + size=terminal_size, + ) + self.log(driver=driver) + + if not self._exit: + driver.start_application_mode() + try: + with redirect_stdout(self._capture_stdout): + with redirect_stderr(self._capture_stderr): + await run_process_messages() + + finally: + Reactive._clear_watchers(self) + if self._driver.is_inline: + cursor_x, cursor_y = self._previous_cursor_position + self._driver.write( + Control.move(-cursor_x, -cursor_y).segment.text + ) + self._driver.flush() + if inline_no_clear and not self.app._exit_renderables: + console = Console() + try: + console.print(self.screen._compositor) + except ScreenStackError: + console.print() + else: + self._driver.write( + Control.move(0, -self.INLINE_PADDING).segment.text + ) + + driver.stop_application_mode() + except Exception as error: + self._handle_exception(error) + + async def _pre_process(self) -> bool: + """Special case for the app, which doesn't need the functionality in MessagePump.""" + return True + + async def _ready(self) -> None: + """Called immediately prior to processing messages. + + May be used as a hook for any operations that should run first. + """ + + ready_time = (perf_counter() - self._start_time) * 1000 + self.log.system(f"ready in {ready_time:0.0f} milliseconds") + + async def take_screenshot() -> None: + """Take a screenshot and exit.""" + self.save_screenshot( + path=constants.SCREENSHOT_LOCATION, + filename=constants.SCREENSHOT_FILENAME, + ) + self.exit() + + if constants.SCREENSHOT_DELAY >= 0: + self.set_timer( + constants.SCREENSHOT_DELAY, take_screenshot, name="screenshot timer" + ) + + async def _on_compose(self) -> None: + _rich_traceback_omit = True + self._compose_screen = self.screen + try: + widgets = [*self.screen._nodes, *compose(self)] + except TypeError as error: + raise TypeError( + f"{self!r} compose() method returned an invalid result; {error}" + ) from error + + await self.mount_all(widgets) + + async def _check_recompose(self) -> None: + """Check if a recompose is required.""" + if self._recompose_required: + self._recompose_required = False + await self.recompose() + + async def recompose(self) -> None: + """Recompose the widget. + + Recomposing will remove children and call `self.compose` again to remount. + """ + if self._exit: + return + try: + async with self.screen.batch(): + await self.screen.query("*").exclude(".-textual-system").remove() + await self.screen.mount_all(compose(self)) + except ScreenStackError: + pass + + def _register_child( + self, parent: DOMNode, child: Widget, before: int | None, after: int | None + ) -> None: + """Register a widget as a child of another. + + Args: + parent: Parent node. + child: The child widget to register. + before: A location to mount before. + after: A location to mount after. + """ + + # Let's be 100% sure that we've not been asked to do a before and an + # after at the same time. It's possible that we can remove this + # check later on, but for the purposes of development right now, + # it's likely a good idea to keep it here to check assumptions in + # the rest of the code. + if before is not None and after is not None: + raise AppError("Only one of 'before' and 'after' may be specified.") + + # If we don't already know about this widget... + if child not in self._registry: + # Now to figure out where to place it. If we've got a `before`... + if before is not None: + # ...it's safe to NodeList._insert before that location. + parent._nodes._insert(before, child) + elif after is not None and after != -1: + # In this case we've got an after. -1 holds the special + # position (for now) of meaning "okay really what I mean is + # do an append, like if I'd asked to add with no before or + # after". So... we insert before the next item in the node + # list, if after isn't -1. + parent._nodes._insert(after + 1, child) + else: + # At this point we appear to not be adding before or after, + # or we've got a before/after value that really means + # "please append". So... + parent._nodes._append(child) + + # Now that the widget is in the NodeList of its parent, sort out + # the rest of the admin. + self._registry.add(child) + child._attach(parent) + child._post_register(self) + + def _register( + self, + parent: DOMNode, + *widgets: Widget, + before: int | None = None, + after: int | None = None, + cache: dict[tuple, RulesMap] | None = None, + ) -> list[Widget]: + """Register widget(s) so they may receive events. + + Args: + parent: Parent node. + *widgets: The widget(s) to register. + before: A location to mount before. + after: A location to mount after. + cache: Optional rules map cache. + + Returns: + List of modified widgets. + """ + + if not widgets: + return [] + + if cache is None: + cache = {} + widget_list: Iterable[Widget] + if before is not None or after is not None: + # There's a before or after, which means there's going to be an + # insertion, so make it easier to get the new things in the + # correct order. + widget_list = reversed(widgets) + else: + widget_list = widgets + + apply_stylesheet = self.stylesheet.apply + new_widgets: list[Widget] = [] + add_new_widget = new_widgets.append + for widget in widget_list: + widget._closing = False + widget._closed = False + widget._pruning = False + if not isinstance(widget, Widget): + raise AppError(f"Can't register {widget!r}; expected a Widget instance") + if widget not in self._registry: + add_new_widget(widget) + self._register_child(parent, widget, before, after) + if widget._nodes: + self._register(widget, *widget._nodes, cache=cache) + for widget in new_widgets: + apply_stylesheet(widget, cache=cache) + widget._start_messages() + + if not self._running: + # If the app is not running, prevent awaiting of the widget tasks + return [] + + return list(widgets) + + def _unregister(self, widget: Widget) -> None: + """Unregister a widget. + + Args: + widget: A Widget to unregister + """ + widget.blur() + if isinstance(widget._parent, Widget): + widget._parent._nodes._remove(widget) + widget._detach() + self._registry.discard(widget) + + async def _disconnect_devtools(self): + if self.devtools is not None: + await self.devtools.disconnect() + + def _start_widget(self, parent: Widget, widget: Widget) -> None: + """Start a widget (run its task) so that it can receive messages. + + Args: + parent: The parent of the Widget. + widget: The Widget to start. + """ + + widget._attach(parent) + widget._start_messages() + self.app._registry.add(widget) + + def is_mounted(self, widget: Widget) -> bool: + """Check if a widget is mounted. + + Args: + widget: A widget. + + Returns: + True of the widget is mounted. + """ + return widget in self._registry + + async def _close_all(self) -> None: + """Close all message pumps.""" + + # Close all screens on all stacks: + for stack in self._screen_stacks.values(): + for stack_screen in reversed(stack): + if stack_screen._running: + await self._prune(stack_screen) + stack.clear() + self._installed_screens.clear() + self._modes.clear() + + # Close any remaining nodes + # Should be empty by now + remaining_nodes = list(self._registry) + for child in remaining_nodes: + await child._close_messages() + + async def _shutdown(self) -> None: + self._begin_batch() # Prevents any layout / repaint while shutting down + driver = self._driver + self._running = False + if driver is not None: + driver.disable_input() + + await self._close_all() + await self._close_messages() + await self._dispatch_message(events.Unmount()) + + if self._driver is not None: + self._driver.close() + + self._nodes._clear() + + if self.devtools is not None and self.devtools.is_connected: + await self._disconnect_devtools() + + self._print_error_renderables() + + if constants.SHOW_RETURN: + from rich.console import Console + from rich.pretty import Pretty + + console = Console() + console.print("[b]The app returned:") + console.print(Pretty(self._return_value)) + + async def _on_exit_app(self) -> None: + self._begin_batch() # Prevent repaint / layout while shutting down + self._message_queue.put_nowait(None) + + def refresh( + self, + *, + repaint: bool = True, + layout: bool = False, + recompose: bool = False, + ) -> Self: + """Refresh the entire screen. + + Args: + repaint: Repaint the widget (will call render() again). + layout: Also layout widgets in the view. + recompose: Re-compose the widget (will remove and re-mount children). + + Returns: + The `App` instance. + """ + if recompose: + self._recompose_required = recompose + self.call_next(self._check_recompose) + return self + + if self._screen_stack: + self.screen.refresh(repaint=repaint, layout=layout) + self.check_idle() + return self + + def refresh_css(self, animate: bool = True) -> None: + """Refresh CSS. + + Args: + animate: Also execute CSS animations. + """ + stylesheet = self.app.stylesheet + stylesheet.set_variables(self.get_css_variables()) + stylesheet.reparse() + stylesheet.update(self.app, animate=animate) + try: + if self.screen.is_mounted: + self.screen._refresh_layout(self.size) + self.screen._css_update_count = self._css_update_count + except ScreenError: + pass + # The other screens in the stack will need to know about some style + # changes, as a final pass let's check in on every screen that isn't + # the current one and update them too. + for screen in self.screen_stack: + if screen != self.screen: + stylesheet.update(screen, animate=animate) + screen._css_update_count = self._css_update_count + + def _display(self, screen: Screen, renderable: RenderableType | None) -> None: + """Display a renderable within a sync. + + Args: + screen: Screen instance + renderable: A Rich renderable. + """ + + try: + if renderable is None: + return + if self._batch_count: + return + if ( + self._running + and not self._closed + and not self.is_headless + and self._driver is not None + ): + console = self.console + self._begin_update() + try: + try: + if isinstance(renderable, CompositorUpdate): + cursor_position = self.screen.outer_size.clamp_offset( + self.cursor_position + ) + if self._driver.is_inline: + terminal_sequence = Control.move( + *(-self._previous_cursor_position) + ).segment.text + terminal_sequence += renderable.render_segments(console) + terminal_sequence += Control.move( + *cursor_position + ).segment.text + else: + terminal_sequence = renderable.render_segments(console) + terminal_sequence += Control.move_to( + *cursor_position + ).segment.text + self._previous_cursor_position = cursor_position + else: + segments = console.render(renderable) + terminal_sequence = console._render_buffer(segments) + except Exception as error: + self._handle_exception(error) + else: + if WINDOWS: + # Combat a problem with Python on Windows. + # + # https://github.com/Textualize/textual/issues/2548 + # https://github.com/python/cpython/issues/82052 + CHUNK_SIZE = 8192 + write = self._driver.write + for chunk in ( + terminal_sequence[offset : offset + CHUNK_SIZE] + for offset in range( + 0, len(terminal_sequence), CHUNK_SIZE + ) + ): + write(chunk) + else: + self._driver.write(terminal_sequence) + finally: + self._end_update() + + self._driver.flush() + + finally: + self.post_display_hook() + + def post_display_hook(self) -> None: + """Called immediately after a display is done. Used in tests.""" + + def get_widget_at(self, x: int, y: int) -> tuple[Widget, Region]: + """Get the widget under the given coordinates. + + Args: + x: X coordinate. + y: Y coordinate. + + Returns: + The widget and the widget's screen region. + """ + return self.screen.get_widget_at(x, y) + + def bell(self) -> None: + """Play the console 'bell'. + + For terminals that support a bell, this typically makes a notification or error sound. + Some terminals may make no sound or display a visual bell indicator, depending on configuration. + """ + if not self.is_headless and self._driver is not None: + self._driver.write("\07") + + def _set_pointer_shape(self, shape: str) -> None: + """Generate escape sequence to set pointer (cursor) shape using Kitty protocol. + + Args: + shape: The pointer shape name (e.g., "default", "pointer", "text", "crosshair", etc.) + + Returns: + The escape sequence to set the pointer shape. + + See: https://sw.kovidgoyal.net/kitty/pointer-shapes/ + """ + # Kitty pointer shape protocol: ESC ] 22 ; ST + # where ST is ESC \ or BEL (\x07) + # Using BEL as terminator for better compatibility + if self._driver is not None: + shape_sequence = f"\x1b]22;{shape}\x07" + self._driver.write(shape_sequence) + + @property + def _binding_chain(self) -> list[tuple[DOMNode, BindingsMap]]: + """Get a chain of nodes and bindings to consider. + + If no widget is focused, returns the bindings from both the screen and the app level bindings. + Otherwise, combines all the bindings from the currently focused node up the DOM to the root App. + """ + focused = self.focused + namespace_bindings: list[tuple[DOMNode, BindingsMap]] + + if focused is None: + namespace_bindings = [ + (self.screen, self.screen._bindings), + (self, self._bindings), + ] + else: + namespace_bindings = [ + (node, node._bindings) for node in focused.ancestors_with_self + ] + + return namespace_bindings + + def simulate_key(self, key: str) -> None: + """Simulate a key press. + + This will perform the same action as if the user had pressed the key. + + Args: + key: Key to simulate. May also be the name of a key, e.g. "space". + """ + self.post_message(events.Key(key, None)) + + async def _check_bindings(self, key: str, priority: bool = False) -> bool: + """Handle a key press. + + This method is used internally by the bindings system. + + Args: + key: A key. + priority: If `True` check from `App` down, otherwise from focused up. + + Returns: + True if the key was handled by a binding, otherwise False + """ + for namespace, bindings in ( + reversed(self.screen._binding_chain) + if priority + else self.screen._modal_binding_chain + ): + key_bindings = bindings.key_to_bindings.get(key, ()) + for binding in key_bindings: + if binding.priority == priority: + if await self.run_action(binding.action, namespace): + return True + return False + + def action_help_quit(self) -> None: + """Bound to ctrl+C to alert the user that it no longer quits.""" + # Doing this because users will reflexively hit ctrl+C to exit + # Ctrl+C is now bound to copy if an input / textarea is focused. + # This makes is possible, even likely, that a user may do it accidentally -- which would be maddening. + # Rather than do nothing, we can make an educated guess the user was trying + # to quit, and inform them how you really quit. + for key, active_binding in self.active_bindings.items(): + if active_binding.binding.action in ("quit", "app.quit"): + self.notify( + f"Press [b]{key}[/b] to quit the app", title="Do you want to quit?" + ) + return + + @classmethod + def _normalize_keymap(cls, keymap: Keymap) -> Keymap: + """Normalizes the keys in a keymap, so they use long form, i.e. "question_mark" rather than "?".""" + return { + binding_id: _normalize_key_list(keys) for binding_id, keys in keymap.items() + } + + def set_keymap(self, keymap: Keymap) -> None: + """Set the keymap, a mapping of binding IDs to key strings. + + Bindings in the keymap are used to override default key bindings, + i.e. those defined in `BINDINGS` class variables. + + Bindings with IDs that are present in the keymap will have + their key string replaced with the value from the keymap. + + Args: + keymap: A mapping of binding IDs to key strings. + """ + + self._keymap = self._normalize_keymap(keymap) + self.refresh_bindings() + + def update_keymap(self, keymap: Keymap) -> None: + """Update the App's keymap, merging with `keymap`. + + If a Binding ID exists in both the App's keymap and the `keymap` + argument, the `keymap` argument takes precedence. + + Args: + keymap: A mapping of binding IDs to key strings. + """ + + self._keymap = {**self._keymap, **self._normalize_keymap(keymap)} + self.refresh_bindings() + + def handle_bindings_clash( + self, clashed_bindings: set[Binding], node: DOMNode + ) -> None: + """Handle a clash between bindings. + + Bindings clashes are likely due to users setting conflicting + keys via their keymap. + + This method is intended to be overridden by subclasses. + + Textual will call this each time a clash is encountered - + which may be on each keypress if a clashing widget is focused + or is in the bindings chain. + + Args: + clashed_bindings: The bindings that are clashing. + node: The node that has the clashing bindings. + """ + pass + + async def on_event(self, event: events.Event) -> None: + # Handle input events that haven't been forwarded + # If the event has been forwarded it may have bubbled up back to the App + if isinstance(event, events.Compose): + await self._init_mode(self._current_mode) + await super().on_event(event) + elif isinstance(event, events.InputEvent) and not event.is_forwarded: + if not self.app_focus and isinstance(event, (events.Key, events.MouseDown)): + self.app_focus = True + if isinstance(event, events.MouseEvent): + # Record current mouse position on App + self.mouse_position = Offset(event.x, event.y) + self.mouse_position_high_resolution = (event.screen_x, event.screen_y) + if isinstance(event, events.MouseDown): + try: + self._mouse_down_widget, _ = self.get_widget_at( + event.x, event.y + ) + except NoWidget: + # Shouldn't occur, since at the very least this will find the Screen + self._mouse_down_widget = None + + self.screen._forward_event(event) + + # If a MouseUp occurs at the same widget as a MouseDown, then we should + # consider it a click, and produce a Click event. + if ( + isinstance(event, events.MouseUp) + and self._mouse_down_widget is not None + ): + try: + screen_offset = event.screen_offset + mouse_down_widget = self._mouse_down_widget + mouse_up_widget, _ = self.get_widget_at(*screen_offset) + if mouse_up_widget is mouse_down_widget: + same_offset = ( + self._click_chain_last_offset is not None + and self._click_chain_last_offset == screen_offset + ) + within_time_threshold = ( + self._click_chain_last_time is not None + and event.time - self._click_chain_last_time + <= self.CLICK_CHAIN_TIME_THRESHOLD + ) + + if same_offset and within_time_threshold: + self._chained_clicks += 1 + else: + self._chained_clicks = 1 + + click_event = events.Click.from_event( + mouse_down_widget, event, chain=self._chained_clicks + ) + + self._click_chain_last_time = event.time + self._click_chain_last_offset = screen_offset + + self.screen._forward_event(click_event) + except NoWidget: + pass + + elif isinstance(event, events.Key): + # Special case for maximized widgets + # If something is maximized, then escape should minimize + if ( + self.screen.maximized is not None + and event.key == "escape" + and self.escape_to_minimize + ): + self.screen.minimize() + return + if self.focused: + try: + self.screen._clear_tooltip() + except NoScreen: + pass + if not await self._check_bindings(event.key, priority=True): + forward_target = self.focused or self.screen + forward_target._forward_event(event) + else: + self.screen._forward_event(event) + + elif isinstance(event, events.Paste) and not event.is_forwarded: + if self.focused is not None: + self.focused._forward_event(event) + else: + self.screen._forward_event(event) + else: + await super().on_event(event) + + @property + def escape_to_minimize(self) -> bool: + """Use the escape key to minimize? + + When a widget is [maximized][textual.screen.Screen.maximize], this boolean determines if the `escape` key will + minimize the widget (potentially overriding any bindings). + + The default logic is to use the screen's `ESCAPE_TO_MINIMIZE` classvar if it is set to `True` or `False`. + If the classvar on the screen is *not* set (and left as `None`), then the app's `ESCAPE_TO_MINIMIZE` is used. + + """ + return bool( + self.ESCAPE_TO_MINIMIZE + if self.screen.ESCAPE_TO_MINIMIZE is None + else self.screen.ESCAPE_TO_MINIMIZE + ) + + def _parse_action( + self, + action: str | ActionParseResult, + default_namespace: DOMNode, + namespaces: Mapping[str, DOMNode] | None = None, + ) -> tuple[DOMNode, str, tuple[object, ...]]: + """Parse an action. + + Args: + action: An action string. + default_namespace: Namespace to user when none is supplied in the action. + namespaces: Mapping of namespaces. + + Raises: + ActionError: If there are any errors parsing the action string. + + Returns: + A tuple of (node or None, action name, tuple of parameters). + """ + if isinstance(action, tuple): + destination, action_name, params = action + else: + destination, action_name, params = actions.parse(action) + + action_target: DOMNode | None = ( + None if namespaces is None else namespaces.get(destination) + ) + if destination and action_target is None: + if destination not in self._action_targets: + raise ActionError(f"Action namespace {destination} is not known") + action_target = getattr(self, destination, None) + if action_target is None: + raise ActionError(f"Action target {destination!r} not available") + return ( + (default_namespace if action_target is None else action_target), + action_name, + params, + ) + + def _check_action_state( + self, action: str, default_namespace: DOMNode + ) -> bool | None: + """Check if an action is enabled. + + Args: + action: An action string. + default_namespace: The default namespace if one is not specified in the action. + + Returns: + State of an action. + """ + action_target, action_name, parameters = self._parse_action( + action, default_namespace + ) + return action_target.check_action(action_name, parameters) + + async def run_action( + self, + action: str | ActionParseResult, + default_namespace: DOMNode | None = None, + namespaces: Mapping[str, DOMNode] | None = None, + ) -> bool: + """Perform an [action](/guide/actions). + + Actions are typically associated with key bindings, where you wouldn't need to call this method manually. + + Args: + action: Action encoded in a string. + default_namespace: Namespace to use if not provided in the action, + or None to use app. + namespaces: Mapping of namespaces. + + Returns: + True if the event has been handled. + """ + action_target, action_name, params = self._parse_action( + action, self if default_namespace is None else default_namespace, namespaces + ) + if action_target.check_action(action_name, params): + return await self._dispatch_action(action_target, action_name, params) + else: + return False + + async def _dispatch_action( + self, namespace: DOMNode, action_name: str, params: Any + ) -> bool: + """Dispatch an action to an action method. + + Args: + namespace: Namespace (object) of action. + action_name: Name of the action. + params: Action parameters. + + Returns: + True if handled, otherwise False. + """ + _rich_traceback_guard = True + + log.system( + "", + namespace=namespace, + action_name=action_name, + params=params, + ) + + try: + private_method = getattr(namespace, f"_action_{action_name}", None) + if callable(private_method): + await invoke(private_method, *params) + return True + public_method = getattr(namespace, f"action_{action_name}", None) + if callable(public_method): + await invoke(public_method, *params) + return True + log.system( + f" {action_name!r} has no target." + f" Could not find methods '_action_{action_name}' or 'action_{action_name}'" + ) + except SkipAction: + # The action method raised this to explicitly not handle the action + log.system(f" {action_name!r} skipped.") + + return False + + async def _broker_event( + self, event_name: str, event: events.Event, default_namespace: DOMNode + ) -> bool: + """Allow the app an opportunity to dispatch events to action system. + + Args: + event_name: _description_ + event: An event object. + default_namespace: The default namespace, where one isn't supplied. + + Returns: + True if an action was processed. + """ + try: + style = getattr(event, "style") + except AttributeError: + return False + try: + _modifiers, action = extract_handler_actions(event_name, style.meta) + except NoHandler: + return False + else: + event.stop() + + if isinstance(action, str): + await self.run_action(action, default_namespace) + elif isinstance(action, tuple) and len(action) == 2: + action_name, action_params = action + namespace, parsed_action, _ = actions.parse(action_name) + await self.run_action( + (namespace, parsed_action, action_params), + default_namespace, + ) + else: + if isinstance(action, tuple) and self.debug: + # It's a tuple and made it this far, which means it'll be a + # malformed action. This is a no-op, but let's log that + # anyway. + log.warning( + f"Can't parse @{event_name} action from style meta; check your console markup syntax" + ) + return False + return True + + async def _on_update(self, message: messages.Update) -> None: + message.stop() + + async def _on_layout(self, message: messages.Layout) -> None: + message.stop() + + async def _on_key(self, event: events.Key) -> None: + if not (await self._check_bindings(event.key)): + await dispatch_key(self, event) + + async def _on_resize(self, event: events.Resize) -> None: + event.stop() + self._size = event.size + self._resize_event = event + + async def _on_app_focus(self, event: events.AppFocus) -> None: + """App has focus.""" + # Required by textual-web to manage focus in a web page. + self.app_focus = True + self.screen.refresh_bindings() + + async def _on_app_blur(self, event: events.AppBlur) -> None: + """App has lost focus.""" + # Required by textual-web to manage focus in a web page. + self.app_focus = False + self.screen.refresh_bindings() + + def _prune(self, *nodes: Widget, parent: DOMNode | None = None) -> AwaitRemove: + """Prune nodes from DOM. + + Args: + parent: Parent node. + + Returns: + Optional awaitable. + """ + if not nodes: + return AwaitRemove([]) + pruning_nodes: set[Widget] = {*nodes} + for node in nodes: + node.post_message(Prune()) + pruning_nodes.update(node.walk_children(with_self=True)) + + try: + screen = nodes[0].screen + except (ScreenStackError, NoScreen): + pass + else: + if screen.focused and screen.focused in pruning_nodes: + screen._reset_focus(screen.focused, list(pruning_nodes)) + + for node in pruning_nodes: + node._pruning = True + + def post_mount() -> None: + """Called after removing children.""" + + if parent is not None: + try: + screen = parent.screen + except (ScreenStackError, NoScreen): + pass + else: + if screen._running: + self._update_mouse_over(screen) + finally: + parent.refresh(layout=True) + + await_complete = AwaitRemove( + [task for node in nodes if (task := node._task) is not None], + post_mount, + ) + self.call_next(await_complete) + return await_complete + + def _watch_app_focus(self, focus: bool) -> None: + """Respond to changes in app focus.""" + self.screen.update_node_styles() + if focus: + # If we've got a last-focused widget, if it still has a screen, + # and if the screen is still the current screen and if nothing + # is focused right now... + try: + if ( + self._last_focused_on_app_blur is not None + and self._last_focused_on_app_blur.screen is self.screen + and self.screen.focused is None + ): + # ...settle focus back on that widget. + # Don't scroll the newly focused widget, as this can be quite jarring + self.screen.set_focus( + self._last_focused_on_app_blur, + scroll_visible=False, + from_app_focus=True, + ) + except NoScreen: + pass + # Now that we have focus back on the app and we don't need the + # widget reference any more, don't keep it hanging around here. + self._last_focused_on_app_blur = None + else: + # Remember which widget has focus, when the app gets focus back + # we'll want to try and focus it again. + self._last_focused_on_app_blur = self.screen.focused + # Remove focus for now. + self.screen.set_focus(None) + + async def action_simulate_key(self, key: str) -> None: + """An [action](/guide/actions) to simulate a key press. + + This will invoke the same actions as if the user had pressed the key. + + Args: + key: The key to process. + """ + self.simulate_key(key) + + async def action_quit(self) -> None: + """An [action](/guide/actions) to quit the app as soon as possible.""" + self.exit() + + async def action_bell(self) -> None: + """An [action](/guide/actions) to play the terminal 'bell'.""" + self.bell() + + async def action_focus(self, widget_id: str) -> None: + """An [action](/guide/actions) to focus the given widget. + + Args: + widget_id: ID of widget to focus. + """ + try: + node = self.query(f"#{widget_id}").first() + except NoMatches: + pass + else: + if isinstance(node, Widget): + self.set_focus(node) + + async def action_switch_screen(self, screen: str) -> None: + """An [action](/guide/actions) to switch screens. + + Args: + screen: Name of the screen. + """ + self.switch_screen(screen) + + async def action_push_screen(self, screen: str) -> None: + """An [action](/guide/actions) to push a new screen on to the stack and make it active. + + Args: + screen: Name of the screen. + """ + self.push_screen(screen) + + async def action_pop_screen(self) -> None: + """An [action](/guide/actions) to remove the topmost screen and makes the new topmost screen active.""" + self.pop_screen() + + async def action_switch_mode(self, mode: str) -> None: + """An [action](/guide/actions) that switches to the given mode.""" + self.switch_mode(mode) + + async def action_back(self) -> None: + """An [action](/guide/actions) to go back to the previous screen (pop the current screen). + + Note: + If there is no screen to go back to, this is a non-operation (in + other words it's safe to call even if there are no other screens + on the stack.) + """ + try: + self.pop_screen() + except ScreenStackError: + pass + + async def action_add_class(self, selector: str, class_name: str) -> None: + """An [action](/guide/actions) to add a CSS class to the selected widget. + + Args: + selector: Selects the widget to add the class to. + class_name: The class to add to the selected widget. + """ + self.screen.query(selector).add_class(class_name) + + async def action_remove_class(self, selector: str, class_name: str) -> None: + """An [action](/guide/actions) to remove a CSS class from the selected widget. + + Args: + selector: Selects the widget to remove the class from. + class_name: The class to remove from the selected widget.""" + self.screen.query(selector).remove_class(class_name) + + async def action_toggle_class(self, selector: str, class_name: str) -> None: + """An [action](/guide/actions) to toggle a CSS class on the selected widget. + + Args: + selector: Selects the widget to toggle the class on. + class_name: The class to toggle on the selected widget. + """ + self.screen.query(selector).toggle_class(class_name) + + def action_toggle_dark(self) -> None: + """An [action](/guide/actions) to toggle the theme between textual-light + and textual-dark. This is offered as a convenience to simplify backwards + compatibility with previous versions of Textual which only had light mode + and dark mode.""" + self.theme = ( + "textual-dark" if self.theme == "textual-light" else "textual-light" + ) + + def action_focus_next(self) -> None: + """An [action](/guide/actions) to focus the next widget.""" + self.screen.focus_next() + + def action_focus_previous(self) -> None: + """An [action](/guide/actions) to focus the previous widget.""" + self.screen.focus_previous() + + def action_hide_help_panel(self) -> None: + """Hide the keys panel (if present).""" + self.screen.query("HelpPanel").remove() + + def action_show_help_panel(self) -> None: + """Show the keys panel.""" + from memray._vendor.textual.widgets import HelpPanel + + try: + self.screen.query_one(HelpPanel) + except NoMatches: + self.screen.mount(HelpPanel()) + + def action_notify( + self, message: str, title: str = "", severity: str = "information" + ) -> None: + """Show a notification.""" + self.notify(message, title=title, severity=severity) + + def _on_terminal_supports_synchronized_output( + self, message: messages.TerminalSupportsSynchronizedOutput + ) -> None: + log.system("SynchronizedOutput mode is supported") + if self._driver is not None and not self._driver.is_inline: + self._sync_available = True + + def _begin_update(self) -> None: + if self._sync_available and self._driver is not None: + self._driver.write(SYNC_START) + + def _end_update(self) -> None: + if self._sync_available and self._driver is not None: + self._driver.write(SYNC_END) + + def _refresh_notifications(self) -> None: + """Refresh the notifications on the current screen, if one is available.""" + # If we've got a screen to hand... + try: + screen = self.screen + except ScreenStackError: + pass + else: + try: + # ...see if it has a toast rack. + toast_rack = screen.get_child_by_type(ToastRack) + except NoMatches: + # It doesn't. That's fine. Either there won't ever be one, + # or one will turn up. Things will work out later. + return + # Update the toast rack. + self.call_later(toast_rack.show, self._notifications) + + def clear_selection(self) -> None: + """Clear text selection on the active screen.""" + try: + self.screen.clear_selection() + except NoScreen: + pass + + def notify( + self, + message: str, + *, + title: str = "", + severity: SeverityLevel = "information", + timeout: float | None = None, + markup: bool = True, + ) -> None: + """Create a notification. + + !!! tip + + This method is thread-safe. + + + Args: + message: The message for the notification. + title: The title for the notification. + severity: The severity of the notification. + timeout: The timeout (in seconds) for the notification, or `None` for default. + markup: Render the message as content markup? + + The `notify` method is used to create an application-wide + notification, shown in a [`Toast`][textual.widgets._toast.Toast], + normally originating in the bottom right corner of the display. + + Notifications can have the following severity levels: + + - `information` + - `warning` + - `error` + + The default is `information`. + + Example: + ```python + # Show an information notification. + self.notify("It's an older code, sir, but it checks out.") + + # Show a warning. Note that Textual's notification system allows + # for the use of Rich console markup. + self.notify( + "Now witness the firepower of this fully " + "[b]ARMED[/b] and [i][b]OPERATIONAL[/b][/i] battle station!", + title="Possible trap detected", + severity="warning", + ) + + # Show an error. Set a longer timeout so it's noticed. + self.notify("It's a trap!", severity="error", timeout=10) + + # Show an information notification, but without any sort of title. + self.notify("It's against my programming to impersonate a deity.", title="") + ``` + """ + if timeout is None: + timeout = self.NOTIFICATION_TIMEOUT + notification = Notification(message, title, severity, timeout, markup=markup) + self.post_message(Notify(notification)) + + def _on_notify(self, event: Notify) -> None: + """Handle notification message.""" + self._notifications.add(event.notification) + self._refresh_notifications() + + def _unnotify(self, notification: Notification, refresh: bool = True) -> None: + """Remove a notification from the notification collection. + + Args: + notification: The notification to remove. + refresh: Flag to say if the display of notifications should be refreshed. + """ + del self._notifications[notification] + if refresh: + self._refresh_notifications() + + def clear_notifications(self) -> None: + """Clear all the current notifications.""" + self._notifications.clear() + self._refresh_notifications() + + def action_command_palette(self) -> None: + """Show the Textual command palette.""" + if self.use_command_palette and not CommandPalette.is_open(self): + self.push_screen(CommandPalette(id="--command-palette")) + + def _suspend_signal(self) -> None: + """Signal that the application is being suspended.""" + self.app_suspend_signal.publish(self) + + @on(Driver.SignalResume) + def _resume_signal(self) -> None: + """Signal that the application is being resumed from a suspension.""" + self.app_resume_signal.publish(self) + + @contextmanager + def suspend(self) -> Iterator[None]: + """A context manager that temporarily suspends the app. + + While inside the `with` block, the app will stop reading input and + emitting output. Other applications will have full control of the + terminal, configured as it was before the app started running. When + the `with` block ends, the application will start reading input and + emitting output again. + + Example: + ```python + with self.suspend(): + os.system("emacs -nw") + ``` + + Raises: + SuspendNotSupported: If the environment doesn't support suspending. + + !!! note + Suspending the application is currently only supported on + Unix-like operating systems and Microsoft Windows. Suspending is + not supported in Textual Web. + """ + if self._driver is None: + return + if self._driver.can_suspend: + # Publish a suspend signal *before* we suspend application mode. + self._suspend_signal() + self._driver.suspend_application_mode() + # We're going to handle the start of the driver again so mark + # this next part as such; the reason for this is that the code + # the developer may be running could be in this process, and on + # Unix-like systems the user may `action_suspend_process` the + # app, and we don't want to have the driver auto-restart + # application mode when the application comes back to the + # foreground, in this context. + with ( + self._driver.no_automatic_restart(), + redirect_stdout(sys.__stdout__), + redirect_stderr(sys.__stderr__), + ): + yield + # We're done with the dev's code so resume application mode. + self._driver.resume_application_mode() + # ...and publish a resume signal. + self._resume_signal() + self.refresh(layout=True) + else: + raise SuspendNotSupported( + "App.suspend is not supported in this environment." + ) + + def action_suspend_process(self) -> None: + """Suspend the process into the background. + + Note: + On Unix and Unix-like systems a `SIGTSTP` is sent to the + application's process. Currently on Windows and when running + under Textual Web this is a non-operation. + """ + # Check if we're in an environment that permits this kind of + # suspend. + if not WINDOWS and self._driver is not None and self._driver.can_suspend: + # First, ensure that the suspend signal gets published while + # we're still in application mode. + self._suspend_signal() + # With that out of the way, send the SIGTSTP signal. + os.kill(os.getpid(), signal.SIGTSTP) + # NOTE: There is no call to publish the resume signal here, this + # will be handled by the driver posting a SignalResume event + # (see the event handler on App._resume_signal) above. + + def open_url(self, url: str, *, new_tab: bool = True) -> None: + """Open a URL in the default web browser. + + Args: + url: The URL to open. + new_tab: Whether to open the URL in a new tab. + """ + if self._driver is not None: + self._driver.open_url(url, new_tab) + + def deliver_text( + self, + path_or_file: str | Path | TextIO, + *, + save_directory: str | Path | None = None, + save_filename: str | None = None, + open_method: Literal["browser", "download"] = "download", + encoding: str | None = None, + mime_type: str | None = None, + name: str | None = None, + ) -> str | None: + """Deliver a text file to the end-user of the application. + + If a TextIO object is supplied, it will be closed by this method + and *must not be used* after this method is called. + + If running in a terminal, this will save the file to the user's + downloads directory. + + If running via a web browser, this will initiate a download via + a single-use URL. + + After the file has been delivered, a `DeliveryComplete` message will be posted + to this `App`, which contains the `delivery_key` returned by this method. By + handling this message, you can add custom logic to your application that fires + only after the file has been delivered. + + Args: + path_or_file: The path or file-like object to save. + save_directory: The directory to save the file to. + save_filename: The filename to save the file to. If `path_or_file` + is a file-like object, the filename will be generated from + the `name` attribute if available. If `path_or_file` is a path + the filename will be generated from the path. + encoding: The encoding to use when saving the file. If `None`, + the encoding will be determined by supplied file-like object + (if possible). If this is not possible, 'utf-8' will be used. + mime_type: The MIME type of the file or None to guess based on file extension. + If no MIME type is supplied and we cannot guess the MIME type, from the + file extension, the MIME type will be set to "text/plain". + name: A user-defined named which will be returned in [`DeliveryComplete`][textual.events.DeliveryComplete] + and [`DeliveryComplete`][textual.events.DeliveryComplete]. + + Returns: + The delivery key that uniquely identifies the file delivery. + """ + # Ensure `path_or_file` is a file-like object - convert if needed. + if isinstance(path_or_file, (str, Path)): + binary_path = Path(path_or_file) + binary = binary_path.open("rb") + file_name = save_filename or binary_path.name + else: + encoding = encoding or getattr(path_or_file, "encoding", None) or "utf-8" + binary = path_or_file + file_name = save_filename or getattr(path_or_file, "name", None) + + # If we could infer a filename, and no MIME type was supplied, guess the MIME type. + if file_name and not mime_type: + mime_type, _ = mimetypes.guess_type(file_name) + + # Still no MIME type? Default it to "text/plain". + if mime_type is None: + mime_type = "text/plain" + + return self._deliver_binary( + binary, + save_directory=save_directory, + save_filename=file_name, + open_method=open_method, + encoding=encoding, + mime_type=mime_type, + name=name, + ) + + def deliver_binary( + self, + path_or_file: str | Path | BinaryIO, + *, + save_directory: str | Path | None = None, + save_filename: str | None = None, + open_method: Literal["browser", "download"] = "download", + mime_type: str | None = None, + name: str | None = None, + ) -> str | None: + """Deliver a binary file to the end-user of the application. + + If an IO object is supplied, it will be closed by this method + and *must not be used* after it is supplied to this method. + + If running in a terminal, this will save the file to the user's + downloads directory. + + If running via a web browser, this will initiate a download via + a single-use URL. + + This operation runs in a thread when running on web, so this method + returning does not indicate that the file has been delivered. + + After the file has been delivered, a `DeliveryComplete` message will be posted + to this `App`, which contains the `delivery_key` returned by this method. By + handling this message, you can add custom logic to your application that fires + only after the file has been delivered. + + Args: + path_or_file: The path or file-like object to save. + save_directory: The directory to save the file to. If None, + the default "downloads" directory will be used. This + argument is ignored when running via the web. + save_filename: The filename to save the file to. If None, the following logic + applies to generate the filename: + - If `path_or_file` is a file-like object, the filename will be taken from + the `name` attribute if available. + - If `path_or_file` is a path, the filename will be taken from the path. + - If a filename is not available, a filename will be generated using the + App's title and the current date and time. + open_method: The method to use to open the file. "browser" will open the file in the + web browser, "download" will initiate a download. Note that this can sometimes + be impacted by the browser's settings. + mime_type: The MIME type of the file or None to guess based on file extension. + If no MIME type is supplied and we cannot guess the MIME type, from the + file extension, the MIME type will be set to "application/octet-stream". + name: A user-defined named which will be returned in [`DeliveryComplete`][textual.events.DeliveryComplete] + and [`DeliveryComplete`][textual.events.DeliveryComplete]. + + Returns: + The delivery key that uniquely identifies the file delivery. + """ + # Ensure `path_or_file` is a file-like object - convert if needed. + if isinstance(path_or_file, (str, Path)): + binary_path = Path(path_or_file) + binary = binary_path.open("rb") + file_name = save_filename or binary_path.name + else: # IO object + binary = path_or_file + file_name = save_filename or getattr(path_or_file, "name", None) + + # If we could infer a filename, and no MIME type was supplied, guess the MIME type. + if file_name and not mime_type: + mime_type, _ = mimetypes.guess_type(file_name) + + # Still no MIME type? Default it to "application/octet-stream". + if mime_type is None: + mime_type = "application/octet-stream" + + return self._deliver_binary( + binary, + save_directory=save_directory, + save_filename=file_name, + open_method=open_method, + mime_type=mime_type, + encoding=None, + name=name, + ) + + def _deliver_binary( + self, + binary: BinaryIO | TextIO, + *, + save_directory: str | Path | None, + save_filename: str | None, + open_method: Literal["browser", "download"], + encoding: str | None = None, + mime_type: str | None = None, + name: str | None = None, + ) -> str | None: + """Deliver a binary file to the end-user of the application.""" + if self._driver is None: + return None + + # Generate a filename if the file-like object doesn't have one. + if save_filename is None: + save_filename = generate_datetime_filename(self.title, "") + + # Find the appropriate save location if not specified. + save_directory = ( + user_downloads_path() if save_directory is None else Path(save_directory) + ) + + # Generate a unique key for this delivery + delivery_key = str(uuid.uuid4().hex) + + # Save the file. The driver will determine the appropriate action + # to take here. It could mean simply writing to the save_path, or + # sending the file to the web browser for download. + self._driver.deliver_binary( + binary, + delivery_key=delivery_key, + save_path=save_directory / save_filename, + encoding=encoding, + open_method=open_method, + mime_type=mime_type, + name=name, + ) + + return delivery_key + + @on(events.DeliveryComplete) + def _on_delivery_complete(self, event: events.DeliveryComplete) -> None: + """Handle a successfully delivered screenshot.""" + if event.name == "screenshot": + if event.path is None: + self.notify("Saved screenshot", title="Screenshot") + else: + self.notify( + f"Saved screenshot to [$text-success]{str(event.path)!r}", + title="Screenshot", + ) + + @on(events.DeliveryFailed) + def _on_delivery_failed(self, event: events.DeliveryComplete) -> None: + """Handle a failure to deliver the screenshot.""" + if event.name == "screenshot": + self.notify( + "Failed to save screenshot", title="Screenshot", severity="error" + ) + + @on(messages.InBandWindowResize) + def _on_in_band_window_resize(self, message: messages.InBandWindowResize) -> None: + """In band window resize enables smooth scrolling.""" + self.supports_smooth_scrolling = message.enabled + self.log.debug(message) + + def _on_idle(self) -> None: + """Send app resize events on idle, so we don't do more resizing that necessary.""" + event = self._resize_event + if event is not None: + self._resize_event = None + self.screen.post_message(event) + for screen in self._background_screens: + screen.post_message(event) diff --git a/src/memray/_vendor/textual/await_complete.py b/src/memray/_vendor/textual/await_complete.py new file mode 100644 index 0000000000..bcc3402f05 --- /dev/null +++ b/src/memray/_vendor/textual/await_complete.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +from asyncio import Future, gather +from typing import TYPE_CHECKING, Any, Awaitable, Generator + +import rich.repr +from typing_extensions import Self + +from memray._vendor.textual._debug import get_caller_file_and_line +from memray._vendor.textual.message_pump import MessagePump + +if TYPE_CHECKING: + from memray._vendor.textual.types import CallbackType + + +@rich.repr.auto(angular=True) +class AwaitComplete: + """An 'optionally-awaitable' object which runs one or more coroutines (or other awaitables) concurrently.""" + + def __init__( + self, *awaitables: Awaitable, pre_await: CallbackType | None = None + ) -> None: + """Create an AwaitComplete. + + Args: + awaitables: One or more awaitables to run concurrently. + """ + self._awaitables = awaitables + self._future: Future[Any] = gather(*awaitables) + self._pre_await: CallbackType | None = pre_await + self._caller = get_caller_file_and_line() + + def __rich_repr__(self) -> rich.repr.Result: + yield self._awaitables + yield "pre_await", self._pre_await, None + yield "caller", self._caller, None + + def set_pre_await_callback(self, pre_await: CallbackType | None) -> None: + """Set a callback to run prior to awaiting. + + This is used by Textual, mainly to check for possible deadlocks. + You are unlikely to need to call this method in an app. + + Args: + pre_await: A callback. + """ + self._pre_await = pre_await + + def call_next(self, node: MessagePump) -> Self: + """Await after the next message. + + Args: + node: The node which created the object. + """ + node.call_next(self) + return self + + async def __call__(self) -> Any: + return await self + + def __await__(self) -> Generator[Any, None, Any]: + _rich_traceback_omit = True + if self._pre_await is not None: + self._pre_await() + return self._future.__await__() + + @property + def is_done(self) -> bool: + """`True` if the task has completed.""" + return self._future.done() + + @property + def exception(self) -> BaseException | None: + """An exception if the awaitables failed.""" + if self._future.done(): + return self._future.exception() + return None + + @classmethod + def nothing(cls): + """Returns an already completed instance of AwaitComplete.""" + instance = cls() + instance._future = Future() + instance._future.set_result(None) # Mark it as completed with no result + return instance diff --git a/src/memray/_vendor/textual/await_remove.py b/src/memray/_vendor/textual/await_remove.py new file mode 100644 index 0000000000..f508d553b4 --- /dev/null +++ b/src/memray/_vendor/textual/await_remove.py @@ -0,0 +1,47 @@ +""" +An *optionally* awaitable object returned by methods that remove widgets. +""" + +from __future__ import annotations + +import asyncio +from asyncio import Task, gather +from typing import Generator + +import rich.repr + +from memray._vendor.textual._callback import invoke +from memray._vendor.textual._debug import get_caller_file_and_line +from memray._vendor.textual._types import CallbackType + + +@rich.repr.auto +class AwaitRemove: + """An awaitable that waits for nodes to be removed.""" + + def __init__( + self, tasks: list[Task], post_remove: CallbackType | None = None + ) -> None: + self._tasks = tasks + self._post_remove = post_remove + self._caller = get_caller_file_and_line() + + def __rich_repr__(self) -> rich.repr.Result: + yield "tasks", self._tasks + yield "post_remove", self._post_remove + yield "caller", self._caller, None + + async def __call__(self) -> None: + await self + + def __await__(self) -> Generator[None, None, None]: + current_task = asyncio.current_task() + tasks = [task for task in self._tasks if task is not current_task] + + async def await_prune() -> None: + """Wait for the prune operation to finish.""" + await gather(*tasks) + if self._post_remove is not None: + await invoke(self._post_remove) + + return await_prune().__await__() diff --git a/src/memray/_vendor/textual/binding.py b/src/memray/_vendor/textual/binding.py new file mode 100644 index 0000000000..4b0573a6c6 --- /dev/null +++ b/src/memray/_vendor/textual/binding.py @@ -0,0 +1,400 @@ +""" + +This module contains the `Binding` class and related objects. + +See [bindings](/guide/input#bindings) in the guide for details. +""" + +from __future__ import annotations + +import dataclasses +from dataclasses import dataclass +from typing import TYPE_CHECKING, Iterable, Iterator, Mapping, NamedTuple + +import rich.repr + +from memray._vendor.textual.keys import _character_to_key + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + from memray._vendor.textual.dom import DOMNode + +BindingType: TypeAlias = "Binding | tuple[str, str] | tuple[str, str, str]" +"""The possible types of a binding found in the `BINDINGS` class variable.""" + +BindingIDString: TypeAlias = str +"""The ID of a Binding defined somewhere in the application. + +Corresponds to the `id` parameter of the `Binding` class. +""" + +KeyString: TypeAlias = str +"""A string that represents a key binding. + +For example, "x", "ctrl+i", "ctrl+shift+a", "ctrl+j,space,x", etc. +""" + +Keymap = Mapping[BindingIDString, KeyString] +"""A mapping of binding IDs to key strings, used for overriding default key bindings.""" + + +class BindingError(Exception): + """A binding related error.""" + + +class NoBinding(Exception): + """A binding was not found.""" + + +class InvalidBinding(Exception): + """Binding key is in an invalid format.""" + + +@dataclass(frozen=True) +class Binding: + """The configuration of a key binding.""" + + key: str + """Key to bind. This can also be a comma-separated list of keys to map multiple keys to a single action.""" + action: str + """Action to bind to.""" + description: str = "" + """Description of action.""" + show: bool = True + """Show the action in Footer, or False to hide.""" + key_display: str | None = None + """How the key should be shown in footer. + + If `None`, the display of the key will use the result of `App.get_key_display`. + + If overridden in a keymap then this value is ignored. + """ + priority: bool = False + """Enable priority binding for this key.""" + tooltip: str = "" + """Optional tooltip to show in footer.""" + + id: str | None = None + """ID of the binding. Intended to be globally unique, but uniqueness is not enforced. + + If specified in the App's keymap then Textual will use this ID to lookup the binding, + and substitute the `key` property of the Binding with the key specified in the keymap. + """ + system: bool = False + """Make this binding a system binding, which removes it from the key panel.""" + + @dataclass(frozen=True) + class Group: + """A binding group causes the keys to be grouped under a single description.""" + + description: str = "" + """Description of the group.""" + + compact: bool = False + """Show keys in compact form (no spaces).""" + + group: Group | None = None + """Optional binding group (used to group related bindings in the footer).""" + + def parse_key(self) -> tuple[list[str], str]: + """Parse a key into a list of modifiers, and the actual key. + + Returns: + A tuple of (MODIFIER LIST, KEY). + """ + *modifiers, key = self.key.split("+") + return modifiers, key + + def with_key(self, key: str, key_display: str | None = None) -> Binding: + """Return a new binding with the key and key_display set to the specified values. + + Args: + key: The new key to set. + key_display: The new key display to set. + + Returns: + A new binding with the key set to the specified value. + """ + return dataclasses.replace(self, key=key, key_display=key_display) + + @classmethod + def make_bindings(cls, bindings: Iterable[BindingType]) -> Iterable[Binding]: + """Convert a list of BindingType (the types that can be specified in BINDINGS) + into an Iterable[Binding]. + + Compound bindings like "j,down" will be expanded into 2 Binding instances. + + Args: + bindings: An iterable of BindingType. + + Returns: + An iterable of Binding. + """ + bindings = list(bindings) + for binding in bindings: + # If it's a tuple of length 3, convert into a Binding first + if isinstance(binding, tuple): + if len(binding) not in (2, 3): + raise BindingError( + f"BINDINGS must contain a tuple of two or three strings, not {binding!r}" + ) + # `binding` is a tuple of 2 or 3 values at this point + binding = Binding(*binding) # type: ignore[reportArgumentType] + + # At this point we have a Binding instance, but the key may + # be a list of keys, so now we unroll that single Binding + # into a (potential) collection of Binding instances. + for key in binding.key.split(","): + key = key.strip() + if not key: + raise InvalidBinding( + f"Can not bind empty string in {binding.key!r}" + ) + if len(key) == 1: + key = _character_to_key(key) + + yield Binding( + key=key, + action=binding.action, + description=binding.description, + show=bool(binding.description and binding.show), + key_display=binding.key_display, + priority=binding.priority, + tooltip=binding.tooltip, + id=binding.id, + system=binding.system, + group=binding.group, + ) + + +class ActiveBinding(NamedTuple): + """Information about an active binding (returned from [active_bindings][textual.screen.Screen.active_bindings]).""" + + node: DOMNode + """The node where the binding is defined.""" + binding: Binding + """The binding information.""" + enabled: bool + """Is the binding enabled? (enabled bindings are typically rendered dim)""" + tooltip: str = "" + """Optional tooltip shown in Footer.""" + + +@rich.repr.auto +class BindingsMap: + """Manage a set of bindings.""" + + def __init__( + self, + bindings: Iterable[BindingType] | None = None, + ) -> None: + """Initialise a collection of bindings. + + Args: + bindings: An optional set of initial bindings. + + Note: + The iterable of bindings can contain either a `Binding` + instance, or a tuple of 3 values mapping to the first three + properties of a `Binding`. + """ + + self.key_to_bindings: dict[str, list[Binding]] = {} + """Mapping of key (e.g. "ctrl+a") to list of bindings for that key.""" + + for binding in Binding.make_bindings(bindings or {}): + self.key_to_bindings.setdefault(binding.key, []).append(binding) + + def _add_binding(self, binding: Binding) -> None: + """Add a new binding. + + Args: + binding: New Binding to add. + """ + self.key_to_bindings.setdefault(binding.key, []).append(binding) + + def __iter__(self) -> Iterator[tuple[str, Binding]]: + """Iterating produces a sequence of (KEY, BINDING) tuples.""" + return iter( + [ + (key, binding) + for key, bindings in self.key_to_bindings.items() + for binding in bindings + ] + ) + + @classmethod + def from_keys(cls, keys: dict[str, list[Binding]]) -> BindingsMap: + """Construct a BindingsMap from a dict of keys and bindings. + + Args: + keys: A dict that maps a key on to a list of `Binding` objects. + + Returns: + New `BindingsMap` + """ + bindings = cls() + bindings.key_to_bindings = keys + return bindings + + def copy(self) -> BindingsMap: + """Return a copy of this instance. + + Return: + New bindings object. + """ + copy = BindingsMap() + copy.key_to_bindings = self.key_to_bindings.copy() + return copy + + def __rich_repr__(self) -> rich.repr.Result: + yield self.key_to_bindings + + @classmethod + def merge(cls, bindings: Iterable[BindingsMap]) -> BindingsMap: + """Merge a bindings. + + Args: + bindings: A number of bindings. + + Returns: + New `BindingsMap`. + """ + keys: dict[str, list[Binding]] = {} + for _bindings in bindings: + for key, key_bindings in _bindings.key_to_bindings.items(): + keys.setdefault(key, []).extend(key_bindings) + return BindingsMap.from_keys(keys) + + def apply_keymap(self, keymap: Keymap) -> KeymapApplyResult: + """Replace bindings for keys that are present in `keymap`. + + Preserves existing bindings for keys that are not in `keymap`. + + Args: + keymap: A keymap to overlay. + + Returns: + KeymapApplyResult: The result of applying the keymap, including any clashed bindings. + """ + clashed_bindings: set[Binding] = set() + new_bindings: dict[str, list[Binding]] = {} + + key_to_bindings = list(self.key_to_bindings.items()) + for key, bindings in key_to_bindings: + for binding in bindings: + binding_id = binding.id + if binding_id is None: + # Bindings without an ID are irrelevant when applying a keymap + continue + + # If the keymap has an override for this binding ID + if keymap_key_string := keymap.get(binding_id): + keymap_keys = keymap_key_string.split(",") + + # Remove the old binding + for key, key_bindings in key_to_bindings: + key = key.strip() + if any(binding.id == binding_id for binding in key_bindings): + if key in self.key_to_bindings: + del self.key_to_bindings[key] + + for keymap_key in keymap_keys: + if ( + keymap_key in self.key_to_bindings + or keymap_key in new_bindings + ): + # The key is already mapped either by default or by the keymap, + # so there's a clash unless the existing binding is being rebound + # to a different key. + clashing_bindings = self.key_to_bindings.get( + keymap_key, [] + ) + new_bindings.get(keymap_key, []) + for clashed_binding in clashing_bindings: + # If the existing binding is not being rebound, it's a clash + if not ( + clashed_binding.id + and keymap.get(clashed_binding.id) + != clashed_binding.key + ): + clashed_bindings.add(clashed_binding) + + if keymap_key in self.key_to_bindings: + del self.key_to_bindings[keymap_key] + + for keymap_key in keymap_keys: + new_bindings.setdefault(keymap_key, []).append( + binding.with_key(key=keymap_key, key_display=None) + ) + + # Update the key_to_bindings with the new bindings + self.key_to_bindings.update(new_bindings) + return KeymapApplyResult(clashed_bindings) + + @property + def shown_keys(self) -> list[Binding]: + """A list of bindings for shown keys.""" + keys = [ + binding + for bindings in self.key_to_bindings.values() + for binding in bindings + if binding.show + ] + return keys + + def bind( + self, + keys: str, + action: str, + description: str = "", + show: bool = True, + key_display: str | None = None, + priority: bool = False, + ) -> None: + """Bind keys to an action. + + Args: + keys: The keys to bind. Can be a comma-separated list of keys. + action: The action to bind the keys to. + description: An optional description for the binding. + show: A flag to say if the binding should appear in the footer. + key_display: Optional string to display in the footer for the key. + priority: Is this a priority binding, checked form app down to focused widget? + """ + all_keys = [key.strip() for key in keys.split(",")] + for key in all_keys: + self.key_to_bindings.setdefault(key, []).append( + Binding( + key, + action, + description, + show=bool(description and show), + key_display=key_display, + priority=priority, + ) + ) + + def get_bindings_for_key(self, key: str) -> list[Binding]: + """Get a list of bindings for a given key. + + Args: + key: Key to look up. + + Raises: + NoBinding: If the binding does not exist. + + Returns: + A list of bindings associated with the key. + """ + try: + return self.key_to_bindings[key] + except KeyError: + raise NoBinding(f"No binding for {key}") from None + + +class KeymapApplyResult(NamedTuple): + """The result of applying a keymap.""" + + clashed_bindings: set[Binding] + """A list of bindings that were clashed and replaced by the keymap.""" diff --git a/src/memray/_vendor/textual/box_model.py b/src/memray/_vendor/textual/box_model.py new file mode 100644 index 0000000000..aa68fcdd4f --- /dev/null +++ b/src/memray/_vendor/textual/box_model.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from fractions import Fraction +from typing import NamedTuple + +from memray._vendor.textual.geometry import Spacing + + +class BoxModel(NamedTuple): + """The result of `get_box_model`.""" + + # Content + padding + border + width: Fraction + height: Fraction + margin: Spacing # Additional margin diff --git a/src/memray/_vendor/textual/cache.py b/src/memray/_vendor/textual/cache.py new file mode 100644 index 0000000000..66fcdb7eed --- /dev/null +++ b/src/memray/_vendor/textual/cache.py @@ -0,0 +1,314 @@ +""" + +Cache classes are dict-like containers used to avoid recalculating expensive operations such as rendering. + +You can also use them in your own apps for similar reasons. + +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Dict, Generic, KeysView, TypeVar, overload + +CacheKey = TypeVar("CacheKey") +CacheValue = TypeVar("CacheValue") +DefaultValue = TypeVar("DefaultValue") + +__all__ = ["LRUCache", "FIFOCache"] + + +class LRUCache(Generic[CacheKey, CacheValue]): + """ + A dictionary-like container with a maximum size. + + If an additional item is added when the LRUCache is full, the least + recently used key is discarded to make room for the new item. + + The implementation is similar to functools.lru_cache, which uses a (doubly) + linked list to keep track of the most recently used items. + + Each entry is stored as [PREV, NEXT, KEY, VALUE] where PREV is a reference + to the previous entry, and NEXT is a reference to the next value. + + Note that stdlib's @lru_cache is implemented in C and faster! It's best to use + @lru_cache where you are caching things that are fairly quick and called many times. + Use LRUCache where you want increased flexibility and you are caching slow operations + where the overhead of the cache is a small fraction of the total processing time. + """ + + __slots__ = [ + "_maxsize", + "_cache", + "_full", + "_head", + "hits", + "misses", + ] + + def __init__(self, maxsize: int) -> None: + """Initialize a LRUCache. + + Args: + maxsize: Maximum size of the cache, before old items are discarded. + """ + self._maxsize = maxsize + self._cache: Dict[CacheKey, list[object]] = {} + self._full = False + self._head: list[object] = [] + self.hits = 0 + self.misses = 0 + super().__init__() + + @property + def maxsize(self) -> int: + """int: Maximum size of cache, before new values evict old values.""" + return self._maxsize + + @maxsize.setter + def maxsize(self, maxsize: int) -> None: + self._maxsize = maxsize + + def __bool__(self) -> bool: + return bool(self._cache) + + def __len__(self) -> int: + return len(self._cache) + + def __repr__(self) -> str: + return f"" + + def grow(self, maxsize: int) -> None: + """Grow the maximum size to at least `maxsize` elements. + + Args: + maxsize: New maximum size. + """ + self.maxsize = max(self.maxsize, maxsize) + + def clear(self) -> None: + """Clear the cache.""" + self._cache.clear() + self._full = False + self._head = [] + + def keys(self) -> KeysView[CacheKey]: + """Get cache keys.""" + # Mostly for tests + return self._cache.keys() + + def set(self, key: CacheKey, value: CacheValue) -> None: + """Set a value. + + Args: + key: Key. + value: Value. + """ + if self._cache.get(key) is None: + head = self._head + if not head: + # First link references itself + self._head[:] = [head, head, key, value] + else: + # Add a new root to the beginning + self._head = [head[0], head, key, value] + # Updated references on previous root + head[0][1] = self._head # type: ignore[index] + head[0] = self._head + self._cache[key] = self._head + + if self._full or len(self._cache) > self._maxsize: + # Cache is full, we need to evict the oldest one + self._full = True + head = self._head + last = head[0] + last[0][1] = head # type: ignore[index] + head[0] = last[0] # type: ignore[index] + del self._cache[last[2]] # type: ignore[index] + + __setitem__ = set + + if TYPE_CHECKING: + + @overload + def get(self, key: CacheKey) -> CacheValue | None: ... + + @overload + def get( + self, key: CacheKey, default: DefaultValue + ) -> CacheValue | DefaultValue: ... + + def get( + self, key: CacheKey, default: DefaultValue | None = None + ) -> CacheValue | DefaultValue | None: + """Get a value from the cache, or return a default if the key is not present. + + Args: + key: Key + default: Default to return if key is not present. + + Returns: + Either the value or a default. + """ + + if (link := self._cache.get(key)) is None: + self.misses += 1 + return default + if link is not self._head: + # Remove link from list + link[0][1] = link[1] # type: ignore[index] + link[1][0] = link[0] # type: ignore[index] + head = self._head + # Move link to head of list + link[0] = head[0] + link[1] = head + self._head = head[0][1] = head[0] = link # type: ignore[index] + self.hits += 1 + return link[3] # type: ignore[return-value] + + def __getitem__(self, key: CacheKey) -> CacheValue: + link = self._cache.get(key) + if (link := self._cache.get(key)) is None: + self.misses += 1 + raise KeyError(key) + if link is not self._head: + link[0][1] = link[1] # type: ignore[index] + link[1][0] = link[0] # type: ignore[index] + head = self._head + link[0] = head[0] + link[1] = head + self._head = head[0][1] = head[0] = link # type: ignore[index] + self.hits += 1 + return link[3] # type: ignore[return-value] + + def __contains__(self, key: CacheKey) -> bool: + return key in self._cache + + def discard(self, key: CacheKey) -> None: + """Discard item in cache from key. + + Args: + key: Cache key. + """ + if key not in self._cache: + return + link = self._cache[key] + + # Remove link from list + link[0][1] = link[1] # type: ignore[index] + link[1][0] = link[0] # type: ignore[index] + # Remove link from cache + + if self._head[2] == key: + self._head = self._head[1] # type: ignore[assignment] + if self._head[2] == key: # type: ignore[index] + self._head = [] + + del self._cache[key] + self._full = False + + +class FIFOCache(Generic[CacheKey, CacheValue]): + """A simple cache that discards the first added key when full (First In First Out). + + This has a lower overhead than LRUCache, but won't manage a working set as efficiently. + It is most suitable for a cache with a relatively low maximum size that is not expected to + do many lookups. + + """ + + __slots__ = [ + "_maxsize", + "_cache", + "hits", + "misses", + ] + + def __init__(self, maxsize: int) -> None: + """Initialize a FIFOCache. + + Args: + maxsize: Maximum size of cache before discarding items. + """ + self._maxsize = maxsize + self._cache: dict[CacheKey, CacheValue] = {} + self.hits = 0 + self.misses = 0 + + def __bool__(self) -> bool: + return bool(self._cache) + + def __len__(self) -> int: + return len(self._cache) + + def __repr__(self) -> str: + return ( + f"" + ) + + def clear(self) -> None: + """Clear the cache.""" + self._cache.clear() + + def keys(self) -> KeysView[CacheKey]: + """Get cache keys.""" + # Mostly for tests + return self._cache.keys() + + def set(self, key: CacheKey, value: CacheValue) -> None: + """Set a value. + + Args: + key: Key. + value: Value. + """ + if key not in self._cache and len(self._cache) >= self._maxsize: + for first_key in self._cache: + self._cache.pop(first_key) + break + self._cache[key] = value + + __setitem__ = set + + if TYPE_CHECKING: + + @overload + def get(self, key: CacheKey) -> CacheValue | None: ... + + @overload + def get( + self, key: CacheKey, default: DefaultValue + ) -> CacheValue | DefaultValue: ... + + def get( + self, key: CacheKey, default: DefaultValue | None = None + ) -> CacheValue | DefaultValue | None: + """Get a value from the cache, or return a default if the key is not present. + + Args: + key: Key + default: Default to return if key is not present. + + Returns: + Either the value or a default. + """ + try: + result = self._cache[key] + except KeyError: + self.misses += 1 + return default + else: + self.hits += 1 + return result + + def __getitem__(self, key: CacheKey) -> CacheValue: + try: + result = self._cache[key] + except KeyError: + self.misses += 1 + raise KeyError(key) from None + else: + self.hits += 1 + return result + + def __contains__(self, key: CacheKey) -> bool: + return key in self._cache diff --git a/src/memray/_vendor/textual/canvas.py b/src/memray/_vendor/textual/canvas.py new file mode 100644 index 0000000000..96d9ea77ca --- /dev/null +++ b/src/memray/_vendor/textual/canvas.py @@ -0,0 +1,284 @@ +""" +A Canvas class used to render keylines. + +!!! note + This API is experimental, and may change in the near future. + +""" + +from __future__ import annotations + +import sys +from array import array +from collections import defaultdict +from dataclasses import dataclass +from operator import itemgetter +from typing import NamedTuple, Sequence + +from rich.segment import Segment +from rich.style import Style +from typing_extensions import Literal, TypeAlias + +from memray._vendor.textual._box_drawing import BOX_CHARACTERS, Quad, combine_quads +from memray._vendor.textual.color import Color +from memray._vendor.textual.geometry import Offset, clamp +from memray._vendor.textual.strip import Strip, StripRenderable + +CanvasLineType: TypeAlias = Literal["thin", "heavy", "double"] + + +_LINE_TYPE_INDEX = {"thin": 1, "heavy": 2, "double": 3} + + +class _Span(NamedTuple): + """Associates a sequence of character indices with a color.""" + + start: int + end: int # exclusive + color: Color + + +class Primitive: + """Base class for a canvas primitive.""" + + def render(self, canvas: Canvas) -> None: + """Render to the canvas. + + Args: + canvas: Canvas instance. + """ + raise NotImplementedError() + + +@dataclass +class HorizontalLine(Primitive): + """A horizontal line.""" + + origin: Offset + length: int + color: Color + line_type: CanvasLineType = "thin" + + def render(self, canvas: Canvas) -> None: + x, y = self.origin + if y < 0 or y > canvas.height - 1: + return + box = canvas.box + box_line = box[y] + + line_type_index = _LINE_TYPE_INDEX[self.line_type] + _combine_quads = combine_quads + + right = x + self.length - 1 + + x_range = canvas.x_range(x, x + self.length) + + if x in x_range: + box_line[x] = _combine_quads(box_line[x], (0, line_type_index, 0, 0)) + if right in x_range: + box_line[right] = _combine_quads( + box_line[right], (0, 0, 0, line_type_index) + ) + + line_quad = (0, line_type_index, 0, line_type_index) + for box_x in canvas.x_range(x + 1, x + self.length - 1): + box_line[box_x] = _combine_quads(box_line[box_x], line_quad) + + canvas.spans[y].append(_Span(x, x + self.length, self.color)) + + +@dataclass +class VerticalLine(Primitive): + """A vertical line.""" + + origin: Offset + length: int + color: Color + line_type: CanvasLineType = "thin" + + def render(self, canvas: Canvas) -> None: + x, y = self.origin + if x < 0 or x >= canvas.width: + return + line_type_index = _LINE_TYPE_INDEX[self.line_type] + box = canvas.box + _combine_quads = combine_quads + + y_range = canvas.y_range(y, y + self.length) + + if y in y_range: + box[y][x] = _combine_quads(box[y][x], (0, 0, line_type_index, 0)) + bottom = y + self.length - 1 + + if bottom in y_range: + box[bottom][x] = _combine_quads(box[bottom][x], (line_type_index, 0, 0, 0)) + line_quad = (line_type_index, 0, line_type_index, 0) + + for box_y in canvas.y_range(y + 1, y + self.length - 1): + box[box_y][x] = _combine_quads(box[box_y][x], line_quad) + + spans = canvas.spans + span = _Span(x, x + 1, self.color) + for y in y_range: + spans[y].append(span) + + +@dataclass +class Rectangle(Primitive): + """A rectangle.""" + + origin: Offset + width: int + height: int + color: Color + line_type: CanvasLineType = "thin" + + def render(self, canvas: Canvas) -> None: + origin = self.origin + width = self.width + height = self.height + color = self.color + line_type = self.line_type + HorizontalLine(origin, width, color, line_type).render(canvas) + HorizontalLine(origin + (0, height - 1), width, color, line_type).render(canvas) + VerticalLine(origin, height, color, line_type).render(canvas) + VerticalLine(origin + (width - 1, 0), height, color, line_type).render(canvas) + + +class Canvas: + """A character canvas.""" + + def __init__(self, width: int, height: int) -> None: + """ + + Args: + width: Width of the canvas (in cells). + height Height of the canvas (in cells). + """ + self._width = width + self._height = height + blank_line = " " * width + array_type_code = "w" if sys.version_info >= (3, 13) else "u" + self.lines: list[array[str]] = [ + array(array_type_code, blank_line) for _ in range(height) + ] + self.box: list[defaultdict[int, Quad]] = [ + defaultdict(lambda: (0, 0, 0, 0)) for _ in range(height) + ] + self.spans: list[list[_Span]] = [[] for _ in range(height)] + + @property + def width(self) -> int: + """The canvas width.""" + return self._width + + @property + def height(self) -> int: + """The canvas height.""" + return self._height + + def x_range(self, start: int, end: int) -> range: + """Range of x values, clipped to the canvas dimensions. + + Args: + start: Start index. + end: End index. + + Returns: + A range object. + """ + return range( + clamp(start, 0, self._width), + clamp(end, 0, self._width), + ) + + def y_range(self, start: int, end: int) -> range: + """Range of y values, clipped to the canvas dimensions. + + Args: + start: Start index. + end: End index. + + Returns: + A range object. + """ + return range( + clamp(start, 0, self._height), + clamp(end, 0, self._height), + ) + + def render( + self, primitives: Sequence[Primitive], base_style: Style + ) -> StripRenderable: + """Render the canvas. + + Args: + primitives: A sequence of primitives. + base_style: The base style of the canvas. + + Returns: + A Rich renderable for the canvas. + """ + for primitive in primitives: + primitive.render(self) + + get_box = BOX_CHARACTERS.__getitem__ + for box, line in zip(self.box, self.lines): + for offset, quad in box.items(): + line[offset] = get_box(quad) + + width = self._width + span_sort_key = itemgetter(0, 1) + strips: list[Strip] = [] + color = ( + Color.from_rich_color(base_style.bgcolor) + if base_style.bgcolor + else Color.parse("transparent") + ) + _Segment = Segment + for raw_spans, line in zip(self.spans, self.lines): + text = line.tounicode() + + if raw_spans: + segments: list[Segment] = [] + colors = [color] + [span.color for span in raw_spans] + spans = [ + (0, False, 0), + *( + (span.start, False, index) + for index, span in enumerate(raw_spans, 1) + ), + *( + (span.end, True, index) + for index, span in enumerate(raw_spans, 1) + ), + (width, True, 0), + ] + spans.sort(key=span_sort_key) + color_indices: set[int] = set() + color_remove = color_indices.discard + color_add = color_indices.add + for (offset, leaving, style_id), (next_offset, _, _) in zip( + spans, spans[1:] + ): + if leaving: + color_remove(style_id) + else: + color_add(style_id) + if next_offset > offset: + segments.append( + _Segment( + text[offset:next_offset], + base_style + + Style.from_color( + colors[ + max(color_indices) if color_indices else 0 + ].rich_color + ), + ) + ) + strips.append(Strip(segments, width)) + else: + strips.append(Strip([_Segment(text, base_style)], width)) + + return StripRenderable(strips, width) diff --git a/src/memray/_vendor/textual/case.py b/src/memray/_vendor/textual/case.py new file mode 100644 index 0000000000..e091f7aef2 --- /dev/null +++ b/src/memray/_vendor/textual/case.py @@ -0,0 +1,23 @@ +import re +from typing import Match, Pattern + + +def camel_to_snake( + name: str, _re_snake: Pattern[str] = re.compile("[a-z][A-Z]") +) -> str: + """Convert name from CamelCase to snake_case. + + Args: + name: A symbol name, such as a class name. + + Returns: + Name in snake case. + """ + + def repl(match: Match[str]) -> str: + lower: str + upper: str + lower, upper = match.group() # type: ignore + return f"{lower}_{upper.lower()}" + + return _re_snake.sub(repl, name).lower() diff --git a/src/memray/_vendor/textual/clock.py b/src/memray/_vendor/textual/clock.py new file mode 100644 index 0000000000..6df651bcc9 --- /dev/null +++ b/src/memray/_vendor/textual/clock.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from time import monotonic +from typing import Callable + +import rich.repr + + +@rich.repr.auto(angular=True) +class Clock: + """An object to get relative time. + + The `time` attribute of clock will return the time in seconds since the + Clock was created or reset. + + """ + + def __init__(self, *, get_time: Callable[[], float] = monotonic) -> None: + """Create a clock. + + Args: + get_time: A callable to get time in seconds. + start: Start the clock (time is 0 unless clock has been started). + """ + self._get_time = get_time + self._start_time = self._get_time() + + def __rich_repr__(self) -> rich.repr.Result: + yield self.time + + def clone(self) -> Clock: + """Clone the Clock with an independent time.""" + return Clock(get_time=self._get_time) + + def reset(self) -> None: + """Reset the clock.""" + self._start_time = self._get_time() + + @property + def time(self) -> float: + """Time since creation or reset.""" + return self._get_time() - self._start_time + + +class MockClock(Clock): + """A mock clock object where the time may be explicitly set.""" + + def __init__(self, time: float = 0.0) -> None: + """Construct a mock clock.""" + self._time = time + super().__init__(get_time=lambda: self._time) + + def clone(self) -> MockClock: + """Clone the mocked clock (clone will return the same time as original).""" + clock = MockClock(self._time) + clock._get_time = self._get_time + clock._time = self._time + return clock + + def reset(self) -> None: + """A null-op because it doesn't make sense to reset a mocked clock.""" + + def set_time(self, time: float) -> None: + """Set the time for the clock. + + Args: + time: Time to set. + """ + self._time = time + + @property + def time(self) -> float: + """Time since creation or reset.""" + return self._get_time() diff --git a/src/memray/_vendor/textual/color.py b/src/memray/_vendor/textual/color.py new file mode 100644 index 0000000000..fead21139c --- /dev/null +++ b/src/memray/_vendor/textual/color.py @@ -0,0 +1,841 @@ +""" +This module contains a powerful [Color][textual.color.Color] class which Textual uses to manipulate colors. + +## Named colors + +The following named colors are used by the [parse][textual.color.Color.parse] method. + + +```{.rich columns="80" title="colors"} +from memray._vendor.textual._color_constants import COLOR_NAME_TO_RGB +from memray._vendor.textual.color import Color +from rich.table import Table +from rich.text import Text +table = Table("Name", "hex", "RGB", "Color", expand=True, highlight=True) + +for name, triplet in sorted(COLOR_NAME_TO_RGB.items()): + if len(triplet) != 3: + continue + color = Color(*triplet) + r, g, b = triplet + table.add_row( + f'"{name}"', + Text(f"{color.hex}", "bold green"), + f"rgb({r}, {g}, {b})", + Text(" ", style=f"on rgb({r},{g},{b})") + ) +output = table +``` +""" + +from __future__ import annotations + +import re +from colorsys import hls_to_rgb, hsv_to_rgb, rgb_to_hls, rgb_to_hsv +from functools import lru_cache +from operator import itemgetter +from typing import Callable, NamedTuple + +import rich.repr +from rich.color import Color as RichColor +from rich.color import ColorType +from rich.color_triplet import ColorTriplet +from rich.terminal_theme import TerminalTheme +from typing_extensions import Final + +from memray._vendor.textual._color_constants import ANSI_COLORS, COLOR_NAME_TO_RGB +from memray._vendor.textual.css.scalar import percentage_string_to_float +from memray._vendor.textual.css.tokenize import CLOSE_BRACE, COMMA, DECIMAL, OPEN_BRACE, PERCENT +from memray._vendor.textual.geometry import clamp +from memray._vendor.textual.suggestions import get_suggestion + +_TRUECOLOR = ColorType.TRUECOLOR + + +class HSL(NamedTuple): + """A color in HSL (Hue, Saturation, Lightness) format.""" + + h: float + """Hue in range 0 to 1.""" + s: float + """Saturation in range 0 to 1.""" + l: float + """Lightness in range 0 to 1.""" + + @property + def css(self) -> str: + """HSL in css format.""" + h, s, l = self + + def as_str(number: float) -> str: + """Format a float.""" + return f"{number:.1f}".rstrip("0").rstrip(".") + + return f"hsl({as_str(h*360)},{as_str(s*100)}%,{as_str(l*100)}%)" + + +class HSV(NamedTuple): + """A color in HSV (Hue, Saturation, Value) format.""" + + h: float + """Hue in range 0 to 1.""" + s: float + """Saturation in range 0 to 1.""" + v: float + """Value in range 0 to 1.""" + + +class Lab(NamedTuple): + """A color in CIE-L*ab format.""" + + L: float + """Lightness in range 0 to 100.""" + a: float + """A axis in range -127 to 128.""" + b: float + """B axis in range -127 to 128.""" + + +RE_COLOR = re.compile( + rf"""^ +\#([0-9a-fA-F]{{3}})$| +\#([0-9a-fA-F]{{4}})$| +\#([0-9a-fA-F]{{6}})$| +\#([0-9a-fA-F]{{8}})$| +rgb{OPEN_BRACE}({DECIMAL}{COMMA}{DECIMAL}{COMMA}{DECIMAL}){CLOSE_BRACE}$| +rgba{OPEN_BRACE}({DECIMAL}{COMMA}{DECIMAL}{COMMA}{DECIMAL}{COMMA}{DECIMAL}){CLOSE_BRACE}$| +hsl{OPEN_BRACE}({DECIMAL}{COMMA}{PERCENT}{COMMA}{PERCENT}){CLOSE_BRACE}$| +hsla{OPEN_BRACE}({DECIMAL}{COMMA}{PERCENT}{COMMA}{PERCENT}{COMMA}{DECIMAL}){CLOSE_BRACE}$ +""", + re.VERBOSE, +) + +# Fast way to split a string of 6 characters into 3 pairs of 2 characters +_split_pairs3: Callable[[str], tuple[str, str, str]] = itemgetter( + slice(0, 2), slice(2, 4), slice(4, 6) +) +# Fast way to split a string of 8 characters into 4 pairs of 2 characters +_split_pairs4: Callable[[str], tuple[str, str, str, str]] = itemgetter( + slice(0, 2), slice(2, 4), slice(4, 6), slice(6, 8) +) + + +class ColorParseError(Exception): + """A color failed to parse. + + Args: + message: The error message + suggested_color: A close color we can suggest. + """ + + def __init__(self, message: str, suggested_color: str | None = None): + super().__init__(message) + self.suggested_color = suggested_color + + +@rich.repr.auto +class Color(NamedTuple): + """A class to represent a color. + + Colors are stored as three values representing the degree of red, green, and blue in a color, and a + fourth "alpha" value which defines where the color lies on a gradient of opaque to transparent. + + Example: + ```python + >>> from textual.color import Color + >>> color = Color.parse("red") + >>> color + Color(255, 0, 0) + >>> color.darken(0.5) + Color(98, 0, 0) + >>> color + Color.parse("green") + Color(0, 128, 0) + >>> color_with_alpha = Color(100, 50, 25, 0.5) + >>> color_with_alpha + Color(100, 50, 25, a=0.5) + >>> color + color_with_alpha + Color(177, 25, 12) + ``` + """ + + r: int + """Red component in range 0 to 255.""" + g: int + """Green component in range 0 to 255.""" + b: int + """Blue component in range 0 to 255.""" + a: float = 1.0 + """Alpha (opacity) component in range 0 to 1.""" + ansi: int | None = None + """ANSI color index. `-1` means default color. `None` if not an ANSI color.""" + auto: bool = False + """Is the color automatic? (automatic colors may be white or black, to provide maximum contrast)""" + + @classmethod + def automatic(cls, alpha_percentage: float = 100.0) -> Color: + """Create an automatic color.""" + return cls(0, 0, 0, alpha_percentage / 100.0, auto=True) + + @classmethod + @lru_cache(maxsize=1024) + def from_rich_color( + cls, rich_color: RichColor | None, theme: TerminalTheme | None = None + ) -> Color: + """Create a new color from Rich's Color class. + + Args: + rich_color: An instance of [Rich color][rich.color.Color]. + theme: Optional Rich [terminal theme][rich.terminal_theme.TerminalTheme]. + + Returns: + A new Color instance. + """ + if rich_color is None: + return TRANSPARENT + r, g, b = rich_color.get_truecolor(theme) + return cls( + r, g, b, ansi=rich_color.number if rich_color.is_system_defined else None + ) + + @classmethod + def from_hsl(cls, h: float, s: float, l: float) -> Color: + """Create a color from HSL components. + + Args: + h: Hue. + s: Saturation. + l: Lightness. + + Returns: + A new color. + """ + r, g, b = hls_to_rgb(h, l, s) + return cls(int(r * 255 + 0.5), int(g * 255 + 0.5), int(b * 255 + 0.5)) + + @classmethod + def from_hsv(cls, h: float, s: float, v: float) -> Color: + """Create a color from HSV components. + + Args: + h: Hue. + s: Saturation. + v: Value. + + Returns: + A new color. + """ + r, g, b = hsv_to_rgb(h, s, v) + return cls(int(r * 255 + 0.5), int(g * 255 + 0.5), int(b * 255 + 0.5)) + + @property + def inverse(self) -> Color: + """The inverse of this color. + + Returns: + Inverse color. + """ + r, g, b, a, _, _ = self + return Color(255 - r, 255 - g, 255 - b, a) + + @property + def is_transparent(self) -> bool: + """Is the color transparent (i.e. has 0 alpha)?""" + return self.a == 0 and self.ansi is None + + @property + def clamped(self) -> Color: + """A clamped color (this color with all values in expected range).""" + r, g, b, a, ansi, auto = self + _clamp = clamp + color = Color( + _clamp(r, 0, 255), + _clamp(g, 0, 255), + _clamp(b, 0, 255), + _clamp(a, 0.0, 1.0), + ansi, + auto, + ) + return color + + @property + @lru_cache(1024) + def rich_color(self) -> RichColor: + """This color encoded in Rich's Color class. + + Returns: + A color object as used by Rich. + """ + r, g, b, a, ansi, _ = self + if ansi is not None: + return RichColor.parse("default") if ansi < 0 else RichColor.from_ansi(ansi) + return RichColor( + f"#{r:02x}{g:02x}{b:02x}", _TRUECOLOR, None, ColorTriplet(r, g, b) + ) + + @property + def normalized(self) -> tuple[float, float, float]: + """A tuple of the color components normalized to between 0 and 1. + + Returns: + Normalized components. + """ + r, g, b, _a, _, _ = self + return (r / 255, g / 255, b / 255) + + @property + def rgb(self) -> tuple[int, int, int]: + """The red, green, and blue color components as a tuple of ints.""" + r, g, b, _, _, _ = self + return (r, g, b) + + @property + def hsl(self) -> HSL: + """This color in HSL format. + + HSL color is an alternative way of representing a color, which can be used in certain color calculations. + + Returns: + Color encoded in HSL format. + """ + r, g, b = self.normalized + h, l, s = rgb_to_hls(r, g, b) + return HSL(h, s, l) + + @property + def hsv(self) -> HSV: + """This color in HSV format. + + HSV color is an alternative way of representing a color, which can be used in certain color calculations. + + Returns: + Color encoded in HSV format. + """ + r, g, b = self.normalized + h, s, v = rgb_to_hsv(r, g, b) + return HSV(h, s, v) + + @property + def brightness(self) -> float: + """The human perceptual brightness. + + A value of 1 is returned for pure white, and 0 for pure black. + Other colors lie on a gradient between the two extremes. + """ + r, g, b = self.normalized + brightness = (299 * r + 587 * g + 114 * b) / 1000 + return brightness + + @property + def hex(self) -> str: + """The color in CSS hex form, with 6 digits for RGB, and 8 digits for RGBA. + + For example, `"#46B3DE"` for an RGB color, or `"#3342457F"` for a color with alpha. + """ + r, g, b, a, ansi, _ = self.clamped + if ansi is not None: + return "ansi_default" if ansi == -1 else f"ansi_{ANSI_COLORS[ansi]}" + return ( + f"#{r:02X}{g:02X}{b:02X}" + if a == 1 + else f"#{r:02X}{g:02X}{b:02X}{int(a*255):02X}" + ) + + @property + def hex6(self) -> str: + """The color in CSS hex form, with 6 digits for RGB. Alpha is ignored. + + For example, `"#46B3DE"`. + """ + r, g, b, _a, _, _ = self.clamped + return f"#{r:02X}{g:02X}{b:02X}" + + @property + def css(self) -> str: + """The color in CSS RGB or RGBA form. + + For example, `"rgb(10,20,30)"` for an RGB color, or `"rgb(50,70,80,0.5)"` for an RGBA color. + """ + r, g, b, a, ansi, auto = self + if auto: + alpha_percentage = clamp(a, 0.0, 1.0) * 100.0 + if alpha_percentage == 100: + return "auto" + if not alpha_percentage % 1: + return f"auto {int(alpha_percentage)}%" + return f"auto {alpha_percentage:.1f}%" + if ansi is not None: + return "ansi_default" if ansi == -1 else f"ansi_{ANSI_COLORS[ansi]}" + return f"rgb({r},{g},{b})" if a == 1 else f"rgba({r},{g},{b},{a})" + + @property + def monochrome(self) -> Color: + """A monochrome version of this color. + + Returns: + The monochrome (black and white) version of this color. + """ + r, g, b, a, _, _ = self + gray = round(r * 0.2126 + g * 0.7152 + b * 0.0722) + return Color(gray, gray, gray, a) + + def __rich_repr__(self) -> rich.repr.Result: + r, g, b, a, ansi, auto = self + yield r + yield g + yield b + yield "a", a, 1.0 + yield "ansi", ansi, None + yield "auto", auto, False + + def with_alpha(self, alpha: float) -> Color: + """Create a new color with the given alpha. + + Args: + alpha: New value for alpha. + + Returns: + A new color. + """ + r, g, b, _, _, _ = self + return Color(r, g, b, alpha) + + def multiply_alpha(self, alpha: float) -> Color: + """Create a new color, multiplying the alpha by a constant. + + Args: + alpha: A value to multiple the alpha by (expected to be in the range 0 to 1). + + Returns: + A new color. + """ + if self.ansi is not None: + return self + r, g, b, a, _ansi, auto = self + return Color(r, g, b, a * alpha, auto=auto) + + @lru_cache(maxsize=1024) + def blend( + self, destination: Color, factor: float, alpha: float | None = None + ) -> Color: + """Generate a new color between two colors. + + This method calculates a new color on a gradient. + The position on the gradient is given by `factor`, which is a float between 0 and 1, where 0 is the original color, and 1 is the `destination` color. + A value of `gradient` between the two extremes produces a color somewhere between the two end points. + + Args: + destination: Another color. + factor: A blend factor, 0 -> 1. + alpha: New alpha for result. + + Returns: + A new color. + """ + if destination.auto: + destination = self.get_contrast_text(destination.a) + if destination.ansi is not None: + return destination + if factor <= 0: + return self + elif factor >= 1: + return destination + r1, g1, b1, a1, _, _ = self + r2, g2, b2, a2, _, _ = destination + + if alpha is None: + new_alpha = a1 + (a2 - a1) * factor + else: + new_alpha = alpha + + return Color( + int(r1 + (r2 - r1) * factor), + int(g1 + (g2 - g1) * factor), + int(b1 + (b2 - b1) * factor), + new_alpha, + ) + + @lru_cache(maxsize=1024) + def tint(self, color: Color) -> Color: + """Apply a tint to a color. + + Similar to blend, but combines color and alpha. + + Args: + color: A color with alpha component. + + Returns: + New color + """ + + r1, g1, b1, a1, ansi1, _ = self + if ansi1 is not None: + return self + r2, g2, b2, a2, ansi2, _ = color + if ansi2 is not None: + return self + return Color( + int(r1 + (r2 - r1) * a2), + int(g1 + (g2 - g1) * a2), + int(b1 + (b2 - b1) * a2), + a1, + ) + + def __add__(self, other: object) -> Color: + if isinstance(other, Color): + return self.blend(other, other.a, 1.0) + elif other is None: + return self + return NotImplemented + + def __radd__(self, other: object) -> Color: + if isinstance(other, Color): + return self.blend(other, other.a, 1.0) + elif other is None: + return self + return NotImplemented + + @classmethod + @lru_cache(maxsize=1024 * 4) + def parse(cls, color_text: str | Color) -> Color: + """Parse a string containing a named color or CSS-style color. + + Colors may be parsed from the following formats: + + - Text beginning with a `#` is parsed as a hexadecimal color code, + where R, G, B, and A must be hexadecimal digits (0-9A-F): + + - `#RGB` + - `#RGBA` + - `#RRGGBB` + - `#RRGGBBAA` + + - Alternatively, RGB colors can also be specified in the format + that follows, where R, G, and B must be numbers between 0 and 255 + and A must be a value between 0 and 1: + + - `rgb(R,G,B)` + - `rgb(R,G,B,A)` + + - The HSL model can also be used, with a syntax similar to the above, + if H is a value between 0 and 360, S and L are percentages, and A + is a value between 0 and 1: + + - `hsl(H,S,L)` + - `hsla(H,S,L,A)` + + Any other formats will raise a `ColorParseError`. + + Args: + color_text: Text with a valid color format. Color objects will + be returned unmodified. + + Raises: + ColorParseError: If the color is not encoded correctly. + + Returns: + Instance encoding the color specified by the argument. + """ + if isinstance(color_text, Color): + return color_text + if color_text == "ansi_default": + return cls(0, 0, 0, ansi=-1) + if color_text.startswith("ansi_"): + try: + ansi = ANSI_COLORS.index(color_text[5:]) + except ValueError: + pass + else: + return cls(*COLOR_NAME_TO_RGB.get(color_text), ansi=ansi) + color_from_name = COLOR_NAME_TO_RGB.get(color_text) + if color_from_name is not None: + return cls(*color_from_name) + color_match = RE_COLOR.match(color_text) + if color_match is None: + error_message = f"failed to parse {color_text!r} as a color" + suggested_color = None + if not color_text.startswith(("#", "rgb", "hsl")): + # Seems like we tried to use a color name: let's try to find one that is close enough: + suggested_color = get_suggestion( + color_text, list(COLOR_NAME_TO_RGB.keys()) + ) + if suggested_color: + error_message += f"; did you mean '{suggested_color}'?" + raise ColorParseError(error_message, suggested_color) + ( + rgb_hex_triple, + rgb_hex_quad, + rgb_hex, + rgba_hex, + rgb, + rgba, + hsl, + hsla, + ) = color_match.groups() + + if rgb_hex_triple is not None: + r, g, b = rgb_hex_triple # type: ignore[misc] + color = cls(int(f"{r}{r}", 16), int(f"{g}{g}", 16), int(f"{b}{b}", 16)) + elif rgb_hex_quad is not None: + r, g, b, a = rgb_hex_quad # type: ignore[misc] + color = cls( + int(f"{r}{r}", 16), + int(f"{g}{g}", 16), + int(f"{b}{b}", 16), + int(f"{a}{a}", 16) / 255.0, + ) + elif rgb_hex is not None: + r, g, b = [int(pair, 16) for pair in _split_pairs3(rgb_hex)] + color = cls(r, g, b, 1.0) + elif rgba_hex is not None: + r, g, b, a = [int(pair, 16) for pair in _split_pairs4(rgba_hex)] + color = cls(r, g, b, a / 255.0) + elif rgb is not None: + r, g, b = [clamp(int(float(value)), 0, 255) for value in rgb.split(",")] + color = cls(r, g, b, 1.0) + elif rgba is not None: + float_r, float_g, float_b, float_a = [ + float(value) for value in rgba.split(",") + ] + color = cls( + clamp(int(float_r), 0, 255), + clamp(int(float_g), 0, 255), + clamp(int(float_b), 0, 255), + clamp(float_a, 0.0, 1.0), + ) + elif hsl is not None: + h, s, l = hsl.split(",") + h = float(h) % 360 / 360 + s = percentage_string_to_float(s) + l = percentage_string_to_float(l) + color = Color.from_hsl(h, s, l) + elif hsla is not None: + h, s, l, a = hsla.split(",") + h = float(h) % 360 / 360 + s = percentage_string_to_float(s) + l = percentage_string_to_float(l) + a = clamp(float(a), 0.0, 1.0) + color = Color.from_hsl(h, s, l).with_alpha(a) + else: # pragma: no-cover + raise AssertionError( # pragma: no-cover + "Can't get here if RE_COLOR matches" + ) + return color + + @lru_cache(maxsize=1024) + def darken(self, amount: float, alpha: float | None = None) -> Color: + """Darken the color by a given amount. + + Args: + amount: Value between 0-1 to reduce luminance by. + alpha: Alpha component for new color or None to copy alpha. + + Returns: + New color. + """ + l, a, b = rgb_to_lab(self) + l -= amount * 100 + return lab_to_rgb(Lab(l, a, b), self.a if alpha is None else alpha).clamped + + def lighten(self, amount: float, alpha: float | None = None) -> Color: + """Lighten the color by a given amount. + + Args: + amount: Value between 0-1 to increase luminance by. + alpha: Alpha component for new color or None to copy alpha. + + Returns: + New color. + """ + return self.darken(-amount, alpha) + + @lru_cache(maxsize=1024) + def get_contrast_text(self, alpha: float = 0.95) -> Color: + """Get a light or dark color that best contrasts this color, for use with text. + + Args: + alpha: An alpha value to apply to the result. + + Returns: + A new color, either an off-white or off-black. + """ + return (WHITE if self.brightness < 0.5 else BLACK).with_alpha(alpha) + + +class Gradient: + """Defines a color gradient.""" + + def __init__(self, *stops: tuple[float, Color | str], quality: int = 50) -> None: + """Create a color gradient that blends colors to form a spectrum. + + A gradient is defined by a sequence of "stops" consisting of a tuple containing a float and a color. + The stop indicates the color at that point on a spectrum between 0 and 1. + Colors may be given as a [Color][textual.color.Color] instance, or a string that + can be parsed into a Color (with [Color.parse][textual.color.Color.parse]). + + The `quality` argument defines the number of _steps_ in the gradient. Intermediate colors are + interpolated from the two nearest colors. Increasing `quality` can generate a smoother looking gradient, + at the expense of a little extra work to pre-calculate the colors. + + Args: + stops: Color stops. + quality: The number of steps in the gradient. + + Raises: + ValueError: If any stops are missing (must be at least a stop for 0 and 1). + """ + parse = Color.parse + self._stops = sorted( + [ + ( + (position, parse(color)) + if isinstance(color, str) + else (position, color) + ) + for position, color in stops + ] + ) + if len(stops) < 2: + raise ValueError("At least 2 stops required.") + if self._stops[0][0] != 0.0: + raise ValueError("First stop must be 0.") + if self._stops[-1][0] != 1.0: + raise ValueError("Last stop must be 1.") + self._quality = quality + self._colors: list[Color] | None = None + self._rich_colors: list[RichColor] | None = None + + @classmethod + def from_colors(cls, *colors: Color | str, quality: int = 50) -> Gradient: + """Construct a gradient form a sequence of colors, where the stops are evenly spaced. + + Args: + *colors: Positional arguments may be Color instances or strings to parse into a color. + quality: The number of steps in the gradient. + + Returns: + A new Gradient instance. + """ + if len(colors) < 2: + raise ValueError("Two or more colors required.") + stops = [(i / (len(colors) - 1), Color.parse(c)) for i, c in enumerate(colors)] + return cls(*stops, quality=quality) + + @property + def colors(self) -> list[Color]: + """A list of colors in the gradient.""" + position = 0 + quality = self._quality + + if self._colors is None: + colors: list[Color] = [] + add_color = colors.append + (stop1, color1), (stop2, color2) = self._stops[0:2] + for step_position in range(quality): + step = step_position / (quality - 1) + while step > stop2: + position += 1 + (stop1, color1), (stop2, color2) = self._stops[ + position : position + 2 + ] + add_color(color1.blend(color2, (step - stop1) / (stop2 - stop1))) + self._colors = colors + assert len(self._colors) == self._quality + return self._colors + + def get_color(self, position: float) -> Color: + """Get a color from the gradient at a position between 0 and 1. + + Positions that are between stops will return a blended color. + + Args: + position: A number between 0 and 1, where 0 is the first stop, and 1 is the last. + + Returns: + A Textual color. + """ + + if position <= 0: + return self.colors[0] + if position >= 1: + return self.colors[-1] + + color_position = position * (self._quality - 1) + color_index = int(color_position) + color1, color2 = self.colors[color_index : color_index + 2] + return color1.blend(color2, color_position % 1) + + def get_rich_color(self, position: float) -> RichColor: + """Get a (Rich) color from the gradient at a position between 0 and 1. + + Positions that are between stops will return a blended color. + + Args: + position: A number between 0 and 1, where 0 is the first stop, and 1 is the last. + + Returns: + A (Rich) color. + """ + return self.get_color(position).rich_color + + +# Color constants +WHITE: Final = Color(255, 255, 255) +"""A constant for pure white.""" +BLACK: Final = Color(0, 0, 0) +"""A constant for pure black.""" +TRANSPARENT: Final = Color.parse("transparent") +"""A constant for transparent.""" + + +def rgb_to_lab(rgb: Color) -> Lab: + """Convert an RGB color to the CIE-L*ab format. + + Uses the standard RGB color space with a D65/2⁰ standard illuminant. + Conversion passes through the XYZ color space. + Cf. http://www.easyrgb.com/en/math.php. + """ + + r, g, b = rgb.r / 255, rgb.g / 255, rgb.b / 255 + + r = pow((r + 0.055) / 1.055, 2.4) if r > 0.04045 else r / 12.92 + g = pow((g + 0.055) / 1.055, 2.4) if g > 0.04045 else g / 12.92 + b = pow((b + 0.055) / 1.055, 2.4) if b > 0.04045 else b / 12.92 + + x = (r * 41.24 + g * 35.76 + b * 18.05) / 95.047 + y = (r * 21.26 + g * 71.52 + b * 7.22) / 100 + z = (r * 1.93 + g * 11.92 + b * 95.05) / 108.883 + + off = 16 / 116 + x = pow(x, 1 / 3) if x > 0.008856 else 7.787 * x + off + y = pow(y, 1 / 3) if y > 0.008856 else 7.787 * y + off + z = pow(z, 1 / 3) if z > 0.008856 else 7.787 * z + off + + return Lab(116 * y - 16, 500 * (x - y), 200 * (y - z)) + + +def lab_to_rgb(lab: Lab, alpha: float = 1.0) -> Color: + """Convert a CIE-L*ab color to RGB. + + Uses the standard RGB color space with a D65/2⁰ standard illuminant. + Conversion passes through the XYZ color space. + Cf. http://www.easyrgb.com/en/math.php. + """ + + y = (lab.L + 16) / 116 + x = lab.a / 500 + y + z = y - lab.b / 200 + + off = 16 / 116 + y = pow(y, 3) if y > 0.2068930344 else (y - off) / 7.787 + x = 0.95047 * pow(x, 3) if x > 0.2068930344 else 0.122059 * (x - off) + z = 1.08883 * pow(z, 3) if z > 0.2068930344 else 0.139827 * (z - off) + + r = x * 3.2406 + y * -1.5372 + z * -0.4986 + g = x * -0.9689 + y * 1.8758 + z * 0.0415 + b = x * 0.0557 + y * -0.2040 + z * 1.0570 + + r = 1.055 * pow(r, 1 / 2.4) - 0.055 if r > 0.0031308 else 12.92 * r + g = 1.055 * pow(g, 1 / 2.4) - 0.055 if g > 0.0031308 else 12.92 * g + b = 1.055 * pow(b, 1 / 2.4) - 0.055 if b > 0.0031308 else 12.92 * b + + return Color(int(r * 255), int(g * 255), int(b * 255), alpha) diff --git a/src/memray/_vendor/textual/command.py b/src/memray/_vendor/textual/command.py new file mode 100644 index 0000000000..e55edc8fd4 --- /dev/null +++ b/src/memray/_vendor/textual/command.py @@ -0,0 +1,1276 @@ +""" +This module contains classes for working with Textual's command palette. + +See the guide on the [Command Palette](../guide/command_palette.md) for full details. + +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from asyncio import ( + CancelledError, + Queue, + Task, + TimeoutError, + create_task, + wait, + wait_for, +) +from dataclasses import dataclass +from functools import total_ordering +from inspect import isclass +from operator import attrgetter +from time import monotonic +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + AsyncIterator, + Callable, + ClassVar, + Iterable, + NamedTuple, +) + +import rich.repr +from rich.align import Align +from rich.text import Text +from typing_extensions import Final, TypeAlias + +from memray._vendor.textual import on, work +from memray._vendor.textual.binding import Binding, BindingType +from memray._vendor.textual.containers import Horizontal, Vertical +from memray._vendor.textual.content import Content +from memray._vendor.textual.events import Click, Mount +from memray._vendor.textual.fuzzy import Matcher +from memray._vendor.textual.message import Message +from memray._vendor.textual.reactive import var +from memray._vendor.textual.screen import Screen, SystemModalScreen +from memray._vendor.textual.style import Style +from memray._vendor.textual.timer import Timer +from memray._vendor.textual.types import IgnoreReturnCallbackType +from memray._vendor.textual.visual import VisualType +from memray._vendor.textual.widget import Widget +from memray._vendor.textual.widgets import Button, Input, LoadingIndicator, OptionList, Static +from memray._vendor.textual.widgets.option_list import Option +from memray._vendor.textual.worker import get_current_worker + +if TYPE_CHECKING: + from memray._vendor.textual.app import App, ComposeResult + +__all__ = [ + "CommandPalette", + "DiscoveryHit", + "Hit", + "Hits", + "Matcher", + "Provider", +] + + +@dataclass +class Hit: + """Holds the details of a single command search hit.""" + + score: float + """The score of the command hit. + + The value should be between 0 (no match) and 1 (complete match). + """ + + match_display: VisualType + """A string or Rich renderable representation of the hit.""" + + command: IgnoreReturnCallbackType + """The function to call when the command is chosen.""" + + text: str | None = None + """The command text associated with the hit, as plain text. + + If `match_display` is not simple text, this attribute should be provided by the + [Provider][textual.command.Provider] object. + """ + + help: str | None = None + """Optional help text for the command.""" + + @property + def prompt(self) -> VisualType: + """The prompt to use when displaying the hit in the command palette.""" + return self.match_display + + def __lt__(self, other: object) -> bool: + if isinstance(other, Hit): + return self.score < other.score + return NotImplemented + + def __eq__(self, other: object) -> bool: + if isinstance(other, Hit): + return self.score == other.score + return NotImplemented + + def __post_init__(self) -> None: + """Ensure 'text' is populated.""" + if self.text is None: + self.text = str(self.match_display) + + +@dataclass +class DiscoveryHit: + """Holds the details of a single command search hit.""" + + display: VisualType + """A string or Rich renderable representation of the hit.""" + + command: IgnoreReturnCallbackType + """The function to call when the command is chosen.""" + + text: str | None = None + """The command text associated with the hit, as plain text. + + If `display` is not simple text, this attribute should be provided by + the [Provider][textual.command.Provider] object. + """ + + help: str | None = None + """Optional help text for the command.""" + + @property + def prompt(self) -> VisualType: + """The prompt to use when displaying the discovery hit in the command palette.""" + return self.display + + @property + def score(self) -> float: + """A discovery hit always has a score of 0. + + The order in which discovery hits are displayed is determined by the order + in which they are yielded by the Provider. It's up to the developer to yield + DiscoveryHits in the . + """ + return 0.0 + + def __lt__(self, other: object) -> bool: + if isinstance(other, DiscoveryHit): + assert self.text is not None + assert other.text is not None + return other.text < self.text + return NotImplemented + + def __eq__(self, other: object) -> bool: + if isinstance(other, Hit): + return self.text == other.text + return NotImplemented + + def __post_init__(self) -> None: + """Ensure 'text' is populated.""" + if self.text is None: + self.text = str(self.display) + + +Hits: TypeAlias = AsyncIterator["DiscoveryHit | Hit"] +"""Return type for the command provider's `search` method.""" + +ProviderSource: TypeAlias = "Iterable[type[Provider] | Callable[[], type[Provider]]]" +"""The type used to declare the providers for a CommandPalette.""" + + +class Provider(ABC): + """Base class for command palette command providers. + + To create new command provider, inherit from this class and implement + [`search`][textual.command.Provider.search]. + """ + + def __init__(self, screen: Screen[Any], match_style: Style | None = None) -> None: + """Initialise the command provider. + + Args: + screen: A reference to the active screen. + """ + if match_style is not None: + assert isinstance( + match_style, Style + ), "match_style must be a Visual style (from textual.style import Style)" + self.__screen = screen + self.__match_style = match_style + self._init_task: Task | None = None + self._init_success = False + + @property + def focused(self) -> Widget | None: + """The currently-focused widget in the currently-active screen in the application. + + If no widget has focus this will be `None`. + """ + return self.__screen.focused + + @property + def screen(self) -> Screen[object]: + """The currently-active screen in the application.""" + return self.__screen + + @property + def app(self) -> App[object]: + """A reference to the application.""" + return self.__screen.app + + @property + def match_style(self) -> Style | None: + """The preferred style to use when highlighting matching portions of the [`match_display`][textual.command.Hit.match_display].""" + return self.__match_style + + def matcher(self, user_input: str, case_sensitive: bool = False) -> Matcher: + """Create a [fuzzy matcher][textual.fuzzy.Matcher] for the given user input. + + Args: + user_input: The text that the user has input. + case_sensitive: Should matching be case sensitive? + + Returns: + A [fuzzy matcher][textual.fuzzy.Matcher] object for matching against candidate hits. + """ + return Matcher( + user_input, + match_style=self.match_style, + case_sensitive=case_sensitive, + ) + + def _post_init(self) -> None: + """Internal method to run post init task.""" + + async def post_init_task() -> None: + """Wrapper to post init that runs in a task.""" + try: + await self.startup() + except Exception: + from rich.traceback import Traceback + + self.app.log.error(Traceback()) + else: + self._init_success = True + + self._init_task = create_task(post_init_task()) + + async def _wait_init(self) -> None: + """Wait for initialization.""" + if self._init_task is not None: + await self._init_task + self._init_task = None + + async def startup(self) -> None: + """Called after the Provider is initialized, but before any calls to `search`.""" + + async def _search(self, query: str) -> Hits: + """Internal method to perform search. + + Args: + query: The user input to be matched. + + Yields: + Instances of [`Hit`][textual.command.Hit]. + """ + await self._wait_init() + if self._init_success: + # An empty search string is a discovery search, anything else is + # a conventional search. + hits = self.search(query) if query else self.discover() + async for hit in hits: + if hit is not NotImplemented: + yield hit + + @abstractmethod + async def search(self, query: str) -> Hits: + """A request to search for commands relevant to the given query. + + Args: + query: The user input to be matched. + + Yields: + Instances of [`Hit`][textual.command.Hit]. + """ + yield NotImplemented + + async def discover(self) -> Hits: + """A default collection of hits for the provider. + + Yields: + Instances of [`DiscoveryHit`][textual.command.DiscoveryHit]. + + Note: + This is different from + [`search`][textual.command.Provider.search] in that it should + yield [`DiscoveryHit`s][textual.command.DiscoveryHit] that + should be shown by default (before user input). + + It is permitted to *not* implement this method. + """ + yield NotImplemented + + async def _shutdown(self) -> None: + """Internal method to call shutdown and log errors.""" + try: + await self.shutdown() + except Exception: + from rich.traceback import Traceback + + self.app.log.error(Traceback()) + + async def shutdown(self) -> None: + """Called when the Provider is shutdown. + + Use this method to perform an cleanup, if required. + + """ + + +class SimpleCommand(NamedTuple): + """A simple command.""" + + name: str + """The name of the command.""" + callback: IgnoreReturnCallbackType + """The callback to invoke when the command is selected.""" + help_text: str | None = None + """The description of the command.""" + + +CommandListItem: TypeAlias = ( + "SimpleCommand | tuple[str, IgnoreReturnCallbackType, str | None] | tuple[str, IgnoreReturnCallbackType]" +) + + +class SimpleProvider(Provider): + """A simple provider which the caller can pass commands to.""" + + def __init__( + self, + screen: Screen[Any], + commands: list[CommandListItem], + ) -> None: + # Convert all commands to SimpleCommand instances + super().__init__(screen, None) + self._commands: list[SimpleCommand] = [] + for command in commands: + if isinstance(command, SimpleCommand): + self._commands.append(command) + elif len(command) == 2: + self._commands.append(SimpleCommand(*command, None)) + elif len(command) == 3: + self._commands.append(SimpleCommand(*command)) + else: + raise ValueError(f"Invalid command: {command}") + + def __call__( + self, screen: Screen[Any], match_style: Style | None = None + ) -> SimpleProvider: + self.__match_style = match_style + return self + + @property + def match_style(self) -> Style | None: + return self.__match_style + + async def search(self, query: str) -> Hits: + matcher = self.matcher(query) + for name, callback, help_text in self._commands: + if (match := matcher.match(name)) > 0: + yield Hit( + match, + matcher.highlight(name), + callback, + help=help_text, + ) + + async def discover(self) -> Hits: + """Handle a request for the discovery commands for this provider. + + Yields: + Commands that can be discovered. + """ + for name, callback, help_text in self._commands: + yield DiscoveryHit( + name, + callback, + help=help_text, + ) + + +@rich.repr.auto +@total_ordering +class Command(Option): + """Class that holds a hit in the [`CommandList`][textual.command.CommandList].""" + + def __init__( + self, + prompt: VisualType, + hit: DiscoveryHit | Hit, + id: str | None = None, + disabled: bool = False, + ) -> None: + """Initialise the option. + + Args: + prompt: The prompt for the option. + hit: The details of the hit associated with the option. + id: The optional ID for the option. + disabled: The initial enabled/disabled state. Enabled by default. + """ + super().__init__(prompt, id, disabled) + self.hit = hit + """The details of the hit associated with the option.""" + + def __hash__(self) -> int: + return id(self) + + def __lt__(self, other: object) -> bool: + if isinstance(other, Command): + return self.hit < other.hit + return NotImplemented + + def __eq__(self, other: object) -> bool: + if isinstance(other, Command): + return self.hit == other.hit + return NotImplemented + + +class CommandList(OptionList, can_focus=False): + """The command palette command list.""" + + DEFAULT_CSS = """ + CommandList { + visibility: hidden; + border-top: blank; + border-bottom: hkey black; + border-left: none; + border-right: none; + height: auto; + max-height: 70vh; + background: transparent; + padding: 0; + } + + CommandList:focus { + border: blank; + } + + CommandList.--visible { + visibility: visible; + } + + CommandList.--populating { + border-bottom: none; + } + + CommandList > .option-list--option-highlighted { + color: $block-cursor-blurred-foreground; + background: $block-cursor-blurred-background; + text-style: $block-cursor-blurred-text-style; + } + + CommandList:nocolor > .option-list--option-highlighted { + text-style: reverse; + } + + CommandList > .option-list--option { + padding: 0 2; + color: $foreground; + text-style: bold; + } + """ + + +class SearchIcon(Static, inherit_css=False): + """Widget for displaying a search icon before the command input.""" + + DEFAULT_CSS = """ + SearchIcon { + color: #000; /* required for snapshot tests */ + margin-left: 1; + margin-top: 1; + width: 2; + } + """ + + icon: var[str] = var("🔎") + """The icon to display.""" + + def render(self) -> VisualType: + """Render the icon. + + Returns: + The icon renderable. + """ + return self.icon + + +class CommandInput(Input): + """The command palette input control.""" + + DEFAULT_CSS = """ + CommandInput, CommandInput:focus { + border: blank; + width: 1fr; + padding-left: 0; + background: transparent; + background-tint: 0%; + } + """ + + +class CommandPalette(SystemModalScreen[None]): + """The Textual command palette.""" + + AUTO_FOCUS = "CommandInput" + + COMPONENT_CLASSES: ClassVar[set[str]] = Screen.COMPONENT_CLASSES | { + "command-palette--help-text", + "command-palette--highlight", + } + """ + | Class | Description | + | :- | :- | + | `command-palette--help-text` | Targets the help text of a matched command. | + | `command-palette--highlight` | Targets the highlights of a matched command. | + """ + + DEFAULT_CSS = """ + + CommandPalette:inline { + /* If the command palette is invoked in inline mode, we may need additional lines. */ + min-height: 20; + } + CommandPalette { + color: $foreground; + background: $background 60%; + align-horizontal: center; + + #--container { + display: none; + } + + &:ansi { + background: transparent; + } + } + + CommandPalette.-ready { + #--container { + display: block; + } + } + + CommandPalette > .command-palette--help-text { + color: $text-muted; + background: transparent; + text-style: not bold; + } + + CommandPalette > .command-palette--highlight { + text-style: bold underline; + } + + CommandPalette:nocolor > .command-palette--highlight { + text-style: underline; + } + + CommandPalette > Vertical { + margin-top: 3; + height: 100%; + visibility: hidden; + background: $surface; + &:dark { background: $panel-darken-1; } + } + + CommandPalette #--input { + height: auto; + visibility: visible; + border: hkey black 50%; + } + + CommandPalette #--input.--list-visible { + border-bottom: none; + } + + CommandPalette #--input Label { + margin-top: 1; + margin-left: 1; + } + + CommandPalette #--input Button { + min-width: 7; + margin-right: 1; + } + + CommandPalette #--results { + overlay: screen; + height: auto; + } + + CommandPalette LoadingIndicator { + height: auto; + visibility: hidden; + border-bottom: hkey $border; + } + + CommandPalette LoadingIndicator.--visible { + visibility: visible; + } + """ + + BINDINGS: ClassVar[list[BindingType]] = [ + Binding( + "ctrl+end, shift+end", + "command_list('last')", + "Go to bottom", + show=False, + ), + Binding( + "ctrl+home, shift+home", + "command_list('first')", + "Go to top", + show=False, + ), + Binding("down", "cursor_down", "Next command", show=False), + Binding("escape", "escape", "Exit the command palette"), + Binding("pagedown", "command_list('page_down')", "Next page", show=False), + Binding("pageup", "command_list('page_up')", "Previous page", show=False), + Binding("up", "command_list('cursor_up')", "Previous command", show=False), + ] + """ + | Key(s) | Description | + | :- | :- | + | ctrl+end, shift+end | Jump to the last available commands. | + | ctrl+home, shift+home | Jump to the first available commands. | + | down | Navigate down through the available commands. | + | escape | Exit the command palette. | + | pagedown | Navigate down a page through the available commands. | + | pageup | Navigate up a page through the available commands. | + | up | Navigate up through the available commands. | + """ + + run_on_select: ClassVar[bool] = True + """A flag to say if a command should be run when selected by the user. + + If `True` then when a user hits `Enter` on a command match in the result + list, or if they click on one with the mouse, the command will be + selected and run. If set to `False` the input will be filled with the + command and then `Enter` should be pressed on the keyboard or the 'go' + button should be pressed. + """ + + _list_visible: var[bool] = var(False, init=False) + """Internal reactive to toggle the visibility of the command list.""" + + _show_busy: var[bool] = var(False, init=False) + """Internal reactive to toggle the visibility of the busy indicator.""" + + _calling_screen: var[Screen[Any] | None] = var(None) + """A record of the screen that was active when we were called.""" + + @dataclass + class OptionHighlighted(Message): + """Posted to App when an option is highlighted in the command palette.""" + + highlighted_event: OptionList.OptionHighlighted + """The option highlighted event from the OptionList within the command palette.""" + + @dataclass + class Opened(Message): + """Posted to App when the command palette is opened.""" + + @dataclass + class Closed(Message): + """Posted to App when the command palette is closed.""" + + option_selected: bool + """True if an option was selected, False if the palette was closed without selecting an option.""" + + def __init__( + self, + providers: ProviderSource | None = None, + *, + placeholder: str = "Search for commands…", + name: str | None = None, + id: str | None = None, + classes: str | None = None, + ) -> None: + """Initialise the command palette. + + Args: + providers: An optional list of providers to use. If None, the providers supplied + in the App or Screen will be used. + placeholder: The placeholder text for the command palette. + """ + super().__init__( + id=id, + classes=classes, + name=name, + ) + self.add_class("--textual-command-palette") + + self._selected_command: DiscoveryHit | Hit | None = None + """The command that was selected by the user.""" + self._busy_timer: Timer | None = None + """Keeps track of if there's a busy indication timer in effect.""" + self._no_matches_timer: Timer | None = None + """Keeps track of if there are 'No matches found' message waiting to be displayed.""" + self._supplied_providers: ProviderSource | None = providers + self._providers: list[Provider] = [] + """List of Provider instances involved in searches.""" + self._hit_count: int = 0 + """Number of hits displayed.""" + self._placeholder = placeholder + + @staticmethod + def is_open(app: App[object]) -> bool: + """Is a command palette current open? + + Args: + app: The app to test. + + Returns: + `True` if a command palette is currently open, `False` if not. + """ + return app.screen.has_class("--textual-command-palette") + + @property + def _provider_classes(self) -> set[type[Provider]]: + """The currently available command providers. + + This is a combination of the command providers defined [in the + application][textual.app.App.COMMANDS] and those [defined in + the current screen][textual.screen.Screen.COMMANDS]. + """ + + def get_providers( + provider_source: ProviderSource, + ) -> Iterable[type[Provider]]: + """Load the providers from a source (typically from the COMMANDS class variable) + at the App or Screen level. + + Args: + provider_source: The source of providers. + + Returns: + An iterable of providers. + """ + for provider in provider_source: + if isinstance(provider, SimpleProvider): + yield provider + elif isclass(provider) and issubclass(provider, Provider): + yield provider + else: + # Lazy loaded providers + yield provider() # type: ignore + + if self._calling_screen is None: + return set() + elif self._supplied_providers is None: + return { + *get_providers(self.app.COMMANDS), + *get_providers(self._calling_screen.COMMANDS), + } + else: + return {*get_providers(self._supplied_providers)} + + def compose(self) -> ComposeResult: + """Compose the command palette. + + Returns: + The content of the screen. + """ + with Vertical(id="--container"): + with Horizontal(id="--input"): + yield SearchIcon() + yield CommandInput(placeholder=self._placeholder, select_on_focus=False) + if not self.run_on_select: + yield Button("\u25b6") + with Vertical(id="--results"): + yield CommandList() + yield LoadingIndicator() + + def _on_click(self, event: Click) -> None: # type: ignore[override] + """Handle the click event. + + Args: + event: The click event. + + This method is used to allow clicking on the 'background' as a + method of dismissing the palette. + """ + if self.get_widget_at(event.screen_x, event.screen_y)[0] is self: + self._cancel_gather_commands() + self.app.post_message(CommandPalette.Closed(option_selected=False)) + self.dismiss() + + def _on_mount(self, _: Mount) -> None: + """Configure the command palette once the DOM is ready.""" + + self.app.post_message(CommandPalette.Opened()) + self._calling_screen = self.app.screen_stack[-2] + + match_style = self.get_visual_style("command-palette--highlight", partial=True) + + assert self._calling_screen is not None + self._providers = [ + provider_class(self._calling_screen, match_style) + for provider_class in self._provider_classes + ] + for provider in self._providers: + provider._post_init() + self._gather_commands("") + + async def _on_unmount(self) -> None: # type: ignore[override] + """Shutdown providers when command palette is closed.""" + if self._providers: + await wait( + [create_task(provider._shutdown()) for provider in self._providers], + ) + self._providers.clear() + + def _stop_busy_countdown(self) -> None: + """Stop any busy countdown that's in effect.""" + if self._busy_timer is not None: + self._busy_timer.stop() + self._busy_timer = None + + _BUSY_COUNTDOWN: Final[float] = 0.5 + """How many seconds to wait for commands to come in before showing we're busy.""" + + def _start_busy_countdown(self) -> None: + """Start a countdown to showing that we're busy searching.""" + self._stop_busy_countdown() + + def _become_busy() -> None: + if self._list_visible: + self._show_busy = True + + self._busy_timer = self.set_timer(self._BUSY_COUNTDOWN, _become_busy) + + def _stop_no_matches_countdown(self) -> None: + """Stop any 'No matches' countdown that's in effect.""" + if self._no_matches_timer is not None: + self._no_matches_timer.stop() + self._no_matches_timer = None + + _NO_MATCHES_COUNTDOWN: Final[float] = 0.5 + """How many seconds to wait before showing 'No matches found'.""" + + def _start_no_matches_countdown(self, search_value: str) -> None: + """Start a countdown to showing that there are no matches for the query. + + Args: + search_value: The value being searched for. + + Adds a 'No matches found' option to the command list after + `_NO_MATCHES_COUNTDOWN` seconds. + """ + self._stop_no_matches_countdown() + + def _show_no_matches() -> None: + # If we were actually searching for something, show that we + # found no matches. + if search_value: + command_list = self.query_one(CommandList) + command_list.add_option( + Option( + Align.center(Text("No matches found", style="not bold")), + disabled=True, + id=self._NO_MATCHES, + ) + ) + self._list_visible = True + else: + # The search value was empty, which means we were in + # discover mode; in that case it makes no sense to show that + # no matches were found. Lack of commands that can be + # discovered is a situation we don't need to highlight. + self._list_visible = False + + self._no_matches_timer = self.set_timer( + self._NO_MATCHES_COUNTDOWN, + _show_no_matches, + ) + + def _watch__list_visible(self) -> None: + """React to the list visible flag being toggled.""" + self.query_one(CommandList).set_class(self._list_visible, "--visible") + self.query_one("#--input", Horizontal).set_class( + self._list_visible, "--list-visible" + ) + if not self._list_visible: + self._show_busy = False + + async def _watch__show_busy(self) -> None: + """React to the show busy flag being toggled. + + This watcher adds or removes a busy indication depending on the + flag's state. + """ + self.query_one(LoadingIndicator).set_class(self._show_busy, "--visible") + self.query_one(CommandList).set_class(self._show_busy, "--populating") + + @staticmethod + async def _consume(hits: Hits, commands: Queue[DiscoveryHit | Hit]) -> None: + """Consume a source of matching commands, feeding the given command queue. + + Args: + hits: The hits to consume. + commands: The command queue to feed. + """ + async for hit in hits: + await commands.put(hit) + + async def _search_for( + self, search_value: str + ) -> AsyncGenerator[DiscoveryHit | Hit, bool]: + """Search for a given search value amongst all of the command providers. + + Args: + search_value: The value to search for. + + Yields: + The hits made amongst the registered command providers. + """ + + # Set up a queue to stream in the command hits from all the providers. + commands: Queue[DiscoveryHit | Hit] = Queue() + + # Fire up an instance of each command provider, inside a task, and + # have them go start looking for matches. + searches = [ + create_task( + self._consume( + provider._search(search_value), + commands, + ) + ) + for provider in self._providers + ] + # Set up a delay for showing that we're busy. + self._start_busy_countdown() + + # Assume the search isn't aborted. + aborted = False + + # Now, while there's some task running... + while not aborted and any(not search.done() for search in searches): + try: + # ...briefly wait for something on the stack. If we get + # something yield it up to our caller. + aborted = yield await wait_for(commands.get(), 0.1) + except TimeoutError: + # A timeout is fine. We're just going to go back round again + # and see if anything else has turned up. + pass + except CancelledError: + # A cancelled error means things are being aborted. + aborted = True + else: + # There was no timeout, which means that we managed to yield + # up that command; we're done with it so let the queue know. + commands.task_done() + + # Check through all the finished searches, see if any have + # exceptions, and log them. In most other circumstances we'd + # re-raise the exception and quit the application, but the decision + # has been made to find and log exceptions with command providers. + # + # https://github.com/Textualize/textual/pull/3058#discussion_r1310051855 + for search in searches: + if search.done(): + exception = search.exception() + if exception is not None: + from rich.traceback import Traceback + + self.log.error( + Traceback.from_exception( + type(exception), exception, exception.__traceback__ + ) + ) + + # Having finished the main processing loop, we're not busy any more. + # Anything left in the queue (see next) will fall out more or less + # instantly. If we're aborted, that means a fresh search is incoming + # and it'll have cleaned up the countdown anyway, so don't do that + # here as they'll be a clash. + if not aborted: + self._stop_busy_countdown() + + # If all the providers are pretty fast it could be that we've reached + # this point but the queue isn't empty yet. So here we flush the + # queue of anything left. + while not aborted and not commands.empty(): + aborted = yield await commands.get() + + # If we were aborted, ensure that all of the searches are cancelled. + if aborted: + for search in searches: + search.cancel() + + def _refresh_command_list( + self, command_list: CommandList, commands: list[Command], clear_current: bool + ) -> None: + """Refresh the command list. + + Args: + command_list: The widget that shows the list of commands. + commands: The commands to show in the widget. + clear_current: Should the current content of the list be cleared first? + """ + + sorted_commands = sorted(commands, key=attrgetter("hit.score"), reverse=True) + command_list.clear_options().add_options(sorted_commands) + + if sorted_commands: + command_list.highlighted = 0 + + self._list_visible = bool(command_list.option_count) + self._hit_count = command_list.option_count + + _RESULT_BATCH_TIME: Final[float] = 0.25 + """How long to wait before adding commands to the command list.""" + + _NO_MATCHES: Final[str] = "--no-matches" + """The ID to give the disabled option that shows there were no matches.""" + + _GATHER_COMMANDS_GROUP: Final[str] = "--textual-command-palette-gather-commands" + """The group name of the command gathering worker.""" + + @work(exclusive=True, group=_GATHER_COMMANDS_GROUP) + async def _gather_commands(self, search_value: str) -> None: + """Gather up all of the commands that match the search value. + + Args: + search_value: The value to search for. + """ + # The list to hold on to the commands we've gathered from the + # command providers. + gathered_commands: list[Command] = [] + + # Get a reference to the widget that we're going to drop the + # (display of) commands into. + command_list = self.query_one(CommandList) + + # If there's just one option in the list, and it's the item that + # tells the user there were no matches, let's remove that. We're + # starting a new search so we don't want them thinking there's no + # matches already. + if ( + command_list.option_count == 1 + and command_list.get_option_at_index(0).id == self._NO_MATCHES + ): + command_list.remove_option(self._NO_MATCHES) + + # Each command will receive a sequential ID. This is going to be + # used to find commands back again when we update the visible list + # and want to settle the selection back on the command it was on. + command_id = 0 + + # We're going to be checking in on the worker as we loop around, so + # grab a reference to that. + worker = get_current_worker() + + # Reset busy mode. + self._show_busy = False + + # A flag to keep track of if the current content of the command hit + # list needs to be cleared. The initial clear *should* be in + # `_input`, but doing so caused an unsightly "flash" of the list; so + # here we sacrifice "correct" code for a better-looking UI. + clear_current = True + + # We're going to batch updates over time, so start off pretending + # we've just done an update. + last_update = monotonic() + + # Kick off the search, grabbing the iterator. + search_routine = self._search_for(search_value) + search_results = search_routine.__aiter__() + + # We're going to be doing the send/await dance in this code, so we + # need to grab the first yielded command to start things off. + try: + hit = await search_results.__anext__() + except StopAsyncIteration: + # We've been stopped before we've even really got going, likely + # because the user is very quick on the keyboard. + hit = None + + while hit: + # Turn the command into something for display, and add it to the + # list of commands that have been gathered so far. + + def build_prompt() -> Iterable[Content]: + """Generator for prompt content.""" + assert hit is not None + if isinstance(hit.prompt, Text): + yield Content.from_rich_text(hit.prompt) + else: + yield Content.from_markup(hit.prompt) + + # Optional help text + if hit.help: + help_style = Style.from_styles( + self.get_component_styles("command-palette--help-text") + ) + yield Content.from_markup(hit.help).stylize_before(help_style) + + prompt = Content("\n").join(build_prompt()) + + gathered_commands.append(Command(prompt, hit, id=str(command_id))) + + if worker.is_cancelled: + break + + now = monotonic() + if (now - last_update) > self._RESULT_BATCH_TIME: + self._refresh_command_list( + command_list, gathered_commands, clear_current + ) + clear_current = False + last_update = now + + command_id += 1 + + # Finally, get the available command from the incoming queue; + # note that we send the worker cancelled status down into the + # search method. + try: + hit = await search_routine.asend(worker.is_cancelled) + except StopAsyncIteration: + break + + # On the way out, if we're still in play, ensure everything has been + # dropped into the command list. + if not worker.is_cancelled: + self._refresh_command_list(command_list, gathered_commands, clear_current) + + # One way or another, we're not busy any more. + self._show_busy = False + + # If we didn't get any hits, and we're not cancelled, that would + # mean nothing was found. Give the user positive feedback to that + # effect. + if command_list.option_count == 0 and not worker.is_cancelled: + self._hit_count = 0 + self._start_no_matches_countdown(search_value) + + self.add_class("-ready") + + def _cancel_gather_commands(self) -> None: + """Cancel any operation that is gather commands.""" + self.workers.cancel_group(self, self._GATHER_COMMANDS_GROUP) + + @on(Input.Changed) + def _input(self, event: Input.Changed) -> None: + """React to input in the command palette. + + Args: + event: The input event. + """ + event.stop() + self._cancel_gather_commands() + self._stop_no_matches_countdown() + self._gather_commands(event.value.strip()) + + @on(OptionList.OptionSelected) + def _select_command(self, event: OptionList.OptionSelected) -> None: + """React to a command being selected from the dropdown. + + Args: + event: The option selection event. + """ + event.stop() + self._cancel_gather_commands() + input = self.query_one(CommandInput) + with self.prevent(Input.Changed): + assert isinstance(event.option, Command) + hit = event.option.hit + input.value = str(hit.text) + self._selected_command = hit + input.action_end() + self._list_visible = False + self.query_one(CommandList).clear_options() + self._hit_count = 0 + if self.run_on_select: + self._select_or_command() + + @on(Input.Submitted) + @on(Button.Pressed) + def _select_or_command( + self, event: Input.Submitted | Button.Pressed | None = None + ) -> None: + """Depending on context, select or execute a command.""" + # If the list is visible, that means we're in "pick a command" + # mode... + if event is not None: + event.stop() + if self._list_visible: + command_list = self.query_one(CommandList) + # ...so if nothing in the list is highlighted yet... + if command_list.highlighted is None: + # ...cause the first completion to be highlighted. + self._action_cursor_down() + # If there is one option, assume the user wants to select it + if command_list.option_count == 1: + # Call after a short delay to provide a little visual feedback + self._action_command_list("select") + else: + # The list is visible, something is highlighted, the user + # made a selection "gesture"; let's go select it! + self._action_command_list("select") + else: + # The list isn't visible, which means that if we have a + # command... + if self._selected_command is not None: + # ...we should return it to the parent screen and let it + # decide what to do with it (hopefully it'll run it). + self._cancel_gather_commands() + self.app.post_message(CommandPalette.Closed(option_selected=True)) + self.app.delay_update() + self.dismiss() + self.app.call_later(self._selected_command.command) + + @on(OptionList.OptionHighlighted) + def _stop_event_leak(self, event: OptionList.OptionHighlighted) -> None: + """Stop any unused events so they don't leak to the application.""" + event.stop() + self.app.post_message(CommandPalette.OptionHighlighted(highlighted_event=event)) + + def _action_escape(self) -> None: + """Handle a request to escape out of the command palette.""" + self._cancel_gather_commands() + self.app.post_message(CommandPalette.Closed(option_selected=False)) + self.dismiss() + + def _action_command_list(self, action: str) -> None: + """Pass an action on to the [`CommandList`][textual.command.CommandList]. + + Args: + action: The action to pass on to the [`CommandList`][textual.command.CommandList]. + """ + try: + command_action = getattr(self.query_one(CommandList), f"action_{action}") + except AttributeError: + return + command_action() + + def _action_cursor_down(self) -> None: + """Handle the cursor down action. + + This allows the cursor down key to either open the command list, if + it's closed but has options, or if it's open with options just + cursor through them. + """ + commands = self.query_one(CommandList) + if commands.option_count and not self._list_visible: + self._list_visible = True + commands.highlighted = 0 + elif ( + commands.option_count + and not commands.get_option_at_index(0).id == self._NO_MATCHES + ): + self._action_command_list("cursor_down") diff --git a/src/memray/_vendor/textual/compose.py b/src/memray/_vendor/textual/compose.py new file mode 100644 index 0000000000..c7a30e92f7 --- /dev/null +++ b/src/memray/_vendor/textual/compose.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from memray._vendor.textual.app import App, ComposeResult + from memray._vendor.textual.widget import Widget + +__all__ = ["compose"] + + +def compose( + node: App | Widget, compose_result: ComposeResult | None = None +) -> list[Widget]: + """Compose child widgets from a generator in the same way as [compose][textual.widget.Widget.compose]. + + Example: + ```python + def on_key(self, event:events.Key) -> None: + + def add_key(key:str) -> ComposeResult: + with containers.HorizontalGroup(): + yield Label("You pressed:") + yield Label(key) + + self.mount_all( + compose(self, add_key(event.key)), + ) + ``` + + Args: + node: The parent node. + compose_result: A compose result, or `None` to call `node.compose()`. + + Returns: + A list of widgets. + """ + _rich_traceback_omit = True + from memray._vendor.textual.widget import MountError, Widget + + app = node.app + nodes: list[Widget] = [] + compose_stack: list[Widget] = [] + composed: list[Widget] = [] + app._compose_stacks.append(compose_stack) + app._composed.append(composed) + iter_compose = iter( + compose_result if compose_result is not None else node.compose() + ) + is_generator = hasattr(iter_compose, "throw") + try: + while True: + try: + child = next(iter_compose) + except StopIteration: + break + + if not isinstance(child, Widget): + mount_error = MountError( + f"Can't mount {type(child)}; expected a Widget instance." + ) + if is_generator: + iter_compose.throw(mount_error) # type: ignore + else: + raise mount_error from None + + try: + child.id + except AttributeError: + mount_error = MountError( + "Widget is missing an 'id' attribute; did you forget to call super().__init__()?" + ) + if is_generator: + iter_compose.throw(mount_error) # type: ignore + else: + raise mount_error from None + + if composed: + nodes.extend(composed) + composed.clear() + if compose_stack: + try: + compose_stack[-1].compose_add_child(child) + except Exception as error: + if is_generator: + # So the error is raised inside the generator + # This will generate a more sensible traceback for the dev + iter_compose.throw(error) # type: ignore + else: + raise + else: + nodes.append(child) + if composed: + nodes.extend(composed) + composed.clear() + finally: + app._compose_stacks.pop() + app._composed.pop() + return nodes diff --git a/src/memray/_vendor/textual/constants.py b/src/memray/_vendor/textual/constants.py new file mode 100644 index 0000000000..0f7717bec5 --- /dev/null +++ b/src/memray/_vendor/textual/constants.py @@ -0,0 +1,172 @@ +""" +This module contains constants, which may be set in environment variables. +""" + +from __future__ import annotations + +import os +from typing import get_args + +from typing_extensions import Final, TypeGuard + +from memray._vendor.textual._types import AnimationLevel + +get_environ = os.environ.get + + +def _get_environ_bool(name: str) -> bool: + """Check an environment variable switch. + + Args: + name: Name of environment variable. + + Returns: + `True` if the env var is "1", otherwise `False`. + """ + has_environ = get_environ(name) == "1" + return has_environ + + +def _get_environ_int( + name: str, default: int, minimum: int | None = None, maximum: int | None = None +) -> int: + """Retrieves an integer environment variable. + + Args: + name: Name of environment variable. + default: The value to use if the value is not set, or set to something other + than a valid integer. + minimum: Optional minimum value. + + Returns: + The integer associated with the environment variable if it's set to a valid int + or the default value otherwise. + """ + try: + value = int(os.environ[name]) + except KeyError: + return default + except ValueError: + return default + if minimum is not None: + return max(minimum, value) + if maximum is not None: + return min(maximum, value) + return value + + +def _get_environ_port(name: str, default: int) -> int: + """Get a port no. from an environment variable. + + Note that there is no 'minimum' here, as ports are more like names than a scalar value. + + Args: + name: Name of environment variable. + default: The value to use if the value is not set, or set to something other + than a valid port. + + Returns: + An integer port number. + + """ + try: + value = int(os.environ[name]) + except KeyError: + return default + except ValueError: + return default + if value < 0 or value > 65535: + return default + return value + + +def _is_valid_animation_level(value: str) -> TypeGuard[AnimationLevel]: + """Checks if a string is a valid animation level. + + Args: + value: The string to check. + + Returns: + Whether it's a valid level or not. + """ + return value in get_args(AnimationLevel) + + +def _get_textual_animations() -> AnimationLevel: + """Get the value of the environment variable that controls textual animations. + + The variable can be in any of the values defined by [`AnimationLevel`][textual.constants.AnimationLevel]. + + Returns: + The value that the variable was set to. If the environment variable is set to an + invalid value, we default to showing all animations. + """ + value: str = get_environ("TEXTUAL_ANIMATIONS", "FULL").lower() + if _is_valid_animation_level(value): + return value + return "full" + + +DEBUG: Final[bool] = _get_environ_bool("TEXTUAL_DEBUG") +"""Enable debug mode.""" + +DRIVER: Final[str | None] = get_environ("TEXTUAL_DRIVER", None) +"""Import for replacement driver.""" + +FILTERS: Final[str] = get_environ("TEXTUAL_FILTERS", "") +"""A list of filters to apply to renderables.""" + +LOG_FILE: Final[str | None] = get_environ("TEXTUAL_LOG", None) +"""A last resort log file that appends all logs, when devtools isn't working.""" + +DEVTOOLS_HOST: Final[str] = get_environ("TEXTUAL_DEVTOOLS_HOST", "127.0.0.1") +"""The host where textual console is running.""" + +DEVTOOLS_PORT: Final[int] = _get_environ_port("TEXTUAL_DEVTOOLS_PORT", 8081) +"""Constant with the port that the devtools will connect to.""" + +SCREENSHOT_DELAY: Final[int] = _get_environ_int("TEXTUAL_SCREENSHOT", -1, minimum=-1) +"""Seconds delay before taking screenshot, -1 for no screenshot.""" + +SCREENSHOT_LOCATION: Final[str | None] = get_environ("TEXTUAL_SCREENSHOT_LOCATION") +"""The location where screenshots should be written.""" + +SCREENSHOT_FILENAME: Final[str | None] = get_environ("TEXTUAL_SCREENSHOT_FILENAME") +"""The filename to use for the screenshot.""" + +PRESS: Final[str] = get_environ("TEXTUAL_PRESS", "") +"""Keys to automatically press.""" + +SHOW_RETURN: Final[bool] = _get_environ_bool("TEXTUAL_SHOW_RETURN") +"""Write the return value on exit.""" + +MAX_FPS: Final[int] = _get_environ_int("TEXTUAL_FPS", 60, minimum=1) +"""Maximum frames per second for updates.""" + +COLOR_SYSTEM: Final[str | None] = get_environ("TEXTUAL_COLOR_SYSTEM", "auto") +"""Force color system override.""" + +TEXTUAL_ANIMATIONS: Final[AnimationLevel] = _get_textual_animations() +"""Determines whether animations run or not.""" + +ESCAPE_DELAY: Final[float] = _get_environ_int("ESCDELAY", 100, minimum=1) / 1000.0 +"""The delay (in seconds) before reporting an escape key (not used if the extend key protocol is available).""" + +SLOW_THRESHOLD: int = _get_environ_int("TEXTUAL_SLOW_THRESHOLD", 500, minimum=100) +"""The time threshold (in milliseconds) after which a warning is logged +if message processing exceeds this duration. +""" + +DEFAULT_THEME: Final[str] = get_environ("TEXTUAL_THEME", "textual-dark") +"""Textual theme to make default. More than one theme may be specified in a comma separated list. +Textual will use the first theme that exists. +""" + +SMOOTH_SCROLL: Final[bool] = _get_environ_int("TEXTUAL_SMOOTH_SCROLL", 1) == 1 +"""Should smooth scrolling be enabled? set `TEXTUAL_SMOOTH_SCROLL=0` to disable smooth scrolling. +""" + +DIM_FACTOR: Final[float] = ( + _get_environ_int("TEXTUAL_DIM_FACTOR", 66, minimum=0, maximum=100) / 100 +) +"""Percentage to use as opacity when converting ANSI 'dim' attribute to RGB.""" diff --git a/src/memray/_vendor/textual/containers.py b/src/memray/_vendor/textual/containers.py new file mode 100644 index 0000000000..bc2c5cde43 --- /dev/null +++ b/src/memray/_vendor/textual/containers.py @@ -0,0 +1,311 @@ +""" +Container widgets for quick styling. + +With the exception of `Center` and `Middle` containers will fill all of the space in the parent widget. + +""" + +from __future__ import annotations + +from typing import ClassVar + +from memray._vendor.textual.binding import Binding, BindingType +from memray._vendor.textual.layout import Layout +from memray._vendor.textual.layouts.grid import GridLayout +from memray._vendor.textual.reactive import reactive +from memray._vendor.textual.widget import Widget + + +class Container(Widget): + """Simple container widget, with vertical layout.""" + + DEFAULT_CSS = """ + Container { + width: 1fr; + height: 1fr; + layout: vertical; + overflow: hidden hidden; + } + """ + + +class ScrollableContainer(Widget, can_focus=True): + """A scrollable container with vertical layout, and auto scrollbars on both axis.""" + + # We don't typically want to maximize scrollable containers, + # since the user can easily navigate the contents + ALLOW_MAXIMIZE = False + + DEFAULT_CSS = """ + ScrollableContainer { + width: 1fr; + height: 1fr; + layout: vertical; + overflow: auto auto; + } + """ + + BINDINGS: ClassVar[list[BindingType]] = [ + Binding("up", "scroll_up", "Scroll Up", show=False), + Binding("down", "scroll_down", "Scroll Down", show=False), + Binding("left", "scroll_left", "Scroll Left", show=False), + Binding("right", "scroll_right", "Scroll Right", show=False), + Binding("home", "scroll_home", "Scroll Home", show=False), + Binding("end", "scroll_end", "Scroll End", show=False), + Binding("pageup", "page_up", "Page Up", show=False), + Binding("pagedown", "page_down", "Page Down", show=False), + Binding("ctrl+pageup", "page_left", "Page Left", show=False), + Binding("ctrl+pagedown", "page_right", "Page Right", show=False), + ] + """Keyboard bindings for scrollable containers. + + | Key(s) | Description | + | :- | :- | + | up | Scroll up, if vertical scrolling is available. | + | down | Scroll down, if vertical scrolling is available. | + | left | Scroll left, if horizontal scrolling is available. | + | right | Scroll right, if horizontal scrolling is available. | + | home | Scroll to the home position, if scrolling is available. | + | end | Scroll to the end position, if scrolling is available. | + | pageup | Scroll up one page, if vertical scrolling is available. | + | pagedown | Scroll down one page, if vertical scrolling is available. | + | ctrl+pageup | Scroll left one page, if horizontal scrolling is available. | + | ctrl+pagedown | Scroll right one page, if horizontal scrolling is available. | + """ + + def __init__( + self, + *children: Widget, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + can_focus: bool | None = None, + can_focus_children: bool | None = None, + can_maximize: bool | None = None, + ) -> None: + """ + Construct a scrollable container. + + Args: + *children: Child widgets. + name: The name of the widget. + id: The ID of the widget in the DOM. + classes: The CSS classes for the widget. + disabled: Whether the widget is disabled or not. + can_focus: Can this container be focused? + can_focus_children: Can this container's children be focused? + can_maximized: Allow this container to maximize? `None` to use default logic., + """ + + super().__init__( + *children, + name=name, + id=id, + classes=classes, + disabled=disabled, + ) + if can_focus is not None: + self.can_focus = can_focus + if can_focus_children is not None: + self.can_focus_children = can_focus_children + self.can_maximize = can_maximize + + @property + def allow_maximize(self) -> bool: + if self.can_maximize is None: + return super().allow_maximize + return self.can_maximize + + +class Vertical(Widget): + """An expanding container with vertical layout and no scrollbars.""" + + DEFAULT_CSS = """ + Vertical { + width: 1fr; + height: 1fr; + layout: vertical; + overflow: hidden hidden; + } + """ + + +class VerticalGroup(Widget): + """A non-expanding container with vertical layout and no scrollbars.""" + + DEFAULT_CSS = """ + VerticalGroup { + width: 1fr; + height: auto; + layout: vertical; + overflow: hidden hidden; + } + """ + + +class VerticalScroll(ScrollableContainer): + """A container with vertical layout and an automatic scrollbar on the Y axis.""" + + DEFAULT_CSS = """ + VerticalScroll { + layout: vertical; + overflow-x: hidden; + overflow-y: auto; + } + """ + + +class Horizontal(Widget): + """An expanding container with horizontal layout and no scrollbars.""" + + DEFAULT_CSS = """ + Horizontal { + width: 1fr; + height: 1fr; + layout: horizontal; + overflow: hidden hidden; + } + """ + + +class HorizontalGroup(Widget): + """A non-expanding container with horizontal layout and no scrollbars.""" + + DEFAULT_CSS = """ + HorizontalGroup { + width: 1fr; + height: auto; + layout: horizontal; + overflow: hidden hidden; + } + """ + + +class HorizontalScroll(ScrollableContainer): + """A container with horizontal layout and an automatic scrollbar on the X axis.""" + + DEFAULT_CSS = """ + HorizontalScroll { + layout: horizontal; + overflow-y: hidden; + overflow-x: auto; + } + """ + + +class Center(Widget): + """A container which aligns children on the X axis.""" + + DEFAULT_CSS = """ + Center { + align-horizontal: center; + width: 1fr; + height: auto; + } + """ + + +class Right(Widget): + """A container which aligns children on the X axis.""" + + DEFAULT_CSS = """ + Right { + align-horizontal: right; + width: 1fr; + height: auto; + } + """ + + +class Middle(Widget): + """A container which aligns children on the Y axis.""" + + DEFAULT_CSS = """ + Middle { + align-vertical: middle; + width: auto; + height: 1fr; + } + """ + + +class CenterMiddle(Widget): + """A container which aligns its children on both axis.""" + + DEFAULT_CSS = """ + CenterMiddle { + align: center middle; + width: 1fr; + height: 1fr; + } + """ + + +class Grid(Widget): + """A container with grid layout.""" + + DEFAULT_CSS = """ + Grid { + width: 1fr; + height: 1fr; + layout: grid; + } + """ + + +class ItemGrid(Widget): + """A container with grid layout and automatic columns.""" + + DEFAULT_CSS = """ + ItemGrid { + width: 1fr; + height: auto; + layout: grid; + } + """ + + stretch_height: reactive[bool] = reactive(True) + min_column_width: reactive[int | None] = reactive(None, layout=True) + max_column_width: reactive[int | None] = reactive(None, layout=True) + regular: reactive[bool] = reactive(False) + + def __init__( + self, + *children: Widget, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + min_column_width: int | None = None, + max_column_width: int | None = None, + stretch_height: bool = True, + regular: bool = False, + ) -> None: + """ + Construct a ItemGrid. + + Args: + *children: Child widgets. + name: The name of the widget. + id: The ID of the widget in the DOM. + classes: The CSS classes for the widget. + disabled: Whether the widget is disabled or not. + stretch_height: Expand the height of widgets to the row height. + min_column_width: The smallest permitted column width. + regular: All rows should have the same number of items. + """ + super().__init__( + *children, name=name, id=id, classes=classes, disabled=disabled + ) + self.set_reactive(ItemGrid.stretch_height, stretch_height) + self.set_reactive(ItemGrid.min_column_width, min_column_width) + self.set_reactive(ItemGrid.max_column_width, max_column_width) + self.set_reactive(ItemGrid.regular, regular) + + def pre_layout(self, layout: Layout) -> None: + if isinstance(layout, GridLayout): + layout.stretch_height = self.stretch_height + layout.min_column_width = self.min_column_width + layout.max_column_width = self.max_column_width + layout.regular = self.regular diff --git a/src/memray/_vendor/textual/content.py b/src/memray/_vendor/textual/content.py new file mode 100644 index 0000000000..3382504d13 --- /dev/null +++ b/src/memray/_vendor/textual/content.py @@ -0,0 +1,1833 @@ +""" +Content is a container for text, with spans marked up with color / style. +It is equivalent to Rich's Text object, with support for more of Textual features. + +Unlike Rich Text, Content is *immutable* so you can't modify it in place, and most methods will return a new Content instance. +This is more like the builtin str, and allows Textual to make some significant optimizations. + +""" + +from __future__ import annotations + +import re +from functools import cached_property, total_ordering +from operator import itemgetter +from typing import Callable, Iterable, NamedTuple, Sequence, Union + +import rich.repr +from rich._wrap import divide_line +from rich.cells import set_cell_size +from rich.console import Console +from rich.segment import Segment +from rich.style import Style as RichStyle +from rich.terminal_theme import TerminalTheme +from rich.text import Text +from typing_extensions import Final, TypeAlias + +from memray._vendor.textual._cells import cell_len +from memray._vendor.textual._context import active_app +from memray._vendor.textual._loop import loop_last +from memray._vendor.textual.cache import FIFOCache +from memray._vendor.textual.color import Color +from memray._vendor.textual.css.types import TextAlign, TextOverflow +from memray._vendor.textual.selection import Selection +from memray._vendor.textual.strip import Strip +from memray._vendor.textual.style import Style +from memray._vendor.textual.visual import RenderOptions, RulesMap, Visual + +__all__ = ["ContentType", "Content", "Span"] + +ContentType: TypeAlias = Union["Content", str] +"""Type alias used where content and a str are interchangeable in a function.""" + +ContentText: TypeAlias = Union["Content", Text, str] +"""A type that may be used to construct Text.""" + +ANSI_DEFAULT = Style( + background=Color(0, 0, 0, 0, ansi=-1), + foreground=Color(0, 0, 0, 0, ansi=-1), +) +"""A Style for ansi default background and foreground.""" + +TRANSPARENT_STYLE = Style() +"""A null style.""" + +_re_whitespace = re.compile(r"\s+$") +_STRIP_CONTROL_CODES: Final = [ + 7, # Bell + 8, # Backspace + 11, # Vertical tab + 12, # Form feed + 13, # Carriage return +] +_CONTROL_STRIP_TRANSLATE: Final = { + _codepoint: None for _codepoint in _STRIP_CONTROL_CODES +} + + +def _strip_control_codes( + text: str, _translate_table: dict[int, None] = _CONTROL_STRIP_TRANSLATE +) -> str: + """Remove control codes from text. + + Args: + text (str): A string possibly contain control codes. + + Returns: + str: String with control codes removed. + """ + return text.translate(_translate_table) + + +@rich.repr.auto +class Span(NamedTuple): + """A style applied to a range of character offsets.""" + + start: int + end: int + style: Style | str + + def __rich_repr__(self) -> rich.repr.Result: + yield self.start + yield self.end + yield "style", self.style + + def extend(self, cells: int) -> "Span": + """Extend the span by the given number of cells. + + Args: + cells (int): Additional space to add to end of span. + + Returns: + Span: A span. + """ + if cells: + start, end, style = self + return Span(start, end + cells, style) + return self + + def _shift(self, distance: int) -> "Span": + """Shift a span a given distance. + + Note that the start offset is clamped to 0. + The end offset is not clamped, as it is assumed this has already been checked by the caller. + + Args: + distance: Number of characters to move. + + Returns: + New Span. + """ + if distance < 0: + start, end, style = self + return Span( + offset if (offset := start + distance) > 0 else 0, end + distance, style + ) + else: + start, end, style = self + return Span(start + distance, end + distance, style) + + +@rich.repr.auto +@total_ordering +class Content(Visual): + """Text content with marked up spans. + + This object can be considered immutable, although it might update its internal state + in a way that is consistent with immutability. + + """ + + __slots__ = ["_text", "_spans", "_cell_length"] + + _NORMALIZE_TEXT_ALIGN = {"start": "left", "end": "right", "justify": "full"} + + def __init__( + self, + text: str = "", + spans: list[Span] | None = None, + cell_length: int | None = None, + strip_control_codes: bool = True, + ) -> None: + """ + Initialize a Content object. + + Args: + text: text content. + spans: Optional list of spans. + cell_length: Cell length of text if known, otherwise `None`. + strip_control_codes: Strip control codes that may break output? + """ + + self._text: str = ( + _strip_control_codes(text) if strip_control_codes and text else text + ) + self._spans: list[Span] = [] if spans is None else spans + self._cell_length = cell_length + self._optimal_width_cache: int | None = None + self._minimal_width_cache: int | None = None + self._height_cache: tuple[tuple[int, str, bool] | None, int] = (None, 0) + self._divide_cache: ( + FIFOCache[Sequence[int], list[tuple[Span, int, int]]] | None + ) = None + self._split_cache: FIFOCache[tuple[str, bool, bool], list[Content]] | None = ( + None + ) + # If there are 1 or 0 spans, it can't be simplified further + self._simplified = len(self._spans) <= 1 + + def __str__(self) -> str: + return self._text + + @property + def _is_regular(self) -> bool: + """Check if the line is regular (spans.end > span.start for all spans). + + This is a debugging aid, and unlikely to be useful in your app. + + Returns: + `True` if the content is regular, `False` if it is not (and broken). + """ + for span in self.spans: + if span.end <= span.start: + return False + return True + + @cached_property + def markup(self) -> str: + """Get the content markup that would create this Content instance. + + This is essentially the inverse of [`Content.from_markup`][textual.content.Content.from_markup]. + + Returns: + str: A string potentially creating markup tags. + """ + from memray._vendor.textual.markup import escape + + output: list[str] = [] + + plain = self.plain + markup_spans = [ + (0, False, None), + *((span.start, False, span.style) for span in self._spans), + *((span.end, True, span.style) for span in self._spans), + (len(plain), True, None), + ] + markup_spans.sort(key=itemgetter(0, 1)) + position = 0 + append = output.append + for offset, closing, style in markup_spans: + if offset > position: + append(escape(plain[position:offset])) + position = offset + if style: + append(f"[/{style}]" if closing else f"[{style}]") + markup = "".join(output) + return markup + + @classmethod + def empty(cls) -> Content: + """Get an empty (blank) content""" + return EMPTY_CONTENT + + @classmethod + def from_text( + cls, markup_content_or_text: ContentText, markup: bool = True + ) -> Content: + """Construct content from Text or str. If the argument is already Content, then + return it unmodified. + + This method exists to make (Rich) Text and Content interchangeable. While Content + is preferred, we don't want to make it harder than necessary for apps to use Text. + + Args: + markup_content_or_text: Value to create Content from. + markup: If `True`, then str values will be parsed as markup, otherwise they will + be considered literals. + + Raises: + TypeError: If the supplied argument is not a valid type. + + Returns: + A new Content instance. + """ + if isinstance(markup_content_or_text, Content): + return markup_content_or_text + elif isinstance(markup_content_or_text, str): + if markup: + return cls.from_markup(markup_content_or_text) + else: + return cls(markup_content_or_text) + elif isinstance(markup_content_or_text, Text): + return cls.from_rich_text(markup_content_or_text) + else: + raise TypeError( + "This method expects a str, a Text instance, or a Content instance" + ) + + @classmethod + def from_markup(cls, markup: str | Content, **variables: object) -> Content: + """Create content from markup, optionally combined with template variables. + + If `markup` is already a Content instance, it will be returned unmodified. + + See the guide on [Content](../guide/content.md#content-class) for more details. + + + Example: + ```python + content = Content.from_markup("Hello, [b]$name[/b]!", name="Will") + ``` + + Args: + markup: Content markup, or Content. + **variables: Optional template variables used + + Returns: + New Content instance. + """ + _rich_traceback_omit = True + if isinstance(markup, Content): + if variables: + raise ValueError("A literal string is require to substitute variables.") + return markup + markup = _strip_control_codes(markup) + if "[" not in markup and not variables: + return Content(markup) + from memray._vendor.textual.markup import to_content + + content = to_content(markup, template_variables=variables or None) + return content + + @classmethod + def from_rich_text( + cls, text: str | Text, console: Console | None = None + ) -> Content: + """Create equivalent Visual Content for str or Text. + + Args: + text: String or Rich Text. + console: A Console object to use if parsing Rich Console markup, or `None` to + use app default. + + Returns: + New Content. + """ + if isinstance(text, str): + text = Text.from_markup(text) + + ansi_theme: TerminalTheme | None = None + + if console is not None: + get_style = console.get_style + else: + try: + app = active_app.get() + except LookupError: + get_style = RichStyle.parse + else: + get_style = app.console.get_style + + if text._spans: + try: + ansi_theme = active_app.get().ansi_theme + except LookupError: + ansi_theme = None + spans = [ + Span( + start, + end, + ( + Style.from_rich_style(get_style(style), ansi_theme) + if isinstance(style, str) + else Style.from_rich_style(style, ansi_theme) + ), + ) + for start, end, style in text._spans + ] + else: + spans = [] + + content = cls(text.plain, spans) + if text.style: + try: + ansi_theme = active_app.get().ansi_theme + except LookupError: + ansi_theme = None + content = content.stylize_before( + text.style + if isinstance(text.style, str) + else Style.from_rich_style(text.style, ansi_theme) + ) + return content + + @classmethod + def styled( + cls, + text: str, + style: Style | str = "", + cell_length: int | None = None, + strip_control_codes: bool = True, + ) -> Content: + """Create a Content instance from text and an optional style. + + Args: + text: String content. + style: Desired style. + cell_length: Cell length of text if known, otherwise `None`. + strip_control_codes: Strip control codes that may break output. + + Returns: + New Content instance. + """ + if not text: + return EMPTY_CONTENT + new_content = cls( + text, + [Span(0, len(text), style)] if style else None, + cell_length, + strip_control_codes=strip_control_codes, + ) + return new_content + + @classmethod + def blank(cls, width: int, style: Style | str | None = None) -> Content: + """Get a Content instance consisting of spaces. + + Args: + width: Width of blank content (number of spaces). + style: Style of blank. + + Returns: + Content instance. + """ + if not width: + return EMPTY_CONTENT + blank = cls( + " " * width, + [Span(0, width, style)] if style else None, + cell_length=width, + ) + return blank + + @classmethod + def assemble( + cls, + *parts: str | Content | tuple[str, str | Style], + end: str = "", + strip_control_codes: bool = True, + ) -> Content: + """Construct new content from string, content, or tuples of (TEXT, STYLE). + + This is an efficient way of constructing Content composed of smaller pieces of + text and / or other Content objects. + + Example: + ```python + content = Content.assemble( + Content.from_markup("[b]assemble[/b]: "), # Other content + "pieces of text or content into a", # Simple string of text + ("a single Content instance", "underline"), # A tuple of text and a style + ) + ``` + + Args: + *parts: Parts to join to gether. A *part* may be a simple string, another Content + instance, or tuple containing text and a style. + end: Optional end to the Content. + strip_control_codes: Strip control codes that may break output. + """ + text: list[str] = [] + spans: list[Span] = [] + _Span = Span + text_append = text.append + + position: int = 0 + for part in parts: + if isinstance(part, str): + text_append(part) + position += len(part) + elif isinstance(part, tuple): + part_text, part_style = part + text_append(part_text) + if part_style: + spans.append( + _Span(position, position + len(part_text), part_style), + ) + position += len(part_text) + elif isinstance(part, Content): + text_append(part.plain) + if part.spans: + spans.extend( + [ + _Span(start + position, end + position, style) + for start, end, style in part.spans + ] + ) + position += len(part.plain) + if end: + text_append(end) + assembled_content = cls( + "".join(text), spans, strip_control_codes=strip_control_codes + ) + return assembled_content + + def simplify(self) -> Content: + """Simplify spans by joining contiguous spans together. + + This may produce faster renders if you have concatenated a large number of small pieces + of content with repeating styles. + + Note that this modifies the Content instance in-place, which might appear + to violate the immutability constraints, but it will not change the rendered output, + nor its hash. + + Returns: + Self. + """ + if not (spans := self._spans) or self._simplified: + return self + last_span = Span(-1, -1, "") + new_spans: list[Span] = [] + changed: bool = False + for span in spans: + if span.start == last_span.end and span.style == last_span.style: + last_span = new_spans[-1] = Span(last_span.start, span.end, span.style) + changed = True + else: + new_spans.append(span) + last_span = span + if changed: + self._spans[:] = new_spans + self._simplified = True + return self + + def add_spans(self, spans: Sequence[Span]) -> Content: + """Adds spans to this Content instance. + + Args: + spans: A sequence of spans. + + Returns: + A Content instance. + """ + if spans: + return Content( + self.plain, + [*self._spans, *spans], + self._cell_length, + strip_control_codes=False, + ) + return self + + def __eq__(self, other: object) -> bool: + """Compares text only, so that markup doesn't effect sorting.""" + if isinstance(other, str): + return self.plain == other + elif isinstance(other, Content): + return self.plain == other.plain + return NotImplemented + + def __lt__(self, other: object) -> bool: + if isinstance(other, str): + return self.plain < other + if isinstance(other, Content): + return self.plain < other.plain + return NotImplemented + + def is_same(self, content: Content) -> bool: + """Compare to another Content object. + + Two Content objects are the same if their text *and* spans match. + Note that if you use the `==` operator to compare Content instances, it will only consider + the plain text portion of the content (and not the spans). + + Args: + content: Content instance. + + Returns: + `True` if this is identical to `content`, otherwise `False`. + """ + if self is content: + return True + if self.plain != content.plain: + return False + return self.spans == content.spans + + def get_optimal_width(self, rules: RulesMap, container_width: int) -> int: + """Get optimal width of the Visual to display its content. + + The exact definition of "optimal width" is dependant on the Visual, but + will typically be wide enough to display output without cropping or wrapping, + and without superfluous space. + + Args: + rules: A mapping of style rules, such as the Widgets `styles` object. + + Returns: + A width in cells. + + """ + if self._optimal_width_cache is None: + self._optimal_width_cache = width = max( + cell_len(line) for line in self.plain.split("\n") + ) + else: + width = self._optimal_width_cache + return width + rules.get("line_pad", 0) * 2 + + def get_minimal_width(self, rules: RulesMap) -> int: + """Minimal width is the largest single word.""" + if not self.plain.strip(): + return 0 + if self._minimal_width_cache is None: + self._minimal_width_cache = width = max( + cell_len(word) + for line in self.plain.splitlines() + for word in line.split() + if word.strip() + ) + else: + width = self._minimal_width_cache + return width + rules.get("line_pad", 0) * 2 + + def get_height(self, rules: RulesMap, width: int) -> int: + """Get the height of the Visual if rendered at the given width. + + Args: + rules: A mapping of style rules, such as the Widgets `styles` object. + width: Width of visual in cells. + + Returns: + A height in lines. + """ + get_rule = rules.get + line_pad = get_rule("line_pad", 0) * 2 + overflow = get_rule("text_overflow", "fold") + no_wrap = get_rule("text_wrap", "wrap") == "nowrap" + cache_key = (width + line_pad, overflow, no_wrap) + if self._height_cache[0] == cache_key: + height = self._height_cache[1] + else: + lines = self.without_spans._wrap_and_format( + width - line_pad, overflow=overflow, no_wrap=no_wrap + ) + height = len(lines) + self._height_cache = (cache_key, height) + return height + + def _wrap_and_format( + self, + width: int, + align: TextAlign = "left", + overflow: TextOverflow = "fold", + no_wrap: bool = False, + line_pad: int = 0, + tab_size: int = 8, + selection: Selection | None = None, + selection_style: Style | None = None, + post_style: Style | None = None, + get_style: Callable[[str | Style], Style] = Style.parse, + ) -> list[_FormattedLine]: + """Wraps the text and applies formatting. + + Args: + width: Desired width. + align: Text alignment. + overflow: Overflow method. + no_wrap: Disabled wrapping. + tab_size: Cell with of tabs. + selection: Selection information or `None` if no selection. + selection_style: Selection style, or `None` if no selection. + + Returns: + List of formatted lines. + """ + output_lines: list[_FormattedLine] = [] + + if selection is not None: + get_span = selection.get_span + else: + + def get_span(y: int) -> tuple[int, int] | None: + return None + + for y, line in enumerate(self.split(allow_blank=True)): + if post_style is not None: + line = line.stylize(post_style) + + if selection_style is not None and (span := get_span(y)) is not None: + start, end = span + if end == -1: + end = len(line.plain) + line = line.stylize(selection_style, start, end) + + line = line.expand_tabs(tab_size) + + if no_wrap: + if overflow == "fold": + cuts = list(range(0, line.cell_length, width))[1:] + new_lines = [ + _FormattedLine(get_style, line, width, y=y, align=align) + for line in line.divide(cuts) + ] + else: + line = line.truncate(width, ellipsis=overflow == "ellipsis") + content_line = _FormattedLine( + get_style, line, width, y=y, align=align + ) + new_lines = [content_line] + else: + content_line = _FormattedLine(get_style, line, width, y=y, align=align) + offsets = divide_line( + line.plain, width - line_pad * 2, fold=overflow == "fold" + ) + divided_lines = content_line.content.divide(offsets) + ellipsis = overflow == "ellipsis" + divided_lines = [ + ( + line.truncate(width, ellipsis=ellipsis) + if last + else line.rstrip().truncate(width, ellipsis=ellipsis) + ) + for last, line in loop_last(divided_lines) + ] + + new_lines = [ + _FormattedLine( + get_style, + content.rstrip_end(width).pad(line_pad, line_pad), + width, + offset, + y, + align=align, + ) + for content, offset in zip(divided_lines, [0, *offsets]) + ] + new_lines[-1].line_end = True + + output_lines.extend(new_lines) + + return output_lines + + def render_strips( + self, width: int, height: int | None, style: Style, options: RenderOptions + ) -> list[Strip]: + """Render the Visual into an iterable of strips. Part of the Visual protocol. + + Args: + width: Width of desired render. + height: Height of desired render or `None` for any height. + style: The base style to render on top of. + options: Additional render options. + + Returns: + An list of Strips. + """ + + if not width: + return [] + + get_rule = options.rules.get + lines = self._wrap_and_format( + width, + align=get_rule("text_align", "left"), + overflow=get_rule("text_overflow", "fold"), + no_wrap=get_rule("text_wrap", "wrap") == "nowrap", + line_pad=get_rule("line_pad", 0), + tab_size=8, + selection=options.selection, + selection_style=options.selection_style, + post_style=options.post_style, + get_style=options.get_style, + ) + + if height is not None: + lines = lines[:height] + + strip_lines = [Strip(*line.to_strip(style)) for line in lines] + return strip_lines + + def __len__(self) -> int: + return len(self.plain) + + def __bool__(self) -> bool: + return self._text != "" + + def __hash__(self) -> int: + return hash(self._text) + + def __rich_repr__(self) -> rich.repr.Result: + try: + yield self._text + yield "spans", self._spans, [] + except AttributeError: + pass + + @property + def spans(self) -> Sequence[Span]: + """A sequence of spans used to markup regions of the content. + + !!! warning + Never attempt to mutate the spans, as this would certainly break the output--possibly + in quite subtle ways! + + """ + return self._spans + + @property + def cell_length(self) -> int: + """The cell length of the content.""" + # Calculated on demand + if self._cell_length is None: + self._cell_length = cell_len(self.plain) + return self._cell_length + + @property + def plain(self) -> str: + """Get the text as a single string.""" + return self._text + + @property + def without_spans(self) -> Content: + """The content with no spans""" + if self._spans: + return Content(self.plain, [], self._cell_length, strip_control_codes=False) + return self + + @property + def first_line(self) -> Content: + """The first line of the content.""" + if "\n" not in self.plain: + return self + return self[: self.plain.index("\n")] + + def __getitem__(self, slice: int | slice) -> Content: + def get_text_at(offset: int) -> "Content": + _Span = Span + content = Content( + self.plain[offset], + spans=[ + _Span(0, 1, style) + for start, end, style in self._spans + if end > offset >= start + ], + strip_control_codes=False, + ) + return content + + if isinstance(slice, int): + return get_text_at(slice) + else: + start, stop, step = slice.indices(len(self.plain)) + if step == 1: + if start == 0: + if stop >= len(self.plain): + return self + text = self.plain[:stop] + sliced_content = Content( + text, + self._trim_spans(text, self._spans), + strip_control_codes=False, + ) + else: + text = self.plain[start:stop] + spans = [ + span._shift(-start) + for span in self._spans + if span.end - start > 0 + ] + sliced_content = Content( + text, self._trim_spans(text, spans), strip_control_codes=False + ) + return sliced_content + + else: + # This would be a bit of work to implement efficiently + # For now, its not required + raise TypeError("slices with step!=1 are not supported") + + def __add__(self, other: Content | str) -> Content: + if isinstance(other, str): + return Content(self._text + other, self._spans, strip_control_codes=False) + if isinstance(other, Content): + offset = len(self.plain) + content = Content( + self.plain + other.plain, + ( + self._spans + + [ + Span(start + offset, end + offset, style) + for start, end, style in other._spans + ] + ), + ( + None + if self._cell_length is not None + else (self.cell_length + other.cell_length) + ), + ) + return content + return NotImplemented + + def __radd__(self, other: str) -> Content: + if not isinstance(other, str): + return NotImplemented + return Content(other) + self + + @classmethod + def _trim_spans(cls, text: str, spans: list[Span]) -> list[Span]: + """Remove or modify any spans that are over the end of the text.""" + max_offset = len(text) + _Span = Span + spans = [ + ( + span + if span.end < max_offset + else _Span(span.start, min(max_offset, span.end), span.style) + ) + for span in spans + if span.start < max_offset + ] + return spans + + def append(self, content: Content | str) -> Content: + """Append text or content to this content. + + Note this is a little inefficient, if you have many strings to append, consider [`join`][textual.content.Content.join]. + + Args: + content: A content instance, or a string. + + Returns: + New content. + """ + if isinstance(content, str): + return Content( + f"{self.plain}{content}", + self._spans, + ( + None + if self._cell_length is None + else self._cell_length + cell_len(content) + ), + strip_control_codes=False, + ) + return EMPTY_CONTENT.join([self, content]) + + def append_text(self, text: str, style: Style | str = "") -> Content: + """Append text give as a string, with an optional style. + + Args: + text: Text to append. + style: Optional style for new text. + + Returns: + New content. + """ + return self.append(Content.styled(text, style)) + + def join(self, lines: Iterable[Content | str]) -> Content: + """Join an iterable of content or strings. + + This works much like the join method on `str` objects. + Self is the separator (which maybe empty) placed between each string or Content. + + Args: + lines: An iterable of other Content instances or or strings. + + Returns: + A single Content instance, containing all of the lines. + + """ + text: list[str] = [] + spans: list[Span] = [] + + def iter_content() -> Iterable[Content]: + """Iterate the lines, optionally inserting the separator.""" + if self.plain: + for last, line in loop_last(lines): + yield ( + line + if isinstance(line, Content) + else Content(line, strip_control_codes=False) + ) + if not last: + yield self + else: + for line in lines: + yield ( + line + if isinstance(line, Content) + else Content(line, strip_control_codes=False) + ) + + extend_text = text.extend + extend_spans = spans.extend + offset = 0 + _Span = Span + + total_cell_length: int | None = self._cell_length + + for content in iter_content(): + if not content: + continue + extend_text(content._text) + extend_spans( + _Span(offset + start, offset + end, style) + for start, end, style in content._spans + if style + ) + offset += len(content._text) + if total_cell_length is not None: + total_cell_length = ( + None + if content._cell_length is None + else total_cell_length + content._cell_length + ) + + return Content("".join(text), spans, total_cell_length) + + def wrap( + self, width: int, *, align: TextAlign = "left", overflow: TextOverflow = "fold" + ) -> list[Content]: + """Wrap text so that it fits within the given dimensions. + + Note that Textual will automatically wrap Content in widgets. + This method is only required if you need some additional processing to lines. + + Args: + width: Maximum width of the line (in cells). + align: Alignment of lines. + overflow: Overflow of lines (what happens when the text doesn't fit). + + Returns: + A list of Content objects, one per line. + """ + lines = self._wrap_and_format(width, align, overflow) + content_lines = [line.content for line in lines] + return content_lines + + def fold(self, width: int) -> list[Content]: + """Fold this line into a list of lines which have a cell length no less than 2 and no greater than `width`. + + Folded lines may be 1 less than the width if it contains double width characters (which may + not be subdivided). + + Note that this method will not do any word wrapping. For that, see [wrap()][textual.content.Content.wrap]. + + Args: + width: Desired maximum width (in cells) + + Returns: + List of content instances. + """ + if not self: + return [self] + text = self.plain + lines: list[Content] = [] + position = 0 + width = max(width, 2) + while True: + snip = text[position : position + width] + if not snip: + break + snip_cell_length = cell_len(snip) + if snip_cell_length < width: + # last snip + lines.append(self[position : position + width]) + break + if snip_cell_length == width: + # Cell length is exactly width + lines.append(self[position : position + width]) + position += len(snip) + continue + # TODO: Can this be more efficient? + extra_cells = snip_cell_length - width + if start_snip := extra_cells // 2: + snip_cell_length -= cell_len(snip[-start_snip:]) + snip = snip[: len(snip) - start_snip] + while snip_cell_length > width: + snip_cell_length -= cell_len(snip[-1]) + snip = snip[:-1] + lines.append(self[position : position + len(snip)]) + position += len(snip) + + return lines + + def get_style_at_offset(self, offset: int) -> Style: + """Get the style of a character at give offset. + + Args: + offset (int): Offset into text (negative indexing supported) + + Returns: + Style: A Style instance. + """ + # TODO: This is a little inefficient, it is only used by full justify + if offset < 0: + offset = len(self) + offset + + style = Style() + for start, end, span_style in self._spans: + if end > offset >= start: + style += span_style + return style + + def truncate( + self, + max_width: int, + *, + ellipsis=False, + pad: bool = False, + ) -> Content: + """Truncate the content at a given cell width. + + Args: + max_width: The maximum width in cells. + ellipsis: Insert an ellipsis when cropped. + pad: Pad the content if less than `max_width`. + + Returns: + New Content. + """ + + length = self.cell_length + if length == max_width: + return self + + text = self.plain + spans = self._spans + if pad and length < max_width: + spaces = max_width - length + text = f"{self.plain}{' ' * spaces}" + return Content(text, spans, max_width, strip_control_codes=False) + elif length > max_width: + if ellipsis and max_width: + text = set_cell_size(self.plain, max_width - 1) + "…" + else: + text = set_cell_size(self.plain, max_width) + spans = self._trim_spans(text, self._spans) + return Content(text, spans, max_width, strip_control_codes=False) + else: + return self + + def pad_left(self, count: int, character: str = " ") -> Content: + """Pad the left with a given character. + + Args: + count (int): Number of characters to pad. + character (str, optional): Character to pad with. Defaults to " ". + """ + assert len(character) == 1, "Character must be a string of length 1" + if count: + text = f"{character * count}{self.plain}" + _Span = Span + spans = [ + _Span(start + count, end + count, style) + for start, end, style in self._spans + ] + content = Content( + text, + spans, + None if self._cell_length is None else self._cell_length + count, + strip_control_codes=False, + ) + return content + + return self + + def extend_right(self, count: int, character: str = " ") -> Content: + """Add repeating characters (typically spaces) to the content with the style(s) of the last character. + + Args: + count: Number of spaces. + character: Character to add with. + + Returns: + A Content instance. + """ + if count: + plain = self.plain + plain_len = len(plain) + return Content( + f"{plain}{character * count}", + [ + (span.extend(count) if span.end == plain_len else span) + for span in self._spans + ], + None if self._cell_length is None else self._cell_length + count, + strip_control_codes=False, + ) + return self + + def pad_right(self, count: int, character: str = " ") -> Content: + """Pad the right with a given character. + + Args: + count (int): Number of characters to pad. + character (str, optional): Character to pad with. Defaults to " ". + """ + assert len(character) == 1, "Character must be a string of length 1" + if count: + return Content( + f"{self.plain}{character * count}", + self._spans, + None if self._cell_length is None else self._cell_length + count, + strip_control_codes=False, + ) + return self + + def pad(self, left: int, right: int, character: str = " ") -> Content: + """Pad both the left and right edges with a given number of characters. + + Args: + left (int): Number of characters to pad on the left. + right (int): Number of characters to pad on the right. + character (str, optional): Character to pad with. Defaults to " ". + """ + assert len(character) == 1, "Character must be a string of length 1" + if left or right: + text = f"{character * left}{self.plain}{character * right}" + _Span = Span + if left: + spans = [ + _Span(start + left, end + left, style) + for start, end, style in self._spans + ] + else: + spans = self._spans + content = Content( + text, + spans, + None if self._cell_length is None else self._cell_length + left + right, + strip_control_codes=False, + ) + return content + + return self + + def center(self, width: int, ellipsis: bool = False) -> Content: + """Align a line to the center. + + Args: + width: Desired width of output. + ellipsis: Insert ellipsis if content is truncated. + + Returns: + New line Content. + """ + content = self.rstrip().truncate(width, ellipsis=ellipsis) + left = (width - content.cell_length) // 2 + right = width - left + content = content.pad(left, right) + return content + + def right(self, width: int, ellipsis: bool = False) -> Content: + """Align a line to the right. + + Args: + width: Desired width of output. + ellipsis: Insert ellipsis if content is truncated. + + Returns: + New line Content. + """ + content = self.rstrip().truncate(width, ellipsis=ellipsis) + content = content.pad_left(width - content.cell_length) + return content + + def right_crop(self, amount: int = 1) -> Content: + """Remove a number of characters from the end of the text. + + Args: + amount: Number of characters to crop. + + Returns: + New Content + + """ + max_offset = len(self.plain) - amount + _Span = Span + spans = [ + ( + span + if span.end < max_offset + else _Span(span.start, min(max_offset, span.end), span.style) + ) + for span in self._spans + if span.start < max_offset + ] + text = self.plain[:-amount] + length = None if self._cell_length is None else self._cell_length - amount + return Content(text, spans, length, strip_control_codes=False) + + def stylize( + self, style: Style | str, start: int = 0, end: int | None = None + ) -> Content: + """Apply a style to the text, or a portion of the text. + + Args: + style (Union[str, Style]): Style instance or style definition to apply. + start (int): Start offset (negative indexing is supported). Defaults to 0. + end (Optional[int], optional): End offset (negative indexing is supported), or None for end of text. Defaults to None. + """ + if not style: + return self + length = len(self) + if start < 0: + start = length + start + if end is None: + end = length + if end < 0: + end = length + end + if start >= length or end <= start: + # Span not in text or not valid + return self + return Content( + self.plain, + self._spans + [Span(start, length if length < end else end, style)], + self._cell_length, + strip_control_codes=False, + ) + + def stylize_before( + self, + style: Style | str, + start: int = 0, + end: int | None = None, + ) -> Content: + """Apply a style to the text, or a portion of the text. + + Styles applies with this method will be applied *before* other styles already present. + + Args: + style (Union[str, Style]): Style instance or style definition to apply. + start (int): Start offset (negative indexing is supported). Defaults to 0. + end (Optional[int], optional): End offset (negative indexing is supported), or None for end of text. Defaults to None. + """ + if not style: + return self + length = len(self) + if start < 0: + start = length + start + if end is None: + end = length + if end < 0: + end = length + end + if start >= length or end <= start: + # Span not in text or not valid + return self + return Content( + self.plain, + [Span(start, length if length < end else end, style), *self._spans], + self._cell_length, + strip_control_codes=False, + ) + + def render( + self, + base_style: Style = Style.null(), + end: str = "\n", + parse_style: Callable[[str | Style], Style] | None = None, + ) -> Iterable[tuple[str, Style]]: + """Render Content in to an iterable of strings and styles. + + This is typically called by Textual when displaying Content, but may be used if you want to do more advanced + processing of the output. + + Args: + base_style: The style used as a base. This will typically be the style of the widget underneath the content. + end: Text to end the output, such as a new line. + parse_style: Method to parse a style. Use `App.parse_style` to apply CSS variables in styles. + + Returns: + An iterable of string and styles, which make up the content. + + """ + if not self._spans: + yield (self._text, base_style) + if end: + yield end, base_style + return + + get_style: Callable[[str | Style], Style] + if parse_style is None: + + def _get_style(style: str | Style) -> Style: + """The default get_style method.""" + if isinstance(style, Style): + return style + try: + visual_style = Style.parse(style) + except Exception: + visual_style = Style.null() + return visual_style + + get_style = _get_style + + else: + get_style = parse_style + + enumerated_spans = list(enumerate(self._spans, 1)) + style_map = { + index: ( + get_style(span.style) if isinstance(span.style, str) else span.style + ) + for index, span in enumerated_spans + } + style_map[0] = base_style + text = self.plain + + spans = [ + (0, False, 0), + *((span.start, False, index) for index, span in enumerated_spans), + *((span.end, True, index) for index, span in enumerated_spans), + (len(text), True, 0), + ] + spans.sort(key=itemgetter(0, 1)) + + stack: list[int] = [] + stack_append = stack.append + stack_pop = stack.remove + + style_cache: dict[tuple[int, ...], Style] = {} + style_cache_get = style_cache.get + combine = Style.combine + + def get_current_style() -> Style: + """Construct current style from stack.""" + cache_key = tuple(stack) + cached_style = style_cache_get(cache_key) + if cached_style is not None: + return cached_style + styles = [style_map[_style_id] for _style_id in cache_key] + current_style = combine(styles) + style_cache[cache_key] = current_style + return current_style + + for (offset, leaving, style_id), (next_offset, _, _) in zip(spans, spans[1:]): + if leaving: + stack_pop(style_id) + else: + stack_append(style_id) + if next_offset > offset: + yield text[offset:next_offset], get_current_style() + if end: + yield end, base_style + + def render_segments( + self, base_style: Style = Style.null(), end: str = "" + ) -> list[Segment]: + """Render the Content in to a list of segments. + + Args: + base_style: Base style for render (style under the content). Defaults to Style.null(). + end: Character to end the segments with. Defaults to "". + + Returns: + A list of segments. + """ + _Segment = Segment + segments = [ + _Segment(text, (style.rich_style if style else None)) + for text, style in self.render(base_style, end) + ] + return segments + + def __rich__(self): + """Allow Content to be rendered with rich.print.""" + from rich.segment import Segments + + return Segments(self.render_segments(Style(), "\n")) + + def _divide_spans(self, offsets: tuple[int, ...]) -> list[tuple[Span, int, int]]: + """Divide content from a list of offset to cut. + + Args: + offsets: A tuple of indices in to the text. + + Returns: + A list of tuples containing Spans and their line offsets. + """ + if self._divide_cache is None: + self._divide_cache = FIFOCache(4) + if (cached_result := self._divide_cache.get(offsets)) is not None: + return cached_result + + line_ranges = list(zip(offsets, offsets[1:])) + text_length = len(self.plain) + line_count = len(line_ranges) + span_ranges: list[tuple[Span, int, int]] = [] + for span in self._spans: + span_start, span_end, _style = span + if span_start >= text_length: + continue + span_end = min(text_length, span_end) + lower_bound = 0 + upper_bound = line_count + start_line_no = (lower_bound + upper_bound) // 2 + + while True: + line_start, line_end = line_ranges[start_line_no] + if span_start < line_start: + upper_bound = start_line_no - 1 + elif span_start > line_end: + lower_bound = start_line_no + 1 + else: + break + start_line_no = (lower_bound + upper_bound) // 2 + + if span_end < line_end: + end_line_no = start_line_no + else: + end_line_no = lower_bound = start_line_no + upper_bound = line_count + + while True: + line_start, line_end = line_ranges[end_line_no] + if span_end < line_start: + upper_bound = end_line_no - 1 + elif span_end > line_end: + lower_bound = end_line_no + 1 + else: + break + end_line_no = (lower_bound + upper_bound) // 2 + + span_ranges.append((span, start_line_no, end_line_no + 1)) + self._divide_cache[offsets] = span_ranges + return span_ranges + + def divide(self, offsets: Sequence[int]) -> list[Content]: + """Divide the content at the given offsets. + + This will cut the content in to pieces, and return those pieces. Note that the number of pieces + return will be one greater than the number of cuts. + + Args: + offsets: Sequence of offsets (in characters) of where to apply the cuts. + + Returns: + List of Content instances which combined would be equal to the whole. + """ + if not offsets: + return [self] + + offsets = sorted(offsets) + text = self.plain + divide_offsets = tuple([0, *offsets, len(text)]) + line_ranges = list(zip(divide_offsets, divide_offsets[1:])) + line_text = [text[start:end] for start, end in line_ranges] + new_lines = [Content(line, None) for line in line_text] + + if not self._spans: + return new_lines + + _line_appends = [line._spans.append for line in new_lines] + _Span = Span + + for ( + (span_start, span_end, style), + start_line, + end_line, + ) in self._divide_spans(divide_offsets): + for line_no in range(start_line, end_line): + line_start, line_end = line_ranges[line_no] + new_start = max(0, span_start - line_start) + new_end = min(span_end - line_start, line_end - line_start) + if new_end > new_start: + _line_appends[line_no](_Span(new_start, new_end, style)) + + return new_lines + + def split( + self, + separator: str = "\n", + *, + include_separator: bool = False, + allow_blank: bool = False, + ) -> list[Content]: + """Split rich text into lines, preserving styles. + + Args: + separator (str, optional): String to split on. Defaults to "\\\\n". + include_separator (bool, optional): Include the separator in the lines. Defaults to False. + allow_blank (bool, optional): Return a blank line if the text ends with a separator. Defaults to False. + + Returns: + List[Content]: A list of Content, one per line of the original. + """ + assert separator, "separator must not be empty" + text = self.plain + if separator not in text: + return [self] + + cache_key = (separator, include_separator, allow_blank) + if self._split_cache is None: + self._split_cache = FIFOCache(4) + if (cached_result := self._split_cache.get(cache_key)) is not None: + return cached_result.copy() + + if include_separator: + lines = self.divide( + [match.end() for match in re.finditer(re.escape(separator), text)], + ) + else: + + def flatten_spans() -> Iterable[int]: + for match in re.finditer(re.escape(separator), text): + yield from match.span() + + lines = [ + line + for line in self.divide(list(flatten_spans())) + if line.plain != separator + ] + + if not allow_blank and text.endswith(separator): + lines.pop() + + self._split_cache[cache_key] = lines + return lines + + def rstrip(self, chars: str | None = None) -> Content: + """Strip characters from end of text.""" + text = self.plain.rstrip(chars) + return Content(text, self._trim_spans(text, self._spans)) + + def rstrip_end(self, size: int) -> Content: + """Remove whitespace beyond a certain width at the end of the text. + + Args: + size (int): The desired size of the text. + """ + text_length = len(self) + if text_length > size: + excess = text_length - size + whitespace_match = _re_whitespace.search(self.plain) + if whitespace_match is not None: + whitespace_count = len(whitespace_match.group(0)) + return self.right_crop(min(whitespace_count, excess)) + return self + + def extend_style(self, spaces: int) -> Content: + """Extend the Text given number of spaces where the spaces have the same style as the last character. + + Args: + spaces (int): Number of spaces to add to the Text. + + Returns: + New content with additional spaces at the end. + """ + if spaces <= 0: + return self + spans = self._spans + new_spaces = " " * spaces + if spans: + end_offset = len(self) + spans = [ + span.extend(spaces) if span.end >= end_offset else span + for span in spans + ] + return Content(self._text + new_spaces, spans, self.cell_length + spaces) + return Content(self._text + new_spaces, self._spans, self._cell_length) + + def expand_tabs(self, tab_size: int = 8) -> Content: + """Converts tabs to spaces. + + Args: + tab_size (int, optional): Size of tabs. Defaults to 8. + + """ + if "\t" not in self.plain: + return self + + if not self._spans: + return Content(self.plain.expandtabs(tab_size)) + + new_text: list[Content] = [] + append = new_text.append + + for line in self.split("\n", include_separator=True): + if "\t" not in line.plain: + append(line) + else: + cell_position = 0 + parts = line.split("\t", include_separator=True) + for part in parts: + if part.plain.endswith("\t"): + part = Content( + part._text[:-1] + " ", part._spans, part._cell_length + ) + cell_position += part.cell_length + tab_remainder = cell_position % tab_size + if tab_remainder: + spaces = tab_size - tab_remainder + part = part.extend_style(spaces) + cell_position += spaces + else: + cell_position += part.cell_length + append(part) + + content = EMPTY_CONTENT.join(new_text) + return content + + def highlight_regex( + self, + highlight_regex: re.Pattern[str] | str, + *, + style: Style | str, + maximum_highlights: int | None = None, + ) -> Content: + """Apply a style to text that matches a regular expression. + + Args: + highlight_regex: Regular expression as a string, or compiled. + style: Style to apply. + maximum_highlights: Maximum number of matches to highlight, or `None` for no maximum. + + Returns: + new content. + """ + spans: list[Span] = self._spans.copy() + append_span = spans.append + _Span = Span + plain = self.plain + if isinstance(highlight_regex, str): + re_highlight = re.compile(highlight_regex) + else: + re_highlight = highlight_regex + count = 0 + for match in re_highlight.finditer(plain): + start, end = match.span() + if end > start: + append_span(_Span(start, end, style)) + if ( + maximum_highlights is not None + and (count := count + 1) >= maximum_highlights + ): + break + return Content(self._text, spans, cell_length=self._cell_length) + + +class _FormattedLine: + """A line of content with additional formatting information. + + This class is used internally within Content, and you are unlikely to need it an an app. + """ + + def __init__( + self, + get_style: Callable[[str | Style], Style], + content: Content, + width: int, + x: int = 0, + y: int = 0, + align: TextAlign = "left", + line_end: bool = False, + link_style: Style | None = None, + ) -> None: + self.get_style = get_style + self.content = content + self.width = width + self.x = x + self.y = y + self.align = align + self.line_end = line_end + self.link_style = link_style + + @property + def plain(self) -> str: + return self.content.plain + + def to_strip(self, style: Style) -> tuple[list[Segment], int]: + _Segment = Segment + align = self.align + width = self.width + pad_left = pad_right = 0 + content = self.content + x = self.x + y = self.y + get_style = self.get_style + + if align in ("start", "left") or (align == "justify" and self.line_end): + pass + + elif align == "center": + excess_space = width - self.content.cell_length + pad_left = excess_space // 2 + pad_right = excess_space - pad_left + + elif align in ("end", "right"): + pad_left = width - self.content.cell_length + + elif align == "justify": + words = content.split(" ", include_separator=False) + words_size = sum(cell_len(word.plain.rstrip(" ")) for word in words) + num_spaces = len(words) - 1 + spaces = [1] * num_spaces + index = 0 + if spaces: + while words_size + num_spaces < width: + spaces[len(spaces) - index - 1] += 1 + num_spaces += 1 + index = (index + 1) % len(spaces) + + segments: list[Segment] = [] + add_segment = segments.append + x = self.x + for index, word in enumerate(words): + for text, text_style in word.render( + style, end="", parse_style=get_style + ): + add_segment( + _Segment( + text, (style + text_style).rich_style_with_offset(x, y) + ) + ) + x += len(text) + 1 + if index < len(spaces) and (pad := spaces[index]): + add_segment(_Segment(" " * pad, (style + text_style).rich_style)) + + return segments, width + + segments = ( + [Segment(" " * pad_left, style.background_style.rich_style)] + if pad_left + else [] + ) + add_segment = segments.append + for text, text_style in content.render(style, end="", parse_style=get_style): + add_segment( + _Segment(text, (style + text_style).rich_style_with_offset(x, y)) + ) + x += len(text) + + if pad_right: + segments.append( + _Segment(" " * pad_right, style.background_style.rich_style) + ) + + return (segments, content.cell_length + pad_left + pad_right) + + def _apply_link_style( + self, link_style: RichStyle, segments: list[Segment] + ) -> list[Segment]: + _Segment = Segment + segments = [ + _Segment( + text, + ( + style + if style._meta is None + else (style + link_style if "@click" in style.meta else style) + ), + control, + ) + for text, style, control in segments + if style is not None + ] + return segments + + +EMPTY_CONTENT: Final = Content("") diff --git a/src/memray/_vendor/textual/coordinate.py b/src/memray/_vendor/textual/coordinate.py new file mode 100644 index 0000000000..a230df039e --- /dev/null +++ b/src/memray/_vendor/textual/coordinate.py @@ -0,0 +1,53 @@ +""" +A class to store a coordinate, used by the [DataTable][textual.widgets.DataTable]. +""" + +from __future__ import annotations + +from typing import NamedTuple + + +class Coordinate(NamedTuple): + """An object representing a row/column coordinate within a grid.""" + + row: int + """The row of the coordinate within a grid.""" + + column: int + """The column of the coordinate within a grid.""" + + def left(self) -> Coordinate: + """Get the coordinate to the left. + + Returns: + The coordinate to the left. + """ + row, column = self + return Coordinate(row, column - 1) + + def right(self) -> Coordinate: + """Get the coordinate to the right. + + Returns: + The coordinate to the right. + """ + row, column = self + return Coordinate(row, column + 1) + + def up(self) -> Coordinate: + """Get the coordinate above. + + Returns: + The coordinate above. + """ + row, column = self + return Coordinate(row - 1, column) + + def down(self) -> Coordinate: + """Get the coordinate below. + + Returns: + The coordinate below. + """ + row, column = self + return Coordinate(row + 1, column) diff --git a/src/memray/_vendor/textual/css/__init__.py b/src/memray/_vendor/textual/css/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/memray/_vendor/textual/css/_error_tools.py b/src/memray/_vendor/textual/css/_error_tools.py new file mode 100644 index 0000000000..a5e0972e03 --- /dev/null +++ b/src/memray/_vendor/textual/css/_error_tools.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from typing import Iterable + + +def friendly_list( + words: Iterable[str], joiner: str = "or", omit_empty: bool = True +) -> str: + """Generate a list of words as readable prose. + + >>> friendly_list(["foo", "bar", "baz"]) + "'foo', 'bar', or 'baz'" + + Args: + words: A list of words. + joiner: The last joiner word. + + Returns: + List as prose. + """ + words = [ + repr(word) for word in sorted(words, key=str.lower) if word or not omit_empty + ] + if len(words) == 1: + return words[0] + elif len(words) == 2: + word1, word2 = words + return f"{word1} {joiner} {word2}" + else: + return f'{", ".join(words[:-1])}, {joiner} {words[-1]}' diff --git a/src/memray/_vendor/textual/css/_help_renderables.py b/src/memray/_vendor/textual/css/_help_renderables.py new file mode 100644 index 0000000000..5957162f53 --- /dev/null +++ b/src/memray/_vendor/textual/css/_help_renderables.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +from typing import Iterable + +import rich.repr +from rich.console import Console, ConsoleOptions, RenderResult +from rich.highlighter import ReprHighlighter +from rich.markup import render +from rich.text import Text + +_highlighter = ReprHighlighter() + + +def _markup_and_highlight(text: str) -> Text: + """Highlight and render markup in a string of text, returning + a styled Text object. + + Args: + text: The text to highlight and markup. + + Returns: + The Text, with highlighting and markup applied. + """ + return _highlighter(render(text)) + + +class Example: + """Renderable for an example, which can appear below bullet points in + the help text. + + Attributes: + markup: The markup to display for this example + """ + + def __init__(self, markup: str) -> None: + self.markup: str = markup + + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + yield _markup_and_highlight(f" [dim]e.g. [/][i]{self.markup}[/]") + + +@rich.repr.auto +class Bullet: + """Renderable for a single 'bullet point' containing information and optionally some examples + pertaining to that information. + + Attributes: + markup: The markup to display + examples: An optional list of examples + to display below this bullet. + """ + + def __init__(self, markup: str, examples: Iterable[Example] | None = None) -> None: + self.markup: str = markup + self.examples: Iterable[Example] | None = [] if examples is None else examples + + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + yield _markup_and_highlight(self.markup) + if self.examples is not None: + yield from self.examples + + +@rich.repr.auto +class HelpText: + """Renderable for help text - the user is shown this when they + encounter a style-related error (e.g. setting a style property to an invalid + value). + + Attributes: + summary: A succinct summary of the issue. + bullets: Bullet points which provide additional + context around the issue. These are rendered below the summary. + """ + + def __init__( + self, summary: str, *, bullets: Iterable[Bullet] | None = None + ) -> None: + self.summary: str = summary + self.bullets: Iterable[Bullet] | None = bullets or [] + + def __str__(self) -> str: + return self.summary + + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + from rich.tree import Tree + + tree = Tree(_markup_and_highlight(f"[b blue]{self.summary}"), guide_style="dim") + if self.bullets is not None: + for bullet in self.bullets: + tree.add(bullet) + yield tree diff --git a/src/memray/_vendor/textual/css/_help_text.py b/src/memray/_vendor/textual/css/_help_text.py new file mode 100644 index 0000000000..57926bbc8b --- /dev/null +++ b/src/memray/_vendor/textual/css/_help_text.py @@ -0,0 +1,858 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Iterable, Sequence + +from typing_extensions import Literal + +from memray._vendor.textual.color import ColorParseError +from memray._vendor.textual.css._error_tools import friendly_list +from memray._vendor.textual.css._help_renderables import Bullet, Example, HelpText +from memray._vendor.textual.css.constants import ( + VALID_ALIGN_HORIZONTAL, + VALID_ALIGN_VERTICAL, + VALID_BORDER, + VALID_EXPAND, + VALID_KEYLINE, + VALID_LAYOUT, + VALID_POSITION, + VALID_STYLE_FLAGS, + VALID_TEXT_ALIGN, +) +from memray._vendor.textual.css.scalar import SYMBOL_UNIT + +StylingContext = Literal["inline", "css"] +"""The type of styling the user was using when the error was encountered. +Used to give help text specific to the context i.e. we give CSS help if the +user hit an issue with their CSS, and Python help text when the user has an +issue with inline styles.""" + + +@dataclass +class ContextSpecificBullets: + """ + Args: + inline: Information only relevant to users who are using inline styling. + css: Information only relevant to users who are using CSS. + """ + + inline: Sequence[Bullet] + css: Sequence[Bullet] + + def get_by_context(self, context: StylingContext) -> list[Bullet]: + """Get the information associated with the given context + + Args: + context: The context to retrieve info for. + """ + if context == "inline": + return list(self.inline) + else: + return list(self.css) + + +def _python_name(property_name: str) -> str: + """Convert a CSS property name to the corresponding Python attribute name + + Args: + property_name: The CSS property name + + Returns: + The Python attribute name as found on the Styles object + """ + return property_name.replace("-", "_") + + +def _css_name(property_name: str) -> str: + """Convert a Python style attribute name to the corresponding CSS property name + + Args: + property_name: The Python property name + + Returns: + The CSS property name + """ + return property_name.replace("_", "-") + + +def _contextualize_property_name( + property_name: str, + context: StylingContext, +) -> str: + """Convert a property name to CSS or inline by replacing + '-' with '_' or vice-versa + + Args: + property_name: The name of the property + context: The context the property is being used in. + + Returns: + The property name converted to the given context. + """ + return _css_name(property_name) if context == "css" else _python_name(property_name) + + +def _spacing_examples(property_name: str) -> ContextSpecificBullets: + """Returns examples for spacing properties""" + return ContextSpecificBullets( + inline=[ + Bullet( + f"Set [i]{property_name}[/] to a tuple to assign spacing to each edge", + examples=[ + Example( + f"widget.styles.{property_name} = (1, 2) [dim]# Vertical, horizontal" + ), + Example( + f"widget.styles.{property_name} = (1, 2, 3, 4) [dim]# Top, right, bottom, left" + ), + ], + ), + Bullet( + "Or to an integer to assign a single value to all edges", + examples=[Example(f"widget.styles.{property_name} = 2")], + ), + ], + css=[ + Bullet( + "Supply 1, 2 or 4 integers separated by a space", + examples=[ + Example(f"{property_name}: 1;"), + Example(f"{property_name}: 1 2; [dim]# Vertical, horizontal"), + Example( + f"{property_name}: 1 2 3 4; [dim]# Top, right, bottom, left" + ), + ], + ), + ], + ) + + +def property_invalid_value_help_text( + property_name: str, + context: StylingContext, + *, + suggested_property_name: str | None = None, +) -> HelpText: + """Help text to show when the user supplies an invalid value for CSS property + property. + + Args: + property_name: The name of the property. + context: The context the spacing property is being used in. + Keyword Args: + suggested_property_name: A suggested name for the property (e.g. "width" for "wdth"). + + Returns: + Renderable for displaying the help text for this property. + """ + property_name = _contextualize_property_name(property_name, context) + summary = f"Invalid CSS property {property_name!r}" + if suggested_property_name: + suggested_property_name = _contextualize_property_name( + suggested_property_name, context + ) + summary += f". Did you mean '{suggested_property_name}'?" + return HelpText(summary) + + +def spacing_wrong_number_of_values_help_text( + property_name: str, + num_values_supplied: int, + context: StylingContext, +) -> HelpText: + """Help text to show when the user supplies the wrong number of values + for a spacing property (e.g. padding or margin). + + Args: + property_name: The name of the property. + num_values_supplied: The number of values the user supplied (a number other than 1, 2 or 4). + context: The context the spacing property is being used in. + + Returns: + Renderable for displaying the help text for this property. + """ + property_name = _contextualize_property_name(property_name, context) + return HelpText( + summary=f"Invalid number of values for the [i]{property_name}[/] property", + bullets=[ + Bullet( + f"You supplied {num_values_supplied} values for the [i]{property_name}[/] property" + ), + Bullet( + "Spacing properties like [i]margin[/] and [i]padding[/] require either 1, 2 or 4 integer values" + ), + *_spacing_examples(property_name).get_by_context(context), + ], + ) + + +def spacing_invalid_value_help_text( + property_name: str, + context: StylingContext, +) -> HelpText: + """Help text to show when the user supplies an invalid value for a spacing + property. + + Args: + property_name: The name of the property. + context: The context the spacing property is being used in. + + Returns: + Renderable for displaying the help text for this property. + """ + property_name = _contextualize_property_name(property_name, context) + return HelpText( + summary=f"Invalid value for the [i]{property_name}[/] property", + bullets=_spacing_examples(property_name).get_by_context(context), + ) + + +def scalar_help_text( + property_name: str, + context: StylingContext, +) -> HelpText: + """Help text to show when the user supplies an invalid value for + a scalar property. + + Args: + property_name: The name of the property. + num_values_supplied: The number of values the user supplied (a number other than 1, 2 or 4). + context: The context the scalar property is being used in. + + Returns: + Renderable for displaying the help text for this property. + """ + property_name = _contextualize_property_name(property_name, context) + return HelpText( + summary=f"Invalid value for the [i]{property_name}[/] property", + bullets=[ + Bullet( + f"Scalar properties like [i]{property_name}[/] require numerical values and an optional unit" + ), + Bullet(f"Valid units are {friendly_list(SYMBOL_UNIT)}"), + *ContextSpecificBullets( + inline=[ + Bullet( + "Assign a string, int or Scalar object itself", + examples=[ + Example(f'widget.styles.{property_name} = "50%"'), + Example(f"widget.styles.{property_name} = 10"), + Example(f"widget.styles.{property_name} = Scalar(...)"), + ], + ), + ], + css=[ + Bullet( + "Write the number followed by the unit", + examples=[ + Example(f"{property_name}: 50%;"), + Example(f"{property_name}: 5;"), + ], + ), + ], + ).get_by_context(context), + ], + ) + + +def string_enum_help_text( + property_name: str, + valid_values: Iterable[str], + context: StylingContext, +) -> HelpText: + """Help text to show when the user supplies an invalid value for a string + enum property. + + Args: + property_name: The name of the property. + valid_values: A list of the values that are considered valid. + context: The context the property is being used in. + + Returns: + Renderable for displaying the help text for this property. + """ + property_name = _contextualize_property_name(property_name, context) + return HelpText( + summary=f"Invalid value for the [i]{property_name}[/] property", + bullets=[ + Bullet( + f"The [i]{property_name}[/] property can only be set to {friendly_list(valid_values)}" + ), + *ContextSpecificBullets( + inline=[ + Bullet( + "Assign any of the valid strings to the property", + examples=[ + Example(f'widget.styles.{property_name} = "{valid_value}"') + for valid_value in sorted(valid_values) + ], + ) + ], + css=[ + Bullet( + "Assign any of the valid strings to the property", + examples=[ + Example(f"{property_name}: {valid_value};") + for valid_value in sorted(valid_values) + ], + ) + ], + ).get_by_context(context), + ], + ) + + +def color_property_help_text( + property_name: str, + context: StylingContext, + *, + error: Exception | None = None, + value: str | None = None, +) -> HelpText: + """Help text to show when the user supplies an invalid value for a color + property. For example, an unparsable color string. + + Args: + property_name: The name of the property. + context: The context the property is being used in. + error: The error that caused this help text to be displayed. + + Returns: + Renderable for displaying the help text for this property. + """ + property_name = _contextualize_property_name(property_name, context) + if value is None: + summary = f"Invalid value for the [i]{property_name}[/] property" + else: + summary = f"Invalid value ({value!r}) for the [i]{property_name}[/] property" + suggested_color = ( + error.suggested_color if error and isinstance(error, ColorParseError) else None + ) + if suggested_color: + summary += f". Did you mean '{suggested_color}'?" + return HelpText( + summary=summary, + bullets=[ + Bullet( + f"The [i]{property_name}[/] property can only be set to a valid color" + ), + Bullet("Colors can be specified using hex, RGB, or ANSI color names"), + *ContextSpecificBullets( + inline=[ + Bullet( + "Assign colors using strings or Color objects", + examples=[ + Example(f'widget.styles.{property_name} = "#ff00aa"'), + Example( + f'widget.styles.{property_name} = "rgb(12,231,45)"' + ), + Example(f'widget.styles.{property_name} = "red"'), + Example( + f"widget.styles.{property_name} = Color(1, 5, 29, a=0.5)" + ), + ], + ) + ], + css=[ + Bullet( + "Colors can be set as follows", + examples=[ + Example(f"{property_name}: [#ff00aa]#ff00aa[/];"), + Example(f"{property_name}: rgb(12,231,45);"), + Example(f"{property_name}: [rgb(255,0,0)]red[/];"), + ], + ) + ], + ).get_by_context(context), + ], + ) + + +def border_property_help_text(property_name: str, context: StylingContext) -> HelpText: + """Help text to show when the user supplies an invalid value for a border + property (such as border, border-right, outline). + + Args: + property_name: The name of the property. + context: The context the property is being used in. + + Returns: + Renderable for displaying the help text for this property. + """ + property_name = _contextualize_property_name(property_name, context) + return HelpText( + summary=f"Invalid value for [i]{property_name}[/] property", + bullets=[ + *ContextSpecificBullets( + inline=[ + Bullet( + f"Set [i]{property_name}[/] using a tuple of the form (, )", + examples=[ + Example( + f'widget.styles.{property_name} = ("solid", "red")' + ), + Example( + f'widget.styles.{property_name} = ("round", "#f0f0f0")' + ), + Example( + f'widget.styles.{property_name} = [("dashed", "#f0f0f0"), ("solid", "blue")] [dim]# Vertical, horizontal' + ), + ], + ), + Bullet( + f"Valid values for are:\n{friendly_list(VALID_BORDER)}" + ), + Bullet( + "Colors can be specified using hex, RGB, or ANSI color names" + ), + ], + css=[ + Bullet( + f"Set [i]{property_name}[/] using a value of the form [i] [/]", + examples=[ + Example(f"{property_name}: solid red;"), + Example(f"{property_name}: dashed #00ee22;"), + ], + ), + Bullet( + f"Valid values for are:\n{friendly_list(VALID_BORDER)}" + ), + Bullet( + "Colors can be specified using hex, RGB, or ANSI color names" + ), + ], + ).get_by_context(context), + ], + ) + + +def layout_property_help_text(property_name: str, context: StylingContext) -> HelpText: + """Help text to show when the user supplies an invalid value + for a layout property. + + Args: + property_name: The name of the property. + context: The context the property is being used in. + + Returns: + Renderable for displaying the help text for this property. + """ + property_name = _contextualize_property_name(property_name, context) + return HelpText( + summary=f"Invalid value for [i]{property_name}[/] property", + bullets=[ + Bullet( + f"The [i]{property_name}[/] property expects a value of {friendly_list(VALID_LAYOUT)}" + ), + ], + ) + + +def dock_property_help_text(property_name: str, context: StylingContext) -> HelpText: + """Help text to show when the user supplies an invalid value for dock. + + Args: + property_name: The name of the property. + context: The context the property is being used in. + + Returns: + Renderable for displaying the help text for this property. + """ + property_name = _contextualize_property_name(property_name, context) + return HelpText( + summary=f"Invalid value for [i]{property_name}[/] property", + bullets=[ + Bullet( + "The value must be one of 'top', 'right', 'bottom', 'left' or 'none'" + ), + *ContextSpecificBullets( + inline=[ + Bullet( + "The 'dock' rule attaches a widget to the edge of a container.", + examples=[Example('header.styles.dock = "top"')], + ) + ], + css=[ + Bullet( + "The 'dock' rule attaches a widget to the edge of a container.", + examples=[Example("dock: top")], + ) + ], + ).get_by_context(context), + ], + ) + + +def split_property_help_text(property_name: str, context: StylingContext) -> HelpText: + """Help text to show when the user supplies an invalid value for split. + + Args: + property_name: The name of the property. + context: The context the property is being used in. + + Returns: + Renderable for displaying the help text for this property. + """ + property_name = _contextualize_property_name(property_name, context) + return HelpText( + summary=f"Invalid value for [i]{property_name}[/] property", + bullets=[ + Bullet("The value must be one of 'top', 'right', 'bottom' or 'left'"), + *ContextSpecificBullets( + inline=[ + Bullet( + "The 'split' splits the container and aligns the widget to the given edge.", + examples=[Example('header.styles.split = "top"')], + ) + ], + css=[ + Bullet( + "The 'split' splits the container and aligns the widget to the given edge.", + examples=[Example("split: top")], + ) + ], + ).get_by_context(context), + ], + ) + + +def fractional_property_help_text( + property_name: str, context: StylingContext +) -> HelpText: + """Help text to show when the user supplies an invalid value for a fractional property. + + Args: + property_name: The name of the property. + context: The context the property is being used in. + + Returns: + Renderable for displaying the help text for this property. + """ + property_name = _contextualize_property_name(property_name, context) + return HelpText( + summary=f"Invalid value for [i]{property_name}[/] property", + bullets=[ + *ContextSpecificBullets( + inline=[ + Bullet( + f"Set [i]{property_name}[/] to a string or float value", + examples=[ + Example(f'widget.styles.{property_name} = "50%"'), + Example(f"widget.styles.{property_name} = 0.25"), + ], + ) + ], + css=[ + Bullet( + f"Set [i]{property_name}[/] to a string or float", + examples=[ + Example(f"{property_name}: 50%;"), + Example(f"{property_name}: 0.25;"), + ], + ) + ], + ).get_by_context(context) + ], + ) + + +def offset_property_help_text(context: StylingContext) -> HelpText: + """Help text to show when the user supplies an invalid value for the offset property. + + Args: + context: The context the property is being used in. + + Returns: + Renderable for displaying the help text for this property. + """ + return HelpText( + summary="Invalid value for [i]offset[/] property", + bullets=[ + *ContextSpecificBullets( + inline=[ + Bullet( + markup="The [i]offset[/] property expects a tuple of 2 values [i](, )[/]", + examples=[ + Example("widget.styles.offset = (2, '50%')"), + ], + ), + ], + css=[ + Bullet( + markup="The [i]offset[/] property expects a value of the form [i] [/]", + examples=[ + Example( + "offset: 2 3; [dim]# Horizontal offset of 2, vertical offset of 3" + ), + Example( + "offset: 2 50%; [dim]# Horizontal offset of 2, vertical offset of 50%" + ), + ], + ), + ], + ).get_by_context(context), + Bullet(" and can be a number or scalar value"), + ], + ) + + +def scrollbar_size_property_help_text(context: StylingContext) -> HelpText: + """Help text to show when the user supplies an invalid value for the scrollbar-size property. + + Args: + context: The context the property is being used in. + + Returns: + Renderable for displaying the help text for this property. + """ + return HelpText( + summary="Invalid value for [i]scrollbar-size[/] property", + bullets=[ + *ContextSpecificBullets( + inline=[ + Bullet( + markup="The [i]scrollbar_size[/] property expects a tuple of 2 values [i](, )[/]", + examples=[ + Example("widget.styles.scrollbar_size = (2, 1)"), + ], + ), + ], + css=[ + Bullet( + markup="The [i]scrollbar-size[/] property expects a value of the form [i] [/]", + examples=[ + Example( + "scrollbar-size: 2 3; [dim]# Horizontal size of 2, vertical size of 3" + ), + ], + ), + ], + ).get_by_context(context), + Bullet(" and must be non-negative integers."), + ], + ) + + +def scrollbar_size_single_axis_help_text(property_name: str) -> HelpText: + """Help text to show when the user supplies an invalid value for a scrollbar-size-* property. + + Args: + property_name: The name of the property. + + Returns: + Renderable for displaying the help text for this property. + """ + return HelpText( + summary=f"Invalid value for [i]{property_name}[/]", + bullets=[ + Bullet( + markup=f"The [i]{property_name}[/] property can only be set to a positive integer, greater than zero", + examples=[ + Example(f"{property_name}: 2;"), + ], + ), + ], + ) + + +def integer_help_text(property_name: str) -> HelpText: + """Help text to show when the user supplies an invalid integer value. + + Args: + property_name: The name of the property. + + Returns: + Renderable for displaying the help text for this property. + """ + return HelpText( + summary=f"Invalid value for [i]{property_name}[/]", + bullets=[ + Bullet( + markup="An integer value is expected here", + examples=[ + Example(f"{property_name}: 2;"), + ], + ), + ], + ) + + +def align_help_text() -> HelpText: + """Help text to show when the user supplies an invalid value for a `align`. + + Returns: + Renderable for displaying the help text for this property. + """ + return HelpText( + summary="Invalid value for [i]align[/] property", + bullets=[ + Bullet( + markup="The [i]align[/] property expects exactly 2 values", + examples=[ + Example("align: "), + Example( + "align: center middle; [dim]# Center vertically & horizontally within parent" + ), + Example( + "align: left middle; [dim]# Align on the middle left of the parent" + ), + ], + ), + Bullet( + f"Valid values for are {friendly_list(VALID_ALIGN_HORIZONTAL)}" + ), + Bullet( + f"Valid values for are {friendly_list(VALID_ALIGN_VERTICAL)}", + ), + ], + ) + + +def keyline_help_text() -> HelpText: + """Help text to show when the user supplies an invalid value for a `keyline`. + + Returns: + Renderable for displaying the help text for this property. + """ + return HelpText( + summary="Invalid value for [i]keyline[/] property", + bullets=[ + Bullet( + markup="The [i]keyline[/] property expects exactly 2 values", + examples=[ + Example("keyline: "), + ], + ), + Bullet(f"Valid values for are {friendly_list(VALID_KEYLINE)}"), + ], + ) + + +def text_align_help_text() -> HelpText: + """Help text to show when the user supplies an invalid value for the text-align property. + + Returns: + Renderable for displaying the help text for this property. + """ + return HelpText( + summary="Invalid value for the [i]text-align[/] property.", + bullets=[ + Bullet( + f"The [i]text-align[/] property must be one of {friendly_list(VALID_TEXT_ALIGN)}", + examples=[ + Example("text-align: center;"), + Example("text-align: right;"), + ], + ) + ], + ) + + +def offset_single_axis_help_text(property_name: str) -> HelpText: + """Help text to show when the user supplies an invalid value for an offset-* property. + + Args: + property_name: The name of the property. + + Returns: + Renderable for displaying the help text for this property. + """ + return HelpText( + summary=f"Invalid value for [i]{property_name}[/]", + bullets=[ + Bullet( + markup=f"The [i]{property_name}[/] property can be set to a number or scalar value", + examples=[ + Example(f"{property_name}: 10;"), + Example(f"{property_name}: 50%;"), + ], + ), + Bullet(f"Valid scalar units are {friendly_list(SYMBOL_UNIT)}"), + ], + ) + + +def position_help_text(property_name: str) -> HelpText: + """Help text to show when the user supplies the wrong value for position. + + Args: + property_name: The name of the property. + + Returns: + Renderable for displaying the help text for this property. + """ + return HelpText( + summary=f"Invalid value for [i]{property_name}[/]", + bullets=[ + Bullet(f"Valid values are {friendly_list(VALID_POSITION)}"), + ], + ) + + +def expand_help_text(property_name: str) -> HelpText: + """Help text to show when the user supplies the wrong value for expand. + + Args: + property_name: The name of the property. + + Returns: + Renderable for displaying the help text for this property. + """ + return HelpText( + summary=f"Invalid value for [i]{property_name}[/]", + bullets=[ + Bullet(f"Valid values are {friendly_list(VALID_EXPAND)}"), + ], + ) + + +def style_flags_property_help_text( + property_name: str, value: str, context: StylingContext +) -> HelpText: + """Help text to show when the user supplies an invalid value for a style flags property. + + Args: + property_name: The name of the property. + context: The context the property is being used in. + + Returns: + Renderable for displaying the help text for this property. + """ + property_name = _contextualize_property_name(property_name, context) + return HelpText( + summary=f"Invalid value '{value}' in [i]{property_name}[/] property", + bullets=[ + Bullet( + f"Style flag values such as [i]{property_name}[/] expect space-separated values" + ), + Bullet(f"Permitted values are {friendly_list(VALID_STYLE_FLAGS)}"), + Bullet("The value 'none' cannot be mixed with others"), + *ContextSpecificBullets( + inline=[ + Bullet( + markup="Supply a string or Style object", + examples=[ + Example( + f'widget.styles.{property_name} = "bold italic underline"' + ) + ], + ), + ], + css=[ + Bullet( + markup="Supply style flags separated by spaces", + examples=[Example(f"{property_name}: bold italic underline;")], + ) + ], + ).get_by_context(context), + ], + ) + + +def table_rows_or_columns_help_text( + property_name: str, value: str, context: StylingContext +): + property_name = _contextualize_property_name(property_name, context) + return HelpText( + summary=f"Invalid value '{value}' in [i]{property_name}[/] property" + ) diff --git a/src/memray/_vendor/textual/css/_style_properties.py b/src/memray/_vendor/textual/css/_style_properties.py new file mode 100644 index 0000000000..71bd381ed5 --- /dev/null +++ b/src/memray/_vendor/textual/css/_style_properties.py @@ -0,0 +1,1258 @@ +""" +Style properties are descriptors which allow the ``Styles`` object to accept different types when +setting attributes. This gives the developer more freedom in how to express style information. + +Descriptors also play nicely with Mypy, which is aware that attributes can have different types +when setting and getting. +""" + +from __future__ import annotations + +from operator import attrgetter +from typing import ( + TYPE_CHECKING, + Generic, + Iterable, + Literal, + NamedTuple, + Sequence, + TypeVar, + cast, +) + +import rich.errors +import rich.repr +from rich.style import Style +from typing_extensions import TypeAlias + +from memray._vendor.textual._border import normalize_border_value +from memray._vendor.textual._cells import cell_len +from memray._vendor.textual.color import TRANSPARENT, Color, ColorParseError +from memray._vendor.textual.css._error_tools import friendly_list +from memray._vendor.textual.css._help_text import ( + border_property_help_text, + color_property_help_text, + fractional_property_help_text, + layout_property_help_text, + offset_property_help_text, + scalar_help_text, + spacing_wrong_number_of_values_help_text, + string_enum_help_text, + style_flags_property_help_text, +) +from memray._vendor.textual.css.constants import HATCHES, VALID_STYLE_FLAGS +from memray._vendor.textual.css.errors import StyleTypeError, StyleValueError +from memray._vendor.textual.css.scalar import ( + NULL_SCALAR, + UNIT_SYMBOL, + Scalar, + ScalarOffset, + ScalarParseError, + Unit, + get_symbols, + percentage_string_to_float, +) +from memray._vendor.textual.css.transition import Transition +from memray._vendor.textual.geometry import NULL_SPACING, Spacing, SpacingDimensions, clamp + +if TYPE_CHECKING: + from memray._vendor.textual.canvas import CanvasLineType + from memray._vendor.textual.layout import Layout + from memray._vendor.textual.css.styles import StylesBase + +from memray._vendor.textual.css.types import AlignHorizontal, AlignVertical, DockEdge, EdgeType + +BorderDefinition: TypeAlias = ( + "Sequence[tuple[EdgeType, str | Color] | None] | tuple[EdgeType, str | Color] | Literal['none']" +) + +PropertyGetType = TypeVar("PropertyGetType") +PropertySetType = TypeVar("PropertySetType") +EnumType = TypeVar("EnumType", covariant=True) + + +class GenericProperty(Generic[PropertyGetType, PropertySetType]): + """Descriptor that abstracts away common machinery for other style descriptors. + + Args: + default: The default value (or a factory thereof) of the property. + layout: Whether to refresh the node layout on value change. + refresh_children: Whether to refresh the node children on value change. + """ + + def __init__( + self, + default: PropertyGetType, + layout: bool = False, + refresh_children: bool = False, + ) -> None: + self.default = default + self.layout = layout + self.refresh_children = refresh_children + + def validate_value(self, value: object) -> PropertyGetType: + """Validate the setter value. + + Args: + value: The value being set. + + Returns: + The value to be set. + """ + # Raise StyleValueError here + return cast(PropertyGetType, value) + + def __set_name__(self, owner: StylesBase, name: str) -> None: + self.name = name + + def __get__( + self, obj: StylesBase, objtype: type[StylesBase] | None = None + ) -> PropertyGetType: + return obj.get_rule(self.name, self.default) # type: ignore[return-value] + + def __set__(self, obj: StylesBase, value: PropertySetType | None) -> None: + _rich_traceback_omit = True + if value is None: + obj.clear_rule(self.name) + obj.refresh(layout=self.layout, children=self.refresh_children) + return + new_value = self.validate_value(value) + if obj.set_rule(self.name, new_value): + obj.refresh(layout=self.layout, children=self.refresh_children) + + +class IntegerProperty(GenericProperty[int, int]): + def validate_value(self, value: object) -> int: + if isinstance(value, (int, float)): + return int(value) + else: + raise StyleValueError(f"Expected a number here, got {value!r}") + + +class BooleanProperty(GenericProperty[bool, bool]): + """A property that requires a True or False value.""" + + def validate_value(self, value: object) -> bool: + return bool(value) + + +class ScalarProperty: + """Descriptor for getting and setting scalar properties. Scalars are numeric values with a unit, e.g. "50vh".""" + + def __init__( + self, + units: set[Unit] | None = None, + percent_unit: Unit = Unit.WIDTH, + allow_auto: bool = True, + ) -> None: + self.units: set[Unit] = units or {*UNIT_SYMBOL} + self.percent_unit = percent_unit + self.allow_auto = allow_auto + super().__init__() + + def __set_name__(self, owner: StylesBase, name: str) -> None: + self.name = name + + def __get__( + self, obj: StylesBase, objtype: type[StylesBase] | None = None + ) -> Scalar | None: + """Get the scalar property. + + Args: + obj: The ``Styles`` object. + objtype: The ``Styles`` class. + + Returns: + The Scalar object or ``None`` if it's not set. + """ + return obj.get_rule(self.name) # type: ignore[return-value] + + def __set__( + self, obj: StylesBase, value: float | int | Scalar | str | None + ) -> None: + """Set the scalar property. + + Args: + obj: The ``Styles`` object. + value: The value to set the scalar property to. + You can directly pass a float or int value, which will be interpreted with + a default unit of Cells. You may also provide a string such as ``"50%"``, + as you might do when writing CSS. If a string with no units is supplied, + Cells will be used as the unit. Alternatively, you can directly supply + a ``Scalar`` object. + + Raises: + StyleValueError: If the value is of an invalid type, uses an invalid unit, or + cannot be parsed for any other reason. + """ + _rich_traceback_omit = True + if value is None: + obj.clear_rule(self.name) + obj.refresh(layout=True) + return + if isinstance(value, (int, float)): + new_value = Scalar(float(value), Unit.CELLS, Unit.WIDTH) + elif isinstance(value, Scalar): + new_value = value + elif isinstance(value, str): + try: + new_value = Scalar.parse(value) + except ScalarParseError: + raise StyleValueError( + f"unable to parse scalar from {value!r}", + help_text=scalar_help_text( + property_name=self.name, context="inline" + ), + ) + else: + raise StyleValueError("expected float, int, Scalar, or None") + + if ( + new_value is not None + and new_value.unit == Unit.AUTO + and not self.allow_auto + ): + raise StyleValueError("'auto' not allowed here") + + if new_value is not None and new_value.unit != Unit.AUTO: + if new_value.unit not in self.units: + raise StyleValueError( + f"{self.name} units must be one of {friendly_list(get_symbols(self.units))}" + ) + if new_value.is_percent: + new_value = Scalar( + float(new_value.value), self.percent_unit, Unit.WIDTH + ) + if obj.set_rule(self.name, new_value): + obj.refresh(layout=True) + + +class ScalarListProperty: + """Descriptor for lists of scalars. + + Args: + percent_unit: The dimension to which percentage scalars will be relative to. + refresh_children: Whether to refresh the node children on value change. + """ + + def __init__(self, percent_unit: Unit, refresh_children: bool = False) -> None: + self.percent_unit = percent_unit + self.refresh_children = refresh_children + + def __set_name__(self, owner: StylesBase, name: str) -> None: + self.name = name + + def __get__( + self, obj: StylesBase, objtype: type[StylesBase] | None = None + ) -> tuple[Scalar, ...] | None: + return obj.get_rule(self.name) # type: ignore[return-value] + + def __set__( + self, obj: StylesBase, value: str | Iterable[str | float] | None + ) -> None: + if value is None: + obj.clear_rule(self.name) + obj.refresh(layout=True, children=self.refresh_children) + return + parse_values: Iterable[str | float] + if isinstance(value, str): + parse_values = value.split() + else: + parse_values = value + + scalars = [] + for parse_value in parse_values: + if isinstance(parse_value, (int, float)): + scalars.append(Scalar.from_number(parse_value)) + else: + scalars.append( + Scalar.parse(parse_value, self.percent_unit) + if isinstance(parse_value, str) + else parse_value + ) + if obj.set_rule(self.name, tuple(scalars)): + obj.refresh(layout=True, children=self.refresh_children) + + +class BoxProperty: + """Descriptor for getting and setting outlines and borders along a single edge. + For example "border-right", "outline-bottom", etc. + """ + + def __init__(self, default_color: Color) -> None: + self._default_color = default_color + + def __set_name__(self, owner: StylesBase, name: str) -> None: + self.name = name + _type, edge = name.split("_") + self._type = _type + self.edge = edge + + def __get__( + self, obj: StylesBase, objtype: type[StylesBase] | None = None + ) -> tuple[EdgeType, Color]: + """Get the box property. + + Args: + obj: The ``Styles`` object. + objtype: The ``Styles`` class. + + Returns: + A ``tuple[EdgeType, Style]`` containing the string type of the box and + its style. Example types are "round", "solid", and "dashed". + """ + return obj.get_rule(self.name) or ("", self._default_color) # type: ignore[return-value] + + def __set__( + self, + obj: StylesBase, + border: tuple[EdgeType, str | Color] | Literal["none"] | None, + ): + """Set the box property. + + Args: + obj: The ``Styles`` object. + value: A 2-tuple containing the type of box to use, + e.g. "dashed", and the ``Style`` to be used. You can supply the ``Style`` directly, or pass a + ``str`` (e.g. ``"blue on #f0f0f0"`` ) or ``Color`` instead. + + Raises: + StyleValueError: If the string supplied for the color is not a valid color. + """ + + if border is None: + if obj.clear_rule(self.name): + obj.refresh(layout=True) + elif border == "none": + obj.set_rule(self.name, ("", obj.get_rule(self.name)[1])) + else: + _type, color = border + if _type in ("none", "hidden"): + _type = "" + new_value = border + if isinstance(color, str): + try: + new_value = (_type, Color.parse(color)) + except ColorParseError as error: + raise StyleValueError( + str(error), + help_text=border_property_help_text( + self.name, context="inline" + ), + ) + elif isinstance(color, Color): + new_value = (_type, color) + current_value: tuple[str, Color] = cast( + "tuple[str, Color]", obj.get_rule(self.name) + ) + has_edge = bool(current_value and current_value[0]) + new_edge = bool(_type) + if obj.set_rule(self.name, new_value): + obj.refresh(layout=has_edge != new_edge) + + +@rich.repr.auto +class Edges(NamedTuple): + """Stores edges for border / outline.""" + + top: tuple[EdgeType, Color] + right: tuple[EdgeType, Color] + bottom: tuple[EdgeType, Color] + left: tuple[EdgeType, Color] + + def __bool__(self) -> bool: + (top, _), (right, _), (bottom, _), (left, _) = self + return bool(top or right or bottom or left) + + def __rich_repr__(self) -> rich.repr.Result: + top, right, bottom, left = self + if top[0]: + yield "top", top + if right[0]: + yield "right", right + if bottom[0]: + yield "bottom", bottom + if left[0]: + yield "left", left + + @property + def spacing(self) -> Spacing: + """Get spacing created by borders. + + Returns: + Spacing for top, right, bottom, and left. + """ + (top, _), (right, _), (bottom, _), (left, _) = self + return Spacing( + 1 if top else 0, + 1 if right else 0, + 1 if bottom else 0, + 1 if left else 0, + ) + + +class BorderProperty: + """Descriptor for getting and setting full borders and outlines. + + Args: + layout: True if the layout should be refreshed after setting, False otherwise. + """ + + def __init__(self, layout: bool) -> None: + self._layout = layout + + def __set_name__(self, owner: StylesBase, name: str) -> None: + self.name = name + self._properties = ( + f"{name}_top", + f"{name}_right", + f"{name}_bottom", + f"{name}_left", + ) + self._get_properties = attrgetter(*self._properties) + + def __get__( + self, obj: StylesBase, objtype: type[StylesBase] | None = None + ) -> Edges: + """Get the border. + + Args: + obj: The ``Styles`` object. + objtype: The ``Styles`` class. + + Returns: + An ``Edges`` object describing the type and style of each edge. + """ + + return Edges(*self._get_properties(obj)) + + def __set__( + self, + obj: StylesBase, + border: BorderDefinition | None, + ) -> None: + """Set the border. + + Args: + obj: The ``Styles`` object. + border: + A ``tuple[EdgeType, str | Color | Style]`` representing the type of box to use and the ``Style`` to apply + to the box. + Alternatively, you can supply a sequence of these tuples and they will be applied per-edge. + If the sequence is of length 1, all edges will be decorated according to the single element. + If the sequence is length 2, the first ``tuple`` will be applied to the top and bottom edges. + If the sequence is length 4, the tuples will be applied to the edges in the order: top, right, bottom, left. + + Raises: + StyleValueError: When the supplied ``tuple`` is not of valid length (1, 2, or 4). + """ + _rich_traceback_omit = True + top, right, bottom, left = self._properties + + border_spacing = Edges(*self._get_properties(obj)).spacing + + def check_refresh() -> None: + """Check if an update requires a layout""" + if not self._layout: + obj.refresh() + else: + layout = Edges(*self._get_properties(obj)).spacing != border_spacing + obj.refresh(layout=layout) + + if border is None: + clear_rule = obj.clear_rule + clear_rule(top) + clear_rule(right) + clear_rule(bottom) + clear_rule(left) + check_refresh() + return + elif border == "none": + set_rule = obj.set_rule + get_rule = obj.get_rule + set_rule(top, ("", get_rule(top)[1])) + set_rule(right, ("", get_rule(right)[1])) + set_rule(bottom, ("", get_rule(bottom)[1])) + set_rule(left, ("", get_rule(left)[1])) + check_refresh() + return + + if isinstance(border, tuple) and len(border) == 2: + _border = normalize_border_value(border) # type: ignore + setattr(obj, top, _border) + setattr(obj, right, _border) + setattr(obj, bottom, _border) + setattr(obj, left, _border) + check_refresh() + return + + count = len(border) + if count == 1: + _border = normalize_border_value(border[0]) # type: ignore + setattr(obj, top, _border) + setattr(obj, right, _border) + setattr(obj, bottom, _border) + setattr(obj, left, _border) + elif count == 2: + _border1, _border2 = ( + normalize_border_value(border[0]), # type: ignore + normalize_border_value(border[1]), # type: ignore + ) + setattr(obj, top, _border1) + setattr(obj, bottom, _border1) + setattr(obj, right, _border2) + setattr(obj, left, _border2) + elif count == 4: + _border1, _border2, _border3, _border4 = ( + normalize_border_value(border[0]), # type: ignore + normalize_border_value(border[1]), # type: ignore + normalize_border_value(border[2]), # type: ignore + normalize_border_value(border[3]), # type: ignore + ) + setattr(obj, top, _border1) + setattr(obj, right, _border2) + setattr(obj, bottom, _border3) + setattr(obj, left, _border4) + else: + raise StyleValueError( + "expected 1, 2, or 4 values", + help_text=border_property_help_text(self.name, context="inline"), + ) + check_refresh() + + +class KeylineProperty: + """Descriptor for getting and setting keyline information.""" + + def __get__( + self, obj: StylesBase, objtype: type[StylesBase] | None = None + ) -> tuple[CanvasLineType, Color]: + return obj.get_rule("keyline", ("none", TRANSPARENT)) # type: ignore[return-value] + + def __set__(self, obj: StylesBase, keyline: tuple[str, Color] | None): + if keyline is None: + if obj.clear_rule("keyline"): + obj.refresh(layout=True) + else: + if obj.set_rule("keyline", keyline): + obj.refresh(layout=True) + + +class SpacingProperty: + """Descriptor for getting and setting spacing properties (e.g. padding and margin).""" + + def __set_name__(self, owner: StylesBase, name: str) -> None: + self.name = name + + def __get__( + self, obj: StylesBase, objtype: type[StylesBase] | None = None + ) -> Spacing: + """Get the Spacing. + + Args: + obj: The ``Styles`` object. + objtype: The ``Styles`` class. + + Returns: + The Spacing. If unset, returns the null spacing ``(0, 0, 0, 0)``. + """ + return obj.get_rule(self.name, NULL_SPACING) # type: ignore[return-value] + + def __set__(self, obj: StylesBase, spacing: SpacingDimensions | None): + """Set the Spacing. + + Args: + obj: The ``Styles`` object. + style: You can supply the ``Style`` directly, or a + string (e.g. ``"blue on #f0f0f0"``). + + Raises: + ValueError: When the value is malformed, + e.g. a ``tuple`` with a length that is not 1, 2, or 4. + """ + _rich_traceback_omit = True + if spacing is None: + if obj.clear_rule(self.name): + obj.refresh(layout=True) + else: + try: + unpacked_spacing = Spacing.unpack(spacing) + except ValueError as error: + raise StyleValueError( + str(error), + help_text=spacing_wrong_number_of_values_help_text( + property_name=self.name, + num_values_supplied=( + 1 if isinstance(spacing, int) else len(spacing) + ), + context="inline", + ), + ) + if obj.set_rule(self.name, unpacked_spacing): + obj.refresh(layout=True) + + +class DockProperty: + """Descriptor for getting and setting the dock property. The dock property + allows you to specify which edge you want to fix a Widget to. + """ + + def __get__( + self, obj: StylesBase, objtype: type[StylesBase] | None = None + ) -> DockEdge: + """Get the Dock property. + + Args: + obj: The ``Styles`` object. + objtype: The ``Styles`` class. + + Returns: + The edge name as a string. Returns "none" if unset or if "none" has been explicitly set. + """ + return obj.get_rule("dock", "none") # type: ignore[return-value] + + def __set__(self, obj: StylesBase, dock_name: str): + """Set the Dock property. + + Args: + obj: The ``Styles`` object. + dock_name: The name of the dock to attach this widget to. + """ + _rich_traceback_omit = True + if obj.set_rule("dock", dock_name): + obj.refresh(layout=True) + + +class SplitProperty: + """Descriptor for getting and setting the split property. + The split property allows you to specify which edge you want to split. + """ + + def __get__( + self, obj: StylesBase, objtype: type[StylesBase] | None = None + ) -> DockEdge: + """Get the Split property. + + Args: + obj: The ``Styles`` object. + objtype: The ``Styles`` class. + + Returns: + The edge name as a string. Returns "none" if unset or if "none" has been explicitly set. + """ + return obj.get_rule("split", "none") # type: ignore[return-value] + + def __set__(self, obj: StylesBase, dock_name: str): + """Set the Dock property. + + Args: + obj: The ``Styles`` object. + dock_name: The name of the dock to attach this widget to. + """ + _rich_traceback_omit = True + if obj.set_rule("split", dock_name): + obj.refresh(layout=True) + + +class LayoutProperty: + """Descriptor for getting and setting layout.""" + + def __set_name__(self, owner: StylesBase, name: str) -> None: + self.name = name + + def __get__( + self, obj: StylesBase, objtype: type[StylesBase] | None = None + ) -> Layout | None: + """ + Args: + obj: The Styles object. + objtype: The Styles class. + + Returns: + The `Layout` object. + """ + return obj.get_rule(self.name) # type: ignore[return-value] + + def __set__(self, obj: StylesBase, layout: str | Layout | None): + """ + Args: + obj: The Styles object. + layout: The layout to use. You can supply the name of the layout + or a `Layout` object. + """ + + from memray._vendor.textual.layouts.factory import Layout # Prevents circular import + from memray._vendor.textual.layouts.factory import MissingLayout, get_layout + + _rich_traceback_omit = True + if layout is None: + if obj.clear_rule("layout"): + obj.refresh(layout=True, children=True) + return + + if isinstance(layout, Layout): + layout = layout.name + + if obj.layout is not None and obj.layout.name == layout: + return + + try: + layout_object = get_layout(layout) + except MissingLayout as error: + raise StyleValueError( + str(error), + help_text=layout_property_help_text(self.name, context="inline"), + ) + if obj.set_rule("layout", layout_object): + obj.refresh(layout=True, children=True) + + +class OffsetProperty: + """Descriptor for getting and setting the offset property. + Offset consists of two values, x and y, that a widget's position + will be adjusted by before it is rendered. + """ + + def __set_name__(self, owner: StylesBase, name: str) -> None: + self.name = name + + def __get__( + self, obj: StylesBase, objtype: type[StylesBase] | None = None + ) -> ScalarOffset: + """Get the offset. + + Args: + obj: The ``Styles`` object. + objtype: The ``Styles`` class. + + Returns: + The ``ScalarOffset`` indicating the adjustment that + will be made to widget position prior to it being rendered. + """ + return obj.get_rule(self.name, NULL_SCALAR) # type: ignore[return-value] + + def __set__( + self, obj: StylesBase, offset: tuple[int | str, int | str] | ScalarOffset | None + ): + """Set the offset. + + Args: + obj: The ``Styles`` class. + offset: A ScalarOffset object, or a 2-tuple of the form ``(x, y)`` indicating + the x and y offsets. When the ``tuple`` form is used, x and y can be specified + as either ``int`` or ``str``. The string format allows you to also specify + any valid scalar unit e.g. ``("0.5vw", "0.5vh")``. + + Raises: + ScalarParseError: If any of the string values supplied in the 2-tuple cannot + be parsed into a Scalar. For example, if you specify a non-existent unit. + """ + _rich_traceback_omit = True + if offset is None: + if obj.clear_rule(self.name): + obj.refresh(layout=True, repaint=False) + elif isinstance(offset, ScalarOffset): + if obj.set_rule(self.name, offset): + obj.refresh(layout=True, repaint=False) + else: + x, y = offset + + try: + scalar_x = ( + Scalar.parse(x, Unit.WIDTH) + if isinstance(x, str) + else Scalar(float(x), Unit.CELLS, Unit.WIDTH) + ) + scalar_y = ( + Scalar.parse(y, Unit.HEIGHT) + if isinstance(y, str) + else Scalar(float(y), Unit.CELLS, Unit.HEIGHT) + ) + except ScalarParseError as error: + raise StyleValueError( + str(error), help_text=offset_property_help_text(context="inline") + ) + + _offset = ScalarOffset(scalar_x, scalar_y) + + if obj.set_rule(self.name, _offset): + obj.refresh(layout=True, repaint=False) + + +class StringEnumProperty(Generic[EnumType]): + """Descriptor for getting and setting string properties and ensuring that the set + value belongs in the set of valid values. + + Args: + valid_values: The set of valid values that the descriptor can take. + default: The default value (or a factory thereof) of the property. + layout: Whether to refresh the node layout on value change. + refresh_children: Whether to refresh the node children on value change. + display: Does this property change display? + """ + + def __init__( + self, + valid_values: set[str], + default: EnumType, + layout: bool = False, + refresh_children: bool = False, + refresh_parent: bool = False, + display: bool = False, + pointer: bool = False, + ) -> None: + self._valid_values = valid_values + self._default = default + self._layout = layout + self._refresh_children = refresh_children + self._refresh_parent = refresh_parent + self._display = display + self._pointer = pointer + + def __set_name__(self, owner: StylesBase, name: str) -> None: + self.name = name + + def __get__( + self, obj: StylesBase, objtype: type[StylesBase] | None = None + ) -> EnumType: + """Get the string property, or the default value if it's not set. + + Args: + obj: The `Styles` object. + objtype: The `Styles` class. + + Returns: + The string property value. + """ + return obj.get_rule(self.name, self._default) # type: ignore + + def _before_refresh(self, obj: StylesBase, value: str | None) -> None: + """Do any housekeeping before asking for a layout refresh after a value change.""" + + def __set__(self, obj: StylesBase, value: EnumType | None = None): + """Set the string property and ensure it is in the set of allowed values. + + Args: + obj: The `Styles` object. + value: The string value to set the property to. + + Raises: + StyleValueError: If the value is not in the set of valid values. + """ + _rich_traceback_omit = True + if value is None: + if obj.clear_rule(self.name): + self._before_refresh(obj, value) + obj.refresh( + layout=self._layout, + children=self._refresh_children, + parent=self._refresh_parent, + ) + + if self._display: + node = obj.node + if node is not None and node.parent: + node._nodes.updated() + + else: + if value not in self._valid_values: + raise StyleValueError( + f"{self.name} must be one of {friendly_list(self._valid_values)} (received {value!r})", + help_text=string_enum_help_text( + self.name, + valid_values=list(self._valid_values), + context="inline", + ), + ) + if obj.set_rule(self.name, value): + if self._display and obj.node is not None: + node = obj.node + if node.parent: + node._nodes.updated() + + self._before_refresh(obj, value) + obj.refresh( + layout=self._layout, + children=self._refresh_children, + parent=self._refresh_parent, + ) + if self._pointer and obj.node is not None: + from memray._vendor.textual.dom import NoScreen + + try: + obj.node.screen.update_pointer_shape() + except NoScreen: + pass + + +class OverflowProperty(StringEnumProperty): + """Descriptor for overflow styles that forces widgets to refresh scrollbars.""" + + def _before_refresh(self, obj: StylesBase, value: str | None) -> None: + from memray._vendor.textual.widget import Widget # Avoid circular import + + if isinstance(obj.node, Widget): + obj.node._refresh_scrollbars() + + +class NameProperty: + """Descriptor for getting and setting name properties.""" + + def __set_name__(self, owner: StylesBase, name: str) -> None: + self.name = name + + def __get__(self, obj: StylesBase, objtype: type[StylesBase] | None) -> str: + """Get the name property. + + Args: + obj: The ``Styles`` object. + objtype: The ``Styles`` class. + + Returns: + The name. + """ + return obj.get_rule(self.name, "") # type: ignore[return-value] + + def __set__(self, obj: StylesBase, name: str | None): + """Set the name property. + + Args: + obj: The ``Styles`` object. + name: The name to set the property to. + + Raises: + StyleTypeError: If the value is not a ``str``. + """ + _rich_traceback_omit = True + if name is None: + if obj.clear_rule(self.name): + obj.refresh(layout=True) + else: + if not isinstance(name, str): + raise StyleTypeError(f"{self.name} must be a str") + if obj.set_rule(self.name, name): + obj.refresh(layout=True) + + +class NameListProperty: + def __set_name__(self, owner: StylesBase, name: str) -> None: + self.name = name + + def __get__( + self, obj: StylesBase, objtype: type[StylesBase] | None = None + ) -> tuple[str, ...]: + return obj.get_rule(self.name, ()) # type: ignore[return-value] + + def __set__(self, obj: StylesBase, names: str | tuple[str] | None = None): + _rich_traceback_omit = True + if names is None: + if obj.clear_rule(self.name): + obj.refresh(layout=True) + elif isinstance(names, str): + if obj.set_rule( + self.name, tuple(name.strip().lower() for name in names.split(" ")) + ): + obj.refresh(layout=True) + elif isinstance(names, tuple): + if obj.set_rule(self.name, names): + obj.refresh(layout=True) + + +class ColorProperty: + """Descriptor for getting and setting color properties.""" + + def __init__(self, default_color: Color | str) -> None: + self._default_color = Color.parse(default_color) + + def __set_name__(self, owner: StylesBase, name: str) -> None: + self.name = name + + def __get__( + self, obj: StylesBase, objtype: type[StylesBase] | None = None + ) -> Color: + """Get a ``Color``. + + Args: + obj: The ``Styles`` object. + objtype: The ``Styles`` class. + + Returns: + The Color. + """ + return obj.get_rule(self.name, self._default_color) # type: ignore[return-value] + + def __set__(self, obj: StylesBase, color: Color | str | None) -> None: + """Set the Color. + + Args: + obj: The ``Styles`` object. + color: The color to set. Pass a ``Color`` instance directly, + or pass a ``str`` which will be parsed into a color (e.g. ``"red""``, ``"rgb(20, 50, 80)"``, + ``"#f4e32d"``). + + Raises: + ColorParseError: When the color string is invalid. + """ + _rich_traceback_omit = True + if color is None: + if obj.clear_rule(self.name): + obj.refresh(children=True) + elif isinstance(color, Color): + if obj.set_rule(self.name, color): + obj.refresh(children=True) + elif isinstance(color, str): + alpha = 1.0 + parsed_color = Color(255, 255, 255) + for token in color.split(): + if token.endswith("%"): + try: + alpha = percentage_string_to_float(token) + except ValueError: + raise StyleValueError(f"invalid percentage value '{token}'") + continue + try: + parsed_color = Color.parse(token) + except ColorParseError as error: + raise StyleValueError( + f"Invalid color value '{token}'", + help_text=color_property_help_text( + self.name, context="inline", error=error, value=token + ), + ) + parsed_color = parsed_color.multiply_alpha(alpha) + + if obj.set_rule(self.name, parsed_color): + obj.refresh(children=True) + else: + raise StyleValueError(f"Invalid color value {color}") + + +class ScrollbarColorProperty(ColorProperty): + """A descriptor to set scrollbar color(s).""" + + def __set__(self, obj: StylesBase, color: Color | str | None) -> None: + super().__set__(obj, color) + + if obj.node is None: + return + + from memray._vendor.textual.widget import Widget + + if isinstance(obj.node, Widget): + widget = obj.node + + if widget.show_horizontal_scrollbar: + widget.horizontal_scrollbar.refresh() + + if widget.show_vertical_scrollbar: + widget.vertical_scrollbar.refresh() + + +class StyleFlagsProperty: + """Descriptor for getting and set style flag properties (e.g. ``bold italic underline``).""" + + def __set_name__(self, owner: StylesBase, name: str) -> None: + self.name = name + + def __get__( + self, obj: StylesBase, objtype: type[StylesBase] | None = None + ) -> Style: + """Get the ``Style``. + + Args: + obj: The ``Styles`` object. + objtype: The ``Styles`` class. + + Returns: + The ``Style`` object. + """ + return obj.get_rule(self.name, Style.null()) # type: ignore[return-value] + + def __set__(self, obj: StylesBase, style_flags: Style | str | None) -> None: + """Set the style using a style flag string. + + Args: + obj: The ``Styles`` object. + style_flags: The style flags to set as a string. For example, + ``"bold italic"``. + + Raises: + StyleValueError: If the value is an invalid style flag. + """ + _rich_traceback_omit = True + if style_flags is None: + if obj.clear_rule(self.name): + obj.refresh(children=True) + elif isinstance(style_flags, Style): + if obj.set_rule(self.name, style_flags): + obj.refresh(children=True) + else: + words = [word.strip() for word in style_flags.split(" ")] + valid_word = VALID_STYLE_FLAGS.__contains__ + for word in words: + if not valid_word(word): + raise StyleValueError( + f"unknown word {word!r} in style flags", + help_text=style_flags_property_help_text( + self.name, word, context="inline" + ), + ) + try: + style = Style.parse(style_flags) + except rich.errors.StyleSyntaxError as error: + if "none" in words and len(words) > 1: + raise StyleValueError( + "cannot mix 'none' with other style flags", + help_text=style_flags_property_help_text( + self.name, " ".join(words), context="inline" + ), + ) from None + raise error from None + if obj.set_rule(self.name, style): + obj.refresh(children=True) + + +class TransitionsProperty: + """Descriptor for getting transitions properties""" + + def __get__( + self, obj: StylesBase, objtype: type[StylesBase] | None = None + ) -> dict[str, Transition]: + """Get a mapping of properties to the transitions applied to them. + + Args: + obj: The ``Styles`` object. + objtype: The ``Styles`` class. + + Returns: + A ``dict`` mapping property names to the ``Transition`` applied to them. + e.g. ``{"offset": Transition(...), ...}``. If no transitions have been set, an empty ``dict`` + is returned. + """ + return obj.get_rule("transitions", {}) # type: ignore[return-value] + + def __set__( + self, obj: StylesBase, transitions: dict[str, Transition] | None + ) -> None: + _rich_traceback_omit = True + if transitions is None: + obj.clear_rule("transitions") + else: + obj.set_rule("transitions", transitions.copy()) + + +class FractionalProperty: + """Property that can be set either as a float (e.g. 0.1) or a + string percentage (e.g. '10%'). Values will be clamped to the range (0, 1). + """ + + def __init__(self, default: float = 1.0, children: bool = False): + """ + Args: + default: Default value if the rule wasn't explicitly set. + children: If `True`, then updating this value will also refresh children. + Otherwise only this widget will be refreshed. + """ + self.default = default + self.children = children + + def __set_name__(self, owner: StylesBase, name: str) -> None: + self.name = name + + def __get__(self, obj: StylesBase, type: type[StylesBase]) -> float: + """Get the property value as a float between 0 and 1. + + Args: + obj: The ``Styles`` object. + objtype: The ``Styles`` class. + + Returns: + The value of the property (in the range (0, 1)). + """ + return obj.get_rule(self.name, self.default) # type: ignore[return-value] + + def __set__(self, obj: StylesBase, value: float | str | None) -> None: + """Set the property value, clamping it between 0 and 1. + + Args: + obj: The Styles object. + value: The value to set as a float between 0 and 1, or + as a percentage string such as '10%'. + """ + _rich_traceback_omit = True + name = self.name + if value is None: + if obj.clear_rule(name): + obj.refresh(children=self.children) + return + + if isinstance(value, (int, float)): + float_value = float(value) + elif isinstance(value, str) and value.endswith("%"): + float_value = float(Scalar.parse(value).value) / 100 + else: + raise StyleValueError( + f"{self.name} must be a str (e.g. '10%') or a float (e.g. 0.1)", + help_text=fractional_property_help_text(name, context="inline"), + ) + if obj.set_rule(name, clamp(float_value, 0, 1)): + obj.refresh(children=self.children) + + +class AlignProperty: + """Combines the horizontal and vertical alignment properties into a single property.""" + + def __set_name__(self, owner: StylesBase, name: str) -> None: + self.horizontal = f"{name}_horizontal" + self.vertical = f"{name}_vertical" + + def __get__( + self, obj: StylesBase, type: type[StylesBase] + ) -> tuple[AlignHorizontal, AlignVertical]: + horizontal = getattr(obj, self.horizontal) + vertical = getattr(obj, self.vertical) + return (horizontal, vertical) + + def __set__( + self, obj: StylesBase, value: tuple[AlignHorizontal, AlignVertical] + ) -> None: + horizontal, vertical = value + setattr(obj, self.horizontal, horizontal) + setattr(obj, self.vertical, vertical) + + +class HatchProperty: + """Property to expose hatch style.""" + + def __get__( + self, obj: StylesBase, type: type[StylesBase] + ) -> tuple[str, Color] | Literal["none"]: + return obj.get_rule("hatch") # type: ignore[return-value] + + def __set__( + self, obj: StylesBase, value: tuple[str, Color | str] | Literal["none"] | None + ) -> None: + _rich_traceback_omit = True + if value is None: + if obj.clear_rule("hatch"): + obj.refresh(children=True) + return + + if value == "none": + hatch = "none" + else: + character, color = value + if len(character) != 1: + try: + character = HATCHES[character] + except KeyError: + raise ValueError( + f"Expected a character or hatch value here; found {character!r}" + ) from None + if cell_len(character) != 1: + raise ValueError("Hatch character must have a cell length of 1") + if isinstance(color, str): + color = Color.parse(color) + hatch = (character, color) + + obj.set_rule("hatch", hatch) diff --git a/src/memray/_vendor/textual/css/_styles_builder.py b/src/memray/_vendor/textual/css/_styles_builder.py new file mode 100644 index 0000000000..4cbb507138 --- /dev/null +++ b/src/memray/_vendor/textual/css/_styles_builder.py @@ -0,0 +1,1316 @@ +from __future__ import annotations + +from typing import Iterable, NoReturn, cast + +import rich.repr + +from memray._vendor.textual._border import BorderValue, normalize_border_value +from memray._vendor.textual._cells import cell_len +from memray._vendor.textual._duration import _duration_as_seconds +from memray._vendor.textual._easing import EASING +from memray._vendor.textual.color import TRANSPARENT, Color, ColorParseError +from memray._vendor.textual.css._error_tools import friendly_list +from memray._vendor.textual.css._help_renderables import HelpText +from memray._vendor.textual.css._help_text import ( + align_help_text, + border_property_help_text, + color_property_help_text, + dock_property_help_text, + expand_help_text, + fractional_property_help_text, + integer_help_text, + keyline_help_text, + layout_property_help_text, + offset_property_help_text, + offset_single_axis_help_text, + position_help_text, + property_invalid_value_help_text, + scalar_help_text, + scrollbar_size_property_help_text, + scrollbar_size_single_axis_help_text, + spacing_invalid_value_help_text, + spacing_wrong_number_of_values_help_text, + split_property_help_text, + string_enum_help_text, + style_flags_property_help_text, + table_rows_or_columns_help_text, + text_align_help_text, +) +from memray._vendor.textual.css.constants import ( + HATCHES, + VALID_ALIGN_HORIZONTAL, + VALID_ALIGN_VERTICAL, + VALID_BORDER, + VALID_BOX_SIZING, + VALID_CONSTRAIN, + VALID_DISPLAY, + VALID_EDGE, + VALID_EXPAND, + VALID_HATCH, + VALID_KEYLINE, + VALID_OVERFLOW, + VALID_OVERLAY, + VALID_POINTER, + VALID_POSITION, + VALID_SCROLLBAR_GUTTER, + VALID_SCROLLBAR_VISIBILITY, + VALID_STYLE_FLAGS, + VALID_TEXT_ALIGN, + VALID_TEXT_OVERFLOW, + VALID_TEXT_WRAP, + VALID_VISIBILITY, +) +from memray._vendor.textual.css.errors import DeclarationError, StyleValueError +from memray._vendor.textual.css.model import Declaration +from memray._vendor.textual.css.scalar import ( + Scalar, + ScalarError, + ScalarOffset, + ScalarParseError, + Unit, + percentage_string_to_float, +) +from memray._vendor.textual.css.styles import Styles +from memray._vendor.textual.css.tokenize import Token +from memray._vendor.textual.css.transition import Transition +from memray._vendor.textual.css.types import ( + BoxSizing, + Display, + EdgeType, + Overflow, + ScrollbarVisibility, + TextOverflow, + TextWrap, + Visibility, +) +from memray._vendor.textual.geometry import Spacing, SpacingDimensions, clamp +from memray._vendor.textual.suggestions import get_suggestion + + +class StylesBuilder: + """ + The StylesBuilder object takes tokens parsed from the CSS and converts + to the appropriate internal types. + """ + + def __init__(self) -> None: + self.styles = Styles() + + def __rich_repr__(self) -> rich.repr.Result: + yield "styles", self.styles + + def __repr__(self) -> str: + return "StylesBuilder()" + + def error(self, name: str, token: Token, message: str | HelpText) -> NoReturn: + raise DeclarationError(name, token, message) + + def add_declaration(self, declaration: Declaration) -> None: + if not declaration.name: + return + rule_name = declaration.name.replace("-", "_") + + if not declaration.tokens: + self.error( + rule_name, + declaration.token, + f"Missing property value for '{declaration.name}:'", + ) + + process_method = getattr(self, f"process_{rule_name}", None) + + if process_method is None: + suggested_property_name = self._get_suggested_property_name_for_rule( + declaration.name + ) + self.error( + declaration.name, + declaration.token, + property_invalid_value_help_text( + declaration.name, + "css", + suggested_property_name=suggested_property_name, + ), + ) + + tokens = declaration.tokens + + important = tokens[-1].name == "important" + if important: + tokens = tokens[:-1] + self.styles.important.add(rule_name) + + # Check for special token(s) + if tokens[0].name == "token": + value = tokens[0].value + if value == "initial": + self.styles._rules[rule_name] = None + return + try: + process_method(declaration.name, tokens) + except DeclarationError: + raise + except Exception as error: + self.error(declaration.name, declaration.token, str(error)) + + def _process_enum_multiple( + self, name: str, tokens: list[Token], valid_values: set[str], count: int + ) -> tuple[str, ...]: + """Generic code to process a declaration with two enumerations, like overflow: auto auto""" + if len(tokens) > count or not tokens: + self.error(name, tokens[0], f"expected 1 to {count} tokens here") + results: list[str] = [] + append = results.append + for token in tokens: + token_name, value, _, _, location, _ = token + if token_name != "token": + self.error( + name, + token, + f"invalid token {value!r}; expected {friendly_list(valid_values)}", + ) + append(value) + + short_results = results[:] + + while len(results) < count: + results.extend(short_results) + results = results[:count] + + return tuple(results) + + def _process_enum( + self, name: str, tokens: list[Token], valid_values: set[str] + ) -> str: + """Process a declaration that expects an enum. + + Args: + name: Name of declaration. + tokens: Tokens from parser. + valid_values: A set of valid values. + + Returns: + True if the value is valid or False if it is invalid (also generates an error) + """ + + if len(tokens) != 1: + self.error( + name, + tokens[0], + string_enum_help_text( + name, valid_values=list(valid_values), context="css" + ), + ) + + token = tokens[0] + token_name, value, _, _, location, _ = token + if token_name != "token": + self.error( + name, + token, + string_enum_help_text( + name, valid_values=list(valid_values), context="css" + ), + ) + if value not in valid_values: + self.error( + name, + token, + string_enum_help_text( + name, valid_values=list(valid_values), context="css" + ), + ) + return value + + def process_display(self, name: str, tokens: list[Token]) -> None: + for token in tokens: + name, value, _, _, location, _ = token + + if name == "token": + value = value.lower() + if value in VALID_DISPLAY: + self.styles._rules["display"] = cast(Display, value) + else: + self.error( + name, + token, + string_enum_help_text( + "display", valid_values=list(VALID_DISPLAY), context="css" + ), + ) + else: + self.error( + name, + token, + string_enum_help_text( + "display", valid_values=list(VALID_DISPLAY), context="css" + ), + ) + + def _process_scalar(self, name: str, tokens: list[Token]) -> None: + def scalar_error(): + self.error( + name, tokens[0], scalar_help_text(property_name=name, context="css") + ) + + if not tokens: + return + if len(tokens) == 1: + try: + self.styles._rules[name.replace("-", "_")] = Scalar.parse( # type: ignore + tokens[0].value + ) + except ScalarParseError: + scalar_error() + else: + scalar_error() + + def _distribute_importance(self, prefix: str, suffixes: tuple[str, ...]) -> None: + """Distribute importance amongst all aspects of the given style. + + Args: + prefix: The prefix of the style. + suffixes: The suffixes to distribute amongst. + + A number of styles can be set with the 'prefix' of the style, + providing the values as a series of parameters; or they can be set + with specific suffixes. Think `border` vs `border-left`, etc. This + method is used to ensure that if the former is set, `!important` is + distributed amongst all the suffixes. + """ + if prefix in self.styles.important: + self.styles.important.remove(prefix) + self.styles.important.update(f"{prefix}_{suffix}" for suffix in suffixes) + + def process_box_sizing(self, name: str, tokens: list[Token]) -> None: + for token in tokens: + name, value, _, _, location, _ = token + + if name == "token": + value = value.lower() + if value in VALID_BOX_SIZING: + self.styles._rules["box_sizing"] = cast(BoxSizing, value) + else: + self.error( + name, + token, + string_enum_help_text( + "box-sizing", + valid_values=list(VALID_BOX_SIZING), + context="css", + ), + ) + else: + self.error( + name, + token, + string_enum_help_text( + "box-sizing", valid_values=list(VALID_BOX_SIZING), context="css" + ), + ) + + def process_width(self, name: str, tokens: list[Token]) -> None: + self._process_scalar(name, tokens) + + def process_height(self, name: str, tokens: list[Token]) -> None: + self._process_scalar(name, tokens) + + def process_min_width(self, name: str, tokens: list[Token]) -> None: + self._process_scalar(name, tokens) + + def process_min_height(self, name: str, tokens: list[Token]) -> None: + self._process_scalar(name, tokens) + + def process_max_width(self, name: str, tokens: list[Token]) -> None: + self._process_scalar(name, tokens) + + def process_max_height(self, name: str, tokens: list[Token]) -> None: + self._process_scalar(name, tokens) + + def process_overflow(self, name: str, tokens: list[Token]) -> None: + rules = self.styles._rules + overflow_x, overflow_y = self._process_enum_multiple( + name, tokens, VALID_OVERFLOW, 2 + ) + rules["overflow_x"] = cast(Overflow, overflow_x) + rules["overflow_y"] = cast(Overflow, overflow_y) + self._distribute_importance("overflow", ("x", "y")) + + def process_overflow_x(self, name: str, tokens: list[Token]) -> None: + self.styles._rules["overflow_x"] = cast( + Overflow, self._process_enum(name, tokens, VALID_OVERFLOW) + ) + + def process_overflow_y(self, name: str, tokens: list[Token]) -> None: + self.styles._rules["overflow_y"] = cast( + Overflow, self._process_enum(name, tokens, VALID_OVERFLOW) + ) + + def process_visibility(self, name: str, tokens: list[Token]) -> None: + for token in tokens: + name, value, _, _, location, _ = token + if name == "token": + value = value.lower() + if value in VALID_VISIBILITY: + self.styles._rules["visibility"] = cast(Visibility, value) + else: + self.error( + name, + token, + string_enum_help_text( + "visibility", + valid_values=list(VALID_VISIBILITY), + context="css", + ), + ) + else: + string_enum_help_text( + "visibility", valid_values=list(VALID_VISIBILITY), context="css" + ) + + def process_text_wrap(self, name: str, tokens: list[Token]) -> None: + for token in tokens: + name, value, _, _, location, _ = token + if name == "token": + value = value.lower() + if value in VALID_TEXT_WRAP: + self.styles._rules["text_wrap"] = cast(TextWrap, value) + else: + self.error( + name, + token, + string_enum_help_text( + "text-wrap", + valid_values=list(VALID_TEXT_WRAP), + context="css", + ), + ) + else: + string_enum_help_text( + "text-wrap", valid_values=list(VALID_TEXT_WRAP), context="css" + ) + + def process_text_overflow(self, name: str, tokens: list[Token]) -> None: + for token in tokens: + name, value, _, _, location, _ = token + if name == "token": + value = value.lower() + if value in VALID_TEXT_OVERFLOW: + self.styles._rules["text_overflow"] = cast(TextOverflow, value) + else: + self.error( + name, + token, + string_enum_help_text( + "text-overflow", + valid_values=list(VALID_TEXT_OVERFLOW), + context="css", + ), + ) + else: + string_enum_help_text( + "text-overflow", + valid_values=list(VALID_TEXT_OVERFLOW), + context="css", + ) + + def _process_fractional(self, name: str, tokens: list[Token]) -> None: + if not tokens: + return + token = tokens[0] + error = False + if len(tokens) != 1: + error = True + else: + token_name = token.name + value = token.value + rule_name = name.replace("-", "_") + if token_name == "scalar" and value.endswith("%"): + try: + text_opacity = percentage_string_to_float(value) + self.styles.set_rule(rule_name, text_opacity) + except ValueError: + error = True + elif token_name == "number": + try: + text_opacity = clamp(float(value), 0, 1) + self.styles.set_rule(rule_name, text_opacity) + except ValueError: + error = True + else: + error = True + + if error: + self.error(name, token, fractional_property_help_text(name, context="css")) + + process_opacity = _process_fractional + process_text_opacity = _process_fractional + + def _process_space(self, name: str, tokens: list[Token]) -> None: + space: list[int] = [] + append = space.append + for token in tokens: + token_name, value, _, _, _, _ = token + if token_name == "number": + try: + append(int(value)) + except ValueError: + self.error( + name, + token, + spacing_invalid_value_help_text(name, context="css"), + ) + else: + self.error( + name, token, spacing_invalid_value_help_text(name, context="css") + ) + if len(space) not in (1, 2, 4): + self.error( + name, + tokens[0], + spacing_wrong_number_of_values_help_text( + name, num_values_supplied=len(space), context="css" + ), + ) + self.styles._rules[name] = Spacing.unpack(cast(SpacingDimensions, tuple(space))) # type: ignore + + def _process_space_partial(self, name: str, tokens: list[Token]) -> None: + """Process granular margin / padding declarations.""" + if len(tokens) != 1: + self.error( + name, tokens[0], spacing_invalid_value_help_text(name, context="css") + ) + + _EDGE_SPACING_MAP = {"top": 0, "right": 1, "bottom": 2, "left": 3} + token = tokens[0] + token_name, value, _, _, _, _ = token + if token_name == "number": + space = int(value) + else: + self.error( + name, token, spacing_invalid_value_help_text(name, context="css") + ) + style_name, _, edge = name.replace("-", "_").partition("_") + + current_spacing = cast( + "tuple[int, int, int, int]", + self.styles._rules.get(style_name, (0, 0, 0, 0)), + ) + + spacing_list = list(current_spacing) + spacing_list[_EDGE_SPACING_MAP[edge]] = space + + self.styles._rules[style_name] = Spacing(*spacing_list) # type: ignore + + process_padding = _process_space + process_margin = _process_space + + process_margin_top = _process_space_partial + process_margin_right = _process_space_partial + process_margin_bottom = _process_space_partial + process_margin_left = _process_space_partial + + process_padding_top = _process_space_partial + process_padding_right = _process_space_partial + process_padding_bottom = _process_space_partial + process_padding_left = _process_space_partial + + def _parse_border(self, name: str, tokens: list[Token]) -> BorderValue: + border_type: EdgeType = "solid" + border_color = Color(0, 255, 0) + border_alpha: float | None = None + + def border_value_error(): + self.error(name, token, border_property_help_text(name, context="css")) + + for token in tokens: + token_name, value, _, _, _, _ = token + if token_name == "token": + if value in VALID_BORDER: + border_type = value # type: ignore + else: + try: + border_color = Color.parse(value) + except ColorParseError: + border_value_error() + + elif token_name == "color": + try: + border_color = Color.parse(value) + except ColorParseError: + border_value_error() + + elif token_name == "scalar": + alpha_scalar = Scalar.parse(token.value) + if alpha_scalar.unit != Unit.PERCENT: + self.error(name, token, "alpha must be given as a percentage.") + border_alpha = alpha_scalar.value / 100.0 + + else: + border_value_error() + + if border_alpha is not None: + border_color = border_color.multiply_alpha(border_alpha) + + return normalize_border_value((border_type, border_color)) + + def _process_border_edge(self, edge: str, name: str, tokens: list[Token]) -> None: + border = self._parse_border(name, tokens) + self.styles._rules[f"border_{edge}"] = border # type: ignore + + def process_border(self, name: str, tokens: list[Token]) -> None: + border = self._parse_border(name, tokens) + rules = self.styles._rules + rules["border_top"] = rules["border_right"] = border + rules["border_bottom"] = rules["border_left"] = border + self._distribute_importance("border", ("top", "left", "bottom", "right")) + + def process_border_top(self, name: str, tokens: list[Token]) -> None: + self._process_border_edge("top", name, tokens) + + def process_border_right(self, name: str, tokens: list[Token]) -> None: + self._process_border_edge("right", name, tokens) + + def process_border_bottom(self, name: str, tokens: list[Token]) -> None: + self._process_border_edge("bottom", name, tokens) + + def process_border_left(self, name: str, tokens: list[Token]) -> None: + self._process_border_edge("left", name, tokens) + + def _process_outline(self, edge: str, name: str, tokens: list[Token]) -> None: + border = self._parse_border(name, tokens) + self.styles._rules[f"outline_{edge}"] = border # type: ignore + + def process_outline(self, name: str, tokens: list[Token]) -> None: + border = self._parse_border(name, tokens) + rules = self.styles._rules + rules["outline_top"] = rules["outline_right"] = border + rules["outline_bottom"] = rules["outline_left"] = border + self._distribute_importance("outline", ("top", "left", "bottom", "right")) + + def process_outline_top(self, name: str, tokens: list[Token]) -> None: + self._process_outline("top", name, tokens) + + def process_outline_right(self, name: str, tokens: list[Token]) -> None: + self._process_outline("right", name, tokens) + + def process_outline_bottom(self, name: str, tokens: list[Token]) -> None: + self._process_outline("bottom", name, tokens) + + def process_outline_left(self, name: str, tokens: list[Token]) -> None: + self._process_outline("left", name, tokens) + + def process_keyline(self, name: str, tokens: list[Token]) -> None: + if not tokens: + return + if len(tokens) > 3: + self.error(name, tokens[0], keyline_help_text()) + keyline_style = "none" + keyline_color = Color.parse("green") + keyline_alpha = 1.0 + for token in tokens: + if token.name == "color": + try: + keyline_color = Color.parse(token.value) + except Exception as error: + self.error( + name, + token, + color_property_help_text( + name, context="css", error=error, value=token.value + ), + ) + elif token.name == "token": + try: + keyline_color = Color.parse(token.value) + except Exception: + keyline_style = token.value + if keyline_style not in VALID_KEYLINE: + self.error(name, token, keyline_help_text()) + + elif token.name == "scalar": + alpha_scalar = Scalar.parse(token.value) + if alpha_scalar.unit != Unit.PERCENT: + self.error(name, token, "alpha must be given as a percentage.") + keyline_alpha = alpha_scalar.value / 100.0 + + self.styles._rules["keyline"] = ( + keyline_style, + keyline_color.multiply_alpha(keyline_alpha), + ) + + def process_offset(self, name: str, tokens: list[Token]) -> None: + def offset_error(name: str, token: Token) -> None: + self.error(name, token, offset_property_help_text(context="css")) + + if not tokens: + return + if len(tokens) != 2: + offset_error(name, tokens[0]) + else: + token1, token2 = tokens + + if token1.name not in ("scalar", "number"): + offset_error(name, token1) + if token2.name not in ("scalar", "number"): + offset_error(name, token2) + + scalar_x = Scalar.parse(token1.value, Unit.WIDTH) + scalar_y = Scalar.parse(token2.value, Unit.HEIGHT) + self.styles._rules["offset"] = ScalarOffset(scalar_x, scalar_y) + + def process_offset_x(self, name: str, tokens: list[Token]) -> None: + if not tokens: + return + if len(tokens) != 1: + self.error(name, tokens[0], offset_single_axis_help_text(name)) + else: + token = tokens[0] + if token.name not in ("scalar", "number"): + self.error(name, token, offset_single_axis_help_text(name)) + x = Scalar.parse(token.value, Unit.WIDTH) + y = self.styles.offset.y + self.styles._rules["offset"] = ScalarOffset(x, y) + + def process_offset_y(self, name: str, tokens: list[Token]) -> None: + if not tokens: + return + if len(tokens) != 1: + self.error(name, tokens[0], offset_single_axis_help_text(name)) + else: + token = tokens[0] + if token.name not in ("scalar", "number"): + self.error(name, token, offset_single_axis_help_text(name)) + y = Scalar.parse(token.value, Unit.HEIGHT) + x = self.styles.offset.x + self.styles._rules["offset"] = ScalarOffset(x, y) + + def process_position(self, name: str, tokens: list[Token]): + if not tokens: + return + if len(tokens) != 1: + self.error(name, tokens[0], offset_single_axis_help_text(name)) + else: + token = tokens[0] + if token.value not in VALID_POSITION: + self.error(name, tokens[0], position_help_text(name)) + self.styles._rules["position"] = token.value + + def process_layout(self, name: str, tokens: list[Token]) -> None: + from memray._vendor.textual.layouts.factory import MissingLayout, get_layout + + if tokens: + if len(tokens) != 1: + self.error( + name, tokens[0], layout_property_help_text(name, context="css") + ) + else: + value = tokens[0].value + layout_name = value + try: + self.styles._rules["layout"] = get_layout(layout_name) + except MissingLayout: + self.error( + name, + tokens[0], + layout_property_help_text(name, context="css"), + ) + + def process_color(self, name: str, tokens: list[Token]) -> None: + """Processes a simple color declaration.""" + name = name.replace("-", "_") + + color: Color | None = None + alpha: float | None = None + + self.styles._rules[f"auto_{name}"] = False # type: ignore + for token in tokens: + if ( + "background" not in name + and token.name == "token" + and token.value == "auto" + ): + self.styles._rules[f"auto_{name}"] = True # type: ignore + elif token.name == "scalar": + alpha_scalar = Scalar.parse(token.value) + if alpha_scalar.unit != Unit.PERCENT: + self.error(name, token, "alpha must be given as a percentage.") + alpha = alpha_scalar.value / 100.0 + + elif token.name in ("color", "token"): + try: + color = Color.parse(token.value) + except Exception as error: + self.error( + name, + token, + color_property_help_text( + name, context="css", error=error, value=token.value + ), + ) + else: + self.error( + name, + token, + color_property_help_text(name, context="css", value=token.value), + ) + + if color is not None or alpha is not None: + if alpha is not None: + color = (color or Color(255, 255, 255)).multiply_alpha(alpha) + self.styles._rules[name] = color # type: ignore + + process_tint = process_color + process_background = process_color + process_background_tint = process_color + process_scrollbar_color = process_color + process_scrollbar_color_hover = process_color + process_scrollbar_color_active = process_color + process_scrollbar_corner_color = process_color + process_scrollbar_background = process_color + process_scrollbar_background_hover = process_color + process_scrollbar_background_active = process_color + + def process_scrollbar_visibility(self, name: str, tokens: list[Token]) -> None: + """Process scrollbar visibility rules.""" + self.styles._rules["scrollbar_visibility"] = cast( + ScrollbarVisibility, + self._process_enum(name, tokens, VALID_SCROLLBAR_VISIBILITY), + ) + + process_link_color = process_color + process_link_background = process_color + process_link_color_hover = process_color + process_link_background_hover = process_color + + process_border_title_color = process_color + process_border_title_background = process_color + process_border_subtitle_color = process_color + process_border_subtitle_background = process_color + + def process_text_style(self, name: str, tokens: list[Token]) -> None: + for token in tokens: + value = token.value + if value not in VALID_STYLE_FLAGS: + self.error( + name, + token, + style_flags_property_help_text(name, value, context="css"), + ) + + style_definition = " ".join(token.value for token in tokens) + self.styles._rules[name.replace("-", "_")] = style_definition # type: ignore + + process_link_style = process_text_style + process_link_style_hover = process_text_style + + process_border_title_style = process_text_style + process_border_subtitle_style = process_text_style + + def process_text_align(self, name: str, tokens: list[Token]) -> None: + """Process a text-align declaration""" + if not tokens: + return + + if len(tokens) > 1 or tokens[0].value not in VALID_TEXT_ALIGN: + self.error( + name, + tokens[0], + text_align_help_text(), + ) + + self.styles._rules["text_align"] = tokens[0].value # type: ignore + + def process_dock(self, name: str, tokens: list[Token]) -> None: + if not tokens: + return + + if len(tokens) > 1 or tokens[0].value not in VALID_EDGE: + self.error( + name, + tokens[0], + dock_property_help_text(name, context="css"), + ) + + dock_value = tokens[0].value + self.styles._rules["dock"] = dock_value + + def process_split(self, name: str, tokens: list[Token]) -> None: + if not tokens: + return + + if len(tokens) > 1 or tokens[0].value not in VALID_EDGE: + self.error( + name, + tokens[0], + split_property_help_text(name, context="css"), + ) + + split_value = tokens[0].value + self.styles._rules["split"] = split_value + + def process_layer(self, name: str, tokens: list[Token]) -> None: + if len(tokens) > 1: + self.error(name, tokens[1], "unexpected tokens in dock-edge declaration") + self.styles._rules["layer"] = tokens[0].value + + def process_layers(self, name: str, tokens: list[Token]) -> None: + layers: list[str] = [] + for token in tokens: + if token.name not in {"token", "string"}: + self.error(name, token, f"{token.name} not expected here") + layers.append(token.value) + self.styles._rules["layers"] = tuple(layers) + + def process_transition(self, name: str, tokens: list[Token]) -> None: + transitions: dict[str, Transition] = {} + + def make_groups() -> Iterable[list[Token]]: + """Batch tokens into comma-separated groups.""" + group: list[Token] = [] + for token in tokens: + if token.name == "comma": + if group: + yield group + group = [] + else: + group.append(token) + if group: + yield group + + valid_duration_token_names = ("duration", "number") + for tokens in make_groups(): + css_property = "" + duration = 1.0 + easing = "linear" + delay = 0.0 + + try: + iter_tokens = iter(tokens) + token = next(iter_tokens) + if token.name != "token": + self.error(name, token, "expected property") + + css_property = token.value + token = next(iter_tokens) + if token.name not in valid_duration_token_names: + self.error(name, token, "expected duration or number") + try: + duration = _duration_as_seconds(token.value) + except ScalarError as error: + self.error(name, token, str(error)) + + token = next(iter_tokens) + if token.name != "token": + self.error(name, token, "easing function expected") + + if token.value not in EASING: + self.error( + name, + token, + f"expected easing function; found {token.value!r}", + ) + easing = token.value + + token = next(iter_tokens) + if token.name not in valid_duration_token_names: + self.error(name, token, "expected duration or number") + try: + delay = _duration_as_seconds(token.value) + except ScalarError as error: + self.error(name, token, str(error)) + except StopIteration: + pass + transitions[css_property] = Transition(duration, easing, delay) + + self.styles._rules["transitions"] = transitions + + def process_align(self, name: str, tokens: list[Token]) -> None: + def align_error(name, token): + self.error(name, token, align_help_text()) + + if len(tokens) != 2: + self.error(name, tokens[0], align_help_text()) + + token_horizontal = tokens[0] + token_vertical = tokens[1] + + if token_horizontal.name != "token": + align_error(name, token_horizontal) + elif token_horizontal.value not in VALID_ALIGN_HORIZONTAL: + align_error(name, token_horizontal) + + if token_vertical.name != "token": + align_error(name, token_vertical) + elif token_vertical.value not in VALID_ALIGN_VERTICAL: + align_error(name, token_horizontal) + + name = name.replace("-", "_") + self.styles._rules[f"{name}_horizontal"] = token_horizontal.value # type: ignore + self.styles._rules[f"{name}_vertical"] = token_vertical.value # type: ignore + + self._distribute_importance(name, ("horizontal", "vertical")) + + def process_align_horizontal(self, name: str, tokens: list[Token]) -> None: + try: + value = self._process_enum(name, tokens, VALID_ALIGN_HORIZONTAL) + except StyleValueError: + self.error( + name, + tokens[0], + string_enum_help_text(name, VALID_ALIGN_HORIZONTAL, context="css"), + ) + else: + self.styles._rules[name.replace("-", "_")] = value # type: ignore + + def process_align_vertical(self, name: str, tokens: list[Token]) -> None: + try: + value = self._process_enum(name, tokens, VALID_ALIGN_VERTICAL) + except StyleValueError: + self.error( + name, + tokens[0], + string_enum_help_text(name, VALID_ALIGN_VERTICAL, context="css"), + ) + else: + self.styles._rules[name.replace("-", "_")] = value # type: ignore + + process_content_align = process_align + process_content_align_horizontal = process_align_horizontal + process_content_align_vertical = process_align_vertical + + process_border_title_align = process_align_horizontal + process_border_subtitle_align = process_align_horizontal + + def process_scrollbar_gutter(self, name: str, tokens: list[Token]) -> None: + try: + value = self._process_enum(name, tokens, VALID_SCROLLBAR_GUTTER) + except StyleValueError: + self.error( + name, + tokens[0], + string_enum_help_text(name, VALID_SCROLLBAR_GUTTER, context="css"), + ) + else: + self.styles._rules[name.replace("-", "_")] = value # type: ignore + + def process_scrollbar_size(self, name: str, tokens: list[Token]) -> None: + def scrollbar_size_error(name: str, token: Token) -> None: + self.error(name, token, scrollbar_size_property_help_text(context="css")) + + if not tokens: + return + if len(tokens) != 2: + scrollbar_size_error(name, tokens[0]) + else: + token1, token2 = tokens + + if token1.name != "number" or not token1.value.isdigit(): + scrollbar_size_error(name, token1) + if token2.name != "number" or not token2.value.isdigit(): + scrollbar_size_error(name, token2) + + horizontal = int(token1.value) + vertical = int(token2.value) + self.styles._rules["scrollbar_size_horizontal"] = horizontal + self.styles._rules["scrollbar_size_vertical"] = vertical + self._distribute_importance("scrollbar_size", ("horizontal", "vertical")) + + def process_scrollbar_size_vertical(self, name: str, tokens: list[Token]) -> None: + if not tokens: + return + if len(tokens) != 1: + self.error(name, tokens[0], scrollbar_size_single_axis_help_text(name)) + else: + token = tokens[0] + if token.name != "number" or not token.value.isdigit(): + self.error(name, token, scrollbar_size_single_axis_help_text(name)) + value = int(token.value) + self.styles._rules["scrollbar_size_vertical"] = value + + def process_scrollbar_size_horizontal(self, name: str, tokens: list[Token]) -> None: + if not tokens: + return + if len(tokens) != 1: + self.error(name, tokens[0], scrollbar_size_single_axis_help_text(name)) + else: + token = tokens[0] + if token.name != "number" or not token.value.isdigit(): + self.error(name, token, scrollbar_size_single_axis_help_text(name)) + value = int(token.value) + self.styles._rules["scrollbar_size_horizontal"] = value + + def _process_grid_rows_or_columns(self, name: str, tokens: list[Token]) -> None: + scalars: list[Scalar] = [] + percent_unit = Unit.WIDTH if name == "grid-columns" else Unit.HEIGHT + for token in tokens: + if token.name == "number": + scalars.append(Scalar.from_number(float(token.value))) + elif token.name == "scalar": + scalars.append(Scalar.parse(token.value, percent_unit=percent_unit)) + elif token.name == "token" and token.value == "auto": + scalars.append(Scalar.parse("auto")) + else: + self.error( + name, + token, + table_rows_or_columns_help_text(name, token.value, context="css"), + ) + self.styles._rules[name.replace("-", "_")] = scalars # type: ignore + + process_grid_rows = _process_grid_rows_or_columns + process_grid_columns = _process_grid_rows_or_columns + + def _process_integer(self, name: str, tokens: list[Token]) -> None: + if not tokens: + return + if len(tokens) != 1: + self.error(name, tokens[0], integer_help_text(name)) + else: + token = tokens[0] + if token.name != "number" or not token.value.isdigit(): + self.error(name, token, integer_help_text(name)) + value = int(token.value) + if value == 0: + self.error(name, token, integer_help_text(name)) + self.styles._rules[name.replace("-", "_")] = value # type: ignore + + process_grid_gutter_horizontal = _process_integer + process_grid_gutter_vertical = _process_integer + process_column_span = _process_integer + process_row_span = _process_integer + process_grid_size_columns = _process_integer + process_grid_size_rows = _process_integer + process_line_pad = _process_integer + + def process_grid_gutter(self, name: str, tokens: list[Token]) -> None: + if not tokens: + return + if len(tokens) == 1: + token = tokens[0] + if token.name != "number": + self.error(name, token, integer_help_text(name)) + value = max(0, int(token.value)) + self.styles._rules["grid_gutter_horizontal"] = value + self.styles._rules["grid_gutter_vertical"] = value + + elif len(tokens) == 2: + token = tokens[0] + if token.name != "number": + self.error(name, token, integer_help_text(name)) + value = max(0, int(token.value)) + self.styles._rules["grid_gutter_horizontal"] = value + token = tokens[1] + if token.name != "number": + self.error(name, token, integer_help_text(name)) + value = max(0, int(token.value)) + self.styles._rules["grid_gutter_vertical"] = value + + else: + self.error(name, tokens[0], "expected two integers here") + + def process_grid_size(self, name: str, tokens: list[Token]) -> None: + if not tokens: + return + if len(tokens) == 1: + token = tokens[0] + if token.name != "number": + self.error(name, token, integer_help_text(name)) + value = max(0, int(token.value)) + self.styles._rules["grid_size_columns"] = value + self.styles._rules["grid_size_rows"] = 0 + + elif len(tokens) == 2: + token = tokens[0] + if token.name != "number": + self.error(name, token, integer_help_text(name)) + value = max(0, int(token.value)) + self.styles._rules["grid_size_columns"] = value + token = tokens[1] + if token.name != "number": + self.error(name, token, integer_help_text(name)) + value = max(0, int(token.value)) + self.styles._rules["grid_size_rows"] = value + + else: + self.error(name, tokens[0], "expected two integers here") + + def process_overlay(self, name: str, tokens: list[Token]) -> None: + try: + value = self._process_enum(name, tokens, VALID_OVERLAY) + except StyleValueError: + self.error( + name, + tokens[0], + string_enum_help_text(name, VALID_OVERLAY, context="css"), + ) + else: + self.styles._rules[name] = value # type: ignore + + def process_constrain(self, name: str, tokens: list[Token]) -> None: + if len(tokens) == 1: + try: + value = self._process_enum(name, tokens, VALID_CONSTRAIN) + except StyleValueError: + self.error( + name, + tokens[0], + string_enum_help_text(name, VALID_CONSTRAIN, context="css"), + ) + else: + self.styles._rules["constrain_x"] = value # type: ignore + self.styles._rules["constrain_y"] = value # type: ignore + elif len(tokens) == 2: + constrain_x, constrain_y = self._process_enum_multiple( + name, tokens, VALID_CONSTRAIN, 2 + ) + self.styles._rules["constrain_x"] = constrain_x # type: ignore + self.styles._rules["constrain_y"] = constrain_y # type: ignore + else: + self.error(name, tokens[0], "one or two values expected here") + + def process_constrain_x(self, name: str, tokens: list[Token]) -> None: + try: + value = self._process_enum(name, tokens, VALID_CONSTRAIN) + except StyleValueError: + self.error( + name, + tokens[0], + string_enum_help_text(name, VALID_CONSTRAIN, context="css"), + ) + else: + self.styles._rules[name] = value # type: ignore + + def process_constrain_y(self, name: str, tokens: list[Token]) -> None: + try: + value = self._process_enum(name, tokens, VALID_CONSTRAIN) + except StyleValueError: + self.error( + name, + tokens[0], + string_enum_help_text(name, VALID_CONSTRAIN, context="css"), + ) + else: + self.styles._rules[name] = value # type: ignore + + def process_hatch(self, name: str, tokens: list[Token]) -> None: + if not tokens: + return + character: str | None = None + color = TRANSPARENT + opacity = 1.0 + + if len(tokens) == 1 and tokens[0].value == "none": + self.styles._rules[name] = "none" + return + + if len(tokens) not in (2, 3): + self.error(name, tokens[0], "2 or 3 values expected here") + + character_token, color_token, *opacity_tokens = tokens + + if character_token.name == "token": + if character_token.value not in VALID_HATCH: + self.error( + name, + tokens[0], + string_enum_help_text(name, VALID_HATCH, context="css"), + ) + character = HATCHES[character_token.value] + elif character_token.name == "string": + character = character_token.value[1:-1] + if len(character) != 1: + self.error( + name, + character_token, + f"Hatch type requires a string of length 1; got {character_token.value}", + ) + if cell_len(character) != 1: + self.error( + name, + character_token, + f"Hatch type requires a string with a *cell length* of 1; got {character_token.value}", + ) + + if color_token.name in ("color", "token"): + try: + color = Color.parse(color_token.value) + except Exception as error: + self.error( + name, + color_token, + color_property_help_text( + name, context="css", error=error, value=color_token.value + ), + ) + else: + self.error( + name, color_token, f"Expected a color; found {color_token.value!r}" + ) + + if opacity_tokens: + opacity_token = opacity_tokens[0] + if opacity_token.name == "scalar": + opacity_scalar = opacity = Scalar.parse(opacity_token.value) + if opacity_scalar.unit != Unit.PERCENT: + self.error( + name, + opacity_token, + "hatch alpha must be given as a percentage.", + ) + opacity = clamp(opacity_scalar.value / 100.0, 0, 1.0) + else: + self.error( + name, + opacity_token, + f"expected a percentage here; found {opacity_token.value!r}", + ) + + self.styles._rules[name] = (character or " ", color.multiply_alpha(opacity)) + + def process_expand(self, name: str, tokens: list[Token]): + if not tokens: + return + if len(tokens) != 1: + self.error(name, tokens[0], offset_single_axis_help_text(name)) + else: + token = tokens[0] + if token.value not in VALID_EXPAND: + self.error(name, tokens[0], expand_help_text(name)) + self.styles._rules["expand"] = token.value + + def process_pointer(self, name: str, tokens: list[Token]) -> None: + for token in tokens: + name, value, _, _, location, _ = token + if name == "token": + value = value.lower() + if value in VALID_POINTER: + self.styles._rules["pointer"] = value + else: + self.error( + name, + token, + string_enum_help_text( + "pointer", + valid_values=list(VALID_POINTER), + context="css", + ), + ) + + def _get_suggested_property_name_for_rule(self, rule_name: str) -> str | None: + """ + Returns a valid CSS property "Python" name, or None if no close matches could be found. + + Args: + rule_name: An invalid "Python-ised" CSS property (i.e. "offst_x" rather than "offst-x") + + Returns: + The closest valid "Python-ised" CSS property. + Returns `None` if no close matches could be found. + + Example: returns "background" for rule_name "bkgrund", "offset_x" for "ofset_x" + """ + processable_rules_name = [ + attr[8:] for attr in dir(self) if attr.startswith("process_") + ] + return get_suggestion(rule_name, processable_rules_name) diff --git a/src/memray/_vendor/textual/css/constants.py b/src/memray/_vendor/textual/css/constants.py new file mode 100644 index 0000000000..8242ecd942 --- /dev/null +++ b/src/memray/_vendor/textual/css/constants.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +import typing + +if typing.TYPE_CHECKING: + from typing_extensions import Final + +VALID_VISIBILITY: Final = {"visible", "hidden"} +VALID_DISPLAY: Final = {"block", "none"} +VALID_BORDER: Final = { + "ascii", + "blank", + "dashed", + "double", + "heavy", + "hidden", + "hkey", + "inner", + "none", + "outer", + "panel", + "round", + "solid", + "tall", + "tab", + "thick", + "block", + "vkey", + "wide", +} +VALID_EDGE: Final = {"top", "right", "bottom", "left", "none"} +VALID_LAYOUT: Final = {"vertical", "horizontal", "grid", "stream"} + +VALID_BOX_SIZING: Final = {"border-box", "content-box"} +VALID_OVERFLOW: Final = {"scroll", "hidden", "auto"} +VALID_ALIGN_HORIZONTAL: Final = {"left", "center", "right"} +VALID_ALIGN_VERTICAL: Final = {"top", "middle", "bottom"} +VALID_POSITION: Final = {"relative", "absolute"} +VALID_TEXT_ALIGN: Final = { + "start", + "end", + "left", + "right", + "center", + "justify", +} +VALID_SCROLLBAR_GUTTER: Final = {"auto", "stable"} +VALID_STYLE_FLAGS: Final = { + "b", + "blink", + "bold", + "dim", + "i", + "italic", + "none", + "not", + "o", + "overline", + "reverse", + "strike", + "u", + "underline", + "uu", +} +VALID_PSEUDO_CLASSES: Final = { + "ansi", + "blur", + "can-focus", + "dark", + "disabled", + "enabled", + "focus-within", + "focus", + "hover", + "inline", + "light", + "nocolor", + "first-of-type", + "last-of-type", + "first-child", + "last-child", + "odd", + "even", + "empty", +} +VALID_OVERLAY: Final = {"none", "screen"} +VALID_CONSTRAIN: Final = {"inflect", "inside", "none"} +VALID_KEYLINE: Final = {"none", "thin", "heavy", "double"} +VALID_HATCH: Final = {"left", "right", "cross", "vertical", "horizontal"} +VALID_TEXT_WRAP: Final = {"wrap", "nowrap"} +VALID_TEXT_OVERFLOW: Final = {"clip", "fold", "ellipsis"} +VALID_EXPAND: Final = {"greedy", "optimal"} +VALID_SCROLLBAR_VISIBILITY: Final = {"visible", "hidden"} +VALID_POINTER: Final = { + "alias", + "cell", + "copy", + "crosshair", + "default", + "e-resize", + "ew-resize", + "grab", + "grabbing", + "help", + "move", + "n-resize", + "ne-resize", + "nesw-resize", + "no-drop", + "not-allowed", + "ns-resize", + "nw-resize", + "nwse-resize", + "pointer", + "progress", + "s-resize", + "se-resize", + "sw-resize", + "text", + "vertical-text", + "w-resize", + "wait", + "zoom-in", + "zoom-out", +} + +HATCHES: Final = { + "left": "╲", + "right": "╱", + "cross": "╳", + "horizontal": "─", + "vertical": "│", +} diff --git a/src/memray/_vendor/textual/css/errors.py b/src/memray/_vendor/textual/css/errors.py new file mode 100644 index 0000000000..7db0507ac3 --- /dev/null +++ b/src/memray/_vendor/textual/css/errors.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from rich.console import Console, ConsoleOptions, RenderResult + +from memray._vendor.textual.css._help_renderables import HelpText +from memray._vendor.textual.css.tokenizer import Token, TokenError + + +class DeclarationError(Exception): + def __init__(self, name: str, token: Token, message: str | HelpText) -> None: + self.name = name + self.token = token + self.message = message + super().__init__(str(message)) + + +class StyleTypeError(TypeError): + pass + + +class UnresolvedVariableError(TokenError): + pass + + +class StyleValueError(ValueError): + """Raised when the value of a style property is not valid + + Attributes: + help_text: Optional HelpText to be rendered when this + error is raised. + """ + + def __init__(self, *args: object, help_text: HelpText | None = None): + super().__init__(*args) + self.help_text: HelpText | None = help_text + + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + from rich.traceback import Traceback + + yield Traceback.from_exception(type(self), self, self.__traceback__) + if self.help_text is not None: + yield "" + yield self.help_text + yield "" + + +class StylesheetError(Exception): + pass diff --git a/src/memray/_vendor/textual/css/match.py b/src/memray/_vendor/textual/css/match.py new file mode 100644 index 0000000000..b2afb4c0c2 --- /dev/null +++ b/src/memray/_vendor/textual/css/match.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Iterable + +from memray._vendor.textual.css.model import CombinatorType, Selector, SelectorSet + +if TYPE_CHECKING: + from memray._vendor.textual.dom import DOMNode + + +def match(selector_sets: Iterable[SelectorSet], node: DOMNode) -> bool: + """Check if a given node matches any of the given selector sets. + + Args: + selector_sets: Iterable of selector sets. + node: DOM node. + + Returns: + True if the node matches the selector, otherwise False. + """ + return any( + _check_selectors(selector_set.selectors, node.css_path_nodes) + for selector_set in selector_sets + ) + + +def _check_selectors(selectors: list[Selector], css_path_nodes: list[DOMNode]) -> bool: + """Match a list of selectors against DOM nodes. + + Args: + selectors: A list of selectors. + css_path_nodes: The DOM nodes to check the selectors against. + + Returns: + True if any node in css_path_nodes matches a selector. + """ + + DESCENDENT = CombinatorType.DESCENDENT + + node = css_path_nodes[-1] + path_count = len(css_path_nodes) + selector_count = len(selectors) + + stack: list[tuple[int, int]] = [(0, 0)] + + push = stack.append + pop = stack.pop + selector_index = 0 + + while stack: + selector_index, node_index = stack[-1] + if selector_index == selector_count or node_index == path_count: + pop() + else: + path_node = css_path_nodes[node_index] + selector = selectors[selector_index] + if selector.combinator == DESCENDENT: + # Find a matching descendent + if selector.check(path_node): + if path_node is node and selector_index == selector_count - 1: + return True + stack[-1] = (selector_index + 1, node_index + selector.advance) + push((selector_index, node_index + 1)) + else: + stack[-1] = (selector_index, node_index + 1) + else: + # Match the next node + if selector.check(path_node): + if path_node is node and selector_index == selector_count - 1: + return True + stack[-1] = (selector_index + 1, node_index + selector.advance) + else: + pop() + return False diff --git a/src/memray/_vendor/textual/css/model.py b/src/memray/_vendor/textual/css/model.py new file mode 100644 index 0000000000..119019786f --- /dev/null +++ b/src/memray/_vendor/textual/css/model.py @@ -0,0 +1,306 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from functools import partial +from typing import TYPE_CHECKING, Iterable + +import rich.repr + +from memray._vendor.textual.css._help_renderables import HelpText +from memray._vendor.textual.css.styles import Styles +from memray._vendor.textual.css.tokenize import Token +from memray._vendor.textual.css.types import Specificity3 + +if TYPE_CHECKING: + from typing import Callable + + from typing_extensions import Self + + from memray._vendor.textual.dom import DOMNode + + +class SelectorType(Enum): + """Type of selector.""" + + UNIVERSAL = 1 + """i.e. * operator""" + TYPE = 2 + """A CSS type, e.g Label""" + CLASS = 3 + """CSS class, e.g. .loaded""" + ID = 4 + """CSS ID, e.g. #main""" + NESTED = 5 + """Placeholder for nesting operator, i.e &""" + + +class CombinatorType(Enum): + """Type of combinator.""" + + SAME = 1 + """Selector is combined with previous selector""" + DESCENDENT = 2 + """Selector is a descendant of the previous selector""" + CHILD = 3 + """Selector is an immediate child of the previous selector""" + + +def _check_universal(name: str, node: DOMNode) -> bool: + """Check node matches universal selector. + + Args: + name: Selector name. + node: A DOM node. + + Returns: + `True` if the selector matches. + """ + return not node.has_class("-textual-system") + + +def _check_type(name: str, node: DOMNode) -> bool: + """Check node matches a type selector. + + Args: + name: Selector name. + node: A DOM node. + + Returns: + `True` if the selector matches. + """ + return name in node._css_type_names + + +def _check_class(name: str, node: DOMNode) -> bool: + """Check node matches a class selector. + + Args: + name: Selector name. + node: A DOM node. + + Returns: + `True` if the selector matches. + """ + return name in node._classes + + +def _check_id(name: str, node: DOMNode) -> bool: + """Check node matches an ID selector. + + Args: + name: Selector name. + node: A DOM node. + + Returns: + `True` if the selector matches. + """ + return node.id == name + + +_CHECKS = { + SelectorType.UNIVERSAL: _check_universal, + SelectorType.TYPE: _check_type, + SelectorType.CLASS: _check_class, + SelectorType.ID: _check_id, + SelectorType.NESTED: _check_universal, +} + + +@dataclass +class Selector: + """Represents a CSS selector. + + Some examples of selectors: + + * + Header.title + App > Content + """ + + name: str + combinator: CombinatorType = CombinatorType.DESCENDENT + type: SelectorType = SelectorType.TYPE + pseudo_classes: set[str] = field(default_factory=set) + specificity: Specificity3 = field(default_factory=lambda: (0, 0, 0)) + advance: int = 1 + + def __post_init__(self) -> None: + self._check: Callable[[DOMNode], bool] = partial(_CHECKS[self.type], self.name) + + @property + def css(self) -> str: + """Rebuilds the selector as it would appear in CSS.""" + pseudo_suffix = "".join(f":{name}" for name in sorted(self.pseudo_classes)) + if self.type == SelectorType.UNIVERSAL: + return "*" + elif self.type == SelectorType.TYPE: + return f"{self.name}{pseudo_suffix}" + elif self.type == SelectorType.CLASS: + return f".{self.name}{pseudo_suffix}" + else: + return f"#{self.name}{pseudo_suffix}" + + def _add_pseudo_class(self, pseudo_class: str) -> None: + """Adds a pseudo class and updates specificity. + + Args: + pseudo_class: Name of pseudo class. + """ + self.pseudo_classes.add(pseudo_class) + specificity1, specificity2, specificity3 = self.specificity + self.specificity = (specificity1, specificity2 + 1, specificity3) + + def check(self, node: DOMNode) -> bool: + """Check if a given node matches the selector. + + Args: + node: A DOM node. + + Returns: + True if the selector matches, otherwise False. + """ + return self._check(node) and ( + node.has_pseudo_classes(self.pseudo_classes) + if self.pseudo_classes + else True + ) + + +@dataclass +class Declaration: + """A single CSS declaration (not yet processed).""" + + token: Token + name: str + tokens: list[Token] = field(default_factory=list) + + +@rich.repr.auto(angular=True) +@dataclass +class SelectorSet: + """A set of selectors associated with a rule set.""" + + selectors: list[Selector] = field(default_factory=list) + specificity: Specificity3 = (0, 0, 0) + + def __post_init__(self) -> None: + SAME = CombinatorType.SAME + for selector, next_selector in zip(self.selectors, self.selectors[1:]): + selector.advance = int(next_selector.combinator != SAME) + + @property + def css(self) -> str: + return RuleSet._selector_to_css(self.selectors) + + @property + def is_simple(self) -> bool: + """Are all the selectors simple (i.e. only dependent on static DOM state).""" + simple_types = {SelectorType.ID, SelectorType.TYPE} + return all( + (selector.type in simple_types and not selector.pseudo_classes) + for selector in self.selectors + ) + + def __rich_repr__(self) -> rich.repr.Result: + selectors = RuleSet._selector_to_css(self.selectors) + yield selectors + yield None, self.specificity + + def _total_specificity(self) -> Self: + """Calculate total specificity of selectors. + + Returns: + Self. + """ + id_total = class_total = type_total = 0 + for selector in self.selectors: + _id, _class, _type = selector.specificity + id_total += _id + class_total += _class + type_total += _type + self.specificity = (id_total, class_total, type_total) + return self + + @classmethod + def from_selectors(cls, selectors: list[list[Selector]]) -> Iterable[SelectorSet]: + for selector_list in selectors: + id_total = class_total = type_total = 0 + for selector in selector_list: + _id, _class, _type = selector.specificity + id_total += _id + class_total += _class + type_total += _type + yield SelectorSet(selector_list, (id_total, class_total, type_total)) + + +@dataclass +class RuleSet: + selector_set: list[SelectorSet] = field(default_factory=list) + styles: Styles = field(default_factory=Styles) + errors: list[tuple[Token, str | HelpText]] = field(default_factory=list) + + is_default_rules: bool = False + tie_breaker: int = 0 + selector_names: set[str] = field(default_factory=set) + pseudo_classes: set[str] = field(default_factory=set) + + def __hash__(self): + return id(self) + + @classmethod + def _selector_to_css(cls, selectors: list[Selector]) -> str: + tokens: list[str] = [] + for selector in selectors: + if selector.combinator == CombinatorType.DESCENDENT: + tokens.append(" ") + elif selector.combinator == CombinatorType.CHILD: + tokens.append(" > ") + tokens.append(selector.css) + + return "".join(tokens).strip() + + @property + def selectors(self): + return ", ".join( + self._selector_to_css(selector_set.selectors) + for selector_set in self.selector_set + ) + + @property + def css(self) -> str: + """Generate the CSS this RuleSet + + Returns: + A string containing CSS code. + """ + declarations = "\n".join(f" {line}" for line in self.styles.css_lines) + css = f"{self.selectors} {{\n{declarations}\n}}" + return css + + def _post_parse(self) -> None: + """Called after the RuleSet is parsed.""" + # Build a set of the class names that have been updated + + class_type = SelectorType.CLASS + id_type = SelectorType.ID + type_type = SelectorType.TYPE + universal_type = SelectorType.UNIVERSAL + + add_selector = self.selector_names.add + add_pseudo_classes = self.pseudo_classes.update + + for selector_set in self.selector_set: + for selector in selector_set.selectors: + add_pseudo_classes(selector.pseudo_classes) + + selector = selector_set.selectors[-1] + selector_type = selector.type + if selector_type == universal_type: + add_selector("*") + elif selector_type == type_type: + add_selector(selector.name) + elif selector_type == class_type: + add_selector(f".{selector.name}") + elif selector_type == id_type: + add_selector(f"#{selector.name}") diff --git a/src/memray/_vendor/textual/css/parse.py b/src/memray/_vendor/textual/css/parse.py new file mode 100644 index 0000000000..000637356d --- /dev/null +++ b/src/memray/_vendor/textual/css/parse.py @@ -0,0 +1,488 @@ +from __future__ import annotations + +import dataclasses +import re +from functools import lru_cache +from typing import Iterable, Iterator, NoReturn + +from memray._vendor.textual.css._help_renderables import HelpText +from memray._vendor.textual.css._styles_builder import DeclarationError, StylesBuilder +from memray._vendor.textual.css.errors import UnresolvedVariableError +from memray._vendor.textual.css.model import ( + CombinatorType, + Declaration, + RuleSet, + Selector, + SelectorSet, + SelectorType, +) +from memray._vendor.textual.css.styles import Styles +from memray._vendor.textual.css.tokenize import ( + IDENTIFIER, + Token, + tokenize, + tokenize_declarations, + tokenize_values, +) +from memray._vendor.textual.css.tokenizer import ReferencedBy, UnexpectedEnd +from memray._vendor.textual.css.types import CSSLocation, Specificity3 +from memray._vendor.textual.suggestions import get_suggestion + +SELECTOR_MAP: dict[str, tuple[SelectorType, Specificity3]] = { + "selector": (SelectorType.TYPE, (0, 0, 1)), + "selector_start": (SelectorType.TYPE, (0, 0, 1)), + "selector_class": (SelectorType.CLASS, (0, 1, 0)), + "selector_start_class": (SelectorType.CLASS, (0, 1, 0)), + "selector_id": (SelectorType.ID, (1, 0, 0)), + "selector_start_id": (SelectorType.ID, (1, 0, 0)), + "selector_universal": (SelectorType.UNIVERSAL, (0, 0, 0)), + "selector_start_universal": (SelectorType.UNIVERSAL, (0, 0, 0)), + "nested": (SelectorType.NESTED, (0, 0, 0)), +} + +RE_ID_SELECTOR = re.compile("#" + IDENTIFIER) + + +@lru_cache(maxsize=128) +def is_id_selector(selector: str) -> bool: + """Is the selector a single ID selector, i.e. "#foo"? + + Args: + selector: A CSS selector. + + Returns: + `True` if the selector is a simple ID selector, otherwise `False`. + """ + return RE_ID_SELECTOR.fullmatch(selector) is not None + + +def _add_specificity( + specificity1: Specificity3, specificity2: Specificity3 +) -> Specificity3: + """Add specificity tuples together. + + Args: + specificity1: Specificity triple. + specificity2: Specificity triple. + + Returns: + Combined specificity. + """ + + a1, b1, c1 = specificity1 + a2, b2, c2 = specificity2 + return (a1 + a2, b1 + b2, c1 + c2) + + +@lru_cache(maxsize=1024) +def parse_selectors(css_selectors: str) -> tuple[SelectorSet, ...]: + if not css_selectors.strip(): + return () + tokens = iter(tokenize(css_selectors, ("", ""))) + + get_selector = SELECTOR_MAP.get + combinator: CombinatorType | None = CombinatorType.DESCENDENT + selectors: list[Selector] = [] + rule_selectors: list[list[Selector]] = [] + + while True: + try: + token = next(tokens, None) + except UnexpectedEnd: + break + if token is None: + break + token_name = token.name + + if token_name == "pseudo_class": + selectors[-1]._add_pseudo_class(token.value.lstrip(":")) + elif token_name == "whitespace": + if combinator is None or combinator == CombinatorType.SAME: + combinator = CombinatorType.DESCENDENT + elif token_name == "new_selector": + rule_selectors.append(selectors[:]) + selectors.clear() + combinator = None + elif token_name == "declaration_set_start": + break + elif token_name == "combinator_child": + combinator = CombinatorType.CHILD + else: + _selector, specificity = get_selector( + token_name, (SelectorType.TYPE, (0, 0, 0)) + ) + selectors.append( + Selector( + name=token.value.lstrip(".#"), + combinator=combinator or CombinatorType.DESCENDENT, + type=_selector, + specificity=specificity, + ) + ) + combinator = CombinatorType.SAME + if selectors: + rule_selectors.append(selectors[:]) + + selector_set = tuple(SelectorSet.from_selectors(rule_selectors)) + return selector_set + + +def parse_rule_set( + scope: str, + tokens: Iterator[Token], + token: Token, + is_default_rules: bool = False, + tie_breaker: int = 0, +) -> Iterable[RuleSet]: + get_selector = SELECTOR_MAP.get + combinator: CombinatorType | None = CombinatorType.DESCENDENT + selectors: list[Selector] = [] + rule_selectors: list[list[Selector]] = [] + styles_builder = StylesBuilder() + + while True: + if token.name == "pseudo_class": + selectors[-1]._add_pseudo_class(token.value.lstrip(":")) + elif token.name == "whitespace": + if combinator is None or combinator == CombinatorType.SAME: + combinator = CombinatorType.DESCENDENT + elif token.name == "new_selector": + rule_selectors.append(selectors[:]) + selectors.clear() + combinator = None + elif token.name == "declaration_set_start": + break + elif token.name == "combinator_child": + combinator = CombinatorType.CHILD + else: + _selector, specificity = get_selector( + token.name, (SelectorType.TYPE, (0, 0, 0)) + ) + selectors.append( + Selector( + name=token.value.lstrip(".#"), + combinator=combinator or CombinatorType.DESCENDENT, + type=_selector, + specificity=specificity, + ) + ) + combinator = CombinatorType.SAME + + token = next(tokens) + + if selectors: + if scope and selectors[0].name != scope: + scope_selector, scope_specificity = get_selector( + scope, (SelectorType.TYPE, (0, 0, 0)) + ) + selectors.insert( + 0, + Selector( + name=scope, + combinator=CombinatorType.DESCENDENT, + type=scope_selector, + specificity=scope_specificity, + ), + ) + rule_selectors.append(selectors[:]) + + declaration = Declaration(token, "") + errors: list[tuple[Token, str | HelpText]] = [] + nested_rules: list[RuleSet] = [] + + while True: + token = next(tokens) + token_name = token.name + if token_name in ("whitespace", "declaration_end"): + continue + if token_name in { + "selector_start_id", + "selector_start_class", + "selector_start_universal", + "selector_start", + "nested", + }: + recursive_parse: list[RuleSet] = list( + parse_rule_set( + "", + tokens, + token, + is_default_rules=is_default_rules, + tie_breaker=tie_breaker, + ) + ) + + def combine_selectors( + selectors1: list[Selector], selectors2: list[Selector] + ) -> list[Selector]: + """Combine lists of selectors together, processing any nesting. + + Args: + selectors1: List of selectors. + selectors2: Second list of selectors. + + Returns: + Combined selectors. + """ + if selectors2 and selectors2[0].type == SelectorType.NESTED: + final_selector = selectors1[-1] + nested_selector = selectors2[0] + merged_selector = dataclasses.replace( + final_selector, + pseudo_classes=( + final_selector.pseudo_classes + | nested_selector.pseudo_classes + ), + specificity=_add_specificity( + final_selector.specificity, nested_selector.specificity + ), + ) + return [*selectors1[:-1], merged_selector, *selectors2[1:]] + else: + return selectors1 + selectors2 + + for rule_selector in rule_selectors: + for rule_set in recursive_parse: + nested_rule_set = RuleSet( + [ + SelectorSet( + combine_selectors( + rule_selector, recursive_selectors.selectors + ) + )._total_specificity() + for recursive_selectors in rule_set.selector_set + ], + rule_set.styles, + rule_set.errors, + rule_set.is_default_rules, + rule_set.tie_breaker + tie_breaker, + ) + nested_rules.append(nested_rule_set) + continue + if token_name == "declaration_name": + try: + styles_builder.add_declaration(declaration) + except DeclarationError as error: + errors.append((error.token, error.message)) + declaration = Declaration(token, "") + declaration.name = token.value.rstrip(":") + elif token_name == "declaration_set_end": + break + else: + declaration.tokens.append(token) + + try: + styles_builder.add_declaration(declaration) + except DeclarationError as error: + errors.append((error.token, error.message)) + + rule_set = RuleSet( + list(SelectorSet.from_selectors(rule_selectors)), + styles_builder.styles, + errors, + is_default_rules=is_default_rules, + tie_breaker=tie_breaker, + ) + + rule_set._post_parse() + yield rule_set + + for nested_rule_set in nested_rules: + nested_rule_set._post_parse() + yield nested_rule_set + + +def parse_declarations(css: str, read_from: CSSLocation) -> Styles: + """Parse declarations and return a Styles object. + + Args: + css: String containing CSS. + read_from: The location where the CSS was read from. + + Returns: + A styles object. + """ + + tokens = iter(tokenize_declarations(css, read_from)) + styles_builder = StylesBuilder() + + declaration: Declaration | None = None + errors: list[tuple[Token, str | HelpText]] = [] + while True: + token = next(tokens, None) + if token is None: + break + token_name = token.name + if token_name in ("whitespace", "declaration_end", "eof"): + continue + if token_name == "declaration_name": + if declaration: + try: + styles_builder.add_declaration(declaration) + except DeclarationError as error: + errors.append((error.token, error.message)) + raise + declaration = Declaration(token, "") + declaration.name = token.value.rstrip(":") + elif token_name == "declaration_set_end": + break + else: + if declaration: + declaration.tokens.append(token) + + if declaration: + try: + styles_builder.add_declaration(declaration) + except DeclarationError as error: + errors.append((error.token, error.message)) + raise + + return styles_builder.styles + + +def _unresolved(variable_name: str, variables: Iterable[str], token: Token) -> NoReturn: + """Raise a TokenError regarding an unresolved variable. + + Args: + variable_name: A variable name. + variables: Possible choices used to generate suggestion. + token: The Token. + + Raises: + UnresolvedVariableError: Always raises a TokenError. + """ + message = f"reference to undefined variable '${variable_name}'" + suggested_variable = get_suggestion(variable_name, list(variables)) + if suggested_variable: + message += f"; did you mean '${suggested_variable}'?" + + raise UnresolvedVariableError( + token.read_from, + token.code, + token.start, + message, + end=token.end, + ) + + +def substitute_references( + tokens: Iterable[Token], css_variables: dict[str, list[Token]] | None = None +) -> Iterable[Token]: + """Replace variable references with values by substituting variable reference + tokens with the tokens representing their values. + + Args: + tokens: Iterator of Tokens which may contain tokens + with the name "variable_ref". + + Returns: + Yields Tokens such that any variable references (tokens where + token.name == "variable_ref") have been replaced with the tokens representing + the value. In other words, an Iterable of Tokens similar to the original input, + but with variables resolved. Substituted tokens will have their referenced_by + attribute populated with information about where the tokens are being substituted to. + """ + variables: dict[str, list[Token]] = css_variables.copy() if css_variables else {} + iter_tokens = iter(tokens) + + while True: + token = next(iter_tokens, None) + if token is None: + break + if token.name == "variable_name": + variable_name = token.value[1:-1] # Trim the $ and the :, i.e. "$x:" -> "x" + variable_tokens = variables.setdefault(variable_name, []) + yield token + + while True: + token = next(iter_tokens, None) + if token is not None and token.name == "whitespace": + yield token + else: + break + + # Store the tokens for any variable definitions, and substitute + # any variable references we encounter with them. + while True: + if not token: + break + elif token.name == "whitespace": + variable_tokens.append(token) + yield token + elif token.name == "variable_value_end": + yield token + break + # For variables referring to other variables + elif token.name == "variable_ref": + ref_name = token.value[1:] + if ref_name in variables: + reference_tokens = variables[ref_name] + variable_tokens.extend(reference_tokens) + ref_location = token.location + ref_length = len(token.value) + for _token in reference_tokens: + yield _token.with_reference( + ReferencedBy( + ref_name, ref_location, ref_length, token.code + ) + ) + else: + _unresolved(ref_name, variables.keys(), token) + else: + variable_tokens.append(token) + yield token + token = next(iter_tokens, None) + elif token.name == "variable_ref": + variable_name = token.value[1:] # Trim the $, so $x -> x + if variable_name in variables: + variable_tokens = variables[variable_name] + ref_location = token.location + ref_length = len(token.value) + ref_code = token.code + for _token in variable_tokens: + yield _token.with_reference( + ReferencedBy(variable_name, ref_location, ref_length, ref_code) + ) + else: + _unresolved(variable_name, variables.keys(), token) + else: + yield token + + +def parse( + scope: str, + css: str, + read_from: CSSLocation, + variables: dict[str, str] | None = None, + variable_tokens: dict[str, list[Token]] | None = None, + is_default_rules: bool = False, + tie_breaker: int = 0, +) -> Iterable[RuleSet]: + """Parse CSS by tokenizing it, performing variable substitution, + and generating rule sets from it. + + Args: + scope: CSS type name. + css: The input CSS. + read_from: The source location of the CSS. + variables: Substitution variables to substitute tokens for. + is_default_rules: True if the rules we're extracting are + default (i.e. in Widget.DEFAULT_CSS) rules. False if they're from user defined CSS. + """ + reference_tokens = tokenize_values(variables) if variables is not None else {} + if variable_tokens: + reference_tokens.update(variable_tokens) + + tokens = iter(substitute_references(tokenize(css, read_from), variable_tokens)) + while True: + token = next(tokens, None) + if token is None: + break + if token.name.startswith("selector_start"): + yield from parse_rule_set( + scope, + tokens, + token, + is_default_rules=is_default_rules, + tie_breaker=tie_breaker, + ) diff --git a/src/memray/_vendor/textual/css/query.py b/src/memray/_vendor/textual/css/query.py new file mode 100644 index 0000000000..054de67b00 --- /dev/null +++ b/src/memray/_vendor/textual/css/query.py @@ -0,0 +1,508 @@ +""" +This module contains the `DOMQuery` class and related objects. + +A DOMQuery is a set of DOM nodes returned by [query][textual.dom.DOMNode.query]. + +The set of nodes may be further refined with [filter][textual.css.query.DOMQuery.filter] and [exclude][textual.css.query.DOMQuery.exclude]. +Additional methods apply actions to all nodes in the query. + +!!! info + + If this sounds like JQuery, a (once) popular JS library, it is no coincidence. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Generic, Iterable, Iterator, TypeVar, cast, overload + +import rich.repr + +from memray._vendor.textual._context import active_app +from memray._vendor.textual.await_remove import AwaitRemove +from memray._vendor.textual.css.errors import DeclarationError, TokenError +from memray._vendor.textual.css.match import match +from memray._vendor.textual.css.model import SelectorSet +from memray._vendor.textual.css.parse import parse_declarations, parse_selectors + +if TYPE_CHECKING: + from memray._vendor.textual.dom import DOMNode + from memray._vendor.textual.widget import Widget + + +class QueryError(Exception): + """Base class for a query related error.""" + + +class InvalidQueryFormat(QueryError): + """Query did not parse correctly.""" + + +class NoMatches(QueryError): + """No nodes matched the query.""" + + +class TooManyMatches(QueryError): + """Too many nodes matched the query.""" + + +class WrongType(QueryError): + """Query result was not of the correct type.""" + + +QueryType = TypeVar("QueryType", bound="Widget") +"""Type variable used to type generic queries.""" +ExpectType = TypeVar("ExpectType") +"""Type variable used to further restrict queries.""" + + +@rich.repr.auto(angular=True) +class DOMQuery(Generic[QueryType]): + __slots__ = ["_node", "_nodes", "_filters", "_excludes", "_deep"] + + def __init__( + self, + node: DOMNode, + *, + filter: str | None = None, + exclude: str | None = None, + deep: bool = True, + parent: DOMQuery | None = None, + ) -> None: + """Initialize a query object. + + !!! warning + + You won't need to construct this manually, as `DOMQuery` objects are returned by [query][textual.dom.DOMNode.query]. + + Args: + node: A DOM node. + filter: Query to filter children in the node. + exclude: Query to exclude children in the node. + deep: Query should be deep, i.e. recursive. + parent: The parent query, if this is the result of filtering another query. + + Raises: + InvalidQueryFormat: If the format of the query is invalid. + """ + _rich_traceback_omit = True + self._node = node + self._nodes: list[QueryType] | None = None + self._filters: list[tuple[SelectorSet, ...]] = ( + parent._filters.copy() if parent else [] + ) + self._excludes: list[tuple[SelectorSet, ...]] = ( + parent._excludes.copy() if parent else [] + ) + self._deep = deep + if filter is not None: + try: + self._filters.append(parse_selectors(filter)) + except TokenError: + # TODO: More helpful errors + raise InvalidQueryFormat(f"Unable to parse filter {filter!r} as query") + + if exclude is not None: + try: + self._excludes.append(parse_selectors(exclude)) + except TokenError: + raise InvalidQueryFormat(f"Unable to parse filter {filter!r} as query") + + @property + def node(self) -> DOMNode: + """The node being queried.""" + return self._node + + @property + def nodes(self) -> list[QueryType]: + """Lazily evaluate nodes.""" + from memray._vendor.textual.widget import Widget + + if self._nodes is None: + initial_nodes = list( + self._node.walk_children(Widget) if self._deep else self._node._nodes + ) + nodes = [ + node + for node in initial_nodes + if all(match(selector_set, node) for selector_set in self._filters) + ] + nodes = [ + node + for node in nodes + if not any(match(selector_set, node) for selector_set in self._excludes) + ] + self._nodes = cast("list[QueryType]", nodes) + return self._nodes + + def __len__(self) -> int: + return len(self.nodes) + + def __bool__(self) -> bool: + """True if non-empty, otherwise False.""" + return bool(self.nodes) + + def __iter__(self) -> Iterator[QueryType]: + return iter(self.nodes) + + def __reversed__(self) -> Iterator[QueryType]: + return reversed(self.nodes) + + if TYPE_CHECKING: + + @overload + def __getitem__(self, index: int) -> QueryType: ... + + @overload + def __getitem__(self, index: slice) -> list[QueryType]: ... + + def __getitem__(self, index: int | slice) -> QueryType | list[QueryType]: + return self.nodes[index] + + def __rich_repr__(self) -> rich.repr.Result: + try: + if self._filters: + yield ( + "query", + " AND ".join( + ",".join(selector.css for selector in selectors) + for selectors in self._filters + ), + ) + if self._excludes: + yield ( + "exclude", + " OR ".join( + ",".join(selector.css for selector in selectors) + for selectors in self._excludes + ), + ) + except AttributeError: + pass + + def filter(self, selector: str) -> DOMQuery[QueryType]: + """Filter this set by the given CSS selector. + + Args: + selector: A CSS selector. + + Returns: + New DOM Query. + """ + + return DOMQuery( + self.node, + filter=selector, + deep=self._deep, + parent=self, + ) + + def exclude(self, selector: str) -> DOMQuery[QueryType]: + """Exclude nodes that match a given selector. + + Args: + selector: A CSS selector. + + Returns: + New DOM query. + """ + return DOMQuery( + self.node, + exclude=selector, + deep=self._deep, + parent=self, + ) + + if TYPE_CHECKING: + + @overload + def first(self) -> QueryType: ... + + @overload + def first(self, expect_type: type[ExpectType]) -> ExpectType: ... + + def first( + self, expect_type: type[ExpectType] | None = None + ) -> QueryType | ExpectType: + """Get the *first* matching node. + + Args: + expect_type: Require matched node is of this type, + or None for any type. + + Raises: + WrongType: If the wrong type was found. + NoMatches: If there are no matching nodes in the query. + + Returns: + The matching Widget. + """ + _rich_traceback_omit = True + if self.nodes: + first = self.nodes[0] + if expect_type is not None: + if not isinstance(first, expect_type): + raise WrongType( + f"Query value is the wrong type; expected type {expect_type.__name__!r}, found {first}" + ) + return first + else: + raise NoMatches(f"No nodes match {self!r} on {self.node!r}") + + if TYPE_CHECKING: + + @overload + def only_one(self) -> QueryType: ... + + @overload + def only_one(self, expect_type: type[ExpectType]) -> ExpectType: ... + + def only_one( + self, expect_type: type[ExpectType] | None = None + ) -> QueryType | ExpectType: + """Get the *only* matching node. + + Args: + expect_type: Require matched node is of this type, + or None for any type. + + Raises: + WrongType: If the wrong type was found. + NoMatches: If no node matches the query. + TooManyMatches: If there is more than one matching node in the query. + + Returns: + The matching Widget. + """ + _rich_traceback_omit = True + # Call on first to get the first item. Here we'll use all of the + # testing and checking it provides. + the_one: ExpectType | QueryType = ( + self.first(expect_type) if expect_type is not None else self.first() + ) + try: + # Now see if we can access a subsequent item in the nodes. There + # should *not* be anything there, so we *should* get an + # IndexError. We *could* have just checked the length of the + # query, but the idea here is to do the check as cheaply as + # possible. "There can be only one!" -- Kurgan et al. + _ = self.nodes[1] + raise TooManyMatches( + "Call to only_one resulted in more than one matched node" + ) + except IndexError: + # The IndexError was got, that's a good thing in this case. So + # we return what we found. + pass + return the_one + + if TYPE_CHECKING: + + @overload + def last(self) -> QueryType: ... + + @overload + def last(self, expect_type: type[ExpectType]) -> ExpectType: ... + + def last( + self, expect_type: type[ExpectType] | None = None + ) -> QueryType | ExpectType: + """Get the *last* matching node. + + Args: + expect_type: Require matched node is of this type, + or None for any type. + + Raises: + WrongType: If the wrong type was found. + NoMatches: If there are no matching nodes in the query. + + Returns: + The matching Widget. + """ + if not self.nodes: + raise NoMatches(f"No nodes match {self!r} on dom{self.node!r}") + last = self.nodes[-1] + if expect_type is not None and not isinstance(last, expect_type): + raise WrongType( + f"Query value is the wrong type; expected type {expect_type.__name__!r}, found {last}" + ) + return last + + if TYPE_CHECKING: + + @overload + def results(self) -> Iterator[QueryType]: ... + + @overload + def results(self, filter_type: type[ExpectType]) -> Iterator[ExpectType]: ... + + def results( + self, filter_type: type[ExpectType] | None = None + ) -> Iterator[QueryType | ExpectType]: + """Get query results, optionally filtered by a given type. + + Args: + filter_type: A Widget class to filter results, + or None for no filter. + + Yields: + Iterator[Widget | ExpectType]: An iterator of Widget instances. + """ + if filter_type is None: + yield from self + else: + for node in self: + if isinstance(node, filter_type): + yield node + + def set_class(self, add: bool, *class_names: str) -> DOMQuery[QueryType]: + """Set the given class name(s) according to a condition. + + Args: + add: Add the classes if True, otherwise remove them. + + Returns: + Self. + """ + for node in self: + node.set_class(add, *class_names) + return self + + def set_classes(self, classes: str | Iterable[str]) -> DOMQuery[QueryType]: + """Set the classes on nodes to exactly the given set. + + Args: + classes: A string of space separated classes, or an iterable of class names. + + Returns: + Self. + """ + + if isinstance(classes, str): + for node in self: + node.set_classes(classes) + else: + class_names = list(classes) + for node in self: + node.set_classes(class_names) + return self + + def add_class(self, *class_names: str) -> DOMQuery[QueryType]: + """Add the given class name(s) to nodes.""" + for node in self: + node.add_class(*class_names) + return self + + def remove_class(self, *class_names: str) -> DOMQuery[QueryType]: + """Remove the given class names from the nodes.""" + for node in self: + node.remove_class(*class_names) + return self + + def toggle_class(self, *class_names: str) -> DOMQuery[QueryType]: + """Toggle the given class names from matched nodes.""" + for node in self: + node.toggle_class(*class_names) + return self + + def remove(self) -> AwaitRemove: + """Remove matched nodes from the DOM. + + Returns: + An awaitable object that waits for the widgets to be removed. + """ + app = active_app.get() + return app._prune(*self.nodes, parent=self._node) + + def set_styles( + self, css: str | None = None, **update_styles + ) -> DOMQuery[QueryType]: + """Set styles on matched nodes. + + Args: + css: CSS declarations to parser, or None. + """ + _rich_traceback_omit = True + + for node in self: + node.set_styles(**update_styles) + if css is not None: + try: + new_styles = parse_declarations(css, read_from=("set_styles", "")) + except DeclarationError as error: + raise DeclarationError(error.name, error.token, error.message) from None + for node in self: + node._inline_styles.merge(new_styles) + node.refresh(layout=True) + return self + + def refresh( + self, *, repaint: bool = True, layout: bool = False, recompose: bool = False + ) -> DOMQuery[QueryType]: + """Refresh matched nodes. + + Args: + repaint: Repaint node(s). + layout: Layout node(s). + recompose: Recompose node(s). + + Returns: + Query for chaining. + """ + for node in self: + node.refresh(repaint=repaint, layout=layout, recompose=recompose) + return self + + def focus(self) -> DOMQuery[QueryType]: + """Focus the first matching node that permits focus. + + Returns: + Query for chaining. + """ + for node in self: + if node.allow_focus(): + node.focus() + break + return self + + def blur(self) -> DOMQuery[QueryType]: + """Blur the first matching node that is focused. + + Returns: + Query for chaining. + """ + focused = self._node.screen.focused + if focused is not None: + nodes: list[Widget] = list(self) + if focused in nodes: + self._node.screen._reset_focus(focused, avoiding=nodes) + return self + + def set( + self, + display: bool | None = None, + visible: bool | None = None, + disabled: bool | None = None, + loading: bool | None = None, + ) -> DOMQuery[QueryType]: + """Sets common attributes on matched nodes. + + Args: + display: Set `display` attribute on nodes, or `None` for no change. + visible: Set `visible` attribute on nodes, or `None` for no change. + disabled: Set `disabled` attribute on nodes, or `None` for no change. + loading: Set `loading` attribute on nodes, or `None` for no change. + + Returns: + Query for chaining. + """ + for node in self: + if display is not None: + node.display = display + if visible is not None: + node.visible = visible + if disabled is not None: + node.disabled = disabled + if loading is not None: + node.loading = loading + return self diff --git a/src/memray/_vendor/textual/css/scalar.py b/src/memray/_vendor/textual/css/scalar.py new file mode 100644 index 0000000000..89300ce79a --- /dev/null +++ b/src/memray/_vendor/textual/css/scalar.py @@ -0,0 +1,381 @@ +from __future__ import annotations + +import re +from enum import Enum, unique +from fractions import Fraction +from functools import lru_cache +from typing import Iterable, NamedTuple + +import rich.repr + +from memray._vendor.textual.geometry import Offset, Size, clamp + + +class ScalarError(Exception): + """Base class for exceptions raised by the Scalar class.""" + + +class ScalarResolveError(ScalarError): + """Raised for errors resolving scalars (unlikely to occur in practice).""" + + +class ScalarParseError(ScalarError): + """Raised when a scalar couldn't be parsed from a string.""" + + +@unique +class Unit(Enum): + """Enumeration of the various units inherited from CSS.""" + + CELLS = 1 + FRACTION = 2 + PERCENT = 3 + WIDTH = 4 + HEIGHT = 5 + VIEW_WIDTH = 6 + VIEW_HEIGHT = 7 + AUTO = 8 + + +UNIT_SYMBOL = { + Unit.CELLS: "", + Unit.FRACTION: "fr", + Unit.PERCENT: "%", + Unit.WIDTH: "w", + Unit.HEIGHT: "h", + Unit.VIEW_WIDTH: "vw", + Unit.VIEW_HEIGHT: "vh", +} + +SYMBOL_UNIT = {v: k for k, v in UNIT_SYMBOL.items()} + +_MATCH_SCALAR = re.compile(r"^(-?\d+\.?\d*)(fr|%|w|h|vw|vh)?$").match +_FRACTION_ONE = Fraction(1) + + +def _resolve_cells( + value: float, size: Size, viewport: Size, fraction_unit: Fraction +) -> Fraction: + """Resolves explicit cell size, i.e. width: 10 + + Args: + value: Scalar value. + size: Size of widget. + viewport: Size of viewport. + fraction_unit: Size of fraction, i.e. size of 1fr as a Fraction. + + Returns: + Resolved unit. + """ + return Fraction(value) + + +def _resolve_fraction( + value: float, size: Size, viewport: Size, fraction_unit: Fraction +) -> Fraction: + """Resolves a fraction unit i.e. width: 2fr + + Args: + value: Scalar value. + size: Size of widget. + viewport: Size of viewport. + fraction_unit: Size of fraction, i.e. size of 1fr as a Fraction. + + Returns: + Resolved unit. + """ + return fraction_unit * Fraction(value) + + +def _resolve_width( + value: float, size: Size, viewport: Size, fraction_unit: Fraction +) -> Fraction: + """Resolves width unit i.e. width: 50w. + + Args: + value: Scalar value. + size: Size of widget. + viewport: Size of viewport. + fraction_unit: Size of fraction, i.e. size of 1fr as a Fraction. + + Returns: + Resolved unit. + """ + return Fraction(value) * Fraction(size.width, 100) + + +def _resolve_height( + value: float, size: Size, viewport: Size, fraction_unit: Fraction +) -> Fraction: + """Resolves height unit, i.e. height: 12h. + + Args: + value: Scalar value. + size: Size of widget. + viewport: Size of viewport. + fraction_unit: Size of fraction, i.e. size of 1fr as a Fraction. + + Returns: + Resolved unit. + """ + return Fraction(value) * Fraction(size.height, 100) + + +def _resolve_view_width( + value: float, size: Size, viewport: Size, fraction_unit: Fraction +) -> Fraction: + """Resolves view width unit, i.e. width: 25vw. + + Args: + value: Scalar value. + size: Size of widget. + viewport: Size of viewport. + fraction_unit: Size of fraction, i.e. size of 1fr as a Fraction. + + Returns: + Resolved unit. + """ + return Fraction(value) * Fraction(viewport.width, 100) + + +def _resolve_view_height( + value: float, size: Size, viewport: Size, fraction_unit: Fraction +) -> Fraction: + """Resolves view height unit, i.e. height: 25vh. + + Args: + value: Scalar value. + size: Size of widget. + viewport: Size of viewport. + fraction_unit: Size of fraction, i.e. size of 1fr as a Fraction. + + Returns: + Resolved unit. + """ + return Fraction(value) * Fraction(viewport.height, 100) + + +RESOLVE_MAP = { + Unit.CELLS: _resolve_cells, + Unit.FRACTION: _resolve_fraction, + Unit.WIDTH: _resolve_width, + Unit.HEIGHT: _resolve_height, + Unit.VIEW_WIDTH: _resolve_view_width, + Unit.VIEW_HEIGHT: _resolve_view_height, +} + + +def get_symbols(units: Iterable[Unit]) -> list[str]: + """Get symbols for an iterable of units. + + Args: + units: A number of units. + + Returns: + List of symbols. + """ + return [UNIT_SYMBOL[unit] for unit in units] + + +class Scalar(NamedTuple): + """A numeric value and a unit.""" + + value: float + unit: Unit + percent_unit: Unit + + def __str__(self) -> str: + value, unit, _ = self + if unit == Unit.AUTO: + return "auto" + return f"{int(value) if value.is_integer() else value}{self.symbol}" + + @property + def is_cells(self) -> bool: + """Check if the Scalar is explicit cells.""" + return self.unit == Unit.CELLS + + @property + def is_percent(self) -> bool: + """Check if the Scalar is a percentage unit.""" + return self.unit == Unit.PERCENT + + @property + def is_fraction(self) -> bool: + """Check if the unit is a fraction.""" + return self.unit == Unit.FRACTION + + @property + def cells(self) -> int | None: + """Check if the unit is explicit cells.""" + value, unit, _ = self + return int(value) if unit == Unit.CELLS else None + + @property + def fraction(self) -> int | None: + """Get the fraction value, or None if not a value.""" + value, unit, _ = self + return int(value) if unit == Unit.FRACTION else None + + @property + def symbol(self) -> str: + """Get the symbol of this unit.""" + return UNIT_SYMBOL[self.unit] + + @property + def is_auto(self) -> bool: + """Check if this is an auto unit.""" + return self.unit == Unit.AUTO + + @classmethod + def from_number(cls, value: float) -> Scalar: + """Create a scalar with cells unit. + + Args: + value: A number of cells. + + Returns: + New Scalar. + """ + return cls(float(value), Unit.CELLS, Unit.WIDTH) + + @classmethod + @lru_cache(maxsize=1024) + def parse(cls, token: str, percent_unit: Unit = Unit.WIDTH) -> Scalar: + """Parse a string into a Scalar + + Args: + token: A string containing a scalar, e.g. "3.14fr" + + Raises: + ScalarParseError: If the value is not a valid scalar + + Returns: + New scalar + """ + if token.lower() == "auto": + scalar = cls(1.0, Unit.AUTO, Unit.AUTO) + else: + match = _MATCH_SCALAR(token) + if match is None: + raise ScalarParseError(f"{token!r} is not a valid scalar") + value, unit_name = match.groups() + scalar = cls(float(value), SYMBOL_UNIT[unit_name or ""], percent_unit) + return scalar + + @lru_cache(maxsize=4096) + def resolve( + self, size: Size, viewport: Size, fraction_unit: Fraction | None = None + ) -> Fraction: + """Resolve scalar with units into a dimensions. + + Args: + size: Size of the container. + viewport: Size of the viewport (typically terminal size) + + Raises: + ScalarResolveError: If the unit is unknown. + + Returns: + A size (in cells) + """ + value, unit, percent_unit = self + + if unit == Unit.PERCENT: + unit = percent_unit + try: + dimension = RESOLVE_MAP[unit]( + value, size, viewport, fraction_unit or _FRACTION_ONE + ) + except KeyError: + raise ScalarResolveError(f"expected dimensions; found {str(self)!r}") + return dimension + + def copy_with( + self, + value: float | None = None, + unit: Unit | None = None, + percent_unit: Unit | None = None, + ) -> Scalar: + """Get a copy of this Scalar, with values optionally modified + + Args: + value: The new value, or None to keep the same value + unit: The new unit, or None to keep the same unit + percent_unit: The new percent_unit, or None to keep the same percent_unit + """ + return Scalar( + value if value is not None else self.value, + unit if unit is not None else self.unit, + percent_unit if percent_unit is not None else self.percent_unit, + ) + + +@rich.repr.auto(angular=True) +class ScalarOffset(NamedTuple): + """An Offset with two scalars, used to animate between to Scalars.""" + + x: Scalar + y: Scalar + + @classmethod + def null(cls) -> ScalarOffset: + """Get a null scalar offset (0, 0).""" + return NULL_SCALAR + + @classmethod + def from_offset(cls, offset: tuple[int, int]) -> ScalarOffset: + """Create a Scalar offset from a tuple of integers. + + Args: + offset: Offset in cells. + + Returns: + New offset. + """ + x, y = offset + return cls( + Scalar(x, Unit.CELLS, Unit.WIDTH), + Scalar(y, Unit.CELLS, Unit.HEIGHT), + ) + + def __bool__(self) -> bool: + x, y = self + return bool(x.value or y.value) + + def __rich_repr__(self) -> rich.repr.Result: + yield None, str(self.x) + yield None, str(self.y) + + def resolve(self, size: Size, viewport: Size) -> Offset: + """Resolve the offset into cells. + + Args: + size: Size of container. + viewport: Size of viewport. + + Returns: + Offset in cells. + """ + x, y = self + return Offset( + round(x.resolve(size, viewport)), + round(y.resolve(size, viewport)), + ) + + +NULL_SCALAR = ScalarOffset(Scalar.from_number(0), Scalar.from_number(0)) + + +def percentage_string_to_float(string: str) -> float: + """Convert a string percentage e.g. '20%' to a float e.g. 20.0. + + Args: + string: The percentage string to convert. + """ + string = string.strip() + if string.endswith("%"): + float_percentage = clamp(float(string[:-1]) / 100.0, 0.0, 1.0) + else: + float_percentage = float(string) + return float_percentage diff --git a/src/memray/_vendor/textual/css/scalar_animation.py b/src/memray/_vendor/textual/css/scalar_animation.py new file mode 100644 index 0000000000..ddc8498db5 --- /dev/null +++ b/src/memray/_vendor/textual/css/scalar_animation.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from memray._vendor.textual._animator import Animation, EasingFunction +from memray._vendor.textual._types import AnimationLevel, CallbackType +from memray._vendor.textual.css.scalar import Scalar, ScalarOffset + +if TYPE_CHECKING: + from memray._vendor.textual.css.styles import StylesBase + from memray._vendor.textual.widget import Widget + + +class ScalarAnimation(Animation): + def __init__( + self, + widget: Widget, + styles: StylesBase, + start_time: float, + attribute: str, + value: ScalarOffset | Scalar, + duration: float | None, + speed: float | None, + easing: EasingFunction, + on_complete: CallbackType | None = None, + level: AnimationLevel = "full", + ): + assert ( + speed is not None or duration is not None + ), "One of speed or duration required" + self.widget = widget + self.styles = styles + self.start_time = start_time + self.attribute = attribute + self.final_value = value + self.easing = easing + self.on_complete = on_complete + self.level = level + + size = widget.outer_size + viewport = widget.app.size + + self.start = getattr(styles, attribute).resolve(size, viewport) + self.destination = value.resolve(size, viewport) + + if speed is not None: + distance = self.start.get_distance_to(self.destination) + self.duration = distance / speed + else: + assert duration is not None, "Duration expected to be non-None" + self.duration = duration + + def __call__( + self, time: float, app_animation_level: AnimationLevel = "full" + ) -> bool: + factor = min(1.0, (time - self.start_time) / self.duration) + eased_factor = self.easing(factor) + + if ( + eased_factor >= 1 + or app_animation_level == "none" + or app_animation_level == "basic" + and self.level == "full" + ): + setattr(self.styles, self.attribute, self.final_value) + return True + + if hasattr(self.start, "blend"): + value = self.start.blend(self.destination, eased_factor) + else: + value = self.start + (self.destination - self.start) * eased_factor + current = self.styles.get_rule(self.attribute) + if current != value: + setattr(self.styles, self.attribute, value) + + return False + + async def stop(self, complete: bool = True) -> None: + """Stop the animation. + + Args: + complete: Flag to say if the animation should be taken to completion. + + Note: + [`on_complete`][Animation.on_complete] will be called regardless + of the value provided for `complete`. + """ + if complete: + setattr(self.styles, self.attribute, self.final_value) + await self.invoke_callback() + + def __eq__(self, other: object) -> bool: + if isinstance(other, ScalarAnimation): + return ( + self.final_value == other.final_value + and self.duration == other.duration + ) + return False diff --git a/src/memray/_vendor/textual/css/styles.py b/src/memray/_vendor/textual/css/styles.py new file mode 100644 index 0000000000..67724c9496 --- /dev/null +++ b/src/memray/_vendor/textual/css/styles.py @@ -0,0 +1,1528 @@ +from __future__ import annotations + +import weakref +from dataclasses import dataclass, field +from functools import partial +from operator import attrgetter +from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Literal, cast + +import rich.repr +from rich.style import Style +from typing_extensions import TypedDict + +from memray._vendor.textual._animator import DEFAULT_EASING, Animatable, BoundAnimator, EasingFunction +from memray._vendor.textual._types import AnimationLevel, CallbackType +from memray._vendor.textual.color import Color +from memray._vendor.textual.css._style_properties import ( + AlignProperty, + BooleanProperty, + BorderProperty, + BoxProperty, + ColorProperty, + DockProperty, + FractionalProperty, + HatchProperty, + IntegerProperty, + KeylineProperty, + LayoutProperty, + NameListProperty, + NameProperty, + OffsetProperty, + OverflowProperty, + ScalarListProperty, + ScalarProperty, + ScrollbarColorProperty, + SpacingProperty, + SplitProperty, + StringEnumProperty, + StyleFlagsProperty, + TransitionsProperty, +) +from memray._vendor.textual.css.constants import ( + VALID_ALIGN_HORIZONTAL, + VALID_ALIGN_VERTICAL, + VALID_BOX_SIZING, + VALID_CONSTRAIN, + VALID_DISPLAY, + VALID_EXPAND, + VALID_OVERFLOW, + VALID_OVERLAY, + VALID_POINTER, + VALID_POSITION, + VALID_SCROLLBAR_GUTTER, + VALID_SCROLLBAR_VISIBILITY, + VALID_TEXT_ALIGN, + VALID_TEXT_OVERFLOW, + VALID_TEXT_WRAP, + VALID_VISIBILITY, +) +from memray._vendor.textual.css.scalar import Scalar, ScalarOffset, Unit +from memray._vendor.textual.css.scalar_animation import ScalarAnimation +from memray._vendor.textual.css.transition import Transition +from memray._vendor.textual.css.types import ( + AlignHorizontal, + AlignVertical, + BoxSizing, + Constrain, + Display, + Expand, + Overflow, + Overlay, + PointerShape, + ScrollbarGutter, + Specificity3, + Specificity6, + TextAlign, + TextOverflow, + TextWrap, + Visibility, +) +from memray._vendor.textual.geometry import Offset, Spacing + +if TYPE_CHECKING: + from memray._vendor.textual.css.types import CSSLocation + from memray._vendor.textual.dom import DOMNode + from memray._vendor.textual.layout import Layout + + +class RulesMap(TypedDict, total=False): + """A typed dict for CSS rules. + + Any key may be absent, indicating that rule has not been set. + + Does not define composite rules, that is a rule that is made of a combination of other rules. + """ + + display: Display + visibility: Visibility + layout: "Layout" + + auto_color: bool + color: Color + background: Color + text_style: Style + + background_tint: Color + + opacity: float + text_opacity: float + + padding: Spacing + margin: Spacing + offset: ScalarOffset + position: str + + border_top: tuple[str, Color] + border_right: tuple[str, Color] + border_bottom: tuple[str, Color] + border_left: tuple[str, Color] + + border_title_align: AlignHorizontal + border_subtitle_align: AlignHorizontal + + outline_top: tuple[str, Color] + outline_right: tuple[str, Color] + outline_bottom: tuple[str, Color] + outline_left: tuple[str, Color] + + keyline: tuple[str, Color] + + box_sizing: BoxSizing + width: Scalar + height: Scalar + min_width: Scalar + min_height: Scalar + max_width: Scalar + max_height: Scalar + + dock: str + split: str + + overflow_x: Overflow + overflow_y: Overflow + + layers: tuple[str, ...] + layer: str + + transitions: dict[str, Transition] + + tint: Color + + scrollbar_color: Color + scrollbar_color_hover: Color + scrollbar_color_active: Color + + scrollbar_corner_color: Color + + scrollbar_background: Color + scrollbar_background_hover: Color + scrollbar_background_active: Color + scrollbar_gutter: ScrollbarGutter + scrollbar_size_vertical: int + scrollbar_size_horizontal: int + scrollbar_visibility: ScrollbarVisibility + + align_horizontal: AlignHorizontal + align_vertical: AlignVertical + + content_align_horizontal: AlignHorizontal + content_align_vertical: AlignVertical + + grid_size_rows: int + grid_size_columns: int + grid_gutter_horizontal: int + grid_gutter_vertical: int + grid_rows: tuple[Scalar, ...] + grid_columns: tuple[Scalar, ...] + + row_span: int + column_span: int + + text_align: TextAlign + + link_color: Color + auto_link_color: bool + link_background: Color + link_style: Style + + link_color_hover: Color + auto_link_color_hover: bool + link_background_hover: Color + link_style_hover: Style + + auto_border_title_color: bool + border_title_color: Color + border_title_background: Color + border_title_style: Style + + auto_border_subtitle_color: bool + border_subtitle_color: Color + border_subtitle_background: Color + border_subtitle_style: Style + + hatch: tuple[str, Color] | Literal["none"] + + overlay: Overlay + constrain_x: Constrain + constrain_y: Constrain + + text_wrap: TextWrap + text_overflow: TextOverflow + expand: Expand + + line_pad: int + + pointer: PointerShape + + +RULE_NAMES = list(RulesMap.__annotations__.keys()) +RULE_NAMES_SET = frozenset(RULE_NAMES) +_rule_getter = attrgetter(*RULE_NAMES) + + +class StylesBase: + """A common base class for Styles and RenderStyles""" + + ANIMATABLE = { + "offset", + "padding", + "margin", + "width", + "height", + "min_width", + "min_height", + "max_width", + "max_height", + "auto_color", + "color", + "background", + "background_tint", + "opacity", + "position", + "text_opacity", + "tint", + "scrollbar_color", + "scrollbar_color_hover", + "scrollbar_color_active", + "scrollbar_background", + "scrollbar_background_hover", + "scrollbar_background_active", + "scrollbar_visibility", + "link_color", + "link_background", + "link_color_hover", + "link_background_hover", + "text_wrap", + "text_overflow", + "line_pad", + } + + display = StringEnumProperty(VALID_DISPLAY, "block", layout=True, display=True) + """Set the display of the widget, defining how it's rendered. + + Valid values are "block" or "none". + + "none" will hide and allow other widgets to fill the space that this widget would occupy. + + Set to None to clear any value that was set at runtime. + + Raises: + StyleValueError: If an invalid display is specified. + """ + + visibility = StringEnumProperty(VALID_VISIBILITY, "visible", layout=True) + """Set the visibility of the widget. + + Valid values are "visible" or "hidden". + + "hidden" will hide the widget, but reserve the space for this widget. + If you want to hide the widget and allow another widget to fill the space, + set the display attribute to "none" instead. + + Set to None to clear any value that was set at runtime. + + Raises: + StyleValueError: If an invalid visibility is specified. + """ + + layout = LayoutProperty() + """Set the layout of the widget, defining how its children are laid out. + + Valid values are "grid", "stream", "horizontal", or "vertical" or None to clear any layout + that was set at runtime. + + Raises: + MissingLayout: If an invalid layout is specified. + """ + + auto_color = BooleanProperty(default=False) + """Enable automatic picking of best contrasting color.""" + color = ColorProperty(Color(255, 255, 255)) + """Set the foreground (text) color of the widget. + Supports `Color` objects but also strings e.g. "red" or "#ff0000". + You can also specify an opacity after a color e.g. "blue 10%" + """ + background = ColorProperty(Color(0, 0, 0, 0)) + """Set the background color of the widget. + Supports `Color` objects but also strings e.g. "red" or "#ff0000" + You can also specify an opacity after a color e.g. "blue 10%" + """ + background_tint = ColorProperty(Color(0, 0, 0, 0)) + """Set a color to tint (blend) with the background. + Supports `Color` objects but also strings e.g. "red" or "#ff0000" + You can also specify an opacity after a color e.g. "blue 10%" + """ + text_style = StyleFlagsProperty() + """Set the text style of the widget using Rich StyleFlags. + e.g. `"bold underline"` or `"b u strikethrough"`. + """ + opacity = FractionalProperty(children=True) + """Set the opacity of the widget, defining how it blends with the parent.""" + text_opacity = FractionalProperty() + """Set the opacity of the content within the widget against the widget's background.""" + padding = SpacingProperty() + """Set the padding (spacing between border and content) of the widget.""" + margin = SpacingProperty() + """Set the margin (spacing outside the border) of the widget.""" + offset = OffsetProperty() + """Set the offset of the widget relative to where it would have been otherwise.""" + position = StringEnumProperty(VALID_POSITION, "relative") + """If `relative` offset is applied to widgets current position, if `absolute` it is applied to (0, 0).""" + + border = BorderProperty(layout=True) + """Set the border of the widget e.g. ("round", "green") or "none".""" + + border_top = BoxProperty(Color(0, 255, 0)) + """Set the top border of the widget e.g. ("round", "green") or "none".""" + border_right = BoxProperty(Color(0, 255, 0)) + """Set the right border of the widget e.g. ("round", "green") or "none".""" + border_bottom = BoxProperty(Color(0, 255, 0)) + """Set the bottom border of the widget e.g. ("round", "green") or "none".""" + border_left = BoxProperty(Color(0, 255, 0)) + """Set the left border of the widget e.g. ("round", "green") or "none".""" + + border_title_align = StringEnumProperty(VALID_ALIGN_HORIZONTAL, "left") + """The alignment of the border title text.""" + border_subtitle_align = StringEnumProperty(VALID_ALIGN_HORIZONTAL, "right") + """The alignment of the border subtitle text.""" + + outline = BorderProperty(layout=False) + """Set the outline of the widget e.g. ("round", "green") or "none". + The outline is drawn *on top* of the widget, rather than around it like border. + """ + outline_top = BoxProperty(Color(0, 255, 0)) + """Set the top outline of the widget e.g. ("round", "green") or "none".""" + outline_right = BoxProperty(Color(0, 255, 0)) + """Set the right outline of the widget e.g. ("round", "green") or "none".""" + outline_bottom = BoxProperty(Color(0, 255, 0)) + """Set the bottom outline of the widget e.g. ("round", "green") or "none".""" + outline_left = BoxProperty(Color(0, 255, 0)) + """Set the left outline of the widget e.g. ("round", "green") or "none".""" + + keyline = KeylineProperty() + """Keyline parameters.""" + + box_sizing = StringEnumProperty(VALID_BOX_SIZING, "border-box", layout=True) + """Box sizing method ("border-box" or "conetnt-box")""" + width = ScalarProperty(percent_unit=Unit.WIDTH) + """Set the width of the widget.""" + height = ScalarProperty(percent_unit=Unit.HEIGHT) + """Set the height of the widget.""" + min_width = ScalarProperty(percent_unit=Unit.WIDTH, allow_auto=False) + """Set the minimum width of the widget.""" + min_height = ScalarProperty(percent_unit=Unit.HEIGHT, allow_auto=False) + """Set the minimum height of the widget.""" + max_width = ScalarProperty(percent_unit=Unit.WIDTH, allow_auto=False) + """Set the maximum width of the widget.""" + max_height = ScalarProperty(percent_unit=Unit.HEIGHT, allow_auto=False) + """Set the maximum height of the widget.""" + dock = DockProperty() + """Set which edge of the parent to dock this widget to e.g. "top", "left", "right", "bottom", "none". + """ + split = SplitProperty() + + overflow_x = OverflowProperty(VALID_OVERFLOW, "hidden") + """Control what happens when the content extends horizontally beyond the widget's width. + + Valid values are "scroll", "hidden", or "auto". + """ + + overflow_y = OverflowProperty(VALID_OVERFLOW, "hidden") + """Control what happens when the content extends vertically beyond the widget's height. + + Valid values are "scroll", "hidden", or "auto". + """ + + layer = NameProperty() + layers = NameListProperty() + transitions = TransitionsProperty() + + tint = ColorProperty("transparent") + """Set the tint of the widget. This allows you apply an opaque color above the widget. + + You can specify an opacity after a color e.g. "blue 10%" + """ + scrollbar_color = ScrollbarColorProperty("ansi_bright_magenta") + """Set the color of the handle of the scrollbar.""" + scrollbar_color_hover = ScrollbarColorProperty("ansi_yellow") + """Set the color of the handle of the scrollbar when hovered.""" + scrollbar_color_active = ScrollbarColorProperty("ansi_bright_yellow") + """Set the color of the handle of the scrollbar when active (being dragged).""" + scrollbar_corner_color = ScrollbarColorProperty("#666666") + """Set the color of the space between the horizontal and vertical scrollbars.""" + scrollbar_background = ScrollbarColorProperty("#555555") + """Set the background color of the scrollbar (the track that the handle sits on).""" + scrollbar_background_hover = ScrollbarColorProperty("#444444") + """Set the background color of the scrollbar when hovered.""" + scrollbar_background_active = ScrollbarColorProperty("black") + """Set the background color of the scrollbar when active (being dragged).""" + + scrollbar_gutter = StringEnumProperty( + VALID_SCROLLBAR_GUTTER, "auto", layout=True, refresh_children=True + ) + """Set to "stable" to reserve space for the scrollbar even when it's not visible. + This can prevent content from shifting when a scrollbar appears. + """ + + scrollbar_size_vertical = IntegerProperty(default=2, layout=True) + """Set the width of the vertical scrollbar (measured in cells).""" + scrollbar_size_horizontal = IntegerProperty(default=1, layout=True) + """Set the height of the horizontal scrollbar (measured in cells).""" + scrollbar_visibility = StringEnumProperty( + VALID_SCROLLBAR_VISIBILITY, "visible", layout=True + ) + """Sets the visibility of the scrollbar.""" + + align_horizontal = StringEnumProperty( + VALID_ALIGN_HORIZONTAL, "left", layout=True, refresh_children=True + ) + align_vertical = StringEnumProperty( + VALID_ALIGN_VERTICAL, "top", layout=True, refresh_children=True + ) + align = AlignProperty() + + content_align_horizontal = StringEnumProperty(VALID_ALIGN_HORIZONTAL, "left") + content_align_vertical = StringEnumProperty(VALID_ALIGN_VERTICAL, "top") + content_align = AlignProperty() + + grid_rows = ScalarListProperty(percent_unit=Unit.HEIGHT, refresh_children=True) + grid_columns = ScalarListProperty(percent_unit=Unit.WIDTH, refresh_children=True) + + grid_size_columns = IntegerProperty(default=1, layout=True, refresh_children=True) + grid_size_rows = IntegerProperty(default=0, layout=True, refresh_children=True) + grid_gutter_horizontal = IntegerProperty( + default=0, layout=True, refresh_children=True + ) + grid_gutter_vertical = IntegerProperty( + default=0, layout=True, refresh_children=True + ) + + row_span = IntegerProperty(default=1, layout=True) + column_span = IntegerProperty(default=1, layout=True) + + text_align: StringEnumProperty[TextAlign] = StringEnumProperty( + VALID_TEXT_ALIGN, "start" + ) + + link_color = ColorProperty("transparent") + auto_link_color = BooleanProperty(False) + link_background = ColorProperty("transparent") + link_style = StyleFlagsProperty() + + link_color_hover = ColorProperty("transparent") + auto_link_color_hover = BooleanProperty(False) + link_background_hover = ColorProperty("transparent") + link_style_hover = StyleFlagsProperty() + + auto_border_title_color = BooleanProperty(default=False) + border_title_color = ColorProperty(Color(255, 255, 255, 0)) + border_title_background = ColorProperty(Color(0, 0, 0, 0)) + border_title_style = StyleFlagsProperty() + + auto_border_subtitle_color = BooleanProperty(default=False) + border_subtitle_color = ColorProperty(Color(255, 255, 255, 0)) + border_subtitle_background = ColorProperty(Color(0, 0, 0, 0)) + border_subtitle_style = StyleFlagsProperty() + + hatch = HatchProperty() + """Add a hatched background effect e.g. ("right", "yellow") or "none" to use no hatch. + """ + + overlay = StringEnumProperty( + VALID_OVERLAY, "none", layout=True, refresh_parent=True + ) + constrain_x: StringEnumProperty[Constrain] = StringEnumProperty( + VALID_CONSTRAIN, "none" + ) + constrain_y: StringEnumProperty[Constrain] = StringEnumProperty( + VALID_CONSTRAIN, "none" + ) + text_wrap: StringEnumProperty[TextWrap] = StringEnumProperty( + VALID_TEXT_WRAP, "wrap" + ) + text_overflow: StringEnumProperty[TextOverflow] = StringEnumProperty( + VALID_TEXT_OVERFLOW, "fold" + ) + expand: StringEnumProperty[Expand] = StringEnumProperty(VALID_EXPAND, "greedy") + line_pad = IntegerProperty(default=0, layout=True) + """Padding added to left and right of lines.""" + + pointer: StringEnumProperty[PointerShape] = StringEnumProperty( + VALID_POINTER, "default", pointer=True + ) + """Set the pointer (cursor) shape when the mouse is over this widget. + + Valid values include "default", "pointer", "text", "crosshair", "help", "wait", + "move", "grab", "grabbing", and various resize cursors. + + Requires terminal support for Kitty pointer shapes protocol. + """ + + @property + def node(self) -> DOMNode | None: + """The DOM node the styles will be applied to, or `None` if it is not set.""" + return None + + def __textual_animation__( + self, + attribute: str, + start_value: object, + value: object, + start_time: float, + duration: float | None, + speed: float | None, + easing: EasingFunction, + on_complete: CallbackType | None = None, + level: AnimationLevel = "full", + ) -> ScalarAnimation | None: + if self.node is None: + return None + + # Check we are animating a Scalar or Scalar offset + if isinstance(start_value, (Scalar, ScalarOffset)): + # If destination is a number, we can convert that to a scalar + if isinstance(value, (int, float)): + value = Scalar(value, Unit.CELLS, Unit.CELLS) + + # We can only animate to Scalar + if not isinstance(value, (Scalar, ScalarOffset)): + return None + + from memray._vendor.textual.widget import Widget + + assert isinstance(self.node, Widget) + return ScalarAnimation( + self.node, + self, + start_time, + attribute, + value, + duration=duration, + speed=speed, + easing=easing, + on_complete=( + partial(self.node.app.call_later, on_complete) + if on_complete is not None + else None + ), + level=level, + ) + return None + + def __eq__(self, styles: object) -> bool: + """Check that Styles contains the same rules.""" + if not isinstance(styles, StylesBase): + return NotImplemented + return self.get_rules() == styles.get_rules() + + def __getitem__(self, key: str) -> object: + if key not in RULE_NAMES_SET: + raise KeyError(key) + return getattr(self, key) + + def get(self, key: str, default: object | None = None) -> object: + return getattr(self, key) if key in RULE_NAMES_SET else default + + def __len__(self) -> int: + return len(RULE_NAMES) + + def __iter__(self) -> Iterator[str]: + return iter(RULE_NAMES) + + def __contains__(self, key: object) -> bool: + return key in RULE_NAMES_SET + + def keys(self) -> Iterable[str]: + return RULE_NAMES + + def values(self) -> Iterable[object]: + for key in RULE_NAMES: + yield getattr(self, key) + + def items(self) -> Iterable[tuple[str, object]]: + for key in RULE_NAMES: + yield (key, getattr(self, key)) + + @property + def gutter(self) -> Spacing: + """Get space around widget. + + Returns: + Space around widget content. + """ + return self.padding + self.border.spacing + + @property + def auto_dimensions(self) -> bool: + """Check if width or height are set to 'auto'.""" + has_rule = self.has_rule + return (has_rule("width") and self.width.is_auto) or ( # type: ignore + has_rule("height") and self.height.is_auto # type: ignore + ) + + @property + def is_relative_width(self, _relative_units={Unit.FRACTION, Unit.PERCENT}) -> bool: + """Does the node have a relative width?""" + width = self.width + return width is not None and width.unit in _relative_units + + @property + def is_relative_height(self, _relative_units={Unit.FRACTION, Unit.PERCENT}) -> bool: + """Does the node have a relative width?""" + height = self.height + return height is not None and height.unit in _relative_units + + @property + def is_auto_width(self, _auto=Unit.AUTO) -> bool: + """Does the node have automatic width?""" + width = self.width + return width is not None and width.unit == _auto + + @property + def is_auto_height(self, _auto=Unit.AUTO) -> bool: + """Does the node have automatic height?""" + height = self.height + return height is not None and height.unit == _auto + + @property + def is_dynamic_height( + self, _dynamic_units={Unit.AUTO, Unit.FRACTION, Unit.PERCENT} + ) -> bool: + """Does the node have a dynamic (not fixed) height?""" + height = self.height + return height is not None and height.unit in _dynamic_units + + @property + def is_docked(self) -> bool: + """Is the node docked?""" + return self.dock != "none" + + @property + def is_split(self) -> bool: + """Is the node split?""" + return self.split != "none" + + def has_rule(self, rule_name: str) -> bool: + """Check if a rule is set on this Styles object. + + Args: + rule_name: Rule name. + + Returns: + ``True`` if the rules is present, otherwise ``False``. + """ + raise NotImplementedError() + + def clear_rule(self, rule_name: str) -> bool: + """Removes the rule from the Styles object, as if it had never been set. + + Args: + rule_name: Rule name. + + Returns: + ``True`` if a rule was cleared, or ``False`` if the rule is already not set. + """ + raise NotImplementedError() + + def get_rules(self) -> RulesMap: + """Get the rules in a mapping. + + Returns: + A TypedDict of the rules. + """ + raise NotImplementedError() + + def set_rule(self, rule_name: str, value: object | None) -> bool: + """Set a rule. + + Args: + rule_name: Rule name. + value: New rule value. + + Returns: + ``True`` if the rule changed, otherwise ``False``. + """ + raise NotImplementedError() + + def get_rule(self, rule_name: str, default: object = None) -> object: + """Get an individual rule. + + Args: + rule_name: Name of rule. + default: Default if rule does not exists. + + Returns: + Rule value or default. + """ + raise NotImplementedError() + + def refresh( + self, + *, + layout: bool = False, + children: bool = False, + parent: bool = False, + repaint: bool = True, + ) -> None: + """Mark the styles as requiring a refresh. + + Args: + layout: Also require a layout. + children: Also refresh children. + parent: Also refresh the parent. + repaint: Repaint the widgets. + """ + + def reset(self) -> None: + """Reset the rules to initial state.""" + + def merge(self, other: StylesBase) -> None: + """Merge values from another Styles. + + Args: + other: A Styles object. + """ + + def merge_rules(self, rules: RulesMap) -> None: + """Merge rules into Styles. + + Args: + rules: A mapping of rules. + """ + + def get_render_rules(self) -> RulesMap: + """Get rules map with defaults.""" + # Get a dictionary of rules, going through the properties + rules = dict(zip(RULE_NAMES, _rule_getter(self))) + return cast(RulesMap, rules) + + @classmethod + def is_animatable(cls, rule: str) -> bool: + """Check if a given rule may be animated. + + Args: + rule: Name of the rule. + + Returns: + ``True`` if the rule may be animated, otherwise ``False``. + """ + return rule in cls.ANIMATABLE + + @classmethod + def parse( + cls, css: str, read_from: CSSLocation, *, node: DOMNode | None = None + ) -> Styles: + """Parse CSS and return a Styles object. + + Args: + css: Textual CSS. + read_from: Location where the CSS was read from. + node: Node to associate with the Styles. + + Returns: + A Styles instance containing result of parsing CSS. + """ + from memray._vendor.textual.css.parse import parse_declarations + + styles = parse_declarations(css, read_from) + styles.node = node + return styles + + def _get_transition(self, key: str) -> Transition | None: + """Get a transition. + + Args: + key: Transition key. + + Returns: + Transition object or None it no transition exists. + """ + if key in self.ANIMATABLE: + return self.transitions.get(key, None) + else: + return None + + def _align_width(self, width: int, parent_width: int) -> int: + """Align the width dimension. + + Args: + width: Width of the content. + parent_width: Width of the parent container. + + Returns: + An offset to add to the X coordinate. + """ + offset_x = 0 + align_horizontal = self.align_horizontal + if align_horizontal != "left": + if align_horizontal == "center": + offset_x = (parent_width - width) // 2 + else: + offset_x = parent_width - width + + return offset_x + + def _align_height(self, height: int, parent_height: int) -> int: + """Align the height dimensions + + Args: + height: Height of the content. + parent_height: Height of the parent container. + + Returns: + An offset to add to the Y coordinate. + """ + offset_y = 0 + align_vertical = self.align_vertical + if align_vertical != "top": + if align_vertical == "middle": + offset_y = (parent_height - height) // 2 + else: + offset_y = parent_height - height + return offset_y + + def _align_size(self, child: tuple[int, int], parent: tuple[int, int]) -> Offset: + """Align a size according to alignment rules. + + Args: + child: The size of the child (width, height) + parent: The size of the parent (width, height) + + Returns: + Offset required to align the child. + """ + width, height = child + parent_width, parent_height = parent + return Offset( + self._align_width(width, parent_width), + self._align_height(height, parent_height), + ) + + @property + def partial_rich_style(self) -> Style: + """Get the style properties associated with this node only (not including parents in the DOM). + + Returns: + Rich Style object. + """ + style = Style( + color=( + self.color.rich_color + if self.has_rule("color") and self.color.a > 0 + else None + ), + bgcolor=( + self.background.rich_color + if self.has_rule("background") and self.background.a > 0 + else None + ), + ) + style += self.text_style + return style + + +@rich.repr.auto +@dataclass +class Styles(StylesBase): + node: DOMNode | None = None + _rules: RulesMap = field(default_factory=RulesMap) + _updates: int = 0 + + important: set[str] = field(default_factory=set) + + def __post_init__(self) -> None: + self.get_rule: Callable[[str, object], object] = self._rules.get # type: ignore[assignment] + self.has_rule: Callable[[str], bool] = self._rules.__contains__ # type: ignore[assignment] + + def copy(self) -> Styles: + """Get a copy of this Styles object.""" + return Styles( + node=self.node, + _rules=self.get_rules(), + important=self.important, + ) + + def clear_rule(self, rule_name: str) -> bool: + """Removes the rule from the Styles object, as if it had never been set. + + Args: + rule_name: Rule name. + + Returns: + ``True`` if a rule was cleared, or ``False`` if it was already not set. + """ + changed = self._rules.pop(rule_name, None) is not None # type: ignore + if changed: + self._updates += 1 + return changed + + def get_rules(self) -> RulesMap: + return self._rules.copy() + + def set_rule(self, rule: str, value: object | None) -> bool: + """Set a rule. + + Args: + rule: Rule name. + value: New rule value. + + Returns: + ``True`` if the rule changed, otherwise ``False``. + """ + if value is None: + changed = self._rules.pop(rule, None) is not None # type: ignore + if changed: + self._updates += 1 + return changed + current = self._rules.get(rule) + self._rules[rule] = value # type: ignore + changed = current != value + if changed: + self._updates += 1 + return changed + + def refresh( + self, + *, + layout: bool = False, + children: bool = False, + parent: bool = False, + repaint=True, + ) -> None: + node = self.node + if node is None or not node._is_mounted: + return + if parent and node._parent is not None: + node._parent.refresh(repaint=repaint) + node.refresh(layout=layout) + if children: + for child in node.walk_children(with_self=False, reverse=True): + child.refresh(layout=layout, repaint=repaint) + + def reset(self) -> None: + """Reset the rules to initial state.""" + self._updates += 1 + self._rules.clear() # type: ignore + + def merge(self, other: StylesBase) -> None: + """Merge values from another Styles. + + Args: + other: A Styles object. + """ + self._updates += 1 + self._rules.update(other.get_rules()) + + def merge_rules(self, rules: RulesMap) -> None: + self._updates += 1 + self._rules.update(rules) + + def extract_rules( + self, + specificity: Specificity3, + is_default_rules: bool = False, + tie_breaker: int = 0, + ) -> list[tuple[str, Specificity6, Any]]: + """Extract rules from Styles object, and apply !important css specificity as + well as higher specificity of user CSS vs widget CSS. + + Args: + specificity: A node specificity. + is_default_rules: True if the rules we're extracting are + default (i.e. in Widget.DEFAULT_CSS) rules. False if they're from user defined CSS. + + Returns: + A list containing a tuple of , . + """ + is_important = self.important.__contains__ + default_rules = 0 if is_default_rules else 1 + rules: list[tuple[str, Specificity6, Any]] = [ + ( + rule_name, + ( + default_rules, + 1 if is_important(rule_name) else 0, + *specificity, + tie_breaker, + ), + rule_value, + ) + for rule_name, rule_value in self._rules.items() + ] + + return rules + + def __rich_repr__(self) -> rich.repr.Result: + has_rule = self.has_rule + for name in RULE_NAMES: + if has_rule(name): + yield name, getattr(self, name) + if self.important: + yield "important", self.important + + def _get_border_css_lines( + self, rules: RulesMap, name: str + ) -> Iterable[tuple[str, str]]: + """Get pairs of strings containing , for border css declarations. + + Args: + rules: A rules map. + name: Name of rules (border or outline) + + Returns: + An iterable of CSS declarations. + """ + + has_rule = rules.__contains__ + get_rule = rules.__getitem__ + + has_top = has_rule(f"{name}_top") + has_right = has_rule(f"{name}_right") + has_bottom = has_rule(f"{name}_bottom") + has_left = has_rule(f"{name}_left") + if not any((has_top, has_right, has_bottom, has_left)): + # No border related rules + return + + if all((has_top, has_right, has_bottom, has_left)): + # All rules are set + # See if we can set them with a single border: declaration + top = get_rule(f"{name}_top") + right = get_rule(f"{name}_right") + bottom = get_rule(f"{name}_bottom") + left = get_rule(f"{name}_left") + + if top == right and right == bottom and bottom == left: + border_type, border_color = rules[f"{name}_top"] # type: ignore + yield name, f"{border_type} {border_color.hex}" + return + + # Check for edges + if has_top: + border_type, border_color = rules[f"{name}_top"] # type: ignore + yield f"{name}-top", f"{border_type} {border_color.hex}" + + if has_right: + border_type, border_color = rules[f"{name}_right"] # type: ignore + yield f"{name}-right", f"{border_type} {border_color.hex}" + + if has_bottom: + border_type, border_color = rules[f"{name}_bottom"] # type: ignore + yield f"{name}-bottom", f"{border_type} {border_color.hex}" + + if has_left: + border_type, border_color = rules[f"{name}_left"] # type: ignore + yield f"{name}-left", f"{border_type} {border_color.hex}" + + @property + def css_lines(self) -> list[str]: + lines: list[str] = [] + append = lines.append + + def append_declaration(name: str, value: str) -> None: + if name in self.important: + append(f"{name}: {value} !important;") + else: + append(f"{name}: {value};") + + rules = self.get_rules() + get_rule = rules.get + + if "display" in rules: + append_declaration("display", rules["display"]) + if "visibility" in rules: + append_declaration("visibility", rules["visibility"]) + if "padding" in rules: + append_declaration("padding", rules["padding"].css) + if "margin" in rules: + append_declaration("margin", rules["margin"].css) + + for name, rule in self._get_border_css_lines(rules, "border"): + append_declaration(name, rule) + + for name, rule in self._get_border_css_lines(rules, "outline"): + append_declaration(name, rule) + + if "offset" in rules: + x, y = self.offset + append_declaration("offset", f"{x} {y}") + if "position" in rules: + append_declaration("position", self.position) + if "dock" in rules: + append_declaration("dock", rules["dock"]) + if "split" in rules: + append_declaration("split", rules["split"]) + if "layers" in rules: + append_declaration("layers", " ".join(self.layers)) + if "layer" in rules: + append_declaration("layer", self.layer) + if "layout" in rules: + assert self.layout is not None + append_declaration("layout", self.layout.name) + + if "color" in rules: + append_declaration("color", self.color.hex) + if "background" in rules: + append_declaration("background", self.background.hex) + if "background_tint" in rules: + append_declaration("background-tint", self.background_tint.hex) + if "text_style" in rules: + append_declaration("text-style", str(get_rule("text_style"))) + if "tint" in rules: + append_declaration("tint", self.tint.css) + + if "overflow_x" in rules: + append_declaration("overflow-x", self.overflow_x) + if "overflow_y" in rules: + append_declaration("overflow-y", self.overflow_y) + + if "scrollbar_color" in rules: + append_declaration("scrollbar-color", self.scrollbar_color.css) + if "scrollbar_color_hover" in rules: + append_declaration("scrollbar-color-hover", self.scrollbar_color_hover.css) + if "scrollbar_color_active" in rules: + append_declaration( + "scrollbar-color-active", self.scrollbar_color_active.css + ) + + if "scrollbar_corner_color" in rules: + append_declaration( + "scrollbar-corner-color", self.scrollbar_corner_color.css + ) + + if "scrollbar_background" in rules: + append_declaration("scrollbar-background", self.scrollbar_background.css) + if "scrollbar_background_hover" in rules: + append_declaration( + "scrollbar-background-hover", self.scrollbar_background_hover.css + ) + if "scrollbar_background_active" in rules: + append_declaration( + "scrollbar-background-active", self.scrollbar_background_active.css + ) + + if "scrollbar_gutter" in rules: + append_declaration("scrollbar-gutter", self.scrollbar_gutter) + if "scrollbar_size" in rules: + append_declaration( + "scrollbar-size", + f"{self.scrollbar_size_horizontal} {self.scrollbar_size_vertical}", + ) + else: + if "scrollbar_size_horizontal" in rules: + append_declaration( + "scrollbar-size-horizontal", str(self.scrollbar_size_horizontal) + ) + if "scrollbar_size_vertical" in rules: + append_declaration( + "scrollbar-size-vertical", str(self.scrollbar_size_vertical) + ) + if "scrollbar_visibility" in rules: + append_declaration("scrollbar-visibility", self.scrollbar_visibility) + + if "box_sizing" in rules: + append_declaration("box-sizing", self.box_sizing) + if "width" in rules: + append_declaration("width", str(self.width)) + if "height" in rules: + append_declaration("height", str(self.height)) + if "min_width" in rules: + append_declaration("min-width", str(self.min_width)) + if "min_height" in rules: + append_declaration("min-height", str(self.min_height)) + if "max_width" in rules: + append_declaration("max-width", str(self.max_width)) + if "max_height" in rules: + append_declaration("max-height", str(self.max_height)) + if "transitions" in rules: + append_declaration( + "transition", + ", ".join( + f"{name} {transition}" + for name, transition in self.transitions.items() + ), + ) + + if "align_horizontal" in rules and "align_vertical" in rules: + append_declaration( + "align", f"{self.align_horizontal} {self.align_vertical}" + ) + elif "align_horizontal" in rules: + append_declaration("align-horizontal", self.align_horizontal) + elif "align_vertical" in rules: + append_declaration("align-vertical", self.align_vertical) + + if "content_align_horizontal" in rules and "content_align_vertical" in rules: + append_declaration( + "content-align", + f"{self.content_align_horizontal} {self.content_align_vertical}", + ) + elif "content_align_horizontal" in rules: + append_declaration( + "content-align-horizontal", self.content_align_horizontal + ) + elif "content_align_vertical" in rules: + append_declaration("content-align-vertical", self.content_align_vertical) + + if "text_align" in rules: + append_declaration("text-align", self.text_align) + + if "border_title_align" in rules: + append_declaration("border-title-align", self.border_title_align) + if "border_subtitle_align" in rules: + append_declaration("border-subtitle-align", self.border_subtitle_align) + + if "opacity" in rules: + append_declaration("opacity", str(self.opacity)) + if "text_opacity" in rules: + append_declaration("text-opacity", str(self.text_opacity)) + + if "grid_columns" in rules: + append_declaration( + "grid-columns", + " ".join(str(scalar) for scalar in self.grid_columns or ()), + ) + if "grid_rows" in rules: + append_declaration( + "grid-rows", + " ".join(str(scalar) for scalar in self.grid_rows or ()), + ) + if "grid_size_columns" in rules: + append_declaration("grid-size-columns", str(self.grid_size_columns)) + if "grid_size_rows" in rules: + append_declaration("grid-size-rows", str(self.grid_size_rows)) + + if "grid_gutter_horizontal" in rules: + append_declaration( + "grid-gutter-horizontal", str(self.grid_gutter_horizontal) + ) + if "grid_gutter_vertical" in rules: + append_declaration("grid-gutter-vertical", str(self.grid_gutter_vertical)) + + if "row_span" in rules: + append_declaration("row-span", str(self.row_span)) + if "column_span" in rules: + append_declaration("column-span", str(self.column_span)) + + if "link_color" in rules: + append_declaration("link-color", self.link_color.css) + if "link_background" in rules: + append_declaration("link-background", self.link_background.css) + if "link_style" in rules: + append_declaration("link-style", str(self.link_style)) + + if "link_color_hover" in rules: + append_declaration("link-color-hover", self.link_color_hover.css) + if "link_background_hover" in rules: + append_declaration("link-background-hover", self.link_background_hover.css) + if "link_style_hover" in rules: + append_declaration("link-style-hover", str(self.link_style_hover)) + + if "border_title_color" in rules: + append_declaration("title-color", self.border_title_color.css) + if "border_title_background" in rules: + append_declaration("title-background", self.border_title_background.css) + if "border_title_style" in rules: + append_declaration("title-text-style", str(self.border_title_style)) + + if "border_subtitle_color" in rules: + append_declaration("subtitle-color", self.border_subtitle_color.css) + if "border_subtitle_background" in rules: + append_declaration( + "subtitle-background", self.border_subtitle_background.css + ) + if "border_subtitle_text_style" in rules: + append_declaration("subtitle-text-style", str(self.border_subtitle_style)) + if "overlay" in rules: + append_declaration("overlay", str(self.overlay)) + if "constrain_x" in rules and "constrain_y" in rules: + if self.constrain_x == self.constrain_y: + append_declaration("constrain", self.constrain_x) + else: + append_declaration( + "constrain", f"{self.constrain_x} {self.constrain_y}" + ) + elif "constrain_x" in rules: + append_declaration("constrain-x", self.constrain_x) + elif "constrain_y" in rules: + append_declaration("constrain-y", self.constrain_y) + + if "keyline" in rules: + keyline_type, keyline_color = self.keyline + if keyline_type != "none": + append_declaration("keyline", f"{keyline_type}, {keyline_color.css}") + if "hatch" in rules: + hatch_character, hatch_color = self.hatch + append_declaration("hatch", f'"{hatch_character}" {hatch_color.css}') + if "text_wrap" in rules: + append_declaration("text-wrap", self.text_wrap) + if "text_overflow" in rules: + append_declaration("text-overflow", self.text_overflow) + if "expand" in rules: + append_declaration("expand", self.expand) + if "line_pad" in rules: + append_declaration("line-pad", str(self.line_pad)) + lines.sort() + return lines + + @property + def css(self) -> str: + return "\n".join(self.css_lines) + + +@rich.repr.auto +class RenderStyles(StylesBase): + """Presents a combined view of two Styles object: a base Styles and inline Styles.""" + + def __init__(self, node: DOMNode, base: Styles, inline_styles: Styles) -> None: + self._node = weakref.ref(node) + self._base_styles = base + self._inline_styles = inline_styles + self._animate: BoundAnimator | None = None + self._updates: int = 0 + self._rich_style: tuple[int, Style] | None = None + self._gutter: tuple[int, Spacing] | None = None + + def _update_node(self, node: DOMNode) -> None: + """Update the associated DOM node. + + Args: + node: New node for the styles. + """ + self._node = weakref.ref(node) + + def __eq__(self, other: object) -> bool: + if isinstance(other, RenderStyles): + return ( + self._base_styles._rules == other._base_styles._rules + and self._inline_styles._rules == other._inline_styles._rules + ) + return NotImplemented + + @property + def _cache_key(self) -> int: + """A cache key, that changes when any style is changed. + + Returns: + An opaque integer. + """ + return self._updates + self._base_styles._updates + self._inline_styles._updates + + @property + def node(self) -> DOMNode | None: + """The DOM node the styles will be applied to, or `None` if it is not set.""" + return self._node() + + @property + def base(self) -> Styles: + """Quick access to base (css) style.""" + return self._base_styles + + @property + def inline(self) -> Styles: + """Quick access to the inline styles.""" + return self._inline_styles + + @property + def rich_style(self) -> Style: + """Get a Rich style for this Styles object.""" + assert self.node is not None + return self.node.rich_style + + @property + def gutter(self) -> Spacing: + """Get space around widget (padding + border) + + Returns: + Space around widget content. + """ + # This is (surprisingly) a bit of a bottleneck + if self._gutter is not None: + cache_key, gutter = self._gutter + if cache_key == self._cache_key: + return gutter + gutter = self.padding + self.border.spacing + self._gutter = (self._cache_key, gutter) + return gutter + + def animate( + self, + attribute: str, + value: str | float | Animatable, + *, + final_value: object = ..., + duration: float | None = None, + speed: float | None = None, + delay: float = 0.0, + easing: EasingFunction | str = DEFAULT_EASING, + on_complete: CallbackType | None = None, + level: AnimationLevel = "full", + ) -> None: + """Animate an attribute. + + Args: + attribute: Name of the attribute to animate. + value: The value to animate to. + final_value: The final value of the animation. Defaults to `value` if not set. + duration: The duration (in seconds) of the animation. + speed: The speed of the animation. + delay: A delay (in seconds) before the animation starts. + easing: An easing method. + on_complete: A callable to invoke when the animation is finished. + level: Minimum level required for the animation to take place (inclusive). + """ + if self._animate is None: + assert self.node is not None + self._animate = self.node.app.animator.bind(self) + assert self._animate is not None + self._animate( + attribute, + value, + final_value=final_value, + duration=duration, + speed=speed, + delay=delay, + easing=easing, + on_complete=on_complete, + level=level, + ) + + def __rich_repr__(self) -> rich.repr.Result: + yield self.node + for rule_name in RULE_NAMES: + if self.has_rule(rule_name): + yield rule_name, getattr(self, rule_name) + + def refresh( + self, + *, + layout: bool = False, + children: bool = False, + parent: bool = False, + repaint: bool = True, + ) -> None: + self._inline_styles.refresh( + layout=layout, children=children, parent=parent, repaint=repaint + ) + + def merge(self, other: StylesBase) -> None: + """Merge values from another Styles. + + Args: + other: A Styles object. + """ + self._inline_styles.merge(other) + + def merge_rules(self, rules: RulesMap) -> None: + self._inline_styles.merge_rules(rules) + self._updates += 1 + + def reset(self) -> None: + """Reset the rules to initial state.""" + self._inline_styles.reset() + self._updates += 1 + + def has_rule(self, rule_name: str) -> bool: + """Check if a rule has been set.""" + return self._inline_styles.has_rule(rule_name) or self._base_styles.has_rule( + rule_name + ) + + def has_any_rules(self, *rule_names: str) -> bool: + """Check if any of the supplied rules have been set. + + Args: + rule_names: Number of rules. + + Returns: + `True` if any of the supplied rules have been set, `False` if none have. + """ + inline_has_rule = self._inline_styles.has_rule + base_has_rule = self._base_styles.has_rule + return any(inline_has_rule(name) or base_has_rule(name) for name in rule_names) + + def set_rule(self, rule_name: str, value: object | None) -> bool: + return self._inline_styles.set_rule(rule_name, value) + + def get_rule(self, rule_name: str, default: object = None) -> object: + if self._inline_styles.has_rule(rule_name): + return self._inline_styles.get_rule(rule_name, default) + return self._base_styles.get_rule(rule_name, default) + + def clear_rule(self, rule_name: str) -> bool: + """Clear a rule (from inline).""" + return self._inline_styles.clear_rule(rule_name) + + def get_rules(self) -> RulesMap: + """Get rules as a dictionary""" + rules = {**self._base_styles._rules, **self._inline_styles._rules} + return cast(RulesMap, rules) + + @property + def css(self) -> str: + """Get the CSS for the combined styles.""" + styles = Styles() + styles.merge(self._base_styles) + styles.merge(self._inline_styles) + combined_css = styles.css + return combined_css diff --git a/src/memray/_vendor/textual/css/stylesheet.py b/src/memray/_vendor/textual/css/stylesheet.py new file mode 100644 index 0000000000..521dc076e0 --- /dev/null +++ b/src/memray/_vendor/textual/css/stylesheet.py @@ -0,0 +1,737 @@ +from __future__ import annotations + +import os +from collections import defaultdict +from itertools import chain +from operator import itemgetter +from pathlib import Path, PurePath +from typing import Final, Iterable, NamedTuple, Sequence, cast + +import rich.repr +from rich.console import Console, ConsoleOptions, RenderableType, RenderResult +from rich.markup import render +from rich.padding import Padding +from rich.panel import Panel +from rich.text import Text + +from memray._vendor.textual.cache import LRUCache +from memray._vendor.textual.css.errors import StylesheetError +from memray._vendor.textual.css.match import _check_selectors +from memray._vendor.textual.css.model import RuleSet +from memray._vendor.textual.css.parse import parse +from memray._vendor.textual.css.styles import RulesMap, Styles +from memray._vendor.textual.css.tokenize import Token, tokenize_values +from memray._vendor.textual.css.tokenizer import TokenError +from memray._vendor.textual.css.types import CSSLocation, Specificity3, Specificity6 +from memray._vendor.textual.dom import DOMNode +from memray._vendor.textual.markup import parse_style +from memray._vendor.textual.style import Style +from memray._vendor.textual.widget import Widget + +_DEFAULT_STYLES = Styles() + + +class StylesheetParseError(StylesheetError): + """Raised when the stylesheet could not be parsed.""" + + def __init__(self, errors: StylesheetErrors) -> None: + self.errors = errors + + def __rich__(self) -> RenderableType: + return self.errors + + +class StylesheetErrors: + """A renderable for stylesheet errors.""" + + def __init__(self, rules: list[RuleSet]) -> None: + self.rules = rules + self.variables: dict[str, str] = {} + + @classmethod + def _get_snippet(cls, code: str, line_no: int) -> RenderableType: + from rich.syntax import Syntax + + syntax = Syntax( + code, + lexer="scss", + theme="ansi_light", + line_numbers=True, + indent_guides=True, + line_range=(max(0, line_no - 2), line_no + 2), + highlight_lines={line_no}, + ) + return syntax + + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + error_count = 0 + errors = list( + dict.fromkeys(chain.from_iterable(_rule.errors for _rule in self.rules)) + ) + + for token, message in errors: + error_count += 1 + + if token.referenced_by: + line_idx, col_idx = token.referenced_by.location + else: + line_idx, col_idx = token.location + line_no, col_no = line_idx + 1, col_idx + 1 + + display_path, widget_var = token.read_from + if display_path: + link_path = str(Path(display_path).absolute()) + filename = Path(link_path).name + else: + link_path = "" + filename = "" + # If we have a widget/variable from where the CSS was read, then line/column + # numbers are relative to the inline CSS and we'll display them next to the + # widget/variable. + # Otherwise, they're absolute positions in a TCSS file and we can show them + # next to the file path. + if widget_var: + path_string = link_path or filename + widget_string = f" in {widget_var}:{line_no}:{col_no}" + else: + path_string = f"{link_path or filename}:{line_no}:{col_no}" + widget_string = "" + + title = Text.assemble( + "Error at ", path_string, widget_string, style="bold red" + ) + yield "" + yield Panel( + self._get_snippet( + token.referenced_by.code if token.referenced_by else token.code, + line_no, + ), + title=title, + title_align="left", + border_style="red", + ) + yield Padding(message, pad=(0, 0, 1, 3)) + + yield "" + yield render( + f" [b][red]CSS parsing failed:[/] {error_count} error{'s' if error_count != 1 else ''}[/] found in stylesheet" + ) + + +class CssSource(NamedTuple): + """Contains the CSS content and whether or not the CSS comes from user defined stylesheets + vs widget-level stylesheets. + + Args: + content: The CSS as a string. + is_defaults: True if the CSS is default (i.e. that defined at the widget level). + False if it's user CSS (which will override the defaults). + tie_breaker: Specificity tie breaker. + scope: Scope of CSS. + """ + + content: str + is_defaults: bool + tie_breaker: int = 0 + scope: str = "" + + +@rich.repr.auto(angular=True) +class Stylesheet: + """A Stylesheet generated from Textual CSS.""" + + def __init__(self, *, variables: dict[str, str] | None = None) -> None: + self._rules: list[RuleSet] = [] + self._rules_map: dict[str, list[RuleSet]] | None = None + self._variables = variables or {} + self.__variable_tokens: dict[str, list[Token]] | None = None + self.source: dict[CSSLocation, CssSource] = {} + self._require_parse = False + self._invalid_css: set[str] = set() + self._parse_cache: LRUCache[tuple, list[RuleSet]] = LRUCache(64) + self._style_parse_cache: LRUCache[str, Style] = LRUCache(1024 * 4) + + def __rich_repr__(self) -> rich.repr.Result: + yield list(self.source.keys()) + + @property + def _variable_tokens(self) -> dict[str, list[Token]]: + if self.__variable_tokens is None: + self.__variable_tokens = tokenize_values(self._variables) + return self.__variable_tokens + + @property + def rules(self) -> list[RuleSet]: + """List of rule sets. + + Returns: + List of rules sets for this stylesheet. + """ + if self._require_parse: + self.parse() + self._require_parse = False + assert self._rules is not None + return self._rules + + @property + def rules_map(self) -> dict[str, list[RuleSet]]: + """Structure that maps a selector on to a list of rules. + + Returns: + Mapping of selector to rule sets. + """ + if self._rules_map is None: + rules_map: dict[str, list[RuleSet]] = defaultdict(list) + for rule in self.rules: + for name in rule.selector_names: + rules_map[name].append(rule) + self._rules_map = dict(rules_map) + return self._rules_map + + @property + def css(self) -> str: + """The equivalent TCSS for this stylesheet. + + Note that this may not produce the same content as the file(s) used to generate the stylesheet. + """ + return "\n\n".join(rule_set.css for rule_set in self.rules) + + def copy(self) -> Stylesheet: + """Create a copy of this stylesheet. + + Returns: + New stylesheet. + """ + stylesheet = Stylesheet(variables=self._variables.copy()) + stylesheet.source = self.source.copy() + return stylesheet + + def set_variables(self, variables: dict[str, str]) -> None: + """Set CSS variables. + + Args: + variables: A mapping of name to variable. + """ + self._variables = variables + self.__variable_tokens = None + self._invalid_css = set() + self._parse_cache.clear() + self._style_parse_cache.clear() + + def parse_style(self, style_text: str | Style) -> Style: + """Parse a (visual) Style. + + Args: + style_text: Visual style, such as "bold white 90% on $primary" + + Returns: + New Style instance. + """ + if isinstance(style_text, Style): + return style_text + if style_text in self._style_parse_cache: + return self._style_parse_cache[style_text] + style = parse_style(style_text) + self._style_parse_cache[style_text] = style + return style + + def _parse_rules( + self, + css: str, + read_from: CSSLocation, + is_default_rules: bool = False, + tie_breaker: int = 0, + scope: str = "", + ) -> list[RuleSet]: + """Parse CSS and return rules. + + Args: + css: String containing Textual CSS. + read_from: Original CSS location. + is_default_rules: True if the rules we're extracting are + default (i.e. in Widget.DEFAULT_CSS) rules. False if they're from user defined CSS. + scope: Scope of rules, or empty string for global scope. + + Raises: + StylesheetError: If the CSS is invalid. + + Returns: + List of RuleSets. + """ + cache_key = (css, read_from, is_default_rules, tie_breaker, scope) + try: + return self._parse_cache[cache_key] + except KeyError: + pass + try: + rules = list( + parse( + scope, + css, + read_from, + variable_tokens=self._variable_tokens, + is_default_rules=is_default_rules, + tie_breaker=tie_breaker, + ) + ) + + except TokenError: + raise + except Exception as error: + raise StylesheetError(f"failed to parse css; {error}") from None + + self._parse_cache[cache_key] = rules + return rules + + def read(self, filename: str | PurePath) -> None: + """Read Textual CSS file. + + Args: + filename: Filename of CSS. + + Raises: + StylesheetError: If the CSS could not be read. + StylesheetParseError: If the CSS is invalid. + """ + filename = os.path.expanduser(filename) + try: + with open(filename, "rt", encoding="utf-8") as css_file: + css = css_file.read() + path = os.path.abspath(filename) + except Exception: + raise StylesheetError(f"unable to read CSS file {filename!r}") from None + self.source[(str(path), "")] = CssSource(css, False, 0) + self._require_parse = True + + def read_all(self, paths: Sequence[PurePath]) -> None: + """Read multiple CSS files, in order. + + Args: + paths: The paths of the CSS files to read, in order. + + Raises: + StylesheetError: If the CSS could not be read. + StylesheetParseError: If the CSS is invalid. + """ + for path in paths: + self.read(path) + + def has_source(self, path: str, class_var: str = "") -> bool: + """Check if the stylesheet has this CSS source already. + + Args: + path: The file path of the source in question. + class_var: The widget class variable we might be reading the CSS from. + + Returns: + Whether the stylesheet is aware of this CSS source or not. + """ + return (path, class_var) in self.source + + def add_source( + self, + css: str, + read_from: CSSLocation | None = None, + is_default_css: bool = False, + tie_breaker: int = 0, + scope: str = "", + ) -> None: + """Parse CSS from a string. + + Args: + css: String with CSS source. + read_from: The original source location of the CSS. + path: The path of the source if a file, or some other identifier. + is_default_css: True if the CSS is defined in the Widget, False if the CSS is defined + in a user stylesheet. + tie_breaker: Integer representing the priority of this source. + scope: CSS type name to limit scope or empty string for no scope. + + Raises: + StylesheetError: If the CSS could not be read. + StylesheetParseError: If the CSS is invalid. + """ + + if read_from is None: + read_from = ("", str(hash(css))) + + if read_from in self.source and self.source[read_from].content == css: + # Location already in source and CSS is identical. + content, is_defaults, source_tie_breaker, scope = self.source[read_from] + if source_tie_breaker > tie_breaker: + self.source[read_from] = CssSource( + content, is_defaults, tie_breaker, scope + ) + return + self.source[read_from] = CssSource(css, is_default_css, tie_breaker, scope) + self._require_parse = True + self._rules_map = None + + def parse(self) -> None: + """Parse the source in the stylesheet. + + Raises: + StylesheetParseError: If there are any CSS related errors. + """ + rules: list[RuleSet] = [] + add_rules = rules.extend + + for read_from, ( + css, + is_default_rules, + tie_breaker, + scope, + ) in self.source.items(): + if css in self._invalid_css: + continue + try: + css_rules = self._parse_rules( + css, + read_from=read_from, + is_default_rules=is_default_rules, + tie_breaker=tie_breaker, + scope=scope, + ) + except Exception: + self._invalid_css.add(css) + raise + if any(rule.errors for rule in css_rules): + error_renderable = StylesheetErrors(css_rules) + self._invalid_css.add(css) + raise StylesheetParseError(error_renderable) + add_rules(css_rules) + self._rules = rules + self._require_parse = False + self._rules_map = None + + def reparse(self) -> None: + """Re-parse source, applying new variables. + + Raises: + StylesheetError: If the CSS could not be read. + StylesheetParseError: If the CSS is invalid. + """ + # Do this in a fresh Stylesheet so if there are errors we don't break self. + stylesheet = Stylesheet(variables=self._variables) + for read_from, (css, is_defaults, tie_breaker, scope) in self.source.items(): + stylesheet.add_source( + css, + read_from=read_from, + is_default_css=is_defaults, + tie_breaker=tie_breaker, + scope=scope, + ) + try: + stylesheet.parse() + except Exception: + # If we don't update self's invalid CSS, we might end up reparsing this CSS + # before Textual quits application mode. + # See https://github.com/Textualize/textual/issues/3581. + self._invalid_css.update(stylesheet._invalid_css) + raise + else: + self._rules = stylesheet.rules + self._rules_map = None + self.source = stylesheet.source + self._require_parse = False + + @classmethod + def _check_rule( + cls, rule_set: RuleSet, css_path_nodes: list[DOMNode] + ) -> Iterable[Specificity3]: + """Check a rule set, return specificity of applicable rules. + + Args: + rule_set: A rule set. + css_path_nodes: A list of the nodes from the App to the node being checked. + + Yields: + Specificity of any matching selectors. + """ + for selector_set in rule_set.selector_set: + if _check_selectors(selector_set.selectors, css_path_nodes): + yield selector_set.specificity + + # pseudo classes which iterate over multiple nodes + # These shouldn't be used in a cache key + _EXCLUDE_PSEUDO_CLASSES_FROM_CACHE: Final[set[str]] = { + "first-of-type", + "last-of-type", + "first-child", + "last-child", + "odd", + "even", + "focus-within", + "empty", + } + + def apply( + self, + node: DOMNode, + *, + animate: bool = False, + cache: dict[tuple, RulesMap] | None = None, + ) -> None: + """Apply the stylesheet to a DOM node. + + Args: + node: The `DOMNode` to apply the stylesheet to. + Applies the styles defined in this `Stylesheet` to the node. + If the same rule is defined multiple times for the node (e.g. multiple + classes modifying the same CSS property), then only the most specific + rule will be applied. + animate: Animate changed rules. + cache: An optional cache when applying a group of nodes. + """ + # Dictionary of rule attribute names e.g. "text_background" to list of tuples. + # The tuples contain the rule specificity, and the value for that rule. + # We can use this to determine, for a given rule, whether we should apply it + # or not by examining the specificity. If we have two rules for the + # same attribute, then we can choose the most specific rule and use that. + rule_attributes: defaultdict[str, list[tuple[Specificity6, object]]] + rule_attributes = defaultdict(list) + + rules_map = self.rules_map + + # Discard rules which are not applicable early + limit_rules = { + rule + for name in rules_map.keys() & node._selector_names + for rule in rules_map[name] + } + rules = list(filter(limit_rules.__contains__, reversed(self.rules))) + all_pseudo_classes = set().union(*[rule.pseudo_classes for rule in rules]) + node._has_hover_style = "hover" in all_pseudo_classes + node._has_focus_within = "focus-within" in all_pseudo_classes + node._has_order_style = not all_pseudo_classes.isdisjoint( + {"first-of-type", "last-of-type", "first-child", "last-child", "empty"} + ) + node._has_odd_or_even = ( + "odd" in all_pseudo_classes or "even" in all_pseudo_classes + ) + + cache_key: tuple | None = None + + if cache is not None and all_pseudo_classes.isdisjoint( + self._EXCLUDE_PSEUDO_CLASSES_FROM_CACHE + ): + cache_key = ( + node._parent, + ( + None + if node._id is None + else (node._id if f"#{node._id}" in rules_map else None) + ), + node.classes, + node._pseudo_classes_cache_key, + node._css_type_name, + ) + cached_result: RulesMap | None = cache.get(cache_key) + if cached_result is not None: + self.replace_rules(node, cached_result, animate=animate) + self._process_component_classes(node) + return + + _check_rule = self._check_rule + css_path_nodes = node.css_path_nodes + + # Rules that may be set to the special value `initial` + initial: set[str] = set() + # Rules in DEFAULT_CSS set to the special value `initial` + initial_defaults: set[str] = set() + + for rule in rules: + is_default_rules = rule.is_default_rules + tie_breaker = rule.tie_breaker + for base_specificity in _check_rule(rule, css_path_nodes): + for key, rule_specificity, value in rule.styles.extract_rules( + base_specificity, is_default_rules, tie_breaker + ): + if value is None: + if is_default_rules: + initial_defaults.add(key) + else: + initial.add(key) + rule_attributes[key].append((rule_specificity, value)) + + if rule_attributes: + # For each rule declared for this node, keep only the most specific one + get_first_item = itemgetter(0) + node_rules: RulesMap = cast( + RulesMap, + { + name: max(specificity_rules, key=get_first_item)[1] + for name, specificity_rules in rule_attributes.items() + }, + ) + + # Set initial values + for initial_rule_name in initial: + # Rules with a value of None should be set to the default value + if node_rules[initial_rule_name] is None: # type: ignore[literal-required] + # Exclude non default values + # rule[0] is the specificity, rule[0][0] is 0 for default rules + default_rules = [ + rule + for rule in rule_attributes[initial_rule_name] + if not rule[0][0] + ] + if default_rules: + # There is a default value + new_value = max(default_rules, key=get_first_item)[1] + node_rules[initial_rule_name] = new_value # type: ignore[literal-required] + else: + # No default value + initial_defaults.add(initial_rule_name) + + # Rules in DEFAULT_CSS set to initial + for initial_rule_name in initial_defaults: + if node_rules[initial_rule_name] is None: # type: ignore[literal-required] + default_rules = [ + rule + for rule in rule_attributes[initial_rule_name] + if rule[0][0] + ] + if default_rules: + # There is a default value + rule_value = max(default_rules, key=get_first_item)[1] + else: + rule_value = getattr(_DEFAULT_STYLES, initial_rule_name) + node_rules[initial_rule_name] = rule_value # type: ignore[literal-required] + + if cache_key is not None: + cache[cache_key] = node_rules + self.replace_rules(node, node_rules, animate=animate) + self._process_component_classes(node) + + def _process_component_classes(self, node: DOMNode) -> None: + """Process component classes for the given node. + + Args: + node: A DOM Node. + """ + component_classes = node._get_component_classes() + if component_classes: + # Create virtual nodes that exist to extract styles + refresh_node = False + old_component_styles = node._component_styles.copy() + node._component_styles.clear() + for component in sorted(component_classes): + virtual_node = DOMNode(classes=component) + virtual_node._attach(node) + self.apply(virtual_node, animate=False) + if ( + not refresh_node + and old_component_styles.get(component) != virtual_node.styles + ): + # If the styles have changed we want to refresh the node + refresh_node = True + node._component_styles[component] = virtual_node.styles + if refresh_node: + node.refresh() + + @classmethod + def replace_rules( + cls, node: DOMNode, rules: RulesMap, animate: bool = False + ) -> None: + """Replace style rules on a node, animating as required. + + Args: + node: A DOM node. + rules: Mapping of rules. + animate: Enable animation. + """ + + # Alias styles and base styles + styles = node.styles + base_styles = styles.base + + # Styles currently used on new rules + modified_rule_keys = base_styles._rules.keys() | rules.keys() + + if animate: + new_styles = Styles(node, rules) + if new_styles == base_styles: + # Nothing to animate, return early + return + current_render_rules = styles.get_render_rules() + is_animatable = styles.is_animatable + get_current_render_rule = current_render_rules.get + new_render_rules = new_styles.get_render_rules() + get_new_render_rule = new_render_rules.get + animator = node.app.animator + base = node.styles.base + for key in modified_rule_keys: + # Get old and new render rules + old_render_value = get_current_render_rule(key) + new_render_value = get_new_render_rule(key) + # Get new rule value (may be None) + new_value = rules.get(key) + + # Check if this can / should be animated. It doesn't suffice to check + # if the current and target values are different because a previous + # animation may have been scheduled but may have not started yet. + if is_animatable(key) and ( + new_render_value != old_render_value + or animator.is_being_animated(base, key) + ): + transition = new_styles._get_transition(key) + if transition is not None: + duration, easing, delay = transition + animator.animate( + base, + key, + new_render_value, + final_value=new_value, + duration=duration, + delay=delay, + easing=easing, + ) + continue + # Default is to set value (if new_value is None, rule will be removed) + setattr(base_styles, key, new_value) + else: + # Not animated, so we apply the rules directly + get_rule = rules.get + + for key in modified_rule_keys: + setattr(base_styles, key, get_rule(key)) + node.notify_style_update() + + def update(self, root: DOMNode, animate: bool = False) -> None: + """Update styles on node and its children. + + Args: + root: Root note to update. + animate: Enable CSS animation. + """ + + self.update_nodes(root.walk_children(with_self=True), animate=animate) + + def update_nodes(self, nodes: Iterable[DOMNode], animate: bool = False) -> None: + """Update styles for nodes. + + Args: + nodes: Nodes to update. + animate: Enable CSS animation. + """ + cache: dict[tuple, RulesMap] = {} + apply = self.apply + + for node in nodes: + apply(node, animate=animate, cache=cache) + if isinstance(node, Widget) and node.is_scrollable: + show_vertical_scrollbar = ( + node.show_vertical_scrollbar and node.scrollbar_size_vertical + ) + show_horizontal_scrollbar = ( + node.show_horizontal_scrollbar and node.scrollbar_size_horizontal + ) + if show_vertical_scrollbar: + apply(node.vertical_scrollbar, cache=cache) + if show_horizontal_scrollbar: + apply(node.horizontal_scrollbar, cache=cache) + if show_horizontal_scrollbar and show_vertical_scrollbar: + apply(node.scrollbar_corner, cache=cache) diff --git a/src/memray/_vendor/textual/css/tokenize.py b/src/memray/_vendor/textual/css/tokenize.py new file mode 100644 index 0000000000..fdbcece2c4 --- /dev/null +++ b/src/memray/_vendor/textual/css/tokenize.py @@ -0,0 +1,349 @@ +from __future__ import annotations + +import re +from typing import TYPE_CHECKING, ClassVar, Iterable + +from memray._vendor.textual.css.tokenizer import Expect, Token, Tokenizer + +if TYPE_CHECKING: + from memray._vendor.textual.css.types import CSSLocation + +PERCENT = r"-?\d+\.?\d*%" +DECIMAL = r"-?\d+\.?\d*" +COMMA = r"\s*,\s*" +OPEN_BRACE = r"\(\s*" +CLOSE_BRACE = r"\s*\)" + +HEX_COLOR = r"\#[0-9a-fA-F]{8}|\#[0-9a-fA-F]{6}|\#[0-9a-fA-F]{4}|\#[0-9a-fA-F]{3}" +RGB_COLOR = rf"rgb{OPEN_BRACE}{DECIMAL}{COMMA}{DECIMAL}{COMMA}{DECIMAL}{CLOSE_BRACE}|rgba{OPEN_BRACE}{DECIMAL}{COMMA}{DECIMAL}{COMMA}{DECIMAL}{COMMA}{DECIMAL}{CLOSE_BRACE}" +HSL_COLOR = rf"hsl{OPEN_BRACE}{DECIMAL}{COMMA}{PERCENT}{COMMA}{PERCENT}{CLOSE_BRACE}|hsla{OPEN_BRACE}{DECIMAL}{COMMA}{PERCENT}{COMMA}{PERCENT}{COMMA}{DECIMAL}{CLOSE_BRACE}" + +COMMENT_LINE = r"\# .*$" +COMMENT_START = r"\/\*" +SCALAR = rf"{DECIMAL}(?:fr|%|w|h|vw|vh)" +DURATION = r"\d+\.?\d*(?:ms|s)" +NUMBER = r"\-?\d+\.?\d*" +COLOR = rf"{HEX_COLOR}|{RGB_COLOR}|{HSL_COLOR}" +KEY_VALUE = r"[a-zA-Z_-][a-zA-Z0-9_-]*=[0-9a-zA-Z_\-\/]+" +TOKEN = "[a-zA-Z_][a-zA-Z0-9_-]*" +STRING = r"\".*?\"" +VARIABLE_REF = r"\$[a-zA-Z0-9_\-]+" + +IDENTIFIER = r"[a-zA-Z_\-][a-zA-Z0-9_\-]*" +SELECTOR_TYPE_NAME = r"[A-Z_][a-zA-Z0-9_]*" +"""Selectors representing Widget type names should start with upper case or '_'. + +The fact that a selector starts with an upper case letter or '_' is relevant in the +context of nested CSS to help determine whether xxx:yyy is a declaration + value or a +selector + pseudo-class.""" +DECLARATION_NAME = r"[a-z][a-zA-Z0-9_\-]*" +"""Declaration of TCSS rules start with lowercase. + +The fact that a declaration starts with a lower case letter is relevant in the context +of nested CSS to help determine whether xxx:yyy is a declaration + value or a selector ++ pseudo-class. +""" + +# Values permitted in variable and rule declarations. +DECLARATION_VALUES = { + "scalar": SCALAR, + "duration": DURATION, + "number": NUMBER, + "color": COLOR, + "key_value": KEY_VALUE, + "token": TOKEN, + "string": STRING, + "variable_ref": VARIABLE_REF, +} + +# The tokenizers "expectation" while at the root/highest level of scope +# in the CSS file. At this level we might expect to see selectors, comments, +# variable definitions etc. +expect_root_scope = Expect( + "selector or end of file", + whitespace=r"\s+", + comment_start=COMMENT_START, + comment_line=COMMENT_LINE, + selector_start_id=r"\#" + IDENTIFIER, + selector_start_class=r"\." + IDENTIFIER, + selector_start_universal=r"\*", + selector_start=SELECTOR_TYPE_NAME, + variable_name=rf"{VARIABLE_REF}:", + declaration_set_end=r"\}", +).expect_eof(True) + +expect_root_nested = Expect( + "selector or end of file", + whitespace=r"\s+", + comment_start=COMMENT_START, + comment_line=COMMENT_LINE, + declaration_name=DECLARATION_NAME + r"\:", + selector_start_id=r"\#" + IDENTIFIER, + selector_start_class=r"\." + IDENTIFIER, + selector_start_universal=r"\*", + selector_start=SELECTOR_TYPE_NAME, + variable_name=rf"{VARIABLE_REF}:", + declaration_set_end=r"\}", + nested=r"\&", +) + +# After a variable declaration e.g. "$warning-text: TOKENS;" +# for tokenizing variable value ------^~~~~~~^ +expect_variable_name_continue = Expect( + "variable value", + variable_value_end=r"\n|;", + whitespace=r"\s+", + comment_start=COMMENT_START, + comment_line=COMMENT_LINE, + **DECLARATION_VALUES, +).expect_eof(True) + +expect_comment_end = Expect( + "comment end", + comment_end=re.escape("*/"), +) + +# After we come across a selector in CSS e.g. ".my-class", we may +# find other selectors, pseudo-classes... e.g. ".my-class :hover" +expect_selector_continue = Expect( + "selector or {", + whitespace=r"\s+", + comment_start=COMMENT_START, + comment_line=COMMENT_LINE, + pseudo_class=r"\:[a-zA-Z_-]+", + selector_id=r"\#" + IDENTIFIER, + selector_class=r"\." + IDENTIFIER, + selector_universal=r"\*", + selector=SELECTOR_TYPE_NAME, + combinator_child=">", + new_selector=r",", + declaration_set_start=r"\{", + declaration_set_end=r"\}", + nested=r"\&", +).expect_eof(True) + +# A rule declaration e.g. "text: red;" +# ^---^ +expect_declaration = Expect( + "rule or selector", + nested=r"\&", + whitespace=r"\s+", + comment_start=COMMENT_START, + comment_line=COMMENT_LINE, + declaration_name=DECLARATION_NAME + r"\:", + declaration_set_end=r"\}", + # + selector_start_id=r"\#" + IDENTIFIER, + selector_start_class=r"\." + IDENTIFIER, + selector_start_universal=r"\*", + selector_start=SELECTOR_TYPE_NAME, +) + +expect_declaration_solo = Expect( + "rule declaration", + whitespace=r"\s+", + comment_start=COMMENT_START, + comment_line=COMMENT_LINE, + declaration_name=DECLARATION_NAME + r"\:", + declaration_set_end=r"\}", +).expect_eof(True) + +# The value(s)/content from a rule declaration e.g. "text: red;" +# ^---^ +expect_declaration_content = Expect( + "rule value or end of declaration", + declaration_end=r";", + whitespace=r"\s+", + comment_start=COMMENT_START, + comment_line=COMMENT_LINE, + **DECLARATION_VALUES, + important=r"\!important", + comma=",", + declaration_set_end=r"\}", +) + +expect_declaration_content_solo = Expect( + "rule value or end of declaration", + declaration_end=r";", + whitespace=r"\s+", + comment_start=COMMENT_START, + comment_line=COMMENT_LINE, + **DECLARATION_VALUES, + important=r"\!important", + comma=",", + declaration_set_end=r"\}", +).expect_eof(True) + + +class TokenizerState: + EXPECT: ClassVar[Expect] = expect_root_scope + STATE_MAP: ClassVar[dict[str, Expect]] = {} + STATE_PUSH: ClassVar[dict[str, Expect]] = {} + STATE_POP: ClassVar[dict[str, str]] = {} + + def __init__(self) -> None: + self._expect: Expect = self.EXPECT + super().__init__() + + def expect(self, expect: Expect) -> None: + self._expect = expect + + def __call__(self, code: str, read_from: CSSLocation) -> Iterable[Token]: + tokenizer = Tokenizer(code, read_from=read_from) + get_token = tokenizer.get_token + get_state = self.STATE_MAP.get + state_stack: list[Expect] = [] + + while True: + expect = self._expect + token = get_token(expect) + name = token.name + if name in self.STATE_MAP: + self._expect = get_state(token.name, expect) + elif name in self.STATE_PUSH: + self._expect = self.STATE_PUSH[name] + state_stack.append(expect) + elif name in self.STATE_POP: + if state_stack: + self._expect = state_stack.pop() + else: + self._expect = self.EXPECT + token = token._replace(name="end_tag") + yield token + continue + + yield token + if name == "eof": + break + + +class TCSSTokenizerState: + """State machine for the tokenizer. + + Attributes: + EXPECT: The initial expectation of the tokenizer. Since we start tokenizing + at the root scope, we might expect to see either a variable or selector, for example. + STATE_MAP: Maps token names to Expects, defines the sets of valid tokens + that we'd expect to see next, given the current token. For example, if + we've just processed a variable declaration name, we next expect to see + the value of that variable. + """ + + EXPECT = expect_root_scope + STATE_MAP = { + "variable_name": expect_variable_name_continue, + "variable_value_end": expect_root_scope, + "selector_start": expect_selector_continue, + "selector_start_id": expect_selector_continue, + "selector_start_class": expect_selector_continue, + "selector_start_universal": expect_selector_continue, + "selector_id": expect_selector_continue, + "selector_class": expect_selector_continue, + "selector_universal": expect_selector_continue, + "declaration_set_start": expect_declaration, + "declaration_name": expect_declaration_content, + "declaration_end": expect_declaration, + "declaration_set_end": expect_root_nested, + "nested": expect_selector_continue, + } + + def __call__(self, code: str, read_from: CSSLocation) -> Iterable[Token]: + tokenizer = Tokenizer(code, read_from=read_from) + expect = self.EXPECT + get_token = tokenizer.get_token + get_state = self.STATE_MAP.get + nest_level = 0 + while True: + token = get_token(expect) + name = token.name + if name == "comment_line": + continue + elif name == "comment_start": + tokenizer.skip_to(expect_comment_end) + continue + elif name == "eof": + break + elif name == "declaration_set_start": + nest_level += 1 + elif name == "declaration_set_end": + nest_level -= 1 + expect = expect_declaration if nest_level else expect_root_scope + yield token + continue + expect = get_state(name, expect) + yield token + + +class DeclarationTokenizerState(TCSSTokenizerState): + EXPECT = expect_declaration_solo + STATE_MAP = { + "declaration_name": expect_declaration_content, + "declaration_end": expect_declaration_solo, + } + + +class ValueTokenizerState(TCSSTokenizerState): + EXPECT = expect_declaration_content_solo + + +class StyleTokenizerState(TCSSTokenizerState): + EXPECT = ( + Expect( + "style token", + key_value=r"[@a-zA-Z_-][a-zA-Z0-9_-]*=.*", + key_value_quote=r"[@a-zA-Z_-][a-zA-Z0-9_-]*='.*'", + key_value_double_quote=r"""[@a-zA-Z_-][a-zA-Z0-9_-]*=".*\"""", + percent=PERCENT, + color=COLOR, + token=TOKEN, + variable_ref=VARIABLE_REF, + whitespace=r"\s+", + ) + .expect_eof(True) + .expect_semicolon(False) + ) + + +tokenize = TCSSTokenizerState() +tokenize_declarations = DeclarationTokenizerState() +tokenize_value = ValueTokenizerState() +tokenize_style = StyleTokenizerState() + + +def tokenize_values(values: dict[str, str]) -> dict[str, list[Token]]: + """Tokenizes the values in a dict of strings. + + Args: + values: A mapping of CSS variable name on to a value, to be + added to the CSS context. + + Returns: + A mapping of name on to a list of tokens, + """ + value_tokens = { + name: list(tokenize_value(value, ("__name__", ""))) + for name, value in values.items() + } + return value_tokens + + +if __name__ == "__main__": + text = "[@click=app.notify(['foo', 500])] Click me! [/] :-)" + + # text = "[@click=hello]Click" + from rich.console import Console + + c = Console(markup=False) + + from memray._vendor.textual._profile import timer + + with timer("tokenize"): + list(tokenize_markup(text, read_from=("", ""))) + + from memray._vendor.textual.markup import _parse + + with timer("_parse"): + list(_parse(text)) + + for token in tokenize_markup(text, read_from=("", "")): + c.print(repr(token)) diff --git a/src/memray/_vendor/textual/css/tokenizer.py b/src/memray/_vendor/textual/css/tokenizer.py new file mode 100644 index 0000000000..9a8fcee43a --- /dev/null +++ b/src/memray/_vendor/textual/css/tokenizer.py @@ -0,0 +1,384 @@ +from __future__ import annotations + +import re +from typing import TYPE_CHECKING, NamedTuple + +import rich.repr +from rich.console import Group, RenderableType +from rich.highlighter import ReprHighlighter +from rich.padding import Padding +from rich.panel import Panel +from rich.text import Text + +from memray._vendor.textual.css._error_tools import friendly_list +from memray._vendor.textual.css.constants import VALID_PSEUDO_CLASSES +from memray._vendor.textual.suggestions import get_suggestion + +if TYPE_CHECKING: + from memray._vendor.textual.css.types import CSSLocation + + +class TokenError(Exception): + """Error raised when the CSS cannot be tokenized (syntax error).""" + + def __init__( + self, + read_from: CSSLocation, + code: str, + start: tuple[int, int], + message: str, + end: tuple[int, int] | None = None, + ) -> None: + """ + Args: + read_from: The location where the CSS was read from. + code: The code being parsed. + start: Line and column number of the error (1-indexed). + message: A message associated with the error. + end: End location of token (1-indexed), or None if not known. + """ + + self.read_from = read_from + self.code = code + self.start = start + self.end = end or start + super().__init__(message) + + def _get_snippet(self) -> Panel: + """Get a short snippet of code around a given line number. + + Returns: + A renderable. + """ + from rich.syntax import Syntax + + line_no = self.start[0] + # TODO: Highlight column number + syntax = Syntax( + self.code, + lexer="scss", + theme="ansi_light", + line_numbers=True, + indent_guides=True, + line_range=(max(0, line_no - 2), line_no + 2), + highlight_lines={line_no}, + ) + syntax.stylize_range( + "reverse bold", + (self.start[0], self.start[1] - 1), + (self.end[0], self.end[1] - 1), + ) + return Panel(syntax, border_style="red") + + def __rich__(self) -> RenderableType: + highlighter = ReprHighlighter() + errors: list[RenderableType] = [] + + message = str(self) + errors.append(Text(" Error in stylesheet:", style="bold red")) + + line_no, col_no = self.start + + path, widget_variable = self.read_from + if widget_variable: + css_location = f" {path}, {widget_variable}:{line_no}:{col_no}" + else: + css_location = f" {path}:{line_no}:{col_no}" + errors.append(highlighter(css_location)) + errors.append(self._get_snippet()) + + final_message = "\n".join( + f"• {message_part.strip()}" for message_part in message.split(";") + ) + errors.append( + Padding( + highlighter( + Text(final_message, "red"), + ), + pad=(0, 1), + ) + ) + + return Group(*errors) + + +class UnexpectedEnd(TokenError): + """Indicates that the text being tokenized ended prematurely.""" + + +@rich.repr.auto +class Expect: + """Object that describes the format of tokens.""" + + def __init__(self, description: str, **tokens: str) -> None: + """Create Expect object. + + Args: + description: Description of this class of tokens, used in errors. + """ + self.description = f"Expected {description}" + self.names = list(tokens.keys()) + self.regexes = list(tokens.values()) + self._regex = re.compile( + "(" + + "|".join(f"(?P<{name}>{regex})" for name, regex in tokens.items()) + + ")" + ) + self.match = self._regex.match + self.search = self._regex.search + self._expect_eof = False + self._expect_semicolon = True + self._extract_text = False + + def expect_eof(self, eof: bool = True) -> Expect: + """Expect an end of file.""" + self._expect_eof = eof + return self + + def expect_semicolon(self, semicolon: bool = True) -> Expect: + """Tokenizer expects text to be terminated with a semi-colon.""" + self._expect_semicolon = semicolon + return self + + def extract_text(self, extract: bool = True) -> Expect: + self._extract_text = extract + return self + + def __rich_repr__(self) -> rich.repr.Result: + yield from zip(self.names, self.regexes) + + +class ReferencedBy(NamedTuple): + name: str + location: tuple[int, int] + length: int + code: str + + +@rich.repr.auto(angular=True) +class Token(NamedTuple): + name: str + value: str + read_from: CSSLocation + code: str + location: tuple[int, int] + """Token starting location, 0-indexed.""" + referenced_by: ReferencedBy | None = None + + @property + def start(self) -> tuple[int, int]: + """Start line and column (1-indexed).""" + line, offset = self.location + return (line + 1, offset + 1) + + @property + def end(self) -> tuple[int, int]: + """End line and column (1-indexed).""" + line, offset = self.location + return (line + 1, offset + len(self.value) + 1) + + def with_reference(self, by: ReferencedBy | None) -> "Token": + """Return a copy of the Token, with reference information attached. + This is used for variable substitution, where a variable reference + can refer to tokens which were defined elsewhere. With the additional + ReferencedBy data attached, we can track where the token we are referring + to is used. + """ + return Token( + name=self.name, + value=self.value, + read_from=self.read_from, + code=self.code, + location=self.location, + referenced_by=by, + ) + + def __str__(self) -> str: + return self.value + + def __rich_repr__(self) -> rich.repr.Result: + yield "name", self.name + yield "value", self.value + yield ( + "read_from", + self.read_from[0] if not self.read_from[1] else self.read_from, + ) + yield "code", self.code if len(self.code) < 40 else self.code[:40] + "..." + yield "location", self.location + yield "referenced_by", self.referenced_by, None + + +class Tokenizer: + """Tokenizes Textual CSS.""" + + def __init__(self, text: str, read_from: CSSLocation = ("", "")) -> None: + """Initialize the tokenizer. + + Args: + text: String containing CSS. + read_from: Information regarding where the CSS was read from. + """ + self.read_from = read_from + self.code = text + self.lines = text.splitlines(keepends=True) + self.line_no = 0 + self.col_no = 0 + + def get_token(self, expect: Expect) -> Token: + """Get the next token. + + Args: + expect: Expect object which describes which tokens may be read. + + Raises: + UnexpectedEnd: If there is an unexpected end of file. + TokenError: If there is an error with the token. + + Returns: + A new Token. + """ + + line_no = self.line_no + col_no = self.col_no + if line_no >= len(self.lines): + if expect._expect_eof: + return Token( + "eof", + "", + self.read_from, + self.code, + (line_no, col_no), + None, + ) + else: + raise UnexpectedEnd( + self.read_from, + self.code, + (line_no + 1, col_no + 1), + ( + "Unexpected end of file; did you forget a '}' ?" + if expect._expect_semicolon + else "Unexpected end of text" + ), + ) + line = self.lines[line_no] + preceding_text: str = "" + if expect._extract_text: + match = expect.search(line, col_no) + if match is None: + preceding_text = line[self.col_no :] + self.line_no += 1 + self.col_no = 0 + else: + col_no = match.start() + preceding_text = line[self.col_no : col_no] + self.col_no = col_no + if preceding_text: + token = Token( + "text", + preceding_text, + self.read_from, + self.code, + (line_no, col_no), + referenced_by=None, + ) + + return token + + else: + match = expect.match(line, col_no) + + if match is None: + error_line = line[col_no:] + error_message = ( + f"{expect.description} (found {error_line.split(';')[0]!r})." + ) + if expect._expect_semicolon and not error_line.endswith(";"): + error_message += "; Did you forget a semicolon at the end of a line?" + raise TokenError( + self.read_from, self.code, (line_no + 1, col_no + 1), error_message + ) + + for name, value in zip(expect.names, match.groups()[1:]): + if value is not None: + break + else: + # For MyPy's benefit + raise AssertionError("can't reach here") + + token = Token( + name, + value, + self.read_from, + self.code, + (line_no, col_no), + referenced_by=None, + ) + + if ( + token.name == "pseudo_class" + and token.value.strip(":") not in VALID_PSEUDO_CLASSES + ): + pseudo_class = token.value.strip(":") + suggestion = get_suggestion(pseudo_class, list(VALID_PSEUDO_CLASSES)) + all_valid = f"must be one of {friendly_list(VALID_PSEUDO_CLASSES)}" + if suggestion: + raise TokenError( + self.read_from, + self.code, + (line_no + 1, col_no + 1), + f"unknown pseudo-class {pseudo_class!r}; did you mean {suggestion!r}?; {all_valid}", + ) + else: + raise TokenError( + self.read_from, + self.code, + (line_no + 1, col_no + 1), + f"unknown pseudo-class {pseudo_class!r}; {all_valid}", + ) + + col_no += len(value) + if col_no >= len(line): + line_no += 1 + col_no = 0 + self.line_no = line_no + self.col_no = col_no + return token + + def skip_to(self, expect: Expect) -> Token: + """Skip tokens. + + Args: + expect: Expect object describing the expected token. + + Raises: + UnexpectedEndOfText: If end of file is reached. + + Returns: + A new token. + """ + line_no = self.line_no + col_no = self.col_no + + while True: + if line_no >= len(self.lines): + raise UnexpectedEnd( + self.read_from, + self.code, + (line_no, col_no), + ( + "Unexpected end of file; did you forget a '}' ?" + if expect._expect_semicolon + else "Unexpected end of markup" + ), + ) + line = self.lines[line_no] + match = expect.search(line, col_no) + + if match is None: + line_no += 1 + col_no = 0 + else: + self.line_no = line_no + self.col_no = match.span(0)[0] + return self.get_token(expect) diff --git a/src/memray/_vendor/textual/css/transition.py b/src/memray/_vendor/textual/css/transition.py new file mode 100644 index 0000000000..9d9facf102 --- /dev/null +++ b/src/memray/_vendor/textual/css/transition.py @@ -0,0 +1,16 @@ +from typing import NamedTuple + + +class Transition(NamedTuple): + duration: float = 1.0 + easing: str = "linear" + delay: float = 0.0 + + def __str__(self) -> str: + duration, easing, delay = self + if delay: + return f"{duration:.1f}s {easing} {delay:.1f}" + elif easing != "linear": + return f"{duration:.1f}s {easing}" + else: + return f"{duration:.1f}s" diff --git a/src/memray/_vendor/textual/css/types.py b/src/memray/_vendor/textual/css/types.py new file mode 100644 index 0000000000..56d0fbb212 --- /dev/null +++ b/src/memray/_vendor/textual/css/types.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +from typing import Tuple + +from typing_extensions import Literal + +from memray._vendor.textual.color import Color + +DockEdge = Literal["none", "top", "right", "bottom", "left"] +EdgeType = Literal[ + "", + "ascii", + "none", + "hidden", + "blank", + "round", + "solid", + "thick", + "block", + "double", + "dashed", + "heavy", + "inner", + "outer", + "hkey", + "vkey", + "tall", + "tab", + "panel", + "wide", +] +Visibility = Literal["visible", "hidden", "initial", "inherit"] +Display = Literal["block", "none"] +AlignHorizontal = Literal["left", "center", "right"] +AlignVertical = Literal["top", "middle", "bottom"] +ScrollbarGutter = Literal["auto", "stable"] +BoxSizing = Literal["border-box", "content-box"] +Overflow = Literal["scroll", "hidden", "auto"] +EdgeStyle = Tuple[EdgeType, Color] +TextAlign = Literal["left", "start", "center", "right", "end", "justify"] +Constrain = Literal["none", "inflect", "inside"] +Overlay = Literal["none", "screen"] +Position = Literal["relative", "absolute"] +PointerShape = Literal[ + "alias", + "cell", + "copy", + "crosshair", + "default", + "e-resize", + "ew-resize", + "grab", + "grabbing", + "help", + "move", + "n-resize", + "ne-resize", + "nesw-resize", + "no-drop", + "not-allowed", + "ns-resize", + "nw-resize", + "nwse-resize", + "pointer", + "progress", + "s-resize", + "se-resize", + "sw-resize", + "text", + "vertical-text", + "w-resize", + "wait", + "zoom-in", + "zoom-out", +] + +TextWrap = Literal["wrap", "nowrap"] +TextOverflow = Literal["clip", "fold", "ellipsis"] +Expand = Literal["greedy", "expand"] +ScrollbarVisibility = Literal["visible", "hidden"] + +Specificity3 = Tuple[int, int, int] +Specificity6 = Tuple[int, int, int, int, int, int] + +CSSLocation = Tuple[str, str] +"""Represents the definition location of a piece of CSS code. + +The first element of the tuple is the file path from where the CSS was read. +If the CSS was read from a Python source file, the second element contains the class +variable from where the CSS was read (e.g., "Widget.DEFAULT_CSS"), otherwise it's an +empty string. +""" diff --git a/src/memray/_vendor/textual/demo/__main__.py b/src/memray/_vendor/textual/demo/__main__.py new file mode 100644 index 0000000000..db96cfb63b --- /dev/null +++ b/src/memray/_vendor/textual/demo/__main__.py @@ -0,0 +1,5 @@ +from memray._vendor.textual.demo.demo_app import DemoApp + +if __name__ == "__main__": + app = DemoApp() + app.run() diff --git a/src/memray/_vendor/textual/demo/_project_data.py b/src/memray/_vendor/textual/demo/_project_data.py new file mode 100644 index 0000000000..6dcb00c8f7 --- /dev/null +++ b/src/memray/_vendor/textual/demo/_project_data.py @@ -0,0 +1,107 @@ +from dataclasses import dataclass + + +@dataclass +class ProjectInfo: + """Dataclass for storing project information.""" + + title: str + author: str + url: str + description: str + repo_url_part: str + + +PROJECTS = [ + ProjectInfo( + "Posting", + "Darren Burns", + "https://posting.sh/", + "Posting is an HTTP client, not unlike Postman and Insomnia. As a TUI application, it can be used over SSH and enables efficient keyboard-centric workflows. ", + "darrenburns/posting", + ), + ProjectInfo( + "Memray", + "Bloomberg", + "https://github.com/bloomberg/memray", + "Memray is a memory profiler for Python. It can track memory allocations in Python code, in native extension modules, and in the Python interpreter itself.", + "bloomberg/memray", + ), + ProjectInfo( + "Toolong", + "Will McGugan", + "https://github.com/Textualize/toolong", + "A terminal application to view, tail, merge, and search log files (plus JSONL).", + "Textualize/toolong", + ), + ProjectInfo( + "Dolphie", + "Charles Thompson", + "https://github.com/charles-001/dolphie", + "Your single pane of glass for real-time analytics into MySQL/MariaDB & ProxySQL", + "charles-001/dolphie", + ), + ProjectInfo( + "Harlequin", + "Ted Conbeer", + "https://harlequin.sh/", + "Portable, powerful, colorful. An easy, fast, and beautiful database client for the terminal.", + "tconbeer/harlequin", + ), + ProjectInfo( + "Elia", + "Darren Burns", + "https://github.com/darrenburns/elia", + "A snappy, keyboard-centric terminal user interface for interacting with large language models.", + "darrenburns/elia", + ), + ProjectInfo( + "Trogon", + "Textualize", + "https://github.com/Textualize/trogon", + "Auto-generate friendly terminal user interfaces for command line apps.", + "Textualize/trogon", + ), + ProjectInfo( + "TFTUI - The Terraform textual UI", + "Ido Avraham", + "https://github.com/idoavrah/terraform-tui", + "TFTUI is a powerful textual UI that empowers users to effortlessly view and interact with their Terraform state.", + "idoavrah/terraform-tui", + ), + ProjectInfo( + "RecoverPy", + "Pablo Lecolinet", + "https://github.com/PabloLec/RecoverPy", + "RecoverPy is a powerful tool that leverages your system capabilities to recover lost files.", + "PabloLec/RecoverPy", + ), + ProjectInfo( + "Frogmouth", + "Dave Pearson", + "https://github.com/Textualize/frogmouth", + "Frogmouth is a Markdown viewer / browser for your terminal, built with Textual.", + "Textualize/frogmouth", + ), + ProjectInfo( + "oterm", + "Yiorgis Gozadinos", + "https://github.com/ggozad/oterm", + "The text-based terminal client for Ollama.", + "ggozad/oterm", + ), + ProjectInfo( + "logmerger", + "Paul McGuire", + "https://github.com/ptmcg/logmerger", + "logmerger is a TUI for viewing a merged display of multiple log files, merged by timestamp.", + "ptmcg/logmerger", + ), + ProjectInfo( + "doit", + "Murli Tawari", + "https://github.com/dooit-org/dooit", + "A todo manager that you didn't ask for, but needed!", + "dooit-org/dooit", + ), +] diff --git a/src/memray/_vendor/textual/demo/_project_stargazer_updater.py b/src/memray/_vendor/textual/demo/_project_stargazer_updater.py new file mode 100644 index 0000000000..ab2d0f0c92 --- /dev/null +++ b/src/memray/_vendor/textual/demo/_project_stargazer_updater.py @@ -0,0 +1,50 @@ +import httpx +import os +import json +from rich.console import Console + +# Not using the Absolute reference because +# I can't get python to run it. +from _project_data import PROJECTS + +console = Console() +error_console = Console(stderr=True, style="bold red") + + +def main() -> None: + STARS = {} + + for project in PROJECTS: + # get each repo + console.log(f"Checking {project.repo_url_part}") + response = httpx.get(f"https://api.github.com/repos/{project.repo_url_part}") + if response.status_code == 200: + # get stargazers + stargazers = response.json()["stargazers_count"] + if stargazers // 1000 != 0: + # humanize them + stargazers = f"{stargazers / 1000:.1f}k" + else: + stargazers = str(stargazers) + STARS[project.title] = stargazers + elif response.status_code == 403: + # gh api rate limited + error_console.log( + "GitHub has received too many requests and started rate limiting." + ) + exit(1) + else: + # any other reason + print( + f"GET https://api.github.com/repos/{project.repo_url_part} returned status code {response.status_code}" + ) + # replace + with open( + os.path.join(os.path.dirname(__file__), "_project_stars.py"), "w" + ) as file: + file.write("STARS = " + json.dumps(STARS, indent=4)) + console.log("Done!") + + +if __name__ == "__main__": + main() diff --git a/src/memray/_vendor/textual/demo/_project_stars.py b/src/memray/_vendor/textual/demo/_project_stars.py new file mode 100644 index 0000000000..e187559ee3 --- /dev/null +++ b/src/memray/_vendor/textual/demo/_project_stars.py @@ -0,0 +1,15 @@ +STARS = { + "Posting": "11.4k", + "Memray": "14.9k", + "Toolong": "3.9k", + "Dolphie": "1.1k", + "Harlequin": "5.8k", + "Elia": "2.4k", + "Trogon": "2.8k", + "TFTUI - The Terraform textual UI": "1.3k", + "RecoverPy": "1.7k", + "Frogmouth": "3.1k", + "oterm": "2.3k", + "logmerger": "250", + "doit": "2.8k", +} diff --git a/src/memray/_vendor/textual/demo/data.py b/src/memray/_vendor/textual/demo/data.py new file mode 100644 index 0000000000..01ca66e58d --- /dev/null +++ b/src/memray/_vendor/textual/demo/data.py @@ -0,0 +1,458 @@ +import json + +COUNTRIES = [ + "Afghanistan", + "Albania", + "Algeria", + "Andorra", + "Angola", + "Antigua and Barbuda", + "Argentina", + "Armenia", + "Australia", + "Austria", + "Azerbaijan", + "Bahamas", + "Bahrain", + "Bangladesh", + "Barbados", + "Belarus", + "Belgium", + "Belize", + "Benin", + "Bhutan", + "Bolivia", + "Bosnia and Herzegovina", + "Botswana", + "Brazil", + "Brunei", + "Bulgaria", + "Burkina Faso", + "Burundi", + "Cabo Verde", + "Cambodia", + "Cameroon", + "Canada", + "Central African Republic", + "Chad", + "Chile", + "China", + "Colombia", + "Comoros", + "Congo", + "Costa Rica", + "Croatia", + "Cuba", + "Cyprus", + "Czech Republic", + "Democratic Republic of the Congo", + "Denmark", + "Djibouti", + "Dominica", + "Dominican Republic", + "East Timor", + "Ecuador", + "Egypt", + "El Salvador", + "Equatorial Guinea", + "Eritrea", + "Estonia", + "Eswatini", + "Ethiopia", + "Fiji", + "Finland", + "France", + "Gabon", + "Gambia", + "Georgia", + "Germany", + "Ghana", + "Greece", + "Grenada", + "Guatemala", + "Guinea", + "Guinea-Bissau", + "Guyana", + "Haiti", + "Honduras", + "Hungary", + "Iceland", + "India", + "Indonesia", + "Iran", + "Iraq", + "Ireland", + "Israel", + "Italy", + "Ivory Coast", + "Jamaica", + "Japan", + "Jordan", + "Kazakhstan", + "Kenya", + "Kiribati", + "Kuwait", + "Kyrgyzstan", + "Laos", + "Latvia", + "Lebanon", + "Lesotho", + "Liberia", + "Libya", + "Liechtenstein", + "Lithuania", + "Luxembourg", + "Madagascar", + "Malawi", + "Malaysia", + "Maldives", + "Mali", + "Malta", + "Marshall Islands", + "Mauritania", + "Mauritius", + "Mexico", + "Micronesia", + "Moldova", + "Monaco", + "Mongolia", + "Montenegro", + "Morocco", + "Mozambique", + "Myanmar", + "Namibia", + "Nauru", + "Nepal", + "Netherlands", + "New Zealand", + "Nicaragua", + "Niger", + "Nigeria", + "North Korea", + "North Macedonia", + "Norway", + "Oman", + "Pakistan", + "Palau", + "Palestine", + "Panama", + "Papua New Guinea", + "Paraguay", + "Peru", + "Philippines", + "Poland", + "Portugal", + "Qatar", + "Romania", + "Russia", + "Rwanda", + "Saint Kitts and Nevis", + "Saint Lucia", + "Saint Vincent and the Grenadines", + "Samoa", + "San Marino", + "Sao Tome and Principe", + "Saudi Arabia", + "Senegal", + "Serbia", + "Seychelles", + "Sierra Leone", + "Singapore", + "Slovakia", + "Slovenia", + "Solomon Islands", + "Somalia", + "South Africa", + "South Korea", + "South Sudan", + "Spain", + "Sri Lanka", + "Sudan", + "Suriname", + "Sweden", + "Switzerland", + "Syria", + "Taiwan", + "Tajikistan", + "Tanzania", + "Thailand", + "Togo", + "Tonga", + "Trinidad and Tobago", + "Tunisia", + "Turkey", + "Turkmenistan", + "Tuvalu", + "Uganda", + "Ukraine", + "United Arab Emirates", + "United Kingdom", + "United States", + "Uruguay", + "Uzbekistan", + "Vanuatu", + "Vatican City", + "Venezuela", + "Vietnam", + "Yemen", + "Zambia", + "Zimbabwe", +] +# Sort by length for auto-complete +COUNTRIES.sort(key=str.__len__) + +# Thanks, Claude +MOVIES = """\ +Date,Title,Genre,Director,Box Office (millions),Rating,Runtime (min) +1980-01-18,The Fog,Horror,John Carpenter,21,R,89 +1980-02-15,Coal Miner's Daughter,Biography,Michael Apted,67,PG,124 +1980-03-07,Little Miss Marker,Comedy,Walter Bernstein,12,PG,103 +1980-04-11,The Long Riders,Western,Walter Hill,15,R,100 +1980-05-21,The Empire Strikes Back,Sci-Fi,Irvin Kershner,538,PG,124 +1980-06-13,The Blues Brothers,Comedy,John Landis,115,R,133 +1980-07-02,Airplane!,Comedy,Jim Abrahams,83,PG,88 +1980-08-01,Caddyshack,Comedy,Harold Ramis,39,R,98 +1980-09-19,The Big Red One,War,Samuel Fuller,24,PG,113 +1980-10-10,Private Benjamin,Comedy,Howard Zieff,69,R,109 +1980-11-07,The Stunt Man,Action,Richard Rush,7,R,131 +1980-12-19,Nine to Five,Comedy,Colin Higgins,103,PG,109 +1981-01-23,Scanners,Horror,David Cronenberg,14,R,103 +1981-02-20,The Final Conflict,Horror,Graham Baker,20,R,108 +1981-03-20,Raiders of the Lost Ark,Action,Steven Spielberg,389,PG,115 +1981-04-10,Excalibur,Fantasy,John Boorman,35,R,140 +1981-05-22,Outland,Sci-Fi,Peter Hyams,17,R,109 +1981-06-19,Superman II,Action,Richard Lester,108,PG,127 +1981-07-17,Escape from New York,Sci-Fi,John Carpenter,25,R,99 +1981-08-07,An American Werewolf in London,Horror,John Landis,30,R,97 +1981-09-25,Continental Divide,Romance,Michael Apted,15,PG,103 +1981-10-16,True Confessions,Drama,Ulu Grosbard,12,R,108 +1981-11-20,Time Bandits,Fantasy,Terry Gilliam,42,PG,116 +1981-12-04,Rollover,Drama,Alan J. Pakula,11,R,116 +1982-01-15,The Beast Within,Horror,Philippe Mora,7,R,98 +1982-02-12,Quest for Fire,Adventure,Jean-Jacques Annaud,20,R,100 +1982-03-19,Porky's,Comedy,Bob Clark,105,R,94 +1982-04-16,The Sword and the Sorcerer,Fantasy,Albert Pyun,39,R,99 +1982-05-14,Conan the Barbarian,Fantasy,John Milius,68,R,129 +1982-06-04,Star Trek II: The Wrath of Khan,Sci-Fi,Nicholas Meyer,97,PG,113 +1982-06-11,E.T. the Extra-Terrestrial,Sci-Fi,Steven Spielberg,792,PG,115 +1982-06-25,Blade Runner,Sci-Fi,Ridley Scott,33,R,117 +1982-07-16,The World According to Garp,Comedy-Drama,George Roy Hill,29,R,136 +1982-08-13,Fast Times at Ridgemont High,Comedy,Amy Heckerling,27,R,90 +1982-09-17,The Challenge,Action,John Frankenheimer,9,R,108 +1982-10-22,First Blood,Action,Ted Kotcheff,47,R,93 +1982-11-12,The Man from Snowy River,Western,George Miller,20,PG,102 +1982-12-08,48 Hrs.,Action,Walter Hill,79,R,96 +1983-01-21,The Entity,Horror,Sidney J. Furie,13,R,125 +1983-02-18,The Year of Living Dangerously,Drama,Peter Weir,10,PG,115 +1983-03-25,The Outsiders,Drama,Francis Ford Coppola,25,PG,91 +1983-04-22,Something Wicked This Way Comes,Horror,Jack Clayton,5,PG,95 +1983-05-25,Return of the Jedi,Sci-Fi,Richard Marquand,475,PG,131 +1983-06-17,Superman III,Action,Richard Lester,60,PG,125 +1983-07-15,Class,Comedy,Lewis John Carlino,21,R,98 +1983-08-19,Curse of the Pink Panther,Comedy,Blake Edwards,9,PG,109 +1983-09-23,The Big Chill,Drama,Lawrence Kasdan,56,R,105 +1983-10-07,The Right Stuff,Drama,Philip Kaufman,21,PG,193 +1983-11-04,Deal of the Century,Comedy,William Friedkin,10,PG,99 +1983-12-09,Scarface,Crime,Brian De Palma,65,R,170 +1984-01-13,Terms of Endearment,Drama,James L. Brooks,108,PG,132 +1984-02-17,Unfaithfully Yours,Comedy,Howard Zieff,12,PG,96 +1984-03-16,Splash,Romance,Ron Howard,69,PG,111 +1984-04-13,Friday the 13th: The Final Chapter,Horror,Joseph Zito,32,R,91 +1984-05-04,Sixteen Candles,Comedy,John Hughes,23,PG,93 +1984-06-08,Ghostbusters,Comedy,Ivan Reitman,295,PG,105 +1984-07-06,The Last Starfighter,Sci-Fi,Nick Castle,28,PG,101 +1984-08-10,Red Dawn,Action,John Milius,38,PG-13,114 +1984-09-14,All of Me,Comedy,Carl Reiner,40,PG,93 +1984-10-26,The Terminator,Sci-Fi,James Cameron,78,R,107 +1984-11-16,Missing in Action,Action,Joseph Zito,22,R,101 +1984-12-14,Dune,Sci-Fi,David Lynch,30,PG-13,137 +1985-01-18,A Nightmare on Elm Street,Horror,Wes Craven,25,R,91 +1985-02-15,The Breakfast Club,Drama,John Hughes,45,R,97 +1985-03-29,Mask,Drama,Peter Bogdanovich,42,PG-13,120 +1985-04-26,Code of Silence,Action,Andrew Davis,20,R,101 +1985-05-22,Rambo: First Blood Part II,Action,George P. Cosmatos,150,R,96 +1985-06-07,The Goonies,Adventure,Richard Donner,61,PG,114 +1985-07-03,Back to the Future,Sci-Fi,Robert Zemeckis,381,PG,116 +1985-08-16,Year of the Dragon,Crime,Michael Cimino,18,R,134 +1985-09-20,Invasion U.S.A.,Action,Joseph Zito,17,R,107 +1985-10-18,Silver Bullet,Horror,Daniel Attias,12,R,95 +1985-11-22,Rocky IV,Drama,Sylvester Stallone,127,PG,91 +1985-12-20,The Color Purple,Drama,Steven Spielberg,142,PG-13,154 +1986-01-17,Iron Eagle,Action,Sidney J. Furie,24,PG-13,117 +1986-02-21,Crossroads,Drama,Walter Hill,5,R,99 +1986-03-21,Highlander,Fantasy,Russell Mulcahy,12,R,116 +1986-04-18,Legend,Fantasy,Ridley Scott,15,PG,89 +1986-05-16,Top Gun,Action,Tony Scott,357,PG,110 +1986-06-27,Running Scared,Action,Peter Hyams,38,R,107 +1986-07-18,Aliens,Sci-Fi,James Cameron,131,R,137 +1986-08-08,Stand By Me,Drama,Rob Reiner,52,R,89 +1986-09-19,Blue Velvet,Mystery,David Lynch,8,R,120 +1986-10-24,The Name of the Rose,Mystery,Jean-Jacques Annaud,7,R,130 +1986-11-21,An American Tail,Animation,Don Bluth,47,G,80 +1986-12-19,Star Trek IV: The Voyage Home,Sci-Fi,Leonard Nimoy,109,PG,119 +1987-01-23,Critical Condition,Comedy,Michael Apted,19,R,98 +1987-02-20,Death Before Dishonor,Action,Terry Leonard,3,R,91 +1987-03-13,Lethal Weapon,Action,Richard Donner,65,R,110 +1987-04-10,Project X,Drama,Jonathan Kaplan,28,PG,108 +1987-05-22,Beverly Hills Cop II,Action,Tony Scott,276,R,100 +1987-06-19,Predator,Sci-Fi,John McTiernan,98,R,107 +1987-07-17,RoboCop,Action,Paul Verhoeven,53,R,102 +1987-08-14,No Way Out,Thriller,Roger Donaldson,35,R,114 +1987-09-18,Fatal Beauty,Action,Tom Holland,12,R,104 +1987-10-23,Fatal Attraction,Thriller,Adrian Lyne,320,R,119 +1987-11-13,Running Man,Sci-Fi,Paul Michael Glaser,38,R,101 +1987-12-18,Wall Street,Drama,Oliver Stone,43,R,126 +1988-01-15,Return of the Living Dead Part II,Horror,Ken Wiederhorn,9,R,89 +1988-02-12,Action Jackson,Action,Craig R. Baxley,20,R,96 +1988-03-18,D.O.A.,Thriller,Rocky Morton,12,R,96 +1988-04-29,Colors,Crime,Dennis Hopper,46,R,120 +1988-05-20,Willow,Fantasy,Ron Howard,57,PG,126 +1988-06-21,Big,Comedy,Penny Marshall,151,PG,104 +1988-07-15,Die Hard,Action,John McTiernan,140,R,132 +1988-08-05,Young Guns,Western,Christopher Cain,45,R,107 +1988-09-16,Moon Over Parador,Comedy,Paul Mazursky,11,PG-13,103 +1988-10-21,Halloween 4,Horror,Dwight H. Little,17,R,88 +1988-11-11,Child's Play,Horror,Tom Holland,33,R,87 +1988-12-21,Rain Man,Drama,Barry Levinson,172,R,133 +1989-01-13,Deep Star Six,Sci-Fi,Sean S. Cunningham,8,R,99 +1989-02-17,Bill & Ted's Excellent Adventure,Comedy,Stephen Herek,40,PG,90 +1989-03-24,Leviathan,Sci-Fi,George P. Cosmatos,15,R,98 +1989-04-14,Major League,Comedy,David S. Ward,49,R,107 +1989-05-24,Indiana Jones and the Last Crusade,Action,Steven Spielberg,474,PG-13,127 +1989-06-23,Batman,Action,Tim Burton,411,PG-13,126 +1989-07-07,Lethal Weapon 2,Action,Richard Donner,227,R,114 +1989-08-11,A Nightmare on Elm Street 5,Horror,Stephen Hopkins,22,R,89 +1989-09-22,Black Rain,Action,Ridley Scott,46,R,125 +1989-10-20,Look Who's Talking,Comedy,Amy Heckerling,140,PG-13,93 +1989-11-17,All Dogs Go to Heaven,Animation,Don Bluth,27,G,84 +1989-12-20,Tango & Cash,Action,Andrei Konchalovsky,63,R,104 +""" + +MOVIES_JSON = """{ + "decades": { + "1980s": { + "genres": { + "action": { + "franchises": { + "terminator": { + "name": "The Terminator", + "movies": [ + { + "title": "The Terminator", + "year": 1984, + "director": "James Cameron", + "stars": ["Arnold Schwarzenegger", "Linda Hamilton", "Michael Biehn"], + "boxOffice": 78371200, + "quotes": ["I'll be back", "Come with me if you want to live"] + } + ] + }, + "rambo": { + "name": "Rambo", + "movies": [ + { + "title": "First Blood", + "year": 1982, + "director": "Ted Kotcheff", + "stars": ["Sylvester Stallone", "Richard Crenna", "Brian Dennehy"], + "boxOffice": 47212904 + }, + { + "title": "Rambo: First Blood Part II", + "year": 1985, + "director": "George P. Cosmatos", + "stars": ["Sylvester Stallone", "Richard Crenna", "Charles Napier"], + "boxOffice": 150415432 + } + ] + } + }, + "standalone_classics": { + "die_hard": { + "title": "Die Hard", + "year": 1988, + "director": "John McTiernan", + "stars": ["Bruce Willis", "Alan Rickman", "Reginald VelJohnson"], + "boxOffice": 140700000, + "location": "Nakatomi Plaza", + "quotes": ["Yippee-ki-yay, motherf***er"] + }, + "predator": { + "title": "Predator", + "year": 1987, + "director": "John McTiernan", + "stars": ["Arnold Schwarzenegger", "Carl Weathers", "Jesse Ventura"], + "boxOffice": 98267558, + "location": "Val Verde jungle", + "quotes": ["Get to the chopper!"] + } + }, + "common_themes": [ + "Cold War politics", + "One man army", + "Revenge plots", + "Military operations", + "Law enforcement" + ], + "typical_elements": { + "weapons": ["M60 machine gun", "Desert Eagle", "Explosive arrows"], + "vehicles": ["Military helicopters", "Muscle cars", "Tanks"], + "locations": ["Urban jungle", "Actual jungle", "Industrial facilities"] + } + } + } + } + }, + "metadata": { + "total_movies": 4, + "date_compiled": "2024", + "box_office_total": 467654094, + "most_frequent_actor": "Arnold Schwarzenegger", + "most_frequent_director": "John McTiernan" + } +}""" + +MOVIES_TREE = json.loads(MOVIES_JSON) + +DUNE_BIOS = [ + { + "name": "Paul Atreides", + "description": "Heir to House Atreides who becomes the Fremen messiah Muad'Dib. Born with extraordinary mental abilities due to Bene Gesserit breeding program.", + }, + { + "name": "Lady Jessica", + "description": "Bene Gesserit concubine to Duke Leto and mother of Paul. Defied her order by bearing a son instead of a daughter, disrupting centuries of careful breeding.", + }, + { + "name": "Baron Vladimir Harkonnen", + "description": "Cruel and corpulent leader of House Harkonnen, sworn enemy of House Atreides. Known for his cunning and brutality in pursuing power.", + }, + { + "name": "Leto Atreides", + "description": "Noble Duke and father of Paul, known for his honor and just rule. Accepts governorship of Arrakis despite knowing it's likely a trap.", + }, + { + "name": "Stilgar", + "description": "Leader of the Fremen Sietch Tabr, becomes a loyal supporter of Paul. Skilled warrior who helps train Paul in Fremen ways.", + }, + { + "name": "Chani", + "description": "Fremen warrior and daughter of planetologist Liet-Kynes. Becomes Paul's concubine and true love after appearing in his prescient visions.", + }, + { + "name": "Thufir Hawat", + "description": "Mentat and Master of Assassins for House Atreides. Serves three generations of Atreides with his superhuman computational skills.", + }, + { + "name": "Duncan Idaho", + "description": "Swordmaster of the Ginaz, loyal to House Atreides. Known for his exceptional fighting skills and sacrifice to save Paul and Jessica.", + }, + { + "name": "Gurney Halleck", + "description": "Warrior-troubadour of House Atreides, skilled with sword and baliset. Serves as Paul's weapons teacher and loyal friend.", + }, + { + "name": "Dr. Yueh", + "description": "Suk doctor conditioned against taking human life, but betrays House Atreides after the Harkonnens torture his wife. Imperial Conditioning broken.", + }, +] diff --git a/src/memray/_vendor/textual/demo/demo_app.py b/src/memray/_vendor/textual/demo/demo_app.py new file mode 100644 index 0000000000..757b4e35b8 --- /dev/null +++ b/src/memray/_vendor/textual/demo/demo_app.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +from memray._vendor.textual.app import App +from memray._vendor.textual.binding import Binding +from memray._vendor.textual.demo.game import GameScreen +from memray._vendor.textual.demo.home import HomeScreen +from memray._vendor.textual.demo.projects import ProjectsScreen +from memray._vendor.textual.demo.widgets import WidgetsScreen + + +class DemoApp(App): + """The demo app defines the modes and sets a few bindings.""" + + CSS = """ + .column { + align: center top; + &>*{ max-width: 100; } + } + Screen .-maximized { + margin: 1 2; + max-width: 100%; + &.column { margin: 1 2; padding: 1 2; } + &.column > * { + max-width: 100%; + } + } + """ + + MODES = { + "game": GameScreen, + "home": HomeScreen, + "projects": ProjectsScreen, + "widgets": WidgetsScreen, + } + DEFAULT_MODE = "home" + BINDINGS = [ + Binding( + "h", + "app.switch_mode('home')", + "Home", + tooltip="Show the home screen", + ), + Binding( + "g", + "app.switch_mode('game')", + "Game", + tooltip="Unwind with a Textual game", + ), + Binding( + "p", + "app.switch_mode('projects')", + "Projects", + tooltip="A selection of Textual projects", + ), + Binding( + "w", + "app.switch_mode('widgets')", + "Widgets", + tooltip="Test the builtin widgets", + ), + Binding( + "ctrl+s", + "app.screenshot", + "Screenshot", + tooltip="Save an SVG 'screenshot' of the current screen", + ), + Binding( + "ctrl+a", + "app.maximize", + "Maximize", + tooltip="Maximize the focused widget (if possible)", + ), + ] + + def action_maximize(self) -> None: + if self.screen.is_maximized: + return + if self.screen.focused is None: + self.notify( + "Nothing to be maximized (try pressing [b]tab[/b])", + title="Maximize", + severity="warning", + ) + else: + if self.screen.maximize(self.screen.focused): + self.notify( + "You are now in the maximized view. Press [b]escape[/b] to return.", + title="Maximize", + ) + else: + self.notify( + "This widget may not be maximized.", + title="Maximize", + severity="warning", + ) + + def check_action(self, action: str, parameters: tuple[object, ...]) -> bool | None: + """Disable switching to a mode we are already on.""" + if ( + action == "switch_mode" + and parameters + and self.current_mode == parameters[0] + ): + return None + return True diff --git a/src/memray/_vendor/textual/demo/game.py b/src/memray/_vendor/textual/demo/game.py new file mode 100644 index 0000000000..1ab3b41f74 --- /dev/null +++ b/src/memray/_vendor/textual/demo/game.py @@ -0,0 +1,589 @@ +""" +An implementation of the "Sliding Tile" puzzle. + +Textual isn't a game engine exactly, but it wasn't hard to build this. + +""" + +from __future__ import annotations + +from asyncio import sleep +from collections import defaultdict +from dataclasses import dataclass +from itertools import product +from random import choice +from time import monotonic + +from rich.console import ConsoleRenderable +from rich.syntax import Syntax + +from memray._vendor.textual import containers, events, on, work +from memray._vendor.textual._loop import loop_last +from memray._vendor.textual.app import ComposeResult +from memray._vendor.textual.binding import Binding +from memray._vendor.textual.demo.page import PageScreen +from memray._vendor.textual.geometry import Offset, Size +from memray._vendor.textual.reactive import reactive +from memray._vendor.textual.screen import ModalScreen, Screen +from memray._vendor.textual.timer import Timer +from memray._vendor.textual.widgets import Button, Digits, Footer, Markdown, Select, Static + + +@dataclass +class NewGame: + """A dataclass to report the desired game type.""" + + language: str + code: str + size: tuple[int, int] + + +PYTHON_CODE = '''\ +class SpatialMap(Generic[ValueType]): + """A spatial map allows for data to be associated with rectangular regions + in Euclidean space, and efficiently queried. + + When the SpatialMap is populated, a reference to each value is placed into one or + more buckets associated with a regular grid that covers 2D space. + + The SpatialMap is able to quickly retrieve the values under a given "window" region + by combining the values in the grid squares under the visible area. + """ + + def __init__(self, grid_width: int = 100, grid_height: int = 20) -> None: + """Create a spatial map with the given grid size. + + Args: + grid_width: Width of a grid square. + grid_height: Height of a grid square. + """ + self._grid_size = (grid_width, grid_height) + self.total_region = Region() + self._map: defaultdict[GridCoordinate, list[ValueType]] = defaultdict(list) + self._fixed: list[ValueType] = [] + + def _region_to_grid_coordinates(self, region: Region) -> Iterable[GridCoordinate]: + """Get the grid squares under a region. + + Args: + region: A region. + + Returns: + Iterable of grid coordinates (tuple of 2 values). + """ + # (x1, y1) is the coordinate of the top left cell + # (x2, y2) is the coordinate of the bottom right cell + x1, y1, width, height = region + x2 = x1 + width - 1 + y2 = y1 + height - 1 + grid_width, grid_height = self._grid_size + + return product( + range(x1 // grid_width, x2 // grid_width + 1), + range(y1 // grid_height, y2 // grid_height + 1), + ) +''' + +XML_CODE = """\ + + + + Back to the Future 1985 Robert Zemeckis + Science Fiction PG + + Michael J. Fox Marty McFly + Christopher Lloyd Dr. Emmett Brown + + + + The Breakfast Club 1985 John Hughes + Drama R + + Emilio Estevez Andrew Clark + Molly Ringwald Claire Standish + + + + Ghostbusters 1984 Ivan Reitman + Comedy PG + + Bill Murray Dr. Peter Venkman + Dan Aykroyd Dr. Raymond Stantz + + + + Die Hard 1988 John McTiernan + Action R + + Bruce Willis John McClane + Alan Rickman Hans Gruber + + + + E.T. the Extra-Terrestrial 1982 Steven Spielberg + Science Fiction PG + + Henry Thomas Elliott + Drew Barrymore Gertie + + +""" + +BF_CODE = """\ +[life.b -- John Horton Conway's Game of Life +(c) 2021 Daniel B. Cristofani +] + +>>>->+>+++++>(++++++++++)[[>>>+<<<-]>+++++>+>>+[<<+>>>>>+<<<-]<-]>>>>[ + [>>>+>+<<<<-]+++>>+[<+>>>+>+<<<-]>>[>[[>>>+<<<-]<]<<++>+>>>>>>-]<- +]+++>+>[[-]<+<[>+++++++++++++++++<-]<+]>>[ + [+++++++++.-------->>>]+[-<<<]>>>[>>,----------[>]<]<<[ + <<<[ + >--[<->>+>-<<-]<[[>>>]+>-[+>>+>-]+[<<<]<-]>++>[<+>-] + >[[>>>]+[<<<]>>>-]+[->>>]<-[++>]>[------<]>+++[<<<]> + ]< + ]>[ + -[+>>+>-]+>>+>>>+>[<<<]>->+>[ + >[->+>+++>>++[>>>]+++<<<++<<<++[>>>]>>>]<<<[>[>>>]+>>>] + <<<<<<<[<<++<+[-<<<+]->++>>>++>>>++<<<<]<<<+[-<<<+]+>->>->> + ]<<+<<+<<<+<<-[+<+<<-]+<+[ + ->+>[-<-<<[<<<]>[>>[>>>]<<+<[<<<]>-]] + <[<[<[<<<]>+>>[>>>]<<-]<[<<<]]>>>->>>[>>>]+> + ]>+[-<<[-]<]-[ + [>>>]<[<<[<<<]>>>>>+>[>>>]<-]>>>[>[>>>]<<<<+>[<<<]>>-]> + ]<<<<<<[---<-----[-[-[<->>+++<+++++++[-]]]]<+<+]> + ]>> +] + +[This program simulates the Game of Life cellular automaton. + +Type e.g. "be" to toggle the fifth cell in the second row, "q" to quit, +or a bare linefeed to advance one generation. + +Grid wraps toroidally. Board size in parentheses in first line (2-166 work). + +This program is licensed under a Creative Commons Attribution-ShareAlike 4.0 +International License (http://creativecommons.org/licenses/by-sa/4.0/).] +""" + + +LEVELS = {"Python": PYTHON_CODE, "XML": XML_CODE, "BF": BF_CODE} + + +class Tile(containers.Vertical): + """An individual tile in the puzzle. + + A Tile is a container with a static inside it. + The static contains the code (as a Rich Syntax object), scrolled so the + relevant portion is visible. + """ + + DEFAULT_CSS = """ + Tile { + position: absolute; + Static { + width: auto; + height: auto; + &:hover { tint: $primary 30%; } + } + &#blank { visibility: hidden; } + } + """ + + position: reactive[Offset] = reactive(Offset) + + def __init__( + self, + renderable: ConsoleRenderable, + tile: int | None, + size: Size, + position: Offset, + ) -> None: + self.renderable = renderable + self.tile = tile + self.tile_size = size + self.start_position = position + + super().__init__(id="blank" if tile is None else f"tile{self.tile}") + self.set_reactive(Tile.position, position) + + def compose(self) -> ComposeResult: + static = Static( + self.renderable, + classes="tile", + name="blank" if self.tile is None else str(self.tile), + ) + assert self.parent is not None + static.styles.width = self.parent.styles.width + static.styles.height = self.parent.styles.height + yield static + + def on_mount(self) -> None: + if self.tile is not None: + width, height = self.tile_size + self.styles.width = width + self.styles.height = height + column, row = self.position + self.set_scroll(column * width, row * height) + self.offset = self.position * self.tile_size + + def watch_position(self, position: Offset) -> None: + """The 'position' is in tile coordinate. + When it changes we animate it to the cell coordinates.""" + self.animate("offset", position * self.tile_size, duration=0.2) + + +class GameDialog(containers.VerticalGroup): + """A dialog to ask the user for the initial game parameters.""" + + DEFAULT_CSS = """ + GameDialog { + background: $boost; + border: thick $primary-muted; + padding: 0 2; + width: 50; + #values { + width: 1fr; + Select { margin: 1 0;} + } + Button { + margin: 0 1 1 1; + width: 1fr; + } + } + """ + + def compose(self) -> ComposeResult: + with containers.VerticalGroup(id="values"): + yield Select.from_values( + LEVELS.keys(), + prompt="Language", + value="Python", + id="language", + allow_blank=False, + ) + yield Select( + [ + ("Easy (3x3)", (3, 3)), + ("Medium (4x4)", (4, 4)), + ("Hard (5x5)", (5, 5)), + ], + prompt="Level", + value=(4, 4), + id="level", + allow_blank=False, + ) + yield Button("Start", variant="primary") + + @on(Button.Pressed) + def on_button_pressed(self) -> None: + language = self.query_one("#language", Select).selection + level = self.query_one("#level", Select).selection + assert language is not None and level is not None + self.screen.dismiss(NewGame(language, LEVELS[language], level)) + + +class GameDialogScreen(ModalScreen): + """Modal screen containing the dialog.""" + + CSS = """ + GameDialogScreen { + align: center middle; + } + """ + + BINDINGS = [("escape", "dismiss")] + + def compose(self) -> ComposeResult: + yield GameDialog() + + +class Game(containers.Vertical, can_focus=True): + """Widget for the game board.""" + + ALLOW_MAXIMIZE = False + DEFAULT_CSS = """ + Game { + visibility: hidden; + align: center middle; + hatch: right $panel; + border: heavy transparent; + &:focus { + border: heavy $success; + } + #grid { + border: heavy $primary; + hatch: right $panel; + box-sizing: content-box; + } + Digits { + width: auto; + color: $foreground; + } + } + """ + + BINDINGS = [ + Binding("up", "move('up')", "up", priority=True), + Binding("down", "move('down')", "down", priority=True), + Binding("left", "move('left')", "left", priority=True), + Binding("right", "move('right')", "right", priority=True), + ] + + state = reactive("waiting") + play_start_time: reactive[float] = reactive(monotonic) + play_time = reactive(0.0, init=False) + code = reactive("") + dimensions = reactive(Size(3, 3)) + code = reactive("") + language = reactive("") + + def __init__( + self, + code: str, + language: str, + dimensions: tuple[int, int], + tile_size: tuple[int, int], + ) -> None: + self.set_reactive(Game.code, code) + self.set_reactive(Game.language, language) + self.locations: defaultdict[Offset, int | None] = defaultdict(None) + super().__init__() + self.dimensions = Size(*dimensions) + self.tile_size = Size(*tile_size) + self.play_timer: Timer | None = None + + def check_win(self) -> bool: + return all(tile.start_position == tile.position for tile in self.query(Tile)) + + def watch_dimensions(self, dimensions: Size) -> None: + self.locations.clear() + tile_width, tile_height = dimensions + for last, tile_no in loop_last(range(0, tile_width * tile_height)): + position = Offset(*divmod(tile_no, tile_width)) + self.locations[position] = None if last else tile_no + + def compose(self) -> ComposeResult: + syntax = Syntax( + self.code, + self.language.lower(), + indent_guides=True, + line_numbers=True, + theme="material", + ) + tile_width, tile_height = self.dimensions + self.state = "waiting" + yield Digits("") + with containers.HorizontalGroup(id="grid") as grid: + grid.styles.width = tile_width * self.tile_size[0] + grid.styles.height = tile_height * self.tile_size[1] + for row, column in product(range(tile_width), range(tile_height)): + position = Offset(row, column) + tile_no = self.locations[position] + yield Tile(syntax, tile_no, self.tile_size, position) + if self.language: + self.call_after_refresh(self.shuffle) + + def update_clock(self) -> None: + if self.state == "playing": + elapsed = monotonic() - self.play_start_time + self.play_time = elapsed + + def watch_play_time(self, play_time: float) -> None: + minutes, seconds = divmod(play_time, 60) + hours, minutes = divmod(minutes, 60) + self.query_one(Digits).update(f"{hours:02,.0f}:{minutes:02.0f}:{seconds:04.1f}") + + def watch_state(self, old_state: str, new_state: str) -> None: + if self.play_timer is not None: + self.play_timer.stop() + + if new_state == "playing": + self.play_start_time = monotonic() + self.play_timer = self.set_interval(1 / 10, self.update_clock) + + def get_tile(self, tile: int | None) -> Tile: + """Get a tile (int) or the blank (None).""" + return self.query_one("#blank" if tile is None else f"#tile{tile}", Tile) + + def get_tile_at(self, position: Offset) -> Tile: + """Get a tile at the given position, or raise an IndexError.""" + if position not in self.locations: + raise IndexError("No tile") + return self.get_tile(self.locations[position]) + + def move_tile(self, tile_no: int | None) -> None: + """Move a tile to the blank. + Note: this doesn't do any validation of legal moves. + """ + tile = self.get_tile(tile_no) + blank = self.get_tile(None) + blank_position = blank.position + + self.locations[tile.position] = None + blank.position = tile.position + + self.locations[blank_position] = tile_no + tile.position = blank_position + + if self.state == "playing" and self.check_win(): + self.state = "won" + self.notify("You won!", title="Sliding Tile Puzzle") + + def can_move(self, tile: int) -> bool: + """Check if a tile may move.""" + blank_position = self.get_tile(None).position + tile_position = self.get_tile(tile).position + return blank_position in ( + tile_position + (1, 0), + tile_position - (1, 0), + tile_position + (0, 1), + tile_position - (0, 1), + ) + + def action_move(self, direction: str) -> None: + if self.state != "playing": + self.app.bell() + return + blank = self.get_tile(None).position + if direction == "up": + position = blank + (0, +1) + elif direction == "down": + position = blank + (0, -1) + elif direction == "left": + position = blank + (+1, 0) + elif direction == "right": + position = blank + (-1, 0) + try: + tile = self.get_tile_at(position) + except IndexError: + return + self.move_tile(tile.tile) + + def get_legal_moves(self) -> set[Offset]: + """Get the positions of all tiles that can move.""" + blank = self.get_tile(None).position + moves: list[Offset] = [] + + DIRECTIONS = [(-1, 0), (+1, -0), (0, -1), (0, +1)] + moves = [ + blank + direction + for direction in DIRECTIONS + if (blank + direction) in self.locations + ] + return {self.get_tile_at(position).position for position in moves} + + @work(exclusive=True) + async def shuffle(self, shuffles: int = 150) -> None: + """A worker to do the shuffling.""" + self.visible = True + if self.play_timer is not None: + self.play_timer.stop() + self.query_one("#grid").border_title = "[reverse bold] SHUFFLING - Please Wait " + self.state = "shuffling" + previous_move: Offset = Offset(-1, -1) + for _ in range(shuffles): + legal_moves = self.get_legal_moves() + legal_moves.discard(previous_move) + previous_move = self.get_tile(None).position + move_position = choice(list(legal_moves)) + move_tile = self.get_tile_at(move_position) + self.move_tile(move_tile.tile) + await sleep(0.05) + self.query_one("#grid").border_title = "" + self.state = "playing" + + @on(events.Click, ".tile") + def on_tile_clicked(self, event: events.Click) -> None: + assert event.widget is not None + tile = int(event.widget.name or 0) + if self.state != "playing" or not self.can_move(tile): + self.app.bell() + return + self.move_tile(tile) + + +class GameInstructions(containers.VerticalGroup): + DEFAULT_CSS = """\ + GameInstructions { + layer: instructions; + width: 60; + background: $panel; + border: thick $primary-darken-2; + Markdown { + background: $panel; + } + + } + +""" + INSTRUCTIONS = """\ +# Instructions + +This is an implementation of the *sliding tile puzzle*. + +The board consists of a number of tiles and a blank space. +After shuffling, the goal is to restore the original "image" by moving a square either horizontally or vertically into the blank space. + +This version is like the physical game, but rather than an image, you need to restore code. + """ + + def compose(self) -> ComposeResult: + yield Markdown(self.INSTRUCTIONS) + with containers.Center(): + yield Button("New Game", action="screen.new_game", variant="success") + + +class GameScreen(PageScreen): + """The screen containing the game.""" + + DEFAULT_CSS = """ + GameScreen{ + #container { + align: center middle; + layers: instructions game; + } + } + """ + + BINDINGS = [("n", "new_game", "New Game")] + + def compose(self) -> ComposeResult: + with containers.Vertical(id="container"): + yield GameInstructions() + yield Game("\n" * 100, "", dimensions=(4, 4), tile_size=(16, 8)) + yield Footer() + + def action_shuffle(self) -> None: + self.query_one(Game).shuffle() + + def action_new_game(self) -> None: + self.app.push_screen(GameDialogScreen(), callback=self.new_game) + + async def new_game(self, new_game: NewGame | None) -> None: + if new_game is None: + return + self.query_one(GameInstructions).display = False + game = self.query_one(Game) + game.state = "waiting" + game.code = new_game.code + game.language = new_game.language + game.dimensions = Size(*new_game.size) + await game.recompose() + game.focus() + + def check_action(self, action: str, parameters: tuple[object, ...]) -> bool | None: + if action == "shuffle" and self.query_one(Game).state == "waiting": + return None + return True + + +if __name__ == "__main__": + from memray._vendor.textual.app import App + + class GameApp(App): + def get_default_screen(self) -> Screen: + return GameScreen() + + app = GameApp() + app.run() diff --git a/src/memray/_vendor/textual/demo/home.py b/src/memray/_vendor/textual/demo/home.py new file mode 100644 index 0000000000..3a2c98c91c --- /dev/null +++ b/src/memray/_vendor/textual/demo/home.py @@ -0,0 +1,263 @@ +from __future__ import annotations + +import asyncio +from importlib.metadata import version + +try: + import httpx + + HTTPX_AVAILABLE = True +except ImportError: + HTTPX_AVAILABLE = False + +from memray._vendor.textual import work +from memray._vendor.textual.app import ComposeResult +from memray._vendor.textual.containers import Horizontal, Vertical, VerticalScroll +from memray._vendor.textual.demo.page import PageScreen +from memray._vendor.textual.reactive import reactive +from memray._vendor.textual.widgets import Collapsible, Digits, Footer, Label, Markdown + +WHAT_IS_TEXTUAL_MD = """\ +# What is Textual? + +Snappy, keyboard-centric, applications that run in the terminal and [the web](https://github.com/Textualize/textual-web). + +🐍 All you need is Python! + +""" + +WELCOME_MD = """\ +## Welcome keyboard warriors! + +This is a Textual app. Here's what you need to know: + +* **enter** `toggle this collapsible widget` +* **tab** `focus the next widget` +* **shift+tab** `focus the previous widget` +* **ctrl+p** `summon the command palette` + + +👇 Also see the footer below. + +`Or… click away with the mouse (no judgement).` + +""" + +ABOUT_MD = """\ +The retro look is not just an aesthetic choice! Textual apps have some unique properties that make them preferable for many tasks. + +## Textual interfaces are *snappy* +Even the most modern of web apps can leave the user waiting hundreds of milliseconds or more for a response. +Given their low graphical requirements, Textual interfaces can be far more responsive — no waiting required. + +## Reward repeated use +Use the mouse to explore, but Textual apps are keyboard-centric and reward repeated use. +An experienced user can operate a Textual app far faster than their web / GUI counterparts. + +## Command palette +A builtin command palette with fuzzy searching puts powerful commands at your fingertips. + +**Try it:** Press **ctrl+p** now. + +""" + +API_MD = """\ +A modern Python API from the developer of [Rich](https://github.com/Textualize/rich). + +```python +# Start building! +from memray._vendor.textual.app import App, ComposeResult +from memray._vendor.textual.widgets import Label + +class MyApp(App): + def compose(self) -> ComposeResult: + yield Label("Hello, World!") + +MyApp().run() +``` + +* Intuitive, batteries-included, API. +* Well documented: See the [tutorial](https://textual.textualize.io/tutorial/), [guide](https://textual.textualize.io/guide/app/), and [reference](https://textual.textualize.io/reference/). +* Fully typed, with modern type annotations. +* Accessible to Python developers of all skill levels. + +**Hint:** press **C** to view the code for this page. + +## Built on Rich + +With over 3.1 *billion* downloads, Rich is the most popular terminal library out there. +Textual builds on Rich to add interactivity, and is fully-compatible with Rich renderables. + +## Re-usable widgets + +Textual's widgets are self-contained and re-usable across projects. +Virtually all aspects of a widget's look and feel can be customized to your requirements. + +## Builtin widgets + +A large [library of builtin widgets](https://textual.textualize.io/widget_gallery/), and a growing ecosystem of third party widgets on PyPI +(this content is generated by the builtin [Markdown](https://textual.textualize.io/widget_gallery/#markdown) widget). + +## Reactive variables + +[Reactivity](https://textual.textualize.io/guide/reactivity/) using Python idioms, keeps your logic separate from display code. + +## Async support + +Built on asyncio, you can easily integrate async libraries while keeping your UI responsive. + +## Concurrency + +Textual's [Workers](https://textual.textualize.io/guide/workers/) provide a far-less error prone interface to +concurrency: both async and threads. + +## Testing + +With a comprehensive [testing framework](https://textual.textualize.io/guide/testing/), you can release reliable software, that can be maintained indefinitely. + +## Docs + +Textual has [amazing docs](https://textual.textualize.io/)! + +""" + +DEPLOY_MD = """\ +Textual apps have extremely low system requirements, and will run on virtually any OS and hardware; locally or remotely via SSH. + +There are a number of ways to deploy and share Textual apps. + +## As a Python library + +Textual apps may be pip installed, via tools such as `pipx` or `uvx`, and other package managers. + +## As a web application + +It takes two lines of code to [serve your Textual app](https://github.com/Textualize/textual-serve) as a web application. + +## Managed web application + +With [Textual web](https://github.com/Textualize/textual-web) you can serve multiple Textual apps on the web, +with zero configuration. Even behind a firewall. +""" + + +class StarCount(Vertical): + """Widget to get and display GitHub star count.""" + + DEFAULT_CSS = """ + StarCount { + dock: top; + height: 6; + border-bottom: hkey $background; + border-top: hkey $background; + layout: horizontal; + background: $boost; + padding: 0 1; + color: $text-warning; + #stars { align: center top; } + #forks { align: right top; } + Label { text-style: bold; color: $foreground; } + LoadingIndicator { background: transparent !important; } + Digits { width: auto; margin-right: 1; } + Label { margin-right: 1; } + align: center top; + &>Horizontal { max-width: 100;} + } + """ + stars = reactive(34455, recompose=True) + forks = reactive(1108, recompose=True) + + @work + async def get_stars(self): + """Worker to get stars from GitHub API.""" + if not HTTPX_AVAILABLE: + self.notify( + "Install httpx to update stars from the GitHub API.\n\n$ [b]pip install httpx[/b]", + title="GitHub Stars", + ) + return + self.loading = True + try: + await asyncio.sleep(1) # Time to admire the loading indicator + async with httpx.AsyncClient() as client: + repository_json = ( + await client.get("https://api.github.com/repos/textualize/textual") + ).json() + self.stars = repository_json["stargazers_count"] + self.forks = repository_json["forks"] + except Exception: + self.notify( + "Unable to update star count (maybe rate-limited)", + title="GitHub stars", + severity="error", + ) + self.loading = False + + def compose(self) -> ComposeResult: + with Horizontal(): + with Vertical(id="version"): + yield Label("Version") + yield Digits(version("textual")) + with Vertical(id="stars"): + yield Label("GitHub ★") + stars = f"{self.stars / 1000:.1f}K" + yield Digits(stars).with_tooltip(f"{self.stars} GitHub stars") + with Vertical(id="forks"): + yield Label("Forks") + yield Digits(str(self.forks)).with_tooltip(f"{self.forks} Forks") + + def on_mount(self) -> None: + self.tooltip = "Click to refresh" + self.get_stars() + + def on_click(self) -> None: + self.get_stars() + + +class Content(VerticalScroll, can_focus=False): + """Non focusable vertical scroll.""" + + +class HomeScreen(PageScreen): + DEFAULT_CSS = """ + HomeScreen { + + Content { + align-horizontal: center; + & > * { + max-width: 100; + } + margin: 0 1; + overflow-y: auto; + height: 1fr; + scrollbar-gutter: stable; + MarkdownFence { + height: auto; + max-height: initial; + } + Collapsible { + padding-right: 0; + &.-collapsed { padding-bottom: 1; } + } + Markdown { + margin-right: 1; + padding-right: 1; + background: transparent; + } + } + } + """ + + def compose(self) -> ComposeResult: + yield StarCount() + with Content(): + yield Markdown(WHAT_IS_TEXTUAL_MD) + with Collapsible(title="Welcome", collapsed=False): + yield Markdown(WELCOME_MD) + with Collapsible(title="Textual Interfaces"): + yield Markdown(ABOUT_MD) + with Collapsible(title="Textual API"): + yield Markdown(API_MD) + with Collapsible(title="Deploying Textual apps"): + yield Markdown(DEPLOY_MD) + yield Footer() diff --git a/src/memray/_vendor/textual/demo/page.py b/src/memray/_vendor/textual/demo/page.py new file mode 100644 index 0000000000..48a2d1f9f3 --- /dev/null +++ b/src/memray/_vendor/textual/demo/page.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +import inspect + +from rich.syntax import Syntax + +from memray._vendor.textual import work +from memray._vendor.textual.app import ComposeResult +from memray._vendor.textual.binding import Binding +from memray._vendor.textual.containers import ScrollableContainer +from memray._vendor.textual.screen import ModalScreen, Screen +from memray._vendor.textual.widgets import Static + + +class CodeScreen(ModalScreen): + DEFAULT_CSS = """ + CodeScreen { + #code { + border: heavy $accent; + margin: 2 4; + scrollbar-gutter: stable; + Static { + width: auto; + } + } + } + """ + BINDINGS = [("escape", "dismiss", "Dismiss code")] + + def __init__(self, title: str, code: str) -> None: + super().__init__() + self.code = code + self.title = title + + def compose(self) -> ComposeResult: + with ScrollableContainer(id="code"): + yield Static( + Syntax( + self.code, lexer="python", indent_guides=True, line_numbers=True + ), + expand=True, + ) + + def on_mount(self): + code_widget = self.query_one("#code") + code_widget.border_title = self.title + code_widget.border_subtitle = "Escape to close" + + +class PageScreen(Screen): + DEFAULT_CSS = """ + PageScreen { + width: 100%; + height: 1fr; + overflow-y: auto; + } + """ + BINDINGS = [ + Binding( + "c", + "show_code", + "Code", + tooltip="Show the code used to generate this screen", + ) + ] + + @work(thread=True) + def get_code(self, source_file: str) -> str | None: + """Read code from disk, or return `None` on error.""" + try: + with open(source_file, "rt", encoding="utf-8") as file_: + return file_.read() + except Exception: + return None + + async def action_show_code(self): + source_file = inspect.getsourcefile(self.__class__) + if source_file is None: + self.notify( + "Could not get the code for this page", + title="Show code", + severity="error", + ) + return + + code = await self.get_code(source_file).wait() + if code is None: + self.notify( + "Could not get the code for this page", + title="Show code", + severity="error", + ) + else: + self.app.push_screen(CodeScreen("Code for this page", code)) diff --git a/src/memray/_vendor/textual/demo/projects.py b/src/memray/_vendor/textual/demo/projects.py new file mode 100644 index 0000000000..aa47e6bf24 --- /dev/null +++ b/src/memray/_vendor/textual/demo/projects.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from memray._vendor.textual import events, on +from memray._vendor.textual.app import ComposeResult +from memray._vendor.textual.binding import Binding +from memray._vendor.textual.containers import Center, Horizontal, ItemGrid, Vertical, VerticalScroll +from memray._vendor.textual.demo.page import PageScreen +from memray._vendor.textual.widgets import Footer, Label, Link, Markdown, Static +from memray._vendor.textual.demo._project_stars import STARS +from memray._vendor.textual.demo._project_data import PROJECTS, ProjectInfo + +PROJECTS_MD = """\ +# Projects + +There are many amazing Open Source Textual apps available for download. +And many more still in development. + +See below for a small selection! +""" + + +class Project(Vertical, can_focus=True, can_focus_children=False): + """Display project information and open repo links.""" + + ALLOW_MAXIMIZE = True + DEFAULT_CSS = """ + Project { + width: 1fr; + height: auto; + padding: 0 1; + border: tall transparent; + box-sizing: border-box; + &:focus { + border: tall $text-primary; + background: $primary 20%; + &.link { + color: red !important; + } + } + #title { text-style: bold; width: 1fr; } + #author { text-style: italic; } + .stars { + color: $text-accent; + text-align: right; + text-style: bold; + width: auto; + } + .header { height: 1; } + .link { + color: $text-accent; + text-style: underline; + } + .description { color: $text-muted; } + &.-hover { opacity: 1; } + } + """ + + BINDINGS = [ + Binding( + "enter", + "open_repository", + "open repo", + tooltip="Open the GitHub repository in your browser", + ) + ] + + def __init__(self, project_info: ProjectInfo) -> None: + self.project_info = project_info + super().__init__() + + def compose(self) -> ComposeResult: + info = self.project_info + with Horizontal(classes="header"): + yield Label(info.title, id="title") + yield Label(f"★ {STARS[info.title]}", classes="stars") + yield Label(info.author, id="author") + yield Link(info.url, tooltip="Click to open project repository") + yield Static(info.description, classes="description") + + @on(events.Enter) + @on(events.Leave) + def on_enter(self, event: events.Enter): + event.stop() + self.set_class(self.is_mouse_over, "-hover") + + def action_open_repository(self) -> None: + self.app.open_url(self.project_info.url) + + +class ProjectsScreen(PageScreen): + AUTO_FOCUS = None + CSS = """ + ProjectsScreen { + align-horizontal: center; + ItemGrid { + margin: 2 4; + padding: 1 2; + background: $boost; + width: 1fr; + height: auto; + grid-gutter: 1 1; + grid-rows: auto; + keyline:thin $foreground 30%; + } + Markdown { margin: 0; padding: 0 2; max-width: 100; background: transparent; } + } + """ + + def compose(self) -> ComposeResult: + with VerticalScroll() as container: + container.can_focus = False + with Center(): + yield Markdown(PROJECTS_MD) + with ItemGrid(min_column_width=40): + for project in PROJECTS: + yield Project(project) + yield Footer() + + +if __name__ == "__main__": + from memray._vendor.textual.app import App + + class GameApp(App): + def get_default_screen(self) -> Screen: + return ProjectsScreen() + + app = GameApp() + app.run() diff --git a/src/memray/_vendor/textual/demo/widgets.py b/src/memray/_vendor/textual/demo/widgets.py new file mode 100644 index 0000000000..6fe90cc902 --- /dev/null +++ b/src/memray/_vendor/textual/demo/widgets.py @@ -0,0 +1,820 @@ +from __future__ import annotations + +import csv +import io +from math import sin + +from rich.syntax import Syntax +from rich.table import Table +from rich.traceback import Traceback + +from memray._vendor.textual import containers, events, lazy, on +from memray._vendor.textual.app import ComposeResult +from memray._vendor.textual.binding import Binding +from memray._vendor.textual.demo.data import COUNTRIES, DUNE_BIOS, MOVIES, MOVIES_TREE +from memray._vendor.textual.demo.page import PageScreen +from memray._vendor.textual.reactive import reactive, var +from memray._vendor.textual.suggester import SuggestFromList +from memray._vendor.textual.theme import BUILTIN_THEMES +from memray._vendor.textual.widgets import ( + Button, + Checkbox, + DataTable, + Digits, + Footer, + Input, + Label, + ListItem, + ListView, + Log, + Markdown, + MaskedInput, + OptionList, + RadioButton, + RadioSet, + RichLog, + Select, + Sparkline, + Static, + Switch, + TabbedContent, + TextArea, + Tree, +) + +WIDGETS_MD = """\ +# Widgets + +The Textual library includes a large number of builtin widgets. + +The following list is *not* exhaustive… + +""" + + +class Buttons(containers.VerticalGroup): + """Buttons demo.""" + + ALLOW_MAXIMIZE = True + DEFAULT_CLASSES = "column" + DEFAULT_CSS = """ + Buttons { + ItemGrid { margin-bottom: 1;} + Button { width: 1fr; } + } + """ + + BUTTONS_MD = """\ +## Buttons + +A simple button, with a number of semantic styles. +May be rendered unclickable by setting `disabled=True`. + +Press `return` to active a button when focused (or click it). + + """ + + def compose(self) -> ComposeResult: + yield Markdown(self.BUTTONS_MD) + with containers.ItemGrid(min_column_width=20, regular=True): + yield Button( + "Default", + tooltip="The default button style", + action="notify('you pressed Default')", + ) + yield Button( + "Primary", + variant="primary", + tooltip="The primary button style - carry out the core action of the dialog", + action="notify('you pressed Primary')", + ) + yield Button( + "Warning", + variant="warning", + tooltip="The warning button style - warn the user that this isn't a typical button", + action="notify('you pressed Warning')", + ) + yield Button( + "Error", + variant="error", + tooltip="The error button style - clicking is a destructive action", + action="notify('you pressed Error')", + ) + with containers.ItemGrid(min_column_width=20, regular=True): + yield Button("Default", disabled=True) + yield Button("Primary", variant="primary", disabled=True) + yield Button("Warning", variant="warning", disabled=True) + yield Button("Error", variant="error", disabled=True) + + +class Checkboxes(containers.VerticalGroup): + """Demonstrates Checkboxes.""" + + DEFAULT_CLASSES = "column" + DEFAULT_CSS = """ + Checkboxes { + height: auto; + Checkbox, RadioButton { width: 1fr; } + &>HorizontalGroup > * { width: 1fr; } + } + + """ + + CHECKBOXES_MD = """\ +## Checkboxes, Radio buttons, and Radio sets + +Checkboxes to toggle booleans. +Radio buttons for exclusive booleans. + +Hit `return` to toggle an checkbox / radio button, when focused. + + """ + RADIOSET_MD = """\ +### Radio Sets + +A *radio set* is a list of mutually exclusive options. +Use the `up` and `down` keys to navigate the list. +Press `return` to toggle a radio button. + +""" + + def compose(self) -> ComposeResult: + yield Markdown(self.CHECKBOXES_MD) + yield Checkbox("A Checkbox") + yield RadioButton("A Radio Button") + yield Markdown(self.RADIOSET_MD) + yield RadioSet( + "Amanda", + "Connor MacLeod", + "Duncan MacLeod", + "Heather MacLeod", + "Joe Dawson", + "Kurgan, [bold italic red]The[/]", + "Methos", + "Rachel Ellenstein", + "Ramírez", + ) + + +class Datatables(containers.VerticalGroup): + """Demonstrates DataTables.""" + + DEFAULT_CLASSES = "column" + DATATABLES_MD = """\ +## Datatables + +A fully-featured DataTable, with cell, row, and columns cursors. +Cells may be individually styled, and may include Rich renderables. + +**Tip:** Focus the table and press `ctrl+a` + +""" + DEFAULT_CSS = """ + DataTable { + height: 16 !important; + &.-maximized { + height: auto !important; + } + } + + """ + + def compose(self) -> ComposeResult: + yield Markdown(self.DATATABLES_MD) + with containers.Center(): + yield DataTable(fixed_columns=1) + + def on_mount(self) -> None: + ROWS = list(csv.reader(io.StringIO(MOVIES))) + table = self.query_one(DataTable) + table.add_columns(*ROWS[0]) + table.add_rows(ROWS[1:]) + + +class Inputs(containers.VerticalGroup): + """Demonstrates Inputs.""" + + ALLOW_MAXIMIZE = True + DEFAULT_CLASSES = "column" + INPUTS_MD = """\ +## Inputs and MaskedInputs + +Text input fields, with placeholder text, validation, and auto-complete. +Build for intuitive and user-friendly forms. + +""" + DEFAULT_CSS = """ + Inputs { + Grid { + background: $boost; + padding: 1 2; + height: auto; + grid-size: 2; + grid-gutter: 1; + grid-columns: auto 1fr; + border: tall blank; + &:focus-within { + border: tall $accent; + } + Label { + width: 100%; + padding: 1; + text-align: right; + } + } + } + """ + + def compose(self) -> ComposeResult: + yield Markdown(self.INPUTS_MD) + with containers.Grid(): + yield Label("Free") + yield Input(placeholder="Type anything here") + yield Label("Number") + yield Input( + type="number", placeholder="Type a number here", valid_empty=True + ) + yield Label("Credit card") + yield MaskedInput( + "9999-9999-9999-9999;0", + tooltip="Obviously not your real credit card!", + valid_empty=True, + ) + yield Label("Country") + yield Input( + suggester=SuggestFromList(COUNTRIES, case_sensitive=False), + placeholder="Country", + ) + + +class ListViews(containers.VerticalGroup): + """Demonstrates List Views and Option Lists.""" + + ALLOW_MAXIMIZE = True + DEFAULT_CLASSES = "column" + LISTS_MD = """\ +## List Views and Option Lists + +A List View turns any widget into a user-navigable and selectable list. +An Option List for a field to present a list of strings to select from. + + """ + + DEFAULT_CSS = """ + ListViews { + ListView { + width: 1fr; + height: auto; + margin: 0 2; + background: $panel; + } + OptionList { max-height: 15; } + Digits { padding: 1 2; width: 1fr; } + } + + """ + + def compose(self) -> ComposeResult: + yield Markdown(self.LISTS_MD) + with containers.HorizontalGroup(): + yield ListView( + ListItem(Digits("$50.00")), + ListItem(Digits("£100.00")), + ListItem(Digits("€500.00")), + ) + yield OptionList(*COUNTRIES) + + +class Logs(containers.VerticalGroup): + """Demonstrates Logs.""" + + DEFAULT_CLASSES = "column" + LOGS_MD = """\ +## Logs and Rich Logs + +A Log widget to efficiently display a scrolling view of text, with optional highlighting. +And a RichLog widget to display Rich renderables. + +""" + DEFAULT_CSS = """ + Logs { + Log, RichLog { + width: 1fr; + height: 20; + padding: 1; + overflow-x: auto; + border: wide transparent; + &:focus { + border: wide $border; + } + } + TabPane { padding: 0; } + TabbedContent.-maximized { + height: 1fr; + Log, RichLog { height: 1fr; } + } + } + """ + + TEXT = """I must not fear. +Fear is the mind-killer. +Fear is the little-death that brings total obliteration. +I will face my fear. +I will permit it to pass over me and through me. +And when it has gone past, I will turn the inner eye to see its path. +Where the fear has gone there will be nothing. Only I will remain.""".splitlines() + + CSV = """lane,swimmer,country,time +4,Joseph Schooling,Singapore,50.39 +2,Michael Phelps,United States,51.14 +5,Chad le Clos,South Africa,51.14 +6,László Cseh,Hungary,51.14 +3,Li Zhuhao,China,51.26 +8,Mehdy Metella,France,51.58 +7,Tom Shields,United States,51.73 +1,Aleksandr Sadovnikov,Russia,51.84""" + CSV_ROWS = list(csv.reader(io.StringIO(CSV))) + + CODE = '''\ +def loop_first_last(values: Iterable[T]) -> Iterable[tuple[bool, bool, T]]: + """Iterate and generate a tuple with a flag for first and last value.""" + iter_values = iter(values) + try: + previous_value = next(iter_values) + except StopIteration: + return + first = True + for value in iter_values: + yield first, False, previous_value + first = False + previous_value = value + yield first, True, previous_value\ +''' + log_count = var(0) + rich_log_count = var(0) + + def compose(self) -> ComposeResult: + yield Markdown(self.LOGS_MD) + with TabbedContent("Log", "RichLog"): + yield Log(max_lines=10_000, highlight=True) + yield RichLog(max_lines=10_000) + + def on_mount(self) -> None: + log = self.query_one(Log) + rich_log = self.query_one(RichLog) + log.write("I am a Log Widget") + rich_log.write("I am a Rich Log Widget") + self.set_interval(0.25, self.update_log) + self.set_interval(1, self.update_rich_log) + + def update_log(self) -> None: + """Update the Log with new content.""" + log = self.query_one(Log) + if self.is_scrolling: + return + if not self.app.screen.can_view_entire(log) and not log.is_in_maximized_view: + return + self.log_count += 1 + line_no = self.log_count % len(self.TEXT) + line = self.TEXT[self.log_count % len(self.TEXT)] + log.write_line(f"fear[{line_no}] = {line!r}") + + def update_rich_log(self) -> None: + """Update the Rich Log with content.""" + rich_log = self.query_one(RichLog) + if self.is_scrolling: + return + if ( + not self.app.screen.can_view_entire(rich_log) + and not rich_log.is_in_maximized_view + ): + return + self.rich_log_count += 1 + log_option = self.rich_log_count % 3 + if log_option == 0: + rich_log.write("Syntax highlighted code", animate=True) + rich_log.write(Syntax(self.CODE, lexer="python"), animate=True) + elif log_option == 1: + rich_log.write("A Rich Table", animate=True) + table = Table(*self.CSV_ROWS[0]) + for row in self.CSV_ROWS[1:]: + table.add_row(*row) + rich_log.write(table, animate=True) + elif log_option == 2: + rich_log.write("A Rich Traceback", animate=True) + try: + 1 / 0 + except Exception: + traceback = Traceback() + rich_log.write(traceback, animate=True) + + +class Markdowns(containers.VerticalGroup): + DEFAULT_CLASSES = "column" + DEFAULT_CSS = """ + Markdowns { + #container { + background: $boost; + border: tall transparent; + height: 16; + padding: 0 1; + &:focus { border: tall $border; } + &.-maximized { height: 1fr; } + } + #movies { + padding: 0 1; + MarkdownBlock { padding: 0 1 0 0; } + } + } + """ + MD_MD = """\ +## Markdown + +Display Markdown in your apps with the Markdown widget. +Most of the text on this page is Markdown. + +Here's an AI generated Markdown document: + +""" + MOVIES_MD = """\ +# The Golden Age of Action Cinema: The 1980s + +The 1980s marked a transformative era in action cinema, defined by **excessive machismo**, explosive practical effects, and unforgettable one-liners. This decade gave birth to many of Hollywood's most enduring action franchises, from _Die Hard_ to _Rambo_, setting templates that filmmakers still reference today. + +## Technical Innovation + +Technologically, the 80s represented a sweet spot between practical effects and early CGI. Filmmakers relied heavily on: + +* Practical stunts +* Pyrotechnics +* Hand-built models + +These elements lent the films a tangible quality that many argue remains superior to modern digital effects. + +## The Action Hero Archetype + +The quintessential action hero emerged during this period, with key characteristics: + +1. Impressive physique +2. Military background +3. Anti-authority attitude +4. Memorable catchphrases + +> "I'll be back" - The Terminator (1984) + +Heroes like Arnold Schwarzenegger and Sylvester Stallone became global icons. However, the decade also saw more nuanced characters emerge, like Bruce Willis's everyman John McClane in *Die Hard*, and powerful female protagonists like Sigourney Weaver's Ellen Ripley in *Aliens*. + +### Political Influence + +Cold War politics heavily influenced these films' narratives, with many plots featuring American heroes facing off against Soviet adversaries. This political subtext, combined with themes of individual triumph over bureaucratic systems, perfectly captured the era's zeitgeist. + +--- + +While often dismissed as simple entertainment, 80s action films left an indelible mark on cinema history, influencing everything from filming techniques to narrative structures, and continuing to inspire filmmakers and delight audiences decades later. + +""" + + def compose(self) -> ComposeResult: + yield Markdown(self.MD_MD) + with containers.VerticalScroll( + id="container", can_focus=True, can_maximize=True + ): + yield Markdown(self.MOVIES_MD, id="movies") + + +class Selects(containers.VerticalGroup): + DEFAULT_CLASSES = "column" + SELECTS_MD = """\ +## Selects + +Selects (AKA *Combo boxes*), present a list of options in a menu that may be expanded by the user. +""" + HEROS = [ + "Arnold Schwarzenegger", + "Brigitte Nielsen", + "Bruce Willis", + "Carl Weathers", + "Chuck Norris", + "Dolph Lundgren", + "Grace Jones", + "Harrison Ford", + "Jean-Claude Van Damme", + "Kurt Russell", + "Linda Hamilton", + "Mel Gibson", + "Michelle Yeoh", + "Sigourney Weaver", + "Sylvester Stallone", + ] + + def compose(self) -> ComposeResult: + yield Markdown(self.SELECTS_MD) + yield Select.from_values(self.HEROS, prompt="80s action hero") + + +class Sparklines(containers.VerticalGroup): + """Demonstrates sparklines.""" + + DEFAULT_CLASSES = "column" + LOGS_MD = """\ +## Sparklines + +A low-res summary of time-series data. + +For detailed graphs, see [textual-plotext](https://github.com/Textualize/textual-plotext). +""" + DEFAULT_CSS = """ + Sparklines { + Sparkline { + width: 1fr; + margin: 1; + &#first > .sparkline--min-color { color: $success; } + &#first > .sparkline--max-color { color: $warning; } + &#second > .sparkline--min-color { color: $warning; } + &#second > .sparkline--max-color { color: $error; } + &#third > .sparkline--min-color { color: $primary; } + &#third > .sparkline--max-color { color: $accent; } + } + VerticalScroll { + height: auto; + border: heavy transparent; + &:focus { border: heavy $border; } + } + } + + """ + + count = var(0) + data: reactive[list[float]] = reactive(list) + + def compose(self) -> ComposeResult: + yield Markdown(self.LOGS_MD) + with containers.VerticalScroll( + id="container", can_focus=True, can_maximize=True + ): + yield Sparkline([], summary_function=max, id="first").data_bind( + Sparklines.data, + ) + yield Sparkline([], summary_function=max, id="second").data_bind( + Sparklines.data, + ) + yield Sparkline([], summary_function=max, id="third").data_bind( + Sparklines.data, + ) + + def on_mount(self) -> None: + self.set_interval(0.1, self.update_sparks) + + def update_sparks(self) -> None: + """Update the sparks data.""" + if self.is_scrolling: + return + if ( + not self.app.screen.can_view_partial(self) + and not self.query_one(Sparkline).is_in_maximized_view + ): + return + self.count += 1 + offset = self.count * 40 + self.data = [abs(sin(x / 3.14)) for x in range(offset, offset + 360 * 6, 20)] + + +class Switches(containers.VerticalGroup): + """Demonstrate the Switch widget.""" + + ALLOW_MAXIMIZE = True + DEFAULT_CLASSES = "column" + SWITCHES_MD = """\ +## Switches + +Functionally almost identical to a Checkbox, but displays more prominently in the UI. +""" + DEFAULT_CSS = """\ +Switches { + Label { + padding: 1; + &:hover {text-style:underline; } + } +} +""" + + def compose(self) -> ComposeResult: + yield Markdown(self.SWITCHES_MD) + with containers.ItemGrid(min_column_width=32): + for theme in BUILTIN_THEMES: + if theme.endswith("-ansi"): + continue + with containers.HorizontalGroup(): + yield Switch(id=theme) + yield Label(theme, name=theme) + + @on(events.Click, "Label") + def on_click(self, event: events.Click) -> None: + """Make the label toggle the switch.""" + # TODO: Add a dedicated form label + event.stop() + if event.widget is not None: + self.query_one(f"#{event.widget.name}", Switch).toggle() + + def on_switch_changed(self, event: Switch.Changed) -> None: + # Don't issue more Changed events + if not event.value: + self.query_one("#textual-dark", Switch).value = True + return + + with self.prevent(Switch.Changed): + # Reset all other switches + for switch in self.query("Switch").results(Switch): + if switch.id != event.switch.id: + switch.value = False + assert event.switch.id is not None + theme_id = event.switch.id + + def switch_theme() -> None: + """Callback to switch the theme.""" + self.app.theme = theme_id + + # Call after a short delay, so we see the Switch animation + self.set_timer(0.3, switch_theme) + + +class TabsDemo(containers.VerticalGroup): + DEFAULT_CLASSES = "column" + TABS_MD = """\ +## Tabs + +A navigable list of section headers. + +Typically used with `ContentTabs`, to display additional content associate with each tab. + +Use the cursor keys to navigate. + +""" + DEFAULT_CSS = """ + .bio { padding: 1 2; background: $boost; color: $foreground-muted; } + """ + + def compose(self) -> ComposeResult: + yield Markdown(self.TABS_MD) + with TabbedContent(*[bio["name"] for bio in DUNE_BIOS]): + for bio in DUNE_BIOS: + yield Static(bio["description"], classes="bio") + + +class Trees(containers.VerticalGroup): + DEFAULT_CLASSES = "column" + TREES_MD = """\ +## Tree + +The Tree widget displays hierarchical data. + +There is also the Tree widget's cousin, DirectoryTree, to navigate folders and files on the filesystem. + """ + DEFAULT_CSS = """ + Trees { + Tree { + height: 16; + padding: 1; + &.-maximized { height: 1fr; } + border: wide transparent; + &:focus { border: wide $border; } + } + VerticalGroup { + + } + } + + """ + + def compose(self) -> ComposeResult: + yield Markdown(self.TREES_MD) + with containers.VerticalGroup(): + tree = Tree("80s movies") + tree.show_root = False + tree.add_json(MOVIES_TREE) + tree.root.expand() + yield tree + + +class TextAreas(containers.VerticalGroup): + ALLOW_MAXIMIZE = True + DEFAULT_CLASSES = "column" + TEXTAREA_MD = """\ +## TextArea + +A powerful and highly configurable text area that supports syntax highlighting, line numbers, soft wrapping, and more. + +""" + DEFAULT_CSS = """ + TextAreas { + TextArea { + height: 16; + } + &.-maximized { + height: 1fr; + } + } + """ + DEFAULT_TEXT = """\ +# Start building! +from memray._vendor.textual import App, ComposeResult +""" + + def compose(self) -> ComposeResult: + yield Markdown(self.TEXTAREA_MD) + yield Select.from_values( + [ + "Bash", + "Css", + "Go", + "HTML", + "Java", + "Javascript", + "JSON", + "Markdown", + "Python", + "Rust", + "Regex", + "Sql", + "TOML", + "YAML", + ], + value="Python", + prompt="Highlight language", + ) + + yield TextArea(self.DEFAULT_TEXT, show_line_numbers=True, language=None) + + def on_select_changed(self, event: Select.Changed) -> None: + self.query_one(TextArea).language = ( + event.value.lower() if isinstance(event.value, str) else None + ) + + +class YourWidgets(containers.VerticalGroup): + DEFAULT_CLASSES = "column" + YOUR_MD = """\ +## Your Widget Here! + +The Textual API allows you to [build custom re-usable widgets](https://textual.textualize.io/guide/widgets/#custom-widgets) and share them across projects. +Custom widgets can be themed, just like the builtin widget library. + +Combine existing widgets to add new functionality, or use the powerful [Line API](https://textual.textualize.io/guide/widgets/#line-api) for unique creations. + +""" + DEFAULT_CSS = """ + YourWidgets { margin-bottom: 2; } + """ + + def compose(self) -> ComposeResult: + yield Markdown(self.YOUR_MD) + + +class WidgetsScreen(PageScreen): + """The Widgets screen""" + + CSS = """ + WidgetsScreen { + align-horizontal: center; + Markdown { background: transparent; } + & > VerticalScroll { + scrollbar-gutter: stable; + & > * { + &:even { background: $boost; } + padding-bottom: 1; + } + } + } + """ + + BINDINGS = [Binding("escape", "blur", "Unfocus any focused widget", show=False)] + + def compose(self) -> ComposeResult: + with lazy.Reveal(containers.VerticalScroll(can_focus=True)): + yield Markdown(WIDGETS_MD, classes="column") + yield Buttons() + yield Checkboxes() + yield Datatables() + yield Inputs() + yield ListViews() + yield Logs() + yield Markdowns() + yield Selects() + yield Sparklines() + yield Switches() + yield TabsDemo() + yield TextAreas() + yield Trees() + yield YourWidgets() + yield Footer() + + +if __name__ == "__main__": + from memray._vendor.textual.app import App + + class GameApp(App): + def get_default_screen(self) -> Screen: + return WidgetsScreen() + + app = GameApp() + app.run() diff --git a/src/memray/_vendor/textual/design.py b/src/memray/_vendor/textual/design.py new file mode 100644 index 0000000000..d34ee5b76b --- /dev/null +++ b/src/memray/_vendor/textual/design.py @@ -0,0 +1,389 @@ +from __future__ import annotations + +from typing import Iterable + +import rich.repr +from rich.console import group +from rich.padding import Padding +from rich.table import Table +from rich.text import Text + +from memray._vendor.textual.color import WHITE, Color + +NUMBER_OF_SHADES = 3 + +# Where no content exists +DEFAULT_DARK_BACKGROUND = "#121212" +# What text usually goes on top off +DEFAULT_DARK_SURFACE = "#1e1e1e" + +DEFAULT_LIGHT_SURFACE = "#f5f5f5" +DEFAULT_LIGHT_BACKGROUND = "#efefef" + + +@rich.repr.auto +class ColorSystem: + """Defines a standard set of colors and variations for building a UI. + + Primary is the main theme color + Secondary is a second theme color + """ + + COLOR_NAMES = [ + "primary", + "secondary", + "background", + "primary-background", + "secondary-background", + "surface", + "panel", + "boost", + "warning", + "error", + "success", + "accent", + ] + + def __init__( + self, + primary: str, + secondary: str | None = None, + warning: str | None = None, + error: str | None = None, + success: str | None = None, + accent: str | None = None, + foreground: str | None = None, + background: str | None = None, + surface: str | None = None, + panel: str | None = None, + boost: str | None = None, + dark: bool = False, + luminosity_spread: float = 0.15, + text_alpha: float = 0.95, + variables: dict[str, str] | None = None, + ): + def parse(color: str | None) -> Color | None: + if color is None: + return None + return Color.parse(color) + + self.primary = Color.parse(primary) + self.secondary = parse(secondary) + self.warning = parse(warning) + self.error = parse(error) + self.success = parse(success) + self.accent = parse(accent) + self.foreground = parse(foreground) + self.background = parse(background) + self.surface = parse(surface) + self.panel = parse(panel) + self.boost = parse(boost) + self.dark = dark + self.luminosity_spread = luminosity_spread + self.text_alpha = text_alpha + self.variables = variables or {} + """Overrides for specific variables.""" + + @property + def shades(self) -> Iterable[str]: + """The names of the colors and derived shades.""" + for color in self.COLOR_NAMES: + for shade_number in range(-NUMBER_OF_SHADES, NUMBER_OF_SHADES + 1): + if shade_number < 0: + yield f"{color}-darken-{abs(shade_number)}" + elif shade_number > 0: + yield f"{color}-lighten-{shade_number}" + else: + yield color + + def get_or_default(self, name: str, default: str) -> str: + """Get the value of a color variable, or the default value if not set.""" + return self.variables.get(name, default) + + def generate(self) -> dict[str, str]: + """Generate a mapping of color name on to a CSS color. + + Returns: + A mapping of color name on to a CSS-style encoded color + """ + + primary = self.primary + secondary = self.secondary or primary + warning = self.warning or primary + error = self.error or secondary + success = self.success or secondary + accent = self.accent or primary + + dark = self.dark + luminosity_spread = self.luminosity_spread + + colors: dict[str, str] = {} + + if dark: + background = self.background or Color.parse(DEFAULT_DARK_BACKGROUND) + surface = self.surface or Color.parse(DEFAULT_DARK_SURFACE) + else: + background = self.background or Color.parse(DEFAULT_LIGHT_BACKGROUND) + surface = self.surface or Color.parse(DEFAULT_LIGHT_SURFACE) + + foreground = self.foreground or (background.inverse) + contrast_text = background.get_contrast_text(1.0) + boost = self.boost or contrast_text.with_alpha(0.04) + + # Colored text + colors["text-primary"] = contrast_text.tint(primary.with_alpha(0.66)).hex + colors["text-secondary"] = contrast_text.tint(secondary.with_alpha(0.66)).hex + colors["text-warning"] = contrast_text.tint(warning.with_alpha(0.66)).hex + colors["text-error"] = contrast_text.tint(error.with_alpha(0.66)).hex + colors["text-success"] = contrast_text.tint(success.with_alpha(0.66)).hex + colors["text-accent"] = contrast_text.tint(accent.with_alpha(0.66)).hex + + if self.panel is None: + panel = surface.blend(primary, 0.1, alpha=1) + if dark: + panel += boost + else: + panel = self.panel + + def luminosity_range(spread: float) -> Iterable[tuple[str, float]]: + """Get the range of shades from darken2 to lighten2. + + Returns: + Iterable of tuples () + """ + luminosity_step = spread / 2 + for n in range(-NUMBER_OF_SHADES, +NUMBER_OF_SHADES + 1): + if n < 0: + label = "-darken" + elif n > 0: + label = "-lighten" + else: + label = "" + yield (f"{label}{'-' + str(abs(n)) if n else ''}"), n * luminosity_step + + # Color names and color + COLORS: list[tuple[str, Color]] = [ + ("primary", primary), + ("secondary", secondary), + ("primary-background", primary), + ("secondary-background", secondary), + ("background", background), + ("foreground", foreground), + ("panel", panel), + ("boost", boost), + ("surface", surface), + ("warning", warning), + ("error", error), + ("success", success), + ("accent", accent), + ] + + # Colors names that have a dark variant + DARK_SHADES = {"primary-background", "secondary-background"} + + get = self.get_or_default + + for name, color in COLORS: + is_dark_shade = dark and name in DARK_SHADES + spread = luminosity_spread + for shade_name, luminosity_delta in luminosity_range(spread): + key = f"{name}{shade_name}" + if color.ansi is not None: + colors[key] = color.hex + elif is_dark_shade: + dark_background = background.blend(color, 0.15, alpha=1.0) + if key not in self.variables: + shade_color = dark_background.blend( + WHITE, spread + luminosity_delta, alpha=1.0 + ).clamped + colors[key] = shade_color.hex + else: + colors[key] = self.variables[key] + else: + colors[key] = get(key, color.lighten(luminosity_delta).hex) + + if foreground.ansi is None: + colors["text"] = get("text", "auto 87%") + colors["text-muted"] = get("text-muted", "auto 60%") + colors["text-disabled"] = get("text-disabled", "auto 38%") + else: + colors["text"] = "ansi_default" + colors["text-muted"] = "ansi_default" + colors["text-disabled"] = "ansi_default" + + # Muted variants of base colors + colors["primary-muted"] = get( + "primary-muted", primary.blend(background, 0.7).hex + ) + colors["secondary-muted"] = get( + "secondary-muted", secondary.blend(background, 0.7).hex + ) + colors["accent-muted"] = get("accent-muted", accent.blend(background, 0.7).hex) + colors["warning-muted"] = get( + "warning-muted", warning.blend(background, 0.7).hex + ) + colors["error-muted"] = get("error-muted", error.blend(background, 0.7).hex) + colors["success-muted"] = get( + "success-muted", success.blend(background, 0.7).hex + ) + + # Foreground colors + colors["foreground-muted"] = get( + "foreground-muted", foreground.with_alpha(0.6).hex + ) + colors["foreground-disabled"] = get( + "foreground-disabled", foreground.with_alpha(0.38).hex + ) + + # The cursor color for widgets such as OptionList, DataTable, etc. + colors["block-cursor-foreground"] = get( + "block-cursor-foreground", colors["text"] + ) + colors["block-cursor-background"] = get("block-cursor-background", primary.hex) + colors["block-cursor-text-style"] = get("block-cursor-text-style", "bold") + colors["block-cursor-blurred-foreground"] = get( + "block-cursor-blurred-foreground", foreground.hex + ) + colors["block-cursor-blurred-background"] = get( + "block-cursor-blurred-background", primary.with_alpha(0.3).hex + ) + colors["block-cursor-blurred-text-style"] = get( + "block-cursor-blurred-text-style", "none" + ) + colors["block-hover-background"] = get( + "block-hover-background", boost.with_alpha(0.1).hex + ) + + # The border color for focused widgets which have a border. + colors["border"] = get("border", primary.hex) + colors["border-blurred"] = get("border-blurred", surface.darken(0.025).hex) + + # The surface color for builtin focused widgets + colors["surface-active"] = get( + "surface-active", surface.lighten(self.luminosity_spread / 2.5).hex + ) + + # The scrollbar colors + colors["scrollbar"] = get( + "scrollbar", + (Color.parse(colors["background-darken-1"]) + primary.with_alpha(0.4)).hex, + ) + colors["scrollbar-hover"] = get( + "scrollbar-hover", + (Color.parse(colors["background-darken-1"]) + primary.with_alpha(0.5)).hex, + ) + # colors["scrollbar-active"] = get("scrollbar-active", colors["panel-lighten-2"]) + colors["scrollbar-active"] = get("scrollbar-active", primary.hex) + colors["scrollbar-background"] = get( + "scrollbar-background", colors["background-darken-1"] + ) + colors["scrollbar-corner-color"] = get( + "scrollbar-corner-color", colors["scrollbar-background"] + ) + colors["scrollbar-background-hover"] = get( + "scrollbar-background-hover", colors["scrollbar-background"] + ) + colors["scrollbar-background-active"] = get( + "scrollbar-background-active", colors["scrollbar-background"] + ) + + # Links + colors["link-background"] = get("link-background", "initial") + colors["link-background-hover"] = get("link-background-hover", primary.hex) + colors["link-color"] = get("link-color", colors["text"]) + colors["link-style"] = get("link-style", "underline") + colors["link-color-hover"] = get("link-color-hover", colors["text"]) + colors["link-style-hover"] = get("link-style-hover", "bold not underline") + + colors["footer-foreground"] = get("footer-foreground", foreground.hex) + colors["footer-background"] = get("footer-background", panel.hex) + + colors["footer-key-foreground"] = get("footer-key-foreground", accent.hex) + colors["footer-key-background"] = get("footer-key-background", "transparent") + + colors["footer-description-foreground"] = get( + "footer-description-foreground", foreground.hex + ) + colors["footer-description-background"] = get( + "footer-description-background", "transparent" + ) + + colors["footer-item-background"] = get("footer-item-background", "transparent") + + colors["input-cursor-background"] = get( + "input-cursor-background", foreground.hex + ) + colors["input-cursor-foreground"] = get( + "input-cursor-foreground", background.hex + ) + colors["input-cursor-text-style"] = get("input-cursor-text-style", "none") + colors["input-selection-background"] = get( + "input-selection-background", + Color.parse(colors["primary-lighten-1"]).with_alpha(0.4).hex, + ) + + # Markdown header styles + colors["markdown-h1-color"] = get("markdown-h1-color", primary.hex) + colors["markdown-h1-background"] = get("markdown-h1-background", "transparent") + colors["markdown-h1-text-style"] = get("markdown-h1-text-style", "bold") + + colors["markdown-h2-color"] = get("markdown-h2-color", primary.hex) + colors["markdown-h2-background"] = get("markdown-h2-background", "transparent") + colors["markdown-h2-text-style"] = get("markdown-h2-text-style", "underline") + + colors["markdown-h3-color"] = get("markdown-h3-color", primary.hex) + colors["markdown-h3-background"] = get("markdown-h3-background", "transparent") + colors["markdown-h3-text-style"] = get("markdown-h3-text-style", "bold") + + colors["markdown-h4-color"] = get("markdown-h4-color", foreground.hex) + colors["markdown-h4-background"] = get("markdown-h4-background", "transparent") + colors["markdown-h4-text-style"] = get( + "markdown-h4-text-style", "bold underline" + ) + + colors["markdown-h5-color"] = get("markdown-h5-color", foreground.hex) + colors["markdown-h5-background"] = get("markdown-h5-background", "transparent") + colors["markdown-h5-text-style"] = get("markdown-h5-text-style", "bold") + + colors["markdown-h6-color"] = get( + "markdown-h6-color", colors["foreground-muted"] + ) + colors["markdown-h6-background"] = get("markdown-h6-background", "transparent") + colors["markdown-h6-text-style"] = get("markdown-h6-text-style", "bold") + + colors["button-foreground"] = get("button-foreground", foreground.hex) + colors["button-color-foreground"] = get( + "button-color-foreground", colors["text"] + ) + colors["button-focus-text-style"] = get("button-focus-text-style", "b reverse") + + return colors + + +def show_design(light: ColorSystem, dark: ColorSystem) -> Table: + """Generate a renderable to show color systems. + + Args: + light: Light ColorSystem. + dark: Dark ColorSystem + + Returns: + Table showing all colors. + """ + + @group() + def make_shades(system: ColorSystem): + colors = system.generate() + for name in system.shades: + background = Color.parse(colors[name]).with_alpha(1.0) + foreground = background + background.get_contrast_text(0.9) + + text = Text(f"${name}") + + yield Padding(text, 1, style=f"{foreground.hex6} on {background.hex6}") + + table = Table(box=None, expand=True) + table.add_column("Light", justify="center") + table.add_column("Dark", justify="center") + table.add_row(make_shades(light), make_shades(dark)) + return table diff --git a/src/memray/_vendor/textual/document/__init__.py b/src/memray/_vendor/textual/document/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/memray/_vendor/textual/document/_document.py b/src/memray/_vendor/textual/document/_document.py new file mode 100644 index 0000000000..3965114aca --- /dev/null +++ b/src/memray/_vendor/textual/document/_document.py @@ -0,0 +1,473 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from functools import lru_cache +from typing import TYPE_CHECKING, NamedTuple, Tuple, overload + +from typing_extensions import Literal, get_args + +if TYPE_CHECKING: + from tree_sitter import Node, Query + +from memray._vendor.textual._cells import cell_len +from memray._vendor.textual.geometry import Size + +Newline = Literal["\r\n", "\n", "\r"] +"""The type representing valid line separators.""" +VALID_NEWLINES = set(get_args(Newline)) +"""The set of valid line separator strings.""" + + +@dataclass +class EditResult: + """Contains information about an edit that has occurred.""" + + end_location: Location + """The new end Location after the edit is complete.""" + replaced_text: str + """The text that was replaced.""" + + +@lru_cache(maxsize=1024) +def _utf8_encode(text: str) -> bytes: + """Encode the input text as utf-8 bytes. + + The returned encoded bytes may be retrieved from a cache. + + Args: + text: The text to encode. + + Returns: + The utf-8 bytes representing the input string. + """ + return text.encode("utf-8") + + +def _detect_newline_style(text: str) -> Newline: + """Return the newline type used in this document. + + Args: + text: The text to inspect. + + Returns: + The Newline used in the file. + """ + if "\r\n" in text: # Windows newline + return "\r\n" + elif "\n" in text: # Unix/Linux/MacOS newline + return "\n" + elif "\r" in text: # Old MacOS newline + return "\r" + else: + return "\n" # Default to Unix style newline + + +class DocumentBase(ABC): + """Describes the minimum functionality a Document implementation must + provide in order to be used by the TextArea widget.""" + + @abstractmethod + def replace_range(self, start: Location, end: Location, text: str) -> EditResult: + """Replace the text at the given range. + + Args: + start: A tuple (row, column) where the edit starts. + end: A tuple (row, column) where the edit ends. + text: The text to insert between start and end. + + Returns: + The new end location after the edit is complete. + """ + + @property + @abstractmethod + def text(self) -> str: + """The text from the document as a string.""" + + @property + @abstractmethod + def newline(self) -> Newline: + """Return the line separator used in the document.""" + + @property + @abstractmethod + def lines(self) -> list[str]: + """Get the lines of the document as a list of strings. + + The strings should *not* include newline characters. The newline + character used for the document can be retrieved via the newline + property. + """ + + @abstractmethod + def get_line(self, index: int) -> str: + """Returns the line with the given index from the document. + + This is used in rendering lines, and will be called by the + TextArea for each line that is rendered. + + Args: + index: The index of the line in the document. + + Returns: + The str instance representing the line. + """ + + @abstractmethod + def get_text_range(self, start: Location, end: Location) -> str: + """Get the text that falls between the start and end locations. + + Args: + start: The start location of the selection. + end: The end location of the selection. + + Returns: + The text between start (inclusive) and end (exclusive). + """ + + @abstractmethod + def get_size(self, indent_width: int) -> Size: + """Get the size of the document. + + The height is generally the number of lines, and the width + is generally the maximum cell length of all the lines. + + Args: + indent_width: The width to use for tab characters. + + Returns: + The Size of the document bounding box. + """ + + def query_syntax_tree( + self, + query: "Query", + start_point: tuple[int, int] | None = None, + end_point: tuple[int, int] | None = None, + ) -> dict[str, list["Node"]]: + """Query the tree-sitter syntax tree. + + The default implementation always returns an empty list. + + To support querying in a subclass, this must be implemented. + + Args: + query: The tree-sitter Query to perform. + start_point: The (row, column byte) to start the query at. + end_point: The (row, column byte) to end the query at. + + Returns: + A dict mapping captured node names to lists of Nodes with that name. + """ + return {} + + def prepare_query(self, query: str) -> "Query | None": + return None + + @property + @abstractmethod + def line_count(self) -> int: + """Returns the number of lines in the document.""" + + @property + @abstractmethod + def start(self) -> Location: + """Returns the location of the start of the document (0, 0).""" + return (0, 0) + + @property + @abstractmethod + def end(self) -> Location: + """Returns the location of the end of the document.""" + + if TYPE_CHECKING: + + @overload + def __getitem__(self, line_index: int) -> str: ... + + @overload + def __getitem__(self, line_index: slice) -> list[str]: ... + + @abstractmethod + def __getitem__(self, line_index: int | slice) -> str | list[str]: + """Return the content of a line as a string, excluding newline characters. + + Args: + line_index: The index or slice of the line(s) to retrieve. + + Returns: + The line or list of lines requested. + """ + + +class Document(DocumentBase): + """A document which can be opened in a TextArea.""" + + def __init__(self, text: str) -> None: + self._newline: Newline = _detect_newline_style(text) + """The type of newline used in the text.""" + self._lines: list[str] = text.splitlines(keepends=False) + """The lines of the document, excluding newline characters. + + If there's a newline at the end of the file, the final line is an empty string. + """ + if text.endswith(tuple(VALID_NEWLINES)) or not text: + self._lines.append("") + + @property + def lines(self) -> list[str]: + """Get the document as a list of strings, where each string represents a line. + + Newline characters are not included in at the end of the strings. + + The newline character used in this document can be found via the `Document.newline` property. + """ + return self._lines + + @property + def text(self) -> str: + """Get the text from the document.""" + return self._newline.join(self._lines) + + @property + def newline(self) -> Newline: + """Get the Newline used in this document (e.g. '\r\n', '\n'. etc.)""" + return self._newline + + def get_size(self, tab_width: int) -> Size: + """The Size of the document, taking into account the tab rendering width. + + Args: + tab_width: The width to use for tab indents. + + Returns: + The size (width, height) of the document. + """ + lines = self._lines + cell_lengths = [cell_len(line.expandtabs(tab_width)) for line in lines] + max_cell_length = max(cell_lengths, default=0) + height = len(lines) + return Size(max_cell_length, height) + + def replace_range(self, start: Location, end: Location, text: str) -> EditResult: + """Replace text at the given range. + + This is the only method by which a document may be updated. + + Args: + start: A tuple (row, column) where the edit starts. + end: A tuple (row, column) where the edit ends. + text: The text to insert between start and end. + + Returns: + The EditResult containing information about the completed + replace operation. + """ + top, bottom = sorted((start, end)) + top_row, top_column = top + bottom_row, bottom_column = bottom + + insert_lines = text.splitlines() + if text.endswith(tuple(VALID_NEWLINES)): + # Special case where a single newline character is inserted. + insert_lines.append("") + + lines = self._lines + + replaced_text = self.get_text_range(top, bottom) + if bottom_row >= len(lines): + after_selection = "" + else: + after_selection = lines[bottom_row][bottom_column:] + + if top_row >= len(lines): + before_selection = "" + else: + before_selection = lines[top_row][:top_column] + + if insert_lines: + insert_lines[0] = before_selection + insert_lines[0] + destination_column = len(insert_lines[-1]) + insert_lines[-1] = insert_lines[-1] + after_selection + else: + destination_column = len(before_selection) + insert_lines = [before_selection + after_selection] + + lines[top_row : bottom_row + 1] = insert_lines + destination_row = top_row + len(insert_lines) - 1 + + end_location = (destination_row, destination_column) + return EditResult(end_location, replaced_text) + + def get_text_range(self, start: Location, end: Location) -> str: + """Get the text that falls between the start and end locations. + + Returns the text between `start` and `end`, including the appropriate + line separator character as specified by `Document._newline`. Note that + `_newline` is set automatically to the first line separator character + found in the document. + + Args: + start: The start location of the selection. + end: The end location of the selection. + + Returns: + The text between start (inclusive) and end (exclusive). + """ + if start == end: + return "" + + top, bottom = sorted((start, end)) + top_row, top_column = top + bottom_row, bottom_column = bottom + lines = self._lines + if top_row == bottom_row: + line = lines[top_row] + selected_text = line[top_column:bottom_column] + else: + start_line = lines[top_row] + end_line = lines[bottom_row] if bottom_row <= self.line_count - 1 else "" + selected_text = start_line[top_column:] + for row in range(top_row + 1, bottom_row): + selected_text += self._newline + lines[row] + + if bottom_row < self.line_count: + selected_text += self._newline + selected_text += end_line[:bottom_column] + + return selected_text + + @property + def line_count(self) -> int: + """Returns the number of lines in the document.""" + return len(self._lines) + + @property + def start(self) -> Location: + """Returns the location of the start of the document (0, 0).""" + return super().start + + @property + def end(self) -> Location: + """Returns the location of the end of the document.""" + last_line = self._lines[-1] + return (self.line_count - 1, len(last_line)) + + def get_index_from_location(self, location: Location) -> int: + """Given a location, returns the index from the document's text. + + Args: + location: The location in the document. + + Returns: + The index in the document's text. + """ + row, column = location + index = row * len(self.newline) + column + for line_index in range(row): + index += len(self.get_line(line_index)) + return index + + def get_location_from_index(self, index: int) -> Location: + """Given a codepoint index in the document's text, returns the corresponding location. + + Args: + index: The index in the document's text. + + Returns: + The corresponding location. + + Raises: + ValueError: If the index is doesn't correspond to a location in the document. + """ + error_message = ( + f"Index {index!r} does not correspond to a location in the document." + ) + if index < 0 or index > len(self.text): + raise ValueError(error_message) + + column_index = 0 + newline_length = len(self.newline) + for line_index in range(self.line_count): + next_column_index = ( + column_index + len(self.get_line(line_index)) + newline_length + ) + if index < next_column_index: + return (line_index, index - column_index) + elif index == next_column_index: + return (line_index + 1, 0) + column_index = next_column_index + + raise ValueError(error_message) + + def get_line(self, index: int) -> str: + """Returns the line with the given index from the document. + + Args: + index: The index of the line in the document. + + Returns: + The string representing the line. + """ + line_string = self[index] + return line_string + + @overload + def __getitem__(self, line_index: int) -> str: ... + + @overload + def __getitem__(self, line_index: slice) -> list[str]: ... + + def __getitem__(self, line_index: int | slice) -> str | list[str]: + """Return the content of a line as a string, excluding newline characters. + + Args: + line_index: The index or slice of the line(s) to retrieve. + + Returns: + The line or list of lines requested. + """ + return self._lines[line_index] + + +Location = Tuple[int, int] +"""A location (row, column) within the document. Indexing starts at 0.""" + + +class Selection(NamedTuple): + """A range of characters within a document from a start point to the end point. + The location of the cursor is always considered to be the `end` point of the selection. + The selection is inclusive of the minimum point and exclusive of the maximum point. + """ + + start: Location = (0, 0) + """The start location of the selection. + + If you were to click and drag a selection inside a text-editor, this is where you *started* dragging. + """ + end: Location = (0, 0) + """The end location of the selection. + + If you were to click and drag a selection inside a text-editor, this is where you *finished* dragging. + """ + + @classmethod + def cursor(cls, location: Location) -> "Selection": + """Create a Selection with the same start and end point - a "cursor". + + Args: + location: The location to create the zero-width Selection. + """ + return cls(location, location) + + @property + def is_empty(self) -> bool: + """Return True if the selection has 0 width, i.e. it's just a cursor.""" + start, end = self + return start == end + + def contains_line(self, y: int) -> bool: + """Check if the given line is within the selection.""" + top, bottom = sorted((self.start[0], self.end[0])) + return y >= top and y <= bottom diff --git a/src/memray/_vendor/textual/document/_document_navigator.py b/src/memray/_vendor/textual/document/_document_navigator.py new file mode 100644 index 0000000000..315a0090b6 --- /dev/null +++ b/src/memray/_vendor/textual/document/_document_navigator.py @@ -0,0 +1,470 @@ +import re +from bisect import bisect, bisect_left, bisect_right +from typing import Any, Sequence + +from memray._vendor.textual._cells import cell_len +from memray._vendor.textual.document._document import Location +from memray._vendor.textual.document._wrapped_document import WrappedDocument +from memray._vendor.textual.geometry import Offset, clamp + + +class DocumentNavigator: + """Cursor navigation in the TextArea is "wrapping-aware". + + Although the cursor location (the selection) is represented as a location + in the raw document, when you actually *move* the cursor, it must take wrapping + into account (otherwise things start to look really confusing to the user where + wrapping is involved). + + Your cursor visually moves through the wrapped version of the document, rather + than the raw document. So, for example, pressing down on the keyboard + may move your cursor to a position further along the current raw document line, + rather than on to the next line in the raw document. + + The DocumentNavigator class manages that behavior. + + Given a cursor location in the unwrapped document, and a cursor movement action, + this class can inform us of the destination the cursor will move to considering + the current wrapping width and document content. It can also translate between + document-space (a location/(row,col) in the raw document), and visual-space + (x and y offsets) as the user will see them on screen after the document has been + wrapped. + + For this to work correctly, the wrapped_document and document must be synchronised. + This means that if you make an edit to the document, you *must* then update the + wrapped document, and *then* you may query the document navigator. + + Naming conventions: + + A "location" refers to a location, in document-space (in the raw document). It + is entirely unrelated to visually positioning. A location in a document can appear + in any visual position, as it is influenced by scrolling, wrapping, gutter settings, + and the cell width of characters to its left. + + A "wrapped section" refers to a portion of the line accounting for wrapping. + For example the line "ABCDEF" when wrapped at width 3 will result in 2 sections: + "ABC" and "DEF". In this case, we call "ABC" is the first section/wrapped section. + + A "wrap offset" is an integer representing the index at which wrapping occurs in a + document-space line. This is a codepoint index, rather than a visual offset. + In "ABCDEF" with wrapping at width 3, there is a single wrap offset of 3. + + "Smart home" refers to a modification of the "home" key behavior. If smart home is + enabled, the first non-whitespace character is considered to be the home location. + If the cursor is currently at this position, then the normal home behavior applies. + This is designed to make cursor movement more useful to end users. + """ + + def __init__(self, wrapped_document: WrappedDocument) -> None: + """Create a DocumentNavigator. + + Args: + wrapped_document: The WrappedDocument to be used when making navigation decisions. + """ + self._wrapped_document = wrapped_document + self._document = wrapped_document.document + + self._word_pattern = re.compile(r"(?<=\W)(?=\w)|(?<=\w)(?=\W)") + """Compiled regular expression for what we consider to be a 'word'.""" + + self.last_x_offset = 0 + """Remembers the last x offset (cell width) the cursor was moved horizontally to, + so that it can be restored on vertical movement where possible.""" + + def is_start_of_document_line(self, location: Location) -> bool: + """True when the location is at the start of the first document line. + + Args: + location: The location to check. + + Returns: + True if the location is at column index 0. + """ + return location[1] == 0 + + def is_start_of_wrapped_line(self, location: Location) -> bool: + """True when the location is at the start of the first wrapped line. + + Args: + location: The location to check. + + Returns: + True if the location is at column index 0. + """ + if self.is_start_of_document_line(location): + return True + + row, column = location + wrap_offsets = self._wrapped_document.get_offsets(row) + return index(wrap_offsets, column) != -1 + + def is_end_of_document_line(self, location: Location) -> bool: + """True if the location is at the end of a line in the document. + + Note that the "end" of a line is equal to its length (one greater + than the final index), since there is a space at the end of the line + for the cursor to rest. + + Args: + location: The location to examine. + + Returns: + True if and only if the document is at the end of a line in the document. + """ + row, column = location + row_length = len(self._document[row]) + return column == row_length + + def is_end_of_wrapped_line(self, location: Location) -> bool: + """True if the location is at the end of a wrapped line. + + Args: + location: The location to examine. + + Returns: + True if and only if the cursor is on the last wrapped section of *any* line. + """ + if self.is_end_of_document_line(location): + return True + + row, column = location + wrap_offsets = self._wrapped_document.get_offsets(row) + return index(wrap_offsets, column - 1) != -1 + + def is_first_document_line(self, location: Location) -> bool: + """Check if the given location is on the first line in the document. + + Args: + location: The location to examine. + + Returns: + True if and only if the cursor is on the first line of the document. + """ + return location[0] == 0 + + def is_first_wrapped_line(self, location: Location) -> bool: + """Check if the given location is on the first wrapped section of the first line in the document. + + Args: + location: The location to examine. + + Returns: + True if and only if the cursor is on the first wrapped section of the first line. + """ + if not self.is_first_document_line(location): + return False + + row, column = location + wrap_offsets = self._wrapped_document.get_offsets(row) + + if not wrap_offsets: + return True + + if column < wrap_offsets[0]: + return True + return False + + def is_last_document_line(self, location: Location) -> bool: + """Check if the given location is on the last line of the document. + + Args: + location: The location to examine. + + Returns: + True when the location is on the last line of the document. + """ + return location[0] == self._document.line_count - 1 + + def is_last_wrapped_line(self, location: Location) -> bool: + """Check if the given location is on the last wrapped section of the last line. + + That is, the cursor is *visually* on the last rendered row. + + Args: + location: The location to examine. + + Returns: + True if and only if the cursor is on the last section of the last line. + """ + if not self.is_last_document_line(location): + return False + + row, column = location + wrap_offsets = self._wrapped_document.get_offsets(row) + + if not wrap_offsets: + return True + + if column >= wrap_offsets[-1]: + return True + return False + + def is_start_of_document(self, location: Location) -> bool: + """Check if a location is at the start of the document. + + Args: + location: The location to examine. + + Returns: + True if and only if the cursor is at document location (0, 0)""" + return location == (0, 0) + + def is_end_of_document(self, location: Location) -> bool: + """Check if a location is at the end of the document. + + Args: + location: The location to examine. + + Returns: + True if and only if the cursor is at the end of the document. + """ + return self.is_last_document_line(location) and self.is_end_of_document_line( + location + ) + + def get_location_left(self, location: Location) -> Location: + """Get the location to the left of the given location. + + Note that if the given location is at the start of the line, then + this will return the end of the preceding line, since that's where + you would expect the cursor to move. + + Args: + location: The location to start from. + + Returns: + The location to the right. + """ + if location == (0, 0): + return 0, 0 + + row, column = location + length_of_row_above = len(self._document[row - 1]) + target_row = row if column != 0 else row - 1 + target_column = column - 1 if column != 0 else length_of_row_above + return target_row, target_column + + def get_location_right(self, location: Location) -> Location: + """Get the location to the right of the given location. + + Note that if the given location is at the end of the line, then + this will return the start of the following line, since that's where + you would expect the cursor to move. + + Args: + location: The location to start from. + + Returns: + The location to the right. + """ + if self.is_end_of_document(location): + return location + row, column = location + is_end_of_line = self.is_end_of_document_line(location) + target_row = row + 1 if is_end_of_line else row + target_column = 0 if is_end_of_line else column + 1 + return target_row, target_column + + def get_location_above(self, location: Location) -> Location: + """Get the location visually aligned with the cell above the given location. + + Args: + location: The location to start from. + + Returns: + The cell above the given location. + """ + + # Get the wrap offsets of the current line. + line_index, column_index = location + wrap_offsets = self._wrapped_document.get_offsets(line_index) + section_start_columns = [0, *wrap_offsets] + + # We need to find the insertion point to determine which section index we're + # on within the current line. When we know the section index, we can use it + # to find the section which sits above it. + section_index = bisect_right(wrap_offsets, column_index) + offset_within_section = column_index - section_start_columns[section_index] + wrapped_line = self._wrapped_document.get_sections(line_index) + section = wrapped_line[section_index] + + # Convert that cursor offset to a cell (visual) offset + current_visual_offset = cell_len(section[:offset_within_section]) + target_offset = max(current_visual_offset, self.last_x_offset) + + if section_index == 0: + # Moving up from a position on the first visual line moves us to the start. + if self.is_first_wrapped_line(location): + return 0, 0 + # Get the last section from the line above, and find where to move in it. + target_row = line_index - 1 + target_column = self._wrapped_document.get_target_document_column( + target_row, target_offset, -1 + ) + target_location = target_row, target_column + else: + # Stay on the same document line, but move backwards. + # Since the section above could be shorter, we need to clamp the column + # to a valid value. + target_column = self._wrapped_document.get_target_document_column( + line_index, target_offset, section_index - 1 + ) + target_location = line_index, target_column + + return target_location + + def get_location_below(self, location: Location) -> Location: + """Given a location in the raw document, return the raw document + location corresponding to moving down in the wrapped representation + of the document. + + Args: + location: The location in the raw document. + + Returns: + The location which is *visually* below the given location. + """ + line_index, column_index = location + document = self._document + + wrap_offsets = self._wrapped_document.get_offsets(line_index) + section_start_columns = [0, *wrap_offsets] + section_index = bisect(wrap_offsets, column_index) + offset_within_section = column_index - section_start_columns[section_index] + wrapped_line = self._wrapped_document.get_sections(line_index) + section = wrapped_line[section_index] + current_visual_offset = cell_len(section[:offset_within_section]) + target_offset = max(current_visual_offset, self.last_x_offset) + + # If we're at the last section/row of a wrapped line + if section_index == len(wrapped_line) - 1: + # Last section of last line: go to end of file. + if self.is_last_document_line(location): + return line_index, len(document[line_index]) + + # Go to the first section of the line below. + target_row = line_index + 1 + target_column = self._wrapped_document.get_target_document_column( + target_row, target_offset, 0 + ) + target_location = target_row, target_column + else: + # Stay on the same document line, but move forwards to + # the location on the section below with the same visual offset. + target_column = self._wrapped_document.get_target_document_column( + line_index, target_offset, section_index + 1 + ) + target_location = line_index, target_column + + return target_location + + def get_location_end(self, location: Location) -> Location: + """Get the location corresponding to the end of the current section. + + Args: + location: The current location. + + Returns: + The location corresponding to the end of the wrapped line. + """ + line_index, column_offset = location + wrap_offsets = self._wrapped_document.get_offsets(line_index) + if wrap_offsets: + # Get the next wrap offset to the right + next_offset_right = bisect(wrap_offsets, column_offset) + # There's no more wrapping to the right of this location - go to line end. + if next_offset_right == len(wrap_offsets): + return line_index, len(self._document[line_index]) + # We've found a wrap point + return line_index, wrap_offsets[next_offset_right] - 1 + else: + # No wrapping to consider - go to the start/end of the document line. + target_column = len(self._document[line_index]) + return line_index, target_column + + def get_location_home( + self, location: Location, smart_home: bool = False + ) -> Location: + """Get the "home location" corresponding to the given location. + + Args: + location: The location to consider. + smart_home: Enable/disable 'smart home' behavior. + + Returns: + The home location, relative to the given location. + """ + line_index, column_offset = location + wrap_offsets = self._wrapped_document.get_offsets(line_index) + if wrap_offsets: + next_offset_left = bisect(wrap_offsets, column_offset) + if next_offset_left == 0: + return line_index, 0 + return line_index, wrap_offsets[next_offset_left - 1] + else: + # No wrapping to consider, go to the start of the document line + line = self._wrapped_document.document[line_index] + target_column = 0 + if smart_home: + for code_point_index, code_point in enumerate(line): + if not code_point.isspace(): + target_column = code_point_index + break + + if column_offset == 0 or column_offset > target_column: + return line_index, target_column + + return line_index, 0 + + def get_location_at_y_offset( + self, location: Location, vertical_offset: int + ) -> Location: + """Apply a visual vertical offset to a location and check the resulting location. + + Args: + location: The location to start from. + vertical_offset: The vertical offset to move (negative=up, positive=down). + + Returns: + The location after the offset has been applied. + """ + # Convert into offset-space to apply the offset. + x_offset, y_offset = self._wrapped_document.location_to_offset(location) + # Convert the offset with the delta applied back to location-space. + return self._wrapped_document.offset_to_location( + Offset(x_offset, y_offset + vertical_offset), + ) + + def clamp_reachable(self, location: Location) -> Location: + """Given a location, return the nearest location that corresponds to a + reachable location in the document. + + Args: + location: A location. + + Returns: + The nearest reachable location in the document. + """ + document = self._document + row, column = location + clamped_row = clamp(row, 0, document.line_count - 1) + + row_text = self._document[clamped_row] + clamped_column = clamp(column, 0, len(row_text)) + return clamped_row, clamped_column + + +def index(sequence: Sequence, value: Any) -> int: + """Locate the leftmost item in the sequence equal to value via bisection. + + Args: + sequence: The sequence to search in. + value: The value to find. + + Returns: + The index of the value, or -1 if the value is not found in the sequence. + """ + insert_index = bisect_left(sequence, value) + if insert_index != len(sequence) and sequence[insert_index] == value: + return insert_index + return -1 diff --git a/src/memray/_vendor/textual/document/_edit.py b/src/memray/_vendor/textual/document/_edit.py new file mode 100644 index 0000000000..aea18d2dcc --- /dev/null +++ b/src/memray/_vendor/textual/document/_edit.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from memray._vendor.textual.document._document import EditResult, Location, Selection + +if TYPE_CHECKING: + from memray._vendor.textual.widgets import TextArea + + +@dataclass +class Edit: + """Implements the Undoable protocol to replace text at some range within a document.""" + + text: str + """The text to insert. An empty string is equivalent to deletion.""" + + from_location: Location + """The start location of the insert.""" + + to_location: Location + """The end location of the insert""" + + maintain_selection_offset: bool + """If True, the selection will maintain its offset to the replacement range.""" + + _original_selection: Selection | None = field(init=False, default=None) + """The Selection when the edit was originally performed, to be restored on undo.""" + + _updated_selection: Selection | None = field(init=False, default=None) + """Where the selection should move to after the replace happens.""" + + _edit_result: EditResult | None = field(init=False, default=None) + """The result of doing the edit.""" + + def do(self, text_area: TextArea, record_selection: bool = True) -> EditResult: + """Perform the edit operation. + + Args: + text_area: The `TextArea` to perform the edit on. + record_selection: If True, record the current selection in the TextArea + so that it may be restored if this Edit is undone in the future. + + Returns: + An `EditResult` containing information about the replace operation. + """ + if record_selection: + self._original_selection = text_area.selection + + text = self.text + + # This code is mostly handling how we adjust TextArea.selection + # when an edit is made to the document programmatically. + # We want a user who is typing away to maintain their relative + # position in the document even if an insert happens before + # their cursor position. + + edit_bottom_row, edit_bottom_column = self.bottom + + selection_start, selection_end = text_area.selection + selection_start_row, selection_start_column = selection_start + selection_end_row, selection_end_column = selection_end + + edit_result = text_area.document.replace_range(self.top, self.bottom, text) + + new_edit_to_row, new_edit_to_column = edit_result.end_location + + column_offset = new_edit_to_column - edit_bottom_column + target_selection_start_column = ( + selection_start_column + column_offset + if edit_bottom_row == selection_start_row + and edit_bottom_column <= selection_start_column + else selection_start_column + ) + target_selection_end_column = ( + selection_end_column + column_offset + if edit_bottom_row == selection_end_row + and edit_bottom_column <= selection_end_column + else selection_end_column + ) + + row_offset = new_edit_to_row - edit_bottom_row + target_selection_start_row = ( + selection_start_row + row_offset + if edit_bottom_row <= selection_start_row + else selection_start_row + ) + target_selection_end_row = ( + selection_end_row + row_offset + if edit_bottom_row <= selection_end_row + else selection_end_row + ) + + if self.maintain_selection_offset: + self._updated_selection = Selection( + start=(target_selection_start_row, target_selection_start_column), + end=(target_selection_end_row, target_selection_end_column), + ) + else: + self._updated_selection = Selection.cursor(edit_result.end_location) + + self._edit_result = edit_result + return edit_result + + def undo(self, text_area: TextArea) -> EditResult: + """Undo the edit operation. + + Looks at the data stored in the edit, and performs the inverse operation of `Edit.do`. + + Args: + text_area: The `TextArea` to undo the insert operation on. + + Returns: + An `EditResult` containing information about the replace operation. + """ + replaced_text = self._edit_result.replaced_text + edit_end = self._edit_result.end_location + + # Replace the span of the edit with the text that was originally there. + undo_edit_result = text_area.document.replace_range( + self.top, edit_end, replaced_text + ) + self._updated_selection = self._original_selection + + return undo_edit_result + + def after(self, text_area: TextArea) -> None: + """Hook for running code after an Edit has been performed via `Edit.do` *and* + side effects such as re-wrapping the document and refreshing the display + have completed. + + For example, we can't record cursor visual offset until we know where the cursor will + land *after* wrapping has been performed, so we must wait until here to do it. + + Args: + text_area: The `TextArea` this operation was performed on. + """ + if self._updated_selection is not None: + text_area.selection = self._updated_selection + text_area.record_cursor_width() + + @property + def top(self) -> Location: + """The Location impacted by this edit that is nearest the start of the document.""" + return min([self.from_location, self.to_location]) + + @property + def bottom(self) -> Location: + """The Location impacted by this edit that is nearest the end of the document.""" + return max([self.from_location, self.to_location]) diff --git a/src/memray/_vendor/textual/document/_history.py b/src/memray/_vendor/textual/document/_history.py new file mode 100644 index 0000000000..b6014a6974 --- /dev/null +++ b/src/memray/_vendor/textual/document/_history.py @@ -0,0 +1,183 @@ +from __future__ import annotations + +import time +from collections import deque +from dataclasses import dataclass, field + +from memray._vendor.textual.document._edit import Edit + + +class HistoryException(Exception): + """Indicates misuse of the EditHistory API. + + For example, trying to undo() an Edit that has yet to be done. + """ + + +@dataclass +class EditHistory: + """Manages batching/checkpointing of Edits into groups that can be undone/redone in the TextArea.""" + + max_checkpoints: int + + checkpoint_timer: float + """Maximum number of seconds since last edit until a new batch is created.""" + + checkpoint_max_characters: int + """Maximum number of characters that can appear in a batch before a new batch is formed.""" + + _last_edit_time: float = field(init=False, default_factory=time.monotonic) + + _character_count: int = field(init=False, default=0) + """Track number of characters replaced + inserted since last batch creation.""" + + _force_end_batch: bool = field(init=False, default=False) + """Flag to force the creation of a new batch for the next recorded edit.""" + + _previously_replaced: bool = field(init=False, default=False) + """Records whether the most recent edit was a replacement or a pure insertion. + + If an edit removes any text from the document at all, it's considered a replacement. + Every other edit is considered a pure insertion. + """ + + def __post_init__(self) -> None: + self._undo_stack: deque[list[Edit]] = deque(maxlen=self.max_checkpoints) + """Batching Edit operations together (edits are simply grouped together in lists).""" + self._redo_stack: deque[list[Edit]] = deque() + """Stores batches that have been undone, allowing them to be redone.""" + + def record(self, edit: Edit) -> None: + """Record an Edit so that it may be undone and redone. + + Determines whether to batch the Edit with previous Edits, or create a new batch/checkpoint. + + This method must be called exactly once per edit, in chronological order. + + A new batch/checkpoint is created when: + + - The undo stack is empty. + - The checkpoint timer expires. + - The maximum number of characters permitted in a checkpoint is reached. + - A redo is performed (we should not add new edits to a batch that has been redone). + - The programmer has requested a new batch via a call to `force_new_batch`. + - e.g. the TextArea widget may call this method in some circumstances. + - Clicking to move the cursor elsewhere in the document should create a new batch. + - Movement of the cursor via a keyboard action that is NOT an edit. + - Blurring the TextArea creates a new checkpoint. + - The current edit involves a deletion/replacement and the previous edit did not. + - The current edit is a pure insertion and the previous edit was not. + - The edit involves insertion or deletion of one or more newline characters. + - An edit which inserts more than a single character (a paste) gets an isolated batch. + + Args: + edit: The edit to record. + """ + edit_result = edit._edit_result + if edit_result is None: + raise HistoryException( + "Cannot add an edit to history before it has been performed using `Edit.do`." + ) + + if edit.text == "" and edit_result.replaced_text == "": + return None + + is_replacement = bool(edit_result.replaced_text) + undo_stack = self._undo_stack + current_time = self._get_time() + edit_characters = len(edit.text) + contains_newline = "\n" in edit.text or "\n" in edit_result.replaced_text + + # Determine whether to create a new batch, or add to the latest batch. + if ( + not undo_stack + or self._force_end_batch + or edit_characters > 1 + or contains_newline + or is_replacement != self._previously_replaced + or current_time - self._last_edit_time > self.checkpoint_timer + or self._character_count + edit_characters > self.checkpoint_max_characters + ): + # Create a new batch (creating a "checkpoint"). + undo_stack.append([edit]) + self._character_count = edit_characters + self._last_edit_time = current_time + self._force_end_batch = False + else: + # Update the latest batch. + undo_stack[-1].append(edit) + self._character_count += edit_characters + self._last_edit_time = current_time + + self._previously_replaced = is_replacement + self._redo_stack.clear() + + # For some edits, we want to ensure the NEXT edit cannot be added to its batch, + # so enforce a checkpoint now. + if contains_newline or edit_characters > 1: + self.checkpoint() + + def _pop_undo(self) -> list[Edit] | None: + """Pop the latest batch from the undo stack and return it. + + This will also place it on the redo stack. + + Returns: + The batch of Edits from the top of the undo stack or None if it's empty. + """ + undo_stack = self._undo_stack + redo_stack = self._redo_stack + if undo_stack: + batch = undo_stack.pop() + redo_stack.append(batch) + return batch + return None + + def _pop_redo(self) -> list[Edit] | None: + """Redo the latest batch on the redo stack and return it. + + This will also place it on the undo stack (with a forced checkpoint to ensure + this undo does not get batched with other edits). + + Returns: + The batch of Edits from the top of the redo stack or None if it's empty. + """ + undo_stack = self._undo_stack + redo_stack = self._redo_stack + if redo_stack: + batch = redo_stack.pop() + undo_stack.append(batch) + # Ensure edits which follow cannot be added to the redone batch. + self.checkpoint() + return batch + return None + + def clear(self) -> None: + """Completely clear the history.""" + self._undo_stack.clear() + self._redo_stack.clear() + self._last_edit_time = time.monotonic() + self._force_end_batch = False + self._previously_replaced = False + + def checkpoint(self) -> None: + """Ensure the next recorded edit starts a new batch.""" + self._force_end_batch = True + + @property + def undo_stack(self) -> list[list[Edit]]: + """A copy of the undo stack, with references to the original Edits.""" + return list(self._undo_stack) + + @property + def redo_stack(self) -> list[list[Edit]]: + """A copy of the redo stack, with references to the original Edits.""" + return list(self._redo_stack) + + def _get_time(self) -> float: + """Get the time from the monotonic clock. + + Returns: + The result of `time.monotonic()` as a float. + """ + return time.monotonic() diff --git a/src/memray/_vendor/textual/document/_syntax_aware_document.py b/src/memray/_vendor/textual/document/_syntax_aware_document.py new file mode 100644 index 0000000000..4f394722b9 --- /dev/null +++ b/src/memray/_vendor/textual/document/_syntax_aware_document.py @@ -0,0 +1,234 @@ +from __future__ import annotations + +try: + from tree_sitter import Language, Node, Parser, Query, QueryCursor, Tree + + TREE_SITTER = True +except ImportError: + TREE_SITTER = False + + +from memray._vendor.textual.document._document import Document, EditResult, Location, _utf8_encode + +_UINT32_MAX = 0xFFFFFFFF + + +class SyntaxAwareDocumentError(Exception): + """General error raised when SyntaxAwareDocument is used incorrectly.""" + + +class SyntaxAwareDocument(Document): + """A subclass of Document which also maintains a tree-sitter syntax + tree when the document is edited. + """ + + def __init__( + self, + text: str, + language: Language, + ): + """Construct a SyntaxAwareDocument. + + Args: + text: The initial text contained in the document. + language: The tree-sitter language to use. + """ + + if not TREE_SITTER: + raise RuntimeError( + "SyntaxAwareDocument unavailable - tree-sitter is not installed." + ) + + super().__init__(text) + self.language: Language = language + """The tree-sitter Language.""" + + self._parser = Parser(self.language) + """The tree-sitter Parser or None if tree-sitter is unavailable.""" + + self._syntax_tree: Tree = self._parser.parse(self._read_callable) # type: ignore + """The tree-sitter Tree (syntax tree) built from the document.""" + + def prepare_query(self, query: str) -> Query | None: + """Prepare a tree-sitter tree query. + + Queries should be prepared once, then reused. + + To execute a query, call `query_syntax_tree`. + + Args: + query: The string query to prepare. + + Returns: + The prepared query. + """ + return Query(self.language, query) + + def query_syntax_tree( + self, + query: Query, + start_point: tuple[int, int] | None = None, + end_point: tuple[int, int] | None = None, + ) -> dict[str, list["Node"]]: + """Query the tree-sitter syntax tree. + + The default implementation always returns an empty list. + + To support querying in a subclass, this must be implemented. + + Args: + query: The tree-sitter Query to perform. + start_point: The (row, column byte) to start the query at. + end_point: The (row, column byte) to end the query at. + + Returns: + A tuple containing the nodes and text captured by the query. + """ + cursor = QueryCursor(query) + + if start_point is not None or end_point is not None: + if start_point is None: + start_point = (0, 0) + if end_point is None: + end_point = (_UINT32_MAX, _UINT32_MAX) + + cursor.set_point_range(start_point, end_point) + + captures = cursor.captures(self._syntax_tree.root_node) + return captures + + def replace_range(self, start: Location, end: Location, text: str) -> EditResult: + """Replace text at the given range. + + Args: + start: A tuple (row, column) where the edit starts. + end: A tuple (row, column) where the edit ends. + text: The text to insert between start and end. + + Returns: + The new end location after the edit is complete. + """ + top, bottom = sorted((start, end)) + + # An optimisation would be finding the byte offsets as a single operation rather + # than doing two passes over the document content. + start_byte = self._location_to_byte_offset(top) + start_point = self._location_to_point(top) + old_end_byte = self._location_to_byte_offset(bottom) + old_end_point = self._location_to_point(bottom) + + replace_result = super().replace_range(start, end, text) + + text_byte_length = len(_utf8_encode(text)) + end_location = replace_result.end_location + assert self._syntax_tree is not None + assert self._parser is not None + self._syntax_tree.edit( + start_byte=start_byte, + old_end_byte=old_end_byte, + new_end_byte=start_byte + text_byte_length, + start_point=start_point, + old_end_point=old_end_point, + new_end_point=self._location_to_point(end_location), + ) + # Incrementally parse the document. + self._syntax_tree = self._parser.parse( + self._read_callable, + self._syntax_tree, # type: ignore[arg-type] + ) + + return replace_result + + def get_line(self, index: int) -> str: + """Return the string representing the line, not including new line characters. + + Args: + line_index: The index of the line. + + Returns: + The string representing the line. + """ + line_string = self[index] + return line_string + + def _location_to_byte_offset(self, location: Location) -> int: + """Given a document coordinate, return the byte offset of that coordinate. + This method only does work if tree-sitter was imported, otherwise it returns 0. + + Args: + location: The location to convert. + + Returns: + An integer byte offset for the given location. + """ + lines = self._lines + row, column = location + lines_above = lines[:row] + end_of_line_width = len(self.newline) + bytes_lines_above = sum( + len(_utf8_encode(line)) + end_of_line_width for line in lines_above + ) + if row < len(lines): + bytes_on_left = len(_utf8_encode(lines[row][:column])) + else: + bytes_on_left = 0 + byte_offset = bytes_lines_above + bytes_on_left + return byte_offset + + def _location_to_point(self, location: Location) -> tuple[int, int]: + """Convert a document location (row_index, column_index) to a tree-sitter + point (row_index, byte_offset_from_start_of_row). If tree-sitter isn't available + returns (0, 0). + + Args: + location: A location (row index, column codepoint offset) + + Returns: + The point corresponding to that location (row index, column byte offset). + """ + lines = self._lines + row, column = location + if row < len(lines): + bytes_on_left = len(_utf8_encode(lines[row][:column])) + else: + bytes_on_left = 0 + return row, bytes_on_left + + def _read_callable(self, byte_offset: int, point: tuple[int, int]) -> bytes: + """A callable which informs tree-sitter about the document content. + + This is passed to tree-sitter which will call it frequently to retrieve + the bytes from the document. + + Args: + byte_offset: The number of (utf-8) bytes from the start of the document. + point: A tuple (row index, column *byte* offset). Note that this differs + from our Location tuple which is (row_index, column codepoint offset). + + Returns: + All the utf-8 bytes between the byte_offset/point and the end of the current + line _including_ the line separator character(s). Returns None if the + offset/point requested by tree-sitter doesn't correspond to a byte. + """ + row, column = point + lines = self._lines + newline = self.newline + + row_out_of_bounds = row >= len(lines) + if row_out_of_bounds: + return b"" + else: + row_text = lines[row] + + encoded_row = _utf8_encode(row_text) + encoded_row_length = len(encoded_row) + + if column < encoded_row_length: + return encoded_row[column:] + _utf8_encode(newline) + elif column == encoded_row_length: + return _utf8_encode(newline[0]) + elif column == encoded_row_length + 1: + if newline == "\r\n": + return b"\n" + + return b"" diff --git a/src/memray/_vendor/textual/document/_wrapped_document.py b/src/memray/_vendor/textual/document/_wrapped_document.py new file mode 100644 index 0000000000..181f12d48a --- /dev/null +++ b/src/memray/_vendor/textual/document/_wrapped_document.py @@ -0,0 +1,454 @@ +from __future__ import annotations + +from bisect import bisect_right + +from rich.text import Text + +from memray._vendor.textual._cells import cell_len, cell_width_to_column_index +from memray._vendor.textual._wrap import compute_wrap_offsets +from memray._vendor.textual.document._document import DocumentBase, Location +from memray._vendor.textual.expand_tabs import expand_tabs_inline, get_tab_widths +from memray._vendor.textual.geometry import Offset, clamp + +VerticalOffset = int +LineIndex = int +SectionOffset = int + + +class WrappedDocument: + """A view into a Document which wraps the document at a certain + width and can be queried to retrieve lines from the *wrapped* version + of the document. + + Allows for incremental updates, ensuring that we only re-wrap ranges of the document + that were influenced by edits. + """ + + def __init__( + self, + document: DocumentBase, + width: int = 0, + tab_width: int = 4, + ) -> None: + """Construct a WrappedDocument. + + By default, a WrappedDocument is wrapped with width=0 (no wrapping). + To wrap the document, use the wrap() method. + + Args: + document: The document to wrap. + width: The width to wrap at. + tab_width: The maximum width to consider for tab characters. + """ + self.document = document + """The document wrapping is performed on.""" + + self._wrap_offsets: list[list[int]] = [] + """Maps line indices to the offsets within the line where wrapping + breaks should be added.""" + + self._tab_width_cache: list[list[int]] = [] + """Maps line indices to a list of tab widths. `[[2, 4]]` means that on line 0, the first + tab has width 2, and the second tab has width 4.""" + + self._offset_to_line_info: list[tuple[LineIndex, SectionOffset]] = [] + """Maps y_offsets (from the top of the document) to line_index and the offset + of the section within the line.""" + + self._line_index_to_offsets: list[list[VerticalOffset]] = [] + """Maps line indices to all the vertical offsets which correspond to that line.""" + + self._width: int = width + """The width the document is currently wrapped at. This will correspond with + the value last passed into the `wrap` method.""" + + self._tab_width: int = tab_width + """The maximum width to expand tabs to when considering their widths.""" + + self.wrap(width, tab_width) + + @property + def wrapped(self) -> bool: + """True if the content is wrapped. This is not the same as wrapping being "enabled". + For example, an empty document can have wrapping enabled, but no wrapping has actually + occurred. + + In other words, this is True if the length of any line in the document is greater + than the available width.""" + return len(self._line_index_to_offsets) == len(self._offset_to_line_info) + + def wrap(self, width: int, tab_width: int | None = None) -> None: + """Wrap and cache all lines in the document. + + Args: + width: The width to wrap at. 0 for no wrapping. + tab_width: The maximum width to consider for tab characters. If None, + reuse the tab width. + """ + self._width = width + if tab_width: + self._tab_width = tab_width + + # We're starting wrapping from scratch + new_wrap_offsets: list[list[int]] = [] + offset_to_line_info: list[tuple[LineIndex, SectionOffset]] = [] + line_index_to_offsets: list[list[VerticalOffset]] = [] + line_tab_widths: list[list[int]] = [] + + append_wrap_offset = new_wrap_offsets.append + append_line_info = offset_to_line_info.append + append_line_offsets = line_index_to_offsets.append + append_line_tab_widths = line_tab_widths.append + + current_offset = 0 + tab_width = self._tab_width + for line_index, line in enumerate(self.document.lines): + tab_sections = get_tab_widths(line, tab_width) + wrap_offsets = ( + compute_wrap_offsets( + line, + width, + tab_size=tab_width, + precomputed_tab_sections=tab_sections, + ) + if width + else [] + ) + append_line_tab_widths([width for _, width in tab_sections]) + append_wrap_offset(wrap_offsets) + append_line_offsets([]) + for section_y_offset in range(len(wrap_offsets) + 1): + append_line_info((line_index, section_y_offset)) + line_index_to_offsets[line_index].append(current_offset) + current_offset += 1 + + self._wrap_offsets = new_wrap_offsets + self._offset_to_line_info = offset_to_line_info + self._line_index_to_offsets = line_index_to_offsets + self._tab_width_cache = line_tab_widths + + @property + def lines(self) -> list[list[str]]: + """The lines of the wrapped version of the Document. + + Each index in the returned list represents a line index in the raw + document. The list[str] at each index is the content of the raw document line + split into multiple lines via wrapping. + + Note that this is expensive to compute and is not cached. + + Returns: + A list of lines from the wrapped version of the document. + """ + wrapped_lines: list[list[str]] = [] + append = wrapped_lines.append + for line_index, line in enumerate(self.document.lines): + divided = Text(line).divide(self._wrap_offsets[line_index]) + append([section.plain for section in divided]) + + return wrapped_lines + + @property + def height(self) -> int: + """The height of the wrapped document.""" + return sum(len(offsets) + 1 for offsets in self._wrap_offsets) + + def wrap_range( + self, + start: Location, + old_end: Location, + new_end: Location, + ) -> None: + """Incrementally recompute wrapping based on a performed edit. + + This must be called *after* the source document has been edited. + + Args: + start: The start location of the edit that was performed in document-space. + old_end: The old end location of the edit in document-space. + new_end: The new end location of the edit in document-space. + """ + start_line_index, _ = start + old_end_line_index, _ = old_end + new_end_line_index, _ = new_end + + # Although end users should not be able to edit invalid ranges via a TextArea, + # programmers can pass whatever they wish to the edit API, so we need to clamp + # the edit ranges here to ensure we only attempt to update within the bounds + # of the wrapped document. + old_max_index = len(self._line_index_to_offsets) - 1 + new_max_index = self.document.line_count - 1 + + start_line_index = clamp( + start_line_index, 0, min((old_max_index, new_max_index)) + ) + old_end_line_index = clamp(old_end_line_index, 0, old_max_index) + new_end_line_index = clamp(new_end_line_index, 0, new_max_index) + + top_line_index, old_bottom_line_index = sorted( + (start_line_index, old_end_line_index) + ) + new_bottom_line_index = max((start_line_index, new_end_line_index)) + + top_y_offset = self._line_index_to_offsets[top_line_index][0] + old_bottom_y_offset = self._line_index_to_offsets[old_bottom_line_index][-1] + + # Get the new range of the edit from top to bottom. + new_lines = self.document.lines[top_line_index : new_bottom_line_index + 1] + + new_wrap_offsets: list[list[int]] = [] + new_line_index_to_offsets: list[list[VerticalOffset]] = [] + new_offset_to_line_info: list[tuple[LineIndex, SectionOffset]] = [] + new_tab_widths: list[list[int]] = [] + + append_wrap_offsets = new_wrap_offsets.append + append_tab_widths = new_tab_widths.append + + width = self._width + tab_width = self._tab_width + + # Add the new offsets between the top and new bottom (the new post-edit offsets) + current_y_offset = top_y_offset + for line_index, line in enumerate(new_lines, top_line_index): + tab_sections = get_tab_widths(line, tab_width) + wrap_offsets = ( + compute_wrap_offsets( + line, width, tab_width, precomputed_tab_sections=tab_sections + ) + if width + else [] + ) + append_tab_widths([width for _, width in tab_sections]) + append_wrap_offsets(wrap_offsets) + + # Collect up the new y offsets for this document line + y_offsets_for_line: list[int] = [] + for section_offset in range(len(wrap_offsets) + 1): + y_offsets_for_line.append(current_y_offset) + new_offset_to_line_info.append((line_index, section_offset)) + current_y_offset += 1 + + # Save the new y offsets for this line + new_line_index_to_offsets.append(y_offsets_for_line) + + # Replace the range start -> old with the new wrapped lines + self._offset_to_line_info[top_y_offset : old_bottom_y_offset + 1] = ( + new_offset_to_line_info + ) + + self._line_index_to_offsets[top_line_index : old_bottom_line_index + 1] = ( + new_line_index_to_offsets + ) + + self._tab_width_cache[top_line_index : old_bottom_line_index + 1] = ( + new_tab_widths + ) + + # How much did the edit/rewrap alter the offsets? + old_height = old_bottom_y_offset - top_y_offset + 1 + new_height = len(new_offset_to_line_info) + + offset_shift = new_height - old_height + line_shift = new_bottom_line_index - old_bottom_line_index + + # Update the line info at all offsets below the edit region. + if line_shift: + for y_offset in range( + top_y_offset + new_height, len(self._offset_to_line_info) + ): + old_line_index, section_offset = self._offset_to_line_info[y_offset] + new_line_index = old_line_index + line_shift + new_line_info = (new_line_index, section_offset) + self._offset_to_line_info[y_offset] = new_line_info + + # Update the offsets at all lines below the edit region + if offset_shift: + for line_index in range( + top_line_index + len(new_lines), len(self._line_index_to_offsets) + ): + old_offsets = self._line_index_to_offsets[line_index] + new_offsets = [offset + offset_shift for offset in old_offsets] + self._line_index_to_offsets[line_index] = new_offsets + + self._wrap_offsets[top_line_index : old_bottom_line_index + 1] = ( + new_wrap_offsets + ) + + def offset_to_location(self, offset: Offset) -> Location: + """Given an offset within the wrapped/visual display of the document, + return the corresponding location in the document. + + Args: + offset: The y-offset within the document. + + Raises: + ValueError: When the given offset does not correspond to a line + in the document. + + Returns: + The Location in the document corresponding to the given offset. + """ + x, y = offset + x = max(0, x) + y = max(0, y) + + if not self._width: + # No wrapping, so we directly map offset to location and clamp. + line_index = min(y, len(self._wrap_offsets) - 1) + column_index = cell_width_to_column_index( + self.document.get_line(line_index), x, self._tab_width + ) + return line_index, column_index + + # Find the line corresponding to the given y offset in the wrapped document. + get_target_document_column = self.get_target_document_column + + try: + offset_data = self._offset_to_line_info[y] + except IndexError: + # y-offset is too large + offset_data = self._offset_to_line_info[-1] + + if offset_data is not None: + line_index, section_y = offset_data + location = line_index, get_target_document_column( + line_index, + x, + section_y, + ) + else: + location = len(self._wrap_offsets) - 1, get_target_document_column( + -1, x, -1 + ) + + # Offset doesn't match any line => land on bottom wrapped line + return location + + def location_to_offset(self, location: Location) -> Offset: + """ + Convert a location in the document to an offset within the wrapped/visual display of the document. + + Args: + location: The location in the document. + + Returns: + The Offset in the document's visual display corresponding to the given location. + """ + line_index, column_index = location + + # Clamp the line index to the bounds of the document + line_index = clamp(line_index, 0, len(self._line_index_to_offsets)) + + # Find the section index of this location, so that we know which y_offset to use + wrap_offsets = self.get_offsets(line_index) + section_start_columns = [0, *wrap_offsets] + section_index = bisect_right(wrap_offsets, column_index) + + # Get the y-offsets corresponding to this line index + y_offsets = self._line_index_to_offsets[line_index] + section_column_index = column_index - section_start_columns[section_index] + + section = self.get_sections(line_index)[section_index] + x_offset = cell_len( + expand_tabs_inline(section[:section_column_index], self._tab_width) + ) + + return Offset(x_offset, y_offsets[section_index]) + + def get_target_document_column( + self, + line_index: int, + x_offset: int, + y_offset: int, + ) -> int: + """Given a line index and the offsets within the wrapped version of that + line, return the corresponding column index in the raw document. + + Args: + line_index: The index of the line in the document. + x_offset: The x-offset within the wrapped line. + y_offset: The y-offset within the wrapped line (supports negative indexing). + + Returns: + The column index corresponding to the line index and y offset. + """ + + # We've found the relevant line, now find the character by + # looking at the character corresponding to the offset width. + sections = self.get_sections(line_index) + + # wrapped_section is the text that appears on a single y_offset within + # the TextArea. It's a potentially wrapped portion of a larger line from + # the original document. + target_section = sections[y_offset] + + # Add the offsets from the wrapped sections above this one (from the same raw + # document line) + target_section_start = sum( + len(wrapped_section) for wrapped_section in sections[:y_offset] + ) + + # Get the column index within this wrapped section of the line + target_column_index = target_section_start + cell_width_to_column_index( + target_section, x_offset, self._tab_width + ) + + # If we're on the final section of a line, the cursor can legally rest beyond + # the end by a single cell. Otherwise, we'll need to ensure that we're + # keeping the cursor within the bounds of the target section. + if y_offset != len(sections) - 1 and y_offset != -1: + target_column_index = min( + target_column_index, target_section_start + len(target_section) - 1 + ) + + return target_column_index + + def get_sections(self, line_index: int) -> list[str]: + """Return the sections for the given line index. + + When wrapping is enabled, a single line in the document can visually span + multiple lines. The list returned represents that visually (each string in + the list represents a single section (y-offset) after wrapping happens). + + Args: + line_index: The index of the line to get sections for. + + Returns: + The wrapped line as a list of strings. + """ + line_offsets = self._wrap_offsets[line_index] + wrapped_lines = Text(self.document[line_index], end="").divide(line_offsets) + return [line.plain for line in wrapped_lines] + + def get_offsets(self, line_index: int) -> list[int]: + """Given a line index, get the offsets within that line where wrapping + should occur for the current document. + + Args: + line_index: The index of the line within the document. + + Raises: + ValueError: When `line_index` is out of bounds. + + Returns: + The offsets within the line where wrapping should occur. + """ + wrap_offsets = self._wrap_offsets + out_of_bounds = line_index < 0 or line_index >= len(wrap_offsets) + if out_of_bounds: + raise ValueError( + f"The document line index {line_index!r} is out of bounds. " + f"The document contains {len(wrap_offsets)!r} lines." + ) + return wrap_offsets[line_index] + + def get_tab_widths(self, line_index: int) -> list[int]: + """Return a list of the tab widths for the given line index. + + Args: + line_index: The index of the line in the document. + + Returns: + An ordered list of the expanded width of the tabs in the line. + """ + return self._tab_width_cache[line_index] diff --git a/src/memray/_vendor/textual/dom.py b/src/memray/_vendor/textual/dom.py new file mode 100644 index 0000000000..120c91d212 --- /dev/null +++ b/src/memray/_vendor/textual/dom.py @@ -0,0 +1,1916 @@ +""" +The module contains `DOMNode`, the base class for any object within the Textual Document Object Model, +which includes all Widgets, Screens, and Apps. + +""" + +from __future__ import annotations + +import re +import threading +from functools import lru_cache, partial +from inspect import getfile +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Iterable, + Sequence, + Type, + TypeVar, + cast, + overload, +) + +import rich.repr +from rich.highlighter import ReprHighlighter +from rich.style import NULL_STYLE as RICH_NULL_STYLE +from rich.style import Style +from rich.text import Text +from rich.tree import Tree + +from memray._vendor.textual._context import NoActiveAppError, active_message_pump +from memray._vendor.textual._node_list import NodeList +from memray._vendor.textual._types import WatchCallbackType +from memray._vendor.textual.binding import Binding, BindingsMap, BindingType +from memray._vendor.textual.cache import LRUCache +from memray._vendor.textual.color import BLACK, WHITE, Color +from memray._vendor.textual.css._error_tools import friendly_list +from memray._vendor.textual.css.constants import VALID_DISPLAY, VALID_VISIBILITY +from memray._vendor.textual.css.errors import DeclarationError, StyleValueError +from memray._vendor.textual.css.match import match +from memray._vendor.textual.css.parse import is_id_selector, parse_declarations, parse_selectors +from memray._vendor.textual.css.query import InvalidQueryFormat, NoMatches, TooManyMatches, WrongType +from memray._vendor.textual.css.styles import RenderStyles, Styles +from memray._vendor.textual.css.tokenize import IDENTIFIER +from memray._vendor.textual.css.tokenizer import TokenError +from memray._vendor.textual.message_pump import MessagePump +from memray._vendor.textual.reactive import Reactive, ReactiveError, _Mutated, _watch +from memray._vendor.textual.style import Style as VisualStyle +from memray._vendor.textual.timer import Timer +from memray._vendor.textual.walk import walk_breadth_first, walk_breadth_search_id, walk_depth_first +from memray._vendor.textual.worker_manager import WorkerManager + +if TYPE_CHECKING: + from typing_extensions import Self, TypeAlias + from _typeshed import SupportsRichComparison + + from rich.console import RenderableType + from memray._vendor.textual.app import App + from memray._vendor.textual.css.query import DOMQuery, QueryType + from memray._vendor.textual.css.types import CSSLocation + from memray._vendor.textual.message import Message + from memray._vendor.textual.screen import Screen + from memray._vendor.textual.widget import Widget + from memray._vendor.textual.worker import Worker, WorkType, ResultType + +from typing_extensions import Literal + +_re_identifier = re.compile(IDENTIFIER) + + +WalkMethod: TypeAlias = Literal["depth", "breadth"] +"""Valid walking methods for the [`DOMNode.walk_children` method][textual.dom.DOMNode.walk_children].""" + + +ReactiveType = TypeVar("ReactiveType") + + +QueryOneCacheKey: TypeAlias = "tuple[int, str, Type[Widget] | None]" +"""The key used to cache query_one results.""" + + +class BadIdentifier(Exception): + """Exception raised if you supply a `id` attribute or class name in the wrong format.""" + + +def check_identifiers(description: str, *names: str) -> None: + """Validate identifier and raise an error if it fails. + + Args: + description: Description of where identifier is used for error message. + *names: Identifiers to check. + """ + match = _re_identifier.fullmatch + for name in names: + if match(name) is None: + raise BadIdentifier( + f"{name!r} is an invalid {description}; " + "identifiers must contain only letters, numbers, underscores, or hyphens, and must not begin with a number." + ) + + +class DOMError(Exception): + """Base exception class for errors relating to the DOM.""" + + +class NoScreen(DOMError): + """Raised when the node has no associated screen.""" + + +class _ClassesDescriptor: + """A descriptor to manage the `classes` property.""" + + def __get__( + self, obj: DOMNode, objtype: type[DOMNode] | None = None + ) -> frozenset[str]: + """A frozenset of the current classes on the widget.""" + return frozenset(obj._classes) + + def __set__(self, obj: DOMNode, classes: str | Iterable[str]) -> None: + """Replaces classes entirely.""" + if isinstance(classes, str): + class_names = set(classes.split()) + else: + class_names = set(classes) + check_identifiers("class name", *class_names) + obj._classes = class_names + obj.update_node_styles() + + +@rich.repr.auto +class DOMNode(MessagePump): + """The base class for object that can be in the Textual DOM (App and Widget)""" + + DEFAULT_CSS: ClassVar[str] = "" + """Default TCSS.""" + + DEFAULT_CLASSES: ClassVar[str] = "" + """Default classes argument if not supplied.""" + + COMPONENT_CLASSES: ClassVar[set[str]] = set() + """Virtual DOM nodes, used to expose styles to line API widgets.""" + + BINDING_GROUP_TITLE: str | None = None + """Title of widget used where bindings are displayed (such as in the key panel).""" + + BINDINGS: ClassVar[list[BindingType]] = [] + """A list of key bindings.""" + + # Indicates if the CSS should be automatically scoped + SCOPED_CSS: ClassVar[bool] = True + """Should default css be limited to the widget type?""" + + HELP: ClassVar[str | None] = None + """Optional help text shown in help panel (Markdown format).""" + + # True if this node inherits the CSS from the base class. + _inherit_css: ClassVar[bool] = True + + # True if this node inherits the component classes from the base class. + _inherit_component_classes: ClassVar[bool] = True + + # True to inherit bindings from base class + _inherit_bindings: ClassVar[bool] = True + + # List of names of base classes that inherit CSS + _css_type_names: ClassVar[frozenset[str]] = frozenset() + + # Name of the widget in CSS + _css_type_name: str = "" + + # Generated list of bindings + _merged_bindings: ClassVar[BindingsMap | None] = None + + _reactives: ClassVar[dict[str, Reactive]] + + _decorated_handlers: dict[type[Message], list[tuple[Callable, str | None]]] + + # Names of potential computed reactives + _computes: ClassVar[frozenset[str]] + + _PSEUDO_CLASSES: ClassVar[dict[str, Callable[[App[Any]], bool]]] = {} + """Pseudo class checks.""" + + def __init__( + self, + *, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + ) -> None: + self._classes: set[str] = set() + self._name = name + self._id = None + if id is not None: + check_identifiers("id", id) + self._id = id + + _classes = classes.split() if classes else [] + check_identifiers("class name", *_classes) + self._classes.update(_classes) + + self._nodes: NodeList = NodeList(self) + self._css_styles: Styles = Styles(self) + self._inline_styles: Styles = Styles(self) + self.styles: RenderStyles = RenderStyles( + self, self._css_styles, self._inline_styles + ) + # A mapping of class names to Styles set in COMPONENT_CLASSES + self._component_styles: dict[str, RenderStyles] = {} + + self._auto_refresh: float | None = None + self._auto_refresh_timer: Timer | None = None + self._css_types = {cls.__name__ for cls in self._css_bases(self.__class__)} + self._bindings = ( + BindingsMap() + if self._merged_bindings is None + else self._merged_bindings.copy() + ) + self._has_hover_style: bool = False + self._has_focus_within: bool = False + self._has_order_style: bool = False + """The node has an ordered dependent pseudo-style (`:odd`, `:even`, `:first-of-type`, `:last-of-type`, `:first-child`, `:last-child`)""" + self._has_odd_or_even: bool = False + """The node has the pseudo class `odd` or `even`.""" + self._reactive_connect: ( + dict[str, tuple[MessagePump, Reactive[object] | object]] | None + ) = None + self._pruning = False + self._query_one_cache: LRUCache[QueryOneCacheKey, DOMNode] = LRUCache(1024) + self._trap_focus = False + + super().__init__() + + def _get_dom_base(self) -> DOMNode: + """Get the DOM base node (typically self). + + All DOM queries on this node will use the return value as the root node. + This method allows the App to query the default screen, and not the active screen. + + Returns: + DOMNode. + """ + return self + + def set_reactive( + self, reactive: Reactive[ReactiveType], value: ReactiveType + ) -> None: + """Sets a reactive value *without* invoking validators or watchers. + + Example: + ```python + self.set_reactive(App.theme, "textual-light") + ``` + + Args: + reactive: A reactive property (use the class scope syntax, i.e. `MyClass.my_reactive`). + value: New value of reactive. + + Raises: + AttributeError: If the first argument is not a reactive. + """ + name = reactive.name + if not isinstance(reactive, Reactive): + raise TypeError("A Reactive class is required; for example: MyApp.theme") + if name not in self._reactives: + raise AttributeError( + f"No reactive called {name!r}; Have you called super().__init__(...) in the {self.__class__.__name__} constructor?" + ) + setattr(self, f"_reactive_{name}", value) + + def mutate_reactive(self, reactive: Reactive[ReactiveType]) -> None: + """Force an update to a mutable reactive. + + Example: + ```python + self.reactive_name_list.append("Jessica") + self.mutate_reactive(MyClass.reactive_name_list) + ``` + + Textual will automatically detect when a reactive is set to a new value, but it is unable + to detect if a value is _mutated_ (such as updating a list, dict, or attribute of an object). + If you do wish to use a collection or other mutable object in a reactive, then you can call + this method after your reactive is updated. This will ensure that all the reactive _superpowers_ + work. + + !!! note + + This method will cause watchers to be called, even if the value hasn't changed. + + Args: + reactive: A reactive property (use the class scope syntax, i.e. `MyClass.my_reactive`). + """ + + internal_name = f"_reactive_{reactive.name}" + value = getattr(self, internal_name) + reactive._set(self, value, always=True) + + def data_bind( + self, + *reactives: Reactive[Any], + **bind_vars: Reactive[Any] | object, + ) -> Self: + """Bind reactive data so that changes to a reactive automatically change the reactive on another widget. + + Reactives may be given as positional arguments or keyword arguments. + See the [guide on data binding](/guide/reactivity#data-binding). + + Example: + ```python + def compose(self) -> ComposeResult: + yield WorldClock("Europe/London").data_bind(WorldClockApp.time) + yield WorldClock("Europe/Paris").data_bind(WorldClockApp.time) + yield WorldClock("Asia/Tokyo").data_bind(WorldClockApp.time) + ``` + + Raises: + ReactiveError: If the data wasn't bound. + + Returns: + Self. + """ + _rich_traceback_omit = True + + parent = active_message_pump.get() + + if self._reactive_connect is None: + self._reactive_connect = {} + bind_vars = {**{reactive.name: reactive for reactive in reactives}, **bind_vars} + for name, reactive in bind_vars.items(): + if name not in self._reactives: + raise ReactiveError( + f"Unable to bind non-reactive attribute {name!r} on {self}" + ) + if isinstance(reactive, Reactive) and not isinstance( + parent, reactive.owner + ): + raise ReactiveError( + f"Unable to bind data; {reactive.owner.__name__} is not defined on {parent.__class__.__name__}." + ) + self._reactive_connect[name] = (parent, reactive) + if self._is_mounted: + self._initialize_data_bind() + else: + self.call_later(self._initialize_data_bind) + return self + + def _initialize_data_bind(self) -> None: + """initialize a data binding. + + Args: + compose_parent: The node doing the binding. + """ + if not self._reactive_connect: + return + for variable_name, (compose_parent, reactive) in self._reactive_connect.items(): + + def make_setter(variable_name: str) -> Callable[[object], None]: + """Make a setter for the given variable name. + + Args: + variable_name: Name of variable being set. + + Returns: + A callable which takes the value to set. + """ + + def setter(value: object) -> None: + """Set bound data.""" + _rich_traceback_omit = True + Reactive._initialize_object(self) + # Wrap the value in `_Mutated` so the setter knows to invoke watchers etc. + setattr(self, variable_name, _Mutated(value)) + + return setter + + assert isinstance(compose_parent, DOMNode) + setter = make_setter(variable_name) + if isinstance(reactive, Reactive): + self.watch( + compose_parent, + reactive.name, + setter, + init=True, + ) + else: + self.call_later(partial(setter, reactive)) + self._reactive_connect = None + + def compose_add_child(self, widget: Widget) -> None: + """Add a node to children. + + This is used by the compose process when it adds children. + There is no need to use it directly, but you may want to override it in a subclass + if you want children to be attached to a different node. + + Args: + widget: A Widget to add. + """ + self._nodes._append(widget) + + @property + def children(self) -> Sequence["Widget"]: + """A view on to the children. + + Returns: + The node's children. + """ + return self._nodes + + @property + def displayed_children(self) -> Sequence[Widget]: + """The displayed children (where `node.display==True`). + + Returns: + A sequence of widgets. + """ + return self._nodes.displayed + + @property + def displayed_and_visible_children(self) -> Sequence[Widget]: + """The displayed children (where `node.display==True` and `node.visible==True`). + + Returns: + A sequence of widgets. + """ + return self._nodes.displayed_and_visible + + @property + def is_empty(self) -> bool: + """Are there no displayed children?""" + return not any(child.display for child in self._nodes) + + def sort_children( + self, + *, + key: Callable[[Widget], SupportsRichComparison] | None = None, + reverse: bool = False, + ) -> None: + """Sort child widgets with an optional key function. + + If `key` is not provided then widgets will be sorted in the order they are constructed. + + Example: + ```python + # Sort widgets by name + screen.sort_children(key=lambda widget: widget.name or "") + ``` + + Args: + key: A callable which accepts a widget and returns something that can be sorted, + or `None` to sort without a key function. + reverse: Sort in descending order. + """ + self._nodes._sort(key=key, reverse=reverse) + self.refresh(layout=True) + + @property + def auto_refresh(self) -> float | None: + """Number of seconds between automatic refresh, or `None` for no automatic refresh.""" + return self._auto_refresh + + @auto_refresh.setter + def auto_refresh(self, interval: float | None) -> None: + if self._auto_refresh_timer is not None: + self._auto_refresh_timer.stop() + self._auto_refresh_timer = None + if interval is not None: + self._auto_refresh_timer = self.set_interval( + interval, self.automatic_refresh, name=f"auto refresh {self!r}" + ) + self._auto_refresh = interval + + @property + def workers(self) -> WorkerManager: + """The app's worker manager. Shortcut for `self.app.workers`.""" + return self.app.workers + + def trap_focus(self, trap_focus: bool = True) -> None: + """Trap the focus. + + When applied to a container, this will limit tab-to-focus to the children of that + container (once focus is within that container). + + This can be useful for widgets that act like modal dialogs, where you want to restrict + the user to the controls within the dialog. + + Args: + trap_focus: `True` to trap focus. `False` to restore default behavior. + """ + self._trap_focus = trap_focus + + def run_worker( + self, + work: WorkType[ResultType], + name: str | None = "", + group: str = "default", + description: str = "", + exit_on_error: bool = True, + start: bool = True, + exclusive: bool = False, + thread: bool = False, + ) -> Worker[ResultType]: + """Run work in a worker. + + A worker runs a function, coroutine, or awaitable, in the *background* as an async task or as a thread. + + Args: + work: A function, async function, or an awaitable object to run in a worker. + name: A short string to identify the worker (in logs and debugging). + group: A short string to identify a group of workers. + description: A longer string to store longer information on the worker. + exit_on_error: Exit the app if the worker raises an error. Set to `False` to suppress exceptions. + start: Start the worker immediately. + exclusive: Cancel all workers in the same group. + thread: Mark the worker as a thread worker. + + Returns: + New Worker instance. + """ + + # If we're running a worker from inside a secondary thread, + # do so in a thread-safe way. + if self.app._thread_id != threading.get_ident(): + creator = partial(self.app.call_from_thread, self.workers._new_worker) + else: + creator = self.workers._new_worker + worker: Worker[ResultType] = creator( + work, + self, + name=name, + group=group, + description=description, + exit_on_error=exit_on_error, + start=start, + exclusive=exclusive, + thread=thread, + ) + return worker + + @property + def is_modal(self) -> bool: + """Is the node a modal?""" + return False + + @property + def is_on_screen(self) -> bool: + """Check if the node was displayed in the last screen update.""" + return False + + def automatic_refresh(self) -> None: + """Perform an automatic refresh. + + This method is called when you set the `auto_refresh` attribute. + You could implement this method if you want to perform additional work + during an automatic refresh. + + """ + if self.is_on_screen: + self.refresh() + + def __init_subclass__( + cls, + inherit_css: bool = True, + inherit_bindings: bool = True, + inherit_component_classes: bool = True, + ) -> None: + super().__init_subclass__() + + reactives = cls._reactives = {} + for base in reversed(cls.__mro__): + reactives.update( + { + name: reactive + for name, reactive in base.__dict__.items() + if isinstance(reactive, Reactive) + } + ) + + cls._inherit_css = inherit_css + cls._inherit_bindings = inherit_bindings + cls._inherit_component_classes = inherit_component_classes + css_type_names: set[str] = set() + bases = cls._css_bases(cls) + cls._css_type_name = bases[0].__name__ + for base in bases: + css_type_names.add(base.__name__) + cls._merged_bindings = cls._merge_bindings() + cls._css_type_names = frozenset(css_type_names) + cls._computes = frozenset( + [ + name.lstrip("_")[8:] + for name in dir(cls) + if name.startswith(("_compute_", "compute_")) + ] + ) + + def get_component_styles(self, *names: str) -> RenderStyles: + """Get a "component" styles object (must be defined in COMPONENT_CLASSES classvar). + + Args: + names: Names of the components. + + Raises: + KeyError: If the component class doesn't exist. + + Returns: + A Styles object. + """ + + styles = RenderStyles(self, Styles(), Styles()) + + for name in names: + if name not in self._component_styles: + raise KeyError(f"No {name!r} key in COMPONENT_CLASSES") + component_styles = self._component_styles[name] + assert component_styles.node is not None + styles._update_node(component_styles.node) + styles.base.merge(component_styles.base) + styles.inline.merge(component_styles.inline) + styles._updates += 1 + + return styles + + def _post_mount(self): + """Called after the object has been mounted.""" + _rich_traceback_omit = True + Reactive._initialize_object(self) + + def notify_style_update(self) -> None: + """Called after styles are updated. + + Implement this in a subclass if you want to clear any cached data when the CSS is reloaded. + """ + + @property + def _node_bases(self) -> Sequence[Type[DOMNode]]: + """The DOMNode bases classes (including self.__class__)""" + # Node bases are in reversed order so that the base class is lower priority + return self._css_bases(self.__class__) + + @classmethod + @lru_cache(maxsize=None) + def _css_bases(cls, base: Type[DOMNode]) -> Sequence[Type[DOMNode]]: + """Get the DOMNode base classes, which inherit CSS. + + Args: + base: A DOMNode class + + Returns: + An iterable of DOMNode classes. + """ + classes: list[type[DOMNode]] = [] + _class = base + while True: + classes.append(_class) + if not _class._inherit_css: + break + for _base in _class.__bases__: + if issubclass(_base, DOMNode): + _class = _base + break + else: + break + return classes + + @classmethod + def _merge_bindings(cls) -> BindingsMap: + """Merge bindings from base classes. + + Returns: + Merged bindings. + """ + bindings: list[BindingsMap] = [] + + for base in reversed(cls.__mro__): + if issubclass(base, DOMNode): + if not base._inherit_bindings: + bindings.clear() + bindings.append( + BindingsMap( + base.__dict__.get("BINDINGS", []), + ) + ) + + keys: dict[str, list[Binding]] = {} + for bindings_ in bindings: + for key, key_bindings in bindings_.key_to_bindings.items(): + keys[key] = key_bindings + + new_bindings = BindingsMap.from_keys(keys) + return new_bindings + + def _post_register(self, app: App) -> None: + """Called when the widget is registered + + Args: + app: Parent application. + """ + + def __rich_repr__(self) -> rich.repr.Result: + # Being a bit defensive here to guard against errors when calling repr before initialization + if hasattr(self, "_name"): + yield "name", self._name, None + if hasattr(self, "_id"): + yield "id", self._id, None + if hasattr(self, "_classes") and self._classes: + yield "classes", " ".join(self._classes) + + def _get_default_css(self) -> list[tuple[CSSLocation, str, int, str]]: + """Gets the CSS for this class and inherited from bases. + + Default CSS is inherited from base classes, unless `inherit_css` is set to + `False` when subclassing. + + Returns: + A list of tuples containing (LOCATION, SOURCE, SPECIFICITY, SCOPE) for this + class and inherited from base classes. + """ + + css_stack: list[tuple[CSSLocation, str, int, str]] = [] + + def get_location(base: Type[DOMNode]) -> CSSLocation: + """Get the original location of this DEFAULT_CSS. + + Args: + base: The class from which the default css was extracted. + + Returns: + The filename where the class was defined (if possible) and the class + variable the CSS was extracted from. + """ + try: + return (getfile(base), f"{base.__name__}.DEFAULT_CSS") + except (TypeError, OSError): + return ("", f"{base.__name__}.DEFAULT_CSS") + + for tie_breaker, base in enumerate(self._node_bases): + css: str = base.__dict__.get("DEFAULT_CSS", "") + if css: + scoped: bool = base.__dict__.get("SCOPED_CSS", True) + css_stack.append( + ( + get_location(base), + css, + -tie_breaker, + base._css_type_name if scoped else "", + ) + ) + return css_stack + + @classmethod + @lru_cache(maxsize=None) + def _get_component_classes(cls) -> frozenset[str]: + """Gets the component classes for this class and inherited from bases. + + Component classes are inherited from base classes, unless + `inherit_component_classes` is set to `False` when subclassing. + + Returns: + A set with all the component classes available. + """ + + component_classes: set[str] = set() + for base in cls._css_bases(cls): + component_classes.update(base.__dict__.get("COMPONENT_CLASSES", set())) + if not base.__dict__.get("_inherit_component_classes", True): + break + + return frozenset(component_classes) + + @property + def parent(self) -> DOMNode | None: + """The parent node. + + All nodes have parent once added to the DOM, with the exception of the App which is the *root* node. + """ + return cast("DOMNode | None", self._parent) + + @property + def screen(self) -> "Screen[object]": + """The screen containing this node. + + Returns: + A screen object. + + Raises: + NoScreen: If this node isn't mounted (and has no screen). + """ + # Get the node by looking up a chain of parents + # Note that self.screen may not be the same as self.app.screen + from memray._vendor.textual.screen import Screen + + node: MessagePump | None = self + try: + while node is not None and not isinstance(node, Screen): + node = node._parent + except AttributeError: + raise RuntimeError( + "Widget is missing attributes; have you called the constructor in your widget class?" + ) from None + if not isinstance(node, Screen): + raise NoScreen("node has no screen") + return node + + @property + def id(self) -> str | None: + """The ID of this node, or None if the node has no ID.""" + return self._id + + @id.setter + def id(self, new_id: str) -> str: + """Sets the ID (may only be done once). + + Args: + new_id: ID for this node. + + Raises: + ValueError: If the ID has already been set. + """ + check_identifiers("id", new_id) + self._nodes.updated() + if self._id is not None: + raise ValueError( + f"Node 'id' attribute may not be changed once set (current id={self._id!r})" + ) + self._id = new_id + return new_id + + @property + def name(self) -> str | None: + """The name of the node.""" + return self._name + + @property + def css_identifier(self) -> str: + """A CSS selector that identifies this DOM node.""" + tokens = [self.__class__.__name__] + if self.id is not None: + tokens.append(f"#{self.id}") + return "".join(tokens) + + @property + def css_identifier_styled(self) -> Text: + """A syntax highlighted CSS identifier. + + Returns: + A Rich Text object. + """ + tokens = Text.styled(self.__class__.__name__) + if self.id is not None: + tokens.append(f"#{self.id}", style="bold") + if self.classes: + tokens.append(".") + tokens.append(".".join(class_name for class_name in self.classes), "italic") + if self.name: + tokens.append(f"[name={self.name}]", style="underline") + return tokens + + classes = _ClassesDescriptor() + """CSS class names for this node.""" + + @property + def pseudo_classes(self) -> frozenset[str]: + """A (frozen) set of all pseudo classes.""" + return frozenset(self.get_pseudo_classes()) + + @property + def css_path_nodes(self) -> list[DOMNode]: + """A list of nodes from the App to this node, forming a "path". + + Returns: + A list of nodes, where the first item is the App, and the last is this node. + """ + result: list[DOMNode] = [self] + append = result.append + + node: DOMNode = self + while isinstance((node := node._parent), DOMNode): + append(node) + return result[::-1] + + @property + def _selector_names(self) -> set[str]: + """Get a set of selectors applicable to this widget. + + Returns: + Set of selector names. + """ + selectors: set[str] = { + "*", + *(f".{class_name}" for class_name in self._classes), + *self._css_types, + } + if self._id is not None: + selectors.add(f"#{self._id}") + return selectors + + @property + def display(self) -> bool: + """Should the DOM node be displayed? + + May be set to a boolean to show or hide the node, or to any valid value for the `display` rule. + + Example: + ```python + my_widget.display = False # Hide my_widget + ``` + """ + return self.styles.display != "none" and not ( + self._closing or self._closed or self._pruning + ) + + @display.setter + def display(self, new_val: bool | str) -> None: + """ + Args: + new_val: Shortcut to set the ``display`` CSS property. + ``False`` will set ``display: none``. ``True`` will set ``display: block``. + A ``False`` value will prevent the DOMNode from consuming space in the layout. + """ + # TODO: This will forget what the original "display" value was, so if a user + # toggles to False then True, we'll reset to the default "block", rather than + # what the user initially specified. + if isinstance(new_val, bool): + self.styles.display = "block" if new_val else "none" + elif new_val in VALID_DISPLAY: + self.styles.display = new_val + else: + raise StyleValueError( + f"invalid value for display (received {new_val!r}, " + f"expected {friendly_list(VALID_DISPLAY)})", + ) + + @property + def visible(self) -> bool: + """Is this widget visible in the DOM? + + If a widget hasn't had its visibility set explicitly, then it inherits it from its + DOM ancestors. + + This may be set explicitly to override inherited values. + The valid values include the valid values for the `visibility` rule and the booleans + `True` or `False`, to set the widget to be visible or invisible, respectively. + + When a node is invisible, Textual will reserve space for it, but won't display anything. + """ + own_value = self.styles.get_rule("visibility") + if own_value is not None: + return own_value != "hidden" + return self.parent.visible if self.parent else True + + @visible.setter + def visible(self, new_value: bool | str) -> None: + if isinstance(new_value, bool): + self.styles.visibility = "visible" if new_value else "hidden" + elif new_value in VALID_VISIBILITY: + self.styles.visibility = new_value + else: + raise StyleValueError( + f"invalid value for visibility (received {new_value!r}, " + f"expected {friendly_list(VALID_VISIBILITY)})" + ) + + @property + def tree(self) -> Tree: + """A Rich tree to display the DOM. + + Log this to visualize your app in the textual console. + + Example: + ```python + self.log(self.tree) + ``` + + Returns: + A Tree renderable. + """ + from rich.pretty import Pretty + + def render_info(node: DOMNode) -> Pretty: + """Render a node for the tree.""" + return Pretty(node) + + tree = Tree(render_info(self)) + + def add_children(tree, node): + for child in node.children: + info = render_info(child) + branch = tree.add(info) + if tree.children: + add_children(branch, child) + + add_children(tree, self) + return tree + + @property + def css_tree(self) -> Tree: + """A Rich tree to display the DOM, annotated with the node's CSS. + + Log this to visualize your app in the textual console. + + Example: + ```python + self.log(self.css_tree) + ``` + + Returns: + A Tree renderable. + """ + from rich.columns import Columns + from rich.console import Group + from rich.panel import Panel + from rich.pretty import Pretty + + from memray._vendor.textual.widget import Widget + + def render_info(node: DOMNode) -> Columns: + """Render a node for the tree.""" + if isinstance(node, Widget): + info = Columns( + [ + Pretty(node), + highlighter(f"region={node.region!r}"), + highlighter( + f"virtual_size={node.virtual_size!r}", + ), + ] + ) + else: + info = Columns([Pretty(node)]) + return info + + highlighter = ReprHighlighter() + tree = Tree(render_info(self)) + + def add_children(tree: Tree, node: DOMNode) -> None: + """Add children to the tree.""" + for child in node.children: + info: RenderableType = render_info(child) + css = child.styles.css + if css: + info = Group( + info, + Panel.fit( + Text(child.styles.css), + border_style="dim", + title="css", + title_align="left", + ), + ) + branch = tree.add(info) + if tree.children: + add_children(branch, child) + + add_children(tree, self) + return tree + + @property + def text_style(self) -> Style: + """Get the text style object. + + A widget's style is influenced by its parent. for instance if a parent is bold, then + the child will also be bold. + + Returns: + A Rich Style. + """ + return Style.combine( + node.styles.text_style for node in reversed(self.ancestors_with_self) + ) + + @property + def selection_style(self) -> Style: + """The style of selected text.""" + style = self.screen.get_component_rich_style( + "screen--selection", default=RICH_NULL_STYLE + ) + return style + + @property + def rich_style(self) -> Style: + """Get a Rich Style object for this DOMNode. + + Returns: + A Rich style. + """ + background = Color(0, 0, 0, 0) + color = Color(255, 255, 255, 0) + + style = Style() + opacity = 1.0 + + for node in reversed(self.ancestors_with_self): + styles = node.styles + has_rule = styles.has_rule + opacity *= styles.opacity + if has_rule("background"): + text_background = background + styles.background.tint( + styles.background_tint + ) + background += ( + styles.background.tint(styles.background_tint) + ).multiply_alpha(opacity) + else: + text_background = background + if has_rule("color"): + color = styles.color + style += styles.text_style + if has_rule("auto_color") and styles.auto_color: + color = text_background.get_contrast_text(color.a) + + style += Style.from_color( + (background + color).rich_color if (background.a or color.a) else None, + background.rich_color if background.a else None, + ) + return style + + def check_consume_key(self, key: str, character: str | None) -> bool: + """Check if the widget may consume the given key. + + This should be implemented in widgets that handle [`Key`][textual.events.Key] events and + stop propagation (such as Input and TextArea). + + Implementing this method will hide key bindings from the footer and key panel that would + be *consumed* by the focused widget. + + Args: + key: A key identifier. + character: A character associated with the key, or `None` if there isn't one. + + Returns: + `True` if the widget may capture the key in its `Key` event handler, or `False` if it won't. + """ + return False + + def _get_title_style_information( + self, background: Color + ) -> tuple[Color, Color, VisualStyle]: + """Get a Visual Style object for titles. + + Args: + background: The background color. + + Returns: + A Rich style. + """ + + styles = self.styles + if styles.auto_border_title_color: + color = background.get_contrast_text(styles.border_title_color.a) + else: + color = styles.border_title_color + return ( + color, + styles.border_title_background, + VisualStyle.from_rich_style(styles.border_title_style), + ) + + def _get_subtitle_style_information( + self, background: Color + ) -> tuple[Color, Color, VisualStyle]: + """Get a Rich Style object for subtitles. + + Args: + background: The background color. + + Returns: + A Rich style. + """ + styles = self.styles + if styles.auto_border_subtitle_color: + color = background.get_contrast_text(styles.border_subtitle_color.a) + else: + color = styles.border_subtitle_color + return ( + color, + styles.border_subtitle_background, + VisualStyle.from_rich_style(styles.border_subtitle_style), + ) + + @property + def background_colors(self) -> tuple[Color, Color]: + """Background colors adjusted for opacity. + + Returns: + `(, )` + """ + base_background = background = Color(0, 0, 0, 0) + opacity = 1.0 + for node in reversed(self.ancestors_with_self): + styles = node.styles + base_background = background + opacity *= styles.opacity + background += styles.background.tint(styles.background_tint).multiply_alpha( + opacity + ) + return (base_background, background) + + @property + def colors(self) -> tuple[Color, Color, Color, Color]: + """The widget's background and foreground colors, and the parent's background and foreground colors. + + Returns: + `(, , , )` + """ + base_background = background = WHITE + base_color = color = BLACK + for node in reversed(self.ancestors_with_self): + styles = node.styles + base_background = background + background += styles.background.tint(styles.background_tint) + if styles.has_rule("color"): + base_color = color + if styles.auto_color: + color = background.get_contrast_text(color.a) + else: + color = styles.color + + return (base_background, base_color, background, color) + + @property + def ancestors_with_self(self) -> list[DOMNode]: + """A list of ancestor nodes found by tracing a path all the way back to App. + + Note: + This is inclusive of ``self``. + + Returns: + A list of nodes. + """ + nodes: list[MessagePump | None] = [self] + add_node = nodes.append + node: MessagePump | None = self + while (node := node._parent) is not None: + add_node(node) + return cast("list[DOMNode]", nodes) + + @property + def ancestors(self) -> list[DOMNode]: + """A list of ancestor nodes found by tracing a path all the way back to App. + + Returns: + A list of nodes. + """ + nodes: list[MessagePump | None] = [] + add_node = nodes.append + node: MessagePump | None = self + while (node := node._parent) is not None: + add_node(node) + return cast("list[DOMNode]", nodes) + + def watch( + self, + obj: DOMNode, + attribute_name: str, + callback: WatchCallbackType, + init: bool = True, + ) -> None: + """Watches for modifications to reactive attributes on another object. + + Example: + ```python + def on_theme_change(old_value:str, new_value:str) -> None: + # Called when app.theme changes. + print(f"App.theme went from {old_value} to {new_value}") + + self.watch(self.app, "theme", self.on_theme_change, init=False) + ``` + + Args: + obj: Object containing attribute to watch. + attribute_name: Attribute to watch. + callback: A callback to run when attribute changes. + init: Check watchers on first call. + """ + _watch(self, obj, attribute_name, callback, init=init) + + def get_pseudo_classes(self) -> set[str]: + """Pseudo classes for a widget. + + Returns: + Names of the pseudo classes. + """ + + return { + name + for name, check_class in self._PSEUDO_CLASSES.items() + if check_class(self) + } + + def reset_styles(self) -> None: + """Reset styles back to their initial state.""" + from memray._vendor.textual.widget import Widget + + for node in self.walk_children(with_self=True): + node._css_styles.reset() + if isinstance(node, Widget): + node._set_dirty() + node._layout_required = True + + def _add_child(self, node: Widget) -> None: + """Add a new child node. + + !!! note + For tests only. + + Args: + node: A DOM node. + """ + self._nodes._append(node) + node._attach(self) + + def _add_children(self, *nodes: Widget) -> None: + """Add multiple children to this node. + + !!! note + For tests only. + + Args: + *nodes: Positional args should be new DOM nodes. + """ + _append = self._nodes._append + for node in nodes: + node._attach(self) + _append(node) + node._add_children(*node._pending_children) + + WalkType = TypeVar("WalkType", bound="DOMNode") + + if TYPE_CHECKING: + + @overload + def walk_children( + self, + filter_type: type[WalkType], + *, + with_self: bool = False, + method: WalkMethod = "depth", + reverse: bool = False, + ) -> list[WalkType]: ... + + @overload + def walk_children( + self, + *, + with_self: bool = False, + method: WalkMethod = "depth", + reverse: bool = False, + ) -> list[DOMNode]: ... + + def walk_children( + self, + filter_type: type[WalkType] | None = None, + *, + with_self: bool = False, + method: WalkMethod = "depth", + reverse: bool = False, + ) -> list[DOMNode] | list[WalkType]: + """Walk the subtree rooted at this node, and return every descendant encountered in a list. + + Args: + filter_type: Filter only this type, or None for no filter. + with_self: Also yield self in addition to descendants. + method: One of "depth" or "breadth". + reverse: Reverse the order (bottom up). + + Returns: + A list of nodes. + """ + check_type = filter_type or DOMNode + + node_generator = ( + walk_depth_first(self, check_type, with_root=with_self) + if method == "depth" + else walk_breadth_first(self, check_type, with_root=with_self) + ) + + # We want a snapshot of the DOM at this point So that it doesn't + # change mid-walk + nodes = list(node_generator) + if reverse: + nodes.reverse() + return cast("list[DOMNode]", nodes) + + if TYPE_CHECKING: + + @overload + def query(self, selector: str | None = None) -> DOMQuery[Widget]: ... + + @overload + def query(self, selector: type[QueryType]) -> DOMQuery[QueryType]: ... + + def query( + self, selector: str | type[QueryType] | None = None + ) -> DOMQuery[Widget] | DOMQuery[QueryType]: + """Query the DOM for children that match a selector or widget type. + + Args: + selector: A CSS selector, widget type, or `None` for all nodes. + + Returns: + A query object. + """ + from memray._vendor.textual.css.query import DOMQuery, QueryType + from memray._vendor.textual.widget import Widget + + node = self._get_dom_base() + if isinstance(selector, str) or selector is None: + return DOMQuery[Widget](node, filter=selector) + else: + return DOMQuery[QueryType](node, filter=selector.__name__) + + if TYPE_CHECKING: + + @overload + def query_children(self, selector: str | None = None) -> DOMQuery[Widget]: ... + + @overload + def query_children(self, selector: type[QueryType]) -> DOMQuery[QueryType]: ... + + def query_children( + self, selector: str | type[QueryType] | None = None + ) -> DOMQuery[Widget] | DOMQuery[QueryType]: + """Query the DOM for the immediate children that match a selector or widget type. + + Note that this will not return child widgets more than a single level deep. + If you want to a query to potentially match all children in the widget tree, + see [query][textual.dom.DOMNode.query]. + + Args: + selector: A CSS selector, widget type, or `None` for all nodes. + + Returns: + A query object. + """ + from memray._vendor.textual.css.query import DOMQuery, QueryType + from memray._vendor.textual.widget import Widget + + node = self._get_dom_base() + if isinstance(selector, str) or selector is None: + return DOMQuery[Widget](node, deep=False, filter=selector) + else: + return DOMQuery[QueryType](node, deep=False, filter=selector.__name__) + + if TYPE_CHECKING: + + @overload + def query_one(self, selector: str) -> Widget: ... + + @overload + def query_one(self, selector: type[QueryType]) -> QueryType: ... + + @overload + def query_one( + self, selector: str, expect_type: type[QueryType] + ) -> QueryType: ... + + def query_one( + self, + selector: str | type[QueryType], + expect_type: type[QueryType] | None = None, + ) -> QueryType | Widget: + """Get a widget from this widget's children that matches a selector or widget type. + + Args: + selector: A selector or widget type. + expect_type: Require the object be of the supplied type, or None for any type. + + Raises: + WrongType: If the wrong type was found. + NoMatches: If no node matches the query. + + Returns: + A widget matching the selector. + """ + _rich_traceback_omit = True + + base_node = self._get_dom_base() + + if isinstance(selector, str): + query_selector = selector + else: + query_selector = selector.__name__ + + if is_id_selector(query_selector): + cache_key = (base_node._nodes._updates, query_selector, expect_type) + cached_result = base_node._query_one_cache.get(cache_key) + if cached_result is not None: + return cached_result + if ( + node := walk_breadth_search_id( + base_node, query_selector[1:], with_root=False + ) + ) is not None: + if expect_type is not None and not isinstance(node, expect_type): + raise WrongType( + f"Node matching {query_selector!r} is the wrong type; expected type {expect_type.__name__!r}, found {node}" + ) + base_node._query_one_cache[cache_key] = node + return node + raise NoMatches(f"No nodes match {query_selector!r} on {base_node!r}") + + try: + selector_set = parse_selectors(query_selector) + except TokenError: + raise InvalidQueryFormat( + f"Unable to parse {query_selector!r} as a query; check for syntax errors" + ) from None + + if all(selectors.is_simple for selectors in selector_set): + cache_key = (base_node._nodes._updates, query_selector, expect_type) + cached_result = base_node._query_one_cache.get(cache_key) + if cached_result is not None: + return cached_result + else: + cache_key = None + + for node in walk_breadth_first(base_node, with_root=False): + if not match(selector_set, node): + continue + if expect_type is not None and not isinstance(node, expect_type): + raise WrongType( + f"Node matching {query_selector!r} is the wrong type; expected type {expect_type.__name__!r}, found {node}" + ) + if cache_key is not None: + base_node._query_one_cache[cache_key] = node + return node + + raise NoMatches(f"No nodes match {query_selector!r} on {base_node!r}") + + if TYPE_CHECKING: + + @overload + def query_one_optional(self, selector: str) -> Widget | None: ... + + @overload + def query_one_optional(self, selector: type[QueryType]) -> QueryType | None: ... + + @overload + def query_one_optional( + self, selector: str, expect_type: type[QueryType] + ) -> QueryType | None: ... + + def query_one_optional( + self, + selector: str | type[QueryType], + expect_type: type[QueryType] | None = None, + ) -> QueryType | Widget | None: + """Get a widget from this widget's children that matches a selector or widget type, + or `None` if there is no match. + + Args: + selector: A selector or widget type. + expect_type: Require the object be of the supplied type, or None for any type. + + Raises: + WrongType: If the wrong type was found. + + Returns: + A widget matching the selector, or `None`. + """ + try: + widget = self.query_one(selector, expect_type) + except NoMatches: + return None + return widget + + if TYPE_CHECKING: + + @overload + def query_exactly_one(self, selector: str) -> Widget: ... + + @overload + def query_exactly_one(self, selector: type[QueryType]) -> QueryType: ... + + @overload + def query_exactly_one( + self, selector: str, expect_type: type[QueryType] + ) -> QueryType: ... + + def query_exactly_one( + self, + selector: str | type[QueryType], + expect_type: type[QueryType] | None = None, + ) -> QueryType | Widget: + """Get a widget from this widget's children that matches a selector or widget type. + + !!! Note + This method is similar to [query_one][textual.dom.DOMNode.query_one]. + The only difference is that it will raise `TooManyMatches` if there is more than a single match. + + Args: + selector: A selector or widget type. + expect_type: Require the object be of the supplied type, or None for any type. + + Raises: + WrongType: If the wrong type was found. + NoMatches: If no node matches the query. + TooManyMatches: If there is more than one matching node in the query (and `exactly_one==True`). + + Returns: + A widget matching the selector. + """ + _rich_traceback_omit = True + + base_node = self._get_dom_base() + + if isinstance(selector, str): + query_selector = selector + else: + query_selector = selector.__name__ + + try: + selector_set = parse_selectors(query_selector) + except TokenError: + raise InvalidQueryFormat( + f"Unable to parse {query_selector!r} as a query; check for syntax errors" + ) from None + + if all(selectors.is_simple for selectors in selector_set): + cache_key = (base_node._nodes._updates, query_selector, expect_type) + cached_result = base_node._query_one_cache.get(cache_key) + if cached_result is not None: + return cached_result + else: + cache_key = None + + children = walk_breadth_first(base_node, with_root=False) + iter_children = iter(children) + for node in iter_children: + if not match(selector_set, node): + continue + if expect_type is not None and not isinstance(node, expect_type): + raise WrongType( + f"Node matching {query_selector!r} is the wrong type; expected type {expect_type.__name__!r}, found {node}" + ) + for later_node in iter_children: + if match(selector_set, later_node): + raise TooManyMatches( + "Call to query_one resulted in more than one matched node" + ) + if cache_key is not None: + base_node._query_one_cache[cache_key] = node + return node + + raise NoMatches(f"No nodes match {query_selector!r} on {base_node!r}") + + if TYPE_CHECKING: + + @overload + def query_ancestor(self, selector: str) -> DOMNode: ... + + @overload + def query_ancestor(self, selector: type[QueryType]) -> QueryType: ... + + @overload + def query_ancestor( + self, selector: str, expect_type: type[QueryType] + ) -> QueryType: ... + + def query_ancestor( + self, + selector: str | type[QueryType], + expect_type: type[QueryType] | None = None, + ) -> DOMNode: + """Get an ancestor which matches a query. + + Args: + selector: A TCSS selector. + expect_type: Expected type, or `None` for any DOMNode. + + Raises: + InvalidQueryFormat: If the selector is invalid. + NoMatches: If there are no matching ancestors. + + Returns: + A DOMNode or subclass if `expect_type` is provided. + """ + base_node = self._get_dom_base() + if isinstance(selector, str): + query_selector = selector + else: + query_selector = selector.__name__ + + try: + selector_set = parse_selectors(query_selector) + except TokenError: + raise InvalidQueryFormat( + f"Unable to parse {query_selector!r} as a query; check for syntax errors" + ) from None + if base_node.parent is not None: + for node in base_node.parent.ancestors_with_self: + if not match(selector_set, node): + continue + if expect_type is not None and not isinstance(node, expect_type): + continue + return node + raise NoMatches(f"No ancestor matches {selector!r} on {self!r}") + + def set_styles(self, css: str | None = None, **update_styles: Any) -> Self: + """Set custom styles on this object. + + Args: + css: Styles in CSS format. + update_styles: Keyword arguments map style names onto style values. + + Returns: + Self. + """ + + if css is not None: + try: + new_styles = parse_declarations(css, read_from=("set_styles", "")) + except DeclarationError as error: + raise DeclarationError(error.name, error.token, error.message) from None + self._inline_styles.merge(new_styles) + self.refresh(layout=True) + + styles = self.styles + for key, value in update_styles.items(): + setattr(styles, key, value) + return self + + def has_class(self, *class_names: str) -> bool: + """Check if the Node has all the given class names. + + Args: + *class_names: CSS class names to check. + + Returns: + ``True`` if the node has all the given class names, otherwise ``False``. + """ + return self._classes.issuperset(class_names) + + def set_class(self, add: bool, *class_names: str, update: bool = True) -> Self: + """Add or remove class(es) based on a condition. + + This can condense the four lines required to implement the equivalent branch into a single line. + + Example: + ```python + #if foo: + # self.add_class("-foo") + #else: + # self.remove_class("-foo") + self.set_class(foo, "-foo") + ``` + + Args: + add: Add the classes if True, otherwise remove them. + update: Also update styles. + + Returns: + Self. + """ + if add: + self.add_class(*class_names, update=update) + else: + self.remove_class(*class_names, update=update) + return self + + def set_classes(self, classes: str | Iterable[str]) -> Self: + """Replace all classes. + + Args: + classes: A string containing space separated classes, or an + iterable of class names. + + Returns: + Self. + """ + self.classes = classes + return self + + def update_node_styles(self, animate: bool = True) -> None: + """Request an update of this node's styles. + + Called by Textual whenever CSS classes / pseudo classes change. + """ + try: + self.app.update_styles(self, animate=animate) + except NoActiveAppError: + pass + + def add_class(self, *class_names: str, update: bool = True) -> Self: + """Add class names to this Node. + + Args: + *class_names: CSS class names to add. + update: Also update styles. + + Returns: + Self. + """ + check_identifiers("class name", *class_names) + old_classes = self._classes.copy() + self._classes.update(class_names) + if old_classes == self._classes: + return self + if update: + self.update_node_styles() + return self + + def remove_class(self, *class_names: str, update: bool = True) -> Self: + """Remove class names from this Node. + + Args: + *class_names: CSS class names to remove. + update: Also update styles. + + Returns: + Self. + """ + check_identifiers("class name", *class_names) + old_classes = self._classes.copy() + self._classes.difference_update(class_names) + if old_classes == self._classes: + return self + if update: + self.update_node_styles() + return self + + def toggle_class(self, *class_names: str) -> Self: + """Toggle class names on this Node. + + Args: + *class_names: CSS class names to toggle. + + Returns: + Self. + """ + check_identifiers("class name", *class_names) + old_classes = self._classes.copy() + self._classes.symmetric_difference_update(class_names) + if old_classes == self._classes: + return self + self.update_node_styles() + return self + + def has_pseudo_class(self, class_name: str) -> bool: + """Check the node has the given pseudo class. + + Args: + class_name: The pseudo class to check for. + + Returns: + `True` if the DOM node has the pseudo class, `False` if not. + """ + try: + return self._PSEUDO_CLASSES[class_name](self) + except KeyError: + return False + + def has_pseudo_classes(self, class_names: set[str]) -> bool: + """Check the node has all the given pseudo classes. + + Args: + class_names: Set of class names to check for. + + Returns: + `True` if all pseudo class names are present. + """ + PSEUDO_CLASSES = self._PSEUDO_CLASSES + try: + return all(PSEUDO_CLASSES[name](self) for name in class_names) + except KeyError: + return False + + @property + def _pseudo_classes_cache_key(self) -> tuple[int, ...]: + """A cache key used when updating a number of nodes from the stylesheet.""" + return () + + def refresh( + self, *, repaint: bool = True, layout: bool = False, recompose: bool = False + ) -> Self: + return self + + def check_action(self, action: str, parameters: tuple[object, ...]) -> bool | None: + """Check whether an action is enabled. + + Implement this method to add logic for [dynamic actions](/guide/actions#dynamic-actions) / bindings. + + Args: + action: The name of an action. + parameters: A tuple of any action parameters. + + Returns: + `True` if the action is enabled+visible, + `False` if the action is disabled+hidden, + `None` if the action is disabled+visible (grayed out in footer) + """ + return True + + def refresh_bindings(self) -> None: + """Call to prompt widgets such as the [Footer][textual.widgets.Footer] to update + the display of key bindings. + + See [actions](/guide/actions#dynamic-actions) for how to use this method. + + """ + if self._is_mounted: + self.screen.refresh_bindings() + + async def action_toggle(self, attribute_name: str) -> None: + """Toggle an attribute on the node. + + Assumes the attribute is a bool. + + Args: + attribute_name: Name of the attribute. + """ + value = getattr(self, attribute_name) + setattr(self, attribute_name, not value) diff --git a/src/memray/_vendor/textual/driver.py b/src/memray/_vendor/textual/driver.py new file mode 100644 index 0000000000..3919bd2ab3 --- /dev/null +++ b/src/memray/_vendor/textual/driver.py @@ -0,0 +1,301 @@ +from __future__ import annotations + +import asyncio +import threading +from abc import ABC, abstractmethod +from contextlib import contextmanager +from pathlib import Path +from typing import TYPE_CHECKING, Any, BinaryIO, Iterator, Literal, TextIO + +from memray._vendor.textual import events, log, messages +from memray._vendor.textual.events import MouseUp + +if TYPE_CHECKING: + from memray._vendor.textual.app import App + + +class Driver(ABC): + """A base class for drivers.""" + + def __init__( + self, + app: App[Any], + *, + debug: bool = False, + mouse: bool = True, + size: tuple[int, int] | None = None, + ) -> None: + """Initialize a driver. + + Args: + app: The App instance. + debug: Enable debug mode. + mouse: Enable mouse support, + size: Initial size of the terminal or `None` to detect. + """ + self._app = app + self._debug = debug + self._mouse = mouse + self._size = size + self._loop = asyncio.get_running_loop() + self._down_buttons: list[int] = [] + self._last_move_event: events.MouseMove | None = None + self._auto_restart = True + """Should the application auto-restart (where appropriate)?""" + self.cursor_origin: tuple[int, int] | None = None + + @property + def is_headless(self) -> bool: + """Is the driver 'headless' (no output)?""" + return False + + @property + def is_inline(self) -> bool: + """Is the driver 'inline' (not full-screen)?""" + return False + + @property + def is_web(self) -> bool: + """Is the driver 'web' (running via a browser)?""" + return False + + @property + def can_suspend(self) -> bool: + """Can this driver be suspended?""" + return False + + def send_message(self, message: messages.Message) -> None: + """Send a message to the target app. + + Args: + message: A message. + """ + asyncio.run_coroutine_threadsafe( + self._app._post_message(message), loop=self._loop + ) + + def process_message(self, message: messages.Message) -> None: + """Perform additional processing on a message, prior to sending. + + Args: + event: A message to process. + """ + # NOTE: This runs in a thread. + # Avoid calling methods on the app. + message.set_sender(self._app) + if self.cursor_origin is None: + offset_x = 0 + offset_y = 0 + else: + offset_x, offset_y = self.cursor_origin + if isinstance(message, events.MouseEvent): + message._x -= offset_x + message._y -= offset_y + message._screen_x -= offset_x + message._screen_y -= offset_y + + if isinstance(message, events.MouseDown): + if message.button: + self._down_buttons.append(message.button) + elif isinstance(message, events.MouseUp): + if message.button and message.button in self._down_buttons: + self._down_buttons.remove(message.button) + elif isinstance(message, events.MouseMove): + if ( + self._down_buttons + and not message.button + and self._last_move_event is not None + ): + # Deduplicate self._down_buttons while preserving order. + buttons = list(dict.fromkeys(self._down_buttons).keys()) + self._down_buttons.clear() + move_event = self._last_move_event + for button in buttons: + self.send_message( + MouseUp( + message.widget, + x=move_event.x, + y=move_event.y, + delta_x=0, + delta_y=0, + button=button, + shift=message.shift, + meta=message.meta, + ctrl=message.ctrl, + screen_x=move_event.screen_x, + screen_y=move_event.screen_y, + style=message.style, + ) + ) + self._last_move_event = message + + self.send_message(message) + + @abstractmethod + def write(self, data: str) -> None: + """Write data to the output device. + + Args: + data: Raw data. + """ + + def flush(self) -> None: + """Flush any buffered data.""" + + @abstractmethod + def start_application_mode(self) -> None: + """Start application mode.""" + + @abstractmethod + def disable_input(self) -> None: + """Disable further input.""" + + @abstractmethod + def stop_application_mode(self) -> None: + """Stop application mode, restore state.""" + + def suspend_application_mode(self) -> None: + """Suspend application mode. + + Used to suspend application mode and allow uninhibited access to the + terminal. + """ + self.stop_application_mode() + self.close() + + def resume_application_mode(self) -> None: + """Resume application mode. + + Used to resume application mode after it has been previously + suspended. + """ + self.start_application_mode() + + class SignalResume(events.Event): + """Event sent to the app when a resume signal should be published.""" + + @contextmanager + def no_automatic_restart(self) -> Iterator[None]: + """A context manager used to tell the driver to not auto-restart. + + For drivers that support the application being suspended by the + operating system, this context manager is used to mark a body of + code as one that will manage its own stop and start. + """ + auto_restart = self._auto_restart + self._auto_restart = False + try: + yield + finally: + self._auto_restart = auto_restart + + def close(self) -> None: + """Perform any final cleanup.""" + + def open_url(self, url: str, new_tab: bool = True) -> None: + """Open a URL in the default web browser. + + Args: + url: The URL to open. + new_tab: Whether to open the URL in a new tab. + This is only relevant when running via the WebDriver, + and is ignored when called while running through the terminal. + """ + import webbrowser + + webbrowser.open(url) + + def deliver_binary( + self, + binary: BinaryIO | TextIO, + *, + delivery_key: str, + save_path: Path, + open_method: Literal["browser", "download"] = "download", + encoding: str | None = None, + mime_type: str | None = None, + name: str | None = None, + ) -> None: + """Save the file `path_or_file` to `save_path`. + + If running via web through Textual Web or Textual Serve, + this will initiate a download in the web browser. + + Args: + binary: The binary file to save. + delivery_key: The unique key that was used to deliver the file. + save_path: The location to save the file to. + open_method: *web only* Whether to open the file in the browser or + to prompt the user to download it. When running via a standard + (non-web) terminal, this is ignored. + encoding: *web only* The text encoding to use when saving the file. + This will be passed to Python's `open()` built-in function. + When running via web, this will be used to set the charset + in the `Content-Type` header. + mime_type: *web only* The MIME type of the file. This will be used to + set the `Content-Type` header in the HTTP response. + name: A user-defined name which will be returned in [`DeliveryComplete`][textual.events.DeliveryComplete] + and [`DeliveryFailed`][textual.events.DeliveryFailed]. + + """ + + def save_file_thread(binary: BinaryIO | TextIO, mode: str) -> None: + try: + with open( + save_path, mode, encoding=encoding or "utf-8" + ) as destination_file: + read = binary.read + write = destination_file.write + chunk_size = 1024 * 64 + while True: + data = read(chunk_size) + if not data: + # No data left to read - delivery is complete. + self._delivery_complete( + delivery_key, save_path=save_path, name=name + ) + break + write(data) + except Exception as error: + # If any exception occurs during the delivery, pass + # it on to the app via a DeliveryFailed event. + log.error(f"Failed to deliver file: {error}") + import traceback + + log.error(str(traceback.format_exc())) + self._delivery_failed(delivery_key, exception=error, name=name) + finally: + if not binary.closed: + binary.close() + + if isinstance(binary, BinaryIO): + mode = "wb" + else: + mode = "w" + + thread = threading.Thread(target=save_file_thread, args=(binary, mode)) + thread.start() + + def _delivery_complete( + self, delivery_key: str, save_path: Path | None, name: str | None + ) -> None: + """Called when a file has been delivered successfully. + + Delivers a DeliveryComplete event to the app. + """ + self._app.call_from_thread( + self._app.post_message, + events.DeliveryComplete(key=delivery_key, path=save_path, name=name), + ) + + def _delivery_failed( + self, delivery_key: str, exception: BaseException, name: str | None + ) -> None: + """Called when a file delivery fails. + + Delivers a DeliveryFailed event to the app. + """ + self._app.call_from_thread( + self._app.post_message, + events.DeliveryFailed(key=delivery_key, exception=exception, name=name), + ) diff --git a/src/memray/_vendor/textual/drivers/__init__.py b/src/memray/_vendor/textual/drivers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/memray/_vendor/textual/drivers/_byte_stream.py b/src/memray/_vendor/textual/drivers/_byte_stream.py new file mode 100644 index 0000000000..4c6bb602f0 --- /dev/null +++ b/src/memray/_vendor/textual/drivers/_byte_stream.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +import io +from collections import deque +from typing import ( + Callable, + Deque, + Generator, + Generic, + Iterable, + NamedTuple, + Tuple, + TypeVar, +) + +from typing_extensions import TypeAlias + + +class ParseError(Exception): + """Parse related errors.""" + + +class ParseEOF(ParseError): + """End of Stream.""" + + +class Awaitable: + """Base class for an parser awaitable.""" + + __slots__: list[str] = [] + + +class _Read(Awaitable): + """Read a predefined number of bytes.""" + + __slots__ = ["remaining"] + + def __init__(self, count: int) -> None: + self.remaining = count + + +class _Read1(Awaitable): + """Read a single byte.""" + + __slots__: list[str] = [] + + +TokenType = TypeVar("TokenType") + +ByteStreamTokenCallback: TypeAlias = Callable[[TokenType], None] + + +class ByteStreamParser(Generic[TokenType]): + """A parser to feed in binary data and generate a sequence of tokens.""" + + read = _Read + read1 = _Read1 + + def __init__(self) -> None: + """Initialize the parser.""" + self._buffer = io.BytesIO() + self._eof = False + self._tokens: Deque[TokenType] = deque() + self._gen = self.parse(self._tokens.append) + self._awaiting: Awaitable | TokenType = next(self._gen) + + @property + def is_eof(self) -> bool: + """Is the parser at the end of file?""" + return self._eof + + def feed(self, data: bytes) -> Iterable[TokenType]: + """Feed the parser some data, return an iterable of tokens.""" + if self._eof: + raise ParseError("end of file reached") from None + if not data: + self._eof = True + try: + self._gen.send(self._buffer.getvalue()) + except StopIteration: + raise ParseError("end of file reached") from None + while self._tokens: + yield self._tokens.popleft() + + self._buffer.truncate(0) + return + + _buffer = self._buffer + pos = 0 + tokens = self._tokens + popleft = tokens.popleft + data_size = len(data) + + while tokens: + yield popleft() + + while pos < data_size: + _awaiting = self._awaiting + if isinstance(_awaiting, _Read1): + self._awaiting = self._gen.send(data[pos : pos + 1]) + pos += 1 + elif isinstance(_awaiting, _Read): + remaining = _awaiting.remaining + chunk = data[pos : pos + remaining] + chunk_size = len(chunk) + pos += chunk_size + _buffer.write(chunk) + remaining -= chunk_size + if remaining: + _awaiting.remaining = remaining + else: + self._awaiting = self._gen.send(_buffer.getvalue()) + _buffer.seek(0) + _buffer.truncate() + + while tokens: + yield popleft() + + def parse( + self, on_token: ByteStreamTokenCallback + ) -> Generator[Awaitable, bytes, None]: + """Implement in a sub-class to define parse behavior. + + Args: + on_token: A callable which accepts the token type, and returns None. + + """ + yield from () + + +class BytePacket(NamedTuple): + """A type and payload.""" + + type: str + payload: bytes + + +class ByteStream(ByteStreamParser[Tuple[str, bytes]]): + """A stream of packets in the following format. + + 1 Byte for the type. + 4 Bytes for the big endian encoded size + Arbitrary payload + + """ + + def parse( + self, on_token: ByteStreamTokenCallback + ) -> Generator[Awaitable, bytes, None]: + read1 = self.read1 + read = self.read + from_bytes = int.from_bytes + while not self.is_eof: + packet_type = (yield read1()).decode("utf-8", "ignore") + size = from_bytes((yield read(4)), "big") + payload = (yield read(size)) if size else b"" + on_token(BytePacket(packet_type, payload)) diff --git a/src/memray/_vendor/textual/drivers/_input_reader.py b/src/memray/_vendor/textual/drivers/_input_reader.py new file mode 100644 index 0000000000..fa8755e9b2 --- /dev/null +++ b/src/memray/_vendor/textual/drivers/_input_reader.py @@ -0,0 +1,10 @@ +import sys + +__all__ = ["InputReader"] + +WINDOWS = sys.platform == "win32" + +if WINDOWS: + from memray._vendor.textual.drivers._input_reader_windows import InputReader +else: + from memray._vendor.textual.drivers._input_reader_linux import InputReader diff --git a/src/memray/_vendor/textual/drivers/_input_reader_linux.py b/src/memray/_vendor/textual/drivers/_input_reader_linux.py new file mode 100644 index 0000000000..a4bae77935 --- /dev/null +++ b/src/memray/_vendor/textual/drivers/_input_reader_linux.py @@ -0,0 +1,40 @@ +import os +import selectors +import sys +from threading import Event +from typing import Iterator + + +class InputReader: + """Read input from stdin.""" + + def __init__(self, timeout: float = 0.1) -> None: + """ + + Args: + timeout: Seconds to block for input. + """ + self._fileno = sys.__stdin__.fileno() + self.timeout = timeout + self._selector = selectors.DefaultSelector() + self._selector.register(self._fileno, selectors.EVENT_READ) + self._exit_event = Event() + + def close(self) -> None: + """Close the reader (will exit the iterator).""" + self._exit_event.set() + + def __iter__(self) -> Iterator[bytes]: + """Read input, yield bytes.""" + fileno = self._fileno + read = os.read + exit_set = self._exit_event.is_set + EVENT_READ = selectors.EVENT_READ + while not exit_set(): + for _key, events in self._selector.select(self.timeout): + if events & EVENT_READ: + data = read(fileno, 1024) + if not data: + return + yield data + yield b"" diff --git a/src/memray/_vendor/textual/drivers/_input_reader_windows.py b/src/memray/_vendor/textual/drivers/_input_reader_windows.py new file mode 100644 index 0000000000..c001c728e2 --- /dev/null +++ b/src/memray/_vendor/textual/drivers/_input_reader_windows.py @@ -0,0 +1,33 @@ +import os +import sys +from threading import Event +from typing import Iterator + + +class InputReader: + """Read input from stdin.""" + + def __init__(self, timeout: float = 0.1) -> None: + """ + + Args: + timeout: Seconds to block for input. + """ + self._fileno = sys.__stdin__.fileno() + self.timeout = timeout + self._exit_event = Event() + + def close(self) -> None: + """Close the reader (will exit the iterator).""" + self._exit_event.set() + + def __iter__(self) -> Iterator[bytes]: + """Read input, yield bytes.""" + while not self._exit_event.is_set(): + try: + data = os.read(self._fileno, 1024) or None + except Exception: + break + if not data: + break + yield data diff --git a/src/memray/_vendor/textual/drivers/_writer_thread.py b/src/memray/_vendor/textual/drivers/_writer_thread.py new file mode 100644 index 0000000000..a26ef46fbb --- /dev/null +++ b/src/memray/_vendor/textual/drivers/_writer_thread.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +import threading +from queue import Queue +from typing import IO + +from typing_extensions import Final + +MAX_QUEUED_WRITES: Final[int] = 30 + + +class WriterThread(threading.Thread): + """A thread / file-like to do writes to stdout in the background.""" + + def __init__(self, file: IO[str]) -> None: + super().__init__(daemon=True, name="textual-output") + self._queue: Queue[str | None] = Queue(MAX_QUEUED_WRITES) + self._file = file + + def write(self, text: str) -> None: + """Write text. Text will be enqueued for writing. + + Args: + text: Text to write to the file. + """ + self._queue.put(text) + + def isatty(self) -> bool: + """Pretend to be a terminal. + + Returns: + True. + """ + return True + + def fileno(self) -> int: + """Get file handle number. + + Returns: + File number of proxied file. + """ + return self._file.fileno() + + def flush(self) -> None: + """Flush the file (a no-op, because flush is done in the thread).""" + return + + def run(self) -> None: + """Run the thread.""" + write = self._file.write + flush = self._file.flush + get = self._queue.get + qsize = self._queue.qsize + # Read from the queue, write to the file. + # Flush when there is a break. + while True: + text: str | None = get() + if text is None: + break + write(text) + if qsize() == 0: + flush() + flush() + + def stop(self) -> None: + """Stop the thread, and block until it finished.""" + self._queue.put(None) + self.join() diff --git a/src/memray/_vendor/textual/drivers/headless_driver.py b/src/memray/_vendor/textual/drivers/headless_driver.py new file mode 100644 index 0000000000..49e6df38ec --- /dev/null +++ b/src/memray/_vendor/textual/drivers/headless_driver.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +import asyncio + +from memray._vendor.textual import events +from memray._vendor.textual.driver import Driver +from memray._vendor.textual.geometry import Size + + +class HeadlessDriver(Driver): + """A do-nothing driver for testing.""" + + @property + def is_headless(self) -> bool: + """Is the driver running in 'headless' mode?""" + return True + + def _get_terminal_size(self) -> tuple[int, int]: + if self._size is not None: + return self._size + width: int | None = 80 + height: int | None = 25 + import shutil + + try: + width, height = shutil.get_terminal_size() + except (AttributeError, ValueError, OSError): + try: + width, height = shutil.get_terminal_size() + except (AttributeError, ValueError, OSError): + pass + width = width or 80 + height = height or 25 + return width, height + + def write(self, data: str) -> None: + """Write data to the output device. + + Args: + data: Raw data. + """ + # Nothing to write as this is a headless driver. + + def start_application_mode(self) -> None: + """Start application mode.""" + loop = asyncio.get_running_loop() + + def send_size_event() -> None: + """Send first resize event.""" + terminal_size = self._get_terminal_size() + width, height = terminal_size + textual_size = Size(width, height) + event = events.Resize(textual_size, textual_size) + asyncio.run_coroutine_threadsafe( + self._app._post_message(event), + loop=loop, + ) + + send_size_event() + + def disable_input(self) -> None: + """Disable further input.""" + + def stop_application_mode(self) -> None: + """Stop application mode, restore state.""" + # Nothing to do diff --git a/src/memray/_vendor/textual/drivers/linux_driver.py b/src/memray/_vendor/textual/drivers/linux_driver.py new file mode 100644 index 0000000000..0807bb68c9 --- /dev/null +++ b/src/memray/_vendor/textual/drivers/linux_driver.py @@ -0,0 +1,469 @@ +from __future__ import annotations + +import asyncio +import os +import selectors +import signal +import sys +import termios +import tty +from codecs import getincrementaldecoder +from threading import Event, Thread +from typing import TYPE_CHECKING, Any + +import rich.repr + +from memray._vendor.textual import events +from memray._vendor.textual._loop import loop_last +from memray._vendor.textual._parser import ParseError +from memray._vendor.textual._xterm_parser import XTermParser +from memray._vendor.textual.driver import Driver +from memray._vendor.textual.drivers._writer_thread import WriterThread +from memray._vendor.textual.geometry import Size +from memray._vendor.textual.message import Message +from memray._vendor.textual.messages import InBandWindowResize + +if TYPE_CHECKING: + from memray._vendor.textual.app import App + + +@rich.repr.auto(angular=True) +class LinuxDriver(Driver): + """Powers display and input for Linux / MacOS""" + + def __init__( + self, + app: App, + *, + debug: bool = False, + mouse: bool = True, + size: tuple[int, int] | None = None, + ) -> None: + """Initialize Linux driver. + + Args: + app: The App instance. + debug: Enable debug mode. + mouse: Enable mouse support. + size: Initial size of the terminal or `None` to detect. + """ + super().__init__(app, debug=debug, mouse=mouse, size=size) + self._file = sys.__stderr__ + self.fileno = sys.__stdin__.fileno() + self.input_tty = sys.__stdin__.isatty() + self.attrs_before: list[Any] | None = None + self.exit_event = Event() + self._key_thread: Thread | None = None + self._writer_thread: WriterThread | None = None + + # If we've finally and properly come back from a SIGSTOP we want to + # be able to ask the app to publish its resume signal; to do that we + # need to know that we came in here via a SIGTSTP; this flag helps + # keep track of this. + self._must_signal_resume = False + self._in_band_window_resize = False + self._mouse_pixels = False + + # Put handlers for SIGTSTP and SIGCONT in place. These are necessary + # to support the user pressing Ctrl+Z (or whatever the dev might + # have bound to call the relevant action on App) to suspend the + # application. + signal.signal(signal.SIGTSTP, self._sigtstp_application) + signal.signal(signal.SIGCONT, self._sigcont_application) + + def _sigtstp_application(self, *_) -> None: + """Handle a SIGTSTP signal.""" + # If we're supposed to auto-restart, that means we need to shut down + # first. + if self._auto_restart: + self.suspend_application_mode() + # Flag that we'll need to signal a resume on successful startup + # again. + self._must_signal_resume = True + # Now send a SIGSTOP to our process to *actually* suspend the + # process. + os.kill(os.getpid(), signal.SIGSTOP) + + def _sigcont_application(self, *_) -> None: + """Handle a SICONT application.""" + if self._auto_restart: + self.resume_application_mode() + + @property + def can_suspend(self) -> bool: + """Can this driver be suspended?""" + return True + + def __rich_repr__(self) -> rich.repr.Result: + yield self._app + + def _get_terminal_size(self) -> tuple[int, int]: + """Detect the terminal size. + + Returns: + The size of the terminal as a tuple of (WIDTH, HEIGHT). + """ + width: int | None = 80 + height: int | None = 25 + import shutil + + try: + width, height = shutil.get_terminal_size() + except (AttributeError, ValueError, OSError): + try: + width, height = shutil.get_terminal_size() + except (AttributeError, ValueError, OSError): + pass + width = width or 80 + height = height or 25 + return width, height + + def _enable_mouse_support(self) -> None: + """Enable reporting of mouse events.""" + if not self._mouse: + return + + write = self.write + write("\x1b[?1000h") # SET_VT200_MOUSE + write("\x1b[?1003h") # SET_ANY_EVENT_MOUSE + write("\x1b[?1015h") # SET_VT200_HIGHLIGHT_MOUSE + write("\x1b[?1006h") # SET_SGR_EXT_MODE_MOUSE + + # write("\x1b[?1007h") + self.flush() + + # Note: E.g. lxterminal understands 1000h, but not the urxvt or sgr + # extensions. + + def _enable_mouse_pixels(self) -> None: + """Enable mouse reporting as pixels.""" + if not self._mouse: + return + self.write("\x1b[?1016h") + self._mouse_pixels = True + + def _enable_bracketed_paste(self) -> None: + """Enable bracketed paste mode.""" + self.write("\x1b[?2004h") + + def _query_in_band_window_resize(self) -> None: + self.write("\x1b[?2048$p") + + def _enable_in_band_window_resize(self) -> None: + self.write("\x1b[?2048h") + + def _enable_line_wrap(self) -> None: + self.write("\x1b[?7h") + + def _disable_line_wrap(self) -> None: + self.write("\x1b[?7l") + + def _disable_in_band_window_resize(self) -> None: + if self._in_band_window_resize: + self.write("\x1b[?2048l") + + def _disable_bracketed_paste(self) -> None: + """Disable bracketed paste mode.""" + self.write("\x1b[?2004l") + + def _disable_mouse_support(self) -> None: + """Disable reporting of mouse events.""" + if not self._mouse: + return + write = self.write + write("\x1b[?1000l") # + write("\x1b[?1003l") # + write("\x1b[?1015l") + write("\x1b[?1006l") + self.flush() + + def write(self, data: str) -> None: + """Write data to the output device. + + Args: + data: Raw data. + """ + assert self._writer_thread is not None, "Driver must be in application mode" + self._writer_thread.write(data) + + def start_application_mode(self): + """Start application mode.""" + + def _stop_again(*_) -> None: + """Signal handler that will put the application back to sleep.""" + os.kill(os.getpid(), signal.SIGSTOP) + + # If we're working with an actual tty... + # https://github.com/Textualize/textual/issues/4104 + if os.isatty(self.fileno): + # Set up handlers to ensure that, if there's a SIGTTOU or a SIGTTIN, + # we go back to sleep. + signal.signal(signal.SIGTTOU, _stop_again) + signal.signal(signal.SIGTTIN, _stop_again) + try: + # Here we perform a NOP tcsetattr. The reason for this is + # that, if we're suspended and the user has performed a `bg` + # in the shell, we'll SIGCONT *but* we won't be allowed to + # do terminal output; so rather than get into the business + # of spinning up application mode again and then finding + # out, we perform a no-consequence change and detect the + # problem right away. + termios.tcsetattr( + self.fileno, termios.TCSANOW, termios.tcgetattr(self.fileno) + ) + except termios.error: + # There was an error doing the tcsetattr; there is no sense + # in carrying on because we'll be doing a SIGSTOP (see + # above). + return + finally: + # We don't need to be hooking SIGTTOU or SIGTTIN any more. + signal.signal(signal.SIGTTOU, signal.SIG_DFL) + signal.signal(signal.SIGTTIN, signal.SIG_DFL) + + loop = asyncio.get_running_loop() + + def send_size_event() -> None: + terminal_size = self._get_terminal_size() + width, height = terminal_size + textual_size = Size(width, height) + event = events.Resize(textual_size, textual_size) + asyncio.run_coroutine_threadsafe( + self._app._post_message(event), + loop=loop, + ) + + self._writer_thread = WriterThread(self._file) + self._writer_thread.start() + + def on_terminal_resize(signum, stack) -> None: + if not self._in_band_window_resize: + send_size_event() + + signal.signal(signal.SIGWINCH, on_terminal_resize) + + self.write("\x1b[?1049h") # Alt screen + + self._enable_mouse_support() + try: + self.attrs_before = termios.tcgetattr(self.fileno) + except termios.error: + # Ignore attribute errors. + self.attrs_before = None + + try: + newattr = termios.tcgetattr(self.fileno) + except termios.error: + pass + else: + newattr[tty.LFLAG] = self._patch_lflag(newattr[tty.LFLAG]) + newattr[tty.IFLAG] = self._patch_iflag(newattr[tty.IFLAG]) + + # VMIN defines the number of characters read at a time in + # non-canonical mode. It seems to default to 1 on Linux, but on + # Solaris and derived operating systems it defaults to 4. (This is + # because the VMIN slot is the same as the VEOF slot, which + # defaults to ASCII EOT = Ctrl-D = 4.) + newattr[tty.CC][termios.VMIN] = 1 + + try: + termios.tcsetattr(self.fileno, termios.TCSANOW, newattr) + except termios.error: + pass + + self.write("\x1b[?25l") # Hide cursor + self.write("\x1b[?1004h") # Enable FocusIn/FocusOut. + self.write("\x1b[>1u") # https://sw.kovidgoyal.net/kitty/keyboard-protocol/ + + self.flush() + self._key_thread = Thread(target=self._run_input_thread, name="textual-input") + send_size_event() + self._key_thread.start() + self._request_terminal_sync_mode_support() + self._query_in_band_window_resize() + self._enable_bracketed_paste() + self._disable_line_wrap() + + # Appears to fix an issue enabling mouse support in iTerm 3.5.0 + self._enable_mouse_support() + + # If we need to ask the app to signal that we've come back from a + # SIGTSTP... + if self._must_signal_resume: + self._must_signal_resume = False + asyncio.run_coroutine_threadsafe( + self._app._post_message(self.SignalResume()), + loop=loop, + ) + + def _request_terminal_sync_mode_support(self) -> None: + """Writes an escape sequence to query the terminal support for the sync protocol.""" + # Terminals should ignore this sequence if not supported. + # Apple terminal doesn't, and writes a single 'p' into the terminal, + # so we will make a special case for Apple terminal (which doesn't support sync anyway). + if not self.input_tty: + return + if os.environ.get("TERM_PROGRAM", "") != "Apple_Terminal": + self.write("\033[?2026$p") + self.flush() + + @classmethod + def _patch_lflag(cls, attrs: int) -> int: + """Patch termios lflag. + + Args: + attributes: New set attributes. + + Returns: + New lflag. + + """ + # if TEXTUAL_ALLOW_SIGNALS env var is set, then allow Ctrl+C to send signals + ISIG = 0 if os.environ.get("TEXTUAL_ALLOW_SIGNALS") else termios.ISIG + + return attrs & ~(termios.ECHO | termios.ICANON | termios.IEXTEN | ISIG) + + @classmethod + def _patch_iflag(cls, attrs: int) -> int: + return attrs & ~( + # Disable XON/XOFF flow control on output and input. + # (Don't capture Ctrl-S and Ctrl-Q.) + # Like executing: "stty -ixon." + termios.IXON + | termios.IXOFF + | + # Don't translate carriage return into newline on input. + termios.ICRNL + | termios.INLCR + | termios.IGNCR + ) + + def disable_input(self) -> None: + """Disable further input.""" + try: + if not self.exit_event.is_set(): + signal.signal(signal.SIGWINCH, signal.SIG_DFL) + self._disable_mouse_support() + self.exit_event.set() + if self._key_thread is not None: + self._key_thread.join() + self.exit_event.clear() + try: + termios.tcflush(self.fileno, termios.TCIFLUSH) + except termios.error: + pass + except Exception: + # TODO: log this + pass + + def stop_application_mode(self) -> None: + """Stop application mode, restore state.""" + self._disable_bracketed_paste() + self._enable_line_wrap() + self._disable_in_band_window_resize() + self.disable_input() + + if self.attrs_before is not None: + try: + termios.tcsetattr(self.fileno, termios.TCSANOW, self.attrs_before) + except termios.error: + pass + + # Disable the Kitty keyboard protocol. This must be done before leaving + # the alt screen. https://sw.kovidgoyal.net/kitty/keyboard-protocol/ + self.write("\x1b[ None: + """Perform cleanup.""" + if self._writer_thread is not None: + self._writer_thread.stop() + + def _run_input_thread(self) -> None: + """ + Key thread target that wraps run_input_thread() to die gracefully if it raises + an exception + """ + try: + self.run_input_thread() + except BaseException: + import rich.traceback + + self._app.call_later( + self._app.panic, + rich.traceback.Traceback(), + ) + + def run_input_thread(self) -> None: + """Wait for input and dispatch events.""" + selector = selectors.SelectSelector() + selector.register(self.fileno, selectors.EVENT_READ) + + fileno = self.fileno + EVENT_READ = selectors.EVENT_READ + + parser = XTermParser(self._debug) + feed = parser.feed + tick = parser.tick + + utf8_decoder = getincrementaldecoder("utf-8")().decode + decode = utf8_decoder + read = os.read + + def process_selector_events( + selector_events: list[tuple[selectors.SelectorKey, int]], + final: bool = False, + ) -> None: + """Process events from selector. + + Args: + selector_events: List of selector events. + final: True if this is the last call. + + """ + for last, (_selector_key, mask) in loop_last(selector_events): + if mask & EVENT_READ: + unicode_data = decode(read(fileno, 1024 * 4), final=final and last) + if not unicode_data: + # This can occur if the stdin is piped + break + for event in feed(unicode_data): + self.process_message(event) + for event in tick(): + self.process_message(event) + + try: + while not self.exit_event.is_set(): + process_selector_events(selector.select(0.1)) + selector.unregister(self.fileno) + process_selector_events(selector.select(0.1), final=True) + + finally: + selector.close() + try: + for event in feed(""): + pass + except (EOFError, ParseError): + pass + + def process_message(self, message: Message) -> None: + # intercept in-band window resize + if isinstance(message, InBandWindowResize): + if message.supported: + self._in_band_window_resize = True + if message.enabled: + # Supported and enabled + super().process_message(message) + else: + # Supported, but not enabled + self._enable_in_band_window_resize() + super().process_message(InBandWindowResize(True, True)) + self._enable_mouse_pixels() + return + + super().process_message(message) diff --git a/src/memray/_vendor/textual/drivers/linux_inline_driver.py b/src/memray/_vendor/textual/drivers/linux_inline_driver.py new file mode 100644 index 0000000000..e39d5d10c4 --- /dev/null +++ b/src/memray/_vendor/textual/drivers/linux_inline_driver.py @@ -0,0 +1,325 @@ +from __future__ import annotations + +import asyncio +import os +import selectors +import signal +import sys +import termios +import tty +from codecs import getincrementaldecoder +from threading import Event, Thread +from typing import TYPE_CHECKING, Any + +import rich.repr + +from memray._vendor.textual import events +from memray._vendor.textual._loop import loop_last +from memray._vendor.textual._parser import ParseError +from memray._vendor.textual._xterm_parser import XTermParser +from memray._vendor.textual.driver import Driver +from memray._vendor.textual.geometry import Size + +if TYPE_CHECKING: + from memray._vendor.textual.app import App + + +@rich.repr.auto(angular=True) +class LinuxInlineDriver(Driver): + def __init__( + self, + app: App, + *, + debug: bool = False, + mouse: bool = True, + size: tuple[int, int] | None = None, + ): + super().__init__(app, debug=debug, mouse=mouse, size=size) + self._file = sys.__stderr__ + self.fileno = sys.__stdin__.fileno() + self.attrs_before: list[Any] | None = None + self.exit_event = Event() + + def __rich_repr__(self) -> rich.repr.Result: + yield self._app + + @property + def is_inline(self) -> bool: + return True + + def _enable_bracketed_paste(self) -> None: + """Enable bracketed paste mode.""" + self.write("\x1b[?2004h") + + def _disable_bracketed_paste(self) -> None: + """Disable bracketed paste mode.""" + self.write("\x1b[?2004l") + + def _get_terminal_size(self) -> tuple[int, int]: + """Detect the terminal size. + + Returns: + The size of the terminal as a tuple of (WIDTH, HEIGHT). + """ + width: int | None = 80 + height: int | None = 25 + import shutil + + try: + width, height = shutil.get_terminal_size() + except (AttributeError, ValueError, OSError): + try: + width, height = shutil.get_terminal_size() + except (AttributeError, ValueError, OSError): + pass + width = width or 80 + height = height or 25 + return width, height + + def _enable_mouse_support(self) -> None: + """Enable reporting of mouse events.""" + if not self._mouse: + return + write = self.write + write("\x1b[?1000h") # SET_VT200_MOUSE + write("\x1b[?1003h") # SET_ANY_EVENT_MOUSE + write("\x1b[?1015h") # SET_VT200_HIGHLIGHT_MOUSE + write("\x1b[?1006h") # SET_SGR_EXT_MODE_MOUSE + + # write("\x1b[?1007h") + self.flush() + + def _disable_mouse_support(self) -> None: + """Disable reporting of mouse events.""" + if not self._mouse: + return + write = self.write + write("\x1b[?1000l") # + write("\x1b[?1003l") # + write("\x1b[?1015l") + write("\x1b[?1006l") + self.flush() + + def write(self, data: str) -> None: + self._file.write(data) + + def _run_input_thread(self) -> None: + """ + Key thread target that wraps run_input_thread() to die gracefully if it raises + an exception + """ + try: + self.run_input_thread() + except BaseException: + import rich.traceback + + self._app.call_later( + self._app.panic, + rich.traceback.Traceback(), + ) + + def run_input_thread(self) -> None: + """Wait for input and dispatch events.""" + selector = selectors.SelectSelector() + selector.register(self.fileno, selectors.EVENT_READ) + + fileno = self.fileno + EVENT_READ = selectors.EVENT_READ + + parser = XTermParser(self._debug) + feed = parser.feed + tick = parser.tick + + utf8_decoder = getincrementaldecoder("utf-8")().decode + decode = utf8_decoder + read = os.read + + def process_selector_events( + selector_events: list[tuple[selectors.SelectorKey, int]], + final: bool = False, + ) -> None: + """Process events from selector. + + Args: + selector_events: List of selector events. + final: True if this is the last call. + + """ + for last, (_selector_key, mask) in loop_last(selector_events): + if mask & EVENT_READ: + unicode_data = decode(read(fileno, 1024 * 4), final=final and last) + if not unicode_data: + # This can occur if the stdin is piped + break + for event in feed(unicode_data): + if isinstance(event, events.CursorPosition): + self.cursor_origin = (event.x, event.y) + else: + self.process_message(event) + for event in tick(): + if isinstance(event, events.CursorPosition): + self.cursor_origin = (event.x, event.y) + else: + self.process_message(event) + + try: + while not self.exit_event.is_set(): + process_selector_events(selector.select(0.1)) + selector.unregister(self.fileno) + process_selector_events(selector.select(0.1), final=True) + + finally: + selector.close() + try: + for event in feed(""): + pass + except ParseError: + pass + + def start_application_mode(self) -> None: + loop = asyncio.get_running_loop() + + def send_size_event(clear: bool = False) -> None: + """Send the resize event, optionally clearing the screen. + + Args: + clear: Clear the screen. + """ + terminal_size = self._get_terminal_size() + width, height = terminal_size + textual_size = Size(width, height) + event = events.Resize(textual_size, textual_size) + + async def update_size() -> None: + """Update the screen size.""" + if clear: + self.write("\x1b[2J") + await self._app._post_message(event) + + asyncio.run_coroutine_threadsafe( + update_size(), + loop=loop, + ) + + def on_terminal_resize(signum, stack) -> None: + send_size_event(clear=True) + + signal.signal(signal.SIGWINCH, on_terminal_resize) + + self.write("\x1b[?25l") # Hide cursor + self.write("\033[?1004h") # Enable FocusIn/FocusOut. + self.write("\x1b[>1u") # https://sw.kovidgoyal.net/kitty/keyboard-protocol/ + self.flush() + + self._enable_mouse_support() + self.write("\n" * self._app.INLINE_PADDING) + self.flush() + try: + self.attrs_before = termios.tcgetattr(self.fileno) + except termios.error: + # Ignore attribute errors. + self.attrs_before = None + + try: + newattr = termios.tcgetattr(self.fileno) + except termios.error: + pass + else: + newattr[tty.LFLAG] = self._patch_lflag(newattr[tty.LFLAG]) + newattr[tty.IFLAG] = self._patch_iflag(newattr[tty.IFLAG]) + + # VMIN defines the number of characters read at a time in + # non-canonical mode. It seems to default to 1 on Linux, but on + # Solaris and derived operating systems it defaults to 4. (This is + # because the VMIN slot is the same as the VEOF slot, which + # defaults to ASCII EOT = Ctrl-D = 4.) + newattr[tty.CC][termios.VMIN] = 1 + + termios.tcsetattr(self.fileno, termios.TCSANOW, newattr) + + self._key_thread = Thread(target=self._run_input_thread, name="textual-input") + send_size_event() + self._key_thread.start() + self._request_terminal_sync_mode_support() + self._enable_bracketed_paste() + + def _request_terminal_sync_mode_support(self) -> None: + """Writes an escape sequence to query the terminal support for the sync protocol.""" + # Terminals should ignore this sequence if not supported. + # Apple terminal doesn't, and writes a single 'p' into the terminal, + # so we will make a special case for Apple terminal (which doesn't support sync anyway). + if os.environ.get("TERM_PROGRAM", "") != "Apple_Terminal": + self.write("\033[?2026$p") + self.flush() + + @classmethod + def _patch_lflag(cls, attrs: int) -> int: + """Patch termios lflag. + + Args: + attributes: New set attributes. + + Returns: + New lflag. + + """ + # if TEXTUAL_ALLOW_SIGNALS env var is set, then allow Ctrl+C to send signals + ISIG = 0 if os.environ.get("TEXTUAL_ALLOW_SIGNALS") else termios.ISIG + + return attrs & ~(termios.ECHO | termios.ICANON | termios.IEXTEN | ISIG) + + @classmethod + def _patch_iflag(cls, attrs: int) -> int: + return attrs & ~( + # Disable XON/XOFF flow control on output and input. + # (Don't capture Ctrl-S and Ctrl-Q.) + # Like executing: "stty -ixon." + termios.IXON + | termios.IXOFF + | + # Don't translate carriage return into newline on input. + termios.ICRNL + | termios.INLCR + | termios.IGNCR + ) + + def disable_input(self) -> None: + """Disable further input.""" + try: + if not self.exit_event.is_set(): + signal.signal(signal.SIGWINCH, signal.SIG_DFL) + self._disable_mouse_support() + self.exit_event.set() + if self._key_thread is not None: + self._key_thread.join() + self.exit_event.clear() + try: + termios.tcflush(self.fileno, termios.TCIFLUSH) + except termios.error: + pass + + except Exception as error: + # TODO: log this + pass + + def flush(self): + """Flush any buffered data.""" + self._file.flush() + + def stop_application_mode(self) -> None: + """Stop application mode, restore state.""" + self._disable_bracketed_paste() + self.disable_input() + self.write("\x1b[ bool: + return True + + def write(self, data: str) -> None: + """Write string data to the output device, which may be piped to + the parent process (i.e. textual-web/textual-serve). + + Args: + data: Raw data. + """ + + data_bytes = data.encode("utf-8") + self._write(b"D%s%s" % (len(data_bytes).to_bytes(4, "big"), data_bytes)) + + def write_meta(self, data: dict[str, object]) -> None: + """Write a dictionary containing some metadata to stdout, which + may be piped to the parent process (i.e. textual-web/textual-serve). + + Args: + data: Meta dict. + """ + meta_bytes = json.dumps(data).encode("utf-8", errors="ignore") + self._write(b"M%s%s" % (len(meta_bytes).to_bytes(4, "big"), meta_bytes)) + + def write_binary_encoded(self, data: tuple[str | bytes, ...]) -> None: + """Binary encode a data-structure and write to stdout. + + Args: + data: The data to binary encode and write. + """ + packed_bytes = binary_dump(data) + self._write(b"P%s%s" % (len(packed_bytes).to_bytes(4, "big"), packed_bytes)) + + def flush(self) -> None: + pass + + def _enable_mouse_support(self) -> None: + """Enable reporting of mouse events.""" + write = self.write + write("\x1b[?1000h") # SET_VT200_MOUSE + write("\x1b[?1003h") # SET_ANY_EVENT_MOUSE + write("\x1b[?1015h") # SET_VT200_HIGHLIGHT_MOUSE + write("\x1b[?1006h") # SET_SGR_EXT_MODE_MOUSE + + def _enable_bracketed_paste(self) -> None: + """Enable bracketed paste mode.""" + self.write("\x1b[?2004h") + + def _disable_bracketed_paste(self) -> None: + """Disable bracketed paste mode.""" + self.write("\x1b[?2004l") + + def _disable_mouse_support(self) -> None: + """Disable reporting of mouse events.""" + write = self.write + write("\x1b[?1000l") # + write("\x1b[?1003l") # + write("\x1b[?1015l") + write("\x1b[?1006l") + + def _request_terminal_sync_mode_support(self) -> None: + """Writes an escape sequence to query the terminal support for the sync protocol.""" + self.write("\033[?2026$p") + + def start_application_mode(self) -> None: + """Start application mode.""" + + loop = asyncio.get_running_loop() + + def do_exit() -> None: + """Callback to force exit.""" + asyncio.run_coroutine_threadsafe( + self._app._post_message(messages.ExitApp()), loop=loop + ) + + if not WINDOWS: + for _signal in (signal.SIGINT, signal.SIGTERM): + loop.add_signal_handler(_signal, do_exit) + + self._write(b"__GANGLION__\n") + + self.write("\x1b[?1049h") # Alt screen + self._enable_mouse_support() + + self.write("\x1b[?25l") # Hide cursor + self.write("\033[?1003h") + + size = Size(80, 24) if self._size is None else Size(*self._size) + event = events.Resize(size, size) + asyncio.run_coroutine_threadsafe( + self._app._post_message(event), + loop=loop, + ) + + self._request_terminal_sync_mode_support() + self._enable_bracketed_paste() + self.flush() + self._key_thread.start() + self._app.call_later(self._app.post_message, events.AppBlur()) + + def disable_input(self) -> None: + """Disable further input.""" + + def stop_application_mode(self) -> None: + """Stop application mode, restore state.""" + self.exit_event.set() + self._input_reader.close() + self.write_meta({"type": "exit"}) + + def run_input_thread(self) -> None: + """Wait for input and dispatch events.""" + input_reader = self._input_reader + parser = XTermParser(debug=self._debug) + utf8_decoder = getincrementaldecoder("utf-8")().decode + decode = utf8_decoder + # The server sends us a stream of bytes, which contains the equivalent of stdin, plus + # in band data packets. + byte_stream = ByteStream() + try: + for data in input_reader: + if data: + for packet_type, payload in byte_stream.feed(data): + if packet_type == "D": + # Treat as stdin + for event in parser.feed(decode(payload)): + self.process_message(event) + else: + # Process meta information separately + self._on_meta(packet_type, payload) + for event in parser.tick(): + self.process_message(event) + except _ExitInput: + pass + except Exception: + from traceback import format_exc + + log(format_exc()) + finally: + input_reader.close() + + def _on_meta(self, packet_type: str, payload: bytes) -> None: + """Private method to dispatch meta. + + Args: + packet_type: Packet type (currently always "M") + payload: Meta payload (JSON encoded as bytes). + """ + payload_map: dict[str, object] = json.loads(payload) + _type = payload_map.get("type", {}) + if isinstance(_type, str): + self.on_meta(_type, payload_map) + else: + log.error( + f"Protocol error: type field value is not a string. Value is {_type!r}" + ) + + def on_meta(self, packet_type: str, payload: dict[str, object]) -> None: + """Process a dictionary containing information received from the controlling process. + + Args: + packet_type: The type of the packet. + payload: meta dict. + """ + if packet_type == "resize": + self._size = (payload["width"], payload["height"]) + requested_size = Size(*self._size) + self._app.post_message(events.Resize(requested_size, requested_size)) + elif packet_type == "focus": + self._app.post_message(events.AppFocus()) + elif packet_type == "blur": + self._app.post_message(events.AppBlur()) + elif packet_type == "quit": + self._app.post_message(messages.ExitApp()) + elif packet_type == "exit": + raise _ExitInput() + elif packet_type == "deliver_chunk_request": + # A request from the server to deliver another chunk of a file + log.debug(f"Deliver chunk request: {payload}") + try: + delivery_key = cast(str, payload["key"]) + requested_size = cast(int, payload["size"]) + except KeyError: + log.error("Protocol error: deliver_chunk_request missing key or size") + return + + deliveries = self._deliveries + + file_like: BinaryIO | TextIO | None = None + try: + file_like = deliveries[delivery_key] + except KeyError: + log.error( + f"Protocol error: deliver_chunk_request invalid key {delivery_key!r}" + ) + else: + # Read the requested amount of data from the file + name: str | None = payload.get("name", None) + try: + log.debug(f"Reading {requested_size} bytes from {delivery_key}") + chunk = file_like.read(requested_size) + log.debug(f"Delivering chunk {delivery_key!r} of len {len(chunk)}") + self.write_binary_encoded(("deliver_chunk", delivery_key, chunk)) + # We've hit an empty chunk, so we're done + if not chunk: + log.info(f"Delivery complete for {delivery_key}") + file_like.close() + del deliveries[delivery_key] + self._delivery_complete(delivery_key, save_path=None, name=name) + except Exception as error: + file_like.close() + del deliveries[delivery_key] + + log.error( + f"Error delivering file chunk for key {delivery_key!r}. " + "Cancelling delivery." + ) + import traceback + + log.error(str(traceback.format_exc())) + + self._delivery_failed(delivery_key, exception=error, name=name) + + def open_url(self, url: str, new_tab: bool = True) -> None: + """Open a URL in the default web browser. + + Args: + url: The URL to open. + new_tab: Whether to open the URL in a new tab. + """ + self.write_meta({"type": "open_url", "url": url, "new_tab": new_tab}) + + def deliver_binary( + self, + binary: BinaryIO | TextIO, + *, + delivery_key: str, + save_path: Path, + open_method: Literal["browser", "download"] = "download", + encoding: str | None = None, + mime_type: str | None = None, + name: str | None = None, + ) -> None: + self._deliver_file( + binary, + delivery_key=delivery_key, + save_path=save_path, + open_method=open_method, + encoding=encoding, + mime_type=mime_type, + name=name, + ) + + def _deliver_file( + self, + binary: BinaryIO | TextIO, + *, + delivery_key: str, + save_path: Path, + open_method: Literal["browser", "download"], + encoding: str | None = None, + mime_type: str | None = None, + name: str | None = None, + ) -> None: + """Deliver a file to the end-user of the application.""" + binary.seek(0) + + self._deliveries[delivery_key] = binary + + # Inform the server that we're starting a new file delivery + meta: dict[str, object] = { + "type": "deliver_file_start", + "key": delivery_key, + "path": str(save_path.resolve()), + "open_method": open_method, + "encoding": encoding or "", + "mime_type": mime_type or "", + "name": name, + } + self.write_meta(meta) + log.info(f"Delivering file {meta['path']!r}: {meta!r}") diff --git a/src/memray/_vendor/textual/drivers/win32.py b/src/memray/_vendor/textual/drivers/win32.py new file mode 100644 index 0000000000..ab5c66e29d --- /dev/null +++ b/src/memray/_vendor/textual/drivers/win32.py @@ -0,0 +1,304 @@ +from __future__ import annotations + +import ctypes +import msvcrt +import sys +import threading +from asyncio import AbstractEventLoop, run_coroutine_threadsafe +from ctypes import Structure, Union, byref, wintypes +from ctypes.wintypes import BOOL, CHAR, DWORD, HANDLE, SHORT, UINT, WCHAR, WORD +from typing import IO, TYPE_CHECKING, Callable, List, Optional + +from memray._vendor.textual import constants +from memray._vendor.textual._xterm_parser import XTermParser +from memray._vendor.textual.events import Event, Resize +from memray._vendor.textual.geometry import Size + +if TYPE_CHECKING: + from memray._vendor.textual.app import App + +KERNEL32 = ctypes.WinDLL("kernel32", use_last_error=True) # type: ignore + +# Console input modes +ENABLE_ECHO_INPUT = 0x0004 +ENABLE_EXTENDED_FLAGS = 0x0080 +ENABLE_INSERT_MODE = 0x0020 +ENABLE_LINE_INPUT = 0x0002 +ENABLE_MOUSE_INPUT = 0x0010 +ENABLE_PROCESSED_INPUT = 0x0001 +ENABLE_QUICK_EDIT_MODE = 0x0040 +ENABLE_WINDOW_INPUT = 0x0008 +ENABLE_VIRTUAL_TERMINAL_INPUT = 0x0200 + +# Console output modes +ENABLE_PROCESSED_OUTPUT = 0x0001 +ENABLE_WRAP_AT_EOL_OUTPUT = 0x0002 +ENABLE_VIRTUAL_TERMINAL_PROCESSING = 0x0004 +DISABLE_NEWLINE_AUTO_RETURN = 0x0008 +ENABLE_LVB_GRID_WORLDWIDE = 0x0010 + +STD_INPUT_HANDLE = -10 +STD_OUTPUT_HANDLE = -11 + +WAIT_TIMEOUT = 0x00000102 + +GetStdHandle = KERNEL32.GetStdHandle +GetStdHandle.argtypes = [wintypes.DWORD] +GetStdHandle.restype = wintypes.HANDLE + + +class COORD(Structure): + """https://docs.microsoft.com/en-us/windows/console/coord-str""" + + _fields_ = [ + ("X", SHORT), + ("Y", SHORT), + ] + + +class uChar(Union): + """https://docs.microsoft.com/en-us/windows/console/key-event-record-str""" + + _fields_ = [ + ("AsciiChar", CHAR), + ("UnicodeChar", WCHAR), + ] + + +class KEY_EVENT_RECORD(Structure): + """https://docs.microsoft.com/en-us/windows/console/key-event-record-str""" + + _fields_ = [ + ("bKeyDown", BOOL), + ("wRepeatCount", WORD), + ("wVirtualKeyCode", WORD), + ("wVirtualScanCode", WORD), + ("uChar", uChar), + ("dwControlKeyState", DWORD), + ] + + +class MOUSE_EVENT_RECORD(Structure): + """https://docs.microsoft.com/en-us/windows/console/mouse-event-record-str""" + + _fields_ = [ + ("dwMousePosition", COORD), + ("dwButtonState", DWORD), + ("dwControlKeyState", DWORD), + ("dwEventFlags", DWORD), + ] + + +class WINDOW_BUFFER_SIZE_RECORD(Structure): + """https://docs.microsoft.com/en-us/windows/console/window-buffer-size-record-str""" + + _fields_ = [("dwSize", COORD)] + + +class MENU_EVENT_RECORD(Structure): + """https://docs.microsoft.com/en-us/windows/console/menu-event-record-str""" + + _fields_ = [("dwCommandId", UINT)] + + +class FOCUS_EVENT_RECORD(Structure): + """https://docs.microsoft.com/en-us/windows/console/focus-event-record-str""" + + _fields_ = [("bSetFocus", BOOL)] + + +class InputEvent(Union): + """https://docs.microsoft.com/en-us/windows/console/input-record-str""" + + _fields_ = [ + ("KeyEvent", KEY_EVENT_RECORD), + ("MouseEvent", MOUSE_EVENT_RECORD), + ("WindowBufferSizeEvent", WINDOW_BUFFER_SIZE_RECORD), + ("MenuEvent", MENU_EVENT_RECORD), + ("FocusEvent", FOCUS_EVENT_RECORD), + ] + + +class INPUT_RECORD(Structure): + """https://docs.microsoft.com/en-us/windows/console/input-record-str""" + + _fields_ = [("EventType", wintypes.WORD), ("Event", InputEvent)] + + +def set_console_mode(file: IO, mode: int) -> bool: + """Set the console mode for a given file (stdout or stdin). + + Args: + file: A file like object. + mode: New mode. + + Returns: + True on success, otherwise False. + """ + windows_filehandle = msvcrt.get_osfhandle(file.fileno()) # type: ignore + success = KERNEL32.SetConsoleMode(windows_filehandle, mode) + return success + + +def get_console_mode(file: IO) -> int: + """Get the console mode for a given file (stdout or stdin) + + Args: + file: A file-like object. + + Returns: + The current console mode. + """ + windows_filehandle = msvcrt.get_osfhandle(file.fileno()) # type: ignore + mode = wintypes.DWORD() + KERNEL32.GetConsoleMode(windows_filehandle, ctypes.byref(mode)) + return mode.value + + +def enable_application_mode() -> Callable[[], None]: + """Enable application mode. + + Returns: + A callable that will restore terminal to previous state. + """ + + terminal_in = sys.__stdin__ + terminal_out = sys.__stdout__ + + current_console_mode_in = get_console_mode(terminal_in) + current_console_mode_out = get_console_mode(terminal_out) + + def restore() -> None: + """Restore console mode to previous settings""" + set_console_mode(terminal_in, current_console_mode_in) + set_console_mode(terminal_out, current_console_mode_out) + + set_console_mode( + terminal_out, current_console_mode_out | ENABLE_VIRTUAL_TERMINAL_PROCESSING + ) + set_console_mode(terminal_in, ENABLE_VIRTUAL_TERMINAL_INPUT) + return restore + + +def wait_for_handles(handles: List[HANDLE], timeout: int = -1) -> Optional[HANDLE]: + """ + Waits for multiple handles. (Similar to 'select') Returns the handle which is ready. + Returns `None` on timeout. + http://msdn.microsoft.com/en-us/library/windows/desktop/ms687025(v=vs.85).aspx + Note that handles should be a list of `HANDLE` objects, not integers. See + this comment in the patch by @quark-zju for the reason why: + ''' Make sure HANDLE on Windows has a correct size + Previously, the type of various HANDLEs are native Python integer + types. The ctypes library will treat them as 4-byte integer when used + in function arguments. On 64-bit Windows, HANDLE is 8-byte and usually + a small integer. Depending on whether the extra 4 bytes are zero-ed out + or not, things can happen to work, or break. ''' + This function returns either `None` or one of the given `HANDLE` objects. + (The return value can be tested with the `is` operator.) + """ + arrtype = HANDLE * len(handles) + handle_array = arrtype(*handles) + + ret: int = KERNEL32.WaitForMultipleObjects( + len(handle_array), handle_array, BOOL(False), DWORD(timeout) + ) + + if ret == WAIT_TIMEOUT: + return None + else: + return handles[ret] + + +class EventMonitor(threading.Thread): + """A thread to send key / window events to Textual loop.""" + + def __init__( + self, + loop: AbstractEventLoop, + app: App, + exit_event: threading.Event, + process_event: Callable[[Event], None], + ) -> None: + self.loop = loop + self.app = app + self.exit_event = exit_event + self.process_event = process_event + super().__init__(name="textual-input") + + def run(self) -> None: + exit_requested = self.exit_event.is_set + parser = XTermParser(debug=constants.DEBUG) + + try: + read_count = wintypes.DWORD(0) + hIn = GetStdHandle(STD_INPUT_HANDLE) + + MAX_EVENTS = 1024 + KEY_EVENT = 0x0001 + WINDOW_BUFFER_SIZE_EVENT = 0x0004 + + arrtype = INPUT_RECORD * MAX_EVENTS + input_records = arrtype() + ReadConsoleInputW = KERNEL32.ReadConsoleInputW + keys: List[str] = [] + append_key = keys.append + + while not exit_requested(): + + for event in parser.tick(): + self.process_event(event) + + # Wait for new events + if wait_for_handles([hIn], 100) is None: + # No new events + continue + + # Get new events + ReadConsoleInputW( + hIn, byref(input_records), MAX_EVENTS, byref(read_count) + ) + read_input_records = input_records[: read_count.value] + + del keys[:] + new_size: Optional[tuple[int, int]] = None + + for input_record in read_input_records: + event_type = input_record.EventType + + if event_type == KEY_EVENT: + # Key event, store unicode char in keys list + key_event = input_record.Event.KeyEvent + key = key_event.uChar.UnicodeChar + if key_event.bKeyDown: + if ( + key_event.dwControlKeyState + and key_event.wVirtualKeyCode == 0 + ): + continue + append_key(key) + elif event_type == WINDOW_BUFFER_SIZE_EVENT: + # Window size changed, store size + size = input_record.Event.WindowBufferSizeEvent.dwSize + new_size = (size.X, size.Y) + + if keys: + # Process keys + # + # https://github.com/Textualize/textual/issues/3178 has + # the context for the encode/decode here. + for event in parser.feed( + "".join(keys).encode("utf-16", "surrogatepass").decode("utf-16") + ): + self.process_event(event) + if new_size is not None: + # Process changed size + self.on_size_change(*new_size) + + except Exception as error: + self.app.log.error("EVENT MONITOR ERROR", error) + + def on_size_change(self, width: int, height: int) -> None: + """Called when terminal size changes.""" + size = Size(width, height) + event = Resize(size, size) + run_coroutine_threadsafe(self.app._post_message(event), loop=self.loop) diff --git a/src/memray/_vendor/textual/drivers/windows_driver.py b/src/memray/_vendor/textual/drivers/windows_driver.py new file mode 100644 index 0000000000..a94dd64f11 --- /dev/null +++ b/src/memray/_vendor/textual/drivers/windows_driver.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +import asyncio +import sys +from threading import Event, Thread +from typing import TYPE_CHECKING, Callable + +from memray._vendor.textual.driver import Driver +from memray._vendor.textual.drivers import win32 +from memray._vendor.textual.drivers._writer_thread import WriterThread + +if TYPE_CHECKING: + from memray._vendor.textual.app import App + + +class WindowsDriver(Driver): + """Powers display and input for Windows.""" + + def __init__( + self, + app: App, + *, + debug: bool = False, + mouse: bool = True, + size: tuple[int, int] | None = None, + ) -> None: + """Initialize Windows driver. + + Args: + app: The App instance. + debug: Enable debug mode. + mouse: Enable mouse support. + size: Initial size of the terminal or `None` to detect. + """ + super().__init__(app, debug=debug, mouse=mouse, size=size) + self._file = sys.__stdout__ + self.exit_event = Event() + self._event_thread: Thread | None = None + self._restore_console: Callable[[], None] | None = None + self._writer_thread: WriterThread | None = None + + @property + def can_suspend(self) -> bool: + """Can this driver be suspended?""" + return True + + def write(self, data: str) -> None: + """Write data to the output device. + + Args: + data: Raw data. + """ + assert self._writer_thread is not None, "Driver must be in application mode" + self._writer_thread.write(data) + + def _enable_mouse_support(self) -> None: + """Enable reporting of mouse events.""" + if not self._mouse: + return + write = self.write + write("\x1b[?1000h") # SET_VT200_MOUSE + write("\x1b[?1003h") # SET_ANY_EVENT_MOUSE + write("\x1b[?1015h") # SET_VT200_HIGHLIGHT_MOUSE + write("\x1b[?1006h") # SET_SGR_EXT_MODE_MOUSE + self.flush() + + def _disable_mouse_support(self) -> None: + """Disable reporting of mouse events.""" + if not self._mouse: + return + write = self.write + write("\x1b[?1000l") + write("\x1b[?1003l") + write("\x1b[?1015l") + write("\x1b[?1006l") + self.flush() + + def _enable_bracketed_paste(self) -> None: + """Enable bracketed paste mode.""" + self.write("\x1b[?2004h") + + def _disable_bracketed_paste(self) -> None: + """Disable bracketed paste mode.""" + self.write("\x1b[?2004l") + + def start_application_mode(self) -> None: + """Start application mode.""" + loop = asyncio.get_running_loop() + + self._restore_console = win32.enable_application_mode() + + self._writer_thread = WriterThread(self._file) + self._writer_thread.start() + + self.write("\x1b[?1049h") # Enable alt screen + self._enable_mouse_support() + self.write("\x1b[?25l") # Hide cursor + self.write("\033[?1004h") # Enable FocusIn/FocusOut. + self.write("\x1b[>1u") # https://sw.kovidgoyal.net/kitty/keyboard-protocol/ + self.flush() + self._enable_bracketed_paste() + + self._event_thread = win32.EventMonitor( + loop, self._app, self.exit_event, self.process_message + ) + self._event_thread.start() + + def disable_input(self) -> None: + """Disable further input.""" + try: + if not self.exit_event.is_set(): + self._disable_mouse_support() + self.exit_event.set() + if self._event_thread is not None: + self._event_thread.join() + self._event_thread = None + self.exit_event.clear() + except Exception as error: + # TODO: log this + pass + + def stop_application_mode(self) -> None: + """Stop application mode, restore state.""" + self._disable_bracketed_paste() + self.disable_input() + + # Disable the Kitty keyboard protocol. This must be done before leaving + # the alt screen. https://sw.kovidgoyal.net/kitty/keyboard-protocol/ + self.write("\x1b[ None: + """Perform cleanup.""" + if self._writer_thread is not None: + self._writer_thread.stop() + if self._restore_console: + self._restore_console() diff --git a/src/memray/_vendor/textual/errors.py b/src/memray/_vendor/textual/errors.py new file mode 100644 index 0000000000..034139e204 --- /dev/null +++ b/src/memray/_vendor/textual/errors.py @@ -0,0 +1,26 @@ +""" +General exception classes. + +""" + +from __future__ import annotations + + +class TextualError(Exception): + """Base class for Textual errors.""" + + +class NoWidget(TextualError): + """Specified widget was not found.""" + + +class RenderError(TextualError): + """An object could not be rendered.""" + + +class DuplicateKeyHandlers(TextualError): + """More than one handler for a single key press. + + For example, if the handlers `key_ctrl_i` and `key_tab` were defined on the same + widget, then this error would be raised. + """ diff --git a/src/memray/_vendor/textual/eta.py b/src/memray/_vendor/textual/eta.py new file mode 100644 index 0000000000..3edb004696 --- /dev/null +++ b/src/memray/_vendor/textual/eta.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +import bisect +from math import ceil +from time import monotonic + +import rich.repr + + +@rich.repr.auto(angular=True) +class ETA: + """Calculate speed and estimate time to arrival.""" + + def __init__( + self, estimation_period: float = 60, extrapolate_period: float = 30 + ) -> None: + """Create an ETA. + + Args: + estimation_period: Period in seconds, used to calculate speed. + extrapolate_period: Maximum number of seconds used to estimate progress after last sample. + """ + self.estimation_period = estimation_period + self.max_extrapolate = extrapolate_period + self._samples: list[tuple[float, float]] = [(0.0, 0.0)] + self._add_count = 0 + + def __rich_repr__(self) -> rich.repr.Result: + yield "speed", self.speed + yield "eta", self.get_eta(monotonic()) + + @property + def first_sample(self) -> tuple[float, float]: + """First sample.""" + assert self._samples, "Assumes samples not empty" + return self._samples[0] + + @property + def last_sample(self) -> tuple[float, float]: + """Last sample.""" + assert self._samples, "Assumes samples not empty" + return self._samples[-1] + + def reset(self) -> None: + """Start ETA calculations from current time.""" + del self._samples[:] + + def add_sample(self, time: float, progress: float) -> None: + """Add a new sample. + + Args: + time: Time when sample occurred. + progress: Progress ratio (0 is start, 1 is complete). + """ + if self._samples and self.last_sample[1] > progress: + # If progress goes backwards, we need to reset calculations + self.reset() + self._samples.append((time, progress)) + self._add_count += 1 + if self._add_count % 100 == 0: + # Prune periodically so we don't accumulate vast amounts of samples + self._prune() + + def _prune(self) -> None: + """Prune old samples.""" + if len(self._samples) <= 10: + # Keep at least 10 samples + return + prune_time = self._samples[-1][0] - self.estimation_period + index = bisect.bisect_left(self._samples, (prune_time, 0)) + del self._samples[:index] + + def _get_progress_at(self, time: float) -> tuple[float, float]: + """Get the progress at a specific time.""" + + index = bisect.bisect_left(self._samples, (time, 0)) + if index >= len(self._samples): + return self.last_sample + if index == 0: + return self.first_sample + # Linearly interpolate progress between two samples + time1, progress1 = self._samples[index - 1] + time2, progress2 = self._samples[index] + factor = (time - time1) / (time2 - time1) + intermediate_progress = progress1 + (progress2 - progress1) * factor + return time, intermediate_progress + + @property + def speed(self) -> float | None: + """The current speed, or `None` if it couldn't be calculated.""" + + if len(self._samples) < 2: + # Need at least 2 samples to calculate speed + return None + + recent_sample_time, progress2 = self.last_sample + progress_start_time, progress1 = self._get_progress_at( + recent_sample_time - self.estimation_period + ) + if recent_sample_time - progress_start_time < 1: + # Require at least a second span to calculate speed. + return None + time_delta = recent_sample_time - progress_start_time + distance = progress2 - progress1 + speed = distance / time_delta if time_delta else 0 + return speed + + def get_eta(self, time: float) -> int | None: + """Estimated seconds until completion, or `None` if no estimate can be made. + + Args: + time: Current time. + """ + speed = self.speed + if not speed: + # Not enough samples to guess + return None + recent_time, recent_progress = self.last_sample + remaining = 1.0 - recent_progress + if remaining <= 0: + # Complete + return 0 + # The bar is not complete, so we will extrapolate progress + # This will give us a countdown, even with no samples + time_since_sample = min(self.max_extrapolate, time - recent_time) + extrapolate_progress = speed * time_since_sample + # We don't want to extrapolate all the way to 0, as that would erroneously suggest it is finished + eta = max(1.0, (remaining - extrapolate_progress) / speed) + return ceil(eta) diff --git a/src/memray/_vendor/textual/events.py b/src/memray/_vendor/textual/events.py new file mode 100644 index 0000000000..7928d52cf0 --- /dev/null +++ b/src/memray/_vendor/textual/events.py @@ -0,0 +1,989 @@ +""" + +Builtin events sent by Textual. + +Events may be marked as "Bubbles" and "Verbose". +See the [events guide](/guide/events/#bubbling) for an explanation of bubbling. +Verbose events are excluded from the textual console, unless you explicitly request them with the `-v` switch as follows: + +``` +textual console -v +``` +""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Type, TypeVar + +import rich.repr +from rich.style import Style +from typing_extensions import Self + +from memray._vendor.textual._types import CallbackType +from memray._vendor.textual.geometry import Offset, Size +from memray._vendor.textual.keys import _get_key_aliases +from memray._vendor.textual.message import Message + +MouseEventT = TypeVar("MouseEventT", bound="MouseEvent") + +if TYPE_CHECKING: + from memray._vendor.textual.dom import DOMNode + from memray._vendor.textual.timer import Timer as TimerClass + from memray._vendor.textual.timer import TimerCallback + from memray._vendor.textual.widget import Widget + + +@rich.repr.auto +class Event(Message): + """The base class for all events.""" + + +@rich.repr.auto +class Callback(Event, bubble=False, verbose=True): + """Sent by Textual to invoke a callback + (see [call_next][textual.message_pump.MessagePump.call_next] and + [call_later][textual.message_pump.MessagePump.call_later]). + """ + + def __init__(self, callback: CallbackType) -> None: + self.callback = callback + super().__init__() + + def __rich_repr__(self) -> rich.repr.Result: + yield "callback", self.callback + + +@dataclass +class CursorPosition(Event, bubble=False): + """Internal event used to retrieve the terminal's cursor position.""" + + x: int + y: int + + +class Load(Event, bubble=False): + """ + Sent when the App is running but *before* the terminal is in application mode. + + Use this event to run any setup that doesn't require any visuals such as loading + configuration and binding keys. + + - [ ] Bubbles + - [ ] Verbose + """ + + +class Idle(Event, bubble=False): + """Sent when there are no more items in the message queue. + + This is a pseudo-event in that it is created by the Textual system and doesn't go + through the usual message queue. + + - [ ] Bubbles + - [ ] Verbose + """ + + +class Action(Event): + __slots__ = ["action"] + + def __init__(self, action: str) -> None: + super().__init__() + self.action = action + + def __rich_repr__(self) -> rich.repr.Result: + yield "action", self.action + + +class Resize(Event, bubble=False): + """Sent when the app or widget has been resized. + + - [ ] Bubbles + - [ ] Verbose + + Args: + size: The new size of the Widget. + virtual_size: The virtual size (scrollable size) of the Widget. + container_size: The size of the Widget's container widget. + """ + + __slots__ = ["size", "virtual_size", "container_size"] + + def __init__( + self, + size: Size, + virtual_size: Size, + container_size: Size | None = None, + pixel_size: Size | None = None, + ) -> None: + self.size = size + """The new size of the Widget.""" + self.virtual_size = virtual_size + """The virtual size (scrollable size) of the Widget.""" + self.container_size = size if container_size is None else container_size + """The size of the Widget's container widget.""" + self.pixel_size = pixel_size + """Size of terminal window in pixels if known, or `None` if not known.""" + super().__init__() + + @classmethod + def from_dimensions( + cls, cells: tuple[int, int], pixels: tuple[int, int] | None + ) -> Resize: + """Construct from basic dimensions. + + Args: + cells: tuple of (, ) in cells. + pixels: tuple of (, ) in pixels if known, or `None` if not known. + + """ + size = Size(*cells) + pixel_size = Size(*pixels) if pixels is not None else None + return Resize(size, size, size, pixel_size) + + def can_replace(self, message: "Message") -> bool: + return isinstance(message, Resize) + + def __rich_repr__(self) -> rich.repr.Result: + yield "size", self.size + yield "virtual_size", self.virtual_size, self.size + yield "container_size", self.container_size, self.size + yield "pixel_size", self.pixel_size, None + + +class Compose(Event, bubble=False, verbose=True): + """Sent to a widget to request it to compose and mount children. + + This event is used internally by Textual. + You won't typically need to explicitly handle it, + + - [ ] Bubbles + - [X] Verbose + """ + + +class Mount(Event, bubble=False, verbose=False): + """Sent when a widget is *mounted* and may receive messages. + + - [ ] Bubbles + - [ ] Verbose + """ + + +class Unmount(Event, bubble=False, verbose=False): + """Sent when a widget is unmounted and may no longer receive messages. + + - [ ] Bubbles + - [ ] Verbose + """ + + +class Show(Event, bubble=False): + """Sent when a widget is first displayed. + + - [ ] Bubbles + - [ ] Verbose + """ + + +class Hide(Event, bubble=False): + """Sent when a widget has been hidden. + + - [ ] Bubbles + - [ ] Verbose + + Sent when any of the following conditions apply: + + - The widget is removed from the DOM. + - The widget is no longer displayed because it has been scrolled or clipped from the terminal or its container. + - The widget has its `display` attribute set to `False`. + - The widget's `display` style is set to `"none"`. + """ + + +class Ready(Event, bubble=False): + """Sent to the `App` when the DOM is ready and the first frame has been displayed. + + - [ ] Bubbles + - [ ] Verbose + """ + + +@rich.repr.auto +class MouseCapture(Event, bubble=False): + """Sent when the mouse has been captured. + + - [ ] Bubbles + - [ ] Verbose + + When a mouse has been captured, all further mouse events will be sent to the capturing widget. + + Args: + mouse_position: The position of the mouse when captured. + """ + + def __init__(self, mouse_position: Offset) -> None: + super().__init__() + self.mouse_position = mouse_position + """The position of the mouse when captured.""" + + def __rich_repr__(self) -> rich.repr.Result: + yield None, self.mouse_position + + +@rich.repr.auto +class MouseRelease(Event, bubble=False): + """Mouse has been released. + + - [ ] Bubbles + - [ ] Verbose + + Args: + mouse_position: The position of the mouse when released. + """ + + def __init__(self, mouse_position: Offset) -> None: + super().__init__() + self.mouse_position = mouse_position + """The position of the mouse when released.""" + + def __rich_repr__(self) -> rich.repr.Result: + yield None, self.mouse_position + + +class InputEvent(Event): + """Base class for input events.""" + + +@rich.repr.auto +class Key(InputEvent): + """Sent when the user hits a key on the keyboard. + + - [X] Bubbles + - [ ] Verbose + + Args: + key: The key that was pressed. + character: A printable character or `None` if it is not printable. + """ + + __slots__ = ["key", "character", "aliases"] + + def __init__(self, key: str, character: str | None) -> None: + super().__init__() + self.key = key + """The key that was pressed.""" + self.character = ( + (key if len(key) == 1 else None) if character is None else character + ) + """A printable character or ``None`` if it is not printable.""" + self.aliases: list[str] = _get_key_aliases(key) + """The aliases for the key, including the key itself.""" + + def __rich_repr__(self) -> rich.repr.Result: + yield "key", self.key + yield "character", self.character + yield "name", self.name + yield "is_printable", self.is_printable + yield "aliases", self.aliases, [self.key] + + @property + def name(self) -> str: + """Name of a key suitable for use as a Python identifier.""" + return _key_to_identifier(self.key).lower() + + @property + def name_aliases(self) -> list[str]: + """The corresponding name for every alias in `aliases` list.""" + return [_key_to_identifier(key) for key in self.aliases] + + @property + def is_printable(self) -> bool: + """Check if the key is printable (produces a unicode character). + + Returns: + `True` if the key is printable. + """ + return False if self.character is None else self.character.isprintable() + + +def _key_to_identifier(key: str) -> str: + """Convert the key string to a name suitable for use as a Python identifier.""" + key_no_modifiers = key.split("+")[-1] + if len(key_no_modifiers) == 1 and key_no_modifiers.isupper(): + if "+" in key: + key = f"{key.rpartition('+')[0]}+upper_{key_no_modifiers}" + else: + key = f"upper_{key_no_modifiers}" + return key.replace("+", "_").lower() + + +@rich.repr.auto +class MouseEvent(InputEvent, bubble=True): + """Sent in response to a mouse event. + + - [X] Bubbles + - [ ] Verbose + + Args: + widget: The widget under the mouse. + x: The relative x coordinate. + y: The relative y coordinate. + delta_x: Change in x since the last message. + delta_y: Change in y since the last message. + button: Indexed of the pressed button. + shift: True if the shift key is pressed. + meta: True if the meta key is pressed. + ctrl: True if the ctrl key is pressed. + screen_x: The absolute x coordinate. + screen_y: The absolute y coordinate. + style: The Rich Style under the mouse cursor. + """ + + __slots__ = [ + "widget", + "_x", + "_y", + "_delta_x", + "_delta_y", + "button", + "shift", + "meta", + "ctrl", + "_screen_x", + "_screen_y", + "_style", + ] + + def __init__( + self, + widget: Widget | None, + x: float, + y: float, + delta_x: int, + delta_y: int, + button: int, + shift: bool, + meta: bool, + ctrl: bool, + screen_x: float | None = None, + screen_y: float | None = None, + style: Style | None = None, + ) -> None: + super().__init__() + self.widget: Widget | None = widget + """The widget under the mouse at the time of a click.""" + self._x = x + """The relative x coordinate.""" + self._y = y + """The relative y coordinate.""" + self._delta_x = delta_x + """Change in x since the last message.""" + self._delta_y = delta_y + """Change in y since the last message.""" + self.button = button + """Indexed of the pressed button.""" + self.shift = shift + """`True` if the shift key is pressed.""" + self.meta = meta + """`True` if the meta key is pressed.""" + self.ctrl = ctrl + """`True` if the ctrl key is pressed.""" + self._screen_x = x if screen_x is None else screen_x + """The absolute x coordinate.""" + self._screen_y = y if screen_y is None else screen_y + """The absolute y coordinate.""" + self._style = style or Style() + + @property + def x(self) -> int: + """The relative X coordinate of the cell under the mouse.""" + return int(self._x) + + @property + def y(self) -> int: + """The relative Y coordinate of the cell under the mouse.""" + return int(self._y) + + @property + def delta_x(self) -> int: + """Change in `x` since last message.""" + return self._delta_x + + @property + def delta_y(self) -> int: + """Change in `y` since the last message.""" + return self._delta_y + + @property + def screen_x(self) -> int: + """X coordinate of the cell relative to top left of screen.""" + return int(self._screen_x) + + @property + def screen_y(self) -> int: + """Y coordinate of the cell relative to top left of screen.""" + return int(self._screen_y) + + @property + def pointer_x(self) -> float: + """The relative X coordinate of the pointer.""" + return self._x + + @property + def pointer_y(self) -> float: + """The relative Y coordinate of the pointer.""" + return self._y + + @property + def pointer_screen_x(self) -> float: + """The X coordinate of the pointer relative to the screen.""" + return self._screen_x + + @property + def pointer_screen_y(self) -> float: + """The Y coordinate of the pointer relative to the screen.""" + return self._screen_y + + @classmethod + def from_event( + cls: Type[MouseEventT], widget: Widget, event: MouseEvent + ) -> MouseEventT: + new_event = cls( + widget, + event._x, + event._y, + event._delta_x, + event._delta_y, + event.button, + event.shift, + event.meta, + event.ctrl, + event._screen_x, + event._screen_y, + event._style, + ) + return new_event + + def __rich_repr__(self) -> rich.repr.Result: + yield self.widget + yield "x", self.x + yield "y", self.y + yield "pointer_x", self.pointer_x + yield "pointer_y", self.pointer_y + yield "delta_x", self.delta_x, 0 + yield "delta_y", self.delta_y, 0 + if self.screen_x != self.x: + yield "screen_x", self._screen_x + if self.screen_y != self.y: + yield "screen_y", self._screen_y + yield "button", self.button, 0 + yield "shift", self.shift, False + yield "meta", self.meta, False + yield "ctrl", self.ctrl, False + if self.style: + yield "style", self.style + + @property + def control(self) -> Widget | None: + return self.widget + + @property + def offset(self) -> Offset: + """The mouse coordinate as an offset. + + Returns: + Mouse coordinate. + """ + return Offset(self.x, self.y) + + @property + def screen_offset(self) -> Offset: + """Mouse coordinate relative to the screen.""" + return Offset(self.screen_x, self.screen_y) + + @property + def delta(self) -> Offset: + """Mouse coordinate delta (change since last event).""" + return Offset(self.delta_x, self.delta_y) + + @property + def style(self) -> Style: + """The (Rich) Style under the cursor.""" + return self._style or Style() + + @style.setter + def style(self, style: Style) -> None: + self._style = style + + def get_content_offset(self, widget: Widget) -> Offset | None: + """Get offset within a widget's content area, or None if offset is not in content (i.e. padding or border). + + Args: + widget: Widget receiving the event. + + Returns: + An offset where the origin is at the top left of the content area. + """ + if self.screen_offset not in widget.content_region: + return None + return self.get_content_offset_capture(widget) + + def get_content_offset_capture(self, widget: Widget) -> Offset: + """Get offset from a widget's content area. + + This method works even if the offset is outside the widget content region. + + Args: + widget: Widget receiving the event. + + Returns: + An offset where the origin is at the top left of the content area. + """ + return self.offset - widget.gutter.top_left + + def _apply_offset(self, x: int, y: int) -> MouseEvent: + return self.__class__( + self.widget, + x=self._x + x, + y=self._y + y, + delta_x=self._delta_x, + delta_y=self._delta_y, + button=self.button, + shift=self.shift, + meta=self.meta, + ctrl=self.ctrl, + screen_x=self._screen_x, + screen_y=self._screen_y, + style=self.style, + ) + + +@rich.repr.auto +class MouseMove(MouseEvent, bubble=True, verbose=True): + """Sent when the mouse cursor moves. + + - [X] Bubbles + - [X] Verbose + """ + + +@rich.repr.auto +class MouseDown(MouseEvent, bubble=True, verbose=True): + """Sent when a mouse button is pressed. + + - [X] Bubbles + - [X] Verbose + """ + + +@rich.repr.auto +class MouseUp(MouseEvent, bubble=True, verbose=True): + """Sent when a mouse button is released. + + - [X] Bubbles + - [X] Verbose + """ + + +@rich.repr.auto +class MouseScrollDown(MouseEvent, bubble=True, verbose=True): + """Sent when the mouse wheel is scrolled *down*. + + - [X] Bubbles + - [X] Verbose + """ + + +@rich.repr.auto +class MouseScrollUp(MouseEvent, bubble=True, verbose=True): + """Sent when the mouse wheel is scrolled *up*. + + - [X] Bubbles + - [X] Verbose + """ + + +@rich.repr.auto +class MouseScrollRight(MouseEvent, bubble=True, verbose=True): + """Sent when the mouse wheel is scrolled *right*. + + - [X] Bubbles + - [X] Verbose + """ + + +@rich.repr.auto +class MouseScrollLeft(MouseEvent, bubble=True, verbose=True): + """Sent when the mouse wheel is scrolled *left*. + + - [X] Bubbles + - [X] Verbose + """ + + +class Click(MouseEvent, bubble=True): + """Sent when a widget is clicked. + + - [X] Bubbles + - [ ] Verbose + + Args: + chain: The number of clicks in the chain. 2 is a double click, 3 is a triple click, etc. + """ + + def __init__( + self, + widget: Widget | None, + x: int, + y: int, + delta_x: int, + delta_y: int, + button: int, + shift: bool, + meta: bool, + ctrl: bool, + screen_x: int | None = None, + screen_y: int | None = None, + style: Style | None = None, + chain: int = 1, + ) -> None: + super().__init__( + widget, + x, + y, + delta_x, + delta_y, + button, + shift, + meta, + ctrl, + screen_x, + screen_y, + style, + ) + self.chain = chain + + @classmethod + def from_event( + cls: Type[Self], + widget: Widget, + event: MouseEvent, + chain: int = 1, + ) -> Self: + new_event = cls( + widget, + event.x, + event.y, + event.delta_x, + event.delta_y, + event.button, + event.shift, + event.meta, + event.ctrl, + event.screen_x, + event.screen_y, + event._style, + chain=chain, + ) + return new_event + + def _apply_offset(self, x: int, y: int) -> Self: + return self.__class__( + self.widget, + x=self.x + x, + y=self.y + y, + delta_x=self.delta_x, + delta_y=self.delta_y, + button=self.button, + shift=self.shift, + meta=self.meta, + ctrl=self.ctrl, + screen_x=self.screen_x, + screen_y=self.screen_y, + style=self.style, + chain=self.chain, + ) + + def __rich_repr__(self) -> rich.repr.Result: + yield from super().__rich_repr__() + yield "chain", self.chain + + +@rich.repr.auto +class Timer(Event, bubble=False, verbose=True): + """Sent in response to a timer. + + - [ ] Bubbles + - [X] Verbose + """ + + __slots__ = ["timer", "time", "count", "callback"] + + def __init__( + self, + timer: "TimerClass", + time: float, + count: int = 0, + callback: TimerCallback | None = None, + ) -> None: + super().__init__() + self.timer = timer + self.time = time + self.count = count + self.callback = callback + + def __rich_repr__(self) -> rich.repr.Result: + yield self.timer.name + yield "count", self.count + + +class Enter(Event, bubble=True, verbose=True): + """Sent when the mouse is moved over a widget. + + Note that this event bubbles, so a widget may receive this event when the mouse + moves over a child widget. Check the `node` attribute for the widget directly under + the mouse. + + - [X] Bubbles + - [X] Verbose + """ + + __slots__ = ["node"] + + def __init__(self, node: DOMNode) -> None: + self.node = node + """The node directly under the mouse.""" + super().__init__() + + @property + def control(self) -> DOMNode: + """Alias for the `node` under the mouse.""" + return self.node + + +class Leave(Event, bubble=True, verbose=True): + """Sent when the mouse is moved away from a widget, or if a widget is + programmatically disabled while hovered. + + Note that this widget bubbles, so a widget may receive Leave events for any child widgets. + Check the `node` parameter for the original widget that was previously under the mouse. + + + - [X] Bubbles + - [X] Verbose + """ + + __slots__ = ["node"] + + def __init__(self, node: DOMNode) -> None: + self.node = node + """The node that was previously directly under the mouse.""" + super().__init__() + + @property + def control(self) -> DOMNode: + """Alias for the `node` that was previously under the mouse.""" + return self.node + + +class Focus(Event, bubble=False): + """Sent when a widget is focussed. + + - [ ] Bubbles + - [ ] Verbose + + Args: + from_app_focus: True if this focus event has been sent because the app itself has + regained focus (via an AppFocus event). False if the focus came from within + the Textual app (e.g. via the user pressing tab or a programmatic setting + of the focused widget). + """ + + def __init__(self, from_app_focus: bool = False) -> None: + self.from_app_focus = from_app_focus + super().__init__() + + def __rich_repr__(self) -> rich.repr.Result: + yield from super().__rich_repr__() + yield "from_app_focus", self.from_app_focus + + +class Blur(Event, bubble=False): + """Sent when a widget is blurred (un-focussed). + + - [ ] Bubbles + - [ ] Verbose + """ + + +class AppFocus(Event, bubble=False): + """Sent when the app has focus. + + - [ ] Bubbles + - [ ] Verbose + + Note: + Only available when running within a terminal that supports + `FocusIn`, or when running via textual-web. + """ + + +class AppBlur(Event, bubble=False): + """Sent when the app loses focus. + + - [ ] Bubbles + - [ ] Verbose + + Note: + Only available when running within a terminal that supports + `FocusOut`, or when running via textual-web. + """ + + +@dataclass +class DescendantFocus(Event, bubble=True, verbose=True): + """Sent when a child widget is focussed. + + - [X] Bubbles + - [X] Verbose + """ + + widget: Widget + """The widget that was focused.""" + + @property + def control(self) -> Widget: + """The widget that was focused (alias of `widget`).""" + return self.widget + + +@dataclass +class DescendantBlur(Event, bubble=True, verbose=True): + """Sent when a child widget is blurred. + + - [X] Bubbles + - [X] Verbose + """ + + widget: Widget + """The widget that was blurred.""" + + @property + def control(self) -> Widget: + """The widget that was blurred (alias of `widget`).""" + return self.widget + + +@rich.repr.auto +class Paste(Event, bubble=True): + """Event containing text that was pasted into the Textual application. + This event will only appear when running in a terminal emulator that supports + bracketed paste mode. Textual will enable bracketed pastes when an app starts, + and disable it when the app shuts down. + + - [X] Bubbles + - [ ] Verbose + + + Args: + text: The text that has been pasted. + """ + + def __init__(self, text: str) -> None: + super().__init__() + self.text = text + """The text that was pasted.""" + + def __rich_repr__(self) -> rich.repr.Result: + yield "text", self.text + + +@dataclass +class ScreenResume(Event, bubble=False): + """Sent to screen that has been made active. + + - [ ] Bubbles + - [ ] Verbose + """ + + refresh_styles: bool = True + """Should the resuming screen refresh its styles?""" + + def __rich_repr__(self) -> rich.repr.Result: + yield self.refresh_styles + + +class ScreenSuspend(Event, bubble=False): + """Sent to screen when it is no longer active. + + - [ ] Bubbles + - [ ] Verbose + """ + + +@rich.repr.auto +class Print(Event, bubble=False): + """Sent to a widget that is capturing [`print`][print]. + + - [ ] Bubbles + - [ ] Verbose + + Args: + text: Text that was printed. + stderr: `True` if the print was to stderr, or `False` for stdout. + + Note: + Python's [`print`][print] output can be captured with + [`App.begin_capture_print`][textual.app.App.begin_capture_print]. + """ + + def __init__(self, text: str, stderr: bool = False) -> None: + super().__init__() + self.text = text + """The text that was printed.""" + self.stderr = stderr + """`True` if the print was to stderr, or `False` for stdout.""" + + def __rich_repr__(self) -> rich.repr.Result: + yield self.text + yield self.stderr + + +@dataclass +class DeliveryComplete(Event, bubble=False): + """Sent to App when a file has been delivered.""" + + key: str + """The delivery key associated with the delivery. + + This is the same key that was returned by `App.deliver_text`/`App.deliver_binary`. + """ + + path: Path | None = None + """The path where the file was saved, or `None` if the path is not available, for + example if the file was delivered via web browser. + """ + + name: str | None = None + """Optional name returned to the app to identify the download.""" + + +@dataclass +class DeliveryFailed(Event, bubble=False): + """Sent to App when a file delivery fails.""" + + key: str + """The delivery key associated with the delivery.""" + + exception: BaseException + """The exception that was raised during the delivery.""" + + name: str | None = None + """Optional name returned to the app to identify the download.""" + + +class TextSelected(Event, bubble=True): + """Sent from the screen when text is selected (Not Input and TextArea)""" diff --git a/src/memray/_vendor/textual/expand_tabs.py b/src/memray/_vendor/textual/expand_tabs.py new file mode 100644 index 0000000000..a019aa11c0 --- /dev/null +++ b/src/memray/_vendor/textual/expand_tabs.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +import re + +from rich.cells import cell_len +from rich.text import Text + +_TABS_SPLITTER_RE = re.compile(r"(.*?\t|.+?$)") + + +def get_tab_widths(line: str, tab_size: int = 4) -> list[tuple[str, int]]: + """Splits a string line into tuples (str, int). + + Each tuple represents a section of the line which precedes a tab character. + The string is the string text that appears before the tab character (excluding the tab). + The integer is the width that the tab character is expanded to. + + Args: + line: The text to expand tabs in. + tab_size: Number of cells in a tab. + + Returns: + A list of tuples representing the line split on tab characters, + and the widths of the tabs after tab expansion is applied. + """ + + parts: list[tuple[str, int]] = [] + add_part = parts.append + cell_position = 0 + matches = _TABS_SPLITTER_RE.findall(line) + + for match in matches: + expansion_width = 0 + if match.endswith("\t"): + # Remove the tab, and check the width of the rest of the line. + match = match[:-1] + cell_position += cell_len(match) + + # Now move along the line by the width of the tab. + tab_remainder = cell_position % tab_size + expansion_width = tab_size - tab_remainder + cell_position += expansion_width + + add_part((match, expansion_width)) + + return parts + + +def expand_tabs_inline(line: str, tab_size: int = 4) -> str: + """Expands tabs, taking into account double cell characters. + + Args: + line: The text to expand tabs in. + tab_size: Number of cells in a tab. + Returns: + New string with tabs replaced with spaces. + """ + tab_widths = get_tab_widths(line, tab_size) + return "".join( + [part + expansion_width * " " for part, expansion_width in tab_widths] + ) + + +def expand_text_tabs_from_widths(line: Text, tab_widths: list[int]) -> Text: + """Expand tabs to the widths defined in the `tab_widths` list. + + This will return a new Text instance with tab characters expanded into a + number of spaces. Each time a tab is encountered, it's expanded into the + next integer encountered in the `tab_widths` list. Consequently, the length + of `tab_widths` should match the number of tab characters in `line`. + + Args: + line: The `Text` instance to expand tabs in. + tab_widths: The widths to expand tabs to. + + Returns: + A new text instance with tab characters converted to spaces. + """ + if "\t" not in line.plain: + return line + + parts = line.split("\t", include_separator=True) + tab_widths_iter = iter(tab_widths) + + new_parts: list[Text] = [] + append_part = new_parts.append + for part in parts: + if part.plain.endswith("\t"): + part._text[-1] = part._text[-1][:-1] + " " + spaces = next(tab_widths_iter) + part.extend_style(spaces - 1) + append_part(part) + + return Text("", end="").join(new_parts) + + +if __name__ == "__main__": + print(expand_tabs_inline("\tbar")) + print(expand_tabs_inline("\tbar\t")) + print(expand_tabs_inline("1\tbar")) + print(expand_tabs_inline("12\tbar")) + print(expand_tabs_inline("123\tbar")) + print(expand_tabs_inline("1234\tbar")) + print(expand_tabs_inline("💩\tbar")) + print(expand_tabs_inline("💩💩\tbar")) + print(expand_tabs_inline("💩💩💩\tbar")) + print(expand_tabs_inline("F💩\tbar")) + print(expand_tabs_inline("F💩O\tbar")) diff --git a/src/memray/_vendor/textual/features.py b/src/memray/_vendor/textual/features.py new file mode 100644 index 0000000000..9799363cd2 --- /dev/null +++ b/src/memray/_vendor/textual/features.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, cast + +from typing_extensions import Literal + +if TYPE_CHECKING: + from typing_extensions import Final + +FEATURES: Final = {"devtools", "debug", "headless"} + +FeatureFlag = Literal["devtools", "debug", "headless"] + + +def parse_features(features: str) -> frozenset[FeatureFlag]: + """Parse features env var + + Args: + features: Comma separated feature flags + + Returns: + A frozen set of known features. + """ + + features_set = frozenset( + feature.strip().lower() for feature in features.split(",") if feature.strip() + ).intersection(FEATURES) + + return cast("frozenset[FeatureFlag]", features_set) diff --git a/src/memray/_vendor/textual/file_monitor.py b/src/memray/_vendor/textual/file_monitor.py new file mode 100644 index 0000000000..cd2b6d7821 --- /dev/null +++ b/src/memray/_vendor/textual/file_monitor.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import os +from pathlib import Path +from typing import Callable, Iterable, Sequence + +import rich.repr + +from memray._vendor.textual._callback import invoke + + +@rich.repr.auto +class FileMonitor: + """Monitors files for changes and invokes a callback when it does.""" + + _paths: set[Path] + + def __init__(self, paths: Sequence[Path], callback: Callable[[], None]) -> None: + """Monitor the given file paths for changes. + + Args: + paths: Paths to monitor. + callback: Callback to invoke if any of the paths change. + """ + self._paths = set(paths) + self.callback = callback + self._modified = self._get_last_modified_time() + + def __rich_repr__(self) -> rich.repr.Result: + yield self._paths + + def _get_last_modified_time(self) -> float: + """Get the most recent modified time out of all files being watched.""" + modified_times = [] + for path in self._paths: + try: + modified_time = os.stat(path).st_mtime + except FileNotFoundError: + modified_time = 0 + modified_times.append(modified_time) + return max(modified_times, default=0) + + def check(self) -> bool: + """Check the monitored files. Return True if any were changed since the last modification time.""" + modified = self._get_last_modified_time() + changed = modified != self._modified + self._modified = modified + return changed + + def add_paths(self, paths: Iterable[Path]) -> None: + """Adds paths to start being monitored. + + Args: + paths: The paths to be monitored. + """ + self._paths.update(paths) + + async def __call__(self) -> None: + if self.check(): + await self.on_change() + + async def on_change(self) -> None: + """Called when any of the monitored files change.""" + await invoke(self.callback) diff --git a/src/memray/_vendor/textual/filter.py b/src/memray/_vendor/textual/filter.py new file mode 100644 index 0000000000..db77d37c48 --- /dev/null +++ b/src/memray/_vendor/textual/filter.py @@ -0,0 +1,288 @@ +"""Filter classes. + +!!! note + + Filters are used internally, and not recommended for use by Textual app developers. + +Filters are used internally to process terminal output after it has been rendered. +Currently this is used internally to convert the application to monochrome, when the NO_COLOR env var is set. + +In the future, this system will be used to implement accessibility features. + +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from functools import lru_cache + +from rich.color import Color as RichColor +from rich.segment import Segment +from rich.style import Style +from rich.terminal_theme import TerminalTheme + +from memray._vendor.textual.color import Color +from memray._vendor.textual.constants import DIM_FACTOR + + +class LineFilter(ABC): + """Base class for a line filter.""" + + def __init__(self, enabled: bool = True) -> None: + """ + + Args: + enabled: If `enabled` is `False` then the filter will not be applied. + """ + self.enabled = enabled + + @abstractmethod + def apply(self, segments: list[Segment], background: Color) -> list[Segment]: + """Transform a list of segments. + + Args: + segments: A list of segments. + background: The background color. + + Returns: + A new list of segments. + """ + + +@lru_cache(1024) +def monochrome_style(style: Style) -> Style: + """Convert colors in a style to monochrome. + + Args: + style: A Rich Style. + + Returns: + A new Rich style. + """ + style_color = style.color + style_background = style.bgcolor + color = ( + None + if style_color is None + else Color.from_rich_color(style_color).monochrome.rich_color + ) + background = ( + None + if style_background is None + else Color.from_rich_color(style_background).monochrome.rich_color + ) + return style + Style.from_color(color, background) + + +class Monochrome(LineFilter): + """Convert all colors to monochrome.""" + + def apply(self, segments: list[Segment], background: Color) -> list[Segment]: + """Transform a list of segments. + + Args: + segments: A list of segments. + background: The background color. + + Returns: + A new list of segments. + """ + _monochrome_style = monochrome_style + _Segment = Segment + return [ + _Segment(text, _monochrome_style(style), None) + for text, style, _ in segments + ] + + +class NoColor(LineFilter): + """Remove all color information from segments.""" + + DEFAULT_COLORS = Style.from_color( + RichColor.parse("default"), RichColor.parse("default") + ) + + def apply(self, segments: list[Segment], background: Color) -> list[Segment]: + """Transform a list of segments. + + Args: + segments: A list of segments. + background: The background color. + + Returns: + A new list of segments. + """ + + _Segment = Segment + default_colors = self.DEFAULT_COLORS + return [ + _Segment(text, None if style is None else (style + default_colors), control) + for text, style, control in segments + ] + + +NO_DIM = Style(dim=False) +"""A Style to set dim to False.""" + + +@lru_cache(1024) +def dim_color( + background: RichColor, color: RichColor, factor: float = DIM_FACTOR +) -> RichColor: + """Dim a color by blending towards the background + + Args: + background: background color. + color: Foreground color. + factor: Blend factor + + Returns: + New dimmer color. + """ + red1, green1, blue1 = background.triplet + red2, green2, blue2 = color.triplet + + return RichColor.from_rgb( + red1 + (red2 - red1) * factor, + green1 + (green2 - green1) * factor, + blue1 + (blue2 - blue1) * factor, + ) + + +DEFAULT_COLOR = RichColor.default() + + +@lru_cache(1024) +def dim_style(style: Style, background: Color, factor: float) -> Style: + """Replace dim attribute with a dim color. + + Args: + style: Style to dim. + factor: Blend factor. + + Returns: + New dimmed style. + """ + return ( + style + + Style.from_color( + dim_color( + (background.rich_color if style.bgcolor.is_default else style.bgcolor), + style.color, + factor, + ), + None, + ) + ) + NO_DIM + + +# Can be used as a workaround for https://github.com/xtermjs/xterm.js/issues/4161 +class DimFilter(LineFilter): + """Replace dim attributes with modified colors.""" + + def __init__(self, dim_factor: float = 0.5, enabled: bool = True) -> None: + """Initialize the filter. + + Args: + dim_factor: The factor to dim by; 0 is 100% background (i.e. invisible), 1.0 is no change. + """ + self.dim_factor = dim_factor + super().__init__(enabled=enabled) + + def apply(self, segments: list[Segment], background: Color) -> list[Segment]: + """Transform a list of segments. + + Args: + segments: A list of segments. + background: The background color. + + Returns: + A new list of segments. + """ + _Segment = Segment + _dim_style = dim_style + factor = self.dim_factor + return [ + ( + _Segment( + segment.text, + _dim_style(segment.style, background, factor), + None, + ) + if segment.style is not None and segment.style.dim + else segment + ) + for segment in segments + ] + + +class ANSIToTruecolor(LineFilter): + """Convert ANSI colors to their truecolor equivalents.""" + + def __init__(self, terminal_theme: TerminalTheme, enabled: bool = True): + """Initialise filter. + + Args: + terminal_theme: A rich terminal theme. + """ + self._terminal_theme = terminal_theme + super().__init__(enabled=enabled) + + @lru_cache(1024) + def truecolor_style(self, style: Style, background: RichColor) -> Style: + """Replace system colors with truecolor equivalent. + + Args: + style: Style to apply truecolor filter to. + + Returns: + New style. + """ + terminal_theme = self._terminal_theme + + changed = False + if (color := style.color) is not None: + if color.triplet is None: + color = RichColor.from_triplet( + color.get_truecolor(terminal_theme, foreground=True) + ) + changed = True + + if (bgcolor := style.bgcolor) is not None and bgcolor.triplet is None: + bgcolor = RichColor.from_triplet( + bgcolor.get_truecolor(terminal_theme, foreground=False) + ) + changed = True + + if style.dim and color is not None: + color = dim_color(background if bgcolor is None else bgcolor, color) + style += NO_DIM + changed = True + + return style + Style.from_color(color, bgcolor) if changed else style + + def apply(self, segments: list[Segment], background: Color) -> list[Segment]: + """Transform a list of segments. + + Args: + segments: A list of segments. + background: The background color. + + Returns: + A new list of segments. + """ + _Segment = Segment + truecolor_style = self.truecolor_style + background_rich_color = background.rich_color + return [ + _Segment( + text, + ( + None + if style is None + else truecolor_style(style, background_rich_color) + ), + None, + ) + for text, style, _ in segments + ] diff --git a/src/memray/_vendor/textual/fuzzy.py b/src/memray/_vendor/textual/fuzzy.py new file mode 100644 index 0000000000..bb563b6da0 --- /dev/null +++ b/src/memray/_vendor/textual/fuzzy.py @@ -0,0 +1,224 @@ +""" +Fuzzy matcher. + +This class is used by the [command palette](/guide/command_palette) to match search terms. + +""" + +from __future__ import annotations + +from functools import lru_cache +from operator import itemgetter +from re import finditer +from typing import Iterable, Sequence + +import rich.repr + +from memray._vendor.textual.cache import LRUCache +from memray._vendor.textual.content import Content +from memray._vendor.textual.visual import Style + + +class FuzzySearch: + """Performs a fuzzy search. + + Unlike a regex solution, this will finds all possible matches. + """ + + def __init__( + self, case_sensitive: bool = False, *, cache_size: int = 1024 * 4 + ) -> None: + """Initialize fuzzy search. + + Args: + case_sensitive: Is the match case sensitive? + cache_size: Number of queries to cache. + """ + + self.case_sensitive = case_sensitive + self.cache: LRUCache[tuple[str, str], tuple[float, Sequence[int]]] = LRUCache( + cache_size + ) + + def match(self, query: str, candidate: str) -> tuple[float, Sequence[int]]: + """Match against a query. + + Args: + query: The fuzzy query. + candidate: A candidate to check,. + + Returns: + A pair of (score, tuple of offsets). `(0, ())` for no result. + """ + + cache_key = (query, candidate) + if cache_key in self.cache: + return self.cache[cache_key] + default: tuple[float, Sequence[int]] = (0.0, []) + result = max(self._match(query, candidate), key=itemgetter(0), default=default) + self.cache[cache_key] = result + return result + + @classmethod + @lru_cache(maxsize=1024) + def get_first_letters(cls, candidate: str) -> frozenset[int]: + return frozenset({match.start() for match in finditer(r"\w+", candidate)}) + + def score(self, candidate: str, positions: Sequence[int]) -> float: + """Score a search. + + Args: + search: Search object. + + Returns: + Score. + """ + first_letters = self.get_first_letters(candidate) + # This is a heuristic, and can be tweaked for better results + # Boost first letter matches + offset_count = len(positions) + score: float = offset_count + len(first_letters.intersection(positions)) + + groups = 1 + last_offset, *offsets = positions + for offset in offsets: + if offset != last_offset + 1: + groups += 1 + last_offset = offset + + # Boost to favor less groups + normalized_groups = (offset_count - (groups - 1)) / offset_count + score *= 1 + (normalized_groups * normalized_groups) + return score + + def _match( + self, query: str, candidate: str + ) -> Iterable[tuple[float, Sequence[int]]]: + letter_positions: list[list[int]] = [] + position = 0 + + if not self.case_sensitive: + candidate = candidate.lower() + query = query.lower() + score = self.score + if query in candidate: + # Quick exit when the query exists as a substring + query_location = candidate.find(query) + offsets = list(range(query_location, query_location + len(query))) + yield ( + score(candidate, offsets) * (2.0 if candidate == query else 1.5), + offsets, + ) + return + + for offset, letter in enumerate(query): + last_index = len(candidate) - offset + positions: list[int] = [] + letter_positions.append(positions) + index = position + while (location := candidate.find(letter, index)) != -1: + positions.append(location) + index = location + 1 + if index >= last_index: + break + if not positions: + yield (0.0, ()) + return + position = positions[0] + 1 + + possible_offsets: list[list[int]] = [] + query_length = len(query) + + def get_offsets(offsets: list[int], positions_index: int) -> None: + """Recursively match offsets. + + Args: + offsets: A list of offsets. + positions_index: Index of query letter. + + """ + for offset in letter_positions[positions_index]: + if not offsets or offset > offsets[-1]: + new_offsets = [*offsets, offset] + if len(new_offsets) == query_length: + possible_offsets.append(new_offsets) + else: + get_offsets(new_offsets, positions_index + 1) + + get_offsets([], 0) + + for offsets in possible_offsets: + yield score(candidate, offsets), offsets + + +@rich.repr.auto +class Matcher: + """A fuzzy matcher.""" + + def __init__( + self, + query: str, + *, + match_style: Style | None = None, + case_sensitive: bool = False, + ) -> None: + """Initialise the fuzzy matching object. + + Args: + query: A query as typed in by the user. + match_style: The style to use to highlight matched portions of a string. + case_sensitive: Should matching be case sensitive? + """ + self._query = query + self._match_style = Style(reverse=True) if match_style is None else match_style + self._case_sensitive = case_sensitive + self.fuzzy_search = FuzzySearch() + + @property + def query(self) -> str: + """The query string to look for.""" + return self._query + + @property + def match_style(self) -> Style: + """The style that will be used to highlight hits in the matched text.""" + return self._match_style + + @property + def case_sensitive(self) -> bool: + """Is this matcher case sensitive?""" + return self._case_sensitive + + def match(self, candidate: str) -> float: + """Match the candidate against the query. + + Args: + candidate: Candidate string to match against the query. + + Returns: + Strength of the match from 0 to 1. + """ + return self.fuzzy_search.match(self.query, candidate)[0] + + def highlight(self, candidate: str) -> Content: + """Highlight the candidate with the fuzzy match. + + Args: + candidate: The candidate string to match against the query. + + Returns: + A [`Text`][rich.text.Text] object with highlighted matches. + """ + content = Content.from_markup(candidate) + score, offsets = self.fuzzy_search.match(self.query, candidate) + if not score: + return content + for offset in offsets: + if not candidate[offset].isspace(): + content = content.stylize(self._match_style, offset, offset + 1) + return content + + +if __name__ == "__main__": + fuzzy_search = FuzzySearch() + fuzzy_search.match("foo.bar", "foo/egg.bar") diff --git a/src/memray/_vendor/textual/geometry.py b/src/memray/_vendor/textual/geometry.py new file mode 100644 index 0000000000..7577d2cb43 --- /dev/null +++ b/src/memray/_vendor/textual/geometry.py @@ -0,0 +1,1487 @@ +""" + +Functions and classes to manage terminal geometry (anything involving coordinates or dimensions). +""" + +from __future__ import annotations + +import os +from functools import lru_cache +from operator import attrgetter, itemgetter +from typing import ( + TYPE_CHECKING, + Any, + Collection, + Iterable, + Literal, + NamedTuple, + Tuple, + TypeVar, + Union, + cast, +) + +from typing_extensions import Final + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + +import rich.repr + +SpacingDimensions: TypeAlias = Union[ + int, Tuple[int], Tuple[int, int], Tuple[int, int, int, int] +] +"""The valid ways in which you can specify spacing.""" + +T = TypeVar("T", int, float) + + +def clamp(value: T, minimum: T, maximum: T) -> T: + """Restrict a value to a given range. + + If `value` is less than the minimum, return the minimum. + If `value` is greater than the maximum, return the maximum. + Otherwise, return `value`. + + The `minimum` and `maximum` arguments values may be given in reverse order. + + Args: + value: A value. + minimum: Minimum value. + maximum: Maximum value. + + Returns: + New value that is not less than the minimum or greater than the maximum. + """ + if minimum > maximum: + # It is common for the min and max to be in non-intuitive order. + # Rather than force the caller to get it right, it is simpler to handle it here. + if value < maximum: + return maximum + if value > minimum: + return minimum + return value + else: + if value < minimum: + return minimum + if value > maximum: + return maximum + return value + + +class Offset(NamedTuple): + """A cell offset defined by x and y coordinates. + + Offsets are typically relative to the top left of the terminal or other container. + + Textual prefers the names `x` and `y`, but you could consider `x` to be the _column_ and `y` to be the _row_. + + Offsets support addition, subtraction, multiplication, and negation. + + Example: + ```python + >>> from textual.geometry import Offset + >>> offset = Offset(3, 2) + >>> offset + Offset(x=3, y=2) + >>> offset += Offset(10, 0) + >>> offset + Offset(x=13, y=2) + >>> -offset + Offset(x=-13, y=-2) + ``` + """ + + x: int = 0 + """Offset in the x-axis (horizontal)""" + y: int = 0 + """Offset in the y-axis (vertical)""" + + @property + def is_origin(self) -> bool: + """Is the offset at (0, 0)?""" + return self == (0, 0) + + @property + def clamped(self) -> Offset: + """This offset with `x` and `y` restricted to values above zero.""" + x, y = self + return Offset(0 if x < 0 else x, 0 if y < 0 else y) + + @property + def transpose(self) -> tuple[int, int]: + """A tuple of x and y, in reverse order, i.e. (y, x).""" + x, y = self + return y, x + + def __bool__(self) -> bool: + return self != (0, 0) + + def __add__(self, other: object) -> Offset: + if isinstance(other, tuple): + _x, _y = self + x, y = other + return Offset(_x + x, _y + y) + return NotImplemented + + def __sub__(self, other: object) -> Offset: + if isinstance(other, tuple): + _x, _y = self + x, y = other + return Offset(_x - x, _y - y) + return NotImplemented + + def __mul__(self, other: object) -> Offset: + if isinstance(other, (float, int)): + x, y = self + return Offset(int(x * other), int(y * other)) + if isinstance(other, tuple): + x, y = self + return Offset(int(x * other[0]), int(y * other[1])) + return NotImplemented + + def __neg__(self) -> Offset: + x, y = self + return Offset(-x, -y) + + def blend(self, destination: Offset, factor: float) -> Offset: + """Calculate a new offset on a line between this offset and a destination offset. + + Args: + destination: Point where factor would be 1.0. + factor: A value between 0 and 1.0. + + Returns: + A new point on a line between self and destination. + """ + x1, y1 = self + x2, y2 = destination + return Offset( + int(x1 + (x2 - x1) * factor), + int(y1 + (y2 - y1) * factor), + ) + + def get_distance_to(self, other: Offset) -> float: + """Get the distance to another offset. + + Args: + other: An offset. + + Returns: + Distance to other offset. + """ + x1, y1 = self + x2, y2 = other + distance: float = ((x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1)) ** 0.5 + return distance + + def clamp(self, width: int, height: int) -> Offset: + """Clamp the offset to fit within a rectangle of width x height. + + Args: + width: Width to clamp. + height: Height to clamp. + + Returns: + A new offset. + """ + x, y = self + return Offset(clamp(x, 0, width - 1), clamp(y, 0, height - 1)) + + +class Size(NamedTuple): + """The dimensions (width and height) of a rectangular region. + + Example: + ```python + >>> from textual.geometry import Size + >>> size = Size(2, 3) + >>> size + Size(width=2, height=3) + >>> size.area + 6 + >>> size + Size(10, 20) + Size(width=12, height=23) + ``` + """ + + width: int = 0 + """The width in cells.""" + + height: int = 0 + """The height in cells.""" + + def __bool__(self) -> bool: + """A Size is Falsy if it has area 0.""" + return self.width * self.height != 0 + + @property + def area(self) -> int: + """The area occupied by a region of this size.""" + return self.width * self.height + + @property + def region(self) -> Region: + """A region of the same size, at the origin.""" + width, height = self + return Region(0, 0, width, height) + + @property + def line_range(self) -> range: + """A range object that covers values between 0 and `height`.""" + return range(self.height) + + def with_width(self, width: int) -> Size: + """Get a new Size with just the width changed. + + Args: + width: New width. + + Returns: + New Size instance. + """ + return Size(width, self.height) + + def with_height(self, height: int) -> Size: + """Get a new Size with just the height changed. + + Args: + height: New height. + + Returns: + New Size instance. + """ + return Size(self.width, height) + + def __add__(self, other: object) -> Size: + if isinstance(other, tuple): + width, height = self + width2, height2 = other + return Size(max(0, width + width2), max(0, height + height2)) + return NotImplemented + + def __sub__(self, other: object) -> Size: + if isinstance(other, tuple): + width, height = self + width2, height2 = other + return Size(max(0, width - width2), max(0, height - height2)) + return NotImplemented + + def contains(self, x: int, y: int) -> bool: + """Check if a point is in area defined by the size. + + Args: + x: X coordinate. + y: Y coordinate. + + Returns: + True if the point is within the region. + """ + width, height = self + return width > x >= 0 and height > y >= 0 + + def contains_point(self, point: tuple[int, int]) -> bool: + """Check if a point is in the area defined by the size. + + Args: + point: A tuple of x and y coordinates. + + Returns: + True if the point is within the region. + """ + x, y = point + width, height = self + return width > x >= 0 and height > y >= 0 + + def __contains__(self, other: Any) -> bool: + try: + x: int + y: int + x, y = other + except Exception: + raise TypeError( + "Dimensions.__contains__ requires an iterable of two integers" + ) + width, height = self + return width > x >= 0 and height > y >= 0 + + def clamp_offset(self, offset: Offset) -> Offset: + """Clamp an offset to fit within the width x height. + + Args: + offset: An offset. + + Returns: + A new offset that will fit inside the dimensions defined in the Size. + """ + return offset.clamp(self.width, self.height) + + +class Region(NamedTuple): + """Defines a rectangular region. + + A Region consists of a coordinate (x and y) and dimensions (width and height). + + ``` + (x, y) + ┌────────────────────┐ ▲ + │ │ │ + │ │ │ + │ │ height + │ │ │ + │ │ │ + └────────────────────┘ ▼ + ◀─────── width ──────▶ + ``` + + Example: + ```python + >>> from textual.geometry import Region + >>> region = Region(4, 5, 20, 10) + >>> region + Region(x=4, y=5, width=20, height=10) + >>> region.area + 200 + >>> region.size + Size(width=20, height=10) + >>> region.offset + Offset(x=4, y=5) + >>> region.contains(1, 2) + False + >>> region.contains(10, 8) + True + ``` + """ + + x: int = 0 + """Offset in the x-axis (horizontal).""" + y: int = 0 + """Offset in the y-axis (vertical).""" + width: int = 0 + """The width of the region.""" + height: int = 0 + """The height of the region.""" + + @classmethod + def from_union(cls, regions: Collection[Region]) -> Region: + """Create a Region from the union of other regions. + + Args: + regions: One or more regions. + + Returns: + A Region that encloses all other regions. + """ + if not regions: + raise ValueError("At least one region expected") + min_x = min(regions, key=itemgetter(0)).x + max_x = max(regions, key=attrgetter("right")).right + min_y = min(regions, key=itemgetter(1)).y + max_y = max(regions, key=attrgetter("bottom")).bottom + return cls(min_x, min_y, max_x - min_x, max_y - min_y) + + @classmethod + def from_corners(cls, x1: int, y1: int, x2: int, y2: int) -> Region: + """Construct a Region form the top left and bottom right corners. + + Args: + x1: Top left x. + y1: Top left y. + x2: Bottom right x. + y2: Bottom right y. + + Returns: + A new region. + """ + return cls(x1, y1, x2 - x1, y2 - y1) + + @classmethod + def from_offset(cls, offset: tuple[int, int], size: tuple[int, int]) -> Region: + """Create a region from offset and size. + + Args: + offset: Offset (top left point). + size: Dimensions of region. + + Returns: + A region instance. + """ + x, y = offset + width, height = size + return cls(x, y, width, height) + + @classmethod + def get_scroll_to_visible( + cls, window_region: Region, region: Region, *, top: bool = False + ) -> Offset: + """Calculate the smallest offset required to translate a window so that it contains + another region. + + This method is used to calculate the required offset to scroll something into view. + + Args: + window_region: The window region. + region: The region to move inside the window. + top: Get offset to top of window. + + Returns: + An offset required to add to region to move it inside window_region. + """ + + if region in window_region and not top: + # Region is already inside the window, so no need to move it. + return NULL_OFFSET + + window_left, window_top, window_right, window_bottom = window_region.corners + region = region.crop_size(window_region.size) + left, top_, right, bottom = region.corners + delta_x = delta_y = 0 + + if not ( + (window_right > left >= window_left) + and (window_right > right >= window_left) + ): + # The region does not fit + # The window needs to scroll on the X axis to bring region into view + delta_x = min( + left - window_left, + left - (window_right - region.width), + key=abs, + ) + + if top: + delta_y = top_ - window_top + + elif not ( + (window_bottom > top_ >= window_top) + and (window_bottom > bottom >= window_top) + ): + # The window needs to scroll on the Y axis to bring region into view + delta_y = min( + top_ - window_top, + top_ - (window_bottom - region.height), + key=abs, + ) + return Offset(delta_x, delta_y) + + def __bool__(self) -> bool: + """A Region is considered False when it has no area.""" + _, _, width, height = self + return width * height > 0 + + @property + def column_span(self) -> tuple[int, int]: + """A pair of integers for the start and end columns (x coordinates) in this region. + + The end value is *exclusive*. + """ + return (self.x, self.x + self.width) + + @property + def line_span(self) -> tuple[int, int]: + """A pair of integers for the start and end lines (y coordinates) in this region. + + The end value is *exclusive*. + """ + return (self.y, self.y + self.height) + + @property + def right(self) -> int: + """Maximum X value (non inclusive).""" + return self.x + self.width + + @property + def bottom(self) -> int: + """Maximum Y value (non inclusive).""" + return self.y + self.height + + @property + def area(self) -> int: + """The area under the region.""" + return self.width * self.height + + @property + def offset(self) -> Offset: + """The top left corner of the region. + + Returns: + An offset. + """ + return Offset(*self[:2]) + + @property + def center(self) -> tuple[float, float]: + """The center of the region. + + Note, that this does *not* return an `Offset`, because the center may not be an integer coordinate. + + Returns: + Tuple of floats. + """ + x, y, width, height = self + return (x + width / 2.0, y + height / 2.0) + + @property + def bottom_left(self) -> Offset: + """Bottom left offset of the region. + + Returns: + An offset. + """ + x, y, _width, height = self + return Offset(x, y + height) + + @property + def top_right(self) -> Offset: + """Top right offset of the region. + + Returns: + An offset. + """ + x, y, width, _height = self + return Offset(x + width, y) + + @property + def bottom_right(self) -> Offset: + """Bottom right offset of the region. + + Returns: + An offset. + """ + x, y, width, height = self + return Offset(x + width, y + height) + + @property + def bottom_right_inclusive(self) -> Offset: + """Bottom right corner of the region, within its boundaries.""" + x, y, width, height = self + return Offset(x + width - 1, y + height - 1) + + @property + def size(self) -> Size: + """Get the size of the region.""" + return Size(*self[2:]) + + @property + def corners(self) -> tuple[int, int, int, int]: + """The top left and bottom right coordinates as a tuple of four integers.""" + x, y, width, height = self + return x, y, x + width, y + height + + @property + def column_range(self) -> range: + """A range object for X coordinates.""" + return range(self.x, self.x + self.width) + + @property + def line_range(self) -> range: + """A range object for Y coordinates.""" + return range(self.y, self.y + self.height) + + @property + def reset_offset(self) -> Region: + """An region of the same size at (0, 0). + + Returns: + A region at the origin. + """ + _, _, width, height = self + return Region(0, 0, width, height) + + def __add__(self, other: object) -> Region: + if isinstance(other, tuple): + ox, oy = other + x, y, width, height = self + return Region(x + ox, y + oy, width, height) + return NotImplemented + + def __sub__(self, other: object) -> Region: + if isinstance(other, tuple): + ox, oy = other + x, y, width, height = self + return Region(x - ox, y - oy, width, height) + return NotImplemented + + def get_spacing_between(self, region: Region) -> Spacing: + """Get spacing between two regions. + + Args: + region: Another region. + + Returns: + Spacing that if subtracted from `self` produces `region`. + """ + return Spacing( + region.y - self.y, + self.right - region.right, + self.bottom - region.bottom, + region.x - self.x, + ) + + def at_offset(self, offset: tuple[int, int]) -> Region: + """Get a new Region with the same size at a given offset. + + Args: + offset: An offset. + + Returns: + New Region with adjusted offset. + """ + x, y = offset + _x, _y, width, height = self + return Region(x, y, width, height) + + def crop_size(self, size: tuple[int, int]) -> Region: + """Get a region with the same offset, with a size no larger than `size`. + + Args: + size: Maximum width and height (WIDTH, HEIGHT). + + Returns: + New region that could fit within `size`. + """ + x, y, width1, height1 = self + width2, height2 = size + return Region(x, y, min(width1, width2), min(height1, height2)) + + def expand(self, size: tuple[int, int]) -> Region: + """Increase the size of the region by adding a border. + + Args: + size: Additional width and height. + + Returns: + A new region. + """ + expand_width, expand_height = size + x, y, width, height = self + return Region( + x - expand_width, + y - expand_height, + width + expand_width * 2, + height + expand_height * 2, + ) + + @lru_cache(maxsize=1024) + def overlaps(self, other: Region) -> bool: + """Check if another region overlaps this region. + + Args: + other: A Region. + + Returns: + True if other region shares any cells with this region. + """ + x, y, x2, y2 = self.corners + ox, oy, ox2, oy2 = other.corners + + return ((x2 > ox >= x) or (x2 > ox2 > x) or (ox < x and ox2 >= x2)) and ( + (y2 > oy >= y) or (y2 > oy2 > y) or (oy < y and oy2 >= y2) + ) + + def contains(self, x: int, y: int) -> bool: + """Check if a point is in the region. + + Args: + x: X coordinate. + y: Y coordinate. + + Returns: + True if the point is within the region. + """ + self_x, self_y, width, height = self + return (self_x + width > x >= self_x) and (self_y + height > y >= self_y) + + def contains_point(self, point: tuple[int, int]) -> bool: + """Check if a point is in the region. + + Args: + point: A tuple of x and y coordinates. + + Returns: + True if the point is within the region. + """ + x1, y1, x2, y2 = self.corners + try: + ox, oy = point + except Exception: + raise TypeError(f"a tuple of two integers is required, not {point!r}") + return (x2 > ox >= x1) and (y2 > oy >= y1) + + @lru_cache(maxsize=1024) + def contains_region(self, other: Region) -> bool: + """Check if a region is entirely contained within this region. + + Args: + other: A region. + + Returns: + True if the other region fits perfectly within this region. + """ + x1, y1, x2, y2 = self.corners + ox, oy, ox2, oy2 = other.corners + return ( + (x2 >= ox >= x1) + and (y2 >= oy >= y1) + and (x2 >= ox2 >= x1) + and (y2 >= oy2 >= y1) + ) + + @lru_cache(maxsize=1024) + def translate(self, offset: tuple[int, int]) -> Region: + """Move the offset of the Region. + + Args: + offset: Offset to add to region. + + Returns: + A new region shifted by (x, y). + """ + + self_x, self_y, width, height = self + offset_x, offset_y = offset + return Region(self_x + offset_x, self_y + offset_y, width, height) + + @lru_cache(maxsize=4096) + def __contains__(self, other: Any) -> bool: + """Check if a point is in this region.""" + if isinstance(other, Region): + return self.contains_region(other) + else: + try: + return self.contains_point(other) + except TypeError: + return False + + def clip(self, width: int, height: int) -> Region: + """Clip this region to fit within width, height. + + Args: + width: Width of bounds. + height: Height of bounds. + + Returns: + Clipped region. + """ + x1, y1, x2, y2 = self.corners + + _clamp = clamp + new_region = Region.from_corners( + _clamp(x1, 0, width), + _clamp(y1, 0, height), + _clamp(x2, 0, width), + _clamp(y2, 0, height), + ) + return new_region + + @lru_cache(maxsize=4096) + def grow(self, margin: tuple[int, int, int, int]) -> Region: + """Grow a region by adding spacing. + + Args: + margin: Grow space by `(, , , )`. + + Returns: + New region. + """ + if not any(margin): + return self + top, right, bottom, left = margin + x, y, width, height = self + return Region( + x=x - left, + y=y - top, + width=max(0, width + left + right), + height=max(0, height + top + bottom), + ) + + @lru_cache(maxsize=4096) + def shrink(self, margin: tuple[int, int, int, int]) -> Region: + """Shrink a region by subtracting spacing. + + Args: + margin: Shrink space by `(, , , )`. + + Returns: + The new, smaller region. + """ + if not any(margin): + return self + top, right, bottom, left = margin + x, y, width, height = self + return Region( + x=x + left, + y=y + top, + width=max(0, width - (left + right)), + height=max(0, height - (top + bottom)), + ) + + @lru_cache(maxsize=4096) + def intersection(self, region: Region) -> Region: + """Get the overlapping portion of the two regions. + + Args: + region: A region that overlaps this region. + + Returns: + A new region that covers when the two regions overlap. + """ + # Unrolled because this method is used a lot + x1, y1, w1, h1 = self + cx1, cy1, w2, h2 = region + x2 = x1 + w1 + y2 = y1 + h1 + cx2 = cx1 + w2 + cy2 = cy1 + h2 + + rx1 = cx2 if x1 > cx2 else (cx1 if x1 < cx1 else x1) + ry1 = cy2 if y1 > cy2 else (cy1 if y1 < cy1 else y1) + rx2 = cx2 if x2 > cx2 else (cx1 if x2 < cx1 else x2) + ry2 = cy2 if y2 > cy2 else (cy1 if y2 < cy1 else y2) + + return Region(rx1, ry1, rx2 - rx1, ry2 - ry1) + + @lru_cache(maxsize=4096) + def union(self, region: Region) -> Region: + """Get the smallest region that contains both regions. + + Args: + region: Another region. + + Returns: + An optimally sized region to cover both regions. + """ + x1, y1, x2, y2 = self.corners + ox1, oy1, ox2, oy2 = region.corners + + union_region = self.from_corners( + min(x1, ox1), min(y1, oy1), max(x2, ox2), max(y2, oy2) + ) + return union_region + + @lru_cache(maxsize=1024) + def split(self, cut_x: int, cut_y: int) -> tuple[Region, Region, Region, Region]: + """Split a region into 4 from given x and y offsets (cuts). + + ``` + cut_x ↓ + ┌────────┐ ┌───┐ + │ │ │ │ + │ 0 │ │ 1 │ + │ │ │ │ + cut_y → └────────┘ └───┘ + ┌────────┐ ┌───┐ + │ 2 │ │ 3 │ + └────────┘ └───┘ + ``` + + Args: + cut_x: Offset from self.x where the cut should be made. If negative, the cut + is taken from the right edge. + cut_y: Offset from self.y where the cut should be made. If negative, the cut + is taken from the lower edge. + + Returns: + Four new regions which add up to the original (self). + """ + + x, y, width, height = self + if cut_x < 0: + cut_x = width + cut_x + if cut_y < 0: + cut_y = height + cut_y + + _Region = Region + return ( + _Region(x, y, cut_x, cut_y), + _Region(x + cut_x, y, width - cut_x, cut_y), + _Region(x, y + cut_y, cut_x, height - cut_y), + _Region(x + cut_x, y + cut_y, width - cut_x, height - cut_y), + ) + + @lru_cache(maxsize=1024) + def split_vertical(self, cut: int) -> tuple[Region, Region]: + """Split a region into two, from a given x offset. + + ``` + cut ↓ + ┌────────┐┌───┐ + │ 0 ││ 1 │ + │ ││ │ + └────────┘└───┘ + ``` + + Args: + cut: An offset from self.x where the cut should be made. If cut is negative, + it is taken from the right edge. + + Returns: + Two regions, which add up to the original (self). + """ + + x, y, width, height = self + if cut < 0: + cut = width + cut + + return ( + Region(x, y, cut, height), + Region(x + cut, y, width - cut, height), + ) + + @lru_cache(maxsize=1024) + def split_horizontal(self, cut: int) -> tuple[Region, Region]: + """Split a region into two, from a given y offset. + + ``` + ┌─────────┐ + │ 0 │ + │ │ + cut → └─────────┘ + ┌─────────┐ + │ 1 │ + └─────────┘ + ``` + + Args: + cut: An offset from self.y where the cut should be made. May be negative, + for the offset to start from the lower edge. + + Returns: + Two regions, which add up to the original (self). + """ + x, y, width, height = self + if cut < 0: + cut = height + cut + + return ( + Region(x, y, width, cut), + Region(x, y + cut, width, height - cut), + ) + + def translate_inside( + self, container: Region, x_axis: bool = True, y_axis: bool = True + ) -> Region: + """Translate this region, so it fits within a container. + + This will ensure that there is as little overlap as possible. + The top left of the returned region is guaranteed to be within the container. + + ``` + ┌──────────────────┐ ┌──────────────────┐ + │ container │ │ container │ + │ │ │ ┌─────────────┤ + │ │ ──▶ │ │ return │ + │ ┌──────────┴──┐ │ │ │ + │ │ self │ │ │ │ + └───────┤ │ └────┴─────────────┘ + │ │ + └─────────────┘ + ``` + + + Args: + container: A container region. + x_axis: Allow translation of X axis. + y_axis: Allow translation of Y axis. + + Returns: + A new region with same dimensions that fits with inside container. + """ + x1, y1, width1, height1 = container + x2, y2, width2, height2 = self + return Region( + max(min(x2, x1 + width1 - width2), x1) if x_axis else x2, + max(min(y2, y1 + height1 - height2), y1) if y_axis else y2, + width2, + height2, + ) + + def inflect( + self, x_axis: int = +1, y_axis: int = +1, margin: Spacing | None = None + ) -> Region: + """Inflect a region around one or both axis. + + The `x_axis` and `y_axis` parameters define which direction to move the region. + A positive value will move the region right or down, a negative value will move + the region left or up. A value of `0` will leave that axis unmodified. + + If a margin is provided, it will add space between the resulting region. + + Note that if margin is specified it *overlaps*, so the space will be the maximum + of two edges, and not the total. + + ``` + ╔══════════╗ │ + ║ ║ + ║ Self ║ │ + ║ ║ + ╚══════════╝ │ + + ─ ─ ─ ─ ─ ─ ─ ─ ┌──────────┐ + │ │ + │ Result │ + │ │ + └──────────┘ + ``` + + Args: + x_axis: +1 to inflect in the positive direction, -1 to inflect in the negative direction. + y_axis: +1 to inflect in the positive direction, -1 to inflect in the negative direction. + margin: Additional margin. + + Returns: + A new region. + """ + inflect_margin = NULL_SPACING if margin is None else margin + x, y, width, height = self + if x_axis: + x += (width + inflect_margin.max_width) * x_axis + if y_axis: + y += (height + inflect_margin.max_height) * y_axis + return Region(x, y, width, height) + + def constrain( + self, + constrain_x: Literal["none", "inside", "inflect"], + constrain_y: Literal["none", "inside", "inflect"], + margin: Spacing, + container: Region, + ) -> Region: + """Constrain a region to fit within a container, using different methods per axis. + + Args: + constrain_x: Constrain method for the X-axis. + constrain_y: Constrain method for the Y-axis. + margin: Margin to maintain around region. + container: Container to constrain to. + + Returns: + New widget, that fits inside the container (if possible). + """ + margin_region = self.grow(margin) + region = self + + def compare_span( + span_start: int, span_end: int, container_start: int, container_end: int + ) -> int: + """Compare a span with a container + + Args: + span_start: Start of the span. + span_end: end of the span. + container_start: Start of the container. + container_end: End of the container. + + Returns: + 0 if the span fits, -1 if it is less that the container, otherwise +1 + """ + if span_start >= container_start and span_end <= container_end: + return 0 + if span_start < container_start: + return -1 + return +1 + + # Apply any inflected constraints + if constrain_x == "inflect" or constrain_y == "inflect": + region = region.inflect( + ( + -compare_span( + margin_region.x, + margin_region.right, + container.x, + container.right, + ) + if constrain_x == "inflect" + else 0 + ), + ( + -compare_span( + margin_region.y, + margin_region.bottom, + container.y, + container.bottom, + ) + if constrain_y == "inflect" + else 0 + ), + margin, + ) + + # Apply translate inside constrains + # Note this is also applied, if a previous inflect constrained has been applied + # This is so that the origin is always inside the container + region = region.translate_inside( + container.shrink(margin), + constrain_x != "none", + constrain_y != "none", + ) + + return region + + +class Spacing(NamedTuple): + """Stores spacing around a widget, such as padding and border. + + Spacing is defined by four integers for the space at the top, right, bottom, and left of a region. + + ``` + ┌ ─ ─ ─ ─ ─ ─ ─▲─ ─ ─ ─ ─ ─ ─ ─ ┐ + │ top + │ ┏━━━━━▼━━━━━━┓ │ + ◀──────▶┃ ┃◀───────▶ + │ left ┃ ┃ right │ + ┃ ┃ + │ ┗━━━━━▲━━━━━━┛ │ + │ bottom + └ ─ ─ ─ ─ ─ ─ ─▼─ ─ ─ ─ ─ ─ ─ ─ ┘ + ``` + + Example: + ```python + >>> from textual.geometry import Region, Spacing + >>> region = Region(2, 3, 20, 10) + >>> spacing = Spacing(1, 2, 3, 4) + >>> region.grow(spacing) + Region(x=-2, y=2, width=26, height=14) + >>> region.shrink(spacing) + Region(x=6, y=4, width=14, height=6) + >>> spacing.css + '1 2 3 4' + ``` + """ + + top: int = 0 + """Space from the top of a region.""" + right: int = 0 + """Space from the right of a region.""" + bottom: int = 0 + """Space from the bottom of a region.""" + left: int = 0 + """Space from the left of a region.""" + + def __bool__(self) -> bool: + return self != (0, 0, 0, 0) + + @property + def width(self) -> int: + """Total space in the x axis.""" + return self.left + self.right + + @property + def height(self) -> int: + """Total space in the y axis.""" + return self.top + self.bottom + + @property + def max_width(self) -> int: + """The space between regions in the X direction if margins overlap, i.e. `max(self.left, self.right)`.""" + _top, right, _bottom, left = self + return left if left > right else right + + @property + def max_height(self) -> int: + """The space between regions in the Y direction if margins overlap, i.e. `max(self.top, self.bottom)`.""" + top, _right, bottom, _left = self + return top if top > bottom else bottom + + @property + def top_left(self) -> tuple[int, int]: + """A pair of integers for the left, and top space.""" + return (self.left, self.top) + + @property + def bottom_right(self) -> tuple[int, int]: + """A pair of integers for the right, and bottom space.""" + return (self.right, self.bottom) + + @property + def totals(self) -> tuple[int, int]: + """A pair of integers for the total horizontal and vertical space.""" + top, right, bottom, left = self + return (left + right, top + bottom) + + @property + def css(self) -> str: + """A string containing the spacing in CSS format. + + For example: "1" or "2 4" or "4 2 8 2". + """ + top, right, bottom, left = self + if top == right == bottom == left: + return f"{top}" + if (top, right) == (bottom, left): + return f"{top} {right}" + else: + return f"{top} {right} {bottom} {left}" + + @classmethod + def unpack(cls, pad: SpacingDimensions) -> Spacing: + """Unpack padding specified in CSS style. + + Args: + pad: An integer, or tuple of 1, 2, or 4 integers. + + Raises: + ValueError: If `pad` is an invalid value. + + Returns: + New Spacing object. + """ + if isinstance(pad, int): + return cls(pad, pad, pad, pad) + pad_len = len(pad) + if pad_len == 1: + _pad = pad[0] + return cls(_pad, _pad, _pad, _pad) + if pad_len == 2: + pad_top, pad_right = cast(Tuple[int, int], pad) + return cls(pad_top, pad_right, pad_top, pad_right) + if pad_len == 4: + top, right, bottom, left = cast(Tuple[int, int, int, int], pad) + return cls(top, right, bottom, left) + raise ValueError( + f"1, 2 or 4 integers required for spacing properties; {pad_len} given" + ) + + @classmethod + def vertical(cls, amount: int) -> Spacing: + """Construct a Spacing with a given amount of spacing on vertical edges, + and no horizontal spacing. + + Args: + amount: The magnitude of spacing to apply to vertical edges. + + Returns: + `Spacing(amount, 0, amount, 0)` + """ + return Spacing(amount, 0, amount, 0) + + @classmethod + def horizontal(cls, amount: int) -> Spacing: + """Construct a Spacing with a given amount of spacing on horizontal edges, + and no vertical spacing. + + Args: + amount: The magnitude of spacing to apply to horizontal edges. + + Returns: + `Spacing(0, amount, 0, amount)` + """ + return Spacing(0, amount, 0, amount) + + @classmethod + def all(cls, amount: int) -> Spacing: + """Construct a Spacing with a given amount of spacing on all edges. + + Args: + amount: The magnitude of spacing to apply to all edges. + + Returns: + `Spacing(amount, amount, amount, amount)` + """ + return Spacing(amount, amount, amount, amount) + + def __add__(self, other: object) -> Spacing: + if isinstance(other, tuple): + top1, right1, bottom1, left1 = self + top2, right2, bottom2, left2 = other + return Spacing( + top1 + top2, right1 + right2, bottom1 + bottom2, left1 + left2 + ) + return NotImplemented + + def __sub__(self, other: object) -> Spacing: + if isinstance(other, tuple): + top1, right1, bottom1, left1 = self + top2, right2, bottom2, left2 = other + return Spacing( + top1 - top2, right1 - right2, bottom1 - bottom2, left1 - left2 + ) + return NotImplemented + + def grow_maximum(self, other: Spacing) -> Spacing: + """Grow spacing with a maximum. + + Args: + other: Spacing object. + + Returns: + New spacing where the values are maximum of the two values. + """ + top, right, bottom, left = self + other_top, other_right, other_bottom, other_left = other + return Spacing( + max(top, other_top), + max(right, other_right), + max(bottom, other_bottom), + max(left, other_left), + ) + + +class Shape: + """An arbitrary shape defined by a sequence of regions. + + This class currently exists to filter widgets within a shape defined when the user is slecting text. + + """ + + __slots__ = [ + "_regions", + "_bounds", + ] + + def __init__(self, regions: Iterable[Region]): + """ + + Args: + regions: Regions which will define the shape. + """ + self._regions = tuple(regions) + self._bounds = Region.from_union(self._regions) if regions else NULL_REGION + + def __bool__(self) -> bool: + return bool(self._bounds) + + def __hash__(self) -> int: + return hash(self._regions) + + def __rich_repr__(self) -> rich.repr.Result: + yield self._regions + + @property + def regions(self) -> tuple[Region, ...]: + """The regions in the shape.""" + return self._regions + + @property + def bounds(self) -> Region: + """A region that encloses the shape.""" + return self._bounds + + @property + def area(self) -> int: + """Cells covered by the shape.""" + # TODO: Currently doesn't handle overlapping regions + return sum(region.area for region in self._regions) + + @classmethod + def selection_bounds(cls, container: Region, start: Offset, end: Offset) -> Shape: + """Get a shape that would be constructed by a user selecting text between two points. + + The shape would look something like this: + + ``` + XXXXXXXXXX <- top + XXXXXXXXXXXXXX + XXXXXXXXXXXXXX <- middle + XXXXXXXXXXXXXX + XXXXXXXXX <- bottom + ``` + + Args: + container: The container region for the selection. + start: The start offset. + end: The end offset. + + Returns: + A new shape covering the selection bounds. + """ + if start.transpose > end.transpose: + end, start = start, end + start_x, start_y = start + end_x, end_y = end + + def get_regions() -> Iterable[Region]: + """Get regions to cover selection bounds. + + Yields: + Regions to cover bounds. + """ + # Special case where start and end offsets are on the edges, and the shape + # becomes a single region + if start_x == 0 and end_x == container.width: + yield Region( + 0, + start_y, + container.width, + end_y - start_y, + ) + + # Simple case: all on one line + elif start.y == end.y: + yield Region( + start_x, + start_y, + end_x - start_x, + 1, + ) + + # Shape is on two or more lines + else: + # top + yield Region( + start_x, + start_y, + container.width - start_x, + 1, + ) + # middle + if end.y - start.y > 2: + # We need a middle region between the top and the bottom + yield Region( + 0, + start_y + 1, + container.width, + end_y - start_y - 1, + ) + # bottom + yield Region( + container.x, + end_y, + end_x, + 1, + ) + + return Shape(get_regions()) + + def overlaps(self, region: Region) -> bool: + """Does a region overlap this shape? + + Args: + region: A Region to check. + + Returns: + `True` if any part of the shape overlaps the region, `False` if there is no overlap. + """ + return any(shape_region.overlaps(region) for shape_region in self._regions) + + def contains_point(self, offset: Offset) -> bool: + """Check if the given offset is within the shape. + + Args: + offset: An offset. + + Returns: + `True` if the given offset is anywhere within the shape, otherwise `False`. + """ + return any(region.contains_point(offset) for region in self._regions) + + +if not TYPE_CHECKING and os.environ.get("TEXTUAL_SPEEDUPS", "1") == "1": + try: + from textual_speedups import Offset, Region, Size, Spacing + except ImportError: + pass + + +NULL_OFFSET: Final = Offset(0, 0) +"""An [Offset][textual.geometry.Offset] constant for (0, 0).""" + +NULL_REGION: Final = Region(0, 0, 0, 0) +"""A [Region][textual.geometry.Region] constant for a null region (at the origin, with both width and height set to zero).""" + +NULL_SIZE: Final = Size(0, 0) +"""A [Size][textual.geometry.Size] constant for a null size (with zero area).""" + +NULL_SPACING: Final = Spacing(0, 0, 0, 0) +"""A [Spacing][textual.geometry.Spacing] constant for no space.""" diff --git a/src/memray/_vendor/textual/getters.py b/src/memray/_vendor/textual/getters.py new file mode 100644 index 0000000000..5e54038dde --- /dev/null +++ b/src/memray/_vendor/textual/getters.py @@ -0,0 +1,225 @@ +""" +Descriptors to define properties on your widget, screen, or App. + +""" + +from __future__ import annotations + +from inspect import isclass +from typing import TYPE_CHECKING, Callable, Generic, TypeVar, overload + +from memray._vendor.textual._context import NoActiveAppError, active_app +from memray._vendor.textual.css.query import NoMatches, QueryType, WrongType +from memray._vendor.textual.widget import Widget + +if TYPE_CHECKING: + from memray._vendor.textual.app import App + from memray._vendor.textual.dom import DOMNode + from memray._vendor.textual.message_pump import MessagePump + + +AppType = TypeVar("AppType", bound="App") + + +class app(Generic[AppType]): + """Create a property to return the active app. + + All widgets have a default `app` property which returns an App instance. + Type checkers will complain if you try to access attributes defined on your App class, which aren't + present in the base class. To keep the type checker happy you can add this property to get your + specific App subclass. + + Example: + ```python + class MyWidget(Widget): + app = getters.app(MyApp) + ``` + + Args: + app_type: The App subclass, or a callable which returns an App subclass. + """ + + def __init__(self, app_type: type[AppType] | Callable[[], type[AppType]]) -> None: + self._app_type = app_type if isclass(app_type) else app_type() + + def __get__(self, obj: MessagePump, obj_type: type[MessagePump]) -> AppType: + try: + app = active_app.get() + except LookupError: + from memray._vendor.textual.app import App + + node: MessagePump | None = obj + while not isinstance(node, App): + if node is None: + raise NoActiveAppError() + node = node._parent + app = node + + assert isinstance(app, self._app_type) + return app + + +class query_one(Generic[QueryType]): + """Create a query one property. + + A query one property calls [Widget.query_one][textual.dom.DOMNode.query_one] when accessed, and returns + a widget. If the widget doesn't exist, then the property will raise the same exceptions as `Widget.query_one`. + + + Example: + ```python + from memray._vendor.textual import getters + + class MyScreen(screen): + + # Note this is at the class level + output_log = getters.query_one("#output", RichLog) + + def compose(self) -> ComposeResult: + with containers.Vertical(): + yield RichLog(id="output") + + def on_mount(self) -> None: + self.output_log.write("Screen started") + # Equivalent to the following line: + # self.query_one("#output", RichLog).write("Screen started") + ``` + + Args: + selector: A TCSS selector, e.g. "#mywidget". Or a widget type, i.e. `Input`. + expect_type: The type of the expected widget, e.g. `Input`, if the first argument is a selector. + + """ + + selector: str + expect_type: type["Widget"] + + @overload + def __init__(self, selector: str) -> None: + """ + + Args: + selector: A TCSS selector, e.g. "#mywidget" + """ + + @overload + def __init__(self, selector: type[QueryType]) -> None: ... + + @overload + def __init__(self, selector: str, expect_type: type[QueryType]) -> None: ... + + @overload + def __init__( + self, selector: type[QueryType], expect_type: type[QueryType] + ) -> None: ... + + def __init__( + self, + selector: str | type[QueryType], + expect_type: type[QueryType] | None = None, + ) -> None: + if expect_type is None: + from memray._vendor.textual.widget import Widget + + self.expect_type = Widget + else: + self.expect_type = expect_type + if isinstance(selector, str): + self.selector = selector + else: + self.selector = selector.__name__ + self.expect_type = selector + + @overload + def __get__( + self: "query_one[QueryType]", obj: DOMNode, obj_type: type[DOMNode] + ) -> QueryType: ... + + @overload + def __get__( + self: "query_one[QueryType]", obj: None, obj_type: type[DOMNode] + ) -> "query_one[QueryType]": ... + + def __get__( + self: "query_one[QueryType]", obj: DOMNode | None, obj_type: type[DOMNode] + ) -> QueryType | Widget | "query_one": + """Get the widget matching the selector and/or type.""" + if obj is None: + return self + query_node = obj.query_one(self.selector, self.expect_type) + return query_node + + +class child_by_id(Generic[QueryType]): + """Create a child_by_id property, which returns the child with the given ID. + + This is similar using [query_one][textual.getters.query_one] with an id selector, except that + only the immediate children are considered. It is also more efficient as it doesn't need to search the DOM. + + + Example: + ```python + from memray._vendor.textual import getters + + class MyScreen(screen): + + # Note this is at the class level + output_log = getters.child_by_id("output", RichLog) + + def compose(self) -> ComposeResult: + yield RichLog(id="output") + + def on_mount(self) -> None: + self.output_log.write("Screen started") + ``` + + Args: + child_id: The `id` of the widget to get (not a selector). + expect_type: The type of the expected widget, e.g. `Input`. + + """ + + child_id: str + expect_type: type[Widget] + + @overload + def __init__(self, child_id: str) -> None: ... + + @overload + def __init__(self, child_id: str, expect_type: type[QueryType]) -> None: ... + + def __init__( + self, + child_id: str, + expect_type: type[QueryType] | None = None, + ) -> None: + if expect_type is None: + self.expect_type = Widget + else: + self.expect_type = expect_type + self.child_id = child_id + + @overload + def __get__( + self: "child_by_id[QueryType]", obj: DOMNode, obj_type: type[DOMNode] + ) -> QueryType: ... + + @overload + def __get__( + self: "child_by_id[QueryType]", obj: None, obj_type: type[DOMNode] + ) -> "child_by_id[QueryType]": ... + + def __get__( + self: "child_by_id[QueryType]", obj: DOMNode | None, obj_type: type[DOMNode] + ) -> QueryType | Widget | "child_by_id": + """Get the widget matching the selector and/or type.""" + if obj is None: + return self + child = obj._get_dom_base()._nodes._get_by_id(self.child_id) + if child is None: + raise NoMatches(f"No child found with id={self.child_id!r}") + if not isinstance(child, self.expect_type): + raise WrongType( + f"Child with id={self.child_id!r} is the wrong type; expected type {self.expect_type.__name__!r}, found {child}" + ) + return child diff --git a/src/memray/_vendor/textual/highlight.py b/src/memray/_vendor/textual/highlight.py new file mode 100644 index 0000000000..455d2bc918 --- /dev/null +++ b/src/memray/_vendor/textual/highlight.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +import os +from typing import Tuple + +from pygments.lexer import Lexer +from pygments.lexers import get_lexer_by_name, guess_lexer_for_filename +from pygments.token import Token +from pygments.util import ClassNotFound + +from memray._vendor.textual.content import Content, Span + +TokenType = Tuple[str, ...] + + +class HighlightTheme: + """Contains the style definition for user with the highlight method.""" + + STYLES: dict[TokenType, str] = { + Token.Comment: "$text 60%", + Token.Error: "$text-error on $error-muted", + Token.Generic.Strong: "bold", + Token.Generic.Emph: "italic", + Token.Generic.Error: "$text-error on $error-muted", + Token.Generic.Heading: "$text-primary underline", + Token.Generic.Subheading: "$text-primary", + Token.Keyword: "$text-accent", + Token.Keyword.Constant: "bold $text-success 80%", + Token.Keyword.Namespace: "$text-error", + Token.Keyword.Type: "bold", + Token.Literal.Number: "$text-warning", + Token.Literal.String.Backtick: "$text 60%", + Token.Literal.String: "$text-success 90%", + Token.Literal.String.Doc: "$text-success 80% italic", + Token.Literal.String.Double: "$text-success 90%", + Token.Name: "$text-primary", + Token.Name.Attribute: "$text-warning", + Token.Name.Builtin: "$text-accent", + Token.Name.Builtin.Pseudo: "italic", + Token.Name.Class: "$text-warning bold", + Token.Name.Constant: "$text-error", + Token.Name.Decorator: "$text-primary bold", + Token.Name.Function: "$text-warning underline", + Token.Name.Function.Magic: "$text-warning underline", + Token.Name.Tag: "$text-primary bold", + Token.Name.Variable: "$text-secondary", + Token.Number: "$text-warning", + Token.Operator: "bold", + Token.Operator.Word: "bold $text-error", + Token.String: "$text-success", + Token.Whitespace: "", + } + + +def guess_language(code: str, path: str | None) -> str: + """Guess the language based on the code and path. + The result may be used in the [highlight][textual.highlight.highlight] function. + + Args: + code: The code to guess from. + path: A path to the code. + + Returns: + The language, suitable for use with Pygments. + """ + + if path and os.path.splitext(path)[-1] == ".tcss": + # A special case for TCSS files which aren't known outside of Textual + return "scss" + + lexer: Lexer | None = None + lexer_name = "default" + if code: + if path: + try: + lexer = guess_lexer_for_filename(path, code) + except ClassNotFound: + pass + + if lexer is None: + from pygments.lexers import guess_lexer + + try: + lexer = guess_lexer(code) + except Exception: + pass + + if not lexer and path: + try: + _, ext = os.path.splitext(path) + if ext: + extension = ext.lstrip(".").lower() + lexer = get_lexer_by_name(extension) + except ClassNotFound: + pass + + if lexer: + if lexer.aliases: + lexer_name = lexer.aliases[0] + else: + lexer_name = lexer.name + + return lexer_name + + +def highlight( + code: str, + *, + language: str | None = None, + path: str | None = None, + theme: type[HighlightTheme] = HighlightTheme, + tab_size: int = 8, +) -> Content: + """Apply syntax highlighting to a string. + + Args: + code: A string to highlight. + language: The language to highlight. + theme: A HighlightTheme class (type not instance). + tab_size: Number of spaces in a tab. + + Returns: + A Content instance which may be used in a widget. + """ + if not language: + language = guess_language(code, path) + + assert language is not None + code = "\n".join(code.splitlines()) + try: + lexer = get_lexer_by_name( + language, + stripnl=False, + ensurenl=True, + tabsize=tab_size, + ) + except ClassNotFound: + lexer = get_lexer_by_name( + "text", + stripnl=False, + ensurenl=True, + tabsize=tab_size, + ) + + token_start = 0 + spans: list[Span] = [] + styles = theme.STYLES + + for token_type, token in lexer.get_tokens(code): + token_end = token_start + len(token) + while True: + if style := styles.get(token_type): + spans.append(Span(token_start, token_end, style)) + break + if (token_type := token_type.parent) is None: + break + token_start = token_end + + highlighted_code = Content(code, spans=spans).stylize_before("$text") + return highlighted_code diff --git a/src/memray/_vendor/textual/keys.py b/src/memray/_vendor/textual/keys.py new file mode 100644 index 0000000000..09c7f17acb --- /dev/null +++ b/src/memray/_vendor/textual/keys.py @@ -0,0 +1,368 @@ +from __future__ import annotations + +import unicodedata +from enum import Enum +from functools import lru_cache + + +# Adapted from prompt toolkit https://github.com/prompt-toolkit/python-prompt-toolkit/blob/master/prompt_toolkit/keys.py +class Keys(str, Enum): # type: ignore[no-redef] + """ + List of keys for use in key bindings. + + Note that this is an "StrEnum", all values can be compared against + strings. + """ + + @property + def value(self) -> str: + return super().value + + Escape = "escape" # Also Control-[ + ShiftEscape = "shift+escape" + Return = "return" + + ControlAt = "ctrl+@" # Also Control-Space. + + ControlA = "ctrl+a" + ControlB = "ctrl+b" + ControlC = "ctrl+c" + ControlD = "ctrl+d" + ControlE = "ctrl+e" + ControlF = "ctrl+f" + ControlG = "ctrl+g" + ControlH = "ctrl+h" + ControlI = "ctrl+i" # Tab + ControlJ = "ctrl+j" # Newline + ControlK = "ctrl+k" + ControlL = "ctrl+l" + ControlM = "ctrl+m" # Carriage return + ControlN = "ctrl+n" + ControlO = "ctrl+o" + ControlP = "ctrl+p" + ControlQ = "ctrl+q" + ControlR = "ctrl+r" + ControlS = "ctrl+s" + ControlT = "ctrl+t" + ControlU = "ctrl+u" + ControlV = "ctrl+v" + ControlW = "ctrl+w" + ControlX = "ctrl+x" + ControlY = "ctrl+y" + ControlZ = "ctrl+z" + + Control1 = "ctrl+1" + Control2 = "ctrl+2" + Control3 = "ctrl+3" + Control4 = "ctrl+4" + Control5 = "ctrl+5" + Control6 = "ctrl+6" + Control7 = "ctrl+7" + Control8 = "ctrl+8" + Control9 = "ctrl+9" + Control0 = "ctrl+0" + + ControlShift1 = "ctrl+shift+1" + ControlShift2 = "ctrl+shift+2" + ControlShift3 = "ctrl+shift+3" + ControlShift4 = "ctrl+shift+4" + ControlShift5 = "ctrl+shift+5" + ControlShift6 = "ctrl+shift+6" + ControlShift7 = "ctrl+shift+7" + ControlShift8 = "ctrl+shift+8" + ControlShift9 = "ctrl+shift+9" + ControlShift0 = "ctrl+shift+0" + + ControlBackslash = "ctrl+backslash" + ControlSquareClose = "ctrl+right_square_bracket" + ControlCircumflex = "ctrl+circumflex_accent" + ControlUnderscore = "ctrl+underscore" + + Left = "left" + Right = "right" + Up = "up" + Down = "down" + Home = "home" + End = "end" + Insert = "insert" + Delete = "delete" + PageUp = "pageup" + PageDown = "pagedown" + + ControlLeft = "ctrl+left" + ControlRight = "ctrl+right" + ControlUp = "ctrl+up" + ControlDown = "ctrl+down" + ControlHome = "ctrl+home" + ControlEnd = "ctrl+end" + ControlInsert = "ctrl+insert" + ControlDelete = "ctrl+delete" + ControlPageUp = "ctrl+pageup" + ControlPageDown = "ctrl+pagedown" + + ShiftLeft = "shift+left" + ShiftRight = "shift+right" + ShiftUp = "shift+up" + ShiftDown = "shift+down" + ShiftHome = "shift+home" + ShiftEnd = "shift+end" + ShiftInsert = "shift+insert" + ShiftDelete = "shift+delete" + ShiftPageUp = "shift+pageup" + ShiftPageDown = "shift+pagedown" + + ControlShiftLeft = "ctrl+shift+left" + ControlShiftRight = "ctrl+shift+right" + ControlShiftUp = "ctrl+shift+up" + ControlShiftDown = "ctrl+shift+down" + ControlShiftHome = "ctrl+shift+home" + ControlShiftEnd = "ctrl+shift+end" + ControlShiftInsert = "ctrl+shift+insert" + ControlShiftDelete = "ctrl+shift+delete" + ControlShiftPageUp = "ctrl+shift+pageup" + ControlShiftPageDown = "ctrl+shift+pagedown" + + BackTab = "shift+tab" # shift + tab + + F1 = "f1" + F2 = "f2" + F3 = "f3" + F4 = "f4" + F5 = "f5" + F6 = "f6" + F7 = "f7" + F8 = "f8" + F9 = "f9" + F10 = "f10" + F11 = "f11" + F12 = "f12" + F13 = "f13" + F14 = "f14" + F15 = "f15" + F16 = "f16" + F17 = "f17" + F18 = "f18" + F19 = "f19" + F20 = "f20" + F21 = "f21" + F22 = "f22" + F23 = "f23" + F24 = "f24" + + ControlF1 = "ctrl+f1" + ControlF2 = "ctrl+f2" + ControlF3 = "ctrl+f3" + ControlF4 = "ctrl+f4" + ControlF5 = "ctrl+f5" + ControlF6 = "ctrl+f6" + ControlF7 = "ctrl+f7" + ControlF8 = "ctrl+f8" + ControlF9 = "ctrl+f9" + ControlF10 = "ctrl+f10" + ControlF11 = "ctrl+f11" + ControlF12 = "ctrl+f12" + ControlF13 = "ctrl+f13" + ControlF14 = "ctrl+f14" + ControlF15 = "ctrl+f15" + ControlF16 = "ctrl+f16" + ControlF17 = "ctrl+f17" + ControlF18 = "ctrl+f18" + ControlF19 = "ctrl+f19" + ControlF20 = "ctrl+f20" + ControlF21 = "ctrl+f21" + ControlF22 = "ctrl+f22" + ControlF23 = "ctrl+f23" + ControlF24 = "ctrl+f24" + + # Matches any key. + Any = "" + + # Special. + ScrollUp = "" + ScrollDown = "" + + # For internal use: key which is ignored. + # (The key binding for this key should not do anything.) + Ignore = "" + + # Some 'Key' aliases (for backwardshift+compatibility). + ControlSpace = "ctrl-at" + Tab = "tab" + Space = "space" + Enter = "enter" + Backspace = "backspace" + + # ShiftControl was renamed to ControlShift in + # 888fcb6fa4efea0de8333177e1bbc792f3ff3c24 (20 Feb 2020). + ShiftControlLeft = ControlShiftLeft + ShiftControlRight = ControlShiftRight + ShiftControlHome = ControlShiftHome + ShiftControlEnd = ControlShiftEnd + + +# Unicode db contains some obscure names +# This mapping replaces them with more common terms +KEY_NAME_REPLACEMENTS = { + "solidus": "slash", + "reverse_solidus": "backslash", + "commercial_at": "at", + "hyphen_minus": "minus", + "plus_sign": "plus", + "low_line": "underscore", +} +REPLACED_KEYS = {value: key for key, value in KEY_NAME_REPLACEMENTS.items()} + +# Convert the friendly versions of character key Unicode names +# back to their original names. +# This is because we go from Unicode to friendly by replacing spaces and dashes +# with underscores, which cannot be undone by replacing underscores with spaces/dashes. +KEY_TO_UNICODE_NAME = { + "exclamation_mark": "EXCLAMATION MARK", + "quotation_mark": "QUOTATION MARK", + "number_sign": "NUMBER SIGN", + "dollar_sign": "DOLLAR SIGN", + "percent_sign": "PERCENT SIGN", + "left_parenthesis": "LEFT PARENTHESIS", + "right_parenthesis": "RIGHT PARENTHESIS", + "plus_sign": "PLUS SIGN", + "hyphen_minus": "HYPHEN-MINUS", + "full_stop": "FULL STOP", + "less_than_sign": "LESS-THAN SIGN", + "equals_sign": "EQUALS SIGN", + "greater_than_sign": "GREATER-THAN SIGN", + "question_mark": "QUESTION MARK", + "commercial_at": "COMMERCIAL AT", + "left_square_bracket": "LEFT SQUARE BRACKET", + "reverse_solidus": "REVERSE SOLIDUS", + "right_square_bracket": "RIGHT SQUARE BRACKET", + "circumflex_accent": "CIRCUMFLEX ACCENT", + "low_line": "LOW LINE", + "grave_accent": "GRAVE ACCENT", + "left_curly_bracket": "LEFT CURLY BRACKET", + "vertical_line": "VERTICAL LINE", + "right_curly_bracket": "RIGHT CURLY BRACKET", +} + +# Some keys have aliases. For example, if you press `ctrl+m` on your keyboard, +# it's treated the same way as if you press `enter`. Key handlers `key_ctrl_m` and +# `key_enter` are both valid in this case. +KEY_ALIASES = { + "tab": ["ctrl+i"], + "enter": ["ctrl+m"], + "escape": ["ctrl+left_square_brace"], + "ctrl+at": ["ctrl+space"], + "ctrl+j": ["newline"], +} + +KEY_DISPLAY_ALIASES = { + "up": "↑", + "down": "↓", + "left": "←", + "right": "→", + "backspace": "⌫", + "escape": "esc", + "enter": "⏎", + "minus": "-", + "space": "space", + "pagedown": "pgdn", + "pageup": "pgup", + "delete": "del", +} + + +ASCII_KEY_NAMES = {"\t": "tab"} + + +def _get_unicode_name_from_key(key: str) -> str: + """Get the best guess for the Unicode name of the char corresponding to the key. + + This function can be seen as a pseudo-inverse of the function `_character_to_key`. + """ + return KEY_TO_UNICODE_NAME.get(key, key) + + +def _get_key_aliases(key: str) -> list[str]: + """Return all aliases for the given key, including the key itself""" + return [key] + KEY_ALIASES.get(key, []) + + +@lru_cache(1024) +def format_key(key: str) -> str: + """Given a key (i.e. the `key` string argument to Binding __init__), + return the value that should be displayed in the app when referring + to this key (e.g. in the Footer widget).""" + + display_alias = KEY_DISPLAY_ALIASES.get(key) + if display_alias: + return display_alias + + original_key = REPLACED_KEYS.get(key, key) + tentative_unicode_name = _get_unicode_name_from_key(original_key) + try: + unicode_name = unicodedata.lookup(tentative_unicode_name) + except KeyError: + pass + else: + if unicode_name.isprintable(): + return unicode_name + return tentative_unicode_name + + +@lru_cache(1024) +def key_to_character(key: str) -> str | None: + """Given a key identifier, return the character associated with it. + + Args: + key: The key identifier. + + Returns: + A key if one could be found, otherwise `None`. + """ + _, separator, key = key.rpartition("+") + if separator: + # If there is a separator, then it means a modifier (other than shift) is applied. + # Keys with modifiers, don't come from printable keys. + return None + if len(key) == 1: + # Key identifiers with a length of one, are also characters. + return key + try: + return unicodedata.lookup(KEY_TO_UNICODE_NAME[key]) + except KeyError: + pass + try: + return unicodedata.lookup(key.replace("_", " ").upper()) + except KeyError: + pass + # Return None if we couldn't identify the key. + return None + + +def _character_to_key(character: str) -> str: + """Convert a single character to a key value. + + This transformation can be undone by the function `_get_unicode_name_from_key`. + """ + if not character.isalnum(): + try: + key = ( + unicodedata.name(character).lower().replace("-", "_").replace(" ", "_") + ) + except ValueError: + key = ASCII_KEY_NAMES.get(character, character) + else: + key = character + key = KEY_NAME_REPLACEMENTS.get(key, key) + return key + + +def _normalize_key_list(keys: str) -> str: + """Normalizes a comma separated list of keys. + + Replaces single letter keys with full name. + """ + + keys_list = [key.strip() for key in keys.split(",")] + return ",".join( + _character_to_key(key) if len(key) == 1 else key for key in keys_list + ) diff --git a/src/memray/_vendor/textual/layout.py b/src/memray/_vendor/textual/layout.py new file mode 100644 index 0000000000..20ba65a2e3 --- /dev/null +++ b/src/memray/_vendor/textual/layout.py @@ -0,0 +1,314 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, ClassVar, Iterable, NamedTuple + +from memray._vendor.textual._spatial_map import SpatialMap +from memray._vendor.textual.canvas import Canvas, Rectangle +from memray._vendor.textual.geometry import Offset, Region, Size, Spacing +from memray._vendor.textual.strip import StripRenderable + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + from memray._vendor.textual.widget import Widget + +ArrangeResult: TypeAlias = "list[WidgetPlacement]" + + +@dataclass +class DockArrangeResult: + """Result of [Layout.arrange][textual.layout.Layout.arrange].""" + + placements: list[WidgetPlacement] + """A `WidgetPlacement` for every widget to describe its location on screen.""" + widgets: set[Widget] + """A set of widgets in the arrangement.""" + scroll_spacing: Spacing + """Spacing to reduce scrollable area.""" + + _spatial_map: SpatialMap[WidgetPlacement] | None = None + """A Spatial map to query widget placements.""" + + @property + def spatial_map(self) -> SpatialMap[WidgetPlacement]: + """A lazy-calculated spatial map.""" + if self._spatial_map is None: + self._spatial_map = SpatialMap() + self._spatial_map.insert( + ( + placement.region.grow(placement.margin), + placement.offset, + placement.fixed, + placement.overlay, + placement, + ) + for placement in self.placements + ) + + return self._spatial_map + + @property + def total_region(self) -> Region: + """The total area occupied by the arrangement. + + Returns: + A Region. + """ + _top, right, bottom, _left = self.scroll_spacing + return self.spatial_map.total_region.grow((0, right, bottom, 0)) + + def get_visible_placements(self, region: Region) -> list[WidgetPlacement]: + """Get the placements visible within the given region. + + Args: + region: A region. + + Returns: + Set of placements. + """ + if self.total_region in region: + # Short circuit for when we want all the placements + return self.placements + visible_placements = self.spatial_map.get_values_in_region(region) + overlaps = region.overlaps + culled_placements = [ + placement + for placement in visible_placements + if placement.fixed or overlaps(placement.region + placement.offset) + ] + return culled_placements + + +class WidgetPlacement(NamedTuple): + """The position, size, and relative order of a widget within its parent.""" + + region: Region + offset: Offset + margin: Spacing + widget: Widget + order: int = 0 + fixed: bool = False + overlay: bool = False + absolute: bool = False + + @property + def reset_origin(self) -> WidgetPlacement: + """Reset the origin in the placement (moves it to (0, 0)).""" + return self._replace(region=self.region.reset_offset) + + @classmethod + def translate( + cls, placements: list[WidgetPlacement], translate_offset: Offset + ) -> list[WidgetPlacement]: + """Move all non-absolute placements by a given offset. + + Args: + placements: List of placements. + offset: Offset to add to placements. + + Returns: + Placements with adjusted region, or same instance if offset is null. + """ + if translate_offset: + return [ + cls( + ( + region + translate_offset + if layout_widget.absolute_offset is None + else region + ), + offset, + margin, + layout_widget, + order, + fixed, + overlay, + absolute, + ) + for region, offset, margin, layout_widget, order, fixed, overlay, absolute in placements + ] + return placements + + @classmethod + def apply_absolute(cls, placements: list[WidgetPlacement]) -> None: + """Applies absolute offsets (in place). + + Args: + placements: A list of placements. + """ + for index, placement in enumerate(placements): + if placement.absolute: + placements[index] = placement.reset_origin + + @classmethod + def get_bounds(cls, placements: Iterable[WidgetPlacement]) -> Region: + """Get a bounding region around all placements. + + Args: + placements: A number of placements. + + Returns: + An optimal binding box around all placements. + """ + bounding_region = Region.from_union( + [placement.region.grow(placement.margin) for placement in placements] + ) + return bounding_region + + def process_offset( + self, constrain_region: Region, absolute_offset: Offset + ) -> WidgetPlacement: + """Apply any absolute offset or constrain rules to the placement. + + Args: + constrain_region: The container region when applying constrain rules. + absolute_offset: Default absolute offset that moves widget into screen coordinates. + + Returns: + Processes placement, may be the same instance. + """ + widget = self.widget + styles = widget.styles + if not widget.absolute_offset and not styles.has_any_rules( + "constrain_x", "constrain_y" + ): + # Bail early if there is nothing to do + return self + region = self.region + margin = self.margin + if widget.absolute_offset is not None: + region = region.at_offset( + widget.absolute_offset + margin.top_left - absolute_offset + ) + + region = region.translate(self.offset).constrain( + styles.constrain_x, + styles.constrain_y, + self.margin, + constrain_region - absolute_offset, + ) + + offset = region.offset - self.region.offset + if offset != self.offset: + region, _offset, margin, widget, order, fixed, overlay, absolute = self + placement = WidgetPlacement( + region, offset, margin, widget, order, fixed, overlay, absolute + ) + return placement + return self + + +class Layout(ABC): + """Base class of the object responsible for arranging Widgets within a container.""" + + name: ClassVar[str] = "" + + def __repr__(self) -> str: + return f"<{self.name}>" + + @abstractmethod + def arrange( + self, + parent: Widget, + children: list[Widget], + size: Size, + greedy: bool = True, + ) -> ArrangeResult: + """Generate a layout map that defines where on the screen the widgets will be drawn. + + Args: + parent: Parent widget. + size: Size of container. + + Returns: + An iterable of widget location + """ + + def get_content_width(self, widget: Widget, container: Size, viewport: Size) -> int: + """Get the optimal content width by arranging children. + + Args: + widget: The container widget. + container: The container size. + viewport: The viewport size. + + Returns: + Width of the content. + """ + if not widget._nodes: + width = 0 + else: + arrangement = widget.arrange( + Size(0 if widget.shrink else container.width, 0), + optimal=True, + ) + width = arrangement.total_region.right + return width + + def get_content_height( + self, widget: Widget, container: Size, viewport: Size, width: int + ) -> int: + """Get the content height. + + Args: + widget: The container widget. + container: The container size. + viewport: The viewport. + width: The content width. + + Returns: + Content height (in lines). + """ + if widget._nodes: + if not widget.styles.is_docked and all( + child.styles.is_dynamic_height for child in widget.displayed_children + ): + # An exception for containers with all dynamic height widgets + arrangement = widget.arrange(Size(width, container.height)) + else: + arrangement = widget.arrange(Size(width, 0)) + height = arrangement.total_region.height + else: + height = 0 + return height + + def render_keyline(self, container: Widget) -> StripRenderable: + """Render keylines around all widgets. + + Args: + container: The container widget. + + Returns: + A renderable to draw the keylines. + """ + width, height = container.outer_size + canvas = Canvas(width, height) + + line_style, keyline_color = container.styles.keyline + if keyline_color: + keyline_color = container.background_colors[0] + keyline_color + + container_offset = container.content_region.offset + + def get_rectangle(region: Region) -> Rectangle: + """Get a canvas Rectangle that wraps a region. + + Args: + region: Widget region. + + Returns: + A Rectangle that encloses the widget. + """ + offset = region.offset - container_offset - (1, 1) + width, height = region.size + return Rectangle(offset, width + 2, height + 2, keyline_color, line_style) + + primitives = [ + get_rectangle(widget.region) + for widget in container.children + if widget.visible + ] + canvas_renderable = canvas.render(primitives, container.rich_style) + return canvas_renderable diff --git a/src/memray/_vendor/textual/layouts/__init__.py b/src/memray/_vendor/textual/layouts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/memray/_vendor/textual/layouts/factory.py b/src/memray/_vendor/textual/layouts/factory.py new file mode 100644 index 0000000000..26e98ece24 --- /dev/null +++ b/src/memray/_vendor/textual/layouts/factory.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from memray._vendor.textual.layout import Layout +from memray._vendor.textual.layouts.grid import GridLayout +from memray._vendor.textual.layouts.horizontal import HorizontalLayout +from memray._vendor.textual.layouts.stream import StreamLayout +from memray._vendor.textual.layouts.vertical import VerticalLayout + +LAYOUT_MAP: dict[str, type[Layout]] = { + "horizontal": HorizontalLayout, + "grid": GridLayout, + "vertical": VerticalLayout, + "stream": StreamLayout, +} + + +class MissingLayout(Exception): + pass + + +def get_layout(name: str) -> Layout: + """Get a named layout object. + + Args: + name: Name of the layout. + + Raises: + MissingLayout: If the named layout doesn't exist. + + Returns: + A layout object. + """ + + layout_class = LAYOUT_MAP.get(name) + if layout_class is None: + raise MissingLayout(f"no layout called {name!r}, valid layouts") + return layout_class() diff --git a/src/memray/_vendor/textual/layouts/grid.py b/src/memray/_vendor/textual/layouts/grid.py new file mode 100644 index 0000000000..0a04f3c218 --- /dev/null +++ b/src/memray/_vendor/textual/layouts/grid.py @@ -0,0 +1,375 @@ +from __future__ import annotations + +from fractions import Fraction +from typing import TYPE_CHECKING, Iterable + +from memray._vendor.textual._resolve import resolve +from memray._vendor.textual.css.scalar import Scalar +from memray._vendor.textual.geometry import NULL_OFFSET, Region, Size, Spacing +from memray._vendor.textual.layout import ArrangeResult, Layout, WidgetPlacement +from memray._vendor.textual.visual import visualize + +if TYPE_CHECKING: + from memray._vendor.textual.widget import Widget + + +class GridLayout(Layout): + """Used to layout Widgets into a grid.""" + + name = "grid" + + def __init__(self) -> None: + self.min_column_width: int | None = None + """Maintain a minimum column width, or `None` for no minimum.""" + self.max_column_width: int | None = None + """Maintain a maximum column width, or `None` for no maximum.""" + self.stretch_height: bool = False + """Stretch the height of cells to be equal in each row.""" + self.regular: bool = False + """Grid should be regular (no remainder in last row).""" + self.expand: bool = False + """Expand the grid to fit the container if it is smaller.""" + self.shrink: bool = False + """Shrink the grid to fit the container if it is larger.""" + self.auto_minimum: bool = False + """If self.shrink is `True`, auto-detect and limit the width.""" + self._grid_size: tuple[int, int] | None = None + """Grid size after last arrange call.""" + + @property + def grid_size(self) -> tuple[int, int] | None: + """The grid size after the last arrange call. + + Returns: + A tuple of (WIDTH, HEIGHT) or `None` prior to the first `arrange`. + """ + return self._grid_size + + def arrange( + self, parent: Widget, children: list[Widget], size: Size, greedy: bool = True + ) -> ArrangeResult: + parent.pre_layout(self) + styles = parent.styles + row_scalars = styles.grid_rows or ( + [Scalar.parse("1fr")] + if (size.height and not parent.styles.is_auto_height) + else [Scalar.parse("auto")] + ) + column_scalars = styles.grid_columns or [Scalar.parse("1fr")] + gutter_horizontal = styles.grid_gutter_horizontal + gutter_vertical = styles.grid_gutter_vertical + + table_size_columns = max(1, styles.grid_size_columns) + min_column_width = self.min_column_width + max_column_width = self.max_column_width + + container_width = size.width + if max_column_width is not None: + container_width = ( + max(1, min(len(children), (container_width // max_column_width))) + * max_column_width + ) + size = Size(container_width, size.height) + + if min_column_width is not None: + table_size_columns = max( + 1, + (container_width + gutter_horizontal) + // (min_column_width + gutter_horizontal), + ) + + table_size_columns = min(table_size_columns, len(children)) + if self.regular: + while len(children) % table_size_columns and table_size_columns > 1: + table_size_columns -= 1 + + table_size_rows = styles.grid_size_rows + + viewport = parent.app.viewport_size + keyline_style, _keyline_color = styles.keyline + offset = (0, 0) + gutter_spacing: Spacing | None + if keyline_style == "none": + gutter_spacing = None + else: + size -= (2, 2) + offset = (1, 1) + gutter_spacing = Spacing( + gutter_vertical, + gutter_horizontal, + gutter_vertical, + gutter_horizontal, + ) + + def cell_coords(column_count: int) -> Iterable[tuple[int, int]]: + """Iterate over table coordinates ad infinitum. + + Args: + column_count: Number of columns + """ + row = 0 + while True: + for column in range(column_count): + yield (column, row) + row += 1 + + def widget_coords( + column_start: int, row_start: int, columns: int, rows: int + ) -> set[tuple[int, int]]: + """Get coords occupied by a cell. + + Args: + column_start: Start column. + row_start: Start_row. + columns: Number of columns. + rows: Number of rows. + + Returns: + Set of coords. + """ + return { + (column, row) + for column in range(column_start, column_start + columns) + for row in range(row_start, row_start + rows) + } + + def repeat_scalars(scalars: Iterable[Scalar], count: int) -> list[Scalar]: + """Repeat an iterable of scalars as many times as required to return + a list of `count` values. + + Args: + scalars: Iterable of values. + count: Number of values to return. + + Returns: + A list of values. + """ + limited_values = list(scalars)[:] + while len(limited_values) < count: + limited_values.extend(scalars) + return limited_values[:count] + + cell_map: dict[tuple[int, int], tuple[Widget, bool]] = {} + cell_size_map: dict[Widget, tuple[int, int, int, int]] = {} + + next_coord = iter(cell_coords(table_size_columns)).__next__ + cell_coord = (0, 0) + column = row = 0 + + for child in children: + child_styles = child.styles + column_span = child_styles.column_span or 1 + row_span = child_styles.row_span or 1 + # Find a slot where this cell fits + # A cell on a previous row may have a row span + while True: + column, row = cell_coord + coords = widget_coords(column, row, column_span, row_span) + if cell_map.keys().isdisjoint(coords): + for coord in coords: + cell_map[coord] = (child, coord == cell_coord) + cell_size_map[child] = ( + column, + row, + column_span - 1, + row_span - 1, + ) + break + else: + cell_coord = next_coord() + continue + cell_coord = next_coord() + + column_scalars = repeat_scalars(column_scalars, table_size_columns) + table_size_rows = table_size_rows if table_size_rows else row + 1 + row_scalars = repeat_scalars(row_scalars, table_size_rows) + self._grid_size = (table_size_columns, table_size_rows) + + def apply_width_limits(widget: Widget, width: int) -> int: + """Apply min and max widths to dimension. + + Args: + widget: A Widget. + width: A width. + + Returns: + New width. + """ + styles = widget.styles + if styles.min_width is not None: + width = max( + width, + int(styles.min_width.resolve(size, viewport, Fraction(width))), + ) + if styles.max_width is not None: + width = min( + width, + int(styles.max_width.resolve(size, viewport, Fraction(width))), + ) + return width + + def apply_height_limits(widget: Widget, height: int) -> int: + """Apply min and max height to a dimension. + + Args: + widget: A widget. + height: A height. + + Returns: + New height + """ + styles = widget.styles + if styles.min_height is not None: + height = max( + height, + int(styles.min_height.resolve(size, viewport, Fraction(height))), + ) + if styles.max_height is not None: + height = min( + height, + int(styles.max_height.resolve(size, viewport, Fraction(height))), + ) + return height + + # Handle any auto columns + for column, scalar in enumerate(column_scalars): + if scalar.is_auto: + width = 0.0 + for row in range(len(row_scalars)): + coord = (column, row) + try: + widget, _ = cell_map[coord] + except KeyError: + pass + else: + if widget.styles.column_span != 1: + continue + width = max( + width, + apply_width_limits( + widget, + widget.get_content_width(size, viewport) + + widget.styles.gutter.width, + ), + ) + column_scalars[column] = Scalar.from_number(width) + + column_minimums: list[int] | None = None + if self.auto_minimum and self.shrink: + column_minimums = [1] * table_size_columns + for column_index in range(table_size_columns): + for row_index in range(len(row_scalars)): + if ( + cell_info := cell_map.get((column_index, row_index)) + ) is not None: + widget = cell_info[0] + column_minimums[column_index] = max( + visualize(widget, widget.render()).get_minimal_width( + widget.styles + ) + + widget.styles.gutter.width, + column_minimums[column_index], + ) + + columns = resolve( + column_scalars, + size.width, + gutter_vertical, + size, + viewport, + expand=self.expand, + shrink=self.shrink, + minimums=column_minimums, + ) + + # Handle any auto rows + for row, scalar in enumerate(row_scalars): + if scalar.is_auto: + height = 0.0 + for column in range(len(column_scalars)): + coord = (column, row) + try: + widget, _ = cell_map[coord] + except KeyError: + pass + else: + if widget.styles.row_span != 1: + continue + column_width = columns[column][1] + gutter_width, gutter_height = widget.styles.gutter.totals + widget_height = apply_height_limits( + widget, + widget.get_content_height( + size, + viewport, + column_width - gutter_width, + ) + + gutter_height, + ) + height = max(height, widget_height) + + row_scalars[row] = Scalar.from_number(height) + + rows = resolve(row_scalars, size.height, gutter_horizontal, size, viewport) + + placements: list[WidgetPlacement] = [] + _WidgetPlacement = WidgetPlacement + add_placement = placements.append + max_column = len(columns) - 1 + max_row = len(rows) - 1 + + for widget, (column, row, column_span, row_span) in cell_size_map.items(): + x = columns[column][0] + if row > max_row: + break + y = rows[row][0] + x2, cell_width = columns[min(max_column, column + column_span)] + y2, cell_height = rows[min(max_row, row + row_span)] + cell_size = Size(cell_width + x2 - x, cell_height + y2 - y) + + box_width, box_height, margin = widget._get_box_model( + cell_size, + viewport, + Fraction(cell_size.width), + Fraction(cell_size.height), + constrain_width=True, + greedy=greedy, + ) + + if self.stretch_height and len(children) > 1: + if box_height <= cell_size.height: + box_height = Fraction(cell_size.height) + + region = ( + Region( + x, y, int(box_width + margin.width), int(box_height + margin.height) + ) + .crop_size(cell_size) + .shrink(margin) + ) + offset + + widget_styles = widget.styles + placement_offset = ( + widget_styles.offset.resolve(cell_size, viewport) + if widget_styles.has_rule("offset") + else NULL_OFFSET + ) + + absolute = ( + widget_styles.has_rule("position") and styles.position == "absolute" + ) + add_placement( + _WidgetPlacement( + region, + placement_offset, + ( + margin + if gutter_spacing is None + else margin.grow_maximum(gutter_spacing) + ), + widget, + absolute, + ) + ) + + return placements diff --git a/src/memray/_vendor/textual/layouts/horizontal.py b/src/memray/_vendor/textual/layouts/horizontal.py new file mode 100644 index 0000000000..513f27ba43 --- /dev/null +++ b/src/memray/_vendor/textual/layouts/horizontal.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +from fractions import Fraction +from typing import TYPE_CHECKING + +from memray._vendor.textual._resolve import resolve_box_models +from memray._vendor.textual.geometry import NULL_OFFSET, Region, Size +from memray._vendor.textual.layout import ArrangeResult, Layout, WidgetPlacement + +if TYPE_CHECKING: + from memray._vendor.textual.geometry import Spacing + from memray._vendor.textual.widget import Widget + + +class HorizontalLayout(Layout): + """Used to layout Widgets horizontally on screen, from left to right. Since Widgets naturally + fill the space of their parent container, all widgets used in a horizontal layout should have a specified. + """ + + name = "horizontal" + + def arrange( + self, parent: Widget, children: list[Widget], size: Size, greedy: bool = True + ) -> ArrangeResult: + parent.pre_layout(self) + placements: list[WidgetPlacement] = [] + add_placement = placements.append + viewport = parent.app.viewport_size + + child_styles = [child.styles for child in children] + box_margins: list[Spacing] = [ + styles.margin for styles in child_styles if styles.overlay != "screen" + ] + if box_margins: + resolve_margin = Size( + sum( + [ + max(margin1[1], margin2[3]) + for margin1, margin2 in zip(box_margins, box_margins[1:]) + ] + ) + + (box_margins[0].left + box_margins[-1].right), + max( + [ + margin_top + margin_bottom + for margin_top, _, margin_bottom, _ in box_margins + ] + ), + ) + else: + resolve_margin = Size(0, 0) + + box_models = resolve_box_models( + [styles.width for styles in child_styles], + children, + size, + viewport, + resolve_margin, + resolve_dimension="width", + greedy=greedy, + ) + + margins = [ + max((box1.margin.right, box2.margin.left)) + for box1, box2 in zip(box_models, box_models[1:]) + ] + if box_models: + margins.append(box_models[-1].margin.right) + + x = next( + ( + Fraction(box_model.margin.left) + for box_model, child in zip(box_models, children) + if child.styles.overlay != "screen" + ), + Fraction(0), + ) + + _Region = Region + _WidgetPlacement = WidgetPlacement + _Size = Size + for widget, (content_width, content_height, box_margin), margin in zip( + children, box_models, margins + ): + styles = widget.styles + overlay = styles.overlay == "screen" + offset = ( + styles.offset.resolve( + _Size(content_width.__floor__(), content_height.__floor__()), + viewport, + ) + if styles.has_rule("offset") + else NULL_OFFSET + ) + offset_y = box_margin.top + next_x = x + content_width + + region = _Region( + x.__floor__(), + offset_y, + (next_x - x.__floor__()).__floor__(), + content_height.__floor__(), + ) + absolute = styles.has_rule("position") and styles.position == "absolute" + add_placement( + _WidgetPlacement( + region, + offset, + box_margin, + widget, + 0, + False, + overlay, + absolute, + ) + ) + if not overlay and not absolute: + x = next_x + margin + + return placements diff --git a/src/memray/_vendor/textual/layouts/stream.py b/src/memray/_vendor/textual/layouts/stream.py new file mode 100644 index 0000000000..82c0e42590 --- /dev/null +++ b/src/memray/_vendor/textual/layouts/stream.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +from itertools import zip_longest +from typing import TYPE_CHECKING + +from memray._vendor.textual.geometry import NULL_OFFSET, Region, Size +from memray._vendor.textual.layout import ArrangeResult, Layout, WidgetPlacement + +if TYPE_CHECKING: + from memray._vendor.textual.widget import Widget + + +class StreamLayout(Layout): + """A cut down version of the vertical layout. + + The stream layout is faster, but has a few limitations compared to the vertical layout. + + - All widgets are the full width (as if their widget is `1fr`). + - All widgets have an effective height of `auto`. + - `max-height` is supported, but only if it is a units value, all other extrema rules are ignored. + - No absolute positioning. + - No overlay: screen. + - Layers are ignored. + - Non TCSS styles are ignored. + + The primary use of `layout: stream` is for a long list of widgets in a scrolling container, such as + what you might expect from a LLM chat-bot. The speed improvement will only be significant with a lot of + child widgets, so stick to vertical layouts unless you see any slowdown. + + """ + + name = "stream" + + def __init__(self) -> None: + self._cached_placements: list[WidgetPlacement] | None = None + self._cached_width = 0 + super().__init__() + + def arrange( + self, parent: Widget, children: list[Widget], size: Size, greedy: bool = True + ) -> ArrangeResult: + parent.pre_layout(self) + if not children: + return [] + viewport = parent.app.viewport_size + + if size.width != self._cached_width: + self._cached_placements = None + previous_results = self._cached_placements or [] + + layout_widgets = parent.screen._layout_widgets.get(parent, []) + + _Region = Region + _WidgetPlacement = WidgetPlacement + + placements: list[WidgetPlacement] = [] + width = size.width + first_child_styles = children[0].styles + y = 0 + previous_margin = first_child_styles.margin.top + null_offset = NULL_OFFSET + + pre_populate = bool(previous_results and layout_widgets) + for widget, placement in zip_longest(children, previous_results): + if pre_populate and placement is not None and widget is placement.widget: + if widget in layout_widgets: + pre_populate = False + else: + placements.append(placement) + y = placement.region.bottom + styles = widget.styles._base_styles + previous_margin = styles.margin.bottom + continue + if widget is None: + break + + styles = widget.styles._base_styles + margin = styles.margin + gutter_width, gutter_height = styles.gutter.totals + top, right, bottom, left = margin + y += top if top > previous_margin else previous_margin + previous_margin = bottom + height = ( + widget.get_content_height(size, viewport, width - gutter_width) + + gutter_height + ) + if (max_height := styles.max_height) is not None and max_height.is_cells: + height = ( + height + if height < (max_height_value := int(max_height.value)) + else max_height_value + ) + if (min_height := styles.min_height) is not None and min_height.is_cells: + height = ( + height + if height > (min_height_value := int(min_height.value)) + else min_height_value + ) + placements.append( + _WidgetPlacement( + _Region(left, y, width - (left + right), height), + null_offset, + margin, + widget, + 0, + False, + False, + False, + ) + ) + y += height + + self._cached_width = size.width + self._cached_placements = placements + return placements + + def get_content_width(self, widget: Widget, container: Size, viewport: Size) -> int: + """Get the optimal content width by arranging children. + + Args: + widget: The container widget. + container: The container size. + viewport: The viewport size. + + Returns: + Width of the content. + """ + return widget.scrollable_content_region.width + + def get_content_height( + self, widget: Widget, container: Size, viewport: Size, width: int + ) -> int: + """Get the content height. + + Args: + widget: The container widget. + container: The container size. + viewport: The viewport. + width: The content width. + + Returns: + Content height (in lines). + """ + if widget._nodes: + arrangement = widget.arrange(Size(width, 0)) + height = arrangement.total_region.height + else: + height = 0 + return height diff --git a/src/memray/_vendor/textual/layouts/vertical.py b/src/memray/_vendor/textual/layouts/vertical.py new file mode 100644 index 0000000000..b6eb184917 --- /dev/null +++ b/src/memray/_vendor/textual/layouts/vertical.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +from fractions import Fraction +from typing import TYPE_CHECKING + +from memray._vendor.textual._resolve import resolve_box_models +from memray._vendor.textual.geometry import NULL_OFFSET, Region, Size +from memray._vendor.textual.layout import ArrangeResult, Layout, WidgetPlacement + +if TYPE_CHECKING: + from memray._vendor.textual.geometry import Spacing + from memray._vendor.textual.widget import Widget + + +class VerticalLayout(Layout): + """Used to layout Widgets vertically on screen, from top to bottom.""" + + name = "vertical" + + def arrange( + self, parent: Widget, children: list[Widget], size: Size, greedy: bool = True + ) -> ArrangeResult: + parent.pre_layout(self) + placements: list[WidgetPlacement] = [] + add_placement = placements.append + viewport = parent.app.viewport_size + + child_styles = [child.styles for child in children] + box_margins: list[Spacing] = [ + styles.margin for styles in child_styles if styles.overlay != "screen" + ] + if box_margins: + resolve_margin = Size( + max( + [ + margin_right + margin_left + for _, margin_right, _, margin_left in box_margins + ] + ), + sum( + [ + bottom if bottom > top else top + for (_, _, bottom, _), (top, _, _, _) in zip( + box_margins, box_margins[1:] + ) + ] + ) + + (box_margins[0].top + box_margins[-1].bottom), + ) + else: + resolve_margin = Size(0, 0) + + box_models = resolve_box_models( + [styles.height for styles in child_styles], + children, + size, + parent.app.size, + resolve_margin, + resolve_dimension="height", + greedy=greedy, + ) + + margins = [ + ( + margin_bottom + if (margin_bottom := margin1.bottom) > (margin_top := margin2.top) + else margin_top + ) + for (_, _, margin1), (_, _, margin2) in zip(box_models, box_models[1:]) + ] + + if box_models: + margins.append(box_models[-1].margin.bottom) + + y = next( + ( + Fraction(box_model.margin.top) + for box_model, child in zip(box_models, children) + if child.styles.overlay != "screen" + ), + Fraction(0), + ) + + _Region = Region + _WidgetPlacement = WidgetPlacement + _Size = Size + for widget, (content_width, content_height, box_margin), margin in zip( + children, box_models, margins + ): + styles = widget.styles + overlay = styles.overlay == "screen" + next_y = y + content_height + offset = ( + styles.offset.resolve( + _Size(content_width.__floor__(), content_height.__floor__()), + viewport, + ) + if styles.has_rule("offset") + else NULL_OFFSET + ) + + region = _Region( + box_margin.left, + y.__floor__(), + content_width.__floor__(), + next_y.__floor__() - y.__floor__(), + ) + absolute = styles.has_rule("position") and styles.position == "absolute" + add_placement( + _WidgetPlacement( + region, + offset, + box_margin, + widget, + 0, + False, + overlay, + absolute, + ) + ) + if not overlay and not absolute: + y = next_y + margin + + return placements diff --git a/src/memray/_vendor/textual/lazy.py b/src/memray/_vendor/textual/lazy.py new file mode 100644 index 0000000000..66a37c4222 --- /dev/null +++ b/src/memray/_vendor/textual/lazy.py @@ -0,0 +1,141 @@ +""" +Tools for lazy loading widgets. +""" + +from __future__ import annotations + +from memray._vendor.textual.widget import Widget + + +class Lazy(Widget): + """Wraps a widget so that it is mounted *lazily*. + + Lazy widgets are mounted after the first refresh. This can be used to display some parts of + the UI very quickly, followed by the lazy widgets. Technically, this won't make anything + faster, but it reduces the time the user sees a blank screen and will make apps feel + more responsive. + + Making a widget lazy is beneficial for widgets which start out invisible, such as tab panes. + + Note that since lazy widgets aren't mounted immediately (by definition), they will not appear + in queries for a brief interval until they are mounted. Your code should take this into account. + + Example: + ```python + def compose(self) -> ComposeResult: + yield Footer() + with ColorTabs("Theme Colors", "Named Colors"): + yield Content(ThemeColorButtons(), ThemeColorsView(), id="theme") + yield Lazy(NamedColorsView()) + ``` + + """ + + DEFAULT_CSS = """ + Lazy { + display: none; + } + """ + + def __init__(self, widget: Widget) -> None: + """Create a lazy widget. + + Args: + widget: A widget that should be mounted after a refresh. + """ + self._replace_widget = widget + super().__init__() + + def compose_add_child(self, widget: Widget) -> None: + self._replace_widget.compose_add_child(widget) + + async def mount_composed_widgets(self, widgets: list[Widget]) -> None: + parent = self.parent + if parent is None: + return + assert isinstance(parent, Widget) + + async def mount() -> None: + """Perform the mount and discard the lazy widget.""" + await parent.mount(self._replace_widget, after=self) + await self.remove() + + self.call_after_refresh(mount) + + +class Reveal(Widget): + """Similar to [Lazy][textual.lazy.Lazy], but mounts children sequentially. + + This is useful when you have so many child widgets that there is a noticeable delay before + you see anything. By mounting the children over several frames, the user will feel that + something is happening. + + Example: + ```python + def compose(self) -> ComposeResult: + with lazy.Reveal(containers.VerticalScroll(can_focus=False)): + yield Markdown(WIDGETS_MD, classes="column") + yield Buttons() + yield Checkboxes() + yield Datatables() + yield Inputs() + yield ListViews() + yield Logs() + yield Sparklines() + yield Footer() + ``` + """ + + DEFAULT_CSS = """ + Reveal { + display: none; + } + """ + + def __init__(self, widget: Widget) -> None: + """ + Args: + widget: A widget to mount. + """ + self._replace_widget = widget + self._widgets: list[Widget] = [] + super().__init__() + + @classmethod + def _reveal(cls, parent: Widget, widgets: list[Widget]) -> None: + """Reveal children lazily. + + Args: + parent: The parent widget. + widgets: Child widgets. + """ + + async def check_children() -> None: + """Check for pending children""" + if not widgets: + return + widget = widgets.pop(0) + try: + await parent.mount(widget) + except Exception: + # I think this can occur if the parent is removed before all children are added + # Only noticed this on shutdown + return + + if widgets: + parent.set_timer(0.02, check_children) + + parent.call_next(check_children) + + def compose_add_child(self, widget: Widget) -> None: + self._widgets.append(widget) + + async def mount_composed_widgets(self, widgets: list[Widget]) -> None: + parent = self.parent + if parent is None: + return + assert isinstance(parent, Widget) + await parent.mount(self._replace_widget, after=self) + await self.remove() + self._reveal(self._replace_widget, self._widgets.copy()) + self._widgets.clear() diff --git a/src/memray/_vendor/textual/logging.py b/src/memray/_vendor/textual/logging.py new file mode 100644 index 0000000000..9ecdfc59cc --- /dev/null +++ b/src/memray/_vendor/textual/logging.py @@ -0,0 +1,40 @@ +""" +A Textual Logging handler. + +If there is an active Textual app, then log messages will go via the app (and logged via textual console). + +If there is *no* active app, then log messages will go to stderr or stdout, depending on configuration. +""" + +import sys +from logging import Handler, LogRecord + +from memray._vendor.textual._context import active_app + + +class TextualHandler(Handler): + """A Logging handler for Textual apps.""" + + def __init__(self, stderr: bool = True, stdout: bool = False) -> None: + """Initialize a Textual logging handler. + + Args: + stderr: Log to stderr when there is no active app. + stdout: Log to stdout when there is no active app. + """ + super().__init__() + self._stderr = stderr + self._stdout = stdout + + def emit(self, record: LogRecord) -> None: + """Invoked by logging.""" + message = self.format(record) + try: + app = active_app.get() + except LookupError: + if self._stderr: + print(message, file=sys.stderr) + elif self._stdout: + print(message, file=sys.stdout) + else: + app.log.logging(message) diff --git a/src/memray/_vendor/textual/map_geometry.py b/src/memray/_vendor/textual/map_geometry.py new file mode 100644 index 0000000000..d1a51bb68c --- /dev/null +++ b/src/memray/_vendor/textual/map_geometry.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from typing import NamedTuple + +from memray._vendor.textual.geometry import Region, Size, Spacing + + +class MapGeometry(NamedTuple): + """Defines the absolute location of a Widget.""" + + region: Region + """The (screen) [region][textual.geometry.Region] occupied by the widget.""" + order: tuple[tuple[int, int, int], ...] + """Tuple of tuples defining the painting order of the widget. + + Each successive triple represents painting order information with regards to + ancestors in the DOM hierarchy and the last triple provides painting order + information for this specific widget. + """ + clip: Region + """A [region][textual.geometry.Region] to clip the widget by (if a Widget is within a container).""" + virtual_size: Size + """The virtual [size][textual.geometry.Size] (scrollable area) of a widget if it is a container.""" + container_size: Size + """The container [size][textual.geometry.Size] (area not occupied by scrollbars).""" + virtual_region: Region + """The [region][textual.geometry.Region] relative to the container (but not necessarily visible).""" + dock_gutter: Spacing + """Space from the container reserved by docked widgets.""" + + @property + def visible_region(self) -> Region: + """The Widget region after clipping.""" + return self.clip.intersection(self.region) diff --git a/src/memray/_vendor/textual/markup.py b/src/memray/_vendor/textual/markup.py new file mode 100644 index 0000000000..625683d42a --- /dev/null +++ b/src/memray/_vendor/textual/markup.py @@ -0,0 +1,463 @@ +""" +Utilities related to content markup. + +""" + +from __future__ import annotations + +from operator import itemgetter + +from memray._vendor.textual.css.parse import substitute_references +from memray._vendor.textual.css.tokenizer import UnexpectedEnd + +__all__ = ["MarkupError", "escape", "to_content"] + +import re +from string import Template +from typing import TYPE_CHECKING, Callable, Mapping, Match + +from memray._vendor.textual._context import active_app +from memray._vendor.textual.color import Color +from memray._vendor.textual.css.tokenize import ( + COLOR, + PERCENT, + TOKEN, + VARIABLE_REF, + Expect, + TokenizerState, + tokenize_values, +) +from memray._vendor.textual.style import Style + +if TYPE_CHECKING: + from memray._vendor.textual.content import Content + + +class MarkupError(Exception): + """An error occurred parsing content markup.""" + + +expect_markup_tag = ( + Expect( + "markup style value", + end_tag=r"(? str: + """Escapes text so that it won't be interpreted as markup. + + Args: + markup (str): Content to be inserted in to markup. + + Returns: + str: Markup with square brackets escaped. + """ + + def escape_backslashes(match: Match[str]) -> str: + """Called by re.sub replace matches.""" + backslashes, text = match.groups() + return f"{backslashes}{backslashes}\\{text}" + + markup = _escape(escape_backslashes, markup) + if markup.endswith("\\") and not markup.endswith("\\\\"): + return markup + "\\" + + return markup + + +def parse_style(style: str, variables: dict[str, str] | None = None) -> Style: + """Parse a style with substituted variables. + + Args: + style: Style encoded in a string. + variables: Mapping of variables, or `None` to import from active app. + + Returns: + A Style object. + """ + + styles: dict[str, bool | None] = {} + color: Color | None = None + background: Color | None = None + is_background: bool = False + style_state: bool = True + + tokenizer = StyleTokenizer() + meta = {} + + if variables is None: + try: + app = active_app.get() + except LookupError: + reference_tokens = {} + else: + reference_tokens = app.stylesheet._variable_tokens + else: + reference_tokens = tokenize_values(variables) + + iter_tokens = iter( + substitute_references( + tokenizer(style, ("inline style", "")), + reference_tokens, + ) + ) + + for token in iter_tokens: + token_name = token.name + token_value = token.value + if token_name == "key": + key = token_value.rstrip("=") + parenthesis: list[str] = [] + value_text: list[str] = [] + first_token = next(iter_tokens) + if first_token.name in {"double_string", "single_string"}: + meta[key] = first_token.value[1:-1] + break + else: + value_text.append(first_token.value) + for token in iter_tokens: + if token.name == "whitespace" and not parenthesis: + break + value_text.append(token.value) + if token.name in {"round_start", "square_start", "curly_start"}: + parenthesis.append(token.value) + elif token.name in {"round_end", "square_end", "curly_end"}: + parenthesis.pop() + if not parenthesis: + break + tokenizer.expect(StyleTokenizer.EXPECT) + + value = "".join(value_text) + meta[key] = value + + elif token_name == "color": + if is_background: + background = Color.parse(token.value) + else: + color = Color.parse(token.value) + + elif token_name == "token": + if token_value == "link": + if "link" not in meta: + meta["link"] = "" + elif token_value == "on": + is_background = True + elif token_value == "auto": + if is_background: + background = Color.automatic() + else: + color = Color.automatic() + elif token_value == "not": + style_state = False + elif token_value in STYLES: + styles[token_value] = style_state + style_state = True + elif token_value in STYLE_ABBREVIATIONS: + styles[STYLE_ABBREVIATIONS[token_value]] = style_state + style_state = True + else: + if is_background: + background = Color.parse(token_value) + else: + color = Color.parse(token_value) + + elif token_name == "percent": + percent = int(token_value.rstrip("%")) / 100.0 + if is_background: + if background is not None: + background = background.multiply_alpha(percent) + else: + if color is not None: + color = color.multiply_alpha(percent) + + parsed_style = Style(background, color, link=meta.pop("link", None), **styles) + + if meta: + parsed_style += Style.from_meta(meta) + return parsed_style + + +def to_content( + markup: str, + style: str | Style = "", + template_variables: Mapping[str, object] | None = None, +) -> Content: + """Convert markup to Content. + + Args: + markup: String containing markup. + style: Optional base style. + template_variables: Mapping of string.Template variables + + Raises: + MarkupError: If the markup is invalid. + + Returns: + Content that renders the markup. + """ + _rich_traceback_omit = True + try: + return _to_content(markup, style, template_variables) + except UnexpectedEnd: + raise MarkupError( + "Unexpected end of markup; are you missing a closing square bracket?" + ) from None + except Exception as error: + # Ensure all errors are wrapped in a MarkupError + raise MarkupError(str(error)) from None + + +def _to_content( + markup: str, + style: str | Style = "", + template_variables: Mapping[str, object] | None = None, +) -> Content: + """Internal function to convert markup to Content. + + Args: + markup: String containing markup. + style: Optional base style. + template_variables: Mapping of string.Template variables + + Raises: + MarkupError: If the markup is invalid. + + Returns: + Content that renders the markup. + """ + + from memray._vendor.textual.content import Content, Span + + tokenizer = MarkupTokenizer() + text: list[str] = [] + text_append = text.append + iter_tokens = iter(tokenizer(markup, ("inline", ""))) + + style_stack: list[tuple[int, str, str]] = [] + + spans: list[Span] = [] + + position = 0 + tag_text: list[str] + + normalize_markup_tag = Style._normalize_markup_tag + + if template_variables is None: + process_text = lambda text: text + + else: + + def process_text(template_text: str, /) -> str: + if "$" in template_text: + return Template(template_text).safe_substitute(template_variables) + return template_text + + for token in iter_tokens: + token_name = token.name + if token_name == "text": + value = process_text(token.value.replace("\\[", "[")) + text_append(value) + position += len(value) + + elif token_name == "open_tag": + tag_text = [] + + eof = False + contains_text = False + for token in iter_tokens: + if token.name == "end_tag": + break + elif token.name == "text": + contains_text = True + elif token.name == "eof": + eof = True + tag_text.append(token.value) + if contains_text or eof: + # "tag" was unparsable + text_content = f"[{''.join(tag_text)}" + ("" if eof else "]") + text_append(text_content) + position += len(text_content) + else: + opening_tag = "".join(tag_text) + + if not opening_tag.strip(): + blank_tag = f"[{opening_tag}]" + text_append(blank_tag) + position += len(blank_tag) + else: + style_stack.append( + ( + position, + opening_tag, + normalize_markup_tag(opening_tag.strip()), + ) + ) + + elif token_name == "open_closing_tag": + tag_text = [] + for token in iter_tokens: + if token.name == "end_tag": + break + tag_text.append(token.value) + closing_tag = "".join(tag_text).strip() + normalized_closing_tag = normalize_markup_tag(closing_tag) + if normalized_closing_tag: + for index, (tag_position, tag_body, normalized_tag_body) in enumerate( + reversed(style_stack), 1 + ): + if normalized_tag_body == normalized_closing_tag: + style_stack.pop(-index) + if tag_position != position: + spans.append(Span(tag_position, position, tag_body)) + break + else: + raise MarkupError( + f"closing tag '[/{closing_tag}]' does not match any open tag" + ) + + else: + if not style_stack: + raise MarkupError("auto closing tag ('[/]') has nothing to close") + open_position, tag_body, _ = style_stack.pop() + if open_position != position: + spans.append(Span(open_position, position, tag_body)) + + content_text = "".join(text) + text_length = len(content_text) + if style_stack and text_length: + spans.extend( + [ + Span(position, text_length, tag_body) + for position, tag_body, _ in reversed(style_stack) + if position != text_length + ] + ) + spans.reverse() + spans.sort(key=itemgetter(0)) # Zeroth item of Span is 'start' attribute + + content = Content( + content_text, + [Span(0, text_length, style), *spans] if (style and text_length) else spans, + ) + + return content + + +if __name__ == "__main__": # pragma: no cover + from memray._vendor.textual._markup_playground import MarkupPlayground + + app = MarkupPlayground() + app.run() diff --git a/src/memray/_vendor/textual/message.py b/src/memray/_vendor/textual/message.py new file mode 100644 index 0000000000..2e69f2cd90 --- /dev/null +++ b/src/memray/_vendor/textual/message.py @@ -0,0 +1,158 @@ +""" + +The base class for all messages (including events). +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, ClassVar + +import rich.repr +from typing_extensions import Self + +from memray._vendor.textual import _time +from memray._vendor.textual._context import active_message_pump +from memray._vendor.textual.case import camel_to_snake + +if TYPE_CHECKING: + from memray._vendor.textual.dom import DOMNode + from memray._vendor.textual.message_pump import MessagePump + + +@rich.repr.auto +class Message: + """Base class for a message.""" + + __slots__ = [ + "_sender", + "time", + "_forwarded", + "_no_default_action", + "_stop_propagation", + "_prevent", + ] + + ALLOW_SELECTOR_MATCH: ClassVar[set[str]] = set() + """Additional attributes that can be used with the [`on` decorator][textual.on]. + + These attributes must be widgets. + """ + bubble: ClassVar[bool] = True # Message will bubble to parent + verbose: ClassVar[bool] = False # Message is verbose + no_dispatch: ClassVar[bool] = False # Message may not be handled by client code + namespace: ClassVar[str] = "" # Namespace to disambiguate messages + handler_name: ClassVar[str] + """Name of the default message handler.""" + + def __init__(self) -> None: + self.__post_init__() + + def __post_init__(self) -> None: + """Allow dataclasses to initialize the object.""" + self._sender: MessagePump | None = active_message_pump.get(None) + self.time: float = _time.get_time() + self._forwarded = False + self._no_default_action = False + self._stop_propagation = False + self._prevent: set[type[Message]] = set() + + def __rich_repr__(self) -> rich.repr.Result: + yield from () + + def __init_subclass__( + cls, + bubble: bool | None = True, + verbose: bool = False, + no_dispatch: bool | None = False, + namespace: str | None = None, + ) -> None: + super().__init_subclass__() + if bubble is not None: + cls.bubble = bubble + cls.verbose = verbose + if no_dispatch is not None: + cls.no_dispatch = no_dispatch + if namespace is not None: + cls.namespace = namespace + name = f"{namespace}_{camel_to_snake(cls.__name__)}" + else: + # a class defined inside of a function will have a qualified name like func..Class, + # so make sure we only use the actual class name(s) + qualname = cls.__qualname__.rsplit(".", 1)[-1] + # only keep the last two parts of the qualified name of deeply nested classes + # for backwards compatibility, e.g. A.B.C.D becomes C.D + namespace = qualname.rsplit(".", 2)[-2:] + name = "_".join(camel_to_snake(part) for part in namespace) + cls.handler_name = f"on_{name}" + + @property + def control(self) -> DOMNode | None: + """The widget associated with this message, or None by default.""" + return None + + @property + def is_forwarded(self) -> bool: + """Has the message been forwarded?""" + return self._forwarded + + def _set_forwarded(self) -> None: + """Mark this event as being forwarded.""" + self._forwarded = True + + def set_sender(self, sender: MessagePump) -> Self: + """Set the sender of the message. + + Args: + sender: The sender. + + Note: + When creating a message the sender is automatically set. + Normally there will be no need for this method to be called. + This method will be used when strict control is required over + the sender of a message. + + Returns: + Self. + """ + self._sender = sender + return self + + def can_replace(self, message: "Message") -> bool: + """Check if another message may supersede this one. + + Args: + message: Another message. + + Returns: + True if this message may replace the given message + """ + return False + + def prevent_default(self, prevent: bool = True) -> Message: + """Suppress the default action(s). This will prevent handlers in any base classes + from being called. + + Args: + prevent: True if the default action should be suppressed, + or False if the default actions should be performed. + """ + self._no_default_action = prevent + return self + + def stop(self, stop: bool = True) -> Message: + """Stop propagation of the message to parent. + + Args: + stop: The stop flag. + """ + self._stop_propagation = stop + return self + + def _bubble_to(self, widget: MessagePump) -> None: + """Bubble to a widget (typically the parent). + + Args: + widget: Target of bubble. + """ + self._no_default_action = False + widget.post_message(self) diff --git a/src/memray/_vendor/textual/message_pump.py b/src/memray/_vendor/textual/message_pump.py new file mode 100644 index 0000000000..7ddd2e51ab --- /dev/null +++ b/src/memray/_vendor/textual/message_pump.py @@ -0,0 +1,920 @@ +""" + +A `MessagePump` is a base class for any object which processes messages, which includes Widget, Screen, and App. + +!!! tip + + Most of the method here are useful in general app development. + +""" + +from __future__ import annotations + +import asyncio +import threading +from asyncio import CancelledError, QueueEmpty, Task, create_task +from contextlib import contextmanager +from functools import partial +from time import perf_counter +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Generator, + Iterable, + Type, + TypeVar, + cast, +) +from weakref import WeakSet, ref + +from memray._vendor.textual import Logger, events, log, messages +from memray._vendor.textual._callback import invoke +from memray._vendor.textual._compat import cached_property +from memray._vendor.textual._context import NoActiveAppError, active_app, active_message_pump +from memray._vendor.textual._context import message_hook as message_hook_context_var +from memray._vendor.textual._context import prevent_message_types_stack +from memray._vendor.textual._on import OnNoWidget +from memray._vendor.textual._queue import Queue +from memray._vendor.textual._time import time +from memray._vendor.textual.constants import SLOW_THRESHOLD +from memray._vendor.textual.css.match import match +from memray._vendor.textual.events import Event +from memray._vendor.textual.message import Message +from memray._vendor.textual.reactive import Reactive, TooManyComputesError +from memray._vendor.textual.signal import Signal +from memray._vendor.textual.timer import Timer, TimerCallback + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + from memray._vendor.textual.app import App + from memray._vendor.textual.css.model import SelectorSet + + +Callback: TypeAlias = "Callable[..., Any] | Callable[..., Awaitable[Any]]" + + +class CallbackError(Exception): + pass + + +class MessagePumpClosed(Exception): + pass + + +_MessagePumpMetaSub = TypeVar("_MessagePumpMetaSub", bound="_MessagePumpMeta") + + +class _MessagePumpMeta(type): + """Metaclass for message pump. This exists to populate a Message inner class of a Widget with the + parent classes' name. + """ + + def __new__( + cls: Type[_MessagePumpMetaSub], + name: str, + bases: tuple[type, ...], + class_dict: dict[str, Any], + **kwargs: Any, + ) -> _MessagePumpMetaSub: + handlers: dict[ + type[Message], list[tuple[Callable, dict[str, tuple[SelectorSet, ...]]]] + ] = class_dict.get("_decorated_handlers", {}) + + class_dict["_decorated_handlers"] = handlers + + for value in class_dict.values(): + if callable(value) and hasattr(value, "_textual_on"): + textual_on: list[ + tuple[type[Message], dict[str, tuple[SelectorSet, ...]]] + ] = getattr(value, "_textual_on") + for message_type, selectors in textual_on: + handlers.setdefault(message_type, []).append((value, selectors)) + + # Look for reactives with public AND private compute methods. + prefix = "compute_" + prefix_len = len(prefix) + for attr_name, value in class_dict.items(): + if attr_name.startswith(prefix) and callable(value): + reactive_name = attr_name[prefix_len:] + if ( + reactive_name in class_dict + and isinstance(class_dict[reactive_name], Reactive) + and f"_{attr_name}" in class_dict + ): + raise TooManyComputesError( + f"reactive {reactive_name!r} can't have two computes." + ) + + class_obj = super().__new__(cls, name, bases, class_dict, **kwargs) + return class_obj + + +class MessagePump(metaclass=_MessagePumpMeta): + """Base class which supplies a message pump.""" + + def __init__(self, parent: MessagePump | None = None) -> None: + self._parent = parent + self._running: bool = False + self._closing: bool = False + self._closed: bool = False + self._disabled_messages: set[type[Message]] = set() + self._pending_message: Message | None = None + self._task: Task | None = None + self._timers: WeakSet[Timer] = WeakSet() + self._last_idle: float = time() + self._max_idle: float | None = None + self._is_mounted = False + """Having this explicit Boolean is an optimization. + + The same information could be retrieved from `self._mounted_event.is_set()`, but + we need to access this frequently in the compositor and the attribute with the + explicit Boolean value is faster than the two lookups and the function call. + """ + self._next_callbacks: list[events.Callback] = [] + self._thread_id: int = threading.get_ident() + self._prevented_messages_on_mount = self._prevent_message_types_stack[-1] + self.message_signal: Signal[Message] = Signal(self, "messages") + """Subscribe to this signal to be notified of all messages sent to this widget. + + This is a fairly low-level mechanism, and shouldn't replace regular message handling. + + """ + + @property + def _parent(self) -> MessagePump | None: + """The current parent message pump (if set).""" + return None if self.__parent is None else self.__parent() + + @_parent.setter + def _parent(self, parent: MessagePump | None) -> None: + self.__parent = None if parent is None else ref(parent) + + @cached_property + def _message_queue(self) -> Queue[Message | None]: + return Queue() + + @cached_property + def _mounted_event(self) -> asyncio.Event: + return asyncio.Event() + + @property + def _prevent_message_types_stack(self) -> list[set[type[Message]]]: + """The stack that manages prevented messages.""" + try: + stack = prevent_message_types_stack.get() + except LookupError: + stack = [set()] + prevent_message_types_stack.set(stack) + return stack + + def _thread_init(self): + """Initialize threading primitives for the current thread. + + Require for Python3.8 https://github.com/Textualize/textual/issues/5845 + + """ + self._message_queue + self._mounted_event + + def _get_prevented_messages(self) -> set[type[Message]]: + """A set of all the prevented message types.""" + return self._prevent_message_types_stack[-1] + + def _is_prevented(self, message_type: type[Message]) -> bool: + """Check if a message type has been prevented via the + [prevent][textual.message_pump.MessagePump.prevent] context manager. + + Args: + message_type: A message type. + + Returns: + `True` if the message has been prevented from sending, or `False` if it will be sent as normal. + """ + return message_type in self._prevent_message_types_stack[-1] + + @contextmanager + def prevent(self, *message_types: type[Message]) -> Generator[None, None, None]: + """A context manager to *temporarily* prevent the given message types from being posted. + + Example: + ```python + input = self.query_one(Input) + with self.prevent(Input.Changed): + input.value = "foo" + ``` + """ + if message_types: + prevent_stack = self._prevent_message_types_stack + prevent_stack.append(prevent_stack[-1].union(message_types)) + try: + yield + finally: + prevent_stack.pop() + else: + yield + + @property + def task(self) -> Task: + assert self._task is not None + return self._task + + @property + def has_parent(self) -> bool: + """Does this object have a parent?""" + return self._parent is not None + + @property + def message_queue_size(self) -> int: + """The current size of the message queue.""" + return self._message_queue.qsize() + + @property + def is_dom_root(self): + """Is this a root node (i.e. the App)?""" + return False + + if TYPE_CHECKING: + from memray._vendor.textual import getters + + app = getters.app(App) + else: + + @property + def app(self) -> "App[object]": + """ + Get the current app. + + Returns: + The current app. + + Raises: + NoActiveAppError: if no active app could be found for the current asyncio context + """ + try: + return active_app.get() + except LookupError: + from memray._vendor.textual.app import App + + node: MessagePump | None = self + while not isinstance(node, App): + if node is None: + raise NoActiveAppError() + node = node._parent + + return node + + @property + def is_attached(self) -> bool: + """Is this node linked to the app through the DOM?""" + try: + if self.app._exit: + return False + except NoActiveAppError: + return False + node: MessagePump | None = self + while (node := node._parent) is not None: + if node.is_dom_root: + return True + return False + + @property + def is_parent_active(self) -> bool: + """Is the parent active?""" + parent = self._parent + return bool(parent is not None and not parent._closed and not parent._closing) + + @property + def is_running(self) -> bool: + """Is the message pump running (potentially processing messages)?""" + return self._running + + @property + def log(self) -> Logger: + """Get a logger for this object. + + Returns: + A logger. + """ + return self.app._logger + + def _attach(self, parent: MessagePump) -> None: + """Set the parent, and therefore attach this node to the tree. + + Args: + parent: Parent node. + """ + self._parent = parent + + def _detach(self) -> None: + """Set the parent to None to remove the node from the tree.""" + self._parent = None + + def check_message_enabled(self, message: Message) -> bool: + """Check if a given message is enabled (allowed to be sent). + + Args: + message: A message object. + + Returns: + `True` if the message will be sent, or `False` if it is disabled. + """ + + return type(message) not in self._disabled_messages + + def disable_messages(self, *messages: type[Message]) -> None: + """Disable message types from being processed.""" + self._disabled_messages.update(messages) + + def enable_messages(self, *messages: type[Message]) -> None: + """Enable processing of messages types.""" + self._disabled_messages.difference_update(messages) + + async def _get_message(self) -> Message: + """Get the next event on the queue, or None if queue is closed. + + Returns: + Event object or None. + """ + if self._closed: + raise MessagePumpClosed("The message pump is closed") + if self._pending_message is not None: + try: + return self._pending_message + finally: + self._pending_message = None + + message = await self._message_queue.get() + + if message is None: + self._closed = True + raise MessagePumpClosed("The message pump is now closed") + return message + + def _peek_message(self) -> Message | None: + """Peek the message at the head of the queue (does not remove it from the queue), + or return None if the queue is empty. + + Returns: + The message or None. + """ + if self._pending_message is None: + try: + message = self._message_queue.get_nowait() + except QueueEmpty: + pass + else: + if message is None: + self._closed = True + raise MessagePumpClosed("The message pump is now closed") + self._pending_message = message + + if self._pending_message is not None: + return self._pending_message + return None + + def set_timer( + self, + delay: float, + callback: TimerCallback | None = None, + *, + name: str | None = None, + pause: bool = False, + ) -> Timer: + """Call a function after a delay. + + Example: + ```python + def ready(): + self.notify("Your soft boiled egg is ready!") + # Call ready() after 3 minutes + self.set_timer(3 * 60, ready) + ``` + + Args: + delay: Time (in seconds) to wait before invoking callback. + callback: Callback to call after time has expired. + name: Name of the timer (for debug). + pause: Start timer paused. + + Returns: + A timer object. + """ + + timer = Timer( + self, + delay, + name=name or f"set_timer#{Timer._timer_count}", + callback=None if callback is None else partial(self.call_next, callback), + repeat=0, + pause=pause, + ) + timer._start() + self._timers.add(timer) + return timer + + def set_interval( + self, + interval: float, + callback: TimerCallback | None = None, + *, + name: str | None = None, + repeat: int = 0, + pause: bool = False, + ) -> Timer: + """Call a function at periodic intervals. + + Args: + interval: Time (in seconds) between calls. + callback: Function to call. + name: Name of the timer object. + repeat: Number of times to repeat the call or 0 for continuous. + pause: Start the timer paused. + + Returns: + A timer object. + """ + timer = Timer( + self, + interval, + name=name or f"set_interval#{Timer._timer_count}", + callback=callback, + repeat=repeat or None, + pause=pause, + ) + timer._start() + self._timers.add(timer) + return timer + + def call_after_refresh(self, callback: Callback, *args: Any, **kwargs: Any) -> bool: + """Schedule a callback to run after all messages are processed and the screen + has been refreshed. Positional and keyword arguments are passed to the callable. + + Args: + callback: A callable. + + Returns: + `True` if the callback was scheduled, or `False` if the callback could not be + scheduled (may occur if the message pump was closed or closing). + + """ + # We send the InvokeLater message to ourselves first, to ensure we've cleared + # out anything already pending in our own queue. + + message = messages.InvokeLater(partial(callback, *args, **kwargs)) + return self.post_message(message) + + async def wait_for_refresh(self) -> bool: + """Wait for the next refresh. + + This method should only be called from a task other than the one running this widget. + If called from the same task, it will return immediately to avoid blocking the event loop. + + Returns: + `True` if waiting for refresh was successful, or `False` if the call was a null-op + due to calling it within the node's own task. + + """ + assert ( + self._task is not None + ), "Node must be running before calling wait_for_refresh" + if asyncio.current_task() is self._task: + return False + refreshed_event = asyncio.Event() + self.call_after_refresh(refreshed_event.set) + await refreshed_event.wait() + return True + + def call_later(self, callback: Callback, *args: Any, **kwargs: Any) -> bool: + """Schedule a callback to run after all messages are processed in this object. + Positional and keywords arguments are passed to the callable. + + Args: + callback: Callable to call next. + *args: Positional arguments to pass to the callable. + **kwargs: Keyword arguments to pass to the callable. + + Returns: + `True` if the callback was scheduled, or `False` if the callback could not be + scheduled (may occur if the message pump was closed or closing). + + """ + message = events.Callback(callback=partial(callback, *args, **kwargs)) + return self.post_message(message) + + def call_next(self, callback: Callback, *args: Any, **kwargs: Any) -> None: + """Schedule a callback to run immediately after processing the current message. + + Args: + callback: Callable to run after current event. + *args: Positional arguments to pass to the callable. + **kwargs: Keyword arguments to pass to the callable. + """ + assert callback is not None, "Callback must not be None" + callback_message = events.Callback(callback=partial(callback, *args, **kwargs)) + callback_message._prevent.update(self._get_prevented_messages()) + self._next_callbacks.append(callback_message) + self.check_idle() + + def _on_invoke_later(self, message: messages.InvokeLater) -> None: + # Forward InvokeLater message to the Screen + if self.app._running: + self.app.screen._invoke_later( + message.callback, message._sender or active_message_pump.get() + ) + + async def _close_messages(self, wait: bool = True) -> None: + """Close message queue, and optionally wait for queue to finish processing.""" + if self._closed or self._closing: + return + self._closing = True + if self._timers: + await Timer._stop_all(self._timers) + self._timers.clear() + Reactive._reset_object(self) + self._message_queue.put_nowait(None) + if wait and self._task is not None and asyncio.current_task() != self._task: + try: + running_widget = active_message_pump.get() + except LookupError: + running_widget = None + + if running_widget is None or running_widget is not self: + try: + await self._task + except CancelledError: + pass + + def _start_messages(self) -> None: + """Start messages task.""" + self._thread_init() + + if self.app._running: + self._task = create_task( + self._process_messages(), name=f"message pump {self}" + ) + else: + self._closing = True + self._closed = True + + async def _process_messages(self) -> None: + self._running = True + + with self._context(): + if not await self._pre_process(): + self._running = False + return + + try: + await self._process_messages_loop() + except CancelledError: + pass + finally: + self._running = False + try: + if self._timers: + await Timer._stop_all(self._timers) + self._timers.clear() + Reactive._clear_watchers(self) + finally: + await self._message_loop_exit() + self._task = None + + async def _message_loop_exit(self) -> None: + """Called when the message loop has completed.""" + + async def _pre_process(self) -> bool: + """Procedure to run before processing messages. + + Returns: + `True` if successful, or `False` if any exception occurred. + + """ + # Dispatch compose and mount messages without going through loop + # These events must occur in this order, and at the start. + + try: + await self._dispatch_message(events.Compose()) + if self._prevented_messages_on_mount: + with self.prevent(*self._prevented_messages_on_mount): + await self._dispatch_message(events.Mount()) + else: + await self._dispatch_message(events.Mount()) + self._post_mount() + except Exception as error: + self.app._handle_exception(error) + return False + finally: + # This is critical, mount may be waiting + self._mounted_event.set() + self._is_mounted = True + return True + + def _post_mount(self): + """Called after the object has been mounted.""" + + def _close_messages_no_wait(self) -> None: + """Request the message queue to immediately exit.""" + self._message_queue.put_nowait(messages.CloseMessages()) + + @contextmanager + def _context(self) -> Generator[None, None, None]: + """Context manager to set ContextVars.""" + reset_token = active_message_pump.set(self) + try: + yield + finally: + active_message_pump.reset(reset_token) + + async def _on_close_messages(self, message: messages.CloseMessages) -> None: + await self._close_messages() + + async def _process_messages_loop(self) -> None: + """Process messages until the queue is closed.""" + _rich_traceback_guard = True + self._thread_id = threading.get_ident() + await asyncio.sleep(0) + while not self._closed: + try: + message = await self._get_message() + except MessagePumpClosed: + break + except CancelledError: + raise + except Exception as error: + raise error from None + + # Combine any pending messages that may supersede this one + while not (self._closed or self._closing): + try: + pending = self._peek_message() + except MessagePumpClosed: + break + if pending is None or not message.can_replace(pending): + break + try: + message = await self._get_message() + except MessagePumpClosed: + break + + try: + await self._dispatch_message(message) + except CancelledError: + raise + except Exception as error: + self._mounted_event.set() + self._is_mounted = True + self.app._handle_exception(error) + break + finally: + self.message_signal.publish(message) + self._message_queue.task_done() + + current_time = time() + + # Insert idle events + if self._message_queue.empty() or ( + self._max_idle is not None + and current_time - self._last_idle > self._max_idle + ): + self._last_idle = current_time + if not self._closed: + event = events.Idle() + for _cls, method in self._get_dispatch_methods( + "on_idle", event + ): + try: + await invoke(method, event) + except Exception as error: + self.app._handle_exception(error) + break + await self._flush_next_callbacks() + + async def _flush_next_callbacks(self) -> None: + """Invoke pending callbacks in next callbacks queue.""" + callbacks = self._next_callbacks.copy() + self._next_callbacks.clear() + for callback in callbacks: + try: + with self.prevent(*callback._prevent): + await invoke(callback.callback) + except Exception as error: + self.app._handle_exception(error) + break + + async def _dispatch_message(self, message: Message) -> None: + """Dispatch a message received from the message queue. + + Args: + message: A message object + """ + _rich_traceback_guard = True + if message.no_dispatch: + return + + try: + message_hook = message_hook_context_var.get() + except LookupError: + pass + else: + message_hook(message) + + with self.prevent(*message._prevent): + # Allow apps to treat events and messages separately + if isinstance(message, Event): + await self.on_event(message) + elif "debug" in self.app.features: + start = perf_counter() + await self._on_message(message) + if perf_counter() - start > SLOW_THRESHOLD / 1000: + log.warning( + f"method=<{self.__class__.__name__}." + f"{message.handler_name}>", + f"Took over {SLOW_THRESHOLD}ms to process.", + "\nTo avoid screen freezes, consider using a worker.", + ) + else: + await self._on_message(message) + if self._next_callbacks: + await self._flush_next_callbacks() + + def _get_dispatch_methods( + self, method_name: str, message: Message + ) -> Iterable[tuple[type, Callable[[Message], Awaitable]]]: + """Gets handlers from the MRO + + Args: + method_name: Handler method name. + message: Message object. + """ + from memray._vendor.textual.widget import Widget + + methods_dispatched: set[Callable] = set() + message_mro = [ + _type for _type in message.__class__.__mro__ if issubclass(_type, Message) + ] + for cls in self.__class__.__mro__: + if message._no_default_action: + break + # Try decorated handlers first + decorated_handlers = cast( + "dict[type[Message], list[tuple[Callable, dict[str, tuple[SelectorSet, ...]]]]] | None", + cls.__dict__.get("_decorated_handlers"), + ) + + if decorated_handlers: + for message_class in message_mro: + handlers = decorated_handlers.get(message_class, []) + + for method, selectors in handlers: + if method in methods_dispatched: + continue + if not selectors: + yield cls, method.__get__(self, cls) + methods_dispatched.add(method) + else: + if not message._sender: + continue + for attribute, selector in selectors.items(): + node = getattr(message, attribute) + if node is None: + break + if not isinstance(node, Widget): + raise OnNoWidget( + f"on decorator can't match against {attribute!r} as it is not a widget." + ) + if not match(selector, node): + break + else: + yield cls, method.__get__(self, cls) + methods_dispatched.add(method) + + # Fall back to the naming convention + # But avoid calling the handler if it was decorated + method = cls.__dict__.get(f"_{method_name}") or cls.__dict__.get( + method_name + ) + if method is not None and not getattr(method, "_textual_on", None): + yield cls, method.__get__(self, cls) + + async def on_event(self, event: events.Event) -> None: + """Called to process an event. + + Args: + event: An Event object. + """ + await self._on_message(event) + + async def _on_message(self, message: Message) -> None: + """Called to process a message. + + Args: + message: A Message object. + """ + _rich_traceback_guard = True + handler_name = message.handler_name + + # Look through the MRO to find a handler + dispatched = False + for cls, method in self._get_dispatch_methods(handler_name, message): + log.event.verbosity(message.verbose)( + message, + ">>>", + self, + f"method=<{cls.__name__}.{handler_name}>", + ) + dispatched = True + await invoke(method, message) + if not dispatched: + log.event.verbosity(message.verbose)(message, ">>>", self, "method=None") + + # Bubble messages up the DOM (if enabled on the message) + if message.bubble and self._parent and not message._stop_propagation: + if message._sender is not None and message._sender == self._parent: + # parent is sender, so we stop propagation after parent + message.stop() + if self.is_parent_active and self.is_attached: + message._bubble_to(self._parent) + + def check_idle(self) -> None: + """Prompt the message pump to call idle if the queue is empty.""" + if self._running and self._message_queue.empty(): + self.post_message(messages.Prompt()) + + async def _post_message(self, message: Message) -> bool: + """Post a message or an event to this message pump. + + This is an internal method for use where a coroutine is required. + + Args: + message: A message object. + + Returns: + True if the messages was posted successfully, False if the message was not posted + (because the message pump was in the process of closing). + """ + return self.post_message(message) + + def post_message(self, message: Message) -> bool: + """Posts a message on to this widget's queue. + + Args: + message: A message (including Event). + + Returns: + `True` if the message was queued for processing, otherwise `False`. + """ + _rich_traceback_omit = True + if not hasattr(message, "_prevent"): + # Catch a common error (forgetting to call super) + raise RuntimeError( + "Message is missing attributes; did you forget to call super().__init__() ?" + ) + if self._closing or self._closed: + return False + if not self.check_message_enabled(message): + return False + # Add a copy of the prevented message types to the message + # This is so that prevented messages are honoured by the event's handler + message._prevent.update(self._get_prevented_messages()) + if self._thread_id != threading.get_ident() and self.app._loop is not None: + # If we're not calling from the same thread, make it threadsafe + loop = self.app._loop + loop.call_soon_threadsafe(self._message_queue.put_nowait, message) + else: + self._message_queue.put_nowait(message) + return True + + async def on_callback(self, event: events.Callback) -> None: + if self.app._closing: + return + try: + self.app.screen + except Exception: + self.log.warning( + f"Not invoking timer callback {event.callback!r} because there is no screen." + ) + return + await invoke(event.callback) + + async def on_timer(self, event: events.Timer) -> None: + if not self.app._running: + return + event.prevent_default() + event.stop() + if event.callback is not None: + try: + self.app.screen + except Exception: + self.log.warning( + f"Not invoking timer callback {event.callback!r} because there is no screen." + ) + return + try: + await invoke(event.callback) + except Exception as error: + raise CallbackError( + f"unable to run callback {event.callback!r}; {error}" + ) diff --git a/src/memray/_vendor/textual/messages.py b/src/memray/_vendor/textual/messages.py new file mode 100644 index 0000000000..a31b8396ce --- /dev/null +++ b/src/memray/_vendor/textual/messages.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import rich.repr + +from memray._vendor.textual._types import CallbackType +from memray._vendor.textual.geometry import Region +from memray._vendor.textual.message import Message + +if TYPE_CHECKING: + from memray._vendor.textual.widget import Widget + + +@rich.repr.auto +class CloseMessages(Message, verbose=True): + """Requests message pump to close.""" + + +@rich.repr.auto +class Prune(Message, verbose=True, bubble=False): + """Ask the node to prune (remove from DOM).""" + + +@rich.repr.auto +class ExitApp(Message, verbose=True): + """Exit the app.""" + + +@rich.repr.auto +class Update(Message, verbose=True): + """Sent by Textual to request the update of a widget.""" + + def __init__(self, widget: Widget) -> None: + super().__init__() + self.widget = widget + + def __rich_repr__(self) -> rich.repr.Result: + yield self.widget + + def __eq__(self, other: object) -> bool: + if isinstance(other, Update): + return self.widget == other.widget + return NotImplemented + + def can_replace(self, message: Message) -> bool: + # Update messages can replace update for the same widget + return isinstance(message, Update) and self.widget == message.widget + + +@rich.repr.auto +class Layout(Message, verbose=True): + """Sent by Textual when a layout is required.""" + + def __init__(self, widget: Widget) -> None: + super().__init__() + self.widget = widget + + def can_replace(self, message: Message) -> bool: + return isinstance(message, Layout) + + +@rich.repr.auto +class UpdateScroll(Message, verbose=True): + """Sent by Textual when a scroll update is required.""" + + def can_replace(self, message: Message) -> bool: + return isinstance(message, UpdateScroll) + + +@rich.repr.auto +class InvokeLater(Message, verbose=True, bubble=False): + """Sent by Textual to invoke a callback.""" + + def __init__(self, callback: CallbackType) -> None: + self.callback = callback + super().__init__() + + def __rich_repr__(self) -> rich.repr.Result: + yield "callback", self.callback + + +@rich.repr.auto +class ScrollToRegion(Message, bubble=False): + """Ask the parent to scroll a given region into view.""" + + def __init__(self, region: Region) -> None: + self.region = region + super().__init__() + + +class Prompt(Message, no_dispatch=True): + """Used to 'wake up' an event loop.""" + + def can_replace(self, message: Message) -> bool: + return isinstance(message, Prompt) + + +class TerminalSupportsSynchronizedOutput(Message): + """ + Used to make the App aware that the terminal emulator supports synchronised output. + @link https://gist.github.com/christianparpart/d8a62cc1ab659194337d73e399004036 + """ + + +@rich.repr.auto +class InBandWindowResize(Message): + """Reports if the in-band window resize protocol is supported. + + https://gist.github.com/rockorager/e695fb2924d36b2bcf1fff4a3704bd83""" + + def __init__(self, supported: bool, enabled: bool) -> None: + """Initialize message. + + Args: + supported: Is the protocol supported? + enabled: Is the protocol enabled. + """ + self.supported = supported + self.enabled = enabled + super().__init__() + + def __rich_repr__(self) -> rich.repr.Result: + yield "supported", self.supported + yield "enabled", self.enabled + + @classmethod + def from_setting_parameter(cls, setting_parameter: int) -> InBandWindowResize: + """Construct the message from the setting parameter. + + Args: + setting_parameter: Setting parameter from stdin. + + Returns: + New InBandWindowResize instance. + """ + + supported = setting_parameter not in (0, 4) + enabled = setting_parameter in (1, 3) + return InBandWindowResize(supported, enabled) diff --git a/src/memray/_vendor/textual/notifications.py b/src/memray/_vendor/textual/notifications.py new file mode 100644 index 0000000000..a79fb7ffff --- /dev/null +++ b/src/memray/_vendor/textual/notifications.py @@ -0,0 +1,120 @@ +"""Provides classes for holding and managing notifications.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from time import time +from typing import Iterator +from uuid import uuid4 + +from rich.repr import Result +from typing_extensions import Literal, Self, TypeAlias + +from memray._vendor.textual.message import Message + +SeverityLevel: TypeAlias = Literal["information", "warning", "error"] +"""The severity level for a notification.""" + + +@dataclass +class Notify(Message, bubble=False): + """Message to show a notification.""" + + notification: Notification + + +@dataclass +class Notification: + """Holds the details of a notification.""" + + message: str + """The message for the notification.""" + + title: str = "" + """The title for the notification.""" + + severity: SeverityLevel = "information" + """The severity level for the notification.""" + + timeout: float = 5 + """The timeout (in seconds) for the notification.""" + + markup: bool = False + """Render the notification message as content markup?""" + + raised_at: float = field(default_factory=time) + """The time when the notification was raised (in Unix time).""" + + identity: str = field(default_factory=lambda: str(uuid4())) + """The unique identity of the notification.""" + + @property + def time_left(self) -> float: + """The time left until this notification expires""" + return (self.raised_at + self.timeout) - time() + + @property + def has_expired(self) -> bool: + """Has the notification expired?""" + return self.time_left <= 0 + + def __rich_repr__(self) -> Result: + yield "message", self.message + yield "title", self.title, "" + yield "severity", self.severity + yield "raised_it", self.raised_at + yield "identity", self.identity + yield "time_left", self.time_left + yield "has_expired", self.has_expired + + +class Notifications: + """Class for managing a collection of notifications.""" + + def __init__(self) -> None: + """Initialise the notification collection.""" + self._notifications: dict[str, Notification] = {} + + def _reap(self) -> Self: + """Remove any expired notifications from the notification collection.""" + for notification in list(self._notifications.values()): + if notification.has_expired: + del self._notifications[notification.identity] + return self + + def add(self, notification: Notification) -> Self: + """Add the given notification to the collection of managed notifications. + + Args: + notification: The notification to add. + + Returns: + Self. + """ + self._reap()._notifications[notification.identity] = notification + return self + + def clear(self) -> Self: + """Clear all the notifications.""" + self._notifications.clear() + return self + + def __len__(self) -> int: + """The number of notifications.""" + return len(self._reap()._notifications) + + def __iter__(self) -> Iterator[Notification]: + return iter(self._reap()._notifications.values()) + + def __contains__(self, notification: Notification) -> bool: + return notification.identity in self._notifications + + def __delitem__(self, notification: Notification) -> None: + try: + del self._reap()._notifications[notification.identity] + except KeyError: + # An attempt to remove a notification we don't know about is a + # no-op. What matters here is that the notification is forgotten + # about, and it looks like a caller has tried to be + # belt-and-braces. We're fine with this. + pass diff --git a/src/memray/_vendor/textual/pad.py b/src/memray/_vendor/textual/pad.py new file mode 100644 index 0000000000..e92cf71ccd --- /dev/null +++ b/src/memray/_vendor/textual/pad.py @@ -0,0 +1,79 @@ +from typing import cast + +from rich.align import Align, AlignMethod +from rich.console import ( + Console, + ConsoleOptions, + JustifyMethod, + RenderableType, + RenderResult, +) +from rich.measure import Measurement +from rich.segment import Segment, Segments +from rich.style import Style + + +class HorizontalPad: + """Rich renderable to add padding on the left and right of a renderable. + + Note that unlike Rich's Padding class this align each line independently. + + """ + + def __init__( + self, + renderable: RenderableType, + left: int, + right: int, + pad_style: Style, + justify: JustifyMethod, + ) -> None: + """ + Initialize HorizontalPad. + + Args: + renderable: A Rich renderable. + left: Left padding. + right: Right padding. + pad_style: Style of padding. + justify: Justify method. + """ + self.renderable = renderable + self.left = left + self.right = right + self.pad_style = pad_style + self.justify = justify + + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + options = options.update( + width=options.max_width - self.left - self.right, height=None + ) + lines = console.render_lines(self.renderable, options, pad=False) + left_pad = Segment(" " * self.left, self.pad_style) + right_pad = Segment(" " * self.right, self.pad_style) + + align: AlignMethod = cast( + AlignMethod, + self.justify if self.justify in {"left", "right", "center"} else "left", + ) + + for line in lines: + pad_line = line + if self.left: + pad_line = [left_pad, *line] + if self.right: + pad_line.append(right_pad) + segments = Segments(pad_line) + yield Align(segments, align=align) + + def __rich_measure__( + self, console: "Console", options: "ConsoleOptions" + ) -> Measurement: + measurement = Measurement.get(console, options, self.renderable) + total_padding = self.left + self.right + return Measurement( + measurement.minimum + total_padding, + measurement.maximum + total_padding, + ) diff --git a/src/memray/_vendor/textual/pilot.py b/src/memray/_vendor/textual/pilot.py new file mode 100644 index 0000000000..2c2cb4ce86 --- /dev/null +++ b/src/memray/_vendor/textual/pilot.py @@ -0,0 +1,570 @@ +""" + +This module contains the `Pilot` class used by [App.run_test][textual.app.App.run_test] to programmatically operate an app. + +See the guide on how to [test Textual apps](/guide/testing). + +""" + +from __future__ import annotations + +import asyncio +from typing import Any, Generic + +import rich.repr + +from memray._vendor.textual._wait import wait_for_idle +from memray._vendor.textual.app import App, ReturnType +from memray._vendor.textual.drivers.headless_driver import HeadlessDriver +from memray._vendor.textual.events import Click, MouseDown, MouseEvent, MouseMove, MouseUp, Resize +from memray._vendor.textual.geometry import Offset, Size +from memray._vendor.textual.widget import Widget + + +def _get_mouse_message_arguments( + target: Widget, + offset: tuple[int, int] = (0, 0), + button: int = 0, + shift: bool = False, + meta: bool = False, + control: bool = False, +) -> dict[str, Any]: + """Get the arguments to pass into mouse messages for the click and hover methods.""" + click_x, click_y = target.region.offset + offset + message_arguments = { + "widget": target, + "x": click_x, + "y": click_y, + "delta_x": 0, + "delta_y": 0, + "button": button, + "shift": shift, + "meta": meta, + "ctrl": control, + "screen_x": click_x, + "screen_y": click_y, + } + return message_arguments + + +class OutOfBounds(Exception): + """Raised when the pilot mouse target is outside of the (visible) screen.""" + + +class WaitForScreenTimeout(Exception): + """Exception raised if messages aren't being processed quickly enough. + + If this occurs, the most likely explanation is some kind of deadlock in the app code. + """ + + +@rich.repr.auto(angular=True) +class Pilot(Generic[ReturnType]): + """Pilot object to drive an app.""" + + def __init__(self, app: App[ReturnType]) -> None: + self._app = app + + def __rich_repr__(self) -> rich.repr.Result: + yield "app", self._app + + @property + def app(self) -> App[ReturnType]: + """App: A reference to the application.""" + return self._app + + async def press(self, *keys: str) -> None: + """Simulate key-presses. + + Args: + *keys: Keys to press. + """ + if keys: + await self._app._press_keys(keys) + await self._wait_for_screen() + + async def resize_terminal(self, width: int, height: int) -> None: + """Resize the terminal to the given dimensions. + + Args: + width: The new width of the terminal. + height: The new height of the terminal. + """ + size = Size(width, height) + # If we're running with the headless driver, update the inherent app size. + if isinstance(self.app._driver, HeadlessDriver): + self.app._driver._size = size + self.app.post_message(Resize(size, size)) + await self.pause() + + async def mouse_down( + self, + widget: Widget | type[Widget] | str | None = None, + offset: tuple[int, int] = (0, 0), + shift: bool = False, + meta: bool = False, + control: bool = False, + button: int = 1, + ) -> bool: + """Simulate a [`MouseDown`][textual.events.MouseDown] event at a specified position. + + The final position for the event is computed based on the selector provided and + the offset specified and it must be within the visible area of the screen. + + Args: + widget: A widget or selector used as an origin + for the event offset. If this is not specified, the offset is interpreted + relative to the screen. You can use this parameter to try to target a + specific widget. However, if the widget is currently hidden or obscured by + another widget, the event may not land on the widget you specified. + offset: The offset for the event. The offset is relative to the selector / widget + provided or to the screen, if no selector is provided. + shift: Simulate the event with the shift key held down. + meta: Simulate the event with the meta key held down. + control: Simulate the event with the control key held down. + button: The mouse button to press. + + Raises: + OutOfBounds: If the position for the event is outside of the (visible) screen. + + Returns: + True if no selector was specified or if the event landed on the selected + widget, False otherwise. + """ + try: + return await self._post_mouse_events( + [MouseMove, MouseDown], + widget=widget, + offset=offset, + button=button, + shift=shift, + meta=meta, + control=control, + ) + except OutOfBounds as error: + raise error from None + + async def mouse_up( + self, + widget: Widget | type[Widget] | str | None = None, + offset: tuple[int, int] = (0, 0), + shift: bool = False, + meta: bool = False, + control: bool = False, + ) -> bool: + """Simulate a [`MouseUp`][textual.events.MouseUp] event at a specified position. + + The final position for the event is computed based on the selector provided and + the offset specified and it must be within the visible area of the screen. + + Args: + widget: A widget or selector used as an origin + for the event offset. If this is not specified, the offset is interpreted + relative to the screen. You can use this parameter to try to target a + specific widget. However, if the widget is currently hidden or obscured by + another widget, the event may not land on the widget you specified. + offset: The offset for the event. The offset is relative to the widget / selector + provided or to the screen, if no selector is provided. + shift: Simulate the event with the shift key held down. + meta: Simulate the event with the meta key held down. + control: Simulate the event with the control key held down. + + Raises: + OutOfBounds: If the position for the event is outside of the (visible) screen. + + Returns: + True if no selector was specified or if the event landed on the selected + widget, False otherwise. + """ + try: + return await self._post_mouse_events( + [MouseMove, MouseUp], + widget=widget, + offset=offset, + button=1, + shift=shift, + meta=meta, + control=control, + ) + except OutOfBounds as error: + raise error from None + + async def click( + self, + widget: Widget | type[Widget] | str | None = None, + offset: tuple[int, int] = (0, 0), + shift: bool = False, + meta: bool = False, + control: bool = False, + times: int = 1, + button: int = 1, + ) -> bool: + """Simulate clicking with the mouse at a specified position. + + The final position to be clicked is computed based on the selector provided and + the offset specified and it must be within the visible area of the screen. + + Implementation note: This method bypasses the normal event processing in `App.on_event`. + + Example: + The code below runs an app and clicks its only button right in the middle: + ```py + async with SingleButtonApp().run_test() as pilot: + await pilot.click(Button, offset=(8, 1)) + ``` + + Args: + widget: A widget or selector used as an origin + for the click offset. If this is not specified, the offset is interpreted + relative to the screen. You can use this parameter to try to click on a + specific widget. However, if the widget is currently hidden or obscured by + another widget, the click may not land on the widget you specified. + offset: The offset to click. The offset is relative to the widget / selector provided + or to the screen, if no selector is provided. + shift: Click with the shift key held down. + meta: Click with the meta key held down. + control: Click with the control key held down. + times: The number of times to click. 2 will double-click, 3 will triple-click, etc. + button: The mouse button to click. + + Raises: + OutOfBounds: If the position to be clicked is outside of the (visible) screen. + + Returns: + `True` if no selector was specified or if the selected widget was under the mouse + when the click was initiated. `False` is the selected widget was not under the pointer. + """ + try: + return await self._post_mouse_events( + [MouseDown, MouseUp, Click], + widget=widget, + offset=offset, + button=button, + shift=shift, + meta=meta, + control=control, + times=times, + ) + except OutOfBounds as error: + raise error from None + + async def double_click( + self, + widget: Widget | type[Widget] | str | None = None, + offset: tuple[int, int] = (0, 0), + shift: bool = False, + meta: bool = False, + control: bool = False, + button: int = 1, + ) -> bool: + """Simulate double clicking with the mouse at a specified position. + + Alias for `pilot.click(..., times=2)`. + + The final position to be clicked is computed based on the selector provided and + the offset specified and it must be within the visible area of the screen. + + Implementation note: This method bypasses the normal event processing in `App.on_event`. + + Example: + The code below runs an app and double-clicks its only button right in the middle: + ```py + async with SingleButtonApp().run_test() as pilot: + await pilot.double_click(Button, offset=(8, 1)) + ``` + + Args: + widget: A widget or selector used as an origin + for the click offset. If this is not specified, the offset is interpreted + relative to the screen. You can use this parameter to try to click on a + specific widget. However, if the widget is currently hidden or obscured by + another widget, the click may not land on the widget you specified. + offset: The offset to click. The offset is relative to the widget / selector provided + or to the screen, if no selector is provided. + shift: Click with the shift key held down. + meta: Click with the meta key held down. + control: Click with the control key held down. + button: The mouse button to click. + + Raises: + OutOfBounds: If the position to be clicked is outside of the (visible) screen. + + Returns: + `True` if no selector was specified or if the selected widget was under the mouse + when the click was initiated. `False` is the selected widget was not under the pointer. + """ + return await self.click( + widget, offset, shift, meta, control, times=2, button=button + ) + + async def triple_click( + self, + widget: Widget | type[Widget] | str | None = None, + offset: tuple[int, int] = (0, 0), + shift: bool = False, + meta: bool = False, + control: bool = False, + button: int = 1, + ) -> bool: + """Simulate triple clicking with the mouse at a specified position. + + Alias for `pilot.click(..., times=3)`. + + The final position to be clicked is computed based on the selector provided and + the offset specified and it must be within the visible area of the screen. + + Implementation note: This method bypasses the normal event processing in `App.on_event`. + + Example: + The code below runs an app and triple-clicks its only button right in the middle: + ```py + async with SingleButtonApp().run_test() as pilot: + await pilot.triple_click(Button, offset=(8, 1)) + ``` + + Args: + widget: A widget or selector used as an origin + for the click offset. If this is not specified, the offset is interpreted + relative to the screen. You can use this parameter to try to click on a + specific widget. However, if the widget is currently hidden or obscured by + another widget, the click may not land on the widget you specified. + offset: The offset to click. The offset is relative to the widget / selector provided + or to the screen, if no selector is provided. + shift: Click with the shift key held down. + meta: Click with the meta key held down. + control: Click with the control key held down. + button: The mouse button to click. + + Raises: + OutOfBounds: If the position to be clicked is outside of the (visible) screen. + + Returns: + `True` if no selector was specified or if the selected widget was under the mouse + when the click was initiated. `False` is the selected widget was not under the pointer. + """ + return await self.click( + widget, offset, shift, meta, control, times=3, button=button + ) + + async def hover( + self, + widget: Widget | type[Widget] | str | None | None = None, + offset: tuple[int, int] = (0, 0), + ) -> bool: + """Simulate hovering with the mouse cursor at a specified position. + + The final position to be hovered is computed based on the selector provided and + the offset specified and it must be within the visible area of the screen. + + Args: + widget: A widget or selector used as an origin + for the hover offset. If this is not specified, the offset is interpreted + relative to the screen. You can use this parameter to try to hover a + specific widget. However, if the widget is currently hidden or obscured by + another widget, the hover may not land on the widget you specified. + offset: The offset to hover. The offset is relative to the widget / selector provided + or to the screen, if no selector is provided. + + Raises: + OutOfBounds: If the position to be hovered is outside of the (visible) screen. + + Returns: + True if no selector was specified or if the hover landed on the selected + widget, False otherwise. + """ + # This is usually what the user wants because it gives time for the mouse to + # "settle" before moving it to the new hover position. + await self.pause() + try: + return await self._post_mouse_events([MouseMove], widget, offset, button=0) + except OutOfBounds as error: + raise error from None + + async def _post_mouse_events( + self, + events: list[type[MouseEvent]], + widget: Widget | type[Widget] | str | None | None = None, + offset: tuple[int, int] = (0, 0), + button: int = 0, + shift: bool = False, + meta: bool = False, + control: bool = False, + times: int = 1, + ) -> bool: + """Simulate a series of mouse events to be fired at a given position. + + The final position for the events is computed based on the selector provided and + the offset specified and it must be within the visible area of the screen. + + This function abstracts away the commonalities of the other mouse event-related + functions that the pilot exposes. + + Args: + widget: A widget or selector used as the origin + for the event's offset. If this is not specified, the offset is interpreted + relative to the screen. You can use this parameter to try to target a + specific widget. However, if the widget is currently hidden or obscured by + another widget, the events may not land on the widget you specified. + offset: The offset for the events. The offset is relative to the widget / selector + provided or to the screen, if no selector is provided. + shift: Simulate the events with the shift key held down. + meta: Simulate the events with the meta key held down. + control: Simulate the events with the control key held down. + times: The number of times to click. 2 will double-click, 3 will triple-click, etc. + Raises: + OutOfBounds: If the position for the events is outside of the (visible) screen. + + Returns: + True if no selector was specified or if the *final* event landed on the + selected widget, False otherwise. + """ + app = self.app + screen = app.screen + target_widget: Widget + if widget is None: + target_widget = screen + elif isinstance(widget, Widget): + target_widget = widget + else: + target_widget = screen.query_one(widget) + + message_arguments = _get_mouse_message_arguments( + target_widget, + offset, + button=button, + shift=shift, + meta=meta, + control=control, + ) + + offset = Offset(message_arguments["x"], message_arguments["y"]) + if offset not in screen.region: + raise OutOfBounds( + "Target offset is outside of currently-visible screen region." + ) + + widget_at = None + for chain in range(1, times + 1): + for mouse_event_cls in events: + await self.pause() + # Get the widget under the mouse before the event because the app might + # react to the event and move things around. We override on each iteration + # because we assume the final event in `events` is the actual event we care + # about and that all the preceding events are just setup. + # E.g., the click event is preceded by MouseDown/MouseUp to emulate how + # the driver works and emits a click event. + kwargs = message_arguments + if mouse_event_cls is Click: + kwargs = {**kwargs, "chain": chain} + + if widget_at is None: + widget_at, _ = app.get_widget_at(*offset) + event = mouse_event_cls(**kwargs) + # Bypass event processing in App.on_event. Because App.on_event + # is responsible for updating App.mouse_position, and because + # that's useful to other things (tooltip handling, for example), + # we patch the offset in there as well. + app.mouse_position = offset + screen._forward_event(event) + + await self.pause() + return widget is None or widget_at is target_widget + + async def _wait_for_screen(self, timeout: float = 30.0) -> bool: + """Wait for the current screen and its children to have processed all pending events. + + Args: + timeout: A timeout in seconds to wait. + + Returns: + `True` if all events were processed. `False` if an exception occurred, + meaning that not all events could be processed. + + Raises: + WaitForScreenTimeout: If the screen and its children didn't finish processing within the timeout. + """ + try: + screen = self.app.screen + except Exception: + return False + children = [self.app, *screen.walk_children(with_self=True)] + count = 0 + count_zero_event = asyncio.Event() + + def decrement_counter() -> None: + """Decrement internal counter, and set an event if it reaches zero.""" + nonlocal count + count -= 1 + if count == 0: + # When count is zero, all messages queued at the start of the method have been processed + count_zero_event.set() + + # Increase the count for every successful call_later + for child in children: + if child.call_later(decrement_counter): + count += 1 + + if count: + # Wait for the count to return to zero, or a timeout, or an exception + wait_for = [ + asyncio.create_task(count_zero_event.wait()), + asyncio.create_task(self.app._exception_event.wait()), + ] + _, pending = await asyncio.wait( + wait_for, + timeout=timeout, + return_when=asyncio.FIRST_COMPLETED, + ) + + for task in pending: + task.cancel() + + timed_out = len(wait_for) == len(pending) + if timed_out: + raise WaitForScreenTimeout( + "Timed out while waiting for widgets to process pending messages." + ) + + # We've either timed out, encountered an exception, or we've finished + # decrementing all the counters (all events processed in children). + if count > 0: + return False + + return True + + async def pause(self, delay: float | None = None) -> None: + """Insert a pause. + + Args: + delay: Seconds to pause, or None to wait for cpu idle. + """ + # These sleep zeros, are to force asyncio to give up a time-slice. + await self._wait_for_screen() + if delay is None: + await wait_for_idle(0) + else: + await asyncio.sleep(delay) + self.app.screen._on_timer_update() + + async def wait_for_animation(self) -> None: + """Wait for any current animation to complete.""" + await self._app.animator.wait_for_idle() + self.app.screen._on_timer_update() + + async def wait_for_scheduled_animations(self) -> None: + """Wait for any current and scheduled animations to complete.""" + await self._wait_for_screen() + await self._app.animator.wait_until_complete() + await self._wait_for_screen() + await wait_for_idle() + self.app.screen._on_timer_update() + + async def exit(self, result: ReturnType) -> None: + """Exit the app with the given result. + + Args: + result: The app result returned by `run` or `run_async`. + """ + await self._wait_for_screen() + await wait_for_idle() + self.app.exit(result) diff --git a/src/memray/_vendor/textual/py.typed b/src/memray/_vendor/textual/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/memray/_vendor/textual/reactive.py b/src/memray/_vendor/textual/reactive.py new file mode 100644 index 0000000000..aa67009218 --- /dev/null +++ b/src/memray/_vendor/textual/reactive.py @@ -0,0 +1,532 @@ +""" + +This module contains the `Reactive` class which implements [reactivity](/guide/reactivity/). +""" + +from __future__ import annotations + +from functools import partial +from inspect import isawaitable +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + ClassVar, + Generic, + Type, + TypeVar, + cast, + overload, +) + +import rich.repr + +from memray._vendor.textual import events +from memray._vendor.textual._callback import count_parameters +from memray._vendor.textual._types import ( + MessageTarget, + WatchCallbackBothValuesType, + WatchCallbackNewValueType, + WatchCallbackNoArgsType, + WatchCallbackType, +) + +if TYPE_CHECKING: + from memray._vendor.textual.dom import DOMNode + + Reactable = DOMNode + +ReactiveType = TypeVar("ReactiveType") +ReactableType = TypeVar("ReactableType", bound="DOMNode") + + +class _Mutated: + """A wrapper to indicate a value was mutated.""" + + def __init__(self, value: Any) -> None: + self.value = value + + +class ReactiveError(Exception): + """Base class for reactive errors.""" + + +class TooManyComputesError(ReactiveError): + """Raised when an attribute has public and private compute methods.""" + + +class Initialize(Generic[ReactiveType]): + """Initialize a reactive by calling a method parent object. + + Example: + ```python + class InitializeApp(App): + + def get_names(self) -> list[str]: + return ["foo", "bar", "baz"] + + # The `names` property will call `get_names` to get its default when first referenced. + names = reactive(Initialize(get_names)) + ``` + + """ + + def __init__(self, callback: Callable[[ReactableType], ReactiveType]) -> None: + self.callback = callback + + def __call__(self, obj: ReactableType) -> ReactiveType: + return self.callback(obj) + + +async def await_watcher(obj: Reactable, awaitable: Awaitable[object]) -> None: + """Coroutine to await an awaitable returned from a watcher""" + _rich_traceback_omit = True + await awaitable + # Watcher may have changed the state, so run compute again + obj.post_message(events.Callback(callback=partial(Reactive._compute, obj))) + + +def invoke_watcher( + watcher_object: Reactable, + watch_function: WatchCallbackType, + old_value: object, + value: object, +) -> None: + """Invoke a watch function. + + Args: + watcher_object: The object watching for the changes. + watch_function: A watch function, which may be sync or async. + old_value: The old value of the attribute. + value: The new value of the attribute. + """ + _rich_traceback_omit = True + + param_count = count_parameters(watch_function) + + with watcher_object._context(): + if param_count == 2: + watch_result = cast(WatchCallbackBothValuesType, watch_function)( + old_value, value + ) + elif param_count == 1: + watch_result = cast(WatchCallbackNewValueType, watch_function)(value) + else: + watch_result = cast(WatchCallbackNoArgsType, watch_function)() + if isawaitable(watch_result): + # Result is awaitable, so we need to await it within an async context + watcher_object.call_next( + partial(await_watcher, watcher_object, watch_result) + ) + + +@rich.repr.auto +class Reactive(Generic[ReactiveType]): + """Reactive descriptor. + + Args: + default: A default value or callable that returns a default. + layout: Perform a layout on change. + repaint: Perform a repaint on change. + init: Call watchers on initialize (post mount). + always_update: Call watchers even when the new value equals the old value. + compute: Run compute methods when attribute is changed. + recompose: Compose the widget again when the attribute changes. + bindings: Refresh bindings when the reactive changes. + toggle_class: An optional TCSS classname(s) to toggle based on the truthiness of the value. + """ + + _reactives: ClassVar[dict[str, object]] = {} + + def __init__( + self, + default: ReactiveType | Callable[[], ReactiveType] | Initialize[ReactiveType], + *, + layout: bool = False, + repaint: bool = True, + init: bool = False, + always_update: bool = False, + compute: bool = True, + recompose: bool = False, + bindings: bool = False, + toggle_class: str | None = None, + ) -> None: + self._default = default + self._layout = layout + self._repaint = repaint + self._init = init + self._always_update = always_update + self._run_compute = compute + self._recompose = recompose + self._bindings = bindings + self._toggle_class = toggle_class + self._owner: Type[MessageTarget] | None = None + self.name: str + + def __rich_repr__(self) -> rich.repr.Result: + yield None, self._default + yield "layout", self._layout, False + yield "repaint", self._repaint, True + yield "init", self._init, False + yield "always_update", self._always_update, False + yield "compute", self._run_compute, True + yield "recompose", self._recompose, False + yield "bindings", self._bindings, False + yield "name", getattr(self, "name", None), None + + @classmethod + def _clear_watchers(cls, obj: Reactable) -> None: + """Clear any watchers on a given object. + + Args: + obj: A reactive object. + """ + try: + getattr(obj, "__watchers").clear() + except AttributeError: + pass + + @property + def owner(self) -> Type[MessageTarget]: + """The owner (class) where the reactive was declared.""" + assert self._owner is not None + return self._owner + + def _initialize_reactive(self, obj: Reactable, name: str) -> None: + """Initialized a reactive attribute on an object. + + Args: + obj: An object with reactive attributes. + name: Name of attribute. + """ + _rich_traceback_omit = True + + internal_name = f"_reactive_{name}" + if hasattr(obj, internal_name): + # Attribute already has a value + return + + compute_method = getattr(obj, self.compute_name, None) + if compute_method is not None and self._init: + default = compute_method() + else: + default_or_callable = self._default + default = ( + ( + default_or_callable(obj) + if isinstance(default_or_callable, Initialize) + else default_or_callable() + ) + if callable(default_or_callable) + else default_or_callable + ) + setattr(obj, internal_name, default) + if (toggle_class := self._toggle_class) is not None: + obj.set_class(bool(default), *toggle_class.split()) + if self._init: + self._check_watchers(obj, name, default) + + @classmethod + def _initialize_object(cls, obj: Reactable) -> None: + """Set defaults and call any watchers / computes for the first time. + + Args: + obj: An object with Reactive descriptors + """ + _rich_traceback_omit = True + for name, reactive in obj._reactives.items(): + reactive._initialize_reactive(obj, name) + + @classmethod + def _reset_object(cls, obj: object) -> None: + """Reset reactive structures on object (to avoid reference cycles). + + Args: + obj: A reactive object. + """ + getattr(obj, "__watchers", {}).clear() + getattr(obj, "__computes", []).clear() + + def __set_name__(self, owner: Type[MessageTarget], name: str) -> None: + # Check for compute method + self._owner = owner + public_compute = f"compute_{name}" + private_compute = f"_compute_{name}" + compute_name = ( + private_compute if hasattr(owner, private_compute) else public_compute + ) + if hasattr(owner, compute_name): + # Compute methods are stored in a list called `__computes` + try: + computes = getattr(owner, "__computes") + except AttributeError: + computes = [] + setattr(owner, "__computes", computes) + computes.append(name) + + # The name of the attribute + self.name = name + # The internal name where the attribute's value is stored + self.internal_name = f"_reactive_{name}" + self.compute_name = compute_name + default = self._default + setattr(owner, f"_default_{name}", default) + + if TYPE_CHECKING: + + @overload + def __get__( + self: Reactive[ReactiveType], + obj: ReactableType, + obj_type: type[ReactableType], + ) -> ReactiveType: ... + + @overload + def __get__( + self: Reactive[ReactiveType], obj: None, obj_type: type[ReactableType] + ) -> Reactive[ReactiveType]: ... + + def __get__( + self: Reactive[ReactiveType], + obj: Reactable | None, + obj_type: type[ReactableType], + ) -> Reactive[ReactiveType] | ReactiveType: + _rich_traceback_omit = True + if obj is None: + # obj is None means we are invoking the descriptor via the class, and not the instance + return self + if not hasattr(obj, "id"): + raise ReactiveError( + f"Node is missing data; Check you are calling super().__init__(...) in the {obj.__class__.__name__}() constructor, before getting reactives." + ) + if not hasattr(obj, internal_name := self.internal_name): + self._initialize_reactive(obj, self.name) + + if hasattr(obj, self.compute_name): + value: ReactiveType + old_value = getattr(obj, internal_name) + value = getattr(obj, self.compute_name)() + setattr(obj, internal_name, value) + self._check_watchers(obj, self.name, old_value) + return value + else: + return getattr(obj, internal_name) + + def _set(self, obj: Reactable, value: ReactiveType, always: bool = False) -> None: + _rich_traceback_omit = True + + if not hasattr(obj, "_id"): + raise ReactiveError( + f"Node is missing data; Check you are calling super().__init__(...) in the {obj.__class__.__name__}() constructor, before setting reactives." + ) + + if isinstance(value, _Mutated): + value = value.value + always = True + + self._initialize_reactive(obj, self.name) + + if hasattr(obj, self.compute_name): + raise AttributeError( + f"Can't set {obj}.{self.name!r}; reactive attributes with a compute method are read-only" + ) + + name = self.name + current_value = getattr(obj, name) + # Check for private and public validate functions. + private_validate_function = getattr(obj, f"_validate_{name}", None) + if callable(private_validate_function): + value = private_validate_function(value) + public_validate_function = getattr(obj, f"validate_{name}", None) + if callable(public_validate_function): + value = public_validate_function(value) + + # Toggle the classes using the value's truthiness + if (toggle_class := self._toggle_class) is not None: + obj.set_class(bool(value), *toggle_class.split()) + + # If the value has changed, or this is the first time setting the value + if always or self._always_update or current_value != value: + # Store the internal value + setattr(obj, self.internal_name, value) + + # Check all watchers + self._check_watchers(obj, name, current_value) + + if self._run_compute: + self._compute(obj) + + if self._bindings: + obj.refresh_bindings() + + # Refresh according to descriptor flags + if self._layout or self._repaint or self._recompose: + obj.refresh( + repaint=self._repaint, + layout=self._layout, + recompose=self._recompose, + ) + + def __set__(self, obj: Reactable, value: ReactiveType) -> None: + _rich_traceback_omit = True + + self._set(obj, value) + + @classmethod + def _check_watchers(cls, obj: Reactable, name: str, old_value: Any) -> None: + """Check watchers, and call watch methods / computes + + Args: + obj: The reactable object. + name: Attribute name. + old_value: The old (previous) value of the attribute. + """ + _rich_traceback_omit = True + # Get the current value. + internal_name = f"_reactive_{name}" + value = getattr(obj, internal_name) + + private_watch_function = getattr(obj, f"_watch_{name}", None) + if callable(private_watch_function): + invoke_watcher(obj, private_watch_function, old_value, value) + + public_watch_function = getattr(obj, f"watch_{name}", None) + if callable(public_watch_function): + invoke_watcher(obj, public_watch_function, old_value, value) + + # Process "global" watchers + watchers: list[tuple[Reactable, WatchCallbackType]] + watchers = getattr(obj, "__watchers", {}).get(name, []) + # Remove any watchers for reactables that have since closed + if watchers: + watchers[:] = [ + (reactable, callback) + for reactable, callback in watchers + if not reactable._closing + ] + for reactable, callback in watchers: + with reactable.prevent(*obj._prevent_message_types_stack[-1]): + invoke_watcher(reactable, callback, old_value, value) + + @classmethod + def _compute(cls, obj: Reactable) -> None: + """Invoke all computes. + + Args: + obj: Reactable object. + """ + _rich_traceback_guard = True + for compute in obj._reactives.keys() & obj._computes: + try: + compute_method = getattr(obj, f"compute_{compute}") + except AttributeError: + try: + compute_method = getattr(obj, f"_compute_{compute}") + except AttributeError: + continue + current_value = getattr( + obj, f"_reactive_{compute}", getattr(obj, f"_default_{compute}", None) + ) + value = compute_method() + setattr(obj, f"_reactive_{compute}", value) + if value != current_value: + cls._check_watchers(obj, compute, current_value) + + +class reactive(Reactive[ReactiveType]): + """Create a reactive attribute. + + Args: + default: A default value or callable that returns a default. + layout: Perform a layout on change. + repaint: Perform a repaint on change. + init: Call watchers on initialize (post mount). + always_update: Call watchers even when the new value equals the old value. + recompose: Compose the widget again when the attribute changes. + bindings: Refresh bindings when the reactive changes. + toggle_class: An optional TCSS classname(s) to toggle based on the truthiness of the value. + """ + + def __init__( + self, + default: ReactiveType | Callable[[], ReactiveType] | Initialize[ReactiveType], + *, + layout: bool = False, + repaint: bool = True, + init: bool = True, + always_update: bool = False, + recompose: bool = False, + bindings: bool = False, + toggle_class: str | None = None, + ) -> None: + super().__init__( + default, + layout=layout, + repaint=repaint, + init=init, + always_update=always_update, + recompose=recompose, + bindings=bindings, + toggle_class=toggle_class, + ) + + +class var(Reactive[ReactiveType]): + """Create a reactive attribute (with no auto-refresh). + + Args: + default: A default value or callable that returns a default. + init: Call watchers on initialize (post mount). + always_update: Call watchers even when the new value equals the old value. + bindings: Refresh bindings when the reactive changes. + toggle_class: An optional TCSS classname(s) to toggle based on the truthiness of the value. + """ + + def __init__( + self, + default: ReactiveType | Callable[[], ReactiveType] | Initialize[ReactiveType], + init: bool = True, + always_update: bool = False, + bindings: bool = False, + toggle_class: str | None = None, + ) -> None: + super().__init__( + default, + layout=False, + repaint=False, + init=init, + always_update=always_update, + bindings=bindings, + toggle_class=toggle_class, + ) + + +def _watch( + node: DOMNode, + obj: Reactable, + attribute_name: str, + callback: WatchCallbackType, + *, + init: bool = True, +) -> None: + """Watch a reactive variable on an object. + + Args: + node: The node that created the watcher. + obj: The parent object. + attribute_name: The attribute to watch. + callback: A callable to call when the attribute changes. + init: True to call watcher initialization. + """ + if not hasattr(obj, "__watchers"): + setattr(obj, "__watchers", {}) + watchers: dict[str, list[tuple[Reactable, WatchCallbackType]]] + watchers = getattr(obj, "__watchers") + watcher_list = watchers.setdefault(attribute_name, []) + if any(callback == callback_from_list for _, callback_from_list in watcher_list): + return + if init: + current_value = getattr(obj, attribute_name, None) + invoke_watcher(obj, callback, current_value, current_value) + watcher_list.append((node, callback)) diff --git a/src/memray/_vendor/textual/render.py b/src/memray/_vendor/textual/render.py new file mode 100644 index 0000000000..c12bbbd453 --- /dev/null +++ b/src/memray/_vendor/textual/render.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from rich.cells import cell_len +from rich.console import Console, RenderableType +from rich.protocol import rich_cast + + +def measure( + console: Console, + renderable: RenderableType, + default: int, + *, + container_width: int | None = None, +) -> int: + """Measure a rich renderable. + + Args: + console: A console object. + renderable: Rich renderable. + default: Default width to use if renderable does not expose dimensions. + container_width: Width of container or None to use console width. + + Returns: + Width in cells + """ + if isinstance(renderable, str): + return cell_len(renderable) + + width = default + renderable = rich_cast(renderable) + get_console_width = getattr(renderable, "__rich_measure__", None) + if get_console_width is not None: + options = ( + console.options + if container_width is None + else console.options.update_width(container_width) + ) + render_width = get_console_width(console, options).maximum + width = max(0, render_width) + + return width diff --git a/src/memray/_vendor/textual/renderables/__init__.py b/src/memray/_vendor/textual/renderables/__init__.py new file mode 100644 index 0000000000..3d892e1f58 --- /dev/null +++ b/src/memray/_vendor/textual/renderables/__init__.py @@ -0,0 +1 @@ +__all__ = ["bar", "blank", "digits", "gradient", "sparkline"] diff --git a/src/memray/_vendor/textual/renderables/_blend_colors.py b/src/memray/_vendor/textual/renderables/_blend_colors.py new file mode 100644 index 0000000000..24476faf69 --- /dev/null +++ b/src/memray/_vendor/textual/renderables/_blend_colors.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from rich.color import Color + + +def blend_colors(color1: Color, color2: Color, ratio: float) -> Color: + """Given two RGB colors, return a color that sits some distance between + them in RGB color space. + + Args: + color1: The first color. + color2: The second color. + ratio: The ratio of color1 to color2. + + Returns: + A Color representing the blending of the two supplied colors. + """ + if color1.triplet is None or color2.triplet is None: + return color2 + r1, g1, b1 = color1.triplet + r2, g2, b2 = color2.triplet + + return Color.from_rgb( + r1 + (r2 - r1) * ratio, + g1 + (g2 - g1) * ratio, + b1 + (b2 - b1) * ratio, + ) diff --git a/src/memray/_vendor/textual/renderables/background_screen.py b/src/memray/_vendor/textual/renderables/background_screen.py new file mode 100644 index 0000000000..503bde8900 --- /dev/null +++ b/src/memray/_vendor/textual/renderables/background_screen.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Iterable + +from rich.console import Console, ConsoleOptions, RenderResult +from rich.segment import Segment +from rich.style import Style + +from memray._vendor.textual.color import Color + +if TYPE_CHECKING: + from memray._vendor.textual.screen import Screen + + +class BackgroundScreen: + """Tints a renderable and removes links / meta.""" + + def __init__( + self, + screen: Screen, + color: Color, + ) -> None: + """Initialize a BackgroundScreen instance. + + Args: + screen: A Screen instance. + color: A color (presumably with alpha). + """ + self.screen = screen + """Screen to process.""" + self.color = color + """Color to apply (should have alpha).""" + + @classmethod + def process_segments( + cls, segments: Iterable[Segment], color: Color + ) -> Iterable[Segment]: + """Apply tint to segments and remove meta + styles + + Args: + segments: Incoming segments. + color: Color of tint. + + Returns: + Segments with applied tint. + """ + from_rich_color = Color.from_rich_color + style_from_color = Style.from_color + _Segment = Segment + + NULL_STYLE = Style() + + if color.a == 0: + # Special case for transparent color + for segment in segments: + text, style, control = segment + if control: + yield segment + else: + yield _Segment( + text, + NULL_STYLE if style is None else style.clear_meta_and_links(), + control, + ) + return + + for segment in segments: + text, style, control = segment + if control: + yield segment + else: + style = NULL_STYLE if style is None else style.clear_meta_and_links() + yield _Segment( + text, + ( + style + + style_from_color( + ( + (from_rich_color(style.color) + color).rich_color + if style.color is not None + else None + ), + ( + (from_rich_color(style.bgcolor) + color).rich_color + if style.bgcolor is not None + else None + ), + ) + ), + control, + ) + + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + segments = console.render(self.screen._compositor, options) + color = self.color + return self.process_segments(segments, color) diff --git a/src/memray/_vendor/textual/renderables/bar.py b/src/memray/_vendor/textual/renderables/bar.py new file mode 100644 index 0000000000..c604e28c18 --- /dev/null +++ b/src/memray/_vendor/textual/renderables/bar.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +from rich.console import Console, ConsoleOptions, RenderResult +from rich.style import Style, StyleType +from rich.text import Text + +from memray._vendor.textual.color import Gradient + + +class Bar: + """Thin horizontal bar with a portion highlighted. + + Args: + highlight_range: The range to highlight. + highlight_style: The style of the highlighted range of the bar. + background_style: The style of the non-highlighted range(s) of the bar. + width: The width of the bar, or `None` to fill available width. + gradient: Optional gradient object. + """ + + HALF_BAR_LEFT: str = "╺" + BAR: str = "━" + HALF_BAR_RIGHT: str = "╸" + + def __init__( + self, + highlight_range: tuple[float, float] = (0, 0), + highlight_style: StyleType = "magenta", + background_style: StyleType = "grey37", + clickable_ranges: dict[str, tuple[int, int]] | None = None, + width: int | None = None, + gradient: Gradient | None = None, + ) -> None: + self.highlight_range = highlight_range + self.highlight_style = highlight_style + self.background_style = background_style + self.clickable_ranges = clickable_ranges or {} + self.width = width + self.gradient = gradient + + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + highlight_style = console.get_style(self.highlight_style) + background_style = console.get_style(self.background_style) + + width = self.width or options.max_width + start, end = self.highlight_range + + start = max(start, 0) + end = min(end, width) + + output_bar = Text("", end="") + + if start == end == 0 or end < 0 or start > end: + output_bar.append(Text(self.BAR * width, style=background_style, end="")) + yield output_bar + return + + # Round start and end to nearest half + start = round(start * 2) / 2 + end = round(end * 2) / 2 + + # Check if we start/end on a number that rounds to a .5 + half_start = start - int(start) > 0 + half_end = end - int(end) > 0 + + # Initial non-highlighted portion of bar + output_bar.append( + Text(self.BAR * (int(start - 0.5)), style=background_style, end="") + ) + if not half_start and start > 0: + output_bar.append(Text(self.HALF_BAR_RIGHT, style=background_style, end="")) + + highlight_bar = Text("", end="") + # The highlighted portion + bar_width = int(end) - int(start) + if half_start: + highlight_bar.append( + Text( + self.HALF_BAR_LEFT + self.BAR * (bar_width - 1), + style=highlight_style, + end="", + ) + ) + else: + highlight_bar.append( + Text(self.BAR * bar_width, style=highlight_style, end="") + ) + if half_end: + highlight_bar.append( + Text(self.HALF_BAR_RIGHT, style=highlight_style, end="") + ) + + if self.gradient is not None: + _apply_gradient(highlight_bar, self.gradient, width) + output_bar.append(highlight_bar) + + # The non-highlighted tail + if not half_end and end - width != 0: + output_bar.append(Text(self.HALF_BAR_LEFT, style=background_style, end="")) + output_bar.append( + Text(self.BAR * (int(width) - int(end) - 1), style=background_style, end="") + ) + + # Fire actions when certain ranges are clicked (e.g. for tabs) + for range_name, (start, end) in self.clickable_ranges.items(): + output_bar.apply_meta( + {"@click": f"range_clicked('{range_name}')"}, start, end + ) + + yield output_bar + + +def _apply_gradient(text: Text, gradient: Gradient, width: int) -> None: + """Apply a gradient to a Rich Text instance. + + Args: + text: A Text object. + gradient: A Textual gradient. + width: Width of gradient. + """ + if not width: + return + assert width > 0 + from_color = Style.from_color + get_rich_color = gradient.get_rich_color + + max_width = width - 1 + if not max_width: + text.stylize(from_color(gradient.get_color(0).rich_color)) + return + text_length = len(text) + for offset in range(text_length): + bar_offset = text_length - offset + text.stylize( + from_color(get_rich_color(bar_offset / max_width)), + offset, + offset + 1, + ) diff --git a/src/memray/_vendor/textual/renderables/blank.py b/src/memray/_vendor/textual/renderables/blank.py new file mode 100644 index 0000000000..1503ddbfd2 --- /dev/null +++ b/src/memray/_vendor/textual/renderables/blank.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from rich.style import Style as RichStyle + +from memray._vendor.textual.color import Color +from memray._vendor.textual.content import Style +from memray._vendor.textual.css.styles import RulesMap +from memray._vendor.textual.strip import Strip +from memray._vendor.textual.visual import RenderOptions, Visual + + +class Blank(Visual): + """Draw solid background color.""" + + def __init__(self, color: Color | str = "transparent") -> None: + self._rich_style = RichStyle.from_color(bgcolor=Color.parse(color).rich_color) + + def visualize(self) -> Blank: + return self + + def get_optimal_width(self, rules: RulesMap, container_width: int) -> int: + return container_width + + def get_height(self, rules: RulesMap, width: int) -> int: + return 1 + + def render_strips( + self, width: int, height: int | None, style: Style, options: RenderOptions + ) -> list[Strip]: + """Render the Visual into an iterable of strips. Part of the Visual protocol. + + Args: + width: Width of desired render. + height: Height of desired render or `None` for any height. + style: The base style to render on top of. + options: Additional render options. + + Returns: + An list of Strips. + """ + line_count = 1 if height is None else height + return [Strip.blank(width, self._rich_style)] * line_count diff --git a/src/memray/_vendor/textual/renderables/digits.py b/src/memray/_vendor/textual/renderables/digits.py new file mode 100644 index 0000000000..8d44f7114d --- /dev/null +++ b/src/memray/_vendor/textual/renderables/digits.py @@ -0,0 +1,253 @@ +from __future__ import annotations + +from rich.console import Console, ConsoleOptions, RenderResult +from rich.measure import Measurement +from rich.segment import Segment +from rich.style import Style, StyleType + +DIGITS = " 0123456789+-^x:ABCDEF$£€()" +DIGITS3X3_BOLD = """\ + + + +┏━┓ +┃ ┃ +┗━┛ +╺┓ + ┃ +╺┻╸ +╺━┓ +┏━┛ +┗━╸ +╺━┓ + ━┫ +╺━┛ +╻ ╻ +┗━┫ + ╹ +┏━╸ +┗━┓ +╺━┛ +┏━╸ +┣━┓ +┗━┛ +╺━┓ + ┃ + ╹ +┏━┓ +┣━┫ +┗━┛ +┏━┓ +┗━┫ +╺━┛ + +╺╋╸ + + +╺━╸ + + ^ + + + + × + + + : + +╭─╮ +├─┤ +╵ ╵ +┌─╮ +├─┤ +└─╯ +╭─╮ +│ +╰─╯ +┌─╮ +│ │ +└─╯ +╭─╴ +├─ +╰─╴ +╭─╴ +├─ +╵ +╭╫╮ +╰╫╮ +╰╫╯ +╭─╮ +╪═ +┷━╸ +╭─╮ +╪═ +╰─╯ +╭╴ +│ +╰╴ + ╶╮ + │ + ╶╯ +""".splitlines() + + +DIGITS3X3 = """\ + + + +╭─╮ +│ │ +╰─╯ +╶╮ + │ +╶┴╴ +╶─╮ +┌─┘ +╰─╴ +╶─╮ + ─┤ +╶─╯ +╷ ╷ +╰─┤ + ╵ +╭─╴ +╰─╮ +╶─╯ +╭─╴ +├─╮ +╰─╯ +╶─┐ + │ + ╵ +╭─╮ +├─┤ +╰─╯ +╭─╮ +╰─┤ +╶─╯ + +╶┼╴ + + +╶─╴ + + ^ + + + + × + + + : + +╭─╮ +├─┤ +╵ ╵ +┌─╮ +├─┤ +└─╯ +╭─╮ +│ +╰─╯ +┌─╮ +│ │ +└─╯ +╭─╴ +├─ +╰─╴ +╭─╴ +├─ +╵ +╭╫╮ +╰╫╮ +╰╫╯ +╭─╮ +╪═ +┷━╸ +╭─╮ +╪═ +╰─╯ +╭╴ +│ +╰╴ + ╶╮ + │ + ╶╯ +""".splitlines() + + +class Digits: + """Renders a 3X3 unicode 'font' for numerical values. + + Args: + text: Text to display. + style: Style to apply to the digits. + + """ + + REPLACEMENTS = str.maketrans({".": "•"}) + + def __init__(self, text: str, style: StyleType = "") -> None: + self._text = text + self._style = style + + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + style = console.get_style(self._style) + yield from self.render(style) + + def render(self, style: Style) -> RenderResult: + """Render with the given style + + Args: + style: Rich Style. + + Returns: + Result of render. + """ + digit_pieces: list[list[str]] = [[], [], []] + row1 = digit_pieces[0].append + row2 = digit_pieces[1].append + row3 = digit_pieces[2].append + + if style.bold: + digits = DIGITS3X3_BOLD + else: + digits = DIGITS3X3 + + for character in self._text.translate(self.REPLACEMENTS): + try: + position = DIGITS.index(character) * 3 + except ValueError: + row1(" ") + row2(" ") + row3(character) + else: + row1(digits[position].ljust(3)) + row2(digits[position + 1].ljust(3)) + row3(digits[position + 2].ljust(3)) + + new_line = Segment.line() + for line in digit_pieces: + yield Segment("".join(line), style + Style(meta={"offset": (0, 0)})) + yield new_line + + @classmethod + def get_width(cls, text: str) -> int: + """Calculate the width without rendering. + + Args: + text: Text which may be displayed in the `Digits` widget. + + Returns: + width of the text (in cells). + """ + width = sum(3 if character in DIGITS else 1 for character in text) + return width + + def __rich_measure__( + self, console: Console, options: ConsoleOptions + ) -> Measurement: + width = self.get_width(self._text) + return Measurement(width, width) diff --git a/src/memray/_vendor/textual/renderables/gradient.py b/src/memray/_vendor/textual/renderables/gradient.py new file mode 100644 index 0000000000..6cb73a404f --- /dev/null +++ b/src/memray/_vendor/textual/renderables/gradient.py @@ -0,0 +1,170 @@ +from __future__ import annotations + +from math import cos, pi, sin +from typing import Sequence + +from rich.console import Console, ConsoleOptions, RenderResult +from rich.segment import Segment +from rich.style import Style + +from memray._vendor.textual.color import Color, Gradient + + +class VerticalGradient: + """Draw a vertical gradient.""" + + def __init__(self, color1: str, color2: str) -> None: + self._color1 = Color.parse(color1) + self._color2 = Color.parse(color2) + + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + width = options.max_width + height = options.height or options.max_height + color1 = self._color1 + color2 = self._color2 + default_color = Color(0, 0, 0).rich_color + from_color = Style.from_color + blend = color1.blend + rich_color1 = color1.rich_color + for y in range(height): + line_color = from_color( + default_color, + ( + blend(color2, y / (height - 1)).rich_color + if height > 1 + else rich_color1 + ), + ) + yield Segment(f"{width * ' '}\n", line_color) + + +class LinearGradient: + """Render a linear gradient with a rotation. + + Args: + angle: Angle of rotation in degrees. + stops: List of stop consisting of pairs of offset (between 0 and 1) and color. + + """ + + def __init__( + self, angle: float, stops: Sequence[tuple[float, Color | str]] + ) -> None: + self.angle = angle + self._stops = [ + (stop, Color.parse(color) if isinstance(color, str) else color) + for stop, color in stops + ] + self._color_gradient = Gradient(*self._stops) + + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + width = options.max_width + height = options.height or options.max_height + + angle_radians = -self.angle * pi / 180.0 + sin_angle = sin(angle_radians) + cos_angle = cos(angle_radians) + + center_x = width / 2 + center_y = height + + new_line = Segment.line() + + _Segment = Segment + get_color = self._color_gradient.get_rich_color + from_color = Style.from_color + + for line_y in range(height): + point_y = float(line_y) * 2 - center_y + point_x = 0 - center_x + + x1 = (center_x + (point_x * cos_angle - point_y * sin_angle)) / width + x2 = ( + center_x + (point_x * cos_angle - (point_y + 1.0) * sin_angle) + ) / width + point_x = width - center_x + end_x1 = (center_x + (point_x * cos_angle - point_y * sin_angle)) / width + delta_x = (end_x1 - x1) / width + + if abs(delta_x) < 0.0001: + # Special case for verticals + yield _Segment( + "▀" * width, + from_color( + get_color(x1), + get_color(x2), + ), + ) + + else: + yield from [ + _Segment( + "▀", + from_color( + get_color(x1 + x * delta_x), + get_color(x2 + x * delta_x), + ), + ) + for x in range(width) + ] + + yield new_line + + +if __name__ == "__main__": + from rich import print + + COLORS = [ + "#881177", + "#aa3355", + "#cc6666", + "#ee9944", + "#eedd00", + "#99dd55", + "#44dd88", + "#22ccbb", + "#00bbcc", + "#0099cc", + "#3366bb", + "#663399", + ] + + stops = [(i / (len(COLORS) - 1), Color.parse(c)) for i, c in enumerate(COLORS)] + + print(LinearGradient(25, stops)) + + from time import time + + from memray._vendor.textual.app import App, ComposeResult + from memray._vendor.textual.widgets import Static + + class GradientApp(App): + CSS = """ + Screen { + background: transparent; + align: center middle; + } + + Static { + padding: 2 4; + background: $panel; + width: 50; + } + + """ + + def compose(self) -> ComposeResult: + yield Static("Gradients are fast now :-) ") + + def render(self): + return LinearGradient(time() * 90, stops) + + def on_mount(self) -> None: + self.set_interval(1 / 30, self.refresh) + + app = GradientApp() + app.run() diff --git a/src/memray/_vendor/textual/renderables/sparkline.py b/src/memray/_vendor/textual/renderables/sparkline.py new file mode 100644 index 0000000000..fd133eeacd --- /dev/null +++ b/src/memray/_vendor/textual/renderables/sparkline.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +import statistics +from fractions import Fraction +from typing import Callable, Generic, Iterable, Sequence, TypeVar + +from rich.color import Color +from rich.console import Console, ConsoleOptions, RenderResult +from rich.measure import Measurement +from rich.segment import Segment +from rich.style import Style + +from memray._vendor.textual.renderables._blend_colors import blend_colors + +T = TypeVar("T", int, float) + +SummaryFunction = Callable[[Sequence[T]], float] + + +class Sparkline(Generic[T]): + """A sparkline representing a series of data. + + Args: + data: The sequence of data to render. + width: The width of the sparkline/the number of buckets to partition the data into. + height: The height of the sparkline in lines. + min_color: The color of values equal to the min value in data. + max_color: The color of values equal to the max value in data. + summary_function: Function that will be applied to each bucket. + """ + + BARS = "▁▂▃▄▅▆▇█" + + def __init__( + self, + data: Sequence[T], + *, + width: int | None, + height: int | None = None, + min_color: Color = Color.from_rgb(0, 255, 0), + max_color: Color = Color.from_rgb(255, 0, 0), + summary_function: SummaryFunction[T] = max, + ) -> None: + self.data: Sequence[T] = data + self.width = width + self.height = height + self.min_color = Style.from_color(min_color) + self.max_color = Style.from_color(max_color) + self.summary_function: SummaryFunction[T] = summary_function + + @classmethod + def _buckets(cls, data: list[T], num_buckets: int) -> Iterable[Sequence[T]]: + """Partition ``data`` into ``num_buckets`` buckets. For example, the data + [1, 2, 3, 4] partitioned into 2 buckets is [[1, 2], [3, 4]]. + + Args: + data: The data to partition. + num_buckets: The number of buckets to partition the data into. + """ + bucket_step = Fraction(len(data), num_buckets) + for bucket_no in range(num_buckets): + start = int(bucket_step * bucket_no) + end = int(bucket_step * (bucket_no + 1)) + partition = data[start:end] + if partition: + yield partition + + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + width = self.width or options.max_width + height = self.height or 1 + + len_data = len(self.data) + if len_data == 0: + for _ in range(height - 1): + yield Segment.line() + + yield Segment("▁" * width, self.min_color) + return + if len_data == 1: + for i in range(height): + yield Segment("█" * width, self.max_color) + + if i < height - 1: + yield Segment.line() + return + + bar_line_segments = len(self.BARS) + bar_segments = bar_line_segments * height - 1 + + minimum, maximum = min(self.data), max(self.data) + extent = maximum - minimum or 1 + + summary_function = self.summary_function + min_color, max_color = self.min_color.color, self.max_color.color + assert min_color is not None + assert max_color is not None + + buckets = tuple(self._buckets(list(self.data), num_buckets=width)) + + for i in reversed(range(height)): + current_bar_part_low = i * bar_line_segments + current_bar_part_high = (i + 1) * bar_line_segments + + bucket_index = 0.0 + bars_rendered = 0 + step = len(buckets) / width + while bars_rendered < width: + partition = buckets[int(bucket_index)] + partition_summary = summary_function(partition) + height_ratio = (partition_summary - minimum) / extent + bar_index = int(height_ratio * bar_segments) + + if bar_index < current_bar_part_low: + bar = " " + with_color = False + elif bar_index >= current_bar_part_high: + bar = "█" + with_color = True + else: + bar = self.BARS[bar_index % bar_line_segments] + with_color = True + + if with_color: + bar_color = blend_colors(min_color, max_color, height_ratio) + style = Style.from_color(bar_color) + else: + style = None + + bars_rendered += 1 + bucket_index += step + yield Segment(bar, style) + + if i > 0: + yield Segment.line() + + def __rich_measure__( + self, console: "Console", options: "ConsoleOptions" + ) -> Measurement: + return Measurement(self.width or options.max_width, self.height or 1) + + +if __name__ == "__main__": + console = Console() + + def last(l: Sequence[T]) -> T: + return l[-1] + + funcs: Sequence[SummaryFunction[int]] = ( + min, + max, + last, + statistics.median, + statistics.mean, + ) + nums = [10, 2, 30, 60, 45, 20, 7, 8, 9, 10, 50, 13, 10, 6, 5, 4, 3, 7, 20] + console.print(f"data = {nums}\n") + for f in funcs: + console.print( + f"{f.__name__}:\t", + Sparkline(nums, width=12, summary_function=f), + end="", + ) + console.print("\n") diff --git a/src/memray/_vendor/textual/renderables/styled.py b/src/memray/_vendor/textual/renderables/styled.py new file mode 100644 index 0000000000..40ae6af372 --- /dev/null +++ b/src/memray/_vendor/textual/renderables/styled.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from rich.measure import Measurement +from rich.segment import Segment + +if TYPE_CHECKING: + from rich.console import Console, ConsoleOptions, RenderableType, RenderResult + from rich.style import StyleType + + +class Styled: + """A renderable which allows you to apply a style before and after another renderable. + This can be used to layer styles on top of each other, like a style sandwich. This is used, + for example, in the DataTable to layer default CSS styles + user renderables (e.g. Text objects + stored in the cells of the table) + CSS component styles on top of each other.""" + + def __init__( + self, + renderable: "RenderableType", + pre_style: "StyleType", + post_style: "StyleType", + ) -> None: + """Construct a Styled. + + Args: + renderable (RenderableType): Any renderable. + pre_style (StyleType): A style to apply across the entire renderable. + Will be applied before the styles from the renderable itself. + post_style (StyleType): A style to apply across the entire renderable. + Will be applied after the styles from the renderable itself. + """ + self.renderable = renderable + self.pre_style = pre_style + self.post_style = post_style + + def __rich_console__( + self, console: "Console", options: "ConsoleOptions" + ) -> "RenderResult": + pre_style = console.get_style(self.pre_style) + post_style = console.get_style(self.post_style) + rendered_segments = console.render(self.renderable, options) + segments = Segment.apply_style( + rendered_segments, style=pre_style, post_style=post_style + ) + return segments + + def __rich_measure__( + self, console: "Console", options: "ConsoleOptions" + ) -> Measurement: + return Measurement.get(console, options, self.renderable) diff --git a/src/memray/_vendor/textual/renderables/text_opacity.py b/src/memray/_vendor/textual/renderables/text_opacity.py new file mode 100644 index 0000000000..c4193795c0 --- /dev/null +++ b/src/memray/_vendor/textual/renderables/text_opacity.py @@ -0,0 +1,111 @@ +import functools +from typing import Iterable, Tuple, cast + +from rich.cells import cell_len +from rich.color import Color +from rich.console import Console, ConsoleOptions, RenderableType, RenderResult +from rich.segment import Segment +from rich.style import Style +from rich.terminal_theme import TerminalTheme + +from memray._vendor.textual._ansi_theme import DEFAULT_TERMINAL_THEME +from memray._vendor.textual._context import active_app +from memray._vendor.textual.color import TRANSPARENT +from memray._vendor.textual.filter import ANSIToTruecolor +from memray._vendor.textual.renderables._blend_colors import blend_colors + + +@functools.lru_cache(maxsize=1024) +def _get_blended_style_cached( + bg_color: Color, fg_color: Color, opacity: float +) -> Style: + """Blend from one color to another. + + Cached because when a UI is static the opacity will be constant. + + Args: + bg_color: Background color. + fg_color: Foreground color. + opacity: Opacity. + + Returns: + Resulting style. + """ + return Style.from_color( + color=blend_colors(bg_color, fg_color, ratio=opacity), + bgcolor=bg_color, + ) + + +class TextOpacity: + """Blend foreground into background.""" + + def __init__(self, renderable: RenderableType, opacity: float = 1.0) -> None: + """Wrap a renderable to blend foreground color into the background color. + + Args: + renderable: The RenderableType to manipulate. + opacity: The opacity as a float. A value of 1.0 means text is fully visible. + """ + self.renderable = renderable + self.opacity = opacity + + @classmethod + def process_segments( + cls, + segments: Iterable[Segment], + opacity: float, + ansi_theme: TerminalTheme, + ) -> Iterable[Segment]: + """Apply opacity to segments. + + Args: + segments: Incoming segments. + opacity: Opacity to apply. + ansi_theme: Terminal theme. + background: Color of background. + + Returns: + Segments with applied opacity. + """ + + _Segment = Segment + _from_color = Style.from_color + if opacity == 0: + for text, style, _control in cast( + # use Tuple rather than tuple so Python 3.7 doesn't complain + Iterable[Tuple[str, Style, object]], + segments, + ): + invisible_style = _from_color(bgcolor=style.bgcolor) + yield _Segment(cell_len(text) * " ", invisible_style) + elif opacity == 1: + yield from segments + else: + filter = ANSIToTruecolor(ansi_theme) + for segment in filter.apply(list(segments), TRANSPARENT): + # use Tuple rather than tuple so Python 3.7 doesn't complain + text, style, control = cast(Tuple[str, Style, object], segment) + if not style: + yield segment + continue + + color = style.color + bgcolor = style.bgcolor + if color and color.triplet and bgcolor and bgcolor.triplet: + color_style = _get_blended_style_cached(bgcolor, color, opacity) + yield _Segment(text, style + color_style) + else: + yield segment + + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + try: + app = active_app.get() + except LookupError: + ansi_theme = DEFAULT_TERMINAL_THEME + else: + ansi_theme = app.ansi_theme + segments = console.render(self.renderable, options) + return self.process_segments(segments, self.opacity, ansi_theme) diff --git a/src/memray/_vendor/textual/renderables/tint.py b/src/memray/_vendor/textual/renderables/tint.py new file mode 100644 index 0000000000..4fa4b4a2cc --- /dev/null +++ b/src/memray/_vendor/textual/renderables/tint.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +from typing import Iterable + +from rich.console import RenderableType +from rich.segment import Segment +from rich.style import Style +from rich.terminal_theme import TerminalTheme + +from memray._vendor.textual.color import TRANSPARENT, Color +from memray._vendor.textual.filter import ANSIToTruecolor + + +class Tint: + """Applies a color on top of an existing renderable.""" + + def __init__( + self, + renderable: RenderableType, + color: Color, + ) -> None: + """Wrap a renderable to apply a tint color. + + Args: + renderable: A renderable. + color: A color (presumably with alpha). + """ + self.renderable = renderable + self.color = color + + @classmethod + def process_segments( + cls, + segments: Iterable[Segment], + color: Color, + ansi_theme: TerminalTheme, + background: Color = TRANSPARENT, + ) -> Iterable[Segment]: + """Apply tint to segments. + + Args: + segments: Incoming segments. + color: Color of tint. + ansi_theme: The TerminalTheme defining how to map ansi colors to hex. + background: Background color. + + Returns: + Segments with applied tint. + """ + from_rich_color = Color.from_rich_color + style_from_color = Style.from_color + _Segment = Segment + + truecolor_style = ANSIToTruecolor(ansi_theme).truecolor_style + background_rich_color = background.rich_color + + NULL_STYLE = Style() + for segment in segments: + text, style, control = segment + if control: + yield segment + else: + style = ( + truecolor_style(style, background_rich_color) + if style is not None + else NULL_STYLE + ) + yield _Segment( + text, + ( + style + + style_from_color( + ( + (from_rich_color(style.color) + color).rich_color + if style.color is not None + else None + ), + ( + (from_rich_color(style.bgcolor) + color).rich_color + if style.bgcolor is not None + else None + ), + ) + ), + control, + ) diff --git a/src/memray/_vendor/textual/rlock.py b/src/memray/_vendor/textual/rlock.py new file mode 100644 index 0000000000..d7a6af2d5e --- /dev/null +++ b/src/memray/_vendor/textual/rlock.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from asyncio import Lock, Task, current_task + + +class RLock: + """A re-entrant asyncio lock.""" + + def __init__(self) -> None: + self._owner: Task | None = None + self._count = 0 + self._lock = Lock() + + async def acquire(self) -> None: + """Wait until the lock can be acquired.""" + task = current_task() + assert task is not None + if self._owner is None or self._owner is not task: + await self._lock.acquire() + self._owner = task + self._count += 1 + + def release(self) -> None: + """Release a previously acquired lock.""" + task = current_task() + assert task is not None + self._count -= 1 + if self._count < 0: + # Should not occur if every acquire as a release + raise RuntimeError("RLock.release called too many times") + if self._owner is task: + if not self._count: + self._owner = None + self._lock.release() + + @property + def is_locked(self): + """Return True if lock is acquired.""" + return self._lock.locked() + + async def __aenter__(self) -> None: + """Asynchronous context manager to acquire and release lock.""" + await self.acquire() + + async def __aexit__(self, _type, _value, _traceback) -> None: + """Exit the context manager.""" + self.release() + + +if __name__ == "__main__": + from asyncio import Lock + + async def locks(): + lock = RLock() + async with lock: + async with lock: + print("Hello") + + import asyncio + + asyncio.run(locks()) diff --git a/src/memray/_vendor/textual/screen.py b/src/memray/_vendor/textual/screen.py new file mode 100644 index 0000000000..7e51552f28 --- /dev/null +++ b/src/memray/_vendor/textual/screen.py @@ -0,0 +1,2233 @@ +""" + +This module contains the `Screen` class and related objects. + +The `Screen` class is a special widget which represents the content in the terminal. See [Screens](/guide/screens/) for details. + +""" + +from __future__ import annotations + +import asyncio +from functools import partial +from operator import attrgetter +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + ClassVar, + Generic, + Iterable, + Iterator, + Literal, + NamedTuple, + Optional, + TypeVar, + Union, +) + +import rich.repr +from rich.console import RenderableType +from rich.style import Style + +from memray._vendor.textual import constants, errors, events, messages +from memray._vendor.textual._arrange import arrange +from memray._vendor.textual._auto_scroll import get_auto_scroll_regions +from memray._vendor.textual._callback import invoke +from memray._vendor.textual._compositor import Compositor, MapGeometry +from memray._vendor.textual._context import active_message_pump, visible_screen_stack +from memray._vendor.textual._path import ( + CSSPathType, + _css_path_type_as_list, + _make_path_object_relative, +) +from memray._vendor.textual._types import CallbackType +from memray._vendor.textual.actions import SkipAction +from memray._vendor.textual.await_complete import AwaitComplete +from memray._vendor.textual.binding import ActiveBinding, Binding, BindingsMap +from memray._vendor.textual.css.match import match +from memray._vendor.textual.css.parse import parse_selectors +from memray._vendor.textual.css.query import NoMatches, QueryType +from memray._vendor.textual.css.styles import PointerShape +from memray._vendor.textual.dom import DOMNode +from memray._vendor.textual.errors import NoWidget +from memray._vendor.textual.geometry import Offset, Region, Shape, Size +from memray._vendor.textual.keys import key_to_character +from memray._vendor.textual.layout import DockArrangeResult +from memray._vendor.textual.reactive import Reactive, var +from memray._vendor.textual.renderables.background_screen import BackgroundScreen +from memray._vendor.textual.renderables.blank import Blank +from memray._vendor.textual.selection import SELECT_ALL, Selection +from memray._vendor.textual.signal import Signal +from memray._vendor.textual.timer import Timer +from memray._vendor.textual.walk import walk_selectable_widgets +from memray._vendor.textual.widget import Widget +from memray._vendor.textual.widgets import Tooltip +from memray._vendor.textual.widgets._toast import ToastRack + +if TYPE_CHECKING: + from typing_extensions import Final + + from memray._vendor.textual.command import Provider + + # Unused & ignored imports are needed for the docs to link to these objects: + from memray._vendor.textual.message_pump import MessagePump + +# Screen updates will be batched so that they don't happen more often than 60 times per second: +UPDATE_PERIOD: Final[float] = 1 / constants.MAX_FPS + +ScreenResultType = TypeVar("ScreenResultType") +"""The result type of a screen.""" + +ScreenResultCallbackType = Union[ + Callable[[Optional[ScreenResultType]], None], + Callable[[Optional[ScreenResultType]], Awaitable[None]], +] +"""Type of a screen result callback function.""" + + +class HoverWidgets(NamedTuple): + """Result of [get_hover_widget_at][textual.screen.Screen.get_hover_widget_at]""" + + mouse_over: tuple[Widget, Region] + """Widget and region directly under the mouse.""" + hover_over: tuple[Widget, Region] | None + """Widget with a hover style under the mouse, or `None` for no hover style widget.""" + + @property + def widgets(self) -> tuple[Widget, Widget | None]: + """Just the widgets.""" + return ( + self.mouse_over[0], + None if self.hover_over is None else self.hover_over[0], + ) + + +@rich.repr.auto +class ResultCallback(Generic[ScreenResultType]): + """Holds the details of a callback.""" + + def __init__( + self, + requester: MessagePump, + callback: ScreenResultCallbackType[ScreenResultType] | None, + future: asyncio.Future[ScreenResultType] | None = None, + ) -> None: + """Initialise the result callback object. + + Args: + requester: The object making a request for the callback. + callback: The callback function. + future: A Future to hold the result. + """ + self.requester = requester + """The object in the DOM that requested the callback.""" + self.callback: ScreenResultCallbackType | None = callback + """The callback function.""" + self.future = future + """A future for the result""" + + def __call__(self, result: ScreenResultType) -> None: + """Call the callback, passing the given result. + + Args: + result: The result to pass to the callback. + + Note: + If the requested or the callback are `None` this will be a no-op. + """ + if self.future is not None: + self.future.set_result(result) + if self.requester is not None and self.callback is not None: + self.requester.call_next(self.callback, result) + self.callback = None + + +@rich.repr.auto +class Screen(Generic[ScreenResultType], Widget): + """The base class for screens.""" + + AUTO_FOCUS: ClassVar[str | None] = None + """A selector to determine what to focus automatically when the screen is activated. + + The widget focused is the first that matches the given [CSS selector](/guide/queries/#query-selectors). + Set to `None` to inherit the value from the screen's app. + Set to `""` to disable auto focus. + """ + + CSS: ClassVar[str] = "" + """Inline CSS, useful for quick scripts. Rules here take priority over CSS_PATH. + + Note: + This CSS applies to the whole app. + """ + CSS_PATH: ClassVar[CSSPathType | None] = None + """File paths to load CSS from. + + Note: + This CSS applies to the whole app. + """ + + COMPONENT_CLASSES = {"screen--selection"} + + DEFAULT_CSS = """ + Screen { + layout: vertical; + overflow-y: auto; + background: $background; + + &:inline { + height: auto; + min-height: 1; + border-top: tall $background; + border-bottom: tall $background; + } + + &:ansi { + background: ansi_default; + color: ansi_default; + + &.-screen-suspended { + text-style: dim; + ScrollBar { + text-style: not dim; + } + } + } + .screen--selection { + background: $primary 50%; + } + } + """ + + TITLE: ClassVar[str | None] = None + """A class variable to set the *default* title for the screen. + + This overrides the app title. + To update the title while the screen is running, + you can set the [title][textual.screen.Screen.title] attribute. + """ + + SUB_TITLE: ClassVar[str | None] = None + """A class variable to set the *default* sub-title for the screen. + + This overrides the app sub-title. + To update the sub-title while the screen is running, + you can set the [sub_title][textual.screen.Screen.sub_title] attribute. + """ + + HORIZONTAL_BREAKPOINTS: ClassVar[list[tuple[int, str]]] | None = None + """Horizontal breakpoints, will override [App.HORIZONTAL_BREAKPOINTS][textual.app.App.HORIZONTAL_BREAKPOINTS] if not `None`.""" + VERTICAL_BREAKPOINTS: ClassVar[list[tuple[int, str]]] | None = None + """Vertical breakpoints, will override [App.VERTICAL_BREAKPOINTS][textual.app.App.VERTICAL_BREAKPOINTS] if not `None`.""" + + focused: Reactive[Widget | None] = Reactive(None) + """The focused [widget][textual.widget.Widget] or `None` for no focus. + To set focus, do not update this value directly. Use [set_focus][textual.screen.Screen.set_focus] instead.""" + stack_updates: Reactive[int] = Reactive(0, repaint=False) + """An integer that updates when the screen is resumed.""" + sub_title: Reactive[str | None] = Reactive(None, compute=False) + """Screen sub-title to override [the app sub-title][textual.app.App.sub_title].""" + title: Reactive[str | None] = Reactive(None, compute=False) + """Screen title to override [the app title][textual.app.App.title].""" + + COMMANDS: ClassVar[set[type[Provider] | Callable[[], type[Provider]]]] = set() + """Command providers used by the [command palette](/guide/command_palette), associated with the screen. + + Should be a set of [`command.Provider`][textual.command.Provider] classes. + """ + ALLOW_IN_MAXIMIZED_VIEW: ClassVar[str | None] = None + """A selector for the widgets (direct children of Screen) that are allowed in the maximized view (in addition to maximized widget). Or + `None` to default to [App.ALLOW_IN_MAXIMIZED_VIEW][textual.app.App.ALLOW_IN_MAXIMIZED_VIEW]""" + + ESCAPE_TO_MINIMIZE: ClassVar[bool | None] = None + """Use escape key to minimize (potentially overriding bindings) or `None` to defer to [`App.ESCAPE_TO_MINIMIZE`][textual.app.App.ESCAPE_TO_MINIMIZE].""" + + maximized: Reactive[Widget | None] = Reactive(None, layout=True) + """The currently maximized widget, or `None` for no maximized widget.""" + + selections: var[dict[Widget, Selection]] = var(dict) + """Map of widgets and selected ranges.""" + + _selecting = var(False) + """Indicates mouse selection is in progress.""" + + _box_select = var(False) + """Should text selection be limited to a box?""" + + _select_start: Reactive[tuple[Widget, Offset, Offset] | None] = Reactive(None) + """Tuple of (widget, screen offset, text offset) where selection started.""" + _select_end: Reactive[tuple[Widget, Offset, Offset] | None] = Reactive(None) + """Tuple of (widget, screen offset, text offset) where selection ends.""" + + _mouse_down_offset: var[Offset | None] = var(None) + """Last mouse down screen offset, or `None` if the mouse is up.""" + + _pointer_shape: var[PointerShape] = var("default") + """The current mouse pointer shape.""" + + BINDINGS = [ + Binding("tab", "app.focus_next", "Focus Next", show=False), + Binding("shift+tab", "app.focus_previous", "Focus Previous", show=False), + Binding("ctrl+c,super+c", "screen.copy_text", "Copy selected text", show=False), + ] + + def __init__( + self, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + ) -> None: + """ + Initialize the screen. + + Args: + name: The name of the screen. + id: The ID of the screen in the DOM. + classes: The CSS classes for the screen. + """ + self._modal = False + super().__init__(name=name, id=id, classes=classes) + self._compositor = Compositor() + self._dirty_widgets: set[Widget] = set() + self.__update_timer: Timer | None = None + self._callbacks: list[tuple[CallbackType, MessagePump]] = [] + self._result_callbacks: list[ResultCallback[ScreenResultType | None]] = [] + + self._tooltip_widget: Widget | None = None + self._tooltip_timer: Timer | None = None + + css_paths = [ + _make_path_object_relative(css_path, self) + for css_path in ( + _css_path_type_as_list(self.CSS_PATH) + if self.CSS_PATH is not None + else [] + ) + ] + self.css_path = css_paths + + self.title = self.TITLE + self.sub_title = self.SUB_TITLE + + self.screen_layout_refresh_signal: Signal[Screen] = Signal( + self, "layout-refresh" + ) + """The signal that is published when the screen's layout is refreshed.""" + + self.bindings_updated_signal: Signal[Screen] = Signal(self, "bindings_updated") + """A signal published when the bindings have been updated""" + + self.text_selection_started_signal: Signal[Screen] = Signal( + self, "selection_started" + ) + """A signal published when text selection has started.""" + + self._css_update_count = -1 + """Track updates to CSS.""" + + self._layout_widgets: dict[DOMNode, set[Widget]] = {} + """Widgets whose layout may have changed.""" + + self._auto_select_scroll_timer: Timer | None = None + """A timer to auto scroll a container.""" + + @property + def is_modal(self) -> bool: + """Is the screen modal?""" + return self._modal + + @property + def is_current(self) -> bool: + """Is the screen current (i.e. visible to user)?""" + from memray._vendor.textual.app import ScreenStackError + + try: + return self.app.screen is self or self in self.app._background_screens + except ScreenStackError: + return False + + @property + def _update_timer(self) -> Timer: + """Timer used to perform updates.""" + if self.__update_timer is None: + self.__update_timer = self.set_interval( + UPDATE_PERIOD, self._on_timer_update, name="screen_update", pause=True + ) + return self.__update_timer + + @property + def layers(self) -> tuple[str, ...]: + """Layers from parent. + + Returns: + Tuple of layer names. + """ + extras = ["_loading"] + if not self.app._disable_notifications: + extras.append("_toastrack") + if not self.app._disable_tooltips: + extras.append("_tooltips") + return (*super().layers, *extras) + + @property + def size(self) -> Size: + """The size of the screen.""" + return self.app.size - self.styles.gutter.totals + + def _watch_focused(self): + self.refresh_bindings() + + def _watch_stack_updates(self): + self.refresh_bindings() + + async def _watch_selections( + self, + old_selections: dict[Widget, Selection], + selections: dict[Widget, Selection], + ): + for widget in old_selections.keys() | selections.keys(): + widget.selection_updated(selections.get(widget, None)) + + def refresh_bindings(self) -> None: + """Call to request a refresh of bindings.""" + self.bindings_updated_signal.publish(self) + + def _watch_maximized( + self, previously_maximized: Widget | None, maximized: Widget | None + ) -> None: + # The screen gets a `-maximized-view` class if there is a maximized widget + # The widget gets a `-maximized` class if it is maximized + self.set_class(maximized is not None, "-maximized-view") + if previously_maximized is not None: + previously_maximized.remove_class("-maximized") + if maximized is not None: + maximized.add_class("-maximized") + + @property + def _binding_chain(self) -> list[tuple[DOMNode, BindingsMap]]: + """Binding chain from this screen.""" + + focused = self.focused + if focused is not None and focused.loading: + focused = None + + namespace_bindings: list[tuple[DOMNode, BindingsMap]] + if focused is None: + namespace_bindings = [ + (self, self._bindings.copy()), + (self.app, self.app._bindings.copy()), + ] + else: + namespace_bindings = [ + (node, node._bindings.copy()) for node in focused.ancestors_with_self + ] + + # Filter out bindings that could be captures by widgets (such as Input, TextArea) + filter_namespaces: list[DOMNode] = [] + for namespace, bindings_map in namespace_bindings: + for filter_namespace in filter_namespaces: + check_consume_key = filter_namespace.check_consume_key + for key in list(bindings_map.key_to_bindings): + if check_consume_key(key, key_to_character(key)): + # If the widget consumes the key (e.g. like an Input widget), + # then remove the key from the bindings map. + del bindings_map.key_to_bindings[key] + + filter_namespaces.append(namespace) + + keymap = self.app._keymap + for namespace, bindings_map in namespace_bindings: + if keymap: + result = bindings_map.apply_keymap(keymap) + if result.clashed_bindings: + self.app.handle_bindings_clash(result.clashed_bindings, namespace) + + return namespace_bindings + + @property + def _modal_binding_chain(self) -> list[tuple[DOMNode, BindingsMap]]: + """The binding chain, ignoring everything before the last modal.""" + binding_chain = self._binding_chain + for index, (node, _bindings) in enumerate(binding_chain, 1): + if node.is_modal: + return binding_chain[:index] + return binding_chain + + @property + def active_bindings(self) -> dict[str, ActiveBinding]: + """Get currently active bindings for this screen. + + If no widget is focused, then app-level bindings are returned. + If a widget is focused, then any bindings present in the screen and app are merged and returned. + + This property may be used to inspect current bindings. + + Returns: + A map of keys to a tuple containing (NAMESPACE, BINDING, ENABLED). + """ + bindings_map: dict[str, ActiveBinding] = {} + app = self.app + for namespace, bindings in self._modal_binding_chain: + for key, binding in bindings: + # This will call the nodes `check_action` method. + action_state = app._check_action_state(binding.action, namespace) + if action_state is False: + # An action_state of False indicates the action is disabled and not shown + # Note that None has a different meaning, which is why there is an `is False` + # rather than a truthy check. + continue + + enabled = bool(action_state) + if existing_key_and_binding := bindings_map.get(key): + # This key has already been bound + # Replace priority bindings + if ( + binding.priority + and not existing_key_and_binding.binding.priority + ): + bindings_map[key] = ActiveBinding( + namespace, binding, enabled, binding.tooltip + ) + else: + # New binding + bindings_map[key] = ActiveBinding( + namespace, binding, enabled, binding.tooltip + ) + + return bindings_map + + def arrange(self, size: Size, _optimal: bool = False) -> DockArrangeResult: + """Arrange children. + + Args: + size: Size of container. + optimal: Ignored on screen. + + Returns: + Widget locations. + """ + # This is customized over the base class to allow for a widget to be maximized + cache_key = (size, self._nodes._updates, self.maximized) + cached_result = self._arrangement_cache.get(cache_key) + if cached_result is not None: + return cached_result + + allow_in_maximized_view = ( + self.app.ALLOW_IN_MAXIMIZED_VIEW + if self.ALLOW_IN_MAXIMIZED_VIEW is None + else self.ALLOW_IN_MAXIMIZED_VIEW + ) + + def get_maximize_widgets(maximized: Widget) -> list[Widget]: + """Get widgets to display in maximized view. + + Returns: + A list of widgets. + + """ + # De-duplicate with a set + widgets = { + maximized, + *self.query_children(allow_in_maximized_view), + *self.query_children(".-textual-system"), + } + # Restore order of widgets. + maximize_widgets = [widget for widget in self.children if widget in widgets] + # Add the maximized widget, if its not already included + if maximized not in maximize_widgets: + maximize_widgets.insert(0, maximized) + return maximize_widgets + + arrangement = self._arrangement_cache[cache_key] = arrange( + self, + ( + get_maximize_widgets(self.maximized) + if self.maximized is not None + else self._nodes + ), + size, + self.size, + False, + ) + + return arrangement + + @property + def is_active(self) -> bool: + """Is the screen active (i.e. visible and top of the stack)?""" + try: + return self.app.screen is self + except Exception: + return False + + @property + def allow_select(self) -> bool: + """Check if this widget permits text selection.""" + return self.ALLOW_SELECT + + def get_loading_widget(self) -> Widget: + """Get a widget to display a loading indicator. + + The default implementation will defer to App.get_loading_widget. + + Returns: + A widget in place of this widget to indicate a loading. + """ + loading_widget = self.app.get_loading_widget() + return loading_widget + + def _watch__pointer_shape(self, pointer_shape: PointerShape) -> None: + self.app._set_pointer_shape(pointer_shape) + + def update_pointer_shape(self) -> None: + """Get the screen's current pointer shape.""" + if self._selecting: + self._pointer_shape = "text" + return + widget = self if self.app.mouse_over is None else self.app.mouse_over + pointer_shape = "default" + for node in widget.ancestors_with_self: + if isinstance(node, Widget): + if node.loading: + pointer_shape = "wait" + break + if (pointer_shape := node.styles.pointer) != "default": + break + + self._pointer_shape = pointer_shape + + def render(self) -> RenderableType: + """Render method inherited from widget, used to render the screen's background. + + Returns: + Background renderable. + """ + background = self.styles.background + try: + base_screen = visible_screen_stack.get().pop() + except LookupError: + base_screen = None + + if base_screen is not None and base_screen is not self and background.a < 1: + # If background is translucent, render a background screen + return BackgroundScreen(base_screen, background) + + if background.is_transparent: + # If the background is transparent, defer to App.render + return self.app.render() + # Render a screen of a solid color. + return Blank(background) + + def get_offset(self, widget: Widget) -> Offset: + """Get the absolute offset of a given Widget. + + Args: + widget: A widget + + Returns: + The widget's offset relative to the top left of the terminal. + """ + return self._compositor.get_offset(widget) + + def get_widget_at(self, x: int, y: int) -> tuple[Widget, Region]: + """Get the widget at a given coordinate. + + Args: + x: X Coordinate. + y: Y Coordinate. + + Returns: + Widget and screen region. + + Raises: + NoWidget: If there is no widget under the screen coordinate. + """ + return self._compositor.get_widget_at(x, y) + + def get_hover_widgets_at(self, x: int, y: int) -> HoverWidgets: + """Get the widget, and its region directly under the mouse, and the first + widget, region pair with a hover style. + + Args: + x: X Coordinate. + y: Y Coordinate. + + Returns: + A pair of (WIDGET, REGION) tuples for the top most and first hover style respectively. + + Raises: + NoWidget: If there is no widget under the screen coordinate. + + """ + widgets_under_coordinate = iter(self._compositor.get_widgets_at(x, y)) + try: + top_widget, top_region = next(widgets_under_coordinate) + except StopIteration: + raise errors.NoWidget(f"No hover widget under screen coordinate ({x}, {y})") + if not top_widget._has_hover_style: + for widget, region in widgets_under_coordinate: + if widget._has_hover_style: + return HoverWidgets((top_widget, top_region), (widget, region)) + return HoverWidgets((top_widget, top_region), None) + return HoverWidgets((top_widget, top_region), (top_widget, top_region)) + + def get_widgets_at(self, x: int, y: int) -> Iterable[tuple[Widget, Region]]: + """Get all widgets under a given coordinate. + + Args: + x: X coordinate. + y: Y coordinate. + + Returns: + Sequence of (WIDGET, REGION) tuples. + """ + return self._compositor.get_widgets_at(x, y) + + def get_focusable_widget_at(self, x: int, y: int) -> Widget | None: + """Get the focusable widget under a given coordinate. + + If the widget directly under the given coordinate is not focusable, then this method will check + if any of the ancestors are focusable. If no ancestors are focusable, then `None` will be returned. + + Args: + x: X coordinate. + y: Y coordinate. + + Returns: + A `Widget`, or `None` if there is no focusable widget underneath the coordinate. + """ + try: + widget, _region = self.get_widget_at(x, y) + except NoWidget: + return None + + if widget.has_class("-textual-system") or widget.loading: + # Clicking Textual system widgets should not focus anything + return None + + for node in widget.ancestors_with_self: + if isinstance(node, Widget) and node.focusable: + return node + return None + + def get_style_at(self, x: int, y: int) -> Style: + """Get the style under a given coordinate. + + Args: + x: X Coordinate. + y: Y Coordinate. + + Returns: + Rich Style object. + """ + return self._compositor.get_style_at(x, y) + + def get_widget_and_offset_at( + self, x: int, y: int + ) -> tuple[Widget | None, Offset | None]: + """Get the widget under a given coordinate, and an offset within the original content. + + Args: + x: X Coordinate. + y: Y Coordinate. + + Returns: + Tuple of Widget and Offset, both of which may be None. + """ + return self._compositor.get_widget_and_offset_at(x, y) + + def find_widget(self, widget: Widget) -> MapGeometry: + """Get the screen region of a Widget. + + Args: + widget: A Widget within the composition. + + Returns: + Region relative to screen. + + Raises: + NoWidget: If the widget could not be found in this screen. + """ + return self._compositor.find_widget(widget) + + def clear_selection(self) -> None: + """Clear any selected text.""" + self.selections = {} + self._select_start = None + self._select_end = None + + def _select_all_in_widget(self, widget: Widget) -> None: + """Select a widget and all its children. + + Args: + widget: Widget to select. + """ + select_all = SELECT_ALL + self.selections = { + widget: select_all, + **{child: select_all for child in widget.query("*")}, + } + + @property + def focus_chain(self) -> list[Widget]: + """A list of widgets that may receive focus, in focus order.""" + # TODO: Calculating a focus chain is moderately expensive. + # Suspect we can move focus without calculating the entire thing again. + + widgets: list[Widget] = [] + add_widget = widgets.append + focus_sorter = attrgetter("_focus_sort_key") + # We traverse the DOM and keep track of where we are at with a node stack. + # Additionally, we manually keep track of the visibility of the DOM + # instead of relying on the property `.visible` to save on DOM traversals. + # node_stack: list[tuple[iterator over node children, node visibility]] + + root_node = self.screen + + if (focused := self.focused) is not None: + for node in focused.ancestors_with_self: + if node._trap_focus: + root_node = node + break + + node_stack: list[tuple[Iterator[Widget], bool]] = [ + ( + iter(sorted(root_node.displayed_children, key=focus_sorter)), + self.visible, + ) + ] + pop = node_stack.pop + push = node_stack.append + + while node_stack: + children_iterator, parent_visibility = node_stack[-1] + node = next(children_iterator, None) + if node is None: + pop() + else: + if node._check_disabled(): + continue + node_styles_visibility = node.styles.get_rule("visibility") + node_is_visible = ( + node_styles_visibility != "hidden" + if node_styles_visibility + else parent_visibility # Inherit visibility if the style is unset. + ) + if node.is_container and node.allow_focus_children(): + sorted_displayed_children = sorted( + node.displayed_children, key=focus_sorter + ) + push((iter(sorted_displayed_children), node_is_visible)) + # Same check as `if node.focusable`, but we cached inherited visibility + # and we also skipped disabled nodes altogether. + if node_is_visible and node.allow_focus(): + add_widget(node) + + return widgets + + def _move_focus( + self, direction: int = 0, selector: str | type[QueryType] = "*" + ) -> Widget | None: + """Move the focus in the given direction. + + If no widget is currently focused, this will focus the first focusable widget. + If no focusable widget matches the given CSS selector, focus is set to `None`. + + Args: + direction: 1 to move forward, -1 to move backward, or + 0 to keep the current focus. + selector: CSS selector to filter + what nodes can be focused. + + Returns: + Newly focused widget, or None for no focus. If the return + is not `None`, then it is guaranteed that the widget returned matches + the CSS selectors given in the argument. + """ + + if not isinstance(selector, str): + selector = selector.__name__ + selector_set = parse_selectors(selector) + focus_chain = self.focus_chain + + # If a widget is maximized we want to limit the focus chain to the visible widgets + if self.maximized is not None: + focusable = set(self.maximized.walk_children(with_self=True)) + focus_chain = [widget for widget in focus_chain if widget in focusable] + + filtered_focus_chain = ( + node for node in focus_chain if match(selector_set, node) + ) + + if not focus_chain: + # Nothing focusable, so nothing to do + return self.focused + if self.focused is None: + # Nothing currently focused, so focus the first one. + to_focus = next(filtered_focus_chain, None) + self.set_focus(to_focus) + return self.focused + + # Ensure focus will be in a node that matches the selectors. + if not direction and not match(selector_set, self.focused): + direction = 1 + + try: + # Find the index of the currently focused widget + current_index = focus_chain.index(self.focused) + except ValueError: + # Focused widget was removed in the interim, start again + self.set_focus(next(filtered_focus_chain, None)) + else: + # Only move the focus if we are currently showing the focus + if direction: + to_focus = None + chain_length = len(focus_chain) + for step in range(1, len(focus_chain) + 1): + node = focus_chain[ + (current_index + direction * step) % chain_length + ] + if match(selector_set, node): + to_focus = node + break + self.set_focus(to_focus) + + return self.focused + + def focus_next(self, selector: str | type[QueryType] = "*") -> Widget | None: + """Focus the next widget, optionally filtered by a CSS selector. + + If no widget is currently focused, this will focus the first focusable widget. + If no focusable widget matches the given CSS selector, focus is set to `None`. + + Args: + selector: CSS selector to filter + what nodes can be focused. + + Returns: + Newly focused widget, or None for no focus. If the return + is not `None`, then it is guaranteed that the widget returned matches + the CSS selectors given in the argument. + """ + return self._move_focus(1, selector) + + def focus_previous(self, selector: str | type[QueryType] = "*") -> Widget | None: + """Focus the previous widget, optionally filtered by a CSS selector. + + If no widget is currently focused, this will focus the first focusable widget. + If no focusable widget matches the given CSS selector, focus is set to `None`. + + Args: + selector: CSS selector to filter + what nodes can be focused. + + Returns: + Newly focused widget, or None for no focus. If the return + is not `None`, then it is guaranteed that the widget returned matches + the CSS selectors given in the argument. + """ + return self._move_focus(-1, selector) + + def maximize(self, widget: Widget, container: bool = True) -> bool: + """Maximize a widget, so it fills the screen. + + Args: + widget: Widget to maximize. + container: If one of the widgets ancestors is a maximizeable widget, maximize that instead. + + Returns: + `True` if the widget was maximized, otherwise `False`. + """ + if widget.allow_maximize: + if container: + # If we want to maximize the container, look up the dom to find a suitable widget + for maximize_widget in widget.ancestors: + if not isinstance(maximize_widget, Widget): + break + if maximize_widget.allow_maximize: + self.maximized = maximize_widget + return True + + self.maximized = widget + return True + return False + + def minimize(self) -> None: + """Restore any maximized widget to normal state.""" + self.maximized = None + if self.focused is not None: + self.call_after_refresh( + self.scroll_to_widget, self.focused, animate=False, center=True + ) + + def get_selected_text(self) -> str | None: + """Get text under selection. + + Returns: + Selected text, or `None` if no text was selected. + """ + if not self.selections: + return None + + widget_text: list[str] = [] + for widget, selection in self.selections.items(): + # Filter out widgets that may have been removed since the text was selected + if ( + widget.is_attached + and (selected_text_in_widget := widget.get_selection(selection)) + is not None + ): + widget_text.extend(selected_text_in_widget) + + selected_text = "".join(widget_text).rstrip("\n") + return selected_text + + def action_copy_text(self) -> None: + """Copy selected text to clipboard.""" + selection = self.get_selected_text() + if selection is None: + # No text selected + raise SkipAction() + self.app.copy_to_clipboard(selection) + + def action_maximize(self) -> None: + """Action to maximize the currently focused widget.""" + if self.focused is not None: + self.maximize(self.focused) + + def action_minimize(self) -> None: + """Action to minimize the currently maximized widget.""" + self.minimize() + + def action_blur(self) -> None: + """Action to remove focus (if set).""" + self.set_focus(None) + + async def action_focus(self, selector: str) -> None: + """An [action](/guide/actions) to focus the given widget. + + Args: + selector: Selector of widget to focus (first match). + """ + try: + node = self.query(selector).first() + except NoMatches: + pass + else: + if isinstance(node, Widget): + self.set_focus(node) + + def _reset_focus( + self, widget: Widget, avoiding: list[Widget] | None = None + ) -> None: + """Reset the focus when a widget is removed + + Args: + widget: A widget that is removed. + avoiding: Optional list of nodes to avoid. + """ + + avoiding = avoiding or [] + + # Make this a NOP if we're being asked to deal with a widget that + # isn't actually the currently-focused widget. + if self.focused is not widget: + return + + # Grab the list of widgets that we can set focus to. + focusable_widgets = self.focus_chain + if not focusable_widgets: + # If there's nothing to focus... give up now. + self.set_focus(None) + return + + try: + # Find the location of the widget we're taking focus from, in + # the focus chain. + widget_index = focusable_widgets.index(widget) + except ValueError: + # widget is not in focusable widgets + # It may have been made invisible + # Move to a sibling if possible + for sibling in widget.visible_siblings: + if sibling not in avoiding and sibling.focusable: + self.set_focus(sibling) + break + else: + self.set_focus(None) + return + + # Now go looking for something before it, that isn't about to be + # removed, and which can receive focus, and go focus that. + chosen: Widget | None = None + for candidate in reversed( + focusable_widgets[widget_index + 1 :] + focusable_widgets[:widget_index] + ): + if candidate not in avoiding: + chosen = candidate + break + + # Go with what was found. + self.set_focus(chosen) + + def _update_focus_styles( + self, focused: Widget | None = None, blurred: Widget | None = None + ) -> None: + """Update CSS for focus changes. + + Args: + focused: The widget that was focused. + blurred: The widget that was blurred. + """ + widgets: set[DOMNode] = set() + + if focused is not None: + for widget in reversed(focused.ancestors_with_self): + if widget._has_focus_within: + widgets.update(widget.walk_children(with_self=True)) + break + if blurred is not None: + for widget in reversed(blurred.ancestors_with_self): + if widget._has_focus_within: + widgets.update(widget.walk_children(with_self=True)) + break + if widgets: + self.app.stylesheet.update_nodes(widgets, animate=True) + + def set_focus( + self, + widget: Widget | None, + scroll_visible: bool = True, + from_app_focus: bool = False, + ) -> None: + """Focus (or un-focus) a widget. A focused widget will receive key events first. + + Args: + widget: Widget to focus, or None to un-focus. + scroll_visible: Scroll widget into view. + from_app_focus: True if this focus is due to the app itself having regained + focus. False if the focus is being set because a widget within the app + regained focus. + """ + if widget is self.focused: + # Widget is already focused + return + + focused: Widget | None = None + blurred: Widget | None = None + + if widget is None: + # No focus, so blur currently focused widget if it exists + if self.focused is not None: + self.focused.post_message(events.Blur()) + blurred = self.focused + self.focused = None + self.log.debug("focus was removed") + elif widget.focusable: + if self.focused != widget: + if self.focused is not None: + # Blur currently focused widget + self.focused.post_message(events.Blur()) + blurred = self.focused + # Change focus + self.focused = widget + # Send focus event + widget.post_message(events.Focus(from_app_focus=from_app_focus)) + focused = widget + + if scroll_visible: + + def scroll_to_center(widget: Widget) -> None: + """Scroll to center (after a refresh).""" + if self.focused is widget and not self.can_view_entire(widget): + self.scroll_to_center(widget, origin_visible=True) + + self.call_later(scroll_to_center, widget) + + self.log.debug(widget, "was focused") + + self._update_focus_styles(focused, blurred) + self.call_after_refresh(self.refresh_bindings) + + def _extend_compose(self, widgets: list[Widget]) -> None: + """Insert Textual's own internal widgets. + + Args: + widgets: The list of widgets to be composed. + + This method adds the tooltip, if required, and also adds the + container for `Toast`s. + """ + if not self.app._disable_tooltips: + widgets.insert(0, Tooltip(id="textual-tooltip")) + if not self.app._disable_notifications: + widgets.insert(0, ToastRack(id="textual-toastrack")) + + def _on_mount(self, event: events.Mount) -> None: + """Set up the tooltip-clearing signal when we mount.""" + self.screen_layout_refresh_signal.subscribe( + self, self._maybe_clear_tooltip, immediate=True + ) + + async def _on_idle(self, event: events.Idle) -> None: + # Check for any widgets marked as 'dirty' (needs a repaint) + event.prevent_default() + if not self.app._batch_count and self.is_current: + if ( + self._layout_required + or self._scroll_required + or self._repaint_required + or self._recompose_required + or self._dirty_widgets + ): + self._update_timer.resume() + return + + await self._invoke_and_clear_callbacks() + + def _compositor_refresh(self) -> None: + """Perform a compositor refresh.""" + + app = self.app + + if app.is_inline: + if self is app.screen: + inline_height = app._get_inline_height() + clear = ( + app._previous_inline_height is not None + and inline_height < app._previous_inline_height + ) + app._display( + self, + self._compositor.render_inline( + app.size.with_height(inline_height), + screen_stack=app._background_screens, + clear=clear, + ), + ) + app._previous_inline_height = inline_height + self._dirty_widgets.clear() + self._compositor._dirty_regions.clear() + elif ( + self in self.app._background_screens and self._compositor._dirty_regions + ): + app.screen.refresh(*self._compositor._dirty_regions) + self._compositor._dirty_regions.clear() + self._dirty_widgets.clear() + + else: + if self is app.screen: + # Top screen + update = self._compositor.render_update( + screen_stack=app._background_screens + ) + app._display(self, update) + self._dirty_widgets.clear() + elif ( + self in self.app._background_screens and self._compositor._dirty_regions + ): + self._set_dirty(*self._compositor._dirty_regions) + app.screen.refresh(*self._compositor._dirty_regions) + self._repaint_required = True + self._compositor._dirty_regions.clear() + self._dirty_widgets.clear() + app._update_mouse_over(self) + + def _on_timer_update(self) -> None: + """Called by the _update_timer.""" + self._update_timer.pause() + if self.is_current and not self.app._batch_count: + if self._layout_required: + self._refresh_layout(scroll=self._scroll_required) + self._layout_required = False + self._dirty_widgets.clear() + elif self._scroll_required: + self._refresh_layout(scroll=True) + self._scroll_required = False + + if self._repaint_required: + self._dirty_widgets.clear() + self._dirty_widgets.add(self) + self._repaint_required = False + + if self._dirty_widgets: + self._compositor.update_widgets(self._dirty_widgets) + self._compositor_refresh() + + if self._recompose_required: + self._recompose_required = False + self.call_next(self.recompose) + + if self._callbacks: + self.call_next(self._invoke_and_clear_callbacks) + + async def _invoke_and_clear_callbacks(self) -> None: + """If there are scheduled callbacks to run, call them and clear + the callback queue.""" + if self._callbacks: + callbacks = self._callbacks[:] + self._callbacks.clear() + for callback, message_pump in callbacks: + with message_pump._context(): + await invoke(callback) + + def _invoke_later(self, callback: CallbackType, sender: MessagePump) -> None: + """Enqueue a callback to be invoked after the screen is repainted. + + Args: + callback: A callback. + sender: The sender (active message pump) of the callback. + """ + + self._callbacks.append((callback, sender)) + self.check_idle() + + def _push_result_callback( + self, + requester: MessagePump, + callback: ScreenResultCallbackType[ScreenResultType] | None, + future: asyncio.Future[ScreenResultType | None] | None = None, + ) -> None: + """Add a result callback to the screen. + + Args: + requester: The object requesting the callback. + callback: The callback. + future: A Future to hold the result. + """ + self._result_callbacks.append( + ResultCallback[Optional[ScreenResultType]](requester, callback, future) + ) + + async def _message_loop_exit(self) -> None: + await super()._message_loop_exit() + self._compositor.clear() + self._dirty_widgets.clear() + self._dirty_regions.clear() + self._arrangement_cache.clear() + self.screen_layout_refresh_signal.unsubscribe(self) + self._nodes._clear() + self._task = None + + def _pop_result_callback(self) -> None: + """Remove the latest result callback from the stack.""" + self._result_callbacks.pop() + + def _refresh_layout(self, size: Size | None = None, scroll: bool = False) -> None: + """Refresh the layout (can change size and positions of widgets).""" + size = self.outer_size if size is None else size + if self.app.is_inline: + size = size.with_height(self.app._get_inline_height()) + if not size: + return + self._compositor.update_widgets(self._dirty_widgets) + self._update_timer.pause() + ResizeEvent = events.Resize + + try: + if scroll and not self._layout_widgets: + exposed_widgets = self._compositor.reflow_visible(self, size) + if exposed_widgets: + layers = self._compositor.layers + for widget, ( + region, + _order, + _clip, + virtual_size, + container_size, + _, + _, + ) in layers: + if widget in exposed_widgets: + if widget._size_updated( + region.size, virtual_size, container_size, layout=False + ): + widget.post_message( + ResizeEvent( + region.size, virtual_size, container_size + ) + ) + + else: + hidden, shown, resized = self._compositor.reflow(self, size) + self._layout_widgets.clear() + Hide = events.Hide + Show = events.Show + + for widget in hidden: + widget.post_message(Hide()) + + # We want to send a resize event to widgets that were just added or change since last layout + send_resize = shown | resized + + layers = self._compositor.layers + for widget, ( + region, + _order, + _clip, + virtual_size, + container_size, + _, + _, + ) in layers: + widget._size_updated(region.size, virtual_size, container_size) + if widget in send_resize: + widget.post_message( + ResizeEvent(region.size, virtual_size, container_size) + ) + + for widget in shown: + widget.post_message(Show()) + + except Exception as error: + self.app._handle_exception(error) + return + + if self.is_current: + if self.app._batch_count: + self.call_later(self._compositor_refresh) + else: + self._compositor_refresh() + + if self.app._dom_ready: + self.screen_layout_refresh_signal.publish(self.screen) + else: + self.app.post_message(events.Ready()) + self.app._dom_ready = True + + async def _on_update(self, message: messages.Update) -> None: + message.stop() + message.prevent_default() + widget = message.widget + assert isinstance(widget, Widget) + + if self in self._compositor: + self._dirty_widgets.add(widget) + self.check_idle() + + async def _on_layout(self, message: messages.Layout) -> None: + message.stop() + message.prevent_default() + + layout_required = False + widget: DOMNode = message.widget + for ancestor in message.widget.ancestors: + if not isinstance(ancestor, Widget): + break + if ancestor not in self._layout_widgets: + self._layout_widgets[ancestor] = set() + if widget not in self._layout_widgets: + self._layout_widgets[ancestor].add(widget) + layout_required = True + if not ancestor.styles.auto_dimensions: + break + widget = ancestor + + if layout_required and not self._layout_required: + self._layout_required = True + self.check_idle() + + async def _on_update_scroll(self, message: messages.UpdateScroll) -> None: + message.stop() + message.prevent_default() + self._scroll_required = True + self.check_idle() + + def _get_inline_height(self, size: Size) -> int: + """Get the inline height (number of lines to display when running inline mode). + + Args: + size: Size of the terminal + + Returns: + Height for inline mode. + """ + height_scalar = self.styles.height + if height_scalar is None or height_scalar.is_auto: + inline_height = self.get_content_height(size, size, size.width) + else: + inline_height = int(height_scalar.resolve(size, size)) + inline_height += self.styles.gutter.height + min_height = self.styles.min_height + max_height = self.styles.max_height + if min_height is not None: + inline_height = max(inline_height, int(min_height.resolve(size, size))) + if max_height is not None: + inline_height = min(inline_height, int(max_height.resolve(size, size))) + inline_height = min(self.app.size.height, inline_height) + return inline_height + + def _screen_resized(self, size: Size) -> None: + """Called by App when the screen is resized.""" + if self.stack_updates and self.is_attached: + self._refresh_layout(size) + + def _on_screen_resume(self, event: events.ScreenResume) -> None: + """Screen has resumed.""" + if self.app.SUSPENDED_SCREEN_CLASS: + self.remove_class(self.app.SUSPENDED_SCREEN_CLASS) + + self.stack_updates += 1 + + self.app._refresh_notifications() + size = self.app.size + + self._update_auto_focus() + + if self.is_attached: + + if event.refresh_styles: + self.update_node_styles(animate=False) + if self._size != size: + self._refresh_layout(size) + self.refresh() + + async def _compose(self) -> None: + await super()._compose() + self._update_auto_focus() + + def _update_auto_focus(self) -> None: + """Update auto focus.""" + if self.app.app_focus: + auto_focus = ( + self.app.AUTO_FOCUS if self.AUTO_FOCUS is None else self.AUTO_FOCUS + ) + if auto_focus and self.focused is None: + for widget in self.query(auto_focus): + if widget.focusable: + widget.has_focus = True + self.set_focus(widget) + break + + def _on_screen_suspend(self) -> None: + """Screen has suspended.""" + if self.app.SUSPENDED_SCREEN_CLASS: + self.add_class(self.app.SUSPENDED_SCREEN_CLASS) + self.app._set_mouse_over(None, None) + self._clear_tooltip() + self.stack_updates += 1 + + async def _on_resize(self, event: events.Resize) -> None: + event.stop() + self._screen_resized(event.size) + for screen in self.app._background_screens: + screen._screen_resized(event.size) + + horizontal_breakpoints = ( + self.app.HORIZONTAL_BREAKPOINTS + if self.HORIZONTAL_BREAKPOINTS is None + else self.HORIZONTAL_BREAKPOINTS + ) or [] + + vertical_breakpoints = ( + self.app.VERTICAL_BREAKPOINTS + if self.VERTICAL_BREAKPOINTS is None + else self.VERTICAL_BREAKPOINTS + ) or [] + + width, height = event.size + if horizontal_breakpoints: + self._set_breakpoints(width, horizontal_breakpoints) + if vertical_breakpoints: + self._set_breakpoints(height, vertical_breakpoints) + + def _set_breakpoints( + self, dimension: int, breakpoints: list[tuple[int, str]] + ) -> None: + """Set horizontal or vertical breakpoints. + + Args: + dimension: Either the width or the height. + breakpoints: A list of breakpoints. + + """ + class_names = [class_name for _breakpoint, class_name in breakpoints] + self.remove_class(*class_names) + for breakpoint, class_name in sorted(breakpoints, reverse=True): + if dimension >= breakpoint: + self.add_class(class_name) + return + + def _update_tooltip(self, widget: Widget) -> None: + """Update the content of the tooltip.""" + try: + tooltip = self.get_child_by_type(Tooltip) + except NoMatches: + pass + else: + if tooltip.display and self._tooltip_widget is widget: + self._handle_tooltip_timer(widget) + + def _clear_tooltip(self) -> None: + """Unconditionally clear any existing tooltip.""" + try: + tooltip = self.get_child_by_type(Tooltip) + except NoMatches: + return + if tooltip.display: + if self._tooltip_timer is not None: + self._tooltip_timer.stop() + tooltip.display = False + + def _maybe_clear_tooltip(self, _) -> None: + """Check if the widget under the mouse cursor still pertains to the tooltip. + + If they differ, the tooltip will be removed. + """ + # If there's a widget associated with the tooltip at all... + if self._tooltip_widget is not None: + # ...look at what's currently under the mouse. + try: + under_mouse, _ = self.get_widget_at(*self.app.mouse_position) + except NoWidget: + pass + else: + # If it's not the same widget... + if under_mouse is not self._tooltip_widget: + # ...clear the tooltip. + self._clear_tooltip() + + def _handle_tooltip_timer(self, widget: Widget) -> None: + """Called by a timer from _handle_mouse_move to update the tooltip. + + Args: + widget: The widget under the mouse. + """ + + try: + tooltip = self.get_child_by_type(Tooltip) + except NoMatches: + pass + else: + tooltip_content: RenderableType | None = None + for node in widget.ancestors_with_self: + if not isinstance(node, Widget): + break + if node.tooltip is not None: + tooltip_content = node.tooltip + break + + if tooltip_content is None: + tooltip.display = False + else: + tooltip.display = True + tooltip.absolute_offset = self.app.mouse_position + tooltip.update(tooltip_content) + + def _handle_mouse_move(self, event: events.MouseMove) -> None: + hover_widget: Widget | None = None + try: + if self.app.mouse_captured: + widget = self.app.mouse_captured + region = self.find_widget(widget).region + else: + (widget, region), hover = self.get_hover_widgets_at(event.x, event.y) + if hover is not None: + hover_widget = hover[0] + except errors.NoWidget: + self.app._set_mouse_over(None, None) + if self._tooltip_timer is not None: + self._tooltip_timer.stop() + if not self.app._disable_tooltips: + try: + self.get_child_by_type(Tooltip).display = False + except NoMatches: + pass + else: + self.app._set_mouse_over(widget, hover_widget) + self.update_pointer_shape() + widget.hover_style = event.style + if widget is self: + self.post_message(event) + else: + mouse_event = self._translate_mouse_move_event(event, widget, region) + mouse_event._set_forwarded() + widget._forward_event(mouse_event) + + if not self.app._disable_tooltips: + try: + tooltip = self.get_child_by_type(Tooltip) + except NoMatches: + pass + else: + if self._tooltip_widget != widget or not tooltip.display: + self._tooltip_widget = widget + if self._tooltip_timer is not None: + self._tooltip_timer.stop() + + self._tooltip_timer = self.set_timer( + self.app.TOOLTIP_DELAY, + partial(self._handle_tooltip_timer, widget), + name="tooltip-timer", + ) + else: + tooltip.display = False + self.screen.update_pointer_shape() + + @staticmethod + def _translate_mouse_move_event( + event: events.MouseMove, widget: Widget, region: Region + ) -> events.MouseMove: + """ + Returns a mouse move event whose relative coordinates are translated to + the origin of the specified region. + """ + return events.MouseMove( + widget, + event._x - region.x, + event._y - region.y, + event._delta_x, + event._delta_y, + event.button, + event.shift, + event.meta, + event.ctrl, + screen_x=event._screen_x, + screen_y=event._screen_y, + style=event.style, + ) + + def _start_auto_scroll( + self, + widget: Widget, + direction: Literal[+1, -1], + speed: float = 1.0, + ) -> None: + """Start (or update) auto scrolling. + + Args: + widget: Container widget to scroll. + direction: Direction: `+1` for up, `-1` for down. + speed: The scroll speed as a factor of the maximum. + """ + assert speed > 0, "Speed should be positive and non-zero" + + def _auto_scroll_y(widget: Widget, direction: float) -> None: + """Scroll a container a single line in the given direction. + + Args: + widget: Container widgets to scroll. + direction: Lines to scroll. + """ + if self._select_start is not None: + # Update scroll position + widget.scroll_y += direction + widget.scroll_target_y = widget.scroll_y + # Update selection highlights which may have changed due to the scroll + self._update_select(self.app.mouse_position) + + # Replace current timer + self._stop_auto_scroll() + + # Lines to scroll per frame (may be fractional) + lines_to_scroll = ( + direction * (self.app.SELECT_AUTO_SCROLL_SPEED / constants.MAX_FPS) * speed + ) + # Callable to perform scroll + scroll_callback = partial(_auto_scroll_y, widget, lines_to_scroll) + # Perform initial scroll + scroll_callback() + # Start a timer to perform future scrolling + # This is so the user doesn't have to move the mouse to keep scrolling + self._auto_select_scroll_timer = self.set_interval( + 1 / constants.MAX_FPS, scroll_callback + ) + + def _stop_auto_scroll(self) -> None: + """Stop any auto scrolling.""" + if self._auto_select_scroll_timer is not None: + self._auto_select_scroll_timer.stop() + self._auto_select_scroll_timer = None + + def _check_auto_scroll( + self, + select_widget: Widget, + mouse_coordinate: tuple[float, float], + delta_y: float, + ) -> None: + """Check auto-scrolling when selecting. + + This will start, update, or stop a timer used to move the scroll position. + + Args: + select_widget: The widget under the mouise pointer. + mouse_coordinate: The screen-space mouse pointer. + delta_y: Change in mouse y since previous mouse move. + """ + + if not self.app.ENABLE_SELECT_AUTO_SCROLL: + # Disabled by app + return + + if self._auto_select_scroll_timer is None and abs(delta_y) < 1: + # Mouse has moved horizontally, not vertically, so we assume the user doesn't want to scroll + return + + mouse_x, mouse_y = mouse_coordinate + mouse_offset = Offset(int(mouse_x), int(mouse_y)) + + # We want to find any scrollable regions further up the DOM, + # and apply auto scrolling if we are in a region at the top or bottom + for ancestor in select_widget.ancestors_with_self: + if not isinstance(ancestor, Widget): + break + if not ancestor.allow_vertical_scroll: + # Can't scroll, so check the next ancestor + continue + ancestor_region = ancestor.content_region + scroll_lines = self.app.SELECT_AUTO_SCROLL_LINES + up_region, down_region = get_auto_scroll_regions( + ancestor_region, + auto_scroll_lines=scroll_lines, + ) + if mouse_offset in up_region: + # Mouse is in the up region + if ancestor.scroll_y > 0: + # And there is room to scroll + # Speed increases the closer we are to the edge + speed = (scroll_lines - (mouse_y - up_region.y)) / scroll_lines + if speed: + self._start_auto_scroll(ancestor, -1, speed) + return + elif mouse_offset in down_region: + # Mouse is in the down region + if ancestor.scroll_y < ancestor.max_scroll_y: + # And there is room to scroll + speed = (mouse_y - down_region.y) / scroll_lines + if speed: + self._start_auto_scroll(ancestor, +1, speed) + return + # Nothing to auto scroll, so stop the timer + self._stop_auto_scroll() + + def _update_select(self, screen_offset: Offset) -> None: + """Update select for a screen-space offset (typically the mouse position). + + This updates the `_select_end` reactrive, which will trigger the watch method `watch__select_end`. + + Args: + screen_offset: Screen-space position (i.e. mouse position). + """ + select_widget, select_offset = self.get_widget_and_offset_at( + screen_offset.x, screen_offset.y + ) + if ( + self._select_end is not None + and select_offset is None + and screen_offset.y > self._select_end[1].y + ): + end_widget = self._select_end[0] + select_offset = end_widget.content_region.bottom_right_inclusive + self._select_end = ( + end_widget, + screen_offset, + select_offset, + ) + + elif ( + select_widget is not None + and select_widget.allow_select + and select_offset is not None + ): + self._select_end = ( + select_widget, + screen_offset, + select_offset, + ) + + def _forward_event(self, event: events.Event) -> None: + if event.is_forwarded: + return + event._set_forwarded() + + if isinstance(event, (events.Enter, events.Leave)): + self.post_message(event) + + elif isinstance(event, events.MouseMove): + event.style = self.get_style_at(event.screen_x, event.screen_y) + self._handle_mouse_move(event) + + if self._selecting and self._select_start is not None: + + self._box_select = event.shift + select_widget, select_offset = self.get_widget_and_offset_at( + event.x, event.y + ) + if ( + self._select_end is not None + and select_offset is None + and event.y > self._select_end[1].y + ): + end_widget = self._select_end[0] + select_offset = end_widget.content_region.bottom_right_inclusive + self._select_end = ( + end_widget, + event.screen_offset, + select_offset, + ) + + elif ( + select_widget is not None + and select_widget.allow_select + and select_offset is not None + ): + self._select_end = ( + select_widget, + event.screen_offset, + select_offset, + ) + + if select_widget is not None: + self._check_auto_scroll( + select_widget, + (event.pointer_screen_x, event.pointer_screen_y), + event.delta_y, + ) + else: + self._stop_auto_scroll() + + elif isinstance(event, events.MouseEvent): + if isinstance(event, events.MouseUp): + if ( + self._mouse_down_offset is not None + and self._mouse_down_offset == event.screen_offset + ): + # A click elsewhere should clear the selection + select_widget, select_offset = self.get_widget_and_offset_at( + event.x, event.y + ) + # Exclude scrollbars, so the user may navigate without clearing the selection + if select_widget is None or not select_widget.has_class( + "-textual-system" + ): + self.clear_selection() + + self._mouse_down_offset = None + self._selecting = False + self.post_message(events.TextSelected()) + + elif isinstance(event, events.MouseDown) and not self.app.mouse_captured: + self._box_select = event.shift + self._mouse_down_offset = event.screen_offset + select_widget, select_offset = self.get_widget_and_offset_at( + event.screen_x, event.screen_y + ) + if ( + select_widget is not None + and select_widget.allow_select + and self.screen.allow_select + and self.app.ALLOW_SELECT + ): + self._selecting = True + if select_widget is not None and select_offset is not None: + self.text_selection_started_signal.publish(self) + self._select_start = ( + select_widget, + event.screen_offset, + select_offset, + ) + else: + self._selecting = False + + try: + if self.app.mouse_captured: + widget = self.app.mouse_captured + region = self.find_widget(widget).region + else: + widget, region = self.get_widget_at(event.x, event.y) + except errors.NoWidget: + self.set_focus(None) + else: + if isinstance(event, events.MouseDown): + focusable_widget = self.get_focusable_widget_at(event.x, event.y) + if ( + focusable_widget is not None + and focusable_widget.focus_on_click() + ): + self.set_focus(focusable_widget, scroll_visible=False) + event.style = self.get_style_at(event.screen_x, event.screen_y) + if widget.loading: + return + if widget is self: + event._set_forwarded() + self.post_message(event) + else: + widget._forward_event(event._apply_offset(-region.x, -region.y)) + + else: + self.post_message(event) + self.update_pointer_shape() + + def _key_escape(self) -> None: + self.clear_selection() + + def _watch__selecting(self, selecting: bool) -> None: + if not selecting: + self._stop_auto_scroll() + + @classmethod + def _collect_select_widgets( + cls, + selection_bounds: Shape, + container: Widget, + start_widget: Widget, + end_widget: Widget, + ) -> list[Widget]: + """Get widgets between two widgets in select order. + + Args: + container: A parent widgets. + start_widget: First widget. + end_widget: Second widget. + + Returns: + Widgets between start and end, in select sort order. + """ + + widgets = list( + walk_selectable_widgets( + container, + selection_bounds, + {start_widget, end_widget}, + ) + ) + + index1: int | None = None + try: + index1 = widgets.index(start_widget) + 1 + except ValueError: + pass + + index2: int | None = None + try: + index2 = widgets.index(end_widget) + except ValueError: + pass + + results = widgets[index1:index2] + return results + + def _watch__select_end( + self, select_end: tuple[Widget, Offset, Offset] | None + ) -> None: + """When select_end changes, we need to compute which widgets and regions are selected. + + Args: + select_end: The end selection. + """ + + if select_end is None or self._select_start is None: + # Nothing to select + return + + start_widget, screen_start, start_offset = self._select_start + end_widget, screen_end, end_offset = select_end + + if not start_widget.is_attached or not end_widget.is_attached: + # Widgets may have been removed since selection started + return + + if start_widget is end_widget: + # Simplest case, selection starts and ends on the same widget + if end_offset.transpose < start_offset.transpose: + start_offset, end_offset = end_offset, start_offset + self.selections = { + start_widget: Selection.from_offsets( + start_offset, + end_offset + (1, 0), + ) + } + return + + # The start selection may have been scrolled since it was saved + # We need to adjust to the new screen-space position + select_start = (start_widget, start_widget.region.offset, start_offset) + # Ensure select_start is < select_end in selection order + if select_start[0]._selection_order > select_end[0]._selection_order: + select_start, select_end = select_end, select_start + + start_widget, screen_start, start_offset = select_start + end_widget, screen_end, end_offset = select_end + + if (screen_start + start_offset).transpose > ( + screen_end + end_offset + ).transpose: + start_widget, end_widget = end_widget, start_widget + + # Get a widget which contains both widgets + container_widget = Widget.get_common_ancestor( + start_widget, end_widget, default=self + ) + + # Get a selection bounds shape + selection_bounds = Shape.selection_bounds( + container_widget.region, + select_start[1] + select_start[2], + self.app.mouse_position, + ) + + # Get widgets bounded by the selection bounds + select_widgets = self._collect_select_widgets( + selection_bounds, + container_widget, + start_widget, + end_widget, + ) + + # Build the selection + select_all = SELECT_ALL + self.selections = { + start_widget: Selection(start_offset, None), + **{widget: select_all for widget in select_widgets}, + end_widget: Selection(None, end_offset + (1, 0)), + } + + def dismiss(self, result: ScreenResultType | None = None) -> AwaitComplete: + """Dismiss the screen, optionally with a result. + + Any callback provided in [push_screen][textual.app.App.push_screen] will be invoked with the supplied result. + + !!! warning + + Textual will raise a [`ScreenError`][textual.app.ScreenError] if you await the return value from a + message handler on the Screen being dismissed. If you want to dismiss the current screen, you can + call `self.dismiss()` _without_ awaiting. + + Args: + result: The optional result to be passed to the result callback. + + """ + _rich_traceback_omit = True + if self._result_callbacks: + callback = self._result_callbacks[-1] + callback(result) + await_pop = self.app.pop_screen() + + def pre_await() -> None: + """Called by the AwaitComplete object.""" + _rich_traceback_omit = True + if active_message_pump.get() is self: + from memray._vendor.textual.app import ScreenError + + raise ScreenError( + "Can't await screen.dismiss() from the screen's message handler; try removing the await keyword." + ) + + await_pop.set_pre_await_callback(pre_await) + + return await_pop + + def pop_until_active(self) -> None: + """Pop any screens on top of this one, until this screen is active. + + Raises: + ScreenError: If this screen is not in the current mode. + + """ + from memray._vendor.textual.app import ScreenError + + try: + self.app._pop_to_screen(self) + except ScreenError: + # More specific error message + raise ScreenError( + f"Can't make {self} active as it is not in the current stack." + ) from None + + async def action_dismiss(self, result: ScreenResultType | None = None) -> None: + """A wrapper around [`dismiss`][textual.screen.Screen.dismiss] that can be called as an action. + + Args: + result: The optional result to be passed to the result callback. + """ + await self._flush_next_callbacks() + self.dismiss(result) + + def can_view_entire(self, widget: Widget) -> bool: + """Check if a given widget is fully within the current screen. + + Note: This doesn't necessarily equate to a widget being visible. + There are other reasons why a widget may not be visible. + + Args: + widget: A widget. + + Returns: + `True` if the entire widget is in view, `False` if it is partially visible or not in view. + """ + if widget not in self._compositor.visible_widgets: + return False + # If the widget is one that overlays the screen... + if widget.styles.overlay == "screen": + # ...simply check if it's within the screen's region. + return widget.region in self.region + # Failing that fall back to normal checking. + return super().can_view_entire(widget) + + def can_view_partial(self, widget: Widget) -> bool: + """Check if a given widget is at least partially within the current view. + + Args: + widget: A widget. + + Returns: + `True` if the any part of the widget is in view, `False` if it is completely outside of the screen. + """ + if widget not in self._compositor.visible_widgets: + return False + # If the widget is one that overlays the screen... + if widget.styles.overlay == "screen": + # ...simply check if it's within the screen's region. + return widget.region in self.region + # Failing that fall back to normal checking. + return super().can_view_partial(widget) + + def validate_title(self, title: Any) -> str | None: + """Ensure the title is a string or `None`.""" + return None if title is None else str(title) + + def validate_sub_title(self, sub_title: Any) -> str | None: + """Ensure the sub-title is a string or `None`.""" + return None if sub_title is None else str(sub_title) + + +@rich.repr.auto +class ModalScreen(Screen[ScreenResultType]): + """A screen with bindings that take precedence over the App's key bindings. + + The default styling of a modal screen will dim the screen underneath. + """ + + DEFAULT_CSS = """ + ModalScreen { + layout: vertical; + overflow-y: auto; + background: $background 60%; + &:ansi { + background: transparent; + } + } + """ + + def __init__( + self, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + ) -> None: + super().__init__(name=name, id=id, classes=classes) + self._modal = True + + +class SystemModalScreen(ModalScreen[ScreenResultType], inherit_css=False): + """A variant of `ModalScreen` for internal use. + + This version of `ModalScreen` allows us to build system-level screens; + the type being used to indicate that the screen should be isolated from + the main application. + + Note: + This screen is set to *not* inherit CSS. + """ diff --git a/src/memray/_vendor/textual/scroll_view.py b/src/memray/_vendor/textual/scroll_view.py new file mode 100644 index 0000000000..9b305a8892 --- /dev/null +++ b/src/memray/_vendor/textual/scroll_view.py @@ -0,0 +1,196 @@ +""" +`ScrollView` is a base class for [Line API](/guide/widgets#line-api) widgets. +""" + +from __future__ import annotations + +from rich.console import RenderableType + +from memray._vendor.textual._animator import EasingFunction +from memray._vendor.textual._types import AnimationLevel, CallbackType +from memray._vendor.textual.containers import ScrollableContainer +from memray._vendor.textual.geometry import Region, Size + + +class ScrollView(ScrollableContainer): + """ + A base class for a Widget that handles its own scrolling (i.e. doesn't rely + on the compositor to render children). + + !!! note + + This is the typically wrong class for making something scrollable. If you want to make something scroll, set its + `overflow` style to auto or scroll. Or use one of the pre-defined scrolling containers such as [VerticalScroll][textual.containers.VerticalScroll]. + """ + + ALLOW_MAXIMIZE = True + + DEFAULT_CSS = """ + ScrollView { + overflow-y: auto; + overflow-x: auto; + } + """ + + @property + def is_scrollable(self) -> bool: + """Always scrollable.""" + return True + + @property + def is_container(self) -> bool: + """Since a ScrollView should be a line-api widget, it won't have children, + and therefore isn't a container.""" + return False + + def watch_scroll_x(self, old_value: float, new_value: float) -> None: + if self.show_horizontal_scrollbar: + self.horizontal_scrollbar.position = new_value + if round(old_value) != round(new_value): + self.refresh(self.size.region) + + def watch_scroll_y(self, old_value: float, new_value: float) -> None: + if self.show_vertical_scrollbar: + self.vertical_scrollbar.position = new_value + if round(old_value) != round(new_value): + self.refresh(self.size.region) + + def on_mount(self): + self._refresh_scrollbars() + + def get_content_width(self, container: Size, viewport: Size) -> int: + """Gets the width of the content area. + + Args: + container: Size of the container (immediate parent) widget. + viewport: Size of the viewport. + + Returns: + The optimal width of the content. + """ + return self.virtual_size.width + + def get_content_height(self, container: Size, viewport: Size, width: int) -> int: + """Gets the height (number of lines) in the content area. + + Args: + container: Size of the container (immediate parent) widget. + viewport: Size of the viewport. + width: Width of renderable. + + Returns: + The height of the content. + """ + return self.virtual_size.height + + def _size_updated( + self, size: Size, virtual_size: Size, container_size: Size, layout: bool = True + ) -> bool: + """Called when size is updated. + + Args: + size: New size. + virtual_size: New virtual size. + container_size: New container size. + layout: Perform layout if required. + + Returns: + True if a resize event should be sent, otherwise False. + """ + if size_changed := self._size != size: + self._set_dirty() + if ( + size_changed + or virtual_size != self.virtual_size + or container_size != self.container_size + ): + self._scrollbar_changes.clear() + self._size = size + virtual_size = self.virtual_size + self._container_size = size - self.styles.gutter.totals + self._scroll_update(virtual_size) + + return size_changed or self._container_size != container_size + + def render(self) -> RenderableType: + """Render the scrollable region (if `render_lines` is not implemented). + + Returns: + Renderable object. + """ + from rich.panel import Panel + + return Panel(f"{self.scroll_offset} {self.show_vertical_scrollbar}") + + # Custom scroll to which doesn't require call_after_refresh + def scroll_to( + self, + x: float | None = None, + y: float | None = None, + *, + animate: bool = True, + speed: float | None = None, + duration: float | None = None, + easing: EasingFunction | str | None = None, + force: bool = False, + on_complete: CallbackType | None = None, + level: AnimationLevel = "basic", + immediate: bool = False, + ) -> None: + """Scroll to a given (absolute) coordinate, optionally animating. + + Args: + x: X coordinate (column) to scroll to, or `None` for no change. + y: Y coordinate (row) to scroll to, or `None` for no change. + animate: Animate to new scroll position. + speed: Speed of scroll if `animate` is `True`; or `None` to use `duration`. + duration: Duration of animation, if `animate` is `True` and `speed` is `None`. + easing: An easing method for the scrolling animation. + force: Force scrolling even when prohibited by overflow styling. + on_complete: A callable to invoke when the animation is finished. + level: Minimum level required for the animation to take place (inclusive). + immediate: If `False` the scroll will be deferred until after a screen refresh, + set to `True` to scroll immediately. + """ + + self._scroll_to( + x, + y, + animate=animate, + speed=speed, + duration=duration, + easing=easing, + force=force, + on_complete=on_complete, + level=level, + ) + + def refresh_line(self, y: int) -> None: + """Refresh a single line. + + Args: + y: Coordinate of line. + """ + self.refresh( + Region( + 0, + y - self.scroll_offset.y, + max(self.virtual_size.width, self.size.width), + 1, + ) + ) + + def refresh_lines(self, y_start: int, line_count: int = 1) -> None: + """Refresh one or more lines. + + Args: + y_start: First line to refresh. + line_count: Total number of lines to refresh. + """ + refresh_region = Region( + 0, + y_start - self.scroll_offset.y, + max(self.virtual_size.width, self.size.width), + line_count, + ) + self.refresh(refresh_region) diff --git a/src/memray/_vendor/textual/scrollbar.py b/src/memray/_vendor/textual/scrollbar.py new file mode 100644 index 0000000000..a91922db6c --- /dev/null +++ b/src/memray/_vendor/textual/scrollbar.py @@ -0,0 +1,411 @@ +""" +Contains the widgets that manage Textual scrollbars. + +!!! note + + You will not typically need this for most apps. + +""" + +from __future__ import annotations + +from math import ceil +from typing import ClassVar, Type + +import rich.repr +from rich.color import Color +from rich.console import Console, ConsoleOptions, RenderableType, RenderResult +from rich.segment import Segment, Segments +from rich.style import Style, StyleType + +from memray._vendor.textual import events +from memray._vendor.textual.geometry import Offset +from memray._vendor.textual.message import Message +from memray._vendor.textual.reactive import Reactive +from memray._vendor.textual.renderables.blank import Blank +from memray._vendor.textual.widget import Widget + + +class ScrollMessage(Message, bubble=False): + """Base class for all scrollbar messages.""" + + +@rich.repr.auto +class ScrollUp(ScrollMessage, verbose=True): + """Message sent when clicking above handle.""" + + +@rich.repr.auto +class ScrollDown(ScrollMessage, verbose=True): + """Message sent when clicking below handle.""" + + +@rich.repr.auto +class ScrollLeft(ScrollMessage, verbose=True): + """Message sent when clicking above handle.""" + + +@rich.repr.auto +class ScrollRight(ScrollMessage, verbose=True): + """Message sent when clicking below handle.""" + + +class ScrollTo(ScrollMessage, verbose=True): + """Message sent when click and dragging handle.""" + + __slots__ = ["x", "y", "animate"] + + def __init__( + self, + x: float | None = None, + y: float | None = None, + animate: bool = True, + ) -> None: + self.x = x + self.y = y + self.animate = animate + super().__init__() + + def __rich_repr__(self) -> rich.repr.Result: + yield "x", self.x, None + yield "y", self.y, None + yield "animate", self.animate, True + + +class ScrollBarRender: + VERTICAL_BARS: ClassVar[list[str]] = ["▁", "▂", "▃", "▄", "▅", "▆", "▇", " "] + """Glyphs used for vertical scrollbar ends, for smoother display.""" + HORIZONTAL_BARS: ClassVar[list[str]] = ["▉", "▊", "▋", "▌", "▍", "▎", "▏", " "] + """Glyphs used for horizontal scrollbar ends, for smoother display.""" + BLANK_GLYPH: ClassVar[str] = " " + """Glyph used for the main body of the scrollbar""" + + def __init__( + self, + virtual_size: int = 100, + window_size: int = 0, + position: float = 0, + thickness: int = 1, + vertical: bool = True, + style: StyleType = "bright_magenta on #555555", + ) -> None: + self.virtual_size = virtual_size + self.window_size = window_size + self.position = position + self.thickness = thickness + self.vertical = vertical + self.style = style + + @classmethod + def render_bar( + cls, + size: int = 25, + virtual_size: float = 50, + window_size: float = 20, + position: float = 0, + thickness: int = 1, + vertical: bool = True, + back_color: Color = Color.parse("#555555"), + bar_color: Color = Color.parse("bright_magenta"), + ) -> Segments: + if vertical: + bars = cls.VERTICAL_BARS + else: + bars = cls.HORIZONTAL_BARS + + back = back_color + bar = bar_color + + len_bars = len(bars) + + width_thickness = thickness if vertical else 1 + + _Segment = Segment + _Style = Style + blank = cls.BLANK_GLYPH * width_thickness + + foreground_meta = {"@mouse.down": "grab"} + if window_size and size and virtual_size and size != virtual_size: + bar_ratio = virtual_size / size + thumb_size = max(1, window_size / bar_ratio) + + position_ratio = position / (virtual_size - window_size) + position = (size - thumb_size) * position_ratio + + start = int(position * len_bars) + end = start + ceil(thumb_size * len_bars) + + start_index, start_bar = divmod(max(0, start), len_bars) + end_index, end_bar = divmod(max(0, end), len_bars) + + upper = {"@mouse.down": "scroll_up"} + lower = {"@mouse.down": "scroll_down"} + + upper_back_segment = Segment(blank, _Style(bgcolor=back, meta=upper)) + lower_back_segment = Segment(blank, _Style(bgcolor=back, meta=lower)) + + segments = [upper_back_segment] * int(size) + segments[end_index:] = [lower_back_segment] * (size - end_index) + + segments[start_index:end_index] = [ + _Segment(blank, _Style(color=bar, reverse=True, meta=foreground_meta)) + ] * (end_index - start_index) + + # Apply the smaller bar characters to head and tail of scrollbar for more "granularity" + if start_index < len(segments): + bar_character = bars[len_bars - 1 - start_bar] + if bar_character != " ": + segments[start_index] = _Segment( + bar_character * width_thickness, + ( + _Style(bgcolor=back, color=bar, meta=foreground_meta) + if vertical + else _Style( + bgcolor=back, + color=bar, + meta=foreground_meta, + reverse=True, + ) + ), + ) + if end_index < len(segments): + bar_character = bars[len_bars - 1 - end_bar] + if bar_character != " ": + segments[end_index] = _Segment( + bar_character * width_thickness, + ( + _Style( + bgcolor=back, + color=bar, + meta=foreground_meta, + reverse=True, + ) + if vertical + else _Style(bgcolor=back, color=bar, meta=foreground_meta) + ), + ) + else: + style = _Style(bgcolor=back) + segments = [_Segment(blank, style=style)] * int(size) + if vertical: + return Segments(segments, new_lines=True) + else: + return Segments((segments + [_Segment.line()]) * thickness, new_lines=False) + + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + size = ( + (options.height or console.height) + if self.vertical + else (options.max_width or console.width) + ) + thickness = ( + (options.max_width or console.width) + if self.vertical + else (options.height or console.height) + ) + + _style = console.get_style(self.style) + + bar = self.render_bar( + size=size, + window_size=self.window_size, + virtual_size=self.virtual_size, + position=self.position, + vertical=self.vertical, + thickness=thickness, + back_color=_style.bgcolor or Color.parse("#555555"), + bar_color=_style.color or Color.parse("bright_magenta"), + ) + yield bar + + +@rich.repr.auto +class ScrollBar(Widget): + renderer: ClassVar[Type[ScrollBarRender]] = ScrollBarRender + """The class used for rendering scrollbars. + This can be overridden and set to a ScrollBarRender-derived class + in order to delegate all scrollbar rendering to that class. E.g.: + + ``` + class MyScrollBarRender(ScrollBarRender): ... + + app = MyApp() + ScrollBar.renderer = MyScrollBarRender + app.run() + ``` + + Because this variable is accessed through specific instances + (rather than through the class ScrollBar itself) it is also possible + to set this on specific scrollbar instance to change only that + instance: + + ``` + my_widget.horizontal_scrollbar.renderer = MyScrollBarRender + ``` + """ + + DEFAULT_CLASSES = "-textual-system" + + # Nothing to select in scrollbars + ALLOW_SELECT = False + + def __init__( + self, vertical: bool = True, name: str | None = None, *, thickness: int = 1 + ) -> None: + self.vertical = vertical + self.thickness = thickness + self.grabbed_position: float = 0 + super().__init__(name=name) + self.set_reactive(ScrollBar.auto_links, False) + + window_virtual_size: Reactive[int] = Reactive(100) + window_size: Reactive[int] = Reactive(0) + position: Reactive[float] = Reactive(0) + mouse_over: Reactive[bool] = Reactive(False) + grabbed: Reactive[Offset | None] = Reactive(None) + + def __rich_repr__(self) -> rich.repr.Result: + yield from super().__rich_repr__() + yield "window_virtual_size", self.window_virtual_size + yield "window_size", self.window_size + yield "position", self.position + if self.thickness > 1: + yield "thickness", self.thickness + + def validate_position(self, position: float) -> float: + """Position has a granulatory of 1/8 of a cell.""" + return int(position * 8) / 8 + + def render(self) -> RenderableType: + assert self.parent is not None + styles = self.parent.styles + if self.grabbed: + background = styles.scrollbar_background_active + color = styles.scrollbar_color_active + elif self.mouse_over: + background = styles.scrollbar_background_hover + color = styles.scrollbar_color_hover + else: + background = styles.scrollbar_background + color = styles.scrollbar_color + if background.a < 1: + base_background, _ = self.parent.background_colors + background = base_background + background + color = background + color + scrollbar_style = Style.from_color(color.rich_color, background.rich_color) + if self.screen.styles.scrollbar_color.a == 0: + return self.renderer(vertical=self.vertical, style=scrollbar_style) + return self._render_bar(scrollbar_style) + + def _render_bar(self, scrollbar_style: Style) -> RenderableType: + """Get a renderable for the scrollbar with given style. + + Args: + scrollbar_style: Scrollbar style. + + Returns: + Scrollbar renderable. + """ + window_size = ( + self.window_size if self.window_size < self.window_virtual_size else 0 + ) + virtual_size = self.window_virtual_size + + return self.renderer( + virtual_size=ceil(virtual_size), + window_size=ceil(window_size), + position=self.position, + thickness=self.thickness, + vertical=self.vertical, + style=scrollbar_style, + ) + + def _on_hide(self, event: events.Hide) -> None: + if self.grabbed: + self.release_mouse() + self.grabbed = None + + def _on_enter(self, event: events.Enter) -> None: + if event.node is self: + self.mouse_over = True + + def _on_leave(self, event: events.Leave) -> None: + if event.node is self: + self.mouse_over = False + + def action_scroll_down(self) -> None: + """Scroll vertical scrollbars down, horizontal scrollbars right.""" + if not self.grabbed: + self.post_message(ScrollDown() if self.vertical else ScrollRight()) + + def action_scroll_up(self) -> None: + """Scroll vertical scrollbars up, horizontal scrollbars left.""" + if not self.grabbed: + self.post_message(ScrollUp() if self.vertical else ScrollLeft()) + + def action_grab(self) -> None: + """Begin capturing the mouse cursor.""" + self.capture_mouse() + + async def _on_mouse_down(self, event: events.MouseDown) -> None: + # We don't want mouse events on the scrollbar bubbling + event.stop() + + async def _on_mouse_up(self, event: events.MouseUp) -> None: + if self.grabbed: + self.release_mouse() + self.grabbed = None + event.stop() + + def _on_mouse_capture(self, event: events.MouseCapture) -> None: + self.app._realtime_animation_begin() + self.styles.pointer = "grabbing" + if isinstance(self._parent, Widget): + self._parent.release_anchor() + self.grabbed = event.mouse_position + self.grabbed_position = self.position + + def _on_mouse_release(self, event: events.MouseRelease) -> None: + self.app._realtime_animation_complete() + self.styles.pointer = "default" + self.grabbed = None + if self.vertical and isinstance(self.parent, Widget): + self.parent._check_anchor() + event.stop() + + async def _on_mouse_move(self, event: events.MouseMove) -> None: + if self.grabbed and self.window_size: + x: float | None = None + y: float | None = None + if self.vertical: + virtual_size = self.window_virtual_size + y = self.grabbed_position + ( + (event._screen_y - self.grabbed.y) + * (virtual_size / self.window_size) + ) + else: + virtual_size = self.window_virtual_size + x = self.grabbed_position + ( + (event._screen_x - self.grabbed.x) + * (virtual_size / self.window_size) + ) + self.post_message( + ScrollTo(x=x, y=y, animate=not self.app.supports_smooth_scrolling) + ) + event.stop() + + async def _on_click(self, event: events.Click) -> None: + event.stop() + + +class ScrollBarCorner(Widget): + """Widget which fills the gap between horizontal and vertical scrollbars, + should they both be present.""" + + def render(self) -> RenderableType: + assert self.parent is not None + styles = self.parent.styles + color = styles.scrollbar_corner_color + return Blank(color) diff --git a/src/memray/_vendor/textual/selection.py b/src/memray/_vendor/textual/selection.py new file mode 100644 index 0000000000..e73b381fc8 --- /dev/null +++ b/src/memray/_vendor/textual/selection.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +from typing import NamedTuple + +from memray._vendor.textual.geometry import Offset + + +class Selection(NamedTuple): + """A selected range of lines.""" + + start: Offset | None + """Offset or None for `start`.""" + end: Offset | None + """Offset or None for `end`.""" + + @classmethod + def from_offsets(cls, offset1: Offset, offset2: Offset) -> Selection: + """Create selection from 2 offsets. + + Args: + offset1: First offset. + offset2: Second offset. + + Returns: + New Selection. + """ + offsets = sorted([offset1, offset2], key=(lambda offset: (offset.y, offset.x))) + return cls(*offsets) + + def extract(self, text: str) -> str: + """Extract selection from text. + + Args: + text: Raw text pulled from widget. + + Returns: + Extracted text. + """ + lines = text.splitlines() + if not lines: + return "" + if self.start is None: + start_line = 0 + start_offset = 0 + else: + start_line, start_offset = self.start.transpose + + if self.end is None: + end_line = len(lines) + end_offset = len(lines[-1]) + else: + end_line, end_offset = self.end.transpose + end_line = min(len(lines), end_line) + + if start_line == end_line: + return lines[start_line][start_offset:end_offset] + + selection: list[str] = [] + selected_lines = lines[start_line : end_line + 1] + if len(selected_lines) >= 2: + first_line, *mid_lines, last_line = selected_lines + selection.append(first_line[start_offset:]) + selection.extend(mid_lines) + selection.append(last_line[:end_offset]) + else: + return lines[start_line][start_offset:end_offset] + return "\n".join(selection) + + def get_span(self, y: int) -> tuple[int, int] | None: + """Get the selected span in a given line. + + Args: + y: Offset of the line. + + Returns: + A tuple of x start and end offset, or None for no selection. + """ + start, end = self + if start is None and end is None: + # Selection covers everything + return 0, -1 + + if start is not None and end is not None: + if y < start.y or y > end.y: + # Outside + return None + if y == start.y and start.y == end.y: + # Same line + return start.x, end.x + if y == end.y: + # Last line + return 0, end.x + if y == start.y: + return start.x, -1 + # Remaining lines + return 0, -1 + + if start is None and end is not None: + if y == end.y: + return 0, end.x + if y > end.y: + return None + return 0, -1 + + if end is None and start is not None: + if y == start.y: + return start.x, -1 + if y > start.y: + return 0, -1 + return None + return 0, -1 + + +SELECT_ALL = Selection(None, None) diff --git a/src/memray/_vendor/textual/signal.py b/src/memray/_vendor/textual/signal.py new file mode 100644 index 0000000000..fae0cdc1ad --- /dev/null +++ b/src/memray/_vendor/textual/signal.py @@ -0,0 +1,137 @@ +""" +Signals are a simple pub-sub mechanism. + +DOMNodes can subscribe to a signal, which will invoke a callback when the signal is published. + +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Generic, TypeVar, Union +from weakref import WeakKeyDictionary, ref + +import rich.repr + +from memray._vendor.textual import log + +if TYPE_CHECKING: + from memray._vendor.textual.dom import DOMNode + +SignalT = TypeVar("SignalT") + +SignalCallbackType = Union[ + Callable[[SignalT], Awaitable[Any]], Callable[[SignalT], Any] +] + + +class SignalError(Exception): + """Raised for Signal errors.""" + + +@rich.repr.auto(angular=True) +class Signal(Generic[SignalT]): + """A signal that a widget may subscribe to, in order to invoke callbacks when an associated event occurs.""" + + def __init__(self, owner: DOMNode, name: str) -> None: + """Initialize a signal. + + Args: + owner: The owner of this signal. + name: An identifier for debugging purposes. + """ + self._owner = ref(owner) + self._name = name + self._subscriptions: WeakKeyDictionary[ + DOMNode, list[SignalCallbackType[SignalT]] + ] = WeakKeyDictionary() + + def __rich_repr__(self) -> rich.repr.Result: + yield "owner", self.owner + yield "name", self._name + yield "subscriptions", list(self._subscriptions.keys()) + + @property + def owner(self) -> DOMNode | None: + """The owner of this Signal, or `None` if there is no owner.""" + return self._owner() + + def subscribe( + self, + node: DOMNode, + callback: SignalCallbackType[SignalT], + immediate: bool = False, + ) -> None: + """Subscribe a node to this signal. + + When the signal is published, the callback will be invoked. + + Args: + node: Node to subscribe. + callback: A callback function which takes a single argument and returns anything (return type ignored). + immediate: Invoke the callback immediately on publish if `True`, otherwise post it to the DOM node to be + called once existing messages have been processed. + + Raises: + SignalError: Raised when subscribing a non-mounted widget. + """ + if not node.is_running: + raise SignalError( + f"Node must be running to subscribe to a signal (has {node} been mounted)?" + ) + + if immediate: + + def signal_callback(data: SignalT) -> None: + """Invoke the callback immediately.""" + callback(data) + + else: + + def signal_callback(data: SignalT) -> None: + """Post the callback to the node, to call at the next opertunity.""" + node.call_next(callback, data) + + callbacks = self._subscriptions.setdefault(node, []) + callbacks.append(signal_callback) + + def unsubscribe(self, node: DOMNode) -> None: + """Unsubscribe a node from this signal. + + Args: + node: Node to unsubscribe, + """ + self._subscriptions.pop(node, None) + + def publish(self, data: SignalT) -> None: + """Publish the signal (invoke subscribed callbacks). + + Args: + data: An argument to pass to the callbacks. + + """ + if not self._subscriptions: + return + # Don't publish if the DOM is not ready or shutting down + owner = self.owner + if owner is None: + return + + if not owner.is_attached or owner._pruning: + return + for ancestor_node in owner.ancestors_with_self: + if not ancestor_node.is_running: + return + + for node, callbacks in list(self._subscriptions.items()): + if not (node.is_running and node.is_attached) or node._pruning: + # Removed nodes that are no longer running + self._subscriptions.pop(node) + else: + # Call callbacks + for callback in callbacks: + try: + callback(data) + except Exception as error: + log.error( + f"error publishing signal to {node} ignored (callback={callback}); {error}" + ) diff --git a/src/memray/_vendor/textual/strip.py b/src/memray/_vendor/textual/strip.py new file mode 100644 index 0000000000..5b64337127 --- /dev/null +++ b/src/memray/_vendor/textual/strip.py @@ -0,0 +1,817 @@ +""" +This module contains the `Strip` class and related objects. + +A `Strip` contains the result of rendering a widget. +See [Line API](/guide/widgets#line-api) for how to use Strips. +""" + +from __future__ import annotations + +from functools import lru_cache +from typing import Any, Iterable, Iterator, Sequence + +import rich.repr +from rich.cells import cell_len, set_cell_size +from rich.color import ColorSystem +from rich.console import Console, ConsoleOptions, RenderResult +from rich.measure import Measurement +from rich.segment import Segment +from rich.style import Style, StyleType + +from memray._vendor.textual._segment_tools import index_to_cell_position, line_pad +from memray._vendor.textual.cache import FIFOCache +from memray._vendor.textual.color import Color +from memray._vendor.textual.css.types import AlignHorizontal, AlignVertical +from memray._vendor.textual.filter import LineFilter + +SGR_STYLES = ["1", "2", "3", "4", "5", "6", "7", "8", "9", "21", "51", "52", "53"] + + +def get_line_length(segments: Iterable[Segment]) -> int: + """Get the line length (total length of all segments). + + Args: + segments: Iterable of segments. + + Returns: + Length of line in cells. + """ + _cell_len = cell_len + return sum([_cell_len(text) for text, _, control in segments if not control]) + + +class StripRenderable: + """A renderable which renders a list of strips into lines.""" + + def __init__(self, strips: list[Strip], width: int | None = None) -> None: + self._strips = strips + self._width = width + + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + new_line = Segment.line() + for strip in self._strips: + yield from strip + yield new_line + + def __rich_measure__( + self, console: "Console", options: "ConsoleOptions" + ) -> Measurement: + if self._width is None: + width = max(strip.cell_length for strip in self._strips) + else: + width = self._width + return Measurement(width, width) + + +@rich.repr.auto +class Strip: + """Represents a 'strip' (horizontal line) of a Textual Widget. + + A Strip is like an immutable list of Segments. The immutability allows for effective caching. + + Args: + segments: An iterable of segments. + cell_length: The cell length if known, or None to calculate on demand. + """ + + __slots__ = [ + "_segments", + "_cell_length", + "_divide_cache", + "_crop_cache", + "_style_cache", + "_filter_cache", + "_render_cache", + "_line_length_cache", + "_crop_extend_cache", + "_offsets_cache", + "_link_ids", + "_cell_count", + ] + + def __init__( + self, segments: Iterable[Segment], cell_length: int | None = None + ) -> None: + self._segments = list(segments) + self._cell_length = cell_length + self._divide_cache: FIFOCache[tuple[int, ...], list[Strip]] = FIFOCache(4) + self._crop_cache: FIFOCache[tuple[int, int], Strip] = FIFOCache(16) + self._style_cache: FIFOCache[Style, Strip] = FIFOCache(16) + self._filter_cache: FIFOCache[tuple[LineFilter, Color], Strip] = FIFOCache(4) + self._line_length_cache: FIFOCache[ + tuple[int, Style | None], + Strip, + ] = FIFOCache(4) + self._crop_extend_cache: FIFOCache[ + tuple[int, int, Style | None], + Strip, + ] = FIFOCache(4) + self._offsets_cache: FIFOCache[tuple[int, int], Strip] = FIFOCache(4) + self._render_cache: str | None = None + self._link_ids: set[str] | None = None + self._cell_count: int | None = None + + def __rich_repr__(self) -> rich.repr.Result: + try: + yield self._segments + yield self.cell_length + except AttributeError: + pass + + @property + def text(self) -> str: + """Segment text.""" + return "".join(segment.text for segment in self._segments) + + @property + def link_ids(self) -> set[str]: + """A set of the link ids in this Strip.""" + if self._link_ids is None: + self._link_ids = { + style._link_id for _, style, _ in self._segments if style is not None + } + return self._link_ids + + @classmethod + @lru_cache(maxsize=1024) + def blank(cls, cell_length: int, style: StyleType | None = None) -> Strip: + """Create a blank strip. + + Args: + cell_length: Desired cell length. + style: Style of blank. + + Returns: + New strip. + """ + segment_style = Style.parse(style) if isinstance(style, str) else style + return cls([Segment(" " * cell_length, segment_style)], cell_length) + + @classmethod + def from_lines( + cls, lines: list[list[Segment]], cell_length: int | None = None + ) -> list[Strip]: + """Convert lines (lists of segments) to a list of Strips. + + Args: + lines: List of lines, where a line is a list of segments. + cell_length: Cell length of lines (must be same) or None if not known. + + Returns: + List of strips. + """ + return [cls(segments, cell_length) for segments in lines] + + @classmethod + def align( + cls, + strips: list[Strip], + style: Style, + width: int, + height: int | None, + horizontal: AlignHorizontal, + vertical: AlignVertical, + ) -> Iterable[Strip]: + """Align a list of strips on both axis. + + Args: + strips: A list of strips, such as from a render. + style: The Rich style of additional space. + width: Width of container. + height: Height of container. + horizontal: Horizontal alignment method. + vertical: Vertical alignment method. + + Returns: + An iterable of strips, with additional padding. + + """ + if not strips: + return + line_lengths = [strip.cell_length for strip in strips] + shape_width = max(line_lengths) + shape_height = len(line_lengths) + + def blank_lines(count: int) -> Iterable[Strip]: + """Create blank lines. + + Args: + count: Desired number of blank lines. + + Returns: + An iterable of blank lines. + """ + blank = cls([Segment(" " * width, style)], width) + for _ in range(count): + yield blank + + top_blank_lines = bottom_blank_lines = 0 + if height is not None: + vertical_excess_space = max(0, height - shape_height) + + if vertical == "top": + bottom_blank_lines = vertical_excess_space + elif vertical == "middle": + top_blank_lines = vertical_excess_space // 2 + bottom_blank_lines = vertical_excess_space - top_blank_lines + elif vertical == "bottom": + top_blank_lines = vertical_excess_space + + if top_blank_lines: + yield from blank_lines(top_blank_lines) + + if horizontal == "left": + for strip in strips: + if strip.cell_length == width: + yield strip + else: + yield Strip( + line_pad(strip._segments, 0, width - strip.cell_length, style), + width, + ) + elif horizontal == "center": + left_space = max(0, width - shape_width) // 2 + for strip in strips: + if strip.cell_length == width: + yield strip + else: + yield Strip( + line_pad( + strip._segments, + left_space, + width - strip.cell_length - left_space, + style, + ), + width, + ) + + elif horizontal == "right": + for strip in strips: + if strip.cell_length == width: + yield strip + else: + yield cls( + line_pad(strip._segments, width - strip.cell_length, 0, style), + width, + ) + + if bottom_blank_lines: + yield from blank_lines(bottom_blank_lines) + + def index_to_cell_position(self, index: int) -> int: + """Given a character index, return the cell position of that character. + This is the sum of the cell lengths of all the characters *before* the character + at `index`. + + Args: + index: The index to convert. + + Returns: + The cell position of the character at `index`. + """ + return index_to_cell_position(self._segments, index) + + @property + def cell_length(self) -> int: + """Get the number of cells required to render this object.""" + # Done on demand and cached, as this is an O(n) operation + if self._cell_length is None: + self._cell_length = get_line_length(self._segments) + return self._cell_length + + @classmethod + def join(cls, strips: Iterable[Strip | None]) -> Strip: + """Join a number of strips into one. + + Args: + strips: An iterable of Strips. + + Returns: + A new combined strip. + """ + join_strips = [ + strip for strip in strips if strip is not None and strip.cell_count + ] + segments = [segment for strip in join_strips for segment in strip._segments] + cell_length: int | None = None + if any([strip._cell_length is None for strip in join_strips]): + cell_length = None + else: + cell_length = sum([strip._cell_length or 0 for strip in join_strips]) + joined_strip = cls(segments, cell_length) + if all(strip._render_cache is not None for strip in join_strips): + joined_strip._render_cache = "".join( + [strip._render_cache for strip in join_strips] + ) + return joined_strip + + def __add__(self, other: Strip) -> Strip: + return Strip.join([self, other]) + + def __bool__(self) -> bool: + return not not self._segments # faster than bool(...) + + def __iter__(self) -> Iterator[Segment]: + return iter(self._segments) + + def __reversed__(self) -> Iterator[Segment]: + return reversed(self._segments) + + def __len__(self) -> int: + return len(self._segments) + + def __eq__(self, strip: object) -> bool: + return isinstance(strip, Strip) and (self._segments == strip._segments) + + def __getitem__(self, index: int | slice) -> Strip: + if isinstance(index, int): + index = slice(index, index + 1) + return self.crop( + index.start, self.cell_count if index.stop is None else index.stop + ) + + @property + def cell_count(self) -> int: + """Number of cells in the strip""" + if self._cell_count is None: + self._cell_count = sum(len(segment.text) for segment in self._segments) + return self._cell_count + + def extend_cell_length(self, cell_length: int, style: Style | None = None) -> Strip: + """Extend the cell length if it is less than the given value. + + Args: + cell_length: Required minimum cell length. + style: Style for padding if the cell length is extended. + + Returns: + A new Strip. + """ + if self.cell_length < cell_length: + missing_space = cell_length - self.cell_length + segments = self._segments + [Segment(" " * missing_space, style)] + return Strip(segments, cell_length) + else: + return self + + def adjust_cell_length(self, cell_length: int, style: Style | None = None) -> Strip: + """Adjust the cell length, possibly truncating or extending. + + Args: + cell_length: New desired cell length. + style: Style when extending, or `None`. + + Returns: + A new strip with the supplied cell length. + """ + + if self.cell_length == cell_length: + return self + + cache_key = (cell_length, style) + cached_strip = self._line_length_cache.get(cache_key) + if cached_strip is not None: + return cached_strip + + new_line: list[Segment] + line = self._segments + current_cell_length = self.cell_length + + _Segment = Segment + + if current_cell_length < cell_length: + # Cell length is larger, so pad with spaces. + new_line = line + [ + _Segment(" " * (cell_length - current_cell_length), style) + ] + strip = Strip(new_line, cell_length) + + elif current_cell_length > cell_length: + # Cell length is shorter so we need to truncate. + new_line = [] + append = new_line.append + line_length = 0 + for segment in line: + segment_length = segment.cell_length + if line_length + segment_length < cell_length: + append(segment) + line_length += segment_length + else: + text, segment_style, _ = segment + text = set_cell_size(text, cell_length - line_length) + append(_Segment(text, segment_style)) + break + strip = Strip(new_line, cell_length) + else: + # Strip is already the required cell length, so return self. + strip = self + + self._line_length_cache[cache_key] = strip + return strip + + def simplify(self) -> Strip: + """Simplify the segments (join segments with same style). + + Returns: + New strip. + """ + line = Strip( + Segment.simplify(self._segments), + self._cell_length, + ) + return line + + def discard_meta(self) -> Strip: + """Remove all meta from segments. + + Returns: + New strip. + """ + + def remove_meta_from_segment(segment: Segment) -> Segment: + """Build a Segment with no meta. + + Args: + segment: Segment. + + Returns: + Segment, sans meta. + """ + text, style, control = segment + if style is None: + return segment + style = style.copy() + style._meta = None + return Segment(text, style, control) + + return Strip( + [remove_meta_from_segment(segment) for segment in self._segments], + self._cell_length, + ) + + def apply_filter(self, filter: LineFilter, background: Color) -> Strip: + """Apply a filter to all segments in the strip. + + Args: + filter: A line filter object. + + Returns: + A new Strip. + """ + cached_strip = self._filter_cache.get((filter, background)) + if cached_strip is None: + cached_strip = Strip( + filter.apply(self._segments, background), self._cell_length + ) + self._filter_cache[(filter, background)] = cached_strip + return cached_strip + + def style_links(self, link_id: str, link_style: Style) -> Strip: + """Apply a style to Segments with the given link_id. + + Args: + link_id: A link id. + link_style: Style to apply. + + Returns: + New strip (or same Strip if no changes). + """ + + _Segment = Segment + if link_id not in self.link_ids: + return self + segments = [ + _Segment( + text, + ( + (style + link_style if style is not None else None) + if (style and not style._null and style._link_id == link_id) + else style + ), + control, + ) + for text, style, control in self._segments + ] + return Strip(segments, self._cell_length) + + def crop_extend(self, start: int, end: int, style: Style | None) -> Strip: + """Crop between two points, extending the length if required. + + Args: + start: Start offset of crop. + end: End offset of crop. + style: Style of additional padding. + + Returns: + New cropped Strip. + """ + cache_key = (start, end, style) + cached_result = self._crop_extend_cache.get(cache_key) + if cached_result is not None: + return cached_result + strip = self.extend_cell_length(end, style).crop(start, end) + self._crop_extend_cache[cache_key] = strip + return strip + + def crop(self, start: int, end: int | None = None) -> Strip: + """Crop a strip between two cell positions. + + Args: + start: The start cell position (inclusive). + end: The end cell position (exclusive). + + Returns: + A new Strip. + """ + + start = max(0, start) + end = self.cell_length if end is None else min(self.cell_length, end) + if start == 0 and end == self.cell_length: + return self + if end <= start: + return Strip([], 0) + cache_key = (start, end) + cached = self._crop_cache.get(cache_key) + if cached is not None: + return cached + _cell_len = cell_len + pos = 0 + output_segments: list[Segment] = [] + add_segment = output_segments.append + iter_segments = iter(self._segments) + segment: Segment | None = None + if start >= self.cell_length: + strip = Strip([], 0) + else: + for segment in iter_segments: + end_pos = pos + _cell_len(segment.text) + if end_pos > start: + segment = segment.split_cells(start - pos)[1] + break + pos = end_pos + + if end >= self.cell_length: + # The end crop is the end of the segments, so we can collect all remaining segments + if segment: + add_segment(segment) + output_segments.extend(iter_segments) + strip = Strip(output_segments, self.cell_length - start) + else: + pos = start + while segment is not None: + end_pos = pos + _cell_len(segment.text) + if end_pos < end: + add_segment(segment) + else: + add_segment(segment.split_cells(end - pos)[0]) + break + pos = end_pos + segment = next(iter_segments, None) + strip = Strip(output_segments, end - start) + self._crop_cache[cache_key] = strip + return strip + + def divide(self, cuts: Iterable[int]) -> Sequence[Strip]: + """Divide the strip into multiple smaller strips by cutting at given (cell) indices. + + Args: + cuts: An iterable of cell positions as ints. + + Returns: + A new list of strips. + """ + + pos = 0 + cell_length = self.cell_length + cuts = [cut for cut in cuts if cut <= cell_length] + cache_key = tuple(cuts) + if (cached := self._divide_cache.get(cache_key)) is not None: + return cached + + strips: list[Strip] + if cuts == [cell_length]: + strips = [self] + else: + strips = [] + add_strip = strips.append + for segments, cut in zip(Segment.divide(self._segments, cuts), cuts): + add_strip(Strip(segments, cut - pos)) + pos = cut + + self._divide_cache[cache_key] = strips + return strips + + def apply_style(self, style: Style) -> Strip: + """Apply a style to the Strip. + + Args: + style: A Rich style. + + Returns: + A new strip. + """ + cached = self._style_cache.get(style) + if cached is not None: + return cached + styled_strip = Strip( + Segment.apply_style(self._segments, style), self.cell_length + ) + self._style_cache[style] = styled_strip + return styled_strip + + def apply_meta(self, meta: dict[str, Any]) -> Strip: + """Apply meta to all segments. + + Args: + meta: A dict of meta information. + + Returns: + A new strip. + + """ + meta_style = Style.from_meta(meta) + return self.apply_style(meta_style) + + def _apply_link_style(self, link_style: Style) -> Strip: + segments = self._segments + _Segment = Segment + segments = [ + ( + _Segment( + text, + ( + style + if style._meta is None + else (style + link_style if "@click" in style.meta else style) + ), + control, + ) + if style + else _Segment(text) + ) + for text, style, control in segments + ] + return Strip(segments, self._cell_length) + + @classmethod + @lru_cache(maxsize=16384) + def render_ansi(cls, style: Style, color_system: ColorSystem) -> str: + """Render ANSI codes for a give style. + + Args: + style: A Rich style. + color_system: Color system enumeration. + + Returns: + A string of ANSI escape sequences to render the style. + """ + sgr: list[str] + if attributes := style._attributes & style._set_attributes: + _style_map = SGR_STYLES + sgr = [ + _style_map[bit_offset] + for bit_offset in range(attributes.bit_length()) + if attributes & (1 << bit_offset) + ] + else: + sgr = [] + if (color := style._color) is not None: + sgr.extend(color.downgrade(color_system).get_ansi_codes()) + if (bgcolor := style._bgcolor) is not None: + sgr.extend(bgcolor.downgrade(color_system).get_ansi_codes(False)) + ansi = style._ansi = ";".join(sgr) + return ansi + + @classmethod + def render_style(cls, style: Style, text: str, color_system: ColorSystem) -> str: + """Render a Rich style and text. + + Args: + style: Style to render. + text: Content string. + color_system: Color system enumeration. + + Returns: + Text with ANSI escape sequences. + """ + if (ansi := style._ansi) is None: + ansi = cls.render_ansi(style, color_system) + output = f"\x1b[{ansi}m{text}\x1b[0m" if ansi else text + if style._link: + output = ( + f"\x1b]8;id={style._link_id};{style._link}\x1b\\{output}\x1b]8;;\x1b\\" + ) + return output + + def render(self, console: Console) -> str: + """Render the strip into terminal sequences. + + Args: + console: Console instance. + + Returns: + Rendered sequences. + """ + if self._render_cache is None: + color_system = console._color_system or ColorSystem.TRUECOLOR + render = self.render_style + self._render_cache = "".join( + [ + ( + text + if style is None + else render(style, text, color_system=color_system) + ) + for text, style, _ in self._segments + ] + ) + + return self._render_cache + + def crop_pad(self, cell_length: int, left: int, right: int, style: Style) -> Strip: + """Crop the strip to `cell_length`, and add optional padding. + + Args: + cell_length: Cell length of strip prior to padding. + left: Additional padding on the left. + right: Additional padding on the right. + style: Style of any padding. + + Returns: + Cropped and padded strip. + """ + if cell_length != self.cell_length: + strip = self.adjust_cell_length(cell_length, style) + else: + strip = self + if not (left or right): + return strip + segments = strip._segments.copy() + if left: + segments.insert(0, Segment(" " * left, style)) + if right: + segments.append(Segment(" " * right, style)) + return Strip(segments, cell_length + left + right) + + def text_align(self, width: int, align: AlignHorizontal) -> Strip: + if align == "left": + if self.cell_length == width: + return self + else: + return Strip( + line_pad(self._segments, 0, width - self.cell_length, Style.null()), + width, + ) + elif align == "center": + left_space = max(0, width - self.cell_length) // 2 + + if self.cell_length == width: + return self + else: + return Strip( + line_pad( + self._segments, + left_space, + width - self.cell_length - left_space, + Style.null(), + ), + width, + ) + + elif align == "right": + if self.cell_length == width: + return self + else: + return Strip( + line_pad(self._segments, width - self.cell_length, 0, Style.null()), + width, + ) + + def apply_offsets(self, x: int, y: int) -> Strip: + """Apply offsets used in text selection. + + Args: + x: Offset on X axis (column). + y: Offset on Y axis (row). + + Returns: + New strip. + """ + cache_key = (x, y) + if (cached_strip := self._offsets_cache.get(cache_key)) is not None: + return cached_strip + segments = self._segments + strip_segments: list[Segment] = [] + for segment in segments: + text, style, _ = segment + offset_style = Style.from_meta({"offset": (x, y)}) + strip_segments.append( + Segment(text, style + offset_style if style else offset_style) + ) + x += len(segment.text) + strip = Strip(strip_segments, self._cell_length) + strip._render_cache = self._render_cache + self._offsets_cache[cache_key] = strip + return strip diff --git a/src/memray/_vendor/textual/style.py b/src/memray/_vendor/textual/style.py new file mode 100644 index 0000000000..4f6da291aa --- /dev/null +++ b/src/memray/_vendor/textual/style.py @@ -0,0 +1,537 @@ +""" +The Style class contains all the information needed to generate styled terminal output. + +You won't often need to create Style objects directly, if you are using [Content][textual.content.Content] for output. +But you might want to use styles for more customized widgets. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from functools import cached_property, lru_cache +from operator import attrgetter +from pickle import dumps, loads +from typing import TYPE_CHECKING, Any, Iterable, Mapping + +import rich.repr +from rich.style import Style as RichStyle +from rich.terminal_theme import TerminalTheme + +from memray._vendor.textual._context import active_app +from memray._vendor.textual.color import Color + +if TYPE_CHECKING: + from memray._vendor.textual.css.styles import StylesBase + + +_get_hash_attributes = attrgetter( + "background", + "foreground", + "bold", + "dim", + "italic", + "underline", + "underline2", + "reverse", + "strike", + "blink", + "link", + "auto_color", + "_meta", +) + + +_get_simple_attributes = attrgetter( + "background", + "foreground", + "bold", + "dim", + "italic", + "underline", + "underline2", + "reverse", + "strike", + "blink", + "link", + "_meta", +) + +_get_simple_attributes_sans_color = attrgetter( + "bold", + "dim", + "italic", + "underline", + "underline2", + "reverse", + "strike", + "blink", + "link", + "_meta", +) + + +_get_attributes = attrgetter( + "background", + "foreground", + "bold", + "dim", + "italic", + "underline", + "underline2", + "reverse", + "strike", + "blink", + "link", + "meta", + "_meta", +) + + +@rich.repr.auto() +@dataclass(frozen=True) +class Style: + """Represents a style in the Visual interface (color and other attributes). + + Styles may be added together, which combines their style attributes. + + """ + + background: Color | None = None + foreground: Color | None = None + bold: bool | None = None + dim: bool | None = None + italic: bool | None = None + underline: bool | None = None + underline2: bool | None = None + reverse: bool | None = None + strike: bool | None = None + blink: bool | None = None + link: str | None = None + _meta: bytes | None = None + auto_color: bool = False + + def __rich_repr__(self) -> rich.repr.Result: + yield "background", self.background, None + yield "foreground", self.foreground, None + yield "bold", self.bold, None + yield "dim", self.dim, None + yield "italic", self.italic, None + yield "underline", self.underline, None + yield "underline2", self.underline2, None + yield "reverse", self.reverse, None + yield "strike", self.strike, None + yield "blink", self.blink, None + yield "link", self.link, None + + if self._meta is not None: + yield "meta", self.meta + + @cached_property + def _is_null(self) -> bool: + return _get_simple_attributes(self) == ( + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + @cached_property + def hash(self) -> int: + """A hash of the style's attributes.""" + return hash(_get_hash_attributes(self)) + + def __hash__(self) -> int: + return self.hash + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Style): + return NotImplemented + return self.hash == other.hash + + def __bool__(self) -> bool: + return not self._is_null + + def __str__(self) -> str: + return self.style_definition + + @cached_property + def style_definition(self) -> str: + """Style encoded in a string (may be parsed from `Style.parse`).""" + output: list[str] = [] + output_append = output.append + if self.foreground is not None: + output_append(self.foreground.css) + if self.background is not None: + output_append(f"on {self.background.css}") + if self.bold is not None: + output_append("bold" if self.bold else "not bold") + if self.dim is not None: + output_append("dim" if self.dim else "not dim") + if self.italic is not None: + output_append("italic" if self.italic else "not italic") + if self.underline is not None: + output_append("underline" if self.underline else "not underline") + if self.underline2 is not None: + output_append("underline2" if self.underline2 else "not underline2") + if self.strike is not None: + output_append("strike" if self.strike else "not strike") + if self.blink is not None: + output_append("blink" if self.blink else "not blink") + if self.link is not None: + if "'" not in self.link: + output_append(f"link='{self.link}'") + elif '"' not in self.link: + output_append(f'link="{self.link}"') + if self._meta is not None: + for key, value in self.meta.items(): + if isinstance(value, str): + if "'" not in key: + output_append(f"{key}='{value}'") + elif '"' not in key: + output_append(f'{key}="{value}"') + else: + output_append(f"{key}={value!r}") + else: + output_append(f"{key}={value!r}") + + return " ".join(output) + + @cached_property + def markup_tag(self) -> str: + """Identifier used to close tags in markup.""" + output: list[str] = [] + output_append = output.append + if self.foreground is not None: + output_append(self.foreground.css) + if self.background is not None: + output_append(f"on {self.background.css}") + if self.bold is not None: + output_append("bold" if self.bold else "not bold") + if self.dim is not None: + output_append("dim" if self.dim else "not dim") + if self.italic is not None: + output_append("italic" if self.italic else "not italic") + if self.underline is not None: + output_append("underline" if self.underline else "not underline") + if self.underline2 is not None: + output_append("underline2" if self.underline2 else "not underline2") + if self.strike is not None: + output_append("strike" if self.strike else "not strike") + if self.blink is not None: + output_append("blink" if self.blink else "not blink") + if self.link is not None: + output_append("link") + if self._meta is not None: + for key, value in self.meta.items(): + if isinstance(value, str): + output_append(f"{key}=") + + return " ".join(output) + + @lru_cache(maxsize=1024 * 4) + def __add__(self, other: object | None) -> Style: + if isinstance(other, Style): + if self._is_null: + return other + if other._is_null: + return self + ( + background, + foreground, + bold, + dim, + italic, + underline, + underline2, + reverse, + strike, + blink, + link, + meta, + _meta, + ) = _get_attributes(self) + + ( + other_background, + other_foreground, + other_bold, + other_dim, + other_italic, + other_underline, + other_underline2, + other_reverse, + other_strike, + other_blink, + other_link, + other_meta, + other__meta, + ) = _get_attributes(other) + + new_style = Style( + ( + other_background + if (background is None or background.a == 0) + else background + other_background + ), + ( + foreground + if (other_foreground is None or other_foreground.a == 0) + else other_foreground + ), + bold if other_bold is None else other_bold, + dim if other_dim is None else other_dim, + italic if other_italic is None else other_italic, + underline if other_underline is None else other_underline, + underline2 if other_underline2 is None else other_underline2, + reverse if other_reverse is None else other_reverse, + strike if other_strike is None else other_strike, + blink if other_blink is None else other_blink, + link if other_link is None else other_link, + ( + dumps({**meta, **other_meta}) + if _meta is not None and other__meta is not None + else (_meta if other__meta is None else other__meta) + ), + ) + return new_style + elif other is None: + return self + else: + return NotImplemented + + __radd__ = __add__ + + @classmethod + def null(cls) -> Style: + """Get a null (no color or style) style.""" + return NULL_STYLE + + @classmethod + def parse(cls, text_style: str, variables: dict[str, str] | None = None) -> Style: + """Parse a style from text. + + Args: + text_style: A style encoded in a string. + variables: Optional mapping of CSS variables. `None` to get variables from the app. + + Returns: + New style. + """ + from memray._vendor.textual.markup import parse_style + + try: + app = active_app.get() + except LookupError: + return parse_style(text_style, variables) + return app.stylesheet.parse_style(text_style) + + @classmethod + def _normalize_markup_tag(cls, text_style: str) -> str: + """Produces a normalized from of a style, used to match closing tags with opening tags. + + Args: + text_style: Style to normalize. + + Returns: + Normalized markup tag. + """ + try: + style = cls.parse(text_style) + except Exception: + return text_style.strip() + return style.markup_tag + + @classmethod + def from_rich_style( + cls, rich_style: RichStyle, theme: TerminalTheme | None = None + ) -> Style: + """Build a Style from a (Rich) Style. + + Args: + rich_style: A Rich Style object. + theme: Optional Rich [terminal theme][rich.terminal_theme.TerminalTheme]. + + Returns: + New Style. + """ + + return Style( + ( + None + if rich_style.bgcolor is None + else Color.from_rich_color(rich_style.bgcolor, theme) + ), + ( + None + if rich_style.color is None + else Color.from_rich_color(rich_style.color, theme) + ), + bold=rich_style.bold, + dim=rich_style.dim, + italic=rich_style.italic, + underline=rich_style.underline, + underline2=rich_style.underline2, + reverse=rich_style.reverse, + strike=rich_style.strike, + blink=rich_style.blink, + link=rich_style.link, + _meta=rich_style._meta, + ) + + @classmethod + def from_styles(cls, styles: StylesBase) -> Style: + """Create a Visual Style from a Textual styles object. + + Args: + styles: A Styles object, such as `my_widget.styles`. + + """ + text_style = styles.text_style + return Style( + styles.background, + ( + Color(0, 0, 0, styles.color.a, auto=True) + if styles.auto_color + else styles.color + ), + bold=text_style.bold, + dim=text_style.italic, + italic=text_style.italic, + underline=text_style.underline, + underline2=text_style.underline2, + reverse=text_style.reverse, + strike=text_style.strike, + blink=text_style.blink, + auto_color=styles.auto_color, + ) + + @classmethod + def from_meta(cls, meta: Mapping[str, Any]) -> Style: + """Create a Visual Style containing meta information. + + Args: + meta: A dictionary of meta information. + + Returns: + A new Style. + """ + return Style(_meta=dumps({**meta})) + + @cached_property + def rich_style(self) -> RichStyle: + """Convert this Styles into a Rich style. + + Returns: + A Rich style object. + """ + + ( + background, + foreground, + bold, + dim, + italic, + underline, + underline2, + reverse, + strike, + blink, + link, + _meta, + ) = _get_simple_attributes(self) + + color = None if foreground is None else background + foreground + + return RichStyle( + color=None if color is None else color.rich_color, + bgcolor=None if background is None else background.rich_color, + bold=bold, + dim=dim, + italic=italic, + underline=underline, + underline2=underline2, + reverse=reverse, + strike=strike, + blink=blink, + link=link, + meta=None if _meta is None else self.meta, + ) + + def rich_style_with_offset(self, x: int, y: int) -> RichStyle: + """Get a Rich style with the given offset included in meta. + + This is used in text selection. + + Args: + x: X coordinate. + y: Y coordinate. + + Returns: + A Rich Style object. + """ + ( + background, + foreground, + bold, + dim, + italic, + underline, + underline2, + reverse, + strike, + blink, + link, + _meta, + ) = _get_simple_attributes(self) + color = None if foreground is None else background + foreground + return RichStyle( + color=None if color is None else color.rich_color, + bgcolor=None if background is None else background.rich_color, + bold=bold, + dim=dim, + italic=italic, + underline=underline, + underline2=underline2, + reverse=reverse, + strike=strike, + blink=blink, + link=link, + meta={**self.meta, "offset": (x, y)}, + ) + + @cached_property + def without_color(self) -> Style: + """The style without any colors.""" + return Style(None, None, *_get_simple_attributes_sans_color(self)) + + @cached_property + def background_style(self) -> Style: + """Just the background color, with no other attributes.""" + return Style(self.background, _meta=self._meta) + + @property + def has_transparent_foreground(self) -> bool: + """Is the foreground transparent (or not set)?""" + return self.foreground is None or self.foreground.a == 0 + + @classmethod + def combine(cls, styles: Iterable[Style]) -> Style: + """Add a number of styles and get the result.""" + iter_styles = iter(styles) + return sum(iter_styles, next(iter_styles)) + + @cached_property + def meta(self) -> Mapping[str, Any]: + """Get meta information (can not be changed after construction).""" + return {} if self._meta is None else loads(self._meta) + + +NULL_STYLE = Style() diff --git a/src/memray/_vendor/textual/suggester.py b/src/memray/_vendor/textual/suggester.py new file mode 100644 index 0000000000..fc97f0c73d --- /dev/null +++ b/src/memray/_vendor/textual/suggester.py @@ -0,0 +1,144 @@ +""" + +Contains the `Suggester` class, used by the [Input](/widgets/input) widget. + +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Iterable + +from memray._vendor.textual.cache import LRUCache +from memray._vendor.textual.dom import DOMNode +from memray._vendor.textual.message import Message + + +@dataclass +class SuggestionReady(Message): + """Sent when a completion suggestion is ready.""" + + value: str + """The value to which the suggestion is for.""" + suggestion: str + """The string suggestion.""" + + +class Suggester(ABC): + """Defines how widgets generate completion suggestions. + + To define a custom suggester, subclass `Suggester` and implement the async method + `get_suggestion`. + See [`SuggestFromList`][textual.suggester.SuggestFromList] for an example. + """ + + cache: LRUCache[str, str | None] | None + """Suggestion cache, if used.""" + + def __init__(self, *, use_cache: bool = True, case_sensitive: bool = False) -> None: + """Create a suggester object. + + Args: + use_cache: Whether to cache suggestion results. + case_sensitive: Whether suggestions are case sensitive or not. + If they are not, incoming values are casefolded before generating + the suggestion. + """ + self.cache = LRUCache(1024) if use_cache else None + self.case_sensitive = case_sensitive + + async def _get_suggestion(self, requester: DOMNode, value: str) -> None: + """Used by widgets to get completion suggestions. + + Note: + When implementing custom suggesters, this method does not need to be + overridden. + + Args: + requester: The message target that requested a suggestion. + value: The current value to complete. + """ + + normalized_value = value if self.case_sensitive else value.casefold() + if self.cache is None or normalized_value not in self.cache: + suggestion = await self.get_suggestion(normalized_value) + if self.cache is not None: + self.cache[normalized_value] = suggestion + else: + suggestion = self.cache[normalized_value] + + if suggestion is None: + return + requester.post_message(SuggestionReady(value, suggestion)) + + @abstractmethod + async def get_suggestion(self, value: str) -> str | None: + """Try to get a completion suggestion for the given input value. + + Custom suggesters should implement this method. + + Note: + The value argument will be casefolded if `self.case_sensitive` is `False`. + + Note: + If your implementation is not deterministic, you may need to disable caching. + + Args: + value: The current value of the requester widget. + + Returns: + A valid suggestion or `None`. + """ + pass + + +class SuggestFromList(Suggester): + """Give completion suggestions based on a fixed list of options. + + Example: + ```py + countries = ["England", "Scotland", "Portugal", "Spain", "France"] + + class MyApp(App[None]): + def compose(self) -> ComposeResult: + yield Input(suggester=SuggestFromList(countries, case_sensitive=False)) + ``` + + If the user types ++p++ inside the input widget, a completion suggestion + for `"Portugal"` appears. + """ + + def __init__( + self, suggestions: Iterable[str], *, case_sensitive: bool = True + ) -> None: + """Creates a suggester based off of a given iterable of possibilities. + + Args: + suggestions: Valid suggestions sorted by decreasing priority. + case_sensitive: Whether suggestions are computed in a case sensitive manner + or not. The values provided in the argument `suggestions` represent the + canonical representation of the completions and they will be suggested + with that same casing. + """ + super().__init__(case_sensitive=case_sensitive) + self._suggestions = list(suggestions) + self._for_comparison = ( + self._suggestions + if self.case_sensitive + else [suggestion.casefold() for suggestion in self._suggestions] + ) + + async def get_suggestion(self, value: str) -> str | None: + """Gets a completion from the given possibilities. + + Args: + value: The current value. + + Returns: + A valid completion suggestion or `None`. + """ + for idx, suggestion in enumerate(self._for_comparison): + if suggestion.startswith(value): + return self._suggestions[idx] + return None diff --git a/src/memray/_vendor/textual/suggestions.py b/src/memray/_vendor/textual/suggestions.py new file mode 100644 index 0000000000..78bb29c5ab --- /dev/null +++ b/src/memray/_vendor/textual/suggestions.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from difflib import get_close_matches +from typing import Sequence + + +def get_suggestion(word: str, possible_words: Sequence[str]) -> str | None: + """ + Returns a close match of `word` amongst `possible_words`. + + Args: + word: The word we want to find a close match for + possible_words: The words amongst which we want to find a close match + + Returns: + The closest match amongst the `possible_words`. Returns `None` if no close matches could be found. + + Example: returns "red" for word "redu" and possible words ("yellow", "red") + """ + possible_matches = get_close_matches(word, possible_words, n=1) + return None if not possible_matches else possible_matches[0] + + +def get_suggestions(word: str, possible_words: Sequence[str], count: int) -> list[str]: + """ + Returns a list of up to `count` matches of `word` amongst `possible_words`. + + Args: + word: The word we want to find a close match for + possible_words: The words amongst which we want to find close matches + + Returns: + The closest matches amongst the `possible_words`, from the closest to the least close. + Returns an empty list if no close matches could be found. + + Example: returns ["yellow", "ellow"] for word "yllow" and possible words ("yellow", "red", "ellow") + """ + return get_close_matches(word, possible_words, n=count) diff --git a/src/memray/_vendor/textual/system_commands.py b/src/memray/_vendor/textual/system_commands.py new file mode 100644 index 0000000000..118cd842b6 --- /dev/null +++ b/src/memray/_vendor/textual/system_commands.py @@ -0,0 +1,64 @@ +""" + +This module contains `SystemCommands`, a command palette command provider for Textual system commands. + +This is a simple command provider that makes the most obvious application +actions available via the [command palette][textual.command.CommandPalette]. + +!!! note + + The App base class installs this automatically. + +""" + +from __future__ import annotations + +from memray._vendor.textual.command import DiscoveryHit, Hit, Hits, Provider + + +class SystemCommandsProvider(Provider): + """A [source][textual.command.Provider] of command palette commands that run app-wide tasks. + + Used by default in [`App.COMMANDS`][textual.app.App.COMMANDS]. + """ + + async def discover(self) -> Hits: + """Handle a request for the discovery commands for this provider. + + Yields: + Commands that can be discovered. + """ + commands = sorted( + self.app.get_system_commands(self.screen), key=lambda command: command[0] + ) + for name, help_text, callback, discover in commands: + if discover: + yield DiscoveryHit( + name, + callback, + help=help_text, + ) + + async def search(self, query: str) -> Hits: + """Handle a request to search for system commands that match the query. + + Args: + query: The user input to be matched. + + Yields: + Command hits for use in the command palette. + """ + # We're going to use Textual's builtin fuzzy matcher to find + # matching commands. + matcher = self.matcher(query) + + # Loop over all applicable commands, find those that match and offer + # them up to the command palette. + for name, help_text, callback, *_ in self.app.get_system_commands(self.screen): + if (match := matcher.match(name)) > 0: + yield Hit( + match, + matcher.highlight(name), + callback, + help=help_text, + ) diff --git a/src/memray/_vendor/textual/theme.py b/src/memray/_vendor/textual/theme.py new file mode 100644 index 0000000000..43ccfbb777 --- /dev/null +++ b/src/memray/_vendor/textual/theme.py @@ -0,0 +1,496 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from functools import partial +from operator import attrgetter +from typing import Callable + +from memray._vendor.textual.command import DiscoveryHit, Hit, Hits, Provider +from memray._vendor.textual.design import ColorSystem + + +@dataclass +class Theme: + """Defines a theme for the application.""" + + name: str + """The name of the theme. + + After registering a theme with `App.register_theme`, you can set the theme with + `App.theme = theme_name`. This will immediately apply the theme's colors to your + application. + """ + + primary: str + secondary: str | None = None + warning: str | None = None + error: str | None = None + success: str | None = None + accent: str | None = None + foreground: str | None = None + background: str | None = None + surface: str | None = None + panel: str | None = None + boost: str | None = None + dark: bool = True + luminosity_spread: float = 0.15 + text_alpha: float = 0.95 + variables: dict[str, str] = field(default_factory=dict) + + def to_color_system(self) -> ColorSystem: + """ + Create a ColorSystem instance from this Theme. + + Returns: + A ColorSystem instance with attributes copied from this Theme. + """ + return ColorSystem( + primary=self.primary, + secondary=self.secondary, + warning=self.warning, + error=self.error, + success=self.success, + accent=self.accent, + foreground=self.foreground, + background=self.background, + surface=self.surface, + panel=self.panel, + boost=self.boost, + dark=self.dark, + luminosity_spread=self.luminosity_spread, + text_alpha=self.text_alpha, + variables=self.variables, + ) + + +BUILTIN_THEMES: dict[str, Theme] = { + "textual-dark": Theme( + name="textual-dark", + primary="#0178D4", + secondary="#004578", + accent="#ffa62b", + warning="#ffa62b", + error="#ba3c5b", + success="#4EBF71", + foreground="#e0e0e0", + ), + "textual-light": Theme( + name="textual-light", + primary="#004578", + secondary="#0178D4", + accent="#ffa62b", + warning="#ffa62b", + error="#ba3c5b", + success="#4EBF71", + surface="#D8D8D8", + panel="#D0D0D0", + background="#E0E0E0", + dark=False, + variables={ + "footer-key-foreground": "#0178D4", + }, + ), + "nord": Theme( + name="nord", + primary="#88C0D0", + secondary="#81A1C1", + accent="#B48EAD", + foreground="#D8DEE9", + background="#2E3440", + success="#A3BE8C", + warning="#EBCB8B", + error="#BF616A", + surface="#3B4252", + panel="#434C5E", + variables={ + "block-cursor-background": "#88C0D0", + "block-cursor-foreground": "#2E3440", + "block-cursor-text-style": "none", + "footer-key-foreground": "#88C0D0", + "input-selection-background": "#81a1c1 35%", + "button-color-foreground": "#2E3440", + "button-focus-text-style": "reverse", + }, + ), + "gruvbox": Theme( + name="gruvbox", + primary="#85A598", + secondary="#A89A85", + warning="#fe8019", + error="#fb4934", + success="#b8bb26", + accent="#fabd2f", + foreground="#fbf1c7", + background="#282828", + surface="#3c3836", + panel="#504945", + variables={ + "block-cursor-foreground": "#fbf1c7", + "input-selection-background": "#689d6a40", + "button-color-foreground": "#282828", + }, + ), + "catppuccin-mocha": Theme( + name="catppuccin-mocha", + primary="#F5C2E7", + secondary="#cba6f7", + warning="#FAE3B0", + error="#F28FAD", + success="#ABE9B3", + accent="#fab387", + foreground="#cdd6f4", + background="#181825", + surface="#313244", + panel="#45475a", + variables={ + "input-cursor-foreground": "#11111b", + "input-cursor-background": "#f5e0dc", + "input-selection-background": "#9399b2 30%", + "border": "#b4befe", + "border-blurred": "#585b70", + "footer-background": "#45475a", + "block-cursor-foreground": "#1e1e2e", + "block-cursor-text-style": "none", + "button-color-foreground": "#181825", + }, + ), + "textual-ansi": Theme( + name="textual-ansi", + primary="ansi_blue", + secondary="ansi_cyan", + warning="ansi_yellow", + error="ansi_red", + success="ansi_green", + accent="ansi_bright_blue", + foreground="ansi_default", + background="ansi_default", + surface="ansi_default", + panel="ansi_default", + boost="ansi_default", + dark=False, + variables={ + "block-cursor-text-style": "b", + "block-cursor-blurred-text-style": "i", + "input-selection-background": "ansi_blue", + "input-cursor-text-style": "reverse", + "scrollbar": "ansi_blue", + "border-blurred": "ansi_blue", + "border": "ansi_bright_blue", + }, + ), + "dracula": Theme( + name="dracula", + primary="#BD93F9", + secondary="#6272A4", + warning="#FFB86C", + error="#FF5555", + success="#50FA7B", + accent="#FF79C6", + background="#282A36", + surface="#2B2E3B", + panel="#313442", + foreground="#F8F8F2", + variables={ + "button-color-foreground": "#282A36", + }, + ), + "tokyo-night": Theme( + name="tokyo-night", + primary="#BB9AF7", + secondary="#7AA2F7", + warning="#E0AF68", # Yellow + error="#F7768E", # Red + success="#9ECE6A", # Green + accent="#FF9E64", # Orange + foreground="#a9b1d6", + background="#1A1B26", # Background + surface="#24283B", # Surface + panel="#414868", # Panel + variables={ + "button-color-foreground": "#24283B", + }, + ), + "monokai": Theme( + name="monokai", + primary="#AE81FF", + secondary="#F92672", + accent="#66D9EF", + warning="#FD971F", + error="#F92672", + success="#A6E22E", + foreground="#d6d6d6", + background="#272822", + surface="#2e2e2e", + panel="#3E3D32", + variables={ + "foreground-muted": "#797979", + "input-selection-background": "#575b6190", + "button-color-foreground": "#272822", + }, + ), + "flexoki": Theme( + name="flexoki", + primary="#205EA6", # blue + secondary="#24837B", # cyan + warning="#AD8301", # yellow + error="#AF3029", # red + success="#66800B", # green + accent="#9B76C8", # purple light + background="#100F0F", # base.black + surface="#1C1B1A", # base.950 + panel="#282726", # base.900 + foreground="#FFFCF0", # base.paper + variables={ + "input-cursor-foreground": "#5E409D", + "input-cursor-background": "#FFFCF0", + "input-selection-background": "#6F6E69 35%", # base.600 with opacity + "button-color-foreground": "#FFFCF0", + }, + ), + "catppuccin-latte": Theme( + name="catppuccin-latte", + secondary="#DC8A78", + primary="#8839EF", + warning="#DF8E1D", + error="#D20F39", + success="#40A02B", + accent="#FE640B", + foreground="#4C4F69", + background="#EFF1F5", + surface="#E6E9EF", + panel="#CCD0DA", + dark=False, + variables={ + "button-color-foreground": "#EFF1F5", + }, + ), + "catppuccin-frappe": Theme( + name="catppuccin-frappe", + primary="#CA9EE6", + secondary="#EF9F76", + warning="#E5C890", + error="#E78284", + success="#A6D189", + accent="#F4B8E4", + foreground="#C6D0F5", + background="#303446", + surface="#414559", + panel="#51576D", + dark=True, + variables={ + "input-cursor-foreground": "#232634", + "input-cursor-background": "#F2D5CF", + "input-selection-background": "#949CBB 30%", + "border": "#BABBF1", + "border-blurred": "#838BA7", + "footer-background": "#51576D", + "block-cursor-foreground": "#292C3C", + "block-cursor-text-style": "none", + "button-color-foreground": "#303446", + }, + ), + "catppuccin-macchiato": Theme( + name="catppuccin-macchiato", + primary="#C6A0F6", + secondary="#F5A97F", + warning="#EED49F", + error="#ED8796", + success="#A6DA95", + accent="#F5BDE6", + foreground="#CAD3F5", + background="#24273A", + surface="#363A4F", + panel="#494D64", + dark=True, + variables={ + "input-cursor-foreground": "#181926", + "input-cursor-background": "#F4DBD6", + "input-selection-background": "#838BA7 30%", + "border": "#B7BDF8", + "border-blurred": "#737994", + "footer-background": "#494D64", + "block-cursor-foreground": "#1E2030", + "block-cursor-text-style": "none", + "button-color-foreground": "#24273A", + }, + ), + "solarized-light": Theme( + name="solarized-light", + primary="#268bd2", + secondary="#2aa198", + warning="#cb4b16", + error="#dc322f", + success="#859900", + accent="#6c71c4", + foreground="#586e75", + background="#fdf6e3", + surface="#eee8d5", + panel="#eee8d5", + dark=False, + variables={ + "button-color-foreground": "#fdf6e3", + "footer-background": "#268bd2", + "footer-key-foreground": "#fdf6e3", + "footer-description-foreground": "#fdf6e3", + }, + ), + "solarized-dark": Theme( + name="solarized-dark", + primary="#268bd2", + secondary="#2aa198", + warning="#cb4b16", + error="#dc322f", + success="#859900", + accent="#6c71c4", + background="#002b36", + surface="#073642", + panel="#073642", + foreground="#839496", + dark=True, + variables={ + "button-color-foreground": "#fdf6e3", + "footer-background": "#268bd2", + "footer-key-foreground": "#fdf6e3", + "footer-description-foreground": "#fdf6e3", + "input-selection-background": "#073642", # Base02 + }, + ), + "rose-pine": Theme( + name="rose-pine", + primary="#c4a7e7", + secondary="#31748f", + warning="#f6c177", + error="#eb6f92", + success="#9ccfd8", + accent="#ebbcba", + foreground="#e0def4", + background="#191724", + surface="#1f1d2e", + panel="#26233a", + dark=True, + variables={ + "input-cursor-background": "#f4ede8", + "input-selection-background": "#403d52", + "border": "#524f67", + "border-blurred": "#6e6a86", + "footer-background": "#26233a", + "block-cursor-foreground": "#191724", + "block-cursor-text-style": "none", + "block-cursor-background": "#c4a7e7", + }, + ), + "rose-pine-moon": Theme( + name="rose-pine-moon", + primary="#c4a7e7", + secondary="#3e8fb0", + warning="#f6c177", + error="#eb6f92", + success="#9ccfd8", + accent="#ea9a97", + foreground="#e0def4", + background="#232136", + surface="#2a273f", + panel="#393552", + dark=True, + variables={ + "input-cursor-background": "#f4ede8", + "input-selection-background": "#44415a", + "border": "#56526e", + "border-blurred": "#6e6a86", + "footer-background": "#393552", + "block-cursor-foreground": "#232136", + "block-cursor-text-style": "none", + "block-cursor-background": "#c4a7e7", + }, + ), + "rose-pine-dawn": Theme( + name="rose-pine-dawn", + primary="#907aa9", + secondary="#286983", + warning="#ea9d34", + error="#b4637a", + success="#56949f", + accent="#d7827e", + foreground="#575279", + background="#faf4ed", + surface="#fffaf3", + panel="#f2e9e1", + dark=False, + variables={ + "input-cursor-background": "#575279", + "input-selection-background": "#dfdad9", + "border": "#cecacd", + "border-blurred": "#9893a5", + "footer-background": "#f2e9e1", + "block-cursor-foreground": "#faf4ed", + "block-cursor-text-style": "none", + "block-cursor-background": "#575279", + }, + ), + "atom-one-dark": Theme( + name="atom-one-dark", + primary="#61AFEF", + secondary="#C678DD", + warning="#DEB25B", + error="#F06262", + success="#62F062", + accent="#A378C2", + foreground="#ABB2BF", + background="#282C34", + surface="#3B414D", + panel="#4F5666", + boost=None, + dark=True, + luminosity_spread=0.15, + text_alpha=0.95, + ), + "atom-one-light": Theme( + name="atom-one-light", + primary="#4078F2", + secondary="#A626A4", + warning="#D8D938", + error="#F23F3F", + success="#6CF23F", + accent="#bf9232", + foreground="#383A42", + background="#FAFAFA", + surface="#E0E0E0", + panel="#CCCCCC", + boost=None, + dark=False, + luminosity_spread=0.15, + text_alpha=0.95, + ), +} + + +class ThemeProvider(Provider): + """A provider for themes.""" + + @property + def commands(self) -> list[tuple[str, Callable[[], None]]]: + themes = self.app.available_themes + + def set_app_theme(name: str) -> None: + self.app.theme = name + + return [ + (theme.name, partial(set_app_theme, theme.name)) + for theme in sorted(themes.values(), key=attrgetter("name")) + if theme.name != "textual-ansi" + ] + + async def discover(self) -> Hits: + for command in self.commands: + yield DiscoveryHit(*command) + + async def search(self, query: str) -> Hits: + matcher = self.matcher(query) + + for name, callback in self.commands: + if (match := matcher.match(name)) > 0: + yield Hit( + match, + matcher.highlight(name), + callback, + ) diff --git a/src/memray/_vendor/textual/timer.py b/src/memray/_vendor/textual/timer.py new file mode 100644 index 0000000000..b4c701514b --- /dev/null +++ b/src/memray/_vendor/textual/timer.py @@ -0,0 +1,203 @@ +""" + +Contains the `Timer` class. +Timer objects are created by [set_interval][textual.message_pump.MessagePump.set_interval] or + [set_timer][textual.message_pump.MessagePump.set_timer]. +""" + +from __future__ import annotations + +import weakref +from asyncio import CancelledError, Event, Task, create_task, gather +from typing import Any, Awaitable, Callable, Iterable, Union + +from rich.repr import Result, rich_repr + +from memray._vendor.textual import _time, events +from memray._vendor.textual._callback import invoke +from memray._vendor.textual._compat import cached_property +from memray._vendor.textual._context import active_app +from memray._vendor.textual._time import sleep +from memray._vendor.textual._types import MessageTarget + +TimerCallback = Union[Callable[[], Awaitable[Any]], Callable[[], Any]] +"""Type of valid callbacks to be used with timers.""" + + +class EventTargetGone(Exception): + """Raised if the timer event target has been deleted prior to the timer event being sent.""" + + +@rich_repr +class Timer: + """A class to send timer-based events. + + Args: + event_target: The object which will receive the timer events. + interval: The time between timer events, in seconds. + name: A name to assign the event (for debugging). + callback: An optional callback to invoke when the event is handled. + repeat: The number of times to repeat the timer, or None to repeat forever. + skip: Enable skipping of scheduled events that couldn't be sent in time. + pause: Start the timer paused. + """ + + _timer_count: int = 1 + + def __init__( + self, + event_target: MessageTarget, + interval: float, + *, + name: str | None = None, + callback: TimerCallback | None = None, + repeat: int | None = None, + skip: bool = True, + pause: bool = False, + ) -> None: + self._target_repr = repr(event_target) + self._target = weakref.ref(event_target) + self._interval = interval + self.name = f"Timer#{self._timer_count}" if name is None else name + self._timer_count += 1 + self._callback = callback + self._repeat = repeat + self._skip = skip + self._task: Task | None = None + self._reset: bool = False + self._original_pause = pause + + @cached_property + def _active(self) -> Event: + event = Event() + if not self._original_pause: + event.set() + return event + + def __rich_repr__(self) -> Result: + yield self._interval + yield "name", self.name + yield "repeat", self._repeat, None + + @property + def target(self) -> MessageTarget: + target = self._target() + if target is None: + raise EventTargetGone() + return target + + def _start(self) -> None: + """Start the timer.""" + self._task = create_task(self._run_timer(), name=self.name) + + def stop(self) -> None: + """Stop the timer.""" + if self._task is None: + return + + self._active.set() + self._task.cancel() + self._task = None + + @classmethod + async def _stop_all(cls, timers: Iterable[Timer]) -> None: + """Stop a number of timers, and await their completion. + + Args: + timers: A number of timers. + """ + + async def stop_timer(timer: Timer) -> None: + """Stop a timer and wait for it to finish. + + Args: + timer: A Timer instance. + """ + if timer._task is not None: + timer._active.set() + timer._task.cancel() + try: + await timer._task + except CancelledError: + pass + timer._task = None + + await gather(*[stop_timer(timer) for timer in list(timers)]) + + def pause(self) -> None: + """Pause the timer. + + A paused timer will not send events until it is resumed. + """ + self._active.clear() + + def reset(self) -> None: + """Reset the timer, so it starts from the beginning.""" + self._active.set() + self._reset = True + + def resume(self) -> None: + """Resume a paused timer.""" + self._active.set() + + async def _run_timer(self) -> None: + """Run the timer task.""" + try: + await self._run() + except CancelledError: + pass + + async def _run(self) -> None: + """Run the timer.""" + count = 0 + _repeat = self._repeat + _interval = self._interval + self._active # Force instantiation in same thread + await self._active.wait() + start = _time.get_time() + + while _repeat is None or count <= _repeat: + next_timer = start + ((count + 1) * _interval) + now = _time.get_time() + if self._skip and next_timer < now: + count = int((now - start) / _interval + 1) + continue + now = _time.get_time() + wait_time = max(0, next_timer - now) + await sleep(wait_time) + count += 1 + await self._active.wait() + if self._reset: + start = _time.get_time() + count = 0 + self._reset = False + continue + try: + await self._tick(next_timer=next_timer, count=count) + except EventTargetGone: + break + + async def _tick(self, *, next_timer: float, count: int) -> None: + """Triggers the Timer's action: either call its callback, or sends an event to its target""" + + app = active_app.get() + if app._exit: + return + + if self._callback is not None: + try: + await invoke(self._callback) + except CancelledError: + # https://github.com/Textualize/textual/pull/2895 + # Re-raise CancelledErrors that would be caught by the following exception block in Python 3.7 + raise + except Exception as error: + app._handle_exception(error) + else: + event = events.Timer( + timer=self, + time=next_timer, + count=count, + callback=self._callback, + ) + self.target.post_message(event) diff --git a/src/memray/_vendor/textual/tree-sitter/highlights/bash.scm b/src/memray/_vendor/textual/tree-sitter/highlights/bash.scm new file mode 100644 index 0000000000..f33a7c2d3a --- /dev/null +++ b/src/memray/_vendor/textual/tree-sitter/highlights/bash.scm @@ -0,0 +1,56 @@ +[ + (string) + (raw_string) + (heredoc_body) + (heredoc_start) +] @string + +(command_name) @function + +(variable_name) @property + +[ + "case" + "do" + "done" + "elif" + "else" + "esac" + "export" + "fi" + "for" + "function" + "if" + "in" + "select" + "then" + "unset" + "until" + "while" +] @keyword + +(comment) @comment + +(function_definition name: (word) @function) + +(file_descriptor) @number + +[ + (command_substitution) + (process_substitution) + (expansion) +]@embedded + +[ + "$" + "&&" + ">" + ">>" + "<" + "|" +] @operator + +( + (command (_) @constant) + (#match? @constant "^-") +) diff --git a/src/memray/_vendor/textual/tree-sitter/highlights/css.scm b/src/memray/_vendor/textual/tree-sitter/highlights/css.scm new file mode 100644 index 0000000000..61ad9c412e --- /dev/null +++ b/src/memray/_vendor/textual/tree-sitter/highlights/css.scm @@ -0,0 +1,91 @@ +[ + "@media" + "@charset" + "@namespace" + "@supports" + "@keyframes" + (at_keyword) + (to) + (from) + ] @keyword + +"@import" @include + +(comment) @comment @spell + +[ + (tag_name) + (nesting_selector) + (universal_selector) + ] @type + +(function_name) @function + +[ + "~" + ">" + "+" + "-" + "*" + "/" + "=" + "^=" + "|=" + "~=" + "$=" + "*=" + "and" + "or" + "not" + "only" + ] @operator + +(important) @type.qualifier + +(attribute_selector (plain_value) @string) +(pseudo_element_selector "::" (tag_name) @property) +(pseudo_class_selector (class_name) @property) + +[ + (class_name) + (id_name) + (property_name) + (feature_name) + (attribute_name) + ] @css.property + +(namespace_name) @namespace + +((property_name) @type.definition + (#match? @type.definition "^[-][-]")) +((plain_value) @type + (#match? @type "^[-][-]")) + +[ + (string_value) + (color_value) + (unit) + ] @string + +[ + (integer_value) + (float_value) + ] @number + +[ + "#" + "," + "." + ":" + "::" + ";" + ] @punctuation.delimiter + +[ + "{" + ")" + "(" + "}" + ] @punctuation.bracket + +(ERROR) @error diff --git a/src/memray/_vendor/textual/tree-sitter/highlights/go.scm b/src/memray/_vendor/textual/tree-sitter/highlights/go.scm new file mode 100644 index 0000000000..7e1d625272 --- /dev/null +++ b/src/memray/_vendor/textual/tree-sitter/highlights/go.scm @@ -0,0 +1,123 @@ +; Function calls + +(call_expression + function: (identifier) @function.builtin + (.match? @function.builtin "^(append|cap|close|complex|copy|delete|imag|len|make|new|panic|print|println|real|recover)$")) + +(call_expression + function: (identifier) @function) + +(call_expression + function: (selector_expression + field: (field_identifier) @function.method)) + +; Function definitions + +(function_declaration + name: (identifier) @function) + +(method_declaration + name: (field_identifier) @function.method) + +; Identifiers + +(type_identifier) @type +(field_identifier) @property +(identifier) @variable + +; Operators + +[ + "--" + "-" + "-=" + ":=" + "!" + "!=" + "..." + "*" + "*" + "*=" + "/" + "/=" + "&" + "&&" + "&=" + "%" + "%=" + "^" + "^=" + "+" + "++" + "+=" + "<-" + "<" + "<<" + "<<=" + "<=" + "=" + "==" + ">" + ">=" + ">>" + ">>=" + "|" + "|=" + "||" + "~" +] @operator + +; Keywords + +[ + "break" + "case" + "chan" + "const" + "continue" + "default" + "defer" + "else" + "fallthrough" + "for" + "func" + "go" + "goto" + "if" + "import" + "interface" + "map" + "package" + "range" + "return" + "select" + "struct" + "switch" + "type" + "var" +] @keyword + +; Literals + +[ + (interpreted_string_literal) + (raw_string_literal) + (rune_literal) +] @string + +(escape_sequence) @escape + +[ + (int_literal) + (float_literal) + (imaginary_literal) +] @number + +[ + (true) + (false) + (nil) + (iota) +] @constant.builtin + +(comment) @comment diff --git a/src/memray/_vendor/textual/tree-sitter/highlights/html.scm b/src/memray/_vendor/textual/tree-sitter/highlights/html.scm new file mode 100644 index 0000000000..41c83ce0d8 --- /dev/null +++ b/src/memray/_vendor/textual/tree-sitter/highlights/html.scm @@ -0,0 +1,25 @@ +(tag_name) @tag +(erroneous_end_tag_name) @html.end_tag_error +(comment) @comment +(attribute_name) @tag.attribute +(attribute + (quoted_attribute_value) @string) +(text) @text @spell + +((attribute + (attribute_name) @_attr + (quoted_attribute_value (attribute_value) @text.uri)) + (#any-of? @_attr "href" "src")) + +[ + "<" + ">" + "" +] @tag.delimiter + +"=" @operator + +(doctype) @constant + +"" + ">" + ">=" + ">>" + ">>=" + ">>>" + ">>>=" + "~" + "^" + "&" + "|" + "^=" + "&=" + "|=" + "&&" + "||" + "??" + "&&=" + "||=" + "??=" +] @operator + +[ + "(" + ")" + "[" + "]" + "{" + "}" +] @punctuation.bracket + +(template_substitution + "${" @punctuation.special + "}" @punctuation.special) @embedded + +[ + "as" + "async" + "await" + "break" + "case" + "catch" + "class" + "const" + "continue" + "debugger" + "default" + "delete" + "do" + "else" + "export" + "extends" + "finally" + "for" + "from" + "function" + "get" + "if" + "import" + "in" + "instanceof" + "let" + "new" + "of" + "return" + "set" + "static" + "switch" + "target" + "throw" + "try" + "typeof" + "var" + "void" + "while" + "with" + "yield" +] @keyword diff --git a/src/memray/_vendor/textual/tree-sitter/highlights/json.scm b/src/memray/_vendor/textual/tree-sitter/highlights/json.scm new file mode 100644 index 0000000000..c23e7b3ce9 --- /dev/null +++ b/src/memray/_vendor/textual/tree-sitter/highlights/json.scm @@ -0,0 +1,32 @@ +[ + (true) + (false) +] @boolean + +(null) @json.null + +(number) @number + +(pair key: (string) @json.label) +(pair value: (string) @string) + +(array (string) @string) + +(string_content) @spell + +(ERROR) @json.error + +["," ":"] @punctuation.delimiter + +[ + "[" "]" + "{" "}" +] @punctuation.bracket + +(("\"" @conceal) + (#set! conceal "")) + +(escape_sequence) @string.escape +((escape_sequence) @conceal + (#eq? @conceal "\\\"") + (#set! conceal "\"")) diff --git a/src/memray/_vendor/textual/tree-sitter/highlights/markdown.scm b/src/memray/_vendor/textual/tree-sitter/highlights/markdown.scm new file mode 100644 index 0000000000..563c5138f5 --- /dev/null +++ b/src/memray/_vendor/textual/tree-sitter/highlights/markdown.scm @@ -0,0 +1,51 @@ +(atx_heading (inline) @heading) +(setext_heading (paragraph) @heading) + +[ + (atx_h1_marker) + (atx_h2_marker) + (atx_h3_marker) + (atx_h4_marker) + (atx_h5_marker) + (atx_h6_marker) + (setext_h1_underline) + (setext_h2_underline) +] @heading.marker + +[ + (link_title) + (indented_code_block) + (fenced_code_block) +] @text.literal + +[ + (fenced_code_block_delimiter) +] @punctuation.delimiter + +(code_fence_content) @none + +[ + (link_destination) +] @link.uri + +[ + (link_label) +] @link.label + +[ + (list_marker_plus) + (list_marker_minus) + (list_marker_star) + (list_marker_dot) + (list_marker_parenthesis) + (thematic_break) +] @list.marker + +[ + (block_continuation) + (block_quote_marker) +] @punctuation.special + +[ + (backslash_escape) +] @string.escape diff --git a/src/memray/_vendor/textual/tree-sitter/highlights/python.scm b/src/memray/_vendor/textual/tree-sitter/highlights/python.scm new file mode 100644 index 0000000000..3e7381b813 --- /dev/null +++ b/src/memray/_vendor/textual/tree-sitter/highlights/python.scm @@ -0,0 +1,313 @@ +;; From tree-sitter-python licensed under MIT License +; Copyright (c) 2016 Max Brunsfeld +; Adapted for Textual from: +; https://github.com/nvim-treesitter/nvim-treesitter/blob/f95ffd09ed35880c3a46ad2b968df361fa592a76/queries/python/highlights.scm + +; Variables +(identifier) @variable + +; Reset highlighting in f-string interpolations +(interpolation) @none + +;; Identifier naming conventions +((identifier) @type + (#match? @type "^[A-Z].*[a-z]")) +((identifier) @constant + (#match? @constant "^[A-Z][A-Z_0-9]*$")) + +((attribute + attribute: (identifier) @field) + (#match? @field "^([A-Z])@!.*$")) + +((identifier) @type.builtin + (#any-of? @type.builtin + ;; https://docs.python.org/3/library/exceptions.html + "BaseException" "Exception" "ArithmeticError" "BufferError" "LookupError" "AssertionError" "AttributeError" + "EOFError" "FloatingPointError" "GeneratorExit" "ImportError" "ModuleNotFoundError" "IndexError" "KeyError" + "KeyboardInterrupt" "MemoryError" "NameError" "NotImplementedError" "OSError" "OverflowError" "RecursionError" + "ReferenceError" "RuntimeError" "StopIteration" "StopAsyncIteration" "SyntaxError" "IndentationError" "TabError" + "SystemError" "SystemExit" "TypeError" "UnboundLocalError" "UnicodeError" "UnicodeEncodeError" "UnicodeDecodeError" + "UnicodeTranslateError" "ValueError" "ZeroDivisionError" "EnvironmentError" "IOError" "WindowsError" + "BlockingIOError" "ChildProcessError" "ConnectionError" "BrokenPipeError" "ConnectionAbortedError" + "ConnectionRefusedError" "ConnectionResetError" "FileExistsError" "FileNotFoundError" "InterruptedError" + "IsADirectoryError" "NotADirectoryError" "PermissionError" "ProcessLookupError" "TimeoutError" "Warning" + "UserWarning" "DeprecationWarning" "PendingDeprecationWarning" "SyntaxWarning" "RuntimeWarning" + "FutureWarning" "ImportWarning" "UnicodeWarning" "BytesWarning" "ResourceWarning" + ;; https://docs.python.org/3/library/stdtypes.html + "bool" "int" "float" "complex" "list" "tuple" "range" "str" + "bytes" "bytearray" "memoryview" "set" "frozenset" "dict" "type")) + +((assignment + left: (identifier) @type.definition + (type (identifier) @_annotation)) + (#eq? @_annotation "TypeAlias")) + +((assignment + left: (identifier) @type.definition + right: (call + function: (identifier) @_func)) + (#any-of? @_func "TypeVar" "NewType")) + +; Function calls + +(call + function: (identifier) @function.call) + +(call + function: (attribute + attribute: (identifier) @method.call)) + +((call + function: (identifier) @constructor) + (#match? @constructor "^[A-Z]")) + +((call + function: (attribute + attribute: (identifier) @constructor)) + (#match? @constructor "^[A-Z]")) + +;; Decorators + +((decorator "@" @attribute) + (#set! "priority" 101)) + +(decorator + (identifier) @attribute) +(decorator + (attribute + attribute: (identifier) @attribute)) +(decorator + (call (identifier) @attribute)) +(decorator + (call (attribute + attribute: (identifier) @attribute))) + +((decorator + (identifier) @attribute.builtin) + (#any-of? @attribute.builtin "classmethod" "property")) + +;; Builtin functions + +((call + function: (identifier) @function.builtin) + (#any-of? @function.builtin + "abs" "all" "any" "ascii" "bin" "bool" "breakpoint" "bytearray" "bytes" "callable" "chr" "classmethod" + "compile" "complex" "delattr" "dict" "dir" "divmod" "enumerate" "eval" "exec" "filter" "float" "format" + "frozenset" "getattr" "globals" "hasattr" "hash" "help" "hex" "id" "input" "int" "isinstance" "issubclass" + "iter" "len" "list" "locals" "map" "max" "memoryview" "min" "next" "object" "oct" "open" "ord" "pow" + "print" "property" "range" "repr" "reversed" "round" "set" "setattr" "slice" "sorted" "staticmethod" "str" + "sum" "super" "tuple" "type" "vars" "zip" "__import__")) + +;; Function definitions + +(function_definition + name: (identifier) @function) + +(type (identifier) @type) +(type + (subscript + (identifier) @type)) ; type subscript: Tuple[int] + +((call + function: (identifier) @_isinstance + arguments: (argument_list + (_) + (identifier) @type)) + (#eq? @_isinstance "isinstance")) + +;; Normal parameters +(parameters + (identifier) @parameter) +;; Lambda parameters +(lambda_parameters + (identifier) @parameter) +(lambda_parameters + (tuple_pattern + (identifier) @parameter)) +; Default parameters +(keyword_argument + name: (identifier) @parameter) +; Naming parameters on call-site +(default_parameter + name: (identifier) @parameter) +(typed_parameter + (identifier) @parameter) +(typed_default_parameter + (identifier) @parameter) +; Variadic parameters *args, **kwargs +(parameters + (list_splat_pattern ; *args + (identifier) @parameter)) +(parameters + (dictionary_splat_pattern ; **kwargs + (identifier) @parameter)) + + +;; Literals + +(none) @constant.builtin +[(true) (false)] @boolean +((identifier) @variable.builtin + (#eq? @variable.builtin "self")) +((identifier) @variable.builtin + (#eq? @variable.builtin "cls")) + +(integer) @number +(float) @float + +(comment) @comment @spell + +((module . (comment) @preproc) + (#match? @preproc "^#!/")) + +(string) @string +(escape_sequence) @string.escape + +; doc-strings +(expression_statement (string) @spell) + +; Tokens + +[ + "-" + "-=" + ":=" + "!=" + "*" + "**" + "**=" + "*=" + "/" + "//" + "//=" + "/=" + "&" + "&=" + "%" + "%=" + "^" + "^=" + "+" + "+=" + "<" + "<<" + "<<=" + "<=" + "<>" + "=" + "==" + ">" + ">=" + ">>" + ">>=" + "@" + "@=" + "|" + "|=" + "~" + "->" +] @operator + +; Keywords +[ + "and" + "in" + "is" + "not" + "or" + "del" +] @keyword.operator + +[ + "def" + "lambda" +] @keyword.function + +[ + "assert" + "async" + "await" + "class" + "exec" + "global" + "nonlocal" + "pass" + "print" + "with" + "as" +] @keyword + +[ + "return" + "yield" +] @keyword.return +(yield "from" @keyword.return) + +(future_import_statement + "from" @include + "__future__" @constant.builtin) +(import_from_statement "from" @include) +"import" @include + +(aliased_import "as" @include) + +["if" "elif" "else" "match" "case"] @conditional + +["for" "while" "break" "continue"] @repeat + +[ + "try" + "except" + "raise" + "finally" +] @exception + +(raise_statement "from" @exception) + +(try_statement + (else_clause + "else" @exception)) + +["(" ")" "[" "]" "{" "}"] @punctuation.bracket + +(interpolation + "{" @punctuation.special + "}" @punctuation.special) + +["," "." ":" ";" (ellipsis)] @punctuation.delimiter + +;; Class definitions + +(class_definition name: (identifier) @type.class) + +(class_definition + body: (block + (function_definition + name: (identifier) @method))) + +(class_definition + superclasses: (argument_list + (identifier) @type)) + +((class_definition + body: (block + (expression_statement + (assignment + left: (identifier) @field)))) + (#match? @field "^([A-Z])@!.*$")) +((class_definition + body: (block + (expression_statement + (assignment + left: (_ + (identifier) @field))))) + (#match? @field "^([A-Z])@!.*$")) + +((class_definition + (block + (function_definition + name: (identifier) @constructor))) + (#any-of? @constructor "__new__" "__init__")) + +;; Error +(ERROR) @error diff --git a/src/memray/_vendor/textual/tree-sitter/highlights/regex.scm b/src/memray/_vendor/textual/tree-sitter/highlights/regex.scm new file mode 100644 index 0000000000..8b653465b4 --- /dev/null +++ b/src/memray/_vendor/textual/tree-sitter/highlights/regex.scm @@ -0,0 +1,50 @@ +[ + "(" + ")" + "(?" + "(?:" + "(?<" + ">" + "[" + "]" + "{" + "}" +] @punctuation.bracket + +(group_name) @property + +[ + (identity_escape) + (control_letter_escape) + (character_class_escape) + (control_escape) + (start_assertion) + (end_assertion) + (boundary_assertion) + (non_boundary_assertion) +] @escape + +[ + "*" + "+" + "?" + "|" + "=" + "!" +] @operator + +(count_quantifier + [ + (decimal_digits) @number + "," @punctuation.delimiter + ]) + +(character_class + [ + "^" @operator + (class_range "-" @operator) + ]) + +(class_character) @constant.character + +(pattern_character) @string diff --git a/src/memray/_vendor/textual/tree-sitter/highlights/rust.scm b/src/memray/_vendor/textual/tree-sitter/highlights/rust.scm new file mode 100644 index 0000000000..c1556847b3 --- /dev/null +++ b/src/memray/_vendor/textual/tree-sitter/highlights/rust.scm @@ -0,0 +1,155 @@ +; Identifier conventions + +; Assume all-caps names are constants +((identifier) @constant + (#match? @constant "^[A-Z][A-Z\\d_]+$'")) + +; Assume that uppercase names in paths are types +((scoped_identifier + path: (identifier) @type) + (#match? @type "^[A-Z]")) +((scoped_identifier + path: (scoped_identifier + name: (identifier) @type)) + (#match? @type "^[A-Z]")) +((scoped_type_identifier + path: (identifier) @type) + (#match? @type "^[A-Z]")) +((scoped_type_identifier + path: (scoped_identifier + name: (identifier) @type)) + (#match? @type "^[A-Z]")) + +; Assume other uppercase names are enum constructors +((identifier) @constructor + (#match? @constructor "^[A-Z]")) + +; Assume all qualified names in struct patterns are enum constructors. (They're +; either that, or struct names; highlighting both as constructors seems to be +; the less glaring choice of error, visually.) +(struct_pattern + type: (scoped_type_identifier + name: (type_identifier) @constructor)) + +; Function calls + +(call_expression + function: (identifier) @function) +(call_expression + function: (field_expression + field: (field_identifier) @function.method)) +(call_expression + function: (scoped_identifier + "::" + name: (identifier) @function)) + +(generic_function + function: (identifier) @function) +(generic_function + function: (scoped_identifier + name: (identifier) @function)) +(generic_function + function: (field_expression + field: (field_identifier) @function.method)) + +(macro_invocation + macro: (identifier) @function.macro + "!" @function.macro) + +; Function definitions + +(function_item (identifier) @function) +(function_signature_item (identifier) @function) + +; Other identifiers + +(type_identifier) @type +(primitive_type) @type.builtin +(field_identifier) @property + +(line_comment) @comment +(block_comment) @comment + +"(" @punctuation.bracket +")" @punctuation.bracket +"[" @punctuation.bracket +"]" @punctuation.bracket +"{" @punctuation.bracket +"}" @punctuation.bracket + +(type_arguments + "<" @punctuation.bracket + ">" @punctuation.bracket) +(type_parameters + "<" @punctuation.bracket + ">" @punctuation.bracket) + +"::" @punctuation.delimiter +":" @punctuation.delimiter +"." @punctuation.delimiter +"," @punctuation.delimiter +";" @punctuation.delimiter + +(parameter (identifier) @variable.parameter) + +(lifetime (identifier) @label) + +"as" @keyword +"async" @keyword +"await" @keyword +"break" @keyword +"const" @keyword +"continue" @keyword +"default" @keyword +"dyn" @keyword +"else" @keyword +"enum" @keyword +"extern" @keyword +"fn" @keyword +"for" @keyword +"if" @keyword +"impl" @keyword +"in" @keyword +"let" @keyword +"loop" @keyword +"macro_rules!" @keyword +"match" @keyword +"mod" @keyword +"move" @keyword +"pub" @keyword +"ref" @keyword +"return" @keyword +"static" @keyword +"struct" @keyword +"trait" @keyword +"type" @keyword +"union" @keyword +"unsafe" @keyword +"use" @keyword +"where" @keyword +"while" @keyword +(crate) @keyword +(mutable_specifier) @keyword +(use_list (self) @keyword) +(scoped_use_list (self) @keyword) +(scoped_identifier (self) @keyword) +(super) @keyword + +(self) @variable.builtin + +(char_literal) @string +(string_literal) @string +(raw_string_literal) @string + +(boolean_literal) @constant.builtin +(integer_literal) @constant.builtin +(float_literal) @constant.builtin + +(escape_sequence) @escape + +(attribute_item) @attribute +(inner_attribute_item) @attribute + +"*" @operator +"&" @operator +"'" @operator diff --git a/src/memray/_vendor/textual/tree-sitter/highlights/sql.scm b/src/memray/_vendor/textual/tree-sitter/highlights/sql.scm new file mode 100644 index 0000000000..fe4913026c --- /dev/null +++ b/src/memray/_vendor/textual/tree-sitter/highlights/sql.scm @@ -0,0 +1,444 @@ +(object_reference + name: (identifier) @type) + +(invocation + (object_reference + name: (identifier) @function.call)) + +[ + (keyword_gist) + (keyword_btree) + (keyword_hash) + (keyword_spgist) + (keyword_gin) + (keyword_brin) + (keyword_array) +] @function.call + +(relation + alias: (identifier) @variable) + +(field + name: (identifier) @field) + +(term + alias: (identifier) @variable) + +((term + value: (cast + name: (keyword_cast) @function.call + parameter: [(literal)]?))) + +(comment) @comment @spell +(marginalia) @comment + +((literal) @number + (#match? @number "^[-+]?%d+$")) + +((literal) @float + (#match? @float "^[-+]?%d*\.%d*$")) + +(literal) @string + +(parameter) @parameter + +[ + (keyword_true) + (keyword_false) +] @boolean + +[ + (keyword_asc) + (keyword_desc) + (keyword_terminated) + (keyword_escaped) + (keyword_unsigned) + (keyword_nulls) + (keyword_last) + (keyword_delimited) + (keyword_replication) + (keyword_auto_increment) + (keyword_default) + (keyword_collate) + (keyword_concurrently) + (keyword_engine) + (keyword_always) + (keyword_generated) + (keyword_preceding) + (keyword_following) + (keyword_first) + (keyword_current_timestamp) + (keyword_immutable) + (keyword_atomic) + (keyword_parallel) + (keyword_leakproof) + (keyword_safe) + (keyword_cost) + (keyword_strict) +] @attribute + +[ + (keyword_materialized) + (keyword_recursive) + (keyword_temp) + (keyword_temporary) + (keyword_unlogged) + (keyword_external) + (keyword_parquet) + (keyword_csv) + (keyword_rcfile) + (keyword_textfile) + (keyword_orc) + (keyword_avro) + (keyword_jsonfile) + (keyword_sequencefile) + (keyword_volatile) +] @storageclass + +[ + (keyword_case) + (keyword_when) + (keyword_then) + (keyword_else) +] @conditional + +[ + (keyword_select) + (keyword_from) + (keyword_where) + (keyword_index) + (keyword_join) + (keyword_primary) + (keyword_delete) + (keyword_create) + (keyword_show) + (keyword_insert) + (keyword_merge) + (keyword_distinct) + (keyword_replace) + (keyword_update) + (keyword_into) + (keyword_overwrite) + (keyword_matched) + (keyword_values) + (keyword_value) + (keyword_attribute) + (keyword_set) + (keyword_left) + (keyword_right) + (keyword_outer) + (keyword_inner) + (keyword_full) + (keyword_order) + (keyword_partition) + (keyword_group) + (keyword_with) + (keyword_without) + (keyword_as) + (keyword_having) + (keyword_limit) + (keyword_offset) + (keyword_table) + (keyword_tables) + (keyword_key) + (keyword_references) + (keyword_foreign) + (keyword_constraint) + (keyword_force) + (keyword_use) + (keyword_for) + (keyword_if) + (keyword_exists) + (keyword_column) + (keyword_columns) + (keyword_cross) + (keyword_lateral) + (keyword_natural) + (keyword_alter) + (keyword_drop) + (keyword_add) + (keyword_view) + (keyword_end) + (keyword_is) + (keyword_using) + (keyword_between) + (keyword_window) + (keyword_no) + (keyword_data) + (keyword_type) + (keyword_rename) + (keyword_to) + (keyword_schema) + (keyword_owner) + (keyword_authorization) + (keyword_all) + (keyword_any) + (keyword_some) + (keyword_returning) + (keyword_begin) + (keyword_commit) + (keyword_rollback) + (keyword_transaction) + (keyword_only) + (keyword_like) + (keyword_similar) + (keyword_over) + (keyword_change) + (keyword_modify) + (keyword_after) + (keyword_before) + (keyword_range) + (keyword_rows) + (keyword_groups) + (keyword_exclude) + (keyword_current) + (keyword_ties) + (keyword_others) + (keyword_zerofill) + (keyword_format) + (keyword_fields) + (keyword_row) + (keyword_sort) + (keyword_compute) + (keyword_comment) + (keyword_location) + (keyword_cached) + (keyword_uncached) + (keyword_lines) + (keyword_stored) + (keyword_virtual) + (keyword_partitioned) + (keyword_analyze) + (keyword_explain) + (keyword_verbose) + (keyword_truncate) + (keyword_rewrite) + (keyword_optimize) + (keyword_vacuum) + (keyword_cache) + (keyword_language) + (keyword_called) + (keyword_conflict) + (keyword_declare) + (keyword_filter) + (keyword_function) + (keyword_input) + (keyword_name) + (keyword_oid) + (keyword_oids) + (keyword_precision) + (keyword_regclass) + (keyword_regnamespace) + (keyword_regproc) + (keyword_regtype) + (keyword_restricted) + (keyword_return) + (keyword_returns) + (keyword_separator) + (keyword_setof) + (keyword_stable) + (keyword_support) + (keyword_tblproperties) + (keyword_trigger) + (keyword_unsafe) + (keyword_admin) + (keyword_connection) + (keyword_cycle) + (keyword_database) + (keyword_encrypted) + (keyword_increment) + (keyword_logged) + (keyword_none) + (keyword_owned) + (keyword_password) + (keyword_reset) + (keyword_role) + (keyword_sequence) + (keyword_start) + (keyword_restart) + (keyword_tablespace) + (keyword_until) + (keyword_user) + (keyword_valid) + (keyword_action) + (keyword_definer) + (keyword_invoker) + (keyword_security) + (keyword_extension) + (keyword_version) + (keyword_out) + (keyword_inout) + (keyword_variadic) + (keyword_ordinality) + (keyword_session) + (keyword_isolation) + (keyword_level) + (keyword_serializable) + (keyword_repeatable) + (keyword_read) + (keyword_write) + (keyword_committed) + (keyword_uncommitted) + (keyword_deferrable) + (keyword_names) + (keyword_zone) + (keyword_immediate) + (keyword_deferred) + (keyword_constraints) + (keyword_snapshot) + (keyword_characteristics) + (keyword_off) + (keyword_follows) + (keyword_precedes) + (keyword_each) + (keyword_instead) + (keyword_of) + (keyword_initially) + (keyword_old) + (keyword_new) + (keyword_referencing) + (keyword_statement) + (keyword_execute) + (keyword_procedure) + (keyword_copy) + (keyword_delimiter) + (keyword_encoding) + (keyword_escape) + (keyword_force_not_null) + (keyword_force_null) + (keyword_force_quote) + (keyword_freeze) + (keyword_header) + (keyword_match) + (keyword_program) + (keyword_quote) + (keyword_stdin) + (keyword_extended) + (keyword_main) + (keyword_plain) + (keyword_storage) + (keyword_compression) + (keyword_duplicate) +] @keyword + +[ + (keyword_restrict) + (keyword_unbounded) + (keyword_unique) + (keyword_cascade) + (keyword_delayed) + (keyword_high_priority) + (keyword_low_priority) + (keyword_ignore) + (keyword_nothing) + (keyword_check) + (keyword_option) + (keyword_local) + (keyword_cascaded) + (keyword_wait) + (keyword_nowait) + (keyword_metadata) + (keyword_incremental) + (keyword_bin_pack) + (keyword_noscan) + (keyword_stats) + (keyword_statistics) + (keyword_maxvalue) + (keyword_minvalue) +] @type.qualifier + +[ + (keyword_int) + (keyword_null) + (keyword_boolean) + (keyword_binary) + (keyword_varbinary) + (keyword_image) + (keyword_bit) + (keyword_inet) + (keyword_character) + (keyword_smallserial) + (keyword_serial) + (keyword_bigserial) + (keyword_smallint) + (keyword_mediumint) + (keyword_bigint) + (keyword_tinyint) + (keyword_decimal) + (keyword_float) + (keyword_double) + (keyword_numeric) + (keyword_real) + (double) + (keyword_money) + (keyword_smallmoney) + (keyword_char) + (keyword_nchar) + (keyword_varchar) + (keyword_nvarchar) + (keyword_varying) + (keyword_text) + (keyword_string) + (keyword_uuid) + (keyword_json) + (keyword_jsonb) + (keyword_xml) + (keyword_bytea) + (keyword_enum) + (keyword_date) + (keyword_datetime) + (keyword_time) + (keyword_datetime2) + (keyword_datetimeoffset) + (keyword_smalldatetime) + (keyword_timestamp) + (keyword_timestamptz) + (keyword_geometry) + (keyword_geography) + (keyword_box2d) + (keyword_box3d) + (keyword_interval) +] @type.builtin + +[ + (keyword_in) + (keyword_and) + (keyword_or) + (keyword_not) + (keyword_by) + (keyword_on) + (keyword_do) + (keyword_union) + (keyword_except) + (keyword_intersect) +] @keyword.operator + +[ + "+" + "-" + "*" + "/" + "%" + "^" + ":=" + "=" + "<" + "<=" + "!=" + ">=" + ">" + "<>" + (op_other) + (op_unary_other) +] @operator + +[ + "(" + ")" +] @punctuation.bracket + +[ + ";" + "," + "." +] @punctuation.delimiter diff --git a/src/memray/_vendor/textual/tree-sitter/highlights/toml.scm b/src/memray/_vendor/textual/tree-sitter/highlights/toml.scm new file mode 100644 index 0000000000..58ba6eb2d3 --- /dev/null +++ b/src/memray/_vendor/textual/tree-sitter/highlights/toml.scm @@ -0,0 +1,36 @@ +; Properties +;----------- + +(bare_key) @toml.type +(quoted_key) @string +(pair (bare_key)) @property + +; Literals +;--------- + +(boolean) @boolean +(comment) @comment +(string) @string +(integer) @number +(float) @float +(offset_date_time) @toml.datetime +(local_date_time) @toml.datetime +(local_date) @toml.datetime +(local_time) @toml.datetime + +; Punctuation +;------------ + +"." @punctuation.delimiter +"," @punctuation.delimiter + +"=" @operator + +"[" @punctuation.bracket +"]" @punctuation.bracket +"[[" @punctuation.bracket +"]]" @punctuation.bracket +"{" @punctuation.bracket +"}" @punctuation.bracket + +(ERROR) @toml.error diff --git a/src/memray/_vendor/textual/tree-sitter/highlights/xml.scm b/src/memray/_vendor/textual/tree-sitter/highlights/xml.scm new file mode 100644 index 0000000000..9861eea178 --- /dev/null +++ b/src/memray/_vendor/textual/tree-sitter/highlights/xml.scm @@ -0,0 +1,168 @@ +;; XML declaration + +"xml" @keyword + +[ "version" "encoding" "standalone" ] @property + +(EncName) @string.special + +(VersionNum) @number + +[ "yes" "no" ] @boolean + +;; Processing instructions + +(PI) @embedded + +(PI (PITarget) @keyword) + +;; Element declaration + +(elementdecl + "ELEMENT" @keyword + (Name) @tag) + +(contentspec + (_ (Name) @property)) + +"#PCDATA" @type.builtin + +[ "EMPTY" "ANY" ] @string.special.symbol + +[ "*" "?" "+" ] @operator + +;; Entity declaration + +(GEDecl + "ENTITY" @keyword + (Name) @constant) + +(GEDecl (EntityValue) @string) + +(NDataDecl + "NDATA" @keyword + (Name) @label) + +;; Parsed entity declaration + +(PEDecl + "ENTITY" @keyword + "%" @operator + (Name) @constant) + +(PEDecl (EntityValue) @string) + +;; Notation declaration + +(NotationDecl + "NOTATION" @keyword + (Name) @constant) + +(NotationDecl + (ExternalID + (SystemLiteral (URI) @string.special))) + +;; Attlist declaration + +(AttlistDecl + "ATTLIST" @keyword + (Name) @tag) + +(AttDef (Name) @property) + +(AttDef (Enumeration (Nmtoken) @string)) + +(DefaultDecl (AttValue) @string) + +[ + (StringType) + (TokenizedType) +] @type.builtin + +(NotationType "NOTATION" @type.builtin) + +[ + "#REQUIRED" + "#IMPLIED" + "#FIXED" +] @attribute + +;; Entities + +(EntityRef) @constant + +((EntityRef) @constant.builtin + (#any-of? @constant.builtin + "&" "<" ">" """ "'")) + +(CharRef) @constant + +(PEReference) @constant + +;; External references + +[ "PUBLIC" "SYSTEM" ] @keyword + +(PubidLiteral) @string.special + +(SystemLiteral (URI) @markup.link) + +;; Processing instructions + +(XmlModelPI "xml-model" @keyword) + +(StyleSheetPI "xml-stylesheet" @keyword) + +(PseudoAtt (Name) @property) + +(PseudoAtt (PseudoAttValue) @string) + +;; Doctype declaration + +(doctypedecl "DOCTYPE" @keyword) + +(doctypedecl (Name) @type) + +;; Tags + +(STag (Name) @tag) + +(ETag (Name) @tag) + +(EmptyElemTag (Name) @tag) + +;; Attributes + +(Attribute (Name) @property) + +(Attribute (AttValue) @string) + +;; Delimiters & punctuation + +[ + "" + "" + "<" ">" + "" +] @punctuation.delimiter + +[ "(" ")" "[" "]" ] @punctuation.bracket + +[ "\"" "'" ] @punctuation.delimiter + +[ "," "|" "=" ] @operator + +;; Text + +(CharData) @markup + +(CDSect + (CDStart) @markup.heading + (CData) @markup.raw + "]]>" @markup.heading) + +;; Misc + +(Comment) @comment + +(ERROR) @error diff --git a/src/memray/_vendor/textual/tree-sitter/highlights/yaml.scm b/src/memray/_vendor/textual/tree-sitter/highlights/yaml.scm new file mode 100644 index 0000000000..a57f464dfc --- /dev/null +++ b/src/memray/_vendor/textual/tree-sitter/highlights/yaml.scm @@ -0,0 +1,53 @@ +(boolean_scalar) @boolean +(null_scalar) @constant.builtin +(double_quote_scalar) @string +(single_quote_scalar) @string +((block_scalar) @string (#set! "priority" 99)) +(string_scalar) @string +(escape_sequence) @string.escape +(integer_scalar) @number +(float_scalar) @number +(comment) @comment +(anchor_name) @type +(alias_name) @type +(tag) @type +(ERROR) @error + +[ + (yaml_directive) + (tag_directive) + (reserved_directive) +] @preproc + +(block_mapping_pair + key: (flow_node [(double_quote_scalar) (single_quote_scalar)] @yaml.field)) +(block_mapping_pair + key: (flow_node (plain_scalar (string_scalar) @yaml.field))) + +(flow_mapping + (_ key: (flow_node [(double_quote_scalar) (single_quote_scalar)] @yaml.field))) +(flow_mapping + (_ key: (flow_node (plain_scalar (string_scalar) @yaml.field)))) + +[ + "," + "-" + ":" + ">" + "?" + "|" +] @punctuation.delimiter + +[ + "[" + "]" + "{" + "}" +] @punctuation.bracket + +[ + "*" + "&" + "---" + "..." +] @punctuation.special diff --git a/src/memray/_vendor/textual/types.py b/src/memray/_vendor/textual/types.py new file mode 100644 index 0000000000..d6362624d0 --- /dev/null +++ b/src/memray/_vendor/textual/types.py @@ -0,0 +1,52 @@ +""" +Export some objects that are used by Textual and that help document other features. +""" + +from memray._vendor.textual._animator import Animatable, EasingFunction +from memray._vendor.textual._context import NoActiveAppError +from memray._vendor.textual._path import CSSPathError, CSSPathType +from memray._vendor.textual._types import ( + AnimationLevel, + CallbackType, + IgnoreReturnCallbackType, + MessageTarget, + UnusedParameter, + WatchCallbackType, +) +from memray._vendor.textual._widget_navigation import Direction +from memray._vendor.textual.actions import ActionParseResult +from memray._vendor.textual.css.styles import RenderStyles +from memray._vendor.textual.widgets._directory_tree import DirEntry +from memray._vendor.textual.widgets._input import InputValidationOn +from memray._vendor.textual.widgets._option_list import ( + DuplicateID, + OptionDoesNotExist, + OptionListContent, +) +from memray._vendor.textual.widgets._placeholder import PlaceholderVariant +from memray._vendor.textual.widgets._select import NoSelection, SelectType + +__all__ = [ + "ActionParseResult", + "Animatable", + "AnimationLevel", + "CallbackType", + "CSSPathError", + "CSSPathType", + "DirEntry", + "Direction", + "DuplicateID", + "EasingFunction", + "IgnoreReturnCallbackType", + "InputValidationOn", + "MessageTarget", + "NoActiveAppError", + "NoSelection", + "OptionDoesNotExist", + "OptionListContent", + "PlaceholderVariant", + "RenderStyles", + "SelectType", + "UnusedParameter", + "WatchCallbackType", +] diff --git a/src/memray/_vendor/textual/validation.py b/src/memray/_vendor/textual/validation.py new file mode 100644 index 0000000000..9ce79a9b83 --- /dev/null +++ b/src/memray/_vendor/textual/validation.py @@ -0,0 +1,519 @@ +""" + +This module provides a number of classes for validating input. + +See [Validating Input](/widgets/input/#validating-input) for details. + +""" + +from __future__ import annotations + +import math +import re +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Callable, Pattern, Sequence +from urllib.parse import urlparse + +import rich.repr + + +@dataclass +class ValidationResult: + """The result of calling a `Validator.validate` method.""" + + failures: Sequence[Failure] = field(default_factory=list) + """A list of reasons why the value was invalid. Empty if valid=True""" + + @staticmethod + def merge(results: Sequence["ValidationResult"]) -> "ValidationResult": + """Merge multiple ValidationResult objects into one. + + Args: + results: List of ValidationResult objects to merge. + + Returns: + Merged ValidationResult object. + """ + is_valid = all(result.is_valid for result in results) + failures = [failure for result in results for failure in result.failures] + if is_valid: + return ValidationResult.success() + else: + return ValidationResult.failure(failures) + + @staticmethod + def success() -> ValidationResult: + """Construct a successful ValidationResult. + + Returns: + A successful ValidationResult. + """ + return ValidationResult() + + @staticmethod + def failure(failures: Sequence[Failure]) -> ValidationResult: + """Construct a failure ValidationResult. + + Args: + failures: The failures. + + Returns: + A failure ValidationResult. + """ + return ValidationResult(failures) + + @property + def failure_descriptions(self) -> list[str]: + """Utility for extracting failure descriptions as strings. + + Useful if you don't care about the additional metadata included in the `Failure` objects. + + Returns: + A list of the string descriptions explaining the failing validations. + """ + return [ + failure.description + for failure in self.failures + if failure.description is not None + ] + + @property + def is_valid(self) -> bool: + """True if the validation was successful.""" + return len(self.failures) == 0 + + +@dataclass +class Failure: + """Information about a validation failure.""" + + validator: Validator + """The Validator which produced the failure.""" + value: str | None = None + """The value which resulted in validation failing.""" + description: str | None = None + """An optional override for describing this failure. Takes precedence over any messages set in the Validator.""" + + def __post_init__(self) -> None: + # If a failure message isn't supplied, try to get it from the Validator. + if self.description is None: + if self.validator.failure_description is not None: + self.description = self.validator.failure_description + else: + self.description = self.validator.describe_failure(self) + + def __rich_repr__(self) -> rich.repr.Result: # pragma: no cover + yield self.value + yield self.validator + yield self.description + + +class Validator(ABC): + '''Base class for the validation of string values. + + Commonly used in conjunction with the `Input` widget, which accepts a + list of validators via its constructor. This validation framework can also be used to validate any 'stringly-typed' + values (for example raw command line input from `sys.args`). + + To implement your own `Validator`, subclass this class. + + Example: + ```python + def is_palindrome(value: str) -> bool: + """Check has string has the same code points left to right, as right to left.""" + return value == value[::-1] + + class Palindrome(Validator): + def validate(self, value: str) -> ValidationResult: + if is_palindrome(value): + return self.success() + else: + return self.failure("Not a palindrome!") + ``` + ''' + + def __init__(self, failure_description: str | None = None) -> None: + self.failure_description = failure_description + """A description of why the validation failed. + + The description (intended to be user-facing) to attached to the Failure if the validation fails. + This failure description is ultimately accessible at the time of validation failure via the `Input.Changed` + or `Input.Submitted` event, and you can access it on your message handler (a method called, for example, + `on_input_changed` or a method decorated with `@on(Input.Changed)`. + """ + + @abstractmethod + def validate(self, value: str) -> ValidationResult: + """Validate the value and return a ValidationResult describing the outcome of the validation. + + Implement this method when defining custom validators. + + Args: + value: The value to validate. + + Returns: + The result of the validation ([`self.success()`][textual.validation.Validator.success) or [`self.failure(...)`][textual.validation.Validator.failure]). + """ + + def describe_failure(self, failure: Failure) -> str | None: + """Return a string description of the Failure. + + Used to provide a more fine-grained description of the failure. A Validator could fail for multiple + reasons, so this method could be used to provide a different reason for different types of failure. + + !!! warning + + This method is only called if no other description has been supplied. If you supply a description + inside a call to `self.failure(description="...")`, or pass a description into the constructor of + the validator, those will take priority, and this method won't be called. + + Args: + failure: Information about why the validation failed. + + Returns: + A string description of the failure. + """ + return self.failure_description + + def success(self) -> ValidationResult: + """Shorthand for `ValidationResult(True)`. + + Return `self.success()` from [`validate()`][textual.validation.Validator.validate] to indicated that validation *succeeded*. + + Returns: + A ValidationResult indicating validation succeeded. + """ + return ValidationResult() + + def failure( + self, + description: str | None = None, + value: str | None = None, + failures: Failure | Sequence[Failure] | None = None, + ) -> ValidationResult: + """Shorthand for signaling validation failure. + + Return `self.failure(...)` from [`validate()`][textual.validation.Validator.validate] to indicated that validation *failed*. + + Args: + description: The failure description that will be used. When used in conjunction with the Input widget, + this is the description that will ultimately be available inside the handler for `Input.Changed`. If not + supplied, the `failure_description` from the `Validator` will be used. If that is not supplied either, + then the `describe_failure` method on `Validator` will be called. + value: The value that was considered invalid. This is optional, and only needs to be supplied if required + in your `Input.Changed` handler. + failures: The reasons the validator failed. If not supplied, a generic `Failure` will be included in the + ValidationResult returned from this function. + + Returns: + A ValidationResult representing failed validation, and containing the metadata supplied + to this function. + """ + if isinstance(failures, Failure): + failures = [failures] + + result = ValidationResult( + failures or [Failure(validator=self, value=value, description=description)], + ) + return result + + +class Regex(Validator): + """A validator that checks the value matches a regex (via `re.fullmatch`).""" + + def __init__( + self, + regex: str | Pattern[str], + flags: int | re.RegexFlag = 0, + failure_description: str | None = None, + ) -> None: + super().__init__(failure_description=failure_description) + self.regex = regex + """The regex which we'll validate is matched by the value.""" + self.flags = flags + """The flags to pass to `re.fullmatch`.""" + + class NoResults(Failure): + """Indicates validation failed because the regex could not be found within the value string.""" + + def validate(self, value: str) -> ValidationResult: + """Ensure that the value matches the regex. + + Args: + value: The value that should match the regex. + + Returns: + The result of the validation. + """ + regex = self.regex + has_match = re.fullmatch(regex, value, flags=self.flags) is not None + if not has_match: + failures = [Regex.NoResults(self, value)] + return self.failure(failures=failures) + return self.success() + + def describe_failure(self, failure: Failure) -> str | None: + """Describes why the validator failed. + + Args: + failure: Information about why the validation failed. + + Returns: + A string description of the failure. + """ + return f"Must match regular expression {self.regex!r} (flags={self.flags})." + + +class Number(Validator): + """Validator that ensures the value is a number, with an optional range check.""" + + def __init__( + self, + minimum: float | None = None, + maximum: float | None = None, + failure_description: str | None = None, + ) -> None: + super().__init__(failure_description=failure_description) + self.minimum = minimum + """The minimum value of the number, inclusive. If `None`, the minimum is unbounded.""" + self.maximum = maximum + """The maximum value of the number, inclusive. If `None`, the maximum is unbounded.""" + + class NotANumber(Failure): + """Indicates a failure due to the value not being a valid number (decimal/integer, inc. scientific notation)""" + + class NotInRange(Failure): + """Indicates a failure due to the number not being within the range [minimum, maximum].""" + + def validate(self, value: str) -> ValidationResult: + """Ensure that `value` is a valid number, optionally within a range. + + Args: + value: The value to validate. + + Returns: + The result of the validation. + """ + try: + float_value = float(value) + except ValueError: + return ValidationResult.failure([Number.NotANumber(self, value)]) + + if math.isnan(float_value) or math.isinf(float_value): + return ValidationResult.failure([Number.NotANumber(self, value)]) + + if not self._validate_range(float_value): + return ValidationResult.failure( + [Number.NotInRange(self, value)], + ) + return self.success() + + def _validate_range(self, value: float) -> bool: + """Return a boolean indicating whether the number is within the range specified in the attributes.""" + if self.minimum is not None and value < self.minimum: + return False + if self.maximum is not None and value > self.maximum: + return False + return True + + def describe_failure(self, failure: Failure) -> str | None: + """Describes why the validator failed. + + Args: + failure: Information about why the validation failed. + + Returns: + A string description of the failure. + """ + if isinstance(failure, Number.NotANumber): + return "Must be a valid number." + elif isinstance(failure, Number.NotInRange): + if self.minimum is None and self.maximum is not None: + return f"Must be less than or equal to {self.maximum}." + elif self.minimum is not None and self.maximum is None: + return f"Must be greater than or equal to {self.minimum}." + else: + return f"Must be between {self.minimum} and {self.maximum}." + else: + return None + + +class Integer(Number): + """Validator which ensures the value is an integer which falls within a range.""" + + class NotAnInteger(Failure): + """Indicates a failure due to the value not being a valid integer.""" + + def validate(self, value: str) -> ValidationResult: + """Ensure that `value` is an integer, optionally within a range. + + Args: + value: The value to validate. + + Returns: + The result of the validation. + """ + # First, check that we're dealing with a number in the range. + number_validation_result = super().validate(value) + if not number_validation_result.is_valid: + return number_validation_result + + # We know it's a number, but is that number an integer? + try: + int_value = int(value) + except ValueError: + return ValidationResult.failure([Integer.NotAnInteger(self, value)]) + return self.success() + + def describe_failure(self, failure: Failure) -> str | None: + """Describes why the validator failed. + + Args: + failure: Information about why the validation failed. + + Returns: + A string description of the failure. + """ + if isinstance(failure, (Integer.NotANumber, Integer.NotAnInteger)): + return "Must be a valid integer." + elif isinstance(failure, Integer.NotInRange): + if self.minimum is None and self.maximum is not None: + return f"Must be less than or equal to {self.maximum}." + elif self.minimum is not None and self.maximum is None: + return f"Must be greater than or equal to {self.minimum}." + else: + return f"Must be between {self.minimum} and {self.maximum}." + else: + return None + + +class Length(Validator): + """Validate that a string is within a range (inclusive).""" + + def __init__( + self, + minimum: int | None = None, + maximum: int | None = None, + failure_description: str | None = None, + ) -> None: + super().__init__(failure_description=failure_description) + self.minimum = minimum + """The inclusive minimum length of the value, or None if unbounded.""" + self.maximum = maximum + """The inclusive maximum length of the value, or None if unbounded.""" + + class Incorrect(Failure): + """Indicates a failure due to the length of the value being outside the range.""" + + def validate(self, value: str) -> ValidationResult: + """Ensure that value falls within the maximum and minimum length constraints. + + Args: + value: The value to validate. + + Returns: + The result of the validation. + """ + too_short = self.minimum is not None and len(value) < self.minimum + too_long = self.maximum is not None and len(value) > self.maximum + if too_short or too_long: + return ValidationResult.failure([Length.Incorrect(self, value)]) + return self.success() + + def describe_failure(self, failure: Failure) -> str | None: + """Describes why the validator failed. + + Args: + failure: Information about why the validation failed. + + Returns: + A string description of the failure. + """ + if isinstance(failure, Length.Incorrect): + if self.minimum is None and self.maximum is not None: + return f"Must be shorter than {self.maximum} characters." + elif self.minimum is not None and self.maximum is None: + return f"Must be longer than {self.minimum} characters." + else: + return f"Must be between {self.minimum} and {self.maximum} characters." + return None + + +class Function(Validator): + """A flexible validator which allows you to provide custom validation logic.""" + + def __init__( + self, + function: Callable[[str], bool], + failure_description: str | None = None, + ) -> None: + super().__init__(failure_description=failure_description) + self.function = function + """Function which takes the value to validate and returns True if valid, and False otherwise.""" + + class ReturnedFalse(Failure): + """Indicates validation failed because the supplied function returned False.""" + + def validate(self, value: str) -> ValidationResult: + """Validate that the supplied function returns True. + + Args: + value: The value to pass into the supplied function. + + Returns: + A ValidationResult indicating success if the function returned True, + and failure if the function return False. + """ + is_valid = self.function(value) + if is_valid: + return self.success() + return self.failure(failures=Function.ReturnedFalse(self, value)) + + def describe_failure(self, failure: Failure) -> str | None: + """Describes why the validator failed. + + Args: + failure: Information about why the validation failed. + + Returns: + A string description of the failure. + """ + return self.failure_description + + +class URL(Validator): + """Validator that checks if a URL is valid (ensuring a scheme is present).""" + + class InvalidURL(Failure): + """Indicates that the URL is not valid.""" + + def validate(self, value: str) -> ValidationResult: + """Validates that `value` is a valid URL (contains a scheme). + + Args: + value: The value to validate. + + Returns: + The result of the validation. + """ + invalid_url = ValidationResult.failure([URL.InvalidURL(self, value)]) + try: + parsed_url = urlparse(value) + if not all([parsed_url.scheme, parsed_url.netloc]): + return invalid_url + except ValueError: + return invalid_url + + return self.success() + + def describe_failure(self, failure: Failure) -> str | None: + """Describes why the validator failed. + + Args: + failure: Information about why the validation failed. + + Returns: + A string description of the failure. + """ + return "Must be a valid URL." diff --git a/src/memray/_vendor/textual/visual.py b/src/memray/_vendor/textual/visual.py new file mode 100644 index 0000000000..f8dc167914 --- /dev/null +++ b/src/memray/_vendor/textual/visual.py @@ -0,0 +1,431 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from itertools import islice +from typing import TYPE_CHECKING, Callable, Protocol + +import rich.repr +from rich.console import Console, ConsoleOptions, RenderableType +from rich.measure import Measurement +from rich.protocol import is_renderable, rich_cast +from rich.segment import Segment +from rich.style import NULL_STYLE as RICH_NULL_STYLE +from rich.style import Style as RichStyle +from rich.text import Text + +from memray._vendor.textual._context import active_app +from memray._vendor.textual.css.styles import RulesMap +from memray._vendor.textual.geometry import Spacing +from memray._vendor.textual.render import measure +from memray._vendor.textual.selection import Selection +from memray._vendor.textual.strip import Strip +from memray._vendor.textual.style import Style + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + from memray._vendor.textual.widget import Widget + + +def is_visual(obj: object) -> bool: + """Check if the given object is a Visual or supports the Visual protocol.""" + return isinstance(obj, Visual) or hasattr(obj, "textualize") + + +@dataclass(frozen=True) +class RenderOptions: + """Additional options passed to `Visual.render_strips`.""" + + get_style: Callable[[str | Style], Style] + """Callable to get a style.""" + rules: RulesMap + """Mapping of style rules.""" + selection: Selection | None = None + """Text selection information.""" + selection_style: Style | None = None + """Style of text selection.""" + post_style: Style | None = None + """Optional style to apply post render.""" + + +# Note: not runtime checkable currently, as I've found that to be slow +class SupportsVisual(Protocol): + """An object that supports the textualize protocol.""" + + def visualize(self, widget: Widget, obj: object) -> Visual | None: + """Convert the result of a Widget.render() call into a Visual, using the Visual protocol. + + Args: + widget: The widget that generated the render. + obj: The result of the render. + + Returns: + A Visual instance, or `None` if it wasn't possible. + + """ + + +class VisualError(Exception): + """An error with the visual protocol.""" + + +VisualType: TypeAlias = "RenderableType | SupportsVisual | Visual" + + +def visualize(widget: Widget, obj: object, markup: bool = True) -> Visual: + """Get a visual instance from an object. + + If the object does not support the Visual protocol and is a Rich renderable, it + will be wrapped in a [RichVisual][textual.visual.RichVisual]. + + Args: + widget: The parent widget. + obj: An object. + markup: Enable markup. + + Raises: + VisualError: If there is no Visual could be returned to render `obj`. + + Returns: + A Visual instance to render the object, or `None` if there is no associated visual. + """ + _rich_traceback_omit = True + if isinstance(obj, Visual): + # Already a visual + return obj + # The visualize method should return a Visual if present. + visualize = getattr(obj, "visualize", None) + if visualize is None: + # Doesn't expose the textualize protocol + from memray._vendor.textual.content import Content + + if isinstance(obj, str): + return Content.from_markup(obj) if markup else Content(obj) + + if is_renderable(obj): + if isinstance(obj, Text): + return Content.from_rich_text(obj, console=widget.app.console) + + # If its is a Rich renderable, wrap it with a RichVisual + return RichVisual(widget, rich_cast(obj)) + else: + # We don't know how to make a visual from this object + raise VisualError( + f"unable to display {obj.__class__.__name__!r} type; must be a str, Rich renderable, or Textual Visual object" + ) + # Call the textualize method to create a visual + visual = visualize() + if not isinstance(visual, Visual) and is_renderable(visual): + return RichVisual(widget, visual) + return visual + + +class Visual(ABC): + """A Textual 'Visual' object. + + Analogous to a Rich renderable, but with support for transparency. + + """ + + @abstractmethod + def render_strips( + self, width: int, height: int | None, style: Style, options: RenderOptions + ) -> list[Strip]: + """Render the Visual into an iterable of strips. + + Args: + width: Width of desired render. + height: Height of desired render or `None` for any height. + style: The base style to render on top of. + options: Additional render options. + + Returns: + An list of Strips. + """ + + @abstractmethod + def get_optimal_width(self, rules: RulesMap, container_width: int) -> int: + """Get optimal width of the Visual to display its content. + + The exact definition of "optimal width" is dependant on the Visual, but + will typically be wide enough to display output without cropping or wrapping, + and without superfluous space. + + Args: + rules: A mapping of style rules, such as the Widgets `styles` object. + container_width: The width of the container, used by Rich Renderables. + May be ignored for Textual Visuals. + + Returns: + A width in cells. + + """ + + def get_minimal_width(self, rules: RulesMap) -> int: + """Get a minimal width (the smallest width before data loss occurs). + + Args: + rules: A mapping of style rules, such as the Widgets `styles` object. + container_width: The width of the container, used by Rich Renderables. + May be ignored for Textual Visuals. + + Returns: + A width in cells. + + """ + return 1 + + @abstractmethod + def get_height(self, rules: RulesMap, width: int) -> int: + """Get the height of the Visual if rendered at the given width. + + Args: + rules: A mapping of style rules, such as the Widgets `styles` object. + width: Width of visual in cells. + + Returns: + A height in lines. + """ + + @classmethod + def to_strips( + cls, + widget: Widget, + visual: Visual, + width: int, + height: int | None, + style: Style, + *, + apply_selection: bool = True, + pad: bool = False, + post_style: Style | None = None, + ) -> list[Strip]: + """High level function to render a visual to strips. + + Args: + widget: Widget that produced the visual. + visual: A Visual instance. + width: Desired width (in cells). + height: Desired height (in lines) or `None` for no limit. + style: A (Visual) Style instance. + apply_selection: Automatically apply selection styles? + pad: Pad to desired width? + post_style: Optional Style to apply to strips after rendering. + + Returns: + A list of Strips containing the render. + """ + + selection = widget.text_selection + if selection is not None: + selection_style: Style | None = Style.from_rich_style( + widget.screen.get_component_rich_style( + "screen--selection", default=RICH_NULL_STYLE + ) + ) + else: + selection_style = None + + strips = visual.render_strips( + width, + height, + style, + RenderOptions( + widget._get_style, + widget.styles, + selection if apply_selection else None, + selection_style, + ), + ) + if widget.auto_links and not widget.is_container: + # TODO: This is suprisingly expensive (why?) + link_style = widget.link_style + strips = [strip._apply_link_style(link_style) for strip in strips] + + if height is None: + height = len(strips) + rich_style = (style + Style(reverse=False)).rich_style + if pad: + strips = [strip.extend_cell_length(width, rich_style) for strip in strips] + content_align = widget.styles.content_align + if content_align != ("left", "top"): + align_horizontal, align_vertical = content_align + strips = list( + Strip.align( + strips, + rich_style, + width, + height, + align_horizontal, + align_vertical, + ) + ) + return strips + + +@rich.repr.auto +class RichVisual(Visual): + """A Visual to wrap a Rich renderable.""" + + def __init__(self, widget: Widget, renderable: RenderableType) -> None: + """ + + Args: + widget: The associated Widget. + renderable: A Rich renderable. + """ + self._widget = widget + self._renderable = renderable + self._measurement: Measurement | None = None + + def __rich_repr__(self) -> rich.repr.Result: + yield self._widget + yield self._renderable + + def _measure(self, console: Console, options: ConsoleOptions) -> Measurement: + if self._measurement is None: + self._measurement = Measurement.get( + console, + options, + self._widget.post_render(self._renderable, RichStyle.null()), + ) + return self._measurement + + def get_optimal_width(self, rules: RulesMap, container_width: int) -> int: + console = active_app.get().console + width = measure( + console, self._renderable, container_width, container_width=container_width + ) + return width + + def get_height(self, rules: RulesMap, width: int) -> int: + app = active_app.get() + console = app.console + renderable = self._renderable + if isinstance(renderable, Text): + height = len( + Text(renderable.plain).wrap( + console, + width, + no_wrap=renderable.no_wrap, + tab_size=renderable.tab_size or 8, + ) + ) + else: + console_options = app.console_options + options = console_options.update_width(width).update(highlight=False) + segments = console.render(renderable, options) + # Cheaper than counting the lines returned from render_lines! + height = sum([text.count("\n") for text, _, _ in segments]) + + return height + + def render_strips( + self, width: int, height: int | None, style: Style, options: RenderOptions + ) -> list[Strip]: + """Render the Visual into an iterable of strips. Part of the Visual protocol. + + Args: + width: Width of desired render. + height: Height of desired render or `None` for any height. + style: The base style to render on top of. + options: Additional render options. + + Returns: + An list of Strips. + """ + app = active_app.get() + console = app.console + console_options = app.console_options.update( + highlight=False, + width=width, + height=height, + ) + rich_style = style.rich_style + renderable = self._widget.post_render(self._renderable, rich_style) + segments = console.render(renderable, console_options.update_width(width)) + strips = [ + Strip(line) + for line in islice( + Segment.split_and_crop_lines( + segments, width, include_new_lines=False, pad=False + ), + None, + height, + ) + ] + + return strips + + +@rich.repr.auto +class Padding(Visual): + """A Visual to pad another visual.""" + + def __init__(self, visual: Visual, spacing: Spacing) -> None: + """ + + Args: + Visual: A Visual. + spacing: A Spacing object containing desired padding dimensions. + """ + self._visual = visual + self._spacing = spacing + + def __rich_repr__(self) -> rich.repr.Result: + yield self._visual + yield self._spacing + + def get_optimal_width(self, rules: RulesMap, container_width: int) -> int: + return ( + self._visual.get_optimal_width(rules, container_width) + self._spacing.width + ) + + def get_height(self, rules: RulesMap, width: int) -> int: + return ( + self._visual.get_height(rules, width - self._spacing.width) + + self._spacing.height + ) + + def render_strips( + self, width: int, height: int | None, style: Style, options: RenderOptions + ) -> list[Strip]: + """Render the Visual into an iterable of strips. Part of the Visual protocol. + + Args: + width: Width of desired render. + height: Height of desired render or `None` for any height. + style: The base style to render on top of. + options: Additional render options. + + Returns: + An list of Strips. + """ + padding = self._spacing + top, right, bottom, left = self._spacing + render_width = width - (left + right) + if render_width <= 0: + return [] + + strips = self._visual.render_strips( + render_width, + None if height is None else height - padding.height, + style, + options, + ) + + if padding: + rich_style = style.rich_style + top_padding = [Strip.blank(width, rich_style)] * top if top else [] + bottom_padding = [Strip.blank(width, rich_style)] * bottom if bottom else [] + strips = [ + *top_padding, + *[ + strip.crop_pad(render_width, left, right, rich_style) + for strip in strips + ], + *bottom_padding, + ] + + return strips diff --git a/src/memray/_vendor/textual/walk.py b/src/memray/_vendor/textual/walk.py new file mode 100644 index 0000000000..a392035c1e --- /dev/null +++ b/src/memray/_vendor/textual/walk.py @@ -0,0 +1,220 @@ +""" +Functions for *walking* the DOM. + +!!! note + + For most purposes you would be better off using [query][textual.dom.DOMNode.query], which uses these functions internally. +""" + +from __future__ import annotations + +from collections import deque +from operator import attrgetter +from typing import TYPE_CHECKING, Iterable, Iterator, TypeVar, overload + +from memray._vendor.textual.geometry import Shape + +if TYPE_CHECKING: + from memray._vendor.textual.dom import DOMNode + from memray._vendor.textual.widget import Widget + + WalkType = TypeVar("WalkType", bound=DOMNode) + + +if TYPE_CHECKING: + + @overload + def walk_depth_first( + root: DOMNode, + *, + with_root: bool = True, + ) -> Iterable[DOMNode]: ... + + @overload + def walk_depth_first( + root: WalkType, + filter_type: type[WalkType], + *, + with_root: bool = True, + ) -> Iterable[WalkType]: ... + + +def walk_depth_first( + root: DOMNode, + filter_type: type[WalkType] | None = None, + *, + with_root: bool = True, +) -> Iterable[DOMNode] | Iterable[WalkType]: + """Walk the tree depth first (parents first). + + !!! note + + Avoid changing the DOM (mounting, removing etc.) while iterating with this function. + Consider [walk_children][textual.dom.DOMNode.walk_children] which doesn't have this limitation. + + Args: + root: The root note (starting point). + filter_type: Optional DOMNode subclass to filter by, or `None` for no filter. + with_root: Include the root in the walk. + + Returns: + An iterable of DOMNodes, or the type specified in `filter_type`. + """ + stack: list[Iterator[DOMNode]] = [iter(root.children)] + pop = stack.pop + push = stack.append + + if filter_type is None: + if with_root: + yield root + while stack: + if (node := next(stack[-1], None)) is None: + pop() + else: + yield node + if children := node._nodes: + push(iter(children)) + else: + if with_root and isinstance(root, filter_type): + yield root + while stack: + if (node := next(stack[-1], None)) is None: + pop() + else: + if isinstance(node, filter_type): + yield node + if children := node._nodes: + push(iter(children)) + + +if TYPE_CHECKING: + + @overload + def walk_breadth_first( + root: DOMNode, + *, + with_root: bool = True, + ) -> Iterable[DOMNode]: ... + + @overload + def walk_breadth_first( + root: WalkType, + filter_type: type[WalkType], + *, + with_root: bool = True, + ) -> Iterable[WalkType]: ... + + +def walk_breadth_first( + root: DOMNode, + filter_type: type[WalkType] | None = None, + *, + with_root: bool = True, +) -> Iterable[DOMNode] | Iterable[WalkType]: + """Walk the tree breadth first (children first). + + !!! note + + Avoid changing the DOM (mounting, removing etc.) while iterating with this function. + Consider [walk_children][textual.dom.DOMNode.walk_children] which doesn't have this limitation. + + Args: + root: The root note (starting point). + filter_type: Optional DOMNode subclass to filter by, or `None` for no filter. + with_root: Include the root in the walk. + + Returns: + An iterable of DOMNodes, or the type specified in `filter_type`. + """ + from memray._vendor.textual.dom import DOMNode + + queue: deque[DOMNode] = deque() + popleft = queue.popleft + extend = queue.extend + check_type = filter_type or DOMNode + + if with_root and isinstance(root, check_type): + yield root + extend(root.children) + while queue: + node = popleft() + if isinstance(node, check_type): + yield node + extend(node._nodes) + + +def walk_breadth_search_id( + root: DOMNode, node_id: str, *, with_root: bool = True +) -> DOMNode | None: + """Special case to walk breadth first searching for a node with a given id. + + This is more efficient than [walk_breadth_first][textual.walk.walk_breadth_first] for this special case, as it can use an index. + + Args: + root: The root node (starting point). + node_id: Node id to search for. + with_root: Consider the root node? If the root has the node id, then return it. + + Returns: + A DOMNode if a node was found, otherwise `None`. + """ + + if with_root and root.id == node_id: + return root + + queue: deque[DOMNode] = deque() + queue.append(root) + + while queue: + node = queue.popleft() + if (found_node := node._nodes._get_by_id(node_id)) is not None: + return found_node + queue.extend(node._nodes) + return None + + +def walk_selectable_widgets( + root: DOMNode, bounds: Shape, bounded: set[DOMNode] +) -> Iterable[Widget]: + """Walk the tree depth first in select order (top to bottom, then left to right). + + Args: + root: The root note (starting point). + bounds: A Shape object that defines the selection bounds. + bounded: Container widgets that require a bounds check. + + Returns: + An iterable of DOMNodes. + """ + stack: list[Iterator[Widget]] = [iter(root.children)] + pop = stack.pop + push = stack.append + + get_selection_order = attrgetter("_selection_order") + + def get_children(node: DOMNode) -> list[Widget]: + """Get children, sorted in selection order, and potentially filtered by selection bounds. + + Args: + node: A root node. + + Returns: + A list of child widgets. + """ + children = sorted( + node.displayed_and_visible_children, + key=get_selection_order, + ) + if node in bounded: + children = [child for child in children if bounds.overlaps(child.region)] + return children + + children = get_children(root) + + while stack: + if (node := next(stack[-1], None)) is None: + pop() + elif node.allow_select: + yield node + if children := get_children(node): + push(iter(children)) diff --git a/src/memray/_vendor/textual/widget.py b/src/memray/_vendor/textual/widget.py new file mode 100644 index 0000000000..0083acad3a --- /dev/null +++ b/src/memray/_vendor/textual/widget.py @@ -0,0 +1,4954 @@ +""" +This module contains the `Widget` class, the base class for all widgets. + +""" + +from __future__ import annotations + +from asyncio import create_task, gather, wait +from collections import Counter +from contextlib import asynccontextmanager +from fractions import Fraction +from time import monotonic +from types import TracebackType +from typing import ( + TYPE_CHECKING, + AsyncGenerator, + Callable, + ClassVar, + Collection, + Generator, + Iterable, + Mapping, + NamedTuple, + Sequence, + TypeVar, + cast, + overload, +) + +import rich.repr +from rich.console import ( + Console, + ConsoleOptions, + ConsoleRenderable, + JustifyMethod, + RenderableType, +) +from rich.console import RenderResult as RichRenderResult +from rich.measure import Measurement +from rich.segment import Segment +from rich.style import Style +from rich.text import Text +from typing_extensions import Self + +from memray._vendor.textual.css.styles import StylesBase + +if TYPE_CHECKING: + from memray._vendor.textual.app import RenderResult + +from memray._vendor.textual import constants, errors, events, messages +from memray._vendor.textual._animator import DEFAULT_EASING, Animatable, BoundAnimator, EasingFunction +from memray._vendor.textual._arrange import DockArrangeResult, arrange +from memray._vendor.textual._context import NoActiveAppError +from memray._vendor.textual._debug import get_caller_file_and_line +from memray._vendor.textual._dispatch_key import dispatch_key +from memray._vendor.textual._easing import DEFAULT_SCROLL_EASING +from memray._vendor.textual._extrema import Extrema +from memray._vendor.textual._styles_cache import StylesCache +from memray._vendor.textual._types import AnimationLevel +from memray._vendor.textual.actions import SkipAction +from memray._vendor.textual.await_remove import AwaitRemove +from memray._vendor.textual.box_model import BoxModel +from memray._vendor.textual.cache import FIFOCache, LRUCache +from memray._vendor.textual.color import Color +from memray._vendor.textual.compose import compose +from memray._vendor.textual.content import Content, ContentType +from memray._vendor.textual.css.match import match +from memray._vendor.textual.css.parse import parse_selectors +from memray._vendor.textual.css.query import NoMatches, WrongType +from memray._vendor.textual.css.scalar import Scalar, ScalarOffset +from memray._vendor.textual.dom import DOMNode, NoScreen +from memray._vendor.textual.geometry import ( + NULL_REGION, + NULL_SIZE, + NULL_SPACING, + Offset, + Region, + Size, + Spacing, + clamp, +) +from memray._vendor.textual.layout import Layout, WidgetPlacement +from memray._vendor.textual.layouts.vertical import VerticalLayout +from memray._vendor.textual.message import Message +from memray._vendor.textual.messages import CallbackType, Prune +from memray._vendor.textual.notifications import SeverityLevel +from memray._vendor.textual.reactive import Reactive +from memray._vendor.textual.renderables.blank import Blank +from memray._vendor.textual.rlock import RLock +from memray._vendor.textual.selection import Selection +from memray._vendor.textual.strip import Strip +from memray._vendor.textual.style import Style as VisualStyle +from memray._vendor.textual.visual import Visual, VisualType, visualize + +if TYPE_CHECKING: + from memray._vendor.textual.app import App, ComposeResult + from memray._vendor.textual.css.query import QueryType + from memray._vendor.textual.filter import LineFilter + from memray._vendor.textual.message_pump import MessagePump + from memray._vendor.textual.scrollbar import ( + ScrollBar, + ScrollBarCorner, + ScrollDown, + ScrollLeft, + ScrollRight, + ScrollTo, + ScrollUp, + ) + +_JUSTIFY_MAP: dict[str, JustifyMethod] = { + "start": "left", + "end": "right", + "justify": "full", +} + + +_MOUSE_EVENTS_DISALLOW_IF_DISABLED = (events.MouseEvent, events.Enter, events.Leave) +_MOUSE_EVENTS_ALLOW_IF_DISABLED = ( + events.MouseScrollDown, + events.MouseScrollUp, + events.MouseScrollRight, + events.MouseScrollLeft, +) + + +@rich.repr.auto +class AwaitMount: + """An *optional* awaitable returned by [mount][textual.widget.Widget.mount] and [mount_all][textual.widget.Widget.mount_all]. + + Example: + ```python + await self.mount(Static("foo")) + ``` + """ + + def __init__(self, parent: Widget, widgets: Sequence[Widget]) -> None: + self._parent = parent + self._widgets = widgets + self._caller = get_caller_file_and_line() + + def __rich_repr__(self) -> rich.repr.Result: + yield "parent", self._parent + yield "widgets", self._widgets + yield "caller", self._caller, None + + async def __call__(self) -> None: + """Allows awaiting via a call operation.""" + await self + + def __await__(self) -> Generator[None, None, None]: + async def await_mount() -> None: + if self._widgets: + aws = [ + create_task(widget._mounted_event.wait(), name="await mount") + for widget in self._widgets + ] + if aws: + await wait(aws) + self._parent.refresh(layout=True) + try: + self._parent.app._update_mouse_over(self._parent.screen) + except NoScreen: + pass + + return await_mount().__await__() + + +class _Styled: + """Apply a style to a renderable. + + Args: + renderable: Any renderable. + style: A style to apply across the entire renderable. + """ + + def __init__( + self, renderable: "ConsoleRenderable", style: Style, link_style: Style | None + ) -> None: + self.renderable = renderable + self.style = style + self.link_style = link_style + + def __rich_console__( + self, console: "Console", options: "ConsoleOptions" + ) -> "RichRenderResult": + style = console.get_style(self.style) + result_segments = console.render(self.renderable, options) + + _Segment = Segment + if style: + apply = style.__add__ + result_segments = ( + _Segment(text, apply(_style), None) + for text, _style, control in result_segments + ) + link_style = self.link_style + if link_style: + result_segments = ( + _Segment( + text, + ( + style + if style._meta is None + else (style + link_style if "@click" in style.meta else style) + ), + control, + ) + for text, style, control in result_segments + if style is not None + ) + return result_segments + + def __rich_measure__( + self, console: "Console", options: "ConsoleOptions" + ) -> Measurement: + return Measurement.get(console, options, self.renderable) + + +class _RenderCache(NamedTuple): + """Stores results of a previous render.""" + + size: Size + """The size of the render.""" + lines: list[Strip] + """Contents of the render.""" + + +class WidgetError(Exception): + """Base widget error.""" + + +class MountError(WidgetError): + """Error raised when there was a problem with the mount request.""" + + +class PseudoClasses(NamedTuple): + """Used for render/render_line based widgets that use caching. This structure can be used as a + cache-key.""" + + enabled: bool + """Is 'enabled' applied?""" + focus: bool + """Is 'focus' applied?""" + hover: bool + """Is 'hover' applied?""" + + +class _BorderTitle: + """Descriptor to set border titles.""" + + def __set_name__(self, owner: Widget, name: str) -> None: + # The private name where we store the real data. + self._internal_name = f"_{name}" + + def __set__(self, obj: Widget, title: Text | ContentType | None) -> None: + """Setting a title accepts a str, Text, or None.""" + if isinstance(title, Text): + title = Content.from_rich_text(title) + if title is None: + setattr(obj, self._internal_name, None) + else: + # We store the title as Text + new_title = obj.render_str(title).expand_tabs(4) + new_title = new_title.split()[0] + setattr(obj, self._internal_name, new_title) + obj.refresh() + + def __get__(self, obj: Widget, objtype: type[Widget] | None = None) -> str | None: + """Getting a title will return None or a str as console markup.""" + title: Text | None = getattr(obj, self._internal_name, None) + if title is None: + return None + # If we have a title, convert from Text to console markup + return title.markup + + +class BadWidgetName(Exception): + """Raised when widget class names do not satisfy the required restrictions.""" + + +@rich.repr.auto +class Widget(DOMNode): + """ + A Widget is the base class for Textual widgets. + + See also [static][textual.widgets._static.Static] for starting point for your own widgets. + """ + + DEFAULT_CSS = """ + Widget{ + scrollbar-background: $scrollbar-background; + scrollbar-background-hover: $scrollbar-background-hover; + scrollbar-background-active: $scrollbar-background-active; + scrollbar-color: $scrollbar; + scrollbar-color-active: $scrollbar-active; + scrollbar-color-hover: $scrollbar-hover; + scrollbar-corner-color: $scrollbar-corner-color; + scrollbar-size-vertical: 2; + scrollbar-size-horizontal: 1; + link-background: $link-background; + link-color: $link-color; + link-style: $link-style; + link-background-hover: $link-background-hover; + link-color-hover: $link-color-hover; + link-style-hover: $link-style-hover; + background: transparent; + } + """ + COMPONENT_CLASSES: ClassVar[set[str]] = set() + """A set of component classes.""" + + BORDER_TITLE: ClassVar[str] = "" + """Initial value for border_title attribute.""" + + BORDER_SUBTITLE: ClassVar[str] = "" + """Initial value for border_subtitle attribute.""" + + ALLOW_MAXIMIZE: ClassVar[bool | None] = None + """Defines default logic to allow the widget to be maximized. + + - `None` Use default behavior (Focusable widgets may be maximized) + - `False` Do not allow widget to be maximized + - `True` Allow widget to be maximized + + """ + + ALLOW_SELECT: ClassVar[bool] = True + """Does this widget support automatic text selection? May be further refined with [Widget.allow_select][textual.widget.Widget.allow_select].""" + + FOCUS_ON_CLICK: ClassVar[bool] = True + """Should focusable widgets be automatically focused on click? Default return value of [Widget.focus_on_click][textual.widget.Widget.focus_on_click].""" + + BLANK: ClassVar[bool] = False + """Is this widget blank (no border, no content)? Enable for very large scrolling containers.""" + + can_focus: bool = False + """Widget may receive focus.""" + can_focus_children: bool = True + """Widget's children may receive focus.""" + expand: Reactive[bool] = Reactive(False) + """Rich renderable may expand beyond optimal size.""" + shrink: Reactive[bool] = Reactive(True) + """Rich renderable may shrink below optimal size.""" + auto_links: Reactive[bool] = Reactive(True) + """Widget will highlight links automatically.""" + disabled: Reactive[bool] = Reactive(False) + """Is the widget disabled? Disabled widgets can not be interacted with, and are typically styled to look dimmer.""" + + hover_style: Reactive[Style] = Reactive(Style, repaint=False) + """The current hover style (style under the mouse cursor). Read only.""" + highlight_link_id: Reactive[str] = Reactive("") + """The currently highlighted link id. Read only.""" + loading: Reactive[bool] = Reactive(False) + """If set to `True` this widget will temporarily be replaced with a loading indicator.""" + + virtual_size = Reactive(Size(0, 0), layout=True) + """The virtual (scrollable) [size][textual.geometry.Size] of the widget.""" + + has_focus: Reactive[bool] = Reactive(False, repaint=False) + """Does this widget have focus? Read only.""" + + mouse_hover: Reactive[bool] = Reactive(False, repaint=False) + """Is the mouse over this widget? Read only.""" + + scroll_x: Reactive[float] = Reactive(0.0, repaint=False, layout=False) + """The scroll position on the X axis.""" + + scroll_y: Reactive[float] = Reactive(0.0, repaint=False, layout=False) + """The scroll position on the Y axis.""" + + scroll_target_x = Reactive(0.0, repaint=False) + """Scroll target destination, X coord.""" + + scroll_target_y = Reactive(0.0, repaint=False) + """Scroll target destination, Y coord.""" + + show_vertical_scrollbar: Reactive[bool] = Reactive(False, layout=True) + """Show a vertical scrollbar?""" + + show_horizontal_scrollbar: Reactive[bool] = Reactive(False, layout=True) + """Show a horizontal scrollbar?""" + + border_title = _BorderTitle() # type: ignore + """A title to show in the top border (if there is one).""" + border_subtitle = _BorderTitle() + """A title to show in the bottom border (if there is one).""" + + # Default sort order, incremented by constructor + _sort_order: ClassVar[int] = 0 + + _PSEUDO_CLASSES: ClassVar[dict[str, Callable[[Widget], bool]]] = { + "hover": lambda widget: widget.mouse_hover, + "focus": lambda widget: widget.has_focus, + "blur": lambda widget: not widget.has_focus, + "can-focus": lambda widget: widget.allow_focus(), + "disabled": lambda widget: widget.is_disabled, + "enabled": lambda widget: not widget.is_disabled, + "dark": lambda widget: widget.app.current_theme.dark, + "light": lambda widget: not widget.app.current_theme.dark, + "focus-within": lambda widget: widget.has_focus_within, + "inline": lambda widget: widget.app.is_inline, + "ansi": lambda widget: widget.app.ansi_color, + "nocolor": lambda widget: widget.app.no_color, + "first-of-type": lambda widget: widget.first_of_type, + "last-of-type": lambda widget: widget.last_of_type, + "first-child": lambda widget: widget.first_child, + "last-child": lambda widget: widget.last_child, + "odd": lambda widget: widget.is_odd, + "even": lambda widget: widget.is_even, + "empty": lambda widget: widget.is_empty, + } # type: ignore[assignment] + + def __init__( + self, + *children: Widget, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + markup: bool = True, + ) -> None: + """Initialize a Widget. + + Args: + *children: Child widgets. + name: The name of the widget. + id: The ID of the widget in the DOM. + classes: The CSS classes for the widget. + disabled: Whether the widget is disabled or not. + markup: Enable content markup? + """ + self._render_markup = markup + _null_size = NULL_SIZE + self._size = _null_size + self._container_size = _null_size + self._layout_required = False + self._layout_updates = 0 + self._repaint_required = False + self._scroll_required = False + self._recompose_required = False + self._refresh_styles_required = False + self._default_layout = VerticalLayout() + self._animate: BoundAnimator | None = None + Widget._sort_order += 1 + self.sort_order = Widget._sort_order + self.highlight_style: Style | None = None + + self._vertical_scrollbar: ScrollBar | None = None + self._horizontal_scrollbar: ScrollBar | None = None + self._scrollbar_corner: ScrollBarCorner | None = None + + self._border_title: Content | None = None + self._border_subtitle: Content | None = None + + self._layout_cache: dict[str, object] = {} + """A dict that is refreshed when the widget is resized / refreshed.""" + + self._visual_style: VisualStyle | None = None + """Cached style of visual.""" + self._visual_style_cache_key: int = -1 + """Cache busting integer.""" + + self._render_cache = _RenderCache(_null_size, []) + # Regions which need to be updated (in Widget) + self._dirty_regions: set[Region] = set() + # Regions which need to be transferred from cache to screen + self._repaint_regions: set[Region] = set() + + self._box_model_cache: LRUCache[object, BoxModel] = LRUCache(16) + + # Cache the auto content dimensions + self._content_width_cache: tuple[object, int] = (None, 0) + self._content_height_cache: tuple[object, int] = (None, 0) + + self._arrangement_cache: FIFOCache[ + tuple[Size, int, bool], DockArrangeResult + ] = FIFOCache(4) + + self._styles_cache = StylesCache() + self._rich_style_cache: dict[tuple[str, ...], tuple[Style, Style]] = {} + self._visual_style_cache: dict[tuple[str, ...], VisualStyle] = {} + + self._tooltip: VisualType | None = None + """The tooltip content.""" + self.absolute_offset: Offset | None = None + """Force an absolute offset for the widget (used by tooltips).""" + + self._scrollbar_changes: set[tuple[bool, bool]] = set() + """Used to stabilize scrollbars.""" + super().__init__( + name=name, + id=id, + classes=self.DEFAULT_CLASSES if classes is None else classes, + ) + + if self in children: + raise WidgetError("A widget can't be its own parent") + + for child in children: + if not isinstance(child, Widget): + raise TypeError( + f"Widget positional arguments must be Widget subclasses; not {child!r}" + ) + self._pending_children = list(children) + self.set_reactive(Widget.disabled, disabled) + if self.BORDER_TITLE: + self.border_title = self.BORDER_TITLE + if self.BORDER_SUBTITLE: + self.border_subtitle = self.BORDER_SUBTITLE + + self.lock = RLock() + """`asyncio` lock to be used to synchronize the state of the widget. + + Two different tasks might call methods on a widget at the same time, which + might result in a race condition. + This can be fixed by adding `async with widget.lock:` around the method calls. + """ + self._anchored: bool = False + """Has this widget been anchored?""" + self._anchor_released: bool = False + """Has the anchor been released?""" + + """Flag to enable animation when scrolling anchored widgets.""" + self._cover_widget: Widget | None = None + """Widget to render over this widget (used by loading indicator).""" + + self._first_of_type: tuple[int, bool] = (-1, False) + """Used to cache :first-of-type pseudoclass state.""" + self._last_of_type: tuple[int, bool] = (-1, False) + """Used to cache :last-of-type pseudoclass state.""" + self._first_child: tuple[int, bool] = (-1, False) + """Used to cache :first-child pseudoclass state.""" + self._last_child: tuple[int, bool] = (-1, False) + """Used to cache :last-child pseudoclass state.""" + self._odd: tuple[int, bool] = (-1, False) + """Used to cache :odd pseudoclass state.""" + self._last_scroll_time = monotonic() + """Time of last scroll.""" + self._extrema = Extrema() + """Optional minimum and maximum values for width and height.""" + + @property + def is_mounted(self) -> bool: + """Check if this widget is mounted.""" + return self._is_mounted + + @property + def siblings(self) -> list[Widget]: + """Get the widget's siblings (self is removed from the return list). + + Returns: + A list of siblings. + """ + parent = self.parent + if parent is not None: + siblings = list(parent._nodes) + siblings.remove(self) + return siblings + else: + return [] + + @property + def visible_siblings(self) -> list[Widget]: + """A list of siblings which will be shown. + + Returns: + List of siblings. + """ + siblings = [ + widget for widget in self.siblings if widget.visible and widget.display + ] + return siblings + + @property + def allow_vertical_scroll(self) -> bool: + """Check if vertical scroll is permitted. + + May be overridden if you want different logic regarding allowing scrolling. + """ + if self._check_disabled(): + return False + return self.is_scrollable and self.show_vertical_scrollbar + + @property + def allow_horizontal_scroll(self) -> bool: + """Check if horizontal scroll is permitted. + + May be overridden if you want different logic regarding allowing scrolling. + """ + if self._check_disabled(): + return False + return self.is_scrollable and self.show_horizontal_scrollbar + + @property + def _allow_scroll(self) -> bool: + """Check if both axis may be scrolled. + + Returns: + True if horizontal and vertical scrolling is enabled. + """ + return self.is_scrollable and ( + self.allow_horizontal_scroll or self.allow_vertical_scroll + ) + + @property + def allow_maximize(self) -> bool: + """Check if the widget may be maximized. + + Returns: + `True` if the widget may be maximized, or `False` if it should not be maximized. + """ + return self.can_focus if self.ALLOW_MAXIMIZE is None else self.ALLOW_MAXIMIZE + + @property + def offset(self) -> Offset: + """Widget offset from origin. + + Returns: + Relative offset. + """ + return self.styles.offset.resolve(self.size, self.screen.size) + + @offset.setter + def offset(self, offset: tuple[int, int]) -> None: + self.styles.offset = ScalarOffset.from_offset(offset) + + @property + def opacity(self) -> float: + """Total opacity of widget.""" + opacity = 1.0 + for node in reversed(self.ancestors_with_self): + opacity *= node.styles.opacity + if not opacity: + break + return opacity + + @property + def is_anchored(self) -> bool: + """Is this widget anchored? + + See [anchor()][textual.widget.Widget.anchor] for an explanation of anchoring. + + """ + return self._anchored + + @property + def is_mouse_over(self) -> bool: + """Is the mouse currently over this widget? + + Note this will be `True` if the mouse pointer is within the widget's region, even if + the mouse pointer is not directly over the widget (there could be another widget between + the mouse pointer and self). + + """ + if not self.screen.is_active: + return False + for widget, _ in self.screen.get_widgets_at(*self.app.mouse_position): + if widget is self: + return True + return False + + @property + def is_maximized(self) -> bool: + """Is this widget maximized?""" + try: + return self.screen.maximized is self + except NoScreen: + return False + + @property + def is_in_maximized_view(self) -> bool: + """Is this widget, or a parent maximized?""" + maximized = self.screen.maximized + if not maximized: + return False + for node in self.ancestors_with_self: + if maximized is node: + return True + return False + + @property + def _render_widget(self) -> Widget: + """The widget the compositor should render.""" + # Will return the "cover widget" if one is set, otherwise self. + return self._cover_widget if self._cover_widget is not None else self + + @property + def text_selection(self) -> Selection | None: + """Text selection information, or `None` if no text is selected in this widget.""" + return self.screen.selections.get(self, None) + + @classmethod + def get_common_ancestor( + cls, widget1: Widget, widget2: Widget, *, default: Widget | None = None + ) -> Widget: + """Get a common ancestors to both widgets. + + Raises: + ValueError: If there is no common ancestor and `default` is not provided (will not occur if both widgets are attached to the same DOM). + + Args: + widget1: A Widget. + widget2: A second widgets. + default: A widget to return if no common ancestor is found. + + Returns: + A common ancestor widgets. + """ + ancestors1 = widget1.ancestors + ancestors2 = set(widget2.ancestors) + for node in ancestors1: + if node in ancestors2: + assert isinstance(node, Widget) + return node + if default is not None: + return default + raise ValueError("No common ancestor found") + + def focus_on_click(self) -> bool: + """Automatically focus the widget on click? + + Implement this if you want to change the default click to focus behavior. + The default will return the classvar `FOCUS_ON_CLICK`. + + Returns: + `True` if Textual should set focus automatically on a click, or `False` if it shouldn't. + """ + return self.FOCUS_ON_CLICK + + def get_line_filters(self) -> Sequence[LineFilter]: + """Get the line filters enabled for this widget. + + Returns: + A sequence of [LineFilter][textual.filters.LineFilter] instances. + """ + return self.app.get_line_filters() + + def preflight_checks(self) -> None: + """Called in debug mode to do preflight checks. + + This is used by Textual to log some common errors, but you could implement this + in custom widgets to perform additional checks. + + """ + + if hasattr(self, "CSS"): + from memray._vendor.textual.screen import Screen + + if not isinstance(self, Screen): + self.log.warning( + f"'{self.__class__.__name__}.CSS' will be ignored (use 'DEFAULT_CSS' class variable for widgets)" + ) + + def pre_render(self) -> None: + """Called prior to rendering. + + If you implement this in a subclass, be sure to call the base class method via super. + + """ + self._visual_style = None + + def _cover(self, widget: Widget) -> None: + """Set a widget used to replace the visuals of this widget (used for loading indicator). + + Args: + widget: A newly constructed, but unmounted widget. + """ + self._uncover() + self._cover_widget = widget + widget._parent = self + widget._start_messages() + widget._post_register(self.app) + self.app.stylesheet.apply(widget) + self.refresh(layout=True) + + def process_layout( + self, placements: list[WidgetPlacement] + ) -> list[WidgetPlacement]: + """A hook to allow for the manipulation of widget placements before rendering. + + You could use this as a way to modify the positions / margins of widgets if your requirement is + not supported in TCSS. In practice, this method is rarely needed! + + Args: + placements: A list of [`WidgetPlacement`][textual.layout.WidgetPlacement] objects. + + Returns: + A new list of placements. + """ + return placements + + def _uncover(self) -> None: + """Remove any widget, previously set via [`_cover`][textual.widget.Widget._cover].""" + if self._cover_widget is not None: + self._cover_widget.remove() + self._cover_widget = None + self.refresh(layout=True) + + def anchor(self, anchor: bool = True) -> None: + """Anchor a scrollable widget. + + An anchored widget will stay scrolled the bottom when new content is added, until + the user moves the scroll position. + + Args: + anchor: Anchor the widget if `True`, clear the anchor if `False`. + + """ + self._anchored = anchor + if anchor: + self.scroll_end(immediate=True, animate=False) + + def release_anchor(self) -> None: + """Release the [anchor][textual.widget.Widget]. + + If a widget is anchored, releasing the anchor will allow the user to scroll as normal. + + """ + self.scroll_target_y = self.scroll_y + self._anchor_released = True + + def _check_anchor(self) -> None: + """Check if the scroll position is near enough to the bottom to restore anchor.""" + if ( + self._anchored + and self._anchor_released + and self.scroll_y >= self.max_scroll_y + ): + self._anchor_released = False + + def _check_disabled(self) -> bool: + """Check if the widget is disabled either explicitly by setting `disabled`, + or implicitly by setting `loading`. + + Returns: + True if the widget should be disabled. + """ + return self.disabled or self.loading + + @property + def tooltip(self) -> VisualType | None: + """Tooltip for the widget, or `None` for no tooltip.""" + return self._tooltip + + @tooltip.setter + def tooltip(self, tooltip: VisualType | None): + self._tooltip = tooltip + try: + self.screen._update_tooltip(self) + except NoScreen: + pass + + def with_tooltip(self, tooltip: Visual | RenderableType | None) -> Self: + """Chainable method to set a tooltip. + + Example: + ```python + def compose(self) -> ComposeResult: + yield Label("Hello").with_tooltip("A greeting") + ``` + + Args: + tooltip: New tooltip, or `None` to clear the tooltip. + + Returns: + Self. + """ + self.tooltip = tooltip + return self + + def allow_focus(self) -> bool: + """Check if the widget is permitted to focus. + + The base class returns [`can_focus`][textual.widget.Widget.can_focus]. + This method may be overridden if additional logic is required. + + Returns: + `True` if the widget may be focused, or `False` if it may not be focused. + """ + return self.can_focus + + def allow_focus_children(self) -> bool: + """Check if a widget's children may be focused. + + The base class returns [`can_focus_children`][textual.widget.Widget.can_focus_children]. + This method may be overridden if additional logic is required. + + Returns: + `True` if the widget's children may be focused, or `False` if the widget's children may not be focused. + """ + return self.can_focus_children + + def compose_add_child(self, widget: Widget) -> None: + """Add a node to children. + + This is used by the compose process when it adds children. + There is no need to use it directly, but you may want to override it in a subclass + if you want children to be attached to a different node. + + Args: + widget: A Widget to add. + """ + _rich_traceback_omit = True + self._pending_children.append(widget) + + @property + def is_disabled(self) -> bool: + """Is the widget disabled either because `disabled=True` or an ancestor has `disabled=True`.""" + node: MessagePump | None = self + while isinstance(node, Widget): + if node.disabled: + return True + node = node._parent + return False + + @property + def has_focus_within(self) -> bool: + """Are any descendants focused?""" + try: + focused = self.screen.focused + except NoScreen: + return False + node = focused + while node is not None: + if node is self: + return True + node = node._parent + return False + + @property + def first_of_type(self) -> bool: + """Is this the first widget of its type in its siblings?""" + parent = self.parent + if parent is None: + return True + # This pseudo classes only changes when the parent's nodes._updates changes + if parent._nodes._updates == self._first_of_type[0]: + return self._first_of_type[1] + widget_type = type(self) + for node in parent._nodes.displayed: + if isinstance(node, widget_type): + self._first_of_type = (parent._nodes._updates, node is self) + return self._first_of_type[1] + return False + + @property + def last_of_type(self) -> bool: + """Is this the last widget of its type in its siblings?""" + parent = self.parent + if parent is None: + return True + # This pseudo classes only changes when the parent's nodes._updates changes + if parent._nodes._updates == self._last_of_type[0]: + return self._last_of_type[1] + widget_type = type(self) + for node in parent._nodes.displayed_reverse: + if isinstance(node, widget_type): + self._last_of_type = (parent._nodes._updates, node is self) + return self._last_of_type[1] + return False + + @property + def first_child(self) -> bool: + """Is this the first widget in its siblings?""" + parent = self.parent + if parent is None: + return True + # This pseudo class only changes when the parent's nodes._updates changes + if parent._nodes._updates == self._first_child[0]: + return self._first_child[1] + for node in parent._nodes.displayed: + self._first_child = (parent._nodes._updates, node is self) + return self._first_child[1] + return False + + @property + def last_child(self) -> bool: + """Is this the last widget in its siblings?""" + parent = self.parent + if parent is None: + return True + # This pseudo class only changes when the parent's nodes._updates changes + if parent._nodes._updates == self._last_child[0]: + return self._last_child[1] + for node in parent._nodes.displayed_reverse: + self._last_child = (parent._nodes._updates, node is self) + return self._last_child[1] + return False + + @property + def is_odd(self) -> bool: + """Is this widget at an oddly numbered position within its siblings?""" + parent = self.parent + if parent is None: + return True + # This pseudo classes only changes when the parent's nodes._updates changes + if parent._nodes._updates == self._odd[0]: + return self._odd[1] + try: + is_odd = parent._nodes.displayed_and_visible.index(self) % 2 == 0 + self._odd = (parent._nodes._updates, is_odd) + return is_odd + except ValueError: + return False + + @property + def is_even(self) -> bool: + """Is this widget at an evenly numbered position within its siblings?""" + return not self.is_odd + + def __enter__(self) -> Self: + """Use as context manager when composing.""" + self.app._compose_stacks[-1].append(self) + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Exit compose context manager.""" + compose_stack = self.app._compose_stacks[-1] + composed = compose_stack.pop() + if compose_stack: + compose_stack[-1].compose_add_child(composed) + else: + self.app._composed[-1].append(composed) + + def clear_cached_dimensions(self) -> None: + """Clear cached results of `get_content_width` and `get_content_height`. + + Call if the widget's renderable changes size after the widget has been created. + + !!! note + + This is not required if you are extending [`Static`][textual.widgets.Static]. + + """ + self._content_width_cache = (None, 0) + self._content_height_cache = (None, 0) + + def get_loading_widget(self) -> Widget: + """Get a widget to display a loading indicator. + + The default implementation will defer to App.get_loading_widget. + + Returns: + A widget in place of this widget to indicate a loading. + """ + loading_widget = self.screen.get_loading_widget() + return loading_widget + + def set_loading(self, loading: bool) -> None: + """Set or reset the loading state of this widget. + + A widget in a loading state will display a `LoadingIndicator` or a custom widget + set through overriding the `get_loading_widget` method. + + Args: + loading: `True` to put the widget into a loading state, or `False` to reset the loading state. + """ + if loading: + loading_indicator = self.get_loading_widget() + loading_indicator.add_class("-textual-loading-indicator") + self._cover(loading_indicator) + else: + self._uncover() + self.screen.update_pointer_shape() + + def _watch_loading(self, loading: bool) -> None: + """Called when the 'loading' reactive is changed.""" + if not self.is_mounted: + self.call_later(self.set_loading, loading) + else: + self.set_loading(loading) + + ExpectType = TypeVar("ExpectType", bound="Widget") + + if TYPE_CHECKING: + + @overload + def get_child_by_id(self, id: str) -> Widget: ... + + @overload + def get_child_by_id( + self, id: str, expect_type: type[ExpectType] + ) -> ExpectType: ... + + def get_child_by_id( + self, id: str, expect_type: type[ExpectType] | None = None + ) -> ExpectType | Widget: + """Return the first child (immediate descendent) of this node with the given ID. + + Args: + id: The ID of the child. + expect_type: Require the object be of the supplied type, or None for any type. + + Returns: + The first child of this node with the ID. + + Raises: + NoMatches: if no children could be found for this ID + WrongType: if the wrong type was found. + """ + child = self._get_dom_base()._nodes._get_by_id(id) + if child is None: + raise NoMatches(f"No child found with id={id!r}") + if expect_type is None: + return child + if not isinstance(child, expect_type): + raise WrongType( + f"Child with id={id!r} is the wrong type; expected type {expect_type.__name__!r}, found {child}" + ) + return child + + if TYPE_CHECKING: + + @overload + def get_widget_by_id(self, id: str) -> Widget: ... + + @overload + def get_widget_by_id( + self, id: str, expect_type: type[ExpectType] + ) -> ExpectType: ... + + def get_widget_by_id( + self, id: str, expect_type: type[ExpectType] | None = None + ) -> ExpectType | Widget: + """Return the first descendant widget with the given ID. + + Performs a depth-first search rooted at this widget. + + Args: + id: The ID to search for in the subtree. + expect_type: Require the object be of the supplied type, or None for any type. + + Returns: + The first descendant encountered with this ID. + + Raises: + NoMatches: if no children could be found for this ID. + WrongType: if the wrong type was found. + """ + + widget = self.query_one(f"#{id}") + if expect_type is not None and not isinstance(widget, expect_type): + raise WrongType( + f"Descendant with id={id!r} is the wrong type; expected type {expect_type.__name__!r}, found {widget}" + ) + return widget + + def get_child_by_type(self, expect_type: type[ExpectType]) -> ExpectType: + """Get the first immediate child of a given type. + + Only returns exact matches, and so will not match subclasses of the given type. + + Args: + expect_type: The type of the child to search for. + + Raises: + NoMatches: If no matching child is found. + + Returns: + The first immediate child widget with the expected type. + """ + for child in self._nodes: + # We want the child with the exact type (not subclasses) + if type(child) is expect_type: + assert isinstance(child, expect_type) + return child + raise NoMatches(f"No immediate child of type {expect_type}; {self._nodes}") + + def get_component_rich_style( + self, *names: str, partial: bool = False, default: Style | None = None + ) -> Style: + """Get a *Rich* style for a component. + + Args: + names: Names of components. + partial: Return a partial style (not combined with parent). + default: A Style to return if any component style doesn't exist. + + Raises: + KeyError: If a component style doesn't exist, and no `default` is provided. + + Returns: + A Rich style object. + """ + + if names not in self._rich_style_cache: + if default is None: + component_styles = self.get_component_styles(*names) + else: + try: + component_styles = self.get_component_styles(*names) + except KeyError: + return default + + style = component_styles.rich_style + text_opacity = component_styles.text_opacity + if text_opacity < 1 and style.bgcolor is not None: + style += Style.from_color( + ( + Color.from_rich_color(style.bgcolor) + + component_styles.color.multiply_alpha(text_opacity) + ).rich_color + ) + partial_style = component_styles.partial_rich_style + self._rich_style_cache[names] = (style, partial_style) + + style, partial_style = self._rich_style_cache[names] + + return partial_style if partial else style + + def get_visual_style( + self, *component_classes: str, partial: bool = False + ) -> VisualStyle: + """Get the visual style for the widget, including any component styles. + + Args: + component_classes: Optional component styles. + partial: Return a partial style (not combined with parent). + + Returns: + A Visual style instance. + + """ + cache_key = (self._pseudo_classes_cache_key, component_classes, partial) + if (visual_style := self._visual_style_cache.get(cache_key, None)) is None: + background = Color(0, 0, 0, 0) + color = Color(255, 255, 255, 0) + + style = Style() + opacity = 1.0 + + def iter_styles() -> Iterable[StylesBase]: + """Iterate over the styles from the DOM and additional components styles.""" + if partial: + node = self + else: + for node in reversed(self.ancestors_with_self): + yield node.styles + for name in component_classes: + yield node.get_component_styles(name) + + for styles in iter_styles(): + has_rule = styles.has_rule + opacity *= styles.opacity + if has_rule("background"): + text_background = background + styles.background.tint( + styles.background_tint + ) + if partial: + background_tint = styles.background.tint(styles.background_tint) + background = background.blend( + background_tint, 1 - background_tint.a + ).multiply_alpha(opacity) + else: + background += ( + styles.background.tint(styles.background_tint) + ).multiply_alpha(opacity) + else: + text_background = background + if has_rule("color"): + color = styles.color.multiply_alpha(styles.text_opacity) + style += styles.text_style + if has_rule("auto_color") and styles.auto_color: + color = text_background.get_contrast_text(color.a) + + visual_style = VisualStyle( + background, + color, + bold=style.bold, + dim=style.dim, + italic=style.italic, + underline=style.underline, + strike=style.strike, + ) + self._visual_style_cache[cache_key] = visual_style + + return visual_style + + def _get_style(self, style: VisualStyle | str) -> VisualStyle: + """A get_style method for use in Content. + + Args: + style: A style prefixed with a dot. + + Returns: + A visual style if one is fund, otherwise `None`. + """ + if isinstance(style, VisualStyle): + return style + visual_style = VisualStyle.null() + if style.startswith("."): + for node in self.ancestors_with_self: + if not isinstance(node, Widget): + break + try: + visual_style = node.get_visual_style(style[1:], partial=True) + break + except KeyError: + continue + else: + raise KeyError(f"No matching component class found for '{style}'") + return visual_style + try: + visual_style = VisualStyle.parse(style) + except Exception: + pass + return visual_style + + @overload + def render_str(self, text_content: str) -> Content: ... + + @overload + def render_str(self, text_content: Content) -> Content: ... + + def render_str(self, text_content: str | Content) -> Content: + """Convert str into a [Content][textual.content.Content] instance. + + If you pass in an existing Content instance it will be returned unaltered. + + Args: + text_content: Content or str. + + Returns: + Content object. + """ + if isinstance(text_content, Content): + return text_content + return Content.from_markup(text_content) + + def arrange(self, size: Size, optimal: bool = False) -> DockArrangeResult: + """Arrange child widgets. + + This method is best left alone, unless you have a deep understanding of what it does. + + Args: + size: Size of container. + optimal: Whether fr units should expand the widget (`False`) or avoid expanding the widget (`True`). + + Returns: + Widget locations. + """ + cache_key = (size, self._nodes._updates, optimal) + cached_result = self._arrangement_cache.get(cache_key) + if cached_result is not None: + return cached_result + + arrangement = self._arrangement_cache[cache_key] = arrange( + self, self._nodes, size, self.screen.size, optimal=optimal + ) + + return arrangement + + def _clear_arrangement_cache(self) -> None: + """Clear arrangement cache, forcing a new arrange operation.""" + self._arrangement_cache.clear() + + def _get_virtual_dom(self) -> Iterable[Widget]: + """Get widgets not part of the DOM. + + Returns: + An iterable of Widgets. + """ + if self._horizontal_scrollbar is not None: + yield self._horizontal_scrollbar + if self._vertical_scrollbar is not None: + yield self._vertical_scrollbar + if self._scrollbar_corner is not None: + yield self._scrollbar_corner + + def _find_mount_point(self, spot: int | str | "Widget") -> tuple["Widget", int]: + """Attempt to locate the point where the caller wants to mount something. + + Args: + spot: The spot to find. + + Returns: + The parent and the location in its child list. + + Raises: + MountError: If there was an error finding where to mount a widget. + + The rules of this method are: + + - Given an ``int``, parent is ``self`` and location is the integer value. + - Given a ``Widget``, parent is the widget's parent and location is + where the widget is found in the parent's ``children``. If it + can't be found a ``MountError`` will be raised. + - Given a string, it is used to perform a ``query_one`` and then the + result is used as if a ``Widget`` had been given. + """ + + # A numeric location means at that point in our child list. + if isinstance(spot, int): + return self, spot + + # If we've got a string, that should be treated like a query that + # can be passed to query_one. So let's use that to get a widget to + # work on. + if isinstance(spot, str): + spot = self.query_exactly_one(spot, Widget) + + # At this point we should have a widget, either because we got given + # one, or because we pulled one out of the query. First off, does it + # have a parent? There's no way we can use it as a sibling to make + # mounting decisions if it doesn't have a parent. + if spot.parent is None: + raise MountError( + f"Unable to find relative location of {spot!r} because it has no parent" + ) + + # We've got a widget. It has a parent. It has (zero or more) + # children. We should be able to go looking for the widget's + # location amongst its parent's children. + try: + return cast("Widget", spot.parent), spot.parent._nodes.index(spot) + except ValueError: + raise MountError(f"{spot!r} is not a child of {self!r}") from None + + def mount( + self, + *widgets: Widget, + before: int | str | Widget | None = None, + after: int | str | Widget | None = None, + ) -> AwaitMount: + """Mount widgets below this widget (making this widget a container). + + Args: + *widgets: The widget(s) to mount. + before: Optional location to mount before. An `int` is the index + of the child to mount before, a `str` is a `query_one` query to + find the widget to mount before. + after: Optional location to mount after. An `int` is the index + of the child to mount after, a `str` is a `query_one` query to + find the widget to mount after. + + Returns: + An awaitable object that waits for widgets to be mounted. + + Raises: + MountError: If there is a problem with the mount request. + + Note: + Only one of ``before`` or ``after`` can be provided. If both are + provided a ``MountError`` will be raised. + """ + if self._closing or self._pruning: + return AwaitMount(self, []) + if not self.is_attached: + raise MountError(f"Can't mount widget(s) before {self!r} is mounted") + # Check for duplicate IDs in the incoming widgets + ids_to_mount = [ + widget_id for widget in widgets if (widget_id := widget.id) is not None + ] + if len(set(ids_to_mount)) != len(ids_to_mount): + counter = Counter(ids_to_mount) + for widget_id, count in counter.items(): + if count > 1: + raise MountError( + f"Tried to insert {count!r} widgets with the same ID {widget_id!r}. " + "Widget IDs must be unique." + ) + + # Saying you want to mount before *and* after something is an error. + if before is not None and after is not None: + raise MountError( + "Only one of `before` or `after` can be handled -- not both" + ) + + # Decide the final resting place depending on what we've been asked + # to do. + insert_before: int | None = None + insert_after: int | None = None + if before is not None: + parent, insert_before = self._find_mount_point(before) + elif after is not None: + parent, insert_after = self._find_mount_point(after) + else: + parent = self + + mounted = self.app._register( + parent, *widgets, before=insert_before, after=insert_after + ) + + def update_styles(children: list[DOMNode]) -> None: + """Update order related CSS""" + if before is not None or after is not None: + # If the new children aren't at the end. + # we need to update both odd/even, first-of-type/last-of-type and first-child/last-child + for child in children: + if child._has_order_style or child._has_odd_or_even: + child.update_node_styles() + else: + for child in children: + if child._has_order_style: + child.update_node_styles() + + self.call_later(update_styles, self.displayed_children) + await_mount = AwaitMount(self, mounted) + self.call_next(await_mount) + + return await_mount + + def _refresh_styles(self) -> None: + """Request refresh of styles on idle.""" + self._refresh_styles_required = True + self.check_idle() + + def mount_all( + self, + widgets: Iterable[Widget], + *, + before: int | str | Widget | None = None, + after: int | str | Widget | None = None, + ) -> AwaitMount: + """Mount widgets from an iterable. + + Args: + widgets: An iterable of widgets. + before: Optional location to mount before. An `int` is the index + of the child to mount before, a `str` is a `query_one` query to + find the widget to mount before. + after: Optional location to mount after. An `int` is the index + of the child to mount after, a `str` is a `query_one` query to + find the widget to mount after. + + Returns: + An awaitable object that waits for widgets to be mounted. + + Raises: + MountError: If there is a problem with the mount request. + + Note: + Only one of `before` or `after` can be provided. If both are + provided a `MountError` will be raised. + """ + if self.app._exit: + return AwaitMount(self, []) + await_mount = self.mount(*widgets, before=before, after=after) + return await_mount + + def mount_compose( + self, + compose_result: ComposeResult, + *, + before: int | str | Widget | None = None, + after: int | str | Widget | None = None, + ) -> AwaitMount: + """Mount widgets from the result of a compose method. + + Example: + ```python + def on_key(self, event:events.Key) -> None: + + def add_key(key:str) -> ComposeResult: + '''Compose key information widgets''' + with containers.HorizontalGroup(): + yield Label("You pressed:") + yield Label(key) + + self.mount_compose(add_key(event.key)) + + ``` + + Args: + compose_result: The result of a compose method. + before: Optional location to mount before. An `int` is the index + of the child to mount before, a `str` is a `query_one` query to + find the widget to mount before. + after: Optional location to mount after. An `int` is the index + of the child to mount after, a `str` is a `query_one` query to + find the widget to mount after. + + Returns: + An awaitable object that waits for widgets to be mounted. + + Raises: + MountError: If there is a problem with the mount request. + + Note: + Only one of `before` or `after` can be provided. If both are + provided a `MountError` will be raised. + """ + return self.mount_all(compose(self, compose_result), before=before, after=after) + + if TYPE_CHECKING: + + @overload + def move_child( + self, + child: int | Widget, + *, + before: int | Widget, + after: None = None, + ) -> None: ... + + @overload + def move_child( + self, + child: int | Widget, + *, + after: int | Widget, + before: None = None, + ) -> None: ... + + def move_child( + self, + child: int | Widget, + *, + before: int | Widget | None = None, + after: int | Widget | None = None, + ) -> None: + """Move a child widget within its parent's list of children. + + Args: + child: The child widget to move. + before: Child widget or location index to move before. + after: Child widget or location index to move after. + + Raises: + WidgetError: If there is a problem with the child or target. + + Note: + Only one of `before` or `after` can be provided. If neither + or both are provided a `WidgetError` will be raised. + """ + + # One or the other of before or after are required. Can't do + # neither, can't do both. + if before is None and after is None: + raise WidgetError("One of `before` or `after` is required.") + elif before is not None and after is not None: + raise WidgetError("Only one of `before` or `after` can be handled.") + + def _to_widget(child: int | Widget, called: str) -> Widget: + """Ensure a given child reference is a Widget.""" + if isinstance(child, int): + try: + child = self._nodes[child] + except IndexError: + raise WidgetError( + f"An index of {child} for the child to {called} is out of bounds" + ) from None + else: + # We got an actual widget, so let's be sure it really is one of + # our children. + try: + _ = self._nodes.index(child) + except ValueError: + raise WidgetError(f"{child!r} is not a child of {self!r}") from None + return child + + # Ensure the child and target are widgets. + child = _to_widget(child, "move") + target = _to_widget( + cast("int | Widget", before if after is None else after), "move towards" + ) + + if child is target: + return # Nothing to be done. + + # At this point we should know what we're moving, and it should be a + # child; where we're moving it to, which should be within the child + # list; and how we're supposed to move it. All that's left is doing + # the right thing. + self._nodes._remove(child) + if before is not None: + self._nodes._insert(self._nodes.index(target), child) + else: + self._nodes._insert(self._nodes.index(target) + 1, child) + + # Request a refresh. + self.refresh(layout=True) + + def compose(self) -> ComposeResult: + """Called by Textual to create child widgets. + + This method is called when a widget is mounted or by setting `recompose=True` when + calling [`refresh()`][textual.widget.Widget.refresh]. + + Note that you don't typically need to explicitly call this method. + + Example: + ```python + def compose(self) -> ComposeResult: + yield Header() + yield Label("Press the button below:") + yield Button() + yield Footer() + ``` + """ + yield from () + + async def _check_recompose(self) -> None: + """Check if a recompose is required.""" + if self._recompose_required: + self._recompose_required = False + await self.recompose() + + async def recompose(self) -> None: + """Recompose the widget. + + Recomposing will remove children and call `self.compose` again to remount. + """ + if not self.is_attached or self._pruning: + return + + async with self.batch(): + await self.query_children("*").exclude(".-textual-system").remove() + if self.is_attached: + compose_nodes = compose(self) + await self.mount_all(compose_nodes) + + def _post_register(self, app: App) -> None: + """Called when the instance is registered. + + Args: + app: App instance. + """ + # Parse the Widget's CSS + for read_from, css, tie_breaker, scope in self._get_default_css(): + self.app.stylesheet.add_source( + css, + read_from=read_from, + is_default_css=True, + tie_breaker=tie_breaker, + scope=scope, + ) + if app.debug: + app.call_next(self.preflight_checks) + + def _get_box_model( + self, + container: Size, + viewport: Size, + width_fraction: Fraction, + height_fraction: Fraction, + constrain_width: bool = False, + greedy: bool = True, + ) -> BoxModel: + """Process the box model for this widget. + + Args: + container: The size of the container widget (with a layout). + viewport: The viewport size. + width_fraction: A fraction used for 1 `fr` unit on the width dimension. + height_fraction: A fraction used for 1 `fr` unit on the height dimension. + constrain_width: Restrict the width to the container width. + + Returns: + The size and margin for this widget. + """ + cache_key = ( + container, + viewport, + width_fraction, + height_fraction, + constrain_width, + greedy, + self._layout_updates, + self.styles._cache_key, + ) + if cached_box_model := self._box_model_cache.get(cache_key): + return cached_box_model + + styles = self.styles + is_border_box = styles.box_sizing == "border-box" + gutter = styles.gutter # Padding plus border + margin = styles.margin + + styles_width = styles.width + if not greedy and styles_width is not None and styles_width.is_fraction: + styles_width = Scalar.parse("auto") + is_auto_width = styles_width and styles_width.is_auto + is_auto_height = styles.height and styles.height.is_auto + + # Container minus padding and border + content_container = container - gutter.totals + + extrema = self._extrema = self._resolve_extrema( + container, viewport, width_fraction, height_fraction + ) + min_width, max_width, min_height, max_height = extrema + + if styles_width is None: + # No width specified, fill available space + content_width = Fraction(content_container.width - margin.width) + elif is_auto_width: + # When width is auto, we want enough space to always fit the content + content_width = Fraction( + self.get_content_width(content_container - margin.totals, viewport) + ) + if ( + styles.overflow_x == "auto" and styles.scrollbar_gutter == "stable" + ) or self.show_vertical_scrollbar: + content_width += styles.scrollbar_size_vertical + if ( + content_width < content_container.width + and self._has_relative_children_width + ): + content_width = Fraction(content_container.width) + else: + # An explicit width + content_width = styles_width.resolve( + container - margin.totals, viewport, width_fraction + ) + if is_border_box: + content_width -= gutter.width + + if min_width is not None: + # Restrict to minimum width, if set + content_width = max(content_width, min_width, Fraction(0)) + + if max_width is not None and not ( + container.width == 0 + and not styles.max_width.is_cells + and self._parent is not None + and self._parent.styles.is_auto_width + ): + # Restrict to maximum width, if set + content_width = min(content_width, max_width) + + content_width = max(Fraction(0), content_width) + + if constrain_width: + content_width = min(Fraction(container.width - gutter.width), content_width) + + if styles.height is None: + # No height specified, fill the available space + content_height = Fraction(content_container.height - margin.height) + elif is_auto_height: + # Calculate dimensions based on content + content_height = Fraction( + self.get_content_height( + content_container - margin.totals, + viewport, + int(content_width), + ) + ) + if ( + styles.overflow_y == "auto" and styles.scrollbar_gutter == "stable" + ) or self.show_horizontal_scrollbar: + content_height += styles.scrollbar_size_horizontal + if ( + content_height < content_container.height + and self._has_relative_children_height + ): + content_height = Fraction(content_container.height) + else: + styles_height = styles.height + # Explicit height set + content_height = styles_height.resolve( + container - margin.totals, viewport, height_fraction + ) + if is_border_box: + content_height -= gutter.height + + if min_height is not None: + # Restrict to minimum height, if set + content_height = max(content_height, min_height, Fraction(0)) + + if max_height is not None and not ( + container.height == 0 + and not styles.max_height.is_cells + and self._parent is not None + and self._parent.styles.is_auto_height + ): + content_height = min(content_height, max_height) + + content_height = max(Fraction(0), content_height) + model = BoxModel( + content_width + gutter.width, content_height + gutter.height, margin + ) + self._box_model_cache[cache_key] = model + return model + + def get_content_width(self, container: Size, viewport: Size) -> int: + """Called by textual to get the width of the content area. May be overridden in a subclass. + + Args: + container: Size of the container (immediate parent) widget. + viewport: Size of the viewport. + + Returns: + The optimal width of the content. + """ + + if self.is_container: + width = self.layout.get_content_width(self, container, viewport) + return width + + cache_key = container.width + if self._content_width_cache[0] == cache_key: + return self._content_width_cache[1] + + visual = self._render() + width = visual.get_optimal_width(self.styles, container.width) + + if self.expand: + width = max(container.width, width) + if self.shrink: + width = min(width, container.width) + + self._content_width_cache = (cache_key, width) + + return width + + def get_content_height(self, container: Size, viewport: Size, width: int) -> int: + """Called by Textual to get the height of the content area. May be overridden in a subclass. + + Args: + container: Size of the container (immediate parent) widget. + viewport: Size of the viewport. + width: Width of renderable. + + Returns: + The height of the content. + """ + if not width: + return 0 + if self.is_container: + assert self.layout is not None + height = self.layout.get_content_height( + self, + container, + viewport, + width, + ) + else: + cache_key = width + + if self._content_height_cache[0] == cache_key: + return self._content_height_cache[1] + + visual = self._render() + height = visual.get_height(self.styles, width) + self._content_height_cache = (cache_key, height) + + return height + + def watch_hover_style( + self, previous_hover_style: Style, hover_style: Style + ) -> None: + # TODO: This will cause the widget to refresh, even when there are no links + # Can we avoid this? + if self.auto_links and not self.app.mouse_captured: + self.highlight_link_id = hover_style.link_id + + def watch_scroll_x(self, old_value: float, new_value: float) -> None: + if self.show_horizontal_scrollbar: + self.horizontal_scrollbar.position = new_value + if round(old_value) != round(new_value): + self._refresh_scroll() + + def watch_scroll_y(self, old_value: float, new_value: float) -> None: + if self.show_vertical_scrollbar: + self.vertical_scrollbar.position = new_value + if self._anchored and self._anchor_released: + self._check_anchor() + if round(old_value) != round(new_value): + self._refresh_scroll() + + def validate_scroll_x(self, value: float) -> float: + return clamp(value, 0, self.max_scroll_x) + + def validate_scroll_target_x(self, value: float) -> float: + return round(clamp(value, 0, self.max_scroll_x)) + + def validate_scroll_y(self, value: float) -> float: + return clamp(value, 0, self.max_scroll_y) + + def validate_scroll_target_y(self, value: float) -> float: + return round(clamp(value, 0, self.max_scroll_y)) + + @property + def max_scroll_x(self) -> int: + """The maximum value of `scroll_x`.""" + return max( + 0, + self.virtual_size.width + - (self.container_size.width - self.scrollbar_size_vertical), + ) + + @property + def max_scroll_y(self) -> int: + """The maximum value of `scroll_y`.""" + return max( + 0, + self.virtual_size.height + - (self.container_size.height - self.scrollbar_size_horizontal), + ) + + @property + def is_vertical_scroll_end(self) -> bool: + """Is the vertical scroll position at the maximum?""" + return self.scroll_offset.y == self.max_scroll_y or not self.size + + @property + def is_horizontal_scroll_end(self) -> bool: + """Is the horizontal scroll position at the maximum?""" + return self.scroll_offset.x == self.max_scroll_x or not self.size + + @property + def is_vertical_scrollbar_grabbed(self) -> bool: + """Is the user dragging the vertical scrollbar?""" + return bool(self._vertical_scrollbar and self._vertical_scrollbar.grabbed) + + @property + def is_horizontal_scrollbar_grabbed(self) -> bool: + """Is the user dragging the vertical scrollbar?""" + return bool(self._horizontal_scrollbar and self._horizontal_scrollbar.grabbed) + + @property + def scrollbar_corner(self) -> ScrollBarCorner: + """The scrollbar corner. + + Note: + This will *create* a scrollbar corner if one doesn't exist. + + Returns: + ScrollBarCorner Widget. + """ + from memray._vendor.textual.scrollbar import ScrollBarCorner + + if self._scrollbar_corner is not None: + return self._scrollbar_corner + self._scrollbar_corner = ScrollBarCorner() + self.app._start_widget(self, self._scrollbar_corner) + return self._scrollbar_corner + + @property + def vertical_scrollbar(self) -> ScrollBar: + """The vertical scrollbar (create if necessary). + + Note: + This will *create* a scrollbar if one doesn't exist. + + Returns: + ScrollBar Widget. + """ + from memray._vendor.textual.scrollbar import ScrollBar + + if self._vertical_scrollbar is not None: + return self._vertical_scrollbar + self._vertical_scrollbar = scroll_bar = ScrollBar( + vertical=True, name="vertical", thickness=self.scrollbar_size_vertical + ) + self._vertical_scrollbar.display = False + self.app._start_widget(self, scroll_bar) + return scroll_bar + + @property + def horizontal_scrollbar(self) -> ScrollBar: + """The horizontal scrollbar. + + Note: + This will *create* a scrollbar if one doesn't exist. + + Returns: + ScrollBar Widget. + """ + + from memray._vendor.textual.scrollbar import ScrollBar + + if self._horizontal_scrollbar is not None: + return self._horizontal_scrollbar + self._horizontal_scrollbar = scroll_bar = ScrollBar( + vertical=False, name="horizontal", thickness=self.scrollbar_size_horizontal + ) + self._horizontal_scrollbar.display = False + self.app._start_widget(self, scroll_bar) + return scroll_bar + + def _refresh_scrollbars(self) -> None: + """Refresh scrollbar visibility.""" + if not self.is_scrollable or not self.container_size: + return + + styles = self.styles + overflow_x = styles.overflow_x + overflow_y = styles.overflow_y + + width, height = self._container_size + + show_horizontal = False + if overflow_x == "hidden": + show_horizontal = False + elif overflow_x == "scroll": + show_horizontal = True + elif overflow_x == "auto": + show_horizontal = self.virtual_size.width > width + + show_vertical = False + if overflow_y == "hidden": + show_vertical = False + elif overflow_y == "scroll": + show_vertical = True + elif overflow_y == "auto": + show_vertical = self.virtual_size.height > height + + _show_horizontal = show_horizontal + _show_vertical = show_vertical + + if not ( + overflow_x == "auto" + and overflow_y == "auto" + and (show_horizontal, show_vertical) in self._scrollbar_changes + ): + # When a single scrollbar is shown, the other dimension changes, so we need to recalculate. + if overflow_x == "auto" and show_vertical and not show_horizontal: + show_horizontal = self.virtual_size.width > ( + width - styles.scrollbar_size_vertical + ) + if overflow_y == "auto" and show_horizontal and not show_vertical: + show_vertical = self.virtual_size.height > ( + height - styles.scrollbar_size_horizontal + ) + + if ( + self.show_horizontal_scrollbar != show_horizontal + or self.show_vertical_scrollbar != show_vertical + ): + self._scrollbar_changes.add((_show_horizontal, _show_vertical)) + else: + self._scrollbar_changes.clear() + + self.show_horizontal_scrollbar = show_horizontal + self.show_vertical_scrollbar = show_vertical + + if self._horizontal_scrollbar is not None or show_horizontal: + self.horizontal_scrollbar.display = show_horizontal + if self._vertical_scrollbar is not None or show_vertical: + self.vertical_scrollbar.display = show_vertical + + @property + def scrollbars_enabled(self) -> tuple[bool, bool]: + """A tuple of booleans that indicate if scrollbars are enabled. + + Returns: + A tuple of (, ) + """ + if not self.is_scrollable: + return False, False + + return (self.show_vertical_scrollbar, self.show_horizontal_scrollbar) + + @property + def scrollbars_space(self) -> tuple[int, int]: + """The number of cells occupied by scrollbars for width and height""" + return (self.scrollbar_size_vertical, self.scrollbar_size_horizontal) + + @property + def scrollbar_size_vertical(self) -> int: + """Get the width used by the *vertical* scrollbar. + + Returns: + Number of columns in the vertical scrollbar. + """ + styles = self.styles + if styles.scrollbar_gutter == "stable" and styles.overflow_y == "auto": + return styles.scrollbar_size_vertical + return styles.scrollbar_size_vertical if self.show_vertical_scrollbar else 0 + + @property + def scrollbar_size_horizontal(self) -> int: + """Get the height used by the *horizontal* scrollbar. + + Returns: + Number of rows in the horizontal scrollbar. + """ + styles = self.styles + return styles.scrollbar_size_horizontal if self.show_horizontal_scrollbar else 0 + + @property + def scrollbar_gutter(self) -> Spacing: + """Spacing required to fit scrollbar(s). + + Returns: + Scrollbar gutter spacing. + """ + return Spacing( + 0, self.scrollbar_size_vertical, self.scrollbar_size_horizontal, 0 + ) + + @property + def gutter(self) -> Spacing: + """Spacing for padding / border / scrollbars. + + Returns: + Additional spacing around content area. + """ + return self.styles.gutter + self.scrollbar_gutter + + @property + def size(self) -> Size: + """The size of the content area. + + Returns: + Content area size. + """ + return self.content_region.size + + @property + def scrollable_size(self) -> Size: + """The size of the scrollable content. + + Returns: + Scrollable content size. + """ + return self.scrollable_content_region.size + + @property + def outer_size(self) -> Size: + """The size of the widget (including padding and border). + + Returns: + Outer size. + """ + return self._size + + @property + def container_size(self) -> Size: + """The size of the container (parent widget). + + Returns: + Container size. + """ + return self._container_size + + @property + def content_region(self) -> Region: + """Gets an absolute region containing the content (minus padding and border). + + Returns: + Screen region that contains a widget's content. + """ + content_region = self.region.shrink(self.styles.gutter) + return content_region + + @property + def scrollable_content_region(self) -> Region: + """Gets an absolute region containing the scrollable content (minus padding, border, and scrollbars). + + Returns: + Screen region that contains a widget's content. + """ + content_region = self.region.shrink(self.styles.gutter).shrink( + self.scrollbar_gutter + ) + return content_region + + @property + def content_offset(self) -> Offset: + """An offset from the Widget origin where the content begins. + + Returns: + Offset from widget's origin. + """ + x, y = self.gutter.top_left + return Offset(x, y) + + @property + def content_size(self) -> Size: + """The size of the content area. + + Returns: + Content area size. + """ + return self.region.shrink(self.styles.gutter).size + + @property + def _selection_order(self) -> tuple[int, int]: + """A tuple of integers used to sort widgets in selection order.""" + try: + x, y, _width, _height = self.screen.find_widget(self).region + except (NoScreen, errors.NoWidget): + return (0, 0) + return y, x + + @property + def region(self) -> Region: + """The region occupied by this widget, relative to the Screen. + + Returns: + Region within screen occupied by widget. + """ + try: + return self.screen.find_widget(self).region + except (NoScreen, errors.NoWidget): + return NULL_REGION + + @property + def dock_gutter(self) -> Spacing: + """Space allocated to docks in the parent. + + Returns: + Space to be subtracted from scrollable area. + """ + try: + return self.screen.find_widget(self).dock_gutter + except (NoScreen, errors.NoWidget): + return NULL_SPACING + + @property + def container_viewport(self) -> Region: + """The viewport region (parent window). + + Returns: + The region that contains this widget. + """ + if self.parent is None: + return self.size.region + assert isinstance(self.parent, Widget) + return self.parent.region + + @property + def virtual_region(self) -> Region: + """The widget region relative to its container (which may not be visible, + depending on scroll offset). + + + Returns: + The virtual region. + """ + try: + return self.screen.find_widget(self).virtual_region + except NoScreen: + return Region() + except errors.NoWidget: + return Region() + + @property + def window_region(self) -> Region: + """The region within the scrollable area that is currently visible. + + Returns: + New region. + """ + window_region = self.region.at_offset(self.scroll_offset) + return window_region + + @property + def virtual_region_with_margin(self) -> Region: + """The widget region relative to its container (*including margin*), which may not be visible, + depending on the scroll offset. + + Returns: + The virtual region of the Widget, inclusive of its margin. + """ + return self.virtual_region.grow(self.styles.margin) + + @property + def _self_or_ancestors_disabled(self) -> bool: + """Is this widget or any of its ancestors disabled?""" + + node: Widget | None = self + while isinstance(node, Widget) and not node.is_dom_root: + if node.disabled: + return True + node = node._parent # type:ignore[assignment] + return False + + @property + def focusable(self) -> bool: + """Can this widget currently be focused?""" + return ( + not self.loading + and self.allow_focus() + and self.visible + and not self._self_or_ancestors_disabled + ) + + @property + def _focus_sort_key(self) -> tuple[int, int]: + """Key function to sort widgets into focus order.""" + x, y, _, _ = self.virtual_region + top, _, _, left = self.styles.margin + return y - top, x - left + + @property + def scroll_offset(self) -> Offset: + """Get the current scroll offset. + + Returns: + Offset a container has been scrolled by. + """ + return Offset(round(self.scroll_x), round(self.scroll_y)) + + @property + def container_scroll_offset(self) -> Offset: + """The scroll offset the nearest container ancestor.""" + for node in self.ancestors: + if isinstance(node, Widget) and node.is_scrollable: + return node.scroll_offset + return Offset() + + @property + def _console(self) -> Console: + """Get the current console. + + Returns: + A Rich console object. + """ + return self.app.console + + @property + def _has_relative_children_width(self) -> bool: + """Do any children (or progeny) have a relative width?""" + if not self.is_container: + return False + for child in self.children: + if child.styles.expand == "optimal": + continue + styles = child.styles + if styles.display == "none": + continue + width = styles.width + if width is None: + continue + if styles.is_relative_width or ( + width.is_auto and child._has_relative_children_width + ): + return True + return False + + @property + def _has_relative_children_height(self) -> bool: + """Do any children (or progeny) have a relative height?""" + + if not self.is_container: + return False + for child in self.children: + styles = child.styles + if styles.display == "none": + continue + height = styles.height + if height is None: + continue + if styles.is_relative_height or ( + height.is_auto and child._has_relative_children_height + ): + return True + return False + + @property + def is_on_screen(self) -> bool: + """Check if the node was displayed in the last screen update.""" + try: + self.screen.find_widget(self) + except (NoScreen, errors.NoWidget): + return False + return True + + def _resolve_extrema( + self, + container: Size, + viewport: Size, + width_fraction: Fraction, + height_fraction: Fraction, + ) -> Extrema: + """Resolve minimum and maximum values for width and height. + + Args: + container: Size of outer widget. + viewport: Viewport size. + width_fraction: Size of 1fr width. + height_fraction: Size of 1fr height. + + Returns: + Extrema object. + """ + + min_width: Fraction | None = None + max_width: Fraction | None = None + min_height: Fraction | None = None + max_height: Fraction | None = None + + styles = self.styles + container -= styles.margin.totals + if styles.box_sizing == "border-box": + gutter_width, gutter_height = styles.gutter.totals + else: + gutter_width = gutter_height = 0 + + if styles.min_width is not None: + min_width = ( + styles.min_width.resolve(container, viewport, width_fraction) + - gutter_width + ) + + if styles.max_width is not None: + max_width = ( + styles.max_width.resolve(container, viewport, width_fraction) + - gutter_width + ) + if styles.min_height is not None: + min_height = ( + styles.min_height.resolve(container, viewport, height_fraction) + - gutter_height + ) + + if styles.max_height is not None: + max_height = ( + styles.max_height.resolve(container, viewport, height_fraction) + - gutter_height + ) + + extrema = Extrema(min_width, max_width, min_height, max_height) + return extrema + + def animate( + self, + attribute: str, + value: float | Animatable, + *, + final_value: object = ..., + duration: float | None = None, + speed: float | None = None, + delay: float = 0.0, + easing: EasingFunction | str = DEFAULT_EASING, + on_complete: CallbackType | None = None, + level: AnimationLevel = "full", + ) -> None: + """Animate an attribute. + + Args: + attribute: Name of the attribute to animate. + value: The value to animate to. + final_value: The final value of the animation. Defaults to `value` if not set. + duration: The duration (in seconds) of the animation. + speed: The speed of the animation. + delay: A delay (in seconds) before the animation starts. + easing: An easing method. + on_complete: A callable to invoke when the animation is finished. + level: Minimum level required for the animation to take place (inclusive). + """ + if self._animate is None: + self._animate = self.app.animator.bind(self) + assert self._animate is not None + self._animate( + attribute, + value, + final_value=final_value, + duration=duration, + speed=speed, + delay=delay, + easing=easing, + on_complete=on_complete, + level=level, + ) + + async def stop_animation(self, attribute: str, complete: bool = True) -> None: + """Stop an animation on an attribute. + + Args: + attribute: Name of the attribute whose animation should be stopped. + complete: Should the animation be set to its final value? + + Note: + If there is no animation scheduled or running, this is a no-op. + """ + await self.app.animator.stop_animation(self, attribute, complete) + + @property + def layout(self) -> Layout: + """Get the layout object if set in styles, or a default layout. + + Returns: + A layout object. + """ + return self.styles.layout or self._default_layout + + @property + def is_container(self) -> bool: + """Is this widget a container (contains other widgets)?""" + return self.styles.layout is not None or bool(self._nodes) + + @property + def is_scrollable(self) -> bool: + """Can this widget be scrolled?""" + return self.styles.layout is not None or bool(self._nodes) + + @property + def is_scrolling(self) -> bool: + """Is this widget currently scrolling?""" + current_time = monotonic() + for node in self.ancestors: + if not isinstance(node, Widget): + break + if ( + node.scroll_x != node.scroll_target_x + or node.scroll_y != node.scroll_target_y + ): + return True + if current_time - node._last_scroll_time < 0.1: + # Scroll ended very recently + return True + return False + + @property + def layer(self) -> str: + """Get the name of this widgets layer. + + Returns: + Name of layer. + """ + return self.styles.layer or "default" + + @property + def layers(self) -> tuple[str, ...]: + """Layers of from parent. + + Returns: + Tuple of layer names. + """ + layers: tuple[str, ...] = ("default",) + for node in self.ancestors_with_self: + if not isinstance(node, Widget): + break + if node.styles.has_rule("layers"): + layers = node.styles.layers + return layers + + @property + def link_style(self) -> Style: + """Style of links. + + Returns: + Rich style. + """ + styles = self.styles + _, background = self.background_colors + link_background = background + styles.link_background + link_color = link_background + ( + link_background.get_contrast_text(styles.link_color.a) + if styles.auto_link_color + else styles.link_color + ) + style = styles.link_style + Style.from_color( + link_color.rich_color, + link_background.rich_color if styles.link_background.a else None, + ) + return style + + @property + def link_style_hover(self) -> Style: + """Style of links underneath the mouse cursor. + + Returns: + Rich Style. + """ + styles = self.styles + _, background = self.background_colors + hover_background = background + styles.link_background_hover + hover_color = hover_background + ( + hover_background.get_contrast_text(styles.link_color_hover.a) + if styles.auto_link_color_hover + else styles.link_color_hover + ) + style = styles.link_style_hover + Style.from_color( + hover_color.rich_color, + hover_background.rich_color, + ) + return style + + @property + def select_container(self) -> Widget: + """The widget's container used when selecting text.. + + Returns: + A widget which contains this widget. + """ + container: Widget = self + for widget in self.ancestors: + if isinstance(widget, Widget) and widget.is_scrollable: + return widget + return container + + def _set_dirty(self, *regions: Region) -> None: + """Set the Widget as 'dirty' (requiring re-paint). + + Regions should be specified as positional args. If no regions are added, then + the entire widget will be considered dirty. + + Args: + *regions: Regions which require a repaint. + """ + if regions: + content_offset = self.content_offset + widget_regions = [region.translate(content_offset) for region in regions] + self._dirty_regions.update(widget_regions) + self._repaint_regions.update(widget_regions) + self._styles_cache.set_dirty(*widget_regions) + else: + self._dirty_regions.clear() + self._repaint_regions.clear() + self._styles_cache.clear() + self._styles_cache.set_dirty(self.size.region) + outer_size = self.outer_size + self._dirty_regions.add(outer_size.region) + if outer_size: + self._repaint_regions.add(outer_size.region) + + def _exchange_repaint_regions(self) -> Collection[Region]: + """Get a copy of the regions which need a repaint, and clear internal cache. + + Returns: + Regions to repaint. + """ + regions = self._repaint_regions.copy() + self._repaint_regions.clear() + return regions + + def _scroll_to( + self, + x: float | None = None, + y: float | None = None, + *, + animate: bool = True, + speed: float | None = None, + duration: float | None = None, + easing: EasingFunction | str | None = None, + force: bool = False, + on_complete: CallbackType | None = None, + level: AnimationLevel = "basic", + release_anchor: bool = True, + ) -> bool: + """Scroll to a given (absolute) coordinate, optionally animating. + + Args: + x: X coordinate (column) to scroll to, or `None` for no change. + y: Y coordinate (row) to scroll to, or `None` for no change. + animate: Animate to new scroll position. + speed: Speed of scroll if `animate` is `True`. Or `None` to use duration. + duration: Duration of animation, if `animate` is `True` and speed is `None`. + easing: An easing method for the scrolling animation. + force: Force scrolling even when prohibited by overflow styling. + on_complete: A callable to invoke when the animation is finished. + level: Minimum level required for the animation to take place (inclusive). + release_anchor: If `True` call `release_anchor`. + + Returns: + `True` if the scroll position changed, otherwise `False`. + """ + if release_anchor: + self.release_anchor() + maybe_scroll_x = x is not None and (self.allow_horizontal_scroll or force) + maybe_scroll_y = y is not None and (self.allow_vertical_scroll or force) + scrolled_x = scrolled_y = False + + animator = self.app.animator + animator.force_stop_animation(self, "scroll_x") + animator.force_stop_animation(self, "scroll_y") + + def _animate_on_complete() -> None: + """set last scroll time, and invoke callback.""" + self.app._realtime_animation_complete() + self._last_scroll_time = monotonic() + if on_complete is not None: + self.call_next(on_complete) + + if animate: + # TODO: configure animation speed + if duration is None and speed is None: + speed = 50 + + if easing is None: + easing = DEFAULT_SCROLL_EASING + + if maybe_scroll_x: + assert x is not None + self.scroll_target_x = x + if x != self.scroll_x: + self.app._realtime_animation_begin() + self.animate( + "scroll_x", + self.scroll_target_x, + speed=speed, + duration=duration, + easing=easing, + on_complete=_animate_on_complete, + level=level, + ) + scrolled_x = True + if maybe_scroll_y: + assert y is not None + self.scroll_target_y = y + if y != self.scroll_y: + self.app._realtime_animation_begin() + self.animate( + "scroll_y", + self.scroll_target_y, + speed=speed, + duration=duration, + easing=easing, + on_complete=_animate_on_complete, + level=level, + ) + scrolled_y = True + + else: + if maybe_scroll_x: + assert x is not None + scroll_x = self.scroll_x + self.scroll_target_x = self.scroll_x = x + scrolled_x = scroll_x != self.scroll_x + if maybe_scroll_y: + assert y is not None + scroll_y = self.scroll_y + self.scroll_target_y = self.scroll_y = y + scrolled_y = scroll_y != self.scroll_y + + self._last_scroll_time = monotonic() + if on_complete is not None: + self.call_after_refresh(on_complete) + + return scrolled_x or scrolled_y + + @property + def allow_select(self) -> bool: + """Check if this widget permits text selection. + + Returns: + `True` if the widget supports text selection, otherwise `False`. + """ + return self.ALLOW_SELECT + + def pre_layout(self, layout: Layout) -> None: + """This method id called prior to a layout operation. + + Implement this method if you want to make updates that should impact + the layout. + + Args: + layout: The [Layout][textual.layout.Layout] instance that will be used to arrange this widget's children. + + """ + + def set_scroll(self, x: float | None, y: float | None) -> None: + """Set the scroll position without any validation. + + This is a low-level method for when you want to see the scroll position in the next frame. + For a more fully featured method, see [`scroll_to`][textual.widget.Widget.scroll_to]. + + Args: + x: Desired `X` coordinate. + y: Desired `Y` coordinate. + """ + if x is not None: + self.set_reactive(Widget.scroll_x, round(x)) + if y is not None: + self.set_reactive(Widget.scroll_y, round(y)) + + def scroll_to( + self, + x: float | None = None, + y: float | None = None, + *, + animate: bool = True, + speed: float | None = None, + duration: float | None = None, + easing: EasingFunction | str | None = None, + force: bool = False, + on_complete: CallbackType | None = None, + level: AnimationLevel = "basic", + immediate: bool = False, + release_anchor: bool = True, + ) -> None: + """Scroll to a given (absolute) coordinate, optionally animating. + + Args: + x: X coordinate (column) to scroll to, or `None` for no change. + y: Y coordinate (row) to scroll to, or `None` for no change. + animate: Animate to new scroll position. + speed: Speed of scroll if `animate` is `True`; or `None` to use `duration`. + duration: Duration of animation, if `animate` is `True` and `speed` is `None`. + easing: An easing method for the scrolling animation. + force: Force scrolling even when prohibited by overflow styling. + on_complete: A callable to invoke when the animation is finished. + level: Minimum level required for the animation to take place (inclusive). + immediate: If `False` the scroll will be deferred until after a screen refresh, + set to `True` to scroll immediately. + release_anchor: If `True` call `release_anchor`. + + Note: + The call to scroll is made after the next refresh. + """ + if release_anchor: + self.release_anchor() + animator = self.app.animator + if x is not None: + animator.force_stop_animation(self, "scroll_x") + if y is not None: + animator.force_stop_animation(self, "scroll_y") + if immediate: + self._scroll_to( + x, + y, + animate=animate, + speed=speed, + duration=duration, + easing=easing, + force=force, + on_complete=on_complete, + level=level, + ) + else: + self.call_after_refresh( + self._scroll_to, + x, + y, + animate=animate, + speed=speed, + duration=duration, + easing=easing, + force=force, + on_complete=on_complete, + level=level, + ) + + def scroll_relative( + self, + x: float | None = None, + y: float | None = None, + *, + animate: bool = True, + speed: float | None = None, + duration: float | None = None, + easing: EasingFunction | str | None = None, + force: bool = False, + on_complete: CallbackType | None = None, + level: AnimationLevel = "basic", + immediate: bool = False, + ) -> None: + """Scroll relative to current position. + + Args: + x: X distance (columns) to scroll, or ``None`` for no change. + y: Y distance (rows) to scroll, or ``None`` for no change. + animate: Animate to new scroll position. + speed: Speed of scroll if `animate` is `True`. Or `None` to use `duration`. + duration: Duration of animation, if animate is `True` and speed is `None`. + easing: An easing method for the scrolling animation. + force: Force scrolling even when prohibited by overflow styling. + on_complete: A callable to invoke when the animation is finished. + level: Minimum level required for the animation to take place (inclusive). + immediate: If `False` the scroll will be deferred until after a screen refresh, + set to `True` to scroll immediately. + """ + self.scroll_to( + None if x is None else (self.scroll_x + x), + None if y is None else (self.scroll_y + y), + animate=animate, + speed=speed, + duration=duration, + easing=easing, + force=force, + on_complete=on_complete, + level=level, + immediate=immediate, + ) + + def scroll_home( + self, + *, + animate: bool = True, + speed: float | None = None, + duration: float | None = None, + easing: EasingFunction | str | None = None, + force: bool = False, + on_complete: CallbackType | None = None, + level: AnimationLevel = "basic", + immediate: bool = False, + x_axis: bool = True, + y_axis: bool = True, + ) -> None: + """Scroll to home position. + + Args: + animate: Animate scroll. + speed: Speed of scroll if animate is `True`; or `None` to use duration. + duration: Duration of animation, if `animate` is `True` and `speed` is `None`. + easing: An easing method for the scrolling animation. + force: Force scrolling even when prohibited by overflow styling. + on_complete: A callable to invoke when the animation is finished. + level: Minimum level required for the animation to take place (inclusive). + immediate: If `False` the scroll will be deferred until after a screen refresh, + set to `True` to scroll immediately. + x_axis: Allow scrolling on X axis? + y_axis: Allow scrolling on Y axis? + """ + if speed is None and duration is None: + duration = 1.0 + self.scroll_to( + 0 if x_axis else None, + 0 if y_axis else None, + animate=animate, + speed=speed, + duration=duration, + easing=easing, + force=force, + on_complete=on_complete, + level=level, + immediate=immediate, + ) + + def scroll_end( + self, + *, + animate: bool = True, + speed: float | None = None, + duration: float | None = None, + easing: EasingFunction | str | None = None, + force: bool = False, + on_complete: CallbackType | None = None, + level: AnimationLevel = "basic", + immediate: bool = False, + x_axis: bool = True, + y_axis: bool = True, + ) -> None: + """Scroll to the end of the container. + + Args: + animate: Animate scroll. + speed: Speed of scroll if `animate` is `True`; or `None` to use `duration`. + duration: Duration of animation, if `animate` is `True` and `speed` is `None`. + easing: An easing method for the scrolling animation. + force: Force scrolling even when prohibited by overflow styling. + on_complete: A callable to invoke when the animation is finished. + level: Minimum level required for the animation to take place (inclusive). + immediate: If `False` the scroll will be deferred until after a screen refresh, + set to `True` to scroll immediately. + x_axis: Allow scrolling on X axis? + y_axis: Allow scrolling on Y axis? + + """ + + if speed is None and duration is None: + duration = 1.0 + + async def scroll_end_on_complete() -> None: + """It's possible new content was added before we reached the end.""" + if on_complete is not None: + self.call_next(on_complete) + + # In most cases we'd call self.scroll_to and let it handle the call + # to do things after a refresh, but here we need the refresh to + # happen first so that we can get the new self.max_scroll_y (that + # is, we need the layout to work out and then figure out how big + # things are). Because of this we'll create a closure over the call + # here and make our own call to call_after_refresh. + def _lazily_scroll_end() -> None: + """Scroll to the end of the widget.""" + self._scroll_to( + 0 if x_axis else None, + self.max_scroll_y if y_axis else None, + animate=animate, + speed=speed, + duration=duration, + easing=easing, + force=force, + on_complete=scroll_end_on_complete, + level=level, + release_anchor=False, + ) + + if self._anchored and self._anchor_released: + self._anchor_released = False + + if immediate: + _lazily_scroll_end() + else: + self.call_after_refresh(_lazily_scroll_end) + + def scroll_left( + self, + *, + animate: bool = True, + speed: float | None = None, + duration: float | None = None, + easing: EasingFunction | str | None = None, + force: bool = False, + on_complete: CallbackType | None = None, + level: AnimationLevel = "basic", + immediate: bool = False, + ) -> None: + """Scroll one cell left. + + Args: + animate: Animate scroll. + speed: Speed of scroll if `animate` is `True`; or `None` to use `duration`. + duration: Duration of animation, if `animate` is `True` and `speed` is `None`. + easing: An easing method for the scrolling animation. + force: Force scrolling even when prohibited by overflow styling. + on_complete: A callable to invoke when the animation is finished. + level: Minimum level required for the animation to take place (inclusive). + immediate: If `False` the scroll will be deferred until after a screen refresh, + set to `True` to scroll immediately. + """ + self.scroll_to( + x=self.scroll_target_x - 1, + animate=animate, + speed=speed, + duration=duration, + easing=easing, + force=force, + on_complete=on_complete, + level=level, + immediate=immediate, + ) + + def _scroll_left_for_pointer( + self, + *, + animate: bool = True, + speed: float | None = None, + duration: float | None = None, + easing: EasingFunction | str | None = None, + force: bool = False, + on_complete: CallbackType | None = None, + level: AnimationLevel = "basic", + ) -> bool: + """Scroll left one position, taking scroll sensitivity into account. + + Args: + animate: Animate scroll. + speed: Speed of scroll if `animate` is `True`; or `None` to use `duration`. + duration: Duration of animation, if `animate` is `True` and `speed` is `None`. + easing: An easing method for the scrolling animation. + force: Force scrolling even when prohibited by overflow styling. + on_complete: A callable to invoke when the animation is finished. + level: Minimum level required for the animation to take place (inclusive). + + Returns: + `True` if any scrolling was done. + + Note: + How much is scrolled is controlled by + [App.scroll_sensitivity_x][textual.app.App.scroll_sensitivity_x]. + """ + return self._scroll_to( + x=self.scroll_target_x - self.app.scroll_sensitivity_x, + animate=animate, + speed=speed, + duration=duration, + easing=easing, + force=force, + on_complete=on_complete, + level=level, + ) + + def scroll_right( + self, + *, + animate: bool = True, + speed: float | None = None, + duration: float | None = None, + easing: EasingFunction | str | None = None, + force: bool = False, + on_complete: CallbackType | None = None, + level: AnimationLevel = "basic", + immediate: bool = False, + ) -> None: + """Scroll one cell right. + + Args: + animate: Animate scroll. + speed: Speed of scroll if animate is `True`; or `None` to use `duration`. + duration: Duration of animation, if `animate` is `True` and `speed` is `None`. + easing: An easing method for the scrolling animation. + force: Force scrolling even when prohibited by overflow styling. + on_complete: A callable to invoke when the animation is finished. + level: Minimum level required for the animation to take place (inclusive). + immediate: If `False` the scroll will be deferred until after a screen refresh, + set to `True` to scroll immediately. + """ + self.scroll_to( + x=self.scroll_target_x + 1, + animate=animate, + speed=speed, + duration=duration, + easing=easing, + force=force, + on_complete=on_complete, + level=level, + immediate=immediate, + ) + + def _scroll_right_for_pointer( + self, + *, + animate: bool = True, + speed: float | None = None, + duration: float | None = None, + easing: EasingFunction | str | None = None, + force: bool = False, + on_complete: CallbackType | None = None, + level: AnimationLevel = "basic", + ) -> bool: + """Scroll right one position, taking scroll sensitivity into account. + + Args: + animate: Animate scroll. + speed: Speed of scroll if animate is `True`; or `None` to use `duration`. + duration: Duration of animation, if `animate` is `True` and `speed` is `None`. + easing: An easing method for the scrolling animation. + force: Force scrolling even when prohibited by overflow styling. + on_complete: A callable to invoke when the animation is finished. + level: Minimum level required for the animation to take place (inclusive). + + Returns: + `True` if any scrolling was done. + + Note: + How much is scrolled is controlled by + [App.scroll_sensitivity_x][textual.app.App.scroll_sensitivity_x]. + """ + return self._scroll_to( + x=self.scroll_target_x + self.app.scroll_sensitivity_x, + animate=animate, + speed=speed, + duration=duration, + easing=easing, + force=force, + on_complete=on_complete, + level=level, + ) + + def scroll_down( + self, + *, + animate: bool = True, + speed: float | None = None, + duration: float | None = None, + easing: EasingFunction | str | None = None, + force: bool = False, + on_complete: CallbackType | None = None, + level: AnimationLevel = "basic", + immediate: bool = False, + ) -> None: + """Scroll one line down. + + Args: + animate: Animate scroll. + speed: Speed of scroll if `animate` is `True`; or `None` to use `duration`. + duration: Duration of animation, if `animate` is `True` and `speed` is `None`. + easing: An easing method for the scrolling animation. + force: Force scrolling even when prohibited by overflow styling. + on_complete: A callable to invoke when the animation is finished. + level: Minimum level required for the animation to take place (inclusive). + immediate: If `False` the scroll will be deferred until after a screen refresh, + set to `True` to scroll immediately. + """ + self.scroll_to( + y=self.scroll_target_y + 1, + animate=animate, + speed=speed, + duration=duration, + easing=easing, + force=force, + on_complete=on_complete, + level=level, + immediate=immediate, + ) + + def _scroll_down_for_pointer( + self, + *, + animate: bool = True, + speed: float | None = None, + duration: float | None = None, + easing: EasingFunction | str | None = None, + force: bool = False, + on_complete: CallbackType | None = None, + level: AnimationLevel = "basic", + ) -> bool: + """Scroll down one position, taking scroll sensitivity into account. + + Args: + animate: Animate scroll. + speed: Speed of scroll if `animate` is `True`; or `None` to use `duration`. + duration: Duration of animation, if `animate` is `True` and `speed` is `None`. + easing: An easing method for the scrolling animation. + force: Force scrolling even when prohibited by overflow styling. + on_complete: A callable to invoke when the animation is finished. + level: Minimum level required for the animation to take place (inclusive). + + Returns: + `True` if any scrolling was done. + + Note: + How much is scrolled is controlled by + [App.scroll_sensitivity_y][textual.app.App.scroll_sensitivity_y]. + """ + return self._scroll_to( + y=self.scroll_target_y + self.app.scroll_sensitivity_y, + animate=animate, + speed=speed, + duration=duration, + easing=easing, + force=force, + on_complete=on_complete, + level=level, + ) + + def scroll_up( + self, + *, + animate: bool = True, + speed: float | None = None, + duration: float | None = None, + easing: EasingFunction | str | None = None, + force: bool = False, + on_complete: CallbackType | None = None, + level: AnimationLevel = "basic", + immediate: bool = False, + ) -> None: + """Scroll one line up. + + Args: + animate: Animate scroll. + speed: Speed of scroll if `animate` is `True`; or `None` to use `duration`. + duration: Duration of animation, if `animate` is `True` and speed is `None`. + easing: An easing method for the scrolling animation. + force: Force scrolling even when prohibited by overflow styling. + on_complete: A callable to invoke when the animation is finished. + level: Minimum level required for the animation to take place (inclusive). + immediate: If `False` the scroll will be deferred until after a screen refresh, + set to `True` to scroll immediately. + """ + self.scroll_to( + y=self.scroll_target_y - 1, + animate=animate, + speed=speed, + duration=duration, + easing=easing, + force=force, + on_complete=on_complete, + level=level, + immediate=immediate, + ) + + def _scroll_up_for_pointer( + self, + *, + animate: bool = True, + speed: float | None = None, + duration: float | None = None, + easing: EasingFunction | str | None = None, + force: bool = False, + on_complete: CallbackType | None = None, + level: AnimationLevel = "basic", + ) -> bool: + """Scroll up one position, taking scroll sensitivity into account. + + Args: + animate: Animate scroll. + speed: Speed of scroll if `animate` is `True`; or `None` to use `duration`. + duration: Duration of animation, if `animate` is `True` and speed is `None`. + easing: An easing method for the scrolling animation. + force: Force scrolling even when prohibited by overflow styling. + on_complete: A callable to invoke when the animation is finished. + level: Minimum level required for the animation to take place (inclusive). + + Returns: + `True` if any scrolling was done. + + Note: + How much is scrolled is controlled by + [App.scroll_sensitivity_y][textual.app.App.scroll_sensitivity_y]. + """ + return self._scroll_to( + y=self.scroll_target_y - self.app.scroll_sensitivity_y, + animate=animate, + speed=speed, + duration=duration, + easing=easing, + force=force, + on_complete=on_complete, + level=level, + ) + + def scroll_page_up( + self, + *, + animate: bool = True, + speed: float | None = None, + duration: float | None = None, + easing: EasingFunction | str | None = None, + force: bool = False, + on_complete: CallbackType | None = None, + level: AnimationLevel = "basic", + ) -> None: + """Scroll one page up. + + Args: + animate: Animate scroll. + speed: Speed of scroll if animate is `True`; or `None` to use `duration`. + duration: Duration of animation, if `animate` is `True` and `speed` is `None`. + easing: An easing method for the scrolling animation. + force: Force scrolling even when prohibited by overflow styling. + on_complete: A callable to invoke when the animation is finished. + level: Minimum level required for the animation to take place (inclusive). + """ + self.scroll_to( + y=self.scroll_y - self.scrollable_content_region.height, + animate=animate, + speed=speed, + duration=duration, + easing=easing, + force=force, + on_complete=on_complete, + level=level, + ) + + def scroll_page_down( + self, + *, + animate: bool = True, + speed: float | None = None, + duration: float | None = None, + easing: EasingFunction | str | None = None, + force: bool = False, + on_complete: CallbackType | None = None, + level: AnimationLevel = "basic", + ) -> None: + """Scroll one page down. + + Args: + animate: Animate scroll. + speed: Speed of scroll if animate is `True`; or `None` to use `duration`. + duration: Duration of animation, if `animate` is `True` and `speed` is `None`. + easing: An easing method for the scrolling animation. + force: Force scrolling even when prohibited by overflow styling. + on_complete: A callable to invoke when the animation is finished. + level: Minimum level required for the animation to take place (inclusive). + """ + self.scroll_to( + y=self.scroll_y + self.scrollable_content_region.height, + animate=animate, + speed=speed, + duration=duration, + easing=easing, + force=force, + on_complete=on_complete, + level=level, + ) + + def scroll_page_left( + self, + *, + animate: bool = True, + speed: float | None = None, + duration: float | None = None, + easing: EasingFunction | str | None = None, + force: bool = False, + on_complete: CallbackType | None = None, + level: AnimationLevel = "basic", + ) -> None: + """Scroll one page left. + + Args: + animate: Animate scroll. + speed: Speed of scroll if animate is `True`; or `None` to use `duration`. + duration: Duration of animation, if `animate` is `True` and `speed` is `None`. + easing: An easing method for the scrolling animation. + force: Force scrolling even when prohibited by overflow styling. + on_complete: A callable to invoke when the animation is finished. + level: Minimum level required for the animation to take place (inclusive). + """ + if speed is None and duration is None: + duration = 0.3 + self.scroll_to( + x=self.scroll_x - self.scrollable_content_region.width, + animate=animate, + speed=speed, + duration=duration, + easing=easing, + force=force, + on_complete=on_complete, + level=level, + ) + + def scroll_page_right( + self, + *, + animate: bool = True, + speed: float | None = None, + duration: float | None = None, + easing: EasingFunction | str | None = None, + force: bool = False, + on_complete: CallbackType | None = None, + level: AnimationLevel = "basic", + ) -> None: + """Scroll one page right. + + Args: + animate: Animate scroll. + speed: Speed of scroll if animate is `True`; or `None` to use `duration`. + duration: Duration of animation, if `animate` is `True` and `speed` is `None`. + easing: An easing method for the scrolling animation. + force: Force scrolling even when prohibited by overflow styling. + on_complete: A callable to invoke when the animation is finished. + level: Minimum level required for the animation to take place (inclusive). + """ + if speed is None and duration is None: + duration = 0.3 + self.scroll_to( + x=self.scroll_x + self.scrollable_content_region.width, + animate=animate, + speed=speed, + duration=duration, + easing=easing, + force=force, + on_complete=on_complete, + level=level, + ) + + def scroll_to_widget( + self, + widget: Widget, + *, + animate: bool = True, + speed: float | None = None, + duration: float | None = None, + easing: EasingFunction | str | None = None, + center: bool = False, + top: bool = False, + origin_visible: bool = True, + force: bool = False, + on_complete: CallbackType | None = None, + level: AnimationLevel = "basic", + immediate: bool = False, + ) -> bool: + """Scroll scrolling to bring a widget into view. + + Args: + widget: A descendant widget. + animate: `True` to animate, or `False` to jump. + speed: Speed of scroll if `animate` is `True`; or `None` to use `duration`. + duration: Duration of animation, if `animate` is `True` and `speed` is `None`. + easing: An easing method for the scrolling animation. + top: Scroll widget to top of container. + origin_visible: Ensure that the top left of the widget is within the window. + force: Force scrolling even when prohibited by overflow styling. + on_complete: A callable to invoke when the animation is finished. + level: Minimum level required for the animation to take place (inclusive). + immediate: If `False` the scroll will be deferred until after a screen refresh, + set to `True` to scroll immediately. + + Returns: + `True` if any scrolling has occurred in any descendant, otherwise `False`. + """ + # Grow the region by the margin so to keep the margin in view. + region = widget.virtual_region_with_margin + scrolled = False + + if not region.size: + if on_complete is not None: + self.call_after_refresh(on_complete) + return False + + while isinstance(widget.parent, Widget) and widget is not self: + if not region: + break + + container = widget.parent + + if widget.styles.dock != "none": + scroll_offset = Offset(0, 0) + else: + scroll_offset = container.scroll_to_region( + region, + spacing=widget.dock_gutter, + animate=animate, + speed=speed, + duration=duration, + center=center, + top=top, + easing=easing, + origin_visible=origin_visible, + force=force, + on_complete=on_complete, + level=level, + immediate=immediate, + ) + if scroll_offset: + scrolled = True + + # Adjust the region by the amount we just scrolled it, and convert to + # its parent's virtual coordinate system. + region = ( + ( + region.translate(-scroll_offset) + .translate(container.styles.margin.top_left) + .translate(container.styles.border.spacing.top_left) + .translate(container.virtual_region_with_margin.offset) + ) + .grow(container.styles.margin) + .intersection(container.virtual_region_with_margin) + ) + + widget = container + return scrolled + + def scroll_to_region( + self, + region: Region, + *, + spacing: Spacing | None = None, + animate: bool = True, + speed: float | None = None, + duration: float | None = None, + easing: EasingFunction | str | None = None, + center: bool = False, + top: bool = False, + origin_visible: bool = True, + force: bool = False, + on_complete: CallbackType | None = None, + level: AnimationLevel = "basic", + x_axis: bool = True, + y_axis: bool = True, + immediate: bool = False, + ) -> Offset: + """Scrolls a given region into view, if required. + + This method will scroll the least distance required to move `region` fully within + the scrollable area. + + Args: + region: A region that should be visible. + spacing: Optional spacing around the region. + animate: `True` to animate, or `False` to jump. + speed: Speed of scroll if `animate` is `True`; or `None` to use `duration`. + duration: Duration of animation, if `animate` is `True` and `speed` is `None`. + easing: An easing method for the scrolling animation. + top: Scroll `region` to top of container. + origin_visible: Ensure that the top left of the widget is within the window. + force: Force scrolling even when prohibited by overflow styling. + on_complete: A callable to invoke when the animation is finished. + level: Minimum level required for the animation to take place (inclusive). + x_axis: Allow scrolling on X axis? + y_axis: Allow scrolling on Y axis? + immediate: If `False` the scroll will be deferred until after a screen refresh, + set to `True` to scroll immediately. + + Returns: + The distance that was scrolled. + """ + window = self.scrollable_content_region.at_offset(self.scroll_offset) + if spacing is not None: + window = window.shrink(spacing) + + if window in region and not (top or center): + if on_complete is not None: + self.call_after_refresh(on_complete) + return Offset() + + def clamp_delta(delta: Offset) -> Offset: + """Clamp the delta to avoid scrolling out of range.""" + scroll_x, scroll_y = self.scroll_offset + delta = Offset( + clamp(scroll_x + delta.x, 0, self.max_scroll_x) - scroll_x, + clamp(scroll_y + delta.y, 0, self.max_scroll_y) - scroll_y, + ) + return delta + + if center: + region_center_x, region_center_y = region.center + window_center_x, window_center_y = window.center + + delta = clamp_delta( + Offset( + int(region_center_x - window_center_x + 0.5), + int(region_center_y - window_center_y + 0.5), + ) + ) + if origin_visible and (region.offset not in window.translate(delta)): + delta = clamp_delta( + Region.get_scroll_to_visible(window, region, top=True) + ) + else: + delta = clamp_delta( + Region.get_scroll_to_visible(window, region, top=top), + ) + + if not self.allow_horizontal_scroll and not force: + delta = Offset(0, delta.y) + + if not self.allow_vertical_scroll and not force: + delta = Offset(delta.x, 0) + + if delta: + delta_x = delta.x if x_axis else 0 + delta_y = delta.y if y_axis else 0 + if speed is None and duration is None: + duration = 0.2 + self.scroll_relative( + delta_x or None, + delta_y or None, + animate=animate, + speed=speed, + duration=duration, + easing=easing, + force=force, + on_complete=on_complete, + level=level, + immediate=immediate, + ) + else: + if on_complete is not None: + self.call_after_refresh(on_complete) + return delta + + def scroll_visible( + self, + animate: bool = True, + *, + speed: float | None = None, + duration: float | None = None, + top: bool = False, + easing: EasingFunction | str | None = None, + force: bool = False, + on_complete: CallbackType | None = None, + level: AnimationLevel = "basic", + immediate: bool = False, + ) -> None: + """Scroll the container to make this widget visible. + + Args: + animate: Animate scroll. + speed: Speed of scroll if animate is `True`; or `None` to use `duration`. + duration: Duration of animation, if `animate` is `True` and `speed` is `None`. + top: Scroll to top of container. + easing: An easing method for the scrolling animation. + force: Force scrolling even when prohibited by overflow styling. + on_complete: A callable to invoke when the animation is finished. + level: Minimum level required for the animation to take place (inclusive). + immediate: If `False` the scroll will be deferred until after a screen refresh, + set to `True` to scroll immediately. + """ + parent = self.parent + if isinstance(parent, Widget): + if self._size: + self.screen.scroll_to_widget( + self, + animate=animate, + speed=speed, + duration=duration, + top=top, + easing=easing, + force=force, + on_complete=on_complete, + level=level, + immediate=immediate, + ) + else: + # self.region is falsy which may indicate the widget hasn't been through a layout operation + # We can potentially make it do the right thing by postponing the scroll to after a refresh + parent.call_after_refresh( + self.screen.scroll_to_widget, + self, + animate=animate, + speed=speed, + duration=duration, + top=top, + easing=easing, + force=force, + on_complete=on_complete, + level=level, + immediate=immediate, + ) + + def scroll_to_center( + self, + widget: Widget, + animate: bool = True, + *, + speed: float | None = None, + duration: float | None = None, + easing: EasingFunction | str | None = None, + force: bool = False, + origin_visible: bool = True, + on_complete: CallbackType | None = None, + level: AnimationLevel = "basic", + immediate: bool = False, + ) -> None: + """Scroll this widget to the center of self. + + The center of the widget will be scrolled to the center of the container. + + Args: + widget: The widget to scroll to the center of self. + animate: Whether to animate the scroll. + speed: Speed of scroll if animate is `True`; or `None` to use `duration`. + duration: Duration of animation, if `animate` is `True` and `speed` is `None`. + easing: An easing method for the scrolling animation. + force: Force scrolling even when prohibited by overflow styling. + origin_visible: Ensure that the top left corner of the widget remains visible after the scroll. + on_complete: A callable to invoke when the animation is finished. + level: Minimum level required for the animation to take place (inclusive). + immediate: If `False` the scroll will be deferred until after a screen refresh, + set to `True` to scroll immediately. + """ + + self.scroll_to_widget( + widget=widget, + animate=animate, + speed=speed, + duration=duration, + easing=easing, + force=force, + center=True, + origin_visible=origin_visible, + on_complete=on_complete, + level=level, + immediate=immediate, + ) + + def can_view_entire(self, widget: Widget) -> bool: + """Check if a given widget is *fully* within the current view (scrollable area). + + Note: This doesn't necessarily equate to a widget being visible. + There are other reasons why a widget may not be visible. + + Args: + widget: A widget that is a descendant of self. + + Returns: + `True` if the entire widget is in view, `False` if it is partially visible or not in view. + """ + if widget is self: + return True + + if widget not in self.screen._compositor.visible_widgets: + return False + + region = widget.region + node: Widget = widget + + while isinstance(node.parent, Widget) and node is not self: + if region not in node.parent.scrollable_content_region: + return False + node = node.parent + return True + + def can_view_partial(self, widget: Widget) -> bool: + """Check if a given widget at least partially visible within the current view (scrollable area). + + Args: + widget: A widget that is a descendant of self. + + Returns: + `True` if any part of the widget is visible, `False` if it is outside of the viewable area. + """ + if widget is self: + return True + + if widget not in self.screen._compositor.visible_widgets or not widget.display: + return False + + region = widget.region + node: Widget = widget + + while isinstance(node.parent, Widget) and node is not self: + if not region.overlaps(node.parent.scrollable_content_region): + return False + node = node.parent + return True + + def __init_subclass__( + cls, + can_focus: bool | None = None, + can_focus_children: bool | None = None, + inherit_css: bool = True, + inherit_bindings: bool = True, + ) -> None: + name = cls.__name__ + if not name[0].isupper() and not name.startswith("_"): + raise BadWidgetName( + f"Widget subclass {name!r} should be capitalized or start with '_'." + ) + + super().__init_subclass__( + inherit_css=inherit_css, + inherit_bindings=inherit_bindings, + ) + base = cls.__mro__[0] + if issubclass(base, Widget): + cls.can_focus = base.can_focus if can_focus is None else can_focus + cls.can_focus_children = ( + base.can_focus_children + if can_focus_children is None + else can_focus_children + ) + + def __rich_repr__(self) -> rich.repr.Result: + try: + yield "id", self.id, None + if self.name: + yield "name", self.name + if self.classes: + yield "classes", " ".join(self.classes) + except AttributeError: + pass + + def _get_scrollable_region(self, region: Region) -> Region: + """Adjusts the Widget region to accommodate scrollbars. + + Args: + region: A region for the widget. + + Returns: + The widget region minus scrollbars. + """ + show_vertical_scrollbar, show_horizontal_scrollbar = self.scrollbars_enabled + + styles = self.styles + scrollbar_size_horizontal = styles.scrollbar_size_horizontal + scrollbar_size_vertical = styles.scrollbar_size_vertical + + show_vertical_scrollbar = bool( + show_vertical_scrollbar and scrollbar_size_vertical + ) + show_horizontal_scrollbar = bool( + show_horizontal_scrollbar and scrollbar_size_horizontal + ) + + if styles.scrollbar_gutter == "stable": + # Let's _always_ reserve some space, whether the scrollbar is actually displayed or not: + show_vertical_scrollbar = True + scrollbar_size_vertical = styles.scrollbar_size_vertical + + if show_horizontal_scrollbar and show_vertical_scrollbar: + (region, _, _, _) = region.split( + -scrollbar_size_vertical, + -scrollbar_size_horizontal, + ) + elif show_vertical_scrollbar: + region, _ = region.split_vertical(-scrollbar_size_vertical) + elif show_horizontal_scrollbar: + region, _ = region.split_horizontal(-scrollbar_size_horizontal) + return region + + def _arrange_scrollbars(self, region: Region) -> Iterable[tuple[Widget, Region]]: + """Arrange the 'chrome' widgets (typically scrollbars) for a layout element. + + Args: + region: The containing region. + + Returns: + Tuples of scrollbar Widget and region. + """ + show_vertical_scrollbar, show_horizontal_scrollbar = self.scrollbars_enabled + + scrollbar_size_horizontal = self.scrollbar_size_horizontal + scrollbar_size_vertical = self.scrollbar_size_vertical + + show_vertical_scrollbar = bool( + show_vertical_scrollbar and scrollbar_size_vertical + ) + show_horizontal_scrollbar = bool( + show_horizontal_scrollbar and scrollbar_size_horizontal + ) + + if show_horizontal_scrollbar and show_vertical_scrollbar: + ( + window_region, + vertical_scrollbar_region, + horizontal_scrollbar_region, + scrollbar_corner_gap, + ) = region.split( + region.width - scrollbar_size_vertical, + region.height - scrollbar_size_horizontal, + ) + if scrollbar_corner_gap: + yield self.scrollbar_corner, scrollbar_corner_gap + if vertical_scrollbar_region: + scrollbar = self.vertical_scrollbar + scrollbar.window_virtual_size = self.virtual_size.height + scrollbar.window_size = window_region.height + yield scrollbar, vertical_scrollbar_region + if horizontal_scrollbar_region: + scrollbar = self.horizontal_scrollbar + scrollbar.window_virtual_size = self.virtual_size.width + scrollbar.window_size = window_region.width + yield scrollbar, horizontal_scrollbar_region + + elif show_vertical_scrollbar: + window_region, scrollbar_region = region.split_vertical( + region.width - scrollbar_size_vertical + ) + if scrollbar_region: + scrollbar = self.vertical_scrollbar + scrollbar.window_virtual_size = self.virtual_size.height + scrollbar.window_size = window_region.height + yield scrollbar, scrollbar_region + elif show_horizontal_scrollbar: + window_region, scrollbar_region = region.split_horizontal( + region.height - scrollbar_size_horizontal + ) + if scrollbar_region: + scrollbar = self.horizontal_scrollbar + scrollbar.window_virtual_size = self.virtual_size.width + scrollbar.window_size = window_region.width + yield scrollbar, scrollbar_region + + def get_pseudo_class_state(self) -> PseudoClasses: + """Get an object describing whether each pseudo class is present on this object or not. + + Returns: + A PseudoClasses object describing the pseudo classes that are present. + """ + node: MessagePump | None = self + disabled = False + while isinstance(node, Widget): + if node.disabled: + disabled = True + break + node = node._parent + + pseudo_classes = PseudoClasses( + enabled=not disabled, + hover=self.mouse_hover, + focus=self.has_focus, + ) + return pseudo_classes + + @property + def _pseudo_classes_cache_key(self) -> tuple[int, ...]: + """A cache key that changes when the pseudo-classes change.""" + return ( + self.mouse_hover, + self.has_focus, + self.is_disabled, + ) + + def _get_justify_method(self) -> JustifyMethod | None: + """Get the justify method that may be passed to a Rich renderable.""" + text_justify: JustifyMethod | None = None + + if self.styles.has_rule("text_align"): + text_align: JustifyMethod = cast(JustifyMethod, self.styles.text_align) + text_justify = _JUSTIFY_MAP.get(text_align, text_align) + return text_justify + + def post_render( + self, renderable: RenderableType, base_style: Style + ) -> ConsoleRenderable: + """Applies style attributes to the default renderable. + + This method is called by Textual itself. + It is unlikely you will need to call or implement this method. + + Returns: + A new renderable. + """ + + text_justify = self._get_justify_method() + + if isinstance(renderable, str): + renderable = Text.from_markup(renderable, justify=text_justify) + + if ( + isinstance(renderable, Text) + and text_justify is not None + and renderable.justify != text_justify + ): + renderable = renderable.copy() + renderable.justify = text_justify + + renderable = _Styled( + cast(ConsoleRenderable, renderable), + base_style, + self.link_style if self.auto_links else None, + ) + + return renderable + + def watch_has_focus(self, _has_focus: bool) -> None: + """Update from CSS if has focus state changes.""" + self.update_node_styles() + + def watch_disabled(self, disabled: bool) -> None: + """Update the styles of the widget and its children when disabled is toggled.""" + from memray._vendor.textual.app import ScreenStackError + + if disabled and self.mouse_hover and self.app.mouse_over is not None: + # Ensure widget gets a Leave if it is disabled while hovered + self._message_queue.put_nowait(events.Leave(self.app.mouse_over)) + try: + screen = self.screen + if ( + disabled + and screen.focused is not None + and self in screen.focused.ancestors_with_self + ): + screen.focused.blur() + except (ScreenStackError, NoActiveAppError, NoScreen): + pass + + self.update_node_styles() + + def _size_updated( + self, size: Size, virtual_size: Size, container_size: Size, layout: bool = True + ) -> bool: + """Called when the widget's size is updated. + + Args: + size: Screen size. + virtual_size: Virtual (scrollable) size. + container_size: Container size (size of parent). + layout: Perform layout if required. + + Returns: + True if a resize event should be sent, otherwise False. + """ + + self._layout_cache.clear() + if ( + self._size != size + or self.virtual_size != virtual_size + or self._container_size != container_size + ): + if self._size != size: + self._set_dirty() + self._size = size + if layout: + self.virtual_size = virtual_size + else: + self.set_reactive(Widget.virtual_size, virtual_size) + self._container_size = container_size + if self.is_scrollable: + self._scroll_update(virtual_size) + return True + else: + return False + + def _scroll_update(self, virtual_size: Size) -> None: + """Update scrollbars visibility and dimensions. + + Args: + virtual_size: Virtual size. + """ + self._refresh_scrollbars() + width, height = self.container_size + + if self.show_vertical_scrollbar and self.styles.scrollbar_size_vertical: + self.vertical_scrollbar.window_virtual_size = virtual_size.height + self.vertical_scrollbar.window_size = ( + height - self.scrollbar_size_horizontal + ) + self.vertical_scrollbar.refresh() + if self.show_horizontal_scrollbar and self.styles.scrollbar_size_horizontal: + self.horizontal_scrollbar.window_virtual_size = virtual_size.width + self.horizontal_scrollbar.window_size = width - self.scrollbar_size_vertical + self.horizontal_scrollbar.refresh() + + self.scroll_x = self.validate_scroll_x(self.scroll_x) + self.scroll_y = self.validate_scroll_y(self.scroll_y) + + @property + def visual_style(self) -> VisualStyle: + """The widget's current style.""" + if ( + self._visual_style is None + or self._visual_style_cache_key != self.styles._cache_key + ): + self._visual_style_cache_key = self.styles._cache_key + background = Color(0, 0, 0, 0) + color = Color(255, 255, 255, 0) + + style = Style() + opacity = 1.0 + + for node in reversed(self.ancestors_with_self): + styles = node.styles + has_rule = styles.has_rule + opacity *= styles.opacity + if has_rule("background"): + text_background = background + styles.background.tint( + styles.background_tint + ) + background += ( + styles.background.tint(styles.background_tint) + ).multiply_alpha(opacity) + else: + text_background = background + if has_rule("color"): + color = styles.color + style += styles.text_style + if has_rule("auto_color") and styles.auto_color: + color = text_background.get_contrast_text(color.a) + + self._visual_style = VisualStyle( + background, + color, + bold=style.bold, + dim=style.dim, + italic=style.italic, + reverse=style.reverse, + underline=style.underline, + strike=style.strike, + ) + return self._visual_style + + def get_selection(self, selection: Selection) -> tuple[str, str] | None: + """Get the text under the selection. + + !!! note + Implement this method if are building custom widget. If you just want to get the currently + selected text, then see [`Screen.get_selected_text`](textual.screen.Screen.get_selected_text) + + + Args: + selection: Selection information. + + Returns: + Tuple of extracted text and ending (typically "\n" or " "), or `None` if no text could be extracted. + """ + visual = self._render() + if isinstance(visual, (Text, Content)): + text = str(visual) + else: + return None + return selection.extract(text), "\n" + + def selection_updated(self, selection: Selection | None) -> None: + """Called when the selection is updated. + + Args: + selection: Selection information or `None` if no selection. + """ + self.refresh() + + def _render_content(self) -> None: + """Render all lines.""" + width, height = self.size + visual = self._render() + strips = Visual.to_strips(self, visual, width, height, self.visual_style) + self._render_cache = _RenderCache(self.size, strips) + self._dirty_regions.clear() + + def render_line(self, y: int) -> Strip: + """Render a line of content. + + Args: + y: Y Coordinate of line. + + Returns: + A rendered line. + """ + if self.BLANK: + return Strip.blank(self.size.width, self.visual_style.rich_style) + + if self._dirty_regions: + self._render_content() + try: + line = self._render_cache.lines[y] + except IndexError: + line = Strip.blank(self.size.width, self.visual_style.rich_style) + + return line + + def render_lines(self, crop: Region) -> list[Strip]: + """Render the widget into lines. + + Args: + crop: Region within visible area to render. + + Returns: + A list of list of segments. + """ + if self.BLANK: + strips = [ + Strip.blank(crop.width, self.visual_style.rich_style) + ] * crop.height + else: + strips = self._styles_cache.render_widget(self, crop) + return strips + + def get_style_at(self, x: int, y: int) -> Style: + """Get the Rich style in a widget at a given relative offset. + + Args: + x: X coordinate relative to the widget. + y: Y coordinate relative to the widget. + + Returns: + A rich Style object. + """ + offset = Offset(x, y) + screen_offset = offset + self.region.offset + + widget, _ = self.screen.get_widget_at(*screen_offset) + if widget is not self: + return Style() + return self.screen.get_style_at(*screen_offset) + + def suppress_click(self) -> None: + """Suppress a click event. + + This will prevent a [Click][textual.events.Click] event being sent, + if called after a mouse down event and before the click itself. + + """ + self.app._mouse_down_widget = None + + def _forward_event(self, event: events.Event) -> None: + event._set_forwarded() + self.post_message(event) + + def _refresh_scroll(self) -> None: + """Refreshes the scroll position.""" + self._scroll_required = True + self.check_idle() + + def refresh( + self, + *regions: Region, + repaint: bool = True, + layout: bool = False, + recompose: bool = False, + ) -> Self: + """Initiate a refresh of the widget. + + This method sets an internal flag to perform a refresh, which will be done on the + next idle event. Only one refresh will be done even if this method is called multiple times. + + By default this method will cause the content of the widget to refresh, but not change its size. You can also + set `layout=True` to perform a layout. + + !!! warning + + It is rarely necessary to call this method explicitly. Updating styles or reactive attributes will + do this automatically. + + Args: + *regions: Additional screen regions to mark as dirty. + repaint: Repaint the widget (will call render() again). + layout: Also layout widgets in the view. + recompose: Re-compose the widget (will remove and re-mount children). + + Returns: + The `Widget` instance. + """ + + if layout and not self._layout_required: + self._layout_required = True + self._layout_updates += 1 + + if recompose: + self._recompose_required = True + self.call_next(self._check_recompose) + return self + + if not self._is_mounted: + self._repaint_required = True + self.check_idle() + return self + + self._layout_cache.clear() + if repaint: + self._set_dirty(*regions) + self.clear_cached_dimensions() + self._rich_style_cache.clear() + self._repaint_required = True + + self.check_idle() + return self + + def remove(self) -> AwaitRemove: + """Remove the Widget from the DOM (effectively deleting it). + + Returns: + An awaitable object that waits for the widget to be removed. + """ + await_remove = self.app._prune(self, parent=self._parent) + return await_remove + + def remove_children( + self, selector: str | type[QueryType] | Iterable[Widget] = "*" + ) -> AwaitRemove: + """Remove the immediate children of this Widget from the DOM. + + Args: + selector: A CSS selector or iterable of widgets to remove. + + Returns: + An awaitable object that waits for the direct children to be removed. + """ + + if callable(selector) and issubclass(selector, Widget): + selector = selector.__name__ + + children_to_remove: Iterable[Widget] + + if isinstance(selector, str): + parsed_selectors = parse_selectors(selector) + children_to_remove = [ + child for child in self.children if match(parsed_selectors, child) + ] + else: + children_to_remove = selector + await_remove = self.app._prune(*children_to_remove, parent=self) + return await_remove + + @asynccontextmanager + async def batch(self) -> AsyncGenerator[None, None]: + """Async context manager that combines widget locking and update batching. + + Use this async context manager whenever you want to acquire the widget lock and + batch app updates at the same time. + + Example: + ```py + async with container.batch(): + await container.remove_children(Button) + await container.mount(Label("All buttons are gone.")) + ``` + """ + async with self.lock: + with self.app.batch_update(): + yield + + def render(self) -> RenderResult: + """Get [content](/guide/content) for the widget. + + Implement this method in a subclass for custom widgets. + + This method should return [markup](/guide/content#markup), a [Content][textual.content.Content] object, or a [Rich](https://github.com/Textualize/rich) renderable. + + Example: + ```python + from memray._vendor.textual.app import RenderResult + from memray._vendor.textual.widget import Widget + + class CustomWidget(Widget): + def render(self) -> RenderResult: + return "Welcome to [bold red]Textual[/]!" + ``` + + Returns: + A string or object to render as the widget's content. + """ + + if self.is_container: + if self.styles.layout and self.styles.keyline[0] != "none": + return self.layout.render_keyline(self) + else: + return Blank(self.background_colors[1]) + return self.css_identifier_styled + + def _render(self) -> Visual: + """Get renderable, promoting str to text as required. + + Returns: + A Visual. + """ + cache_key = "_render.visual" + cached_visual = self._layout_cache.get(cache_key, None) + if cached_visual is not None: + assert isinstance(cached_visual, Visual) + return cached_visual + visual = visualize(self, self.render(), markup=self._render_markup) + self._layout_cache[cache_key] = visual + return visual + + async def run_action( + self, action: str, namespaces: Mapping[str, DOMNode] | None = None + ) -> None: + """Perform a given action, with this widget as the default namespace. + + Args: + action: Action encoded as a string. + namespaces: Mapping of namespaces. + """ + await self.app.run_action(action, self, namespaces) + + def post_message(self, message: Message) -> bool: + """Post a message to this widget. + + Args: + message: Message to post. + + Returns: + True if the message was posted, False if this widget was closed / closing. + """ + _rich_traceback_omit = True + # Catch a common error. + # This will error anyway, but at least we can offer a helpful message here. + if not hasattr(message, "_prevent"): + raise RuntimeError( + f"{type(message)!r} is missing expected attributes; did you forget to call super().__init__() in the constructor?" + ) + + if constants.DEBUG and not self.is_running and not message.no_dispatch: + try: + self.log.warning(self, f"IS NOT RUNNING, {message!r} not sent") + except NoActiveAppError: + pass + return super().post_message(message) + + async def on_prune(self, event: messages.Prune) -> None: + """Close message loop when asked to prune.""" + await self._close_messages(wait=False) + + async def _message_loop_exit(self) -> None: + """Clean up DOM tree.""" + parent = self._parent + # Post messages to children, asking them to prune + children = [*self.children, *self._get_virtual_dom()] + for node in children: + node.post_message(Prune()) + + # Wait for child nodes to exit + await gather(*[node._task for node in children if node._task is not None]) + # Send unmount event + await self._dispatch_message(events.Unmount()) + assert isinstance(parent, DOMNode) + # Finalize removal from DOM + parent._nodes._remove(self) + self.app._registry.discard(self) + self._detach() + self._arrangement_cache.clear() + self._nodes._clear() + self._render_cache = _RenderCache(NULL_SIZE, []) + self._component_styles.clear() + self._query_one_cache.clear() + + async def _on_idle(self, event: events.Idle) -> None: + """Called when there are no more events on the queue. + + Args: + event: Idle event. + """ + self._check_refresh() + + def _check_refresh(self) -> None: + """Check if a refresh was requested.""" + if self._parent is not None and not self._closing: + try: + screen = self.screen + except NoScreen: + pass + else: + if self._refresh_styles_required: + self._refresh_styles_required = False + self.call_later(self.update_node_styles) + if self._scroll_required: + self._scroll_required = False + if not self._layout_required: + if self.styles.keyline[0] != "none": + # TODO: Feels like a hack + # Perhaps there should be an explicit mechanism for backgrounds to refresh when scrolled? + self._set_dirty() + screen.post_message(messages.UpdateScroll()) + if self._repaint_required: + self._repaint_required = False + if self.display: + screen.post_message(messages.Update(self)) + if self._layout_required: + self._layout_required = False + for ancestor in self.ancestors: + if not isinstance(ancestor, Widget): + break + ancestor._clear_arrangement_cache() + ancestor._layout_updates += 1 + if not ancestor.styles.auto_dimensions: + break + screen.post_message(messages.Layout(self)) + + def focus(self, scroll_visible: bool = True) -> Self: + """Give focus to this widget. + + Args: + scroll_visible: Scroll parent to make this widget visible. + + Returns: + The `Widget` instance. + """ + + def set_focus(widget: Widget) -> None: + """Callback to set the focus.""" + try: + widget.screen.set_focus(self, scroll_visible=scroll_visible) + except NoScreen: + pass + + self.refresh() + self.app.call_later(set_focus, self) + return self + + def blur(self) -> Self: + """Blur (un-focus) the widget. + + Focus will be moved to the next available widget in the focus chain. + + Returns: + The `Widget` instance. + """ + try: + self.screen._reset_focus(self) + except NoScreen: + pass + return self + + def capture_mouse(self, capture: bool = True) -> None: + """Capture (or release) the mouse. + + When captured, mouse events will go to this widget even when the pointer is not directly over the widget. + + Args: + capture: True to capture or False to release. + """ + self.app.capture_mouse(self if capture else None) + + def release_mouse(self) -> None: + """Release the mouse. + + Mouse events will only be sent when the mouse is over the widget. + """ + if self.app.mouse_captured is self: + self.app.capture_mouse(None) + + def text_select_all(self) -> None: + """Select the entire widget.""" + self.screen._select_all_in_widget(self) + + def begin_capture_print(self, stdout: bool = True, stderr: bool = True) -> None: + """Capture text from print statements (or writes to stdout / stderr). + + If printing is captured, the widget will be sent an [`events.Print`][textual.events.Print] message. + + Call [`end_capture_print`][textual.widget.Widget.end_capture_print] to disable print capture. + + Args: + stdout: Whether to capture stdout. + stderr: Whether to capture stderr. + """ + self.app.begin_capture_print(self, stdout=stdout, stderr=stderr) + + def end_capture_print(self) -> None: + """End print capture (set with [`begin_capture_print`][textual.widget.Widget.begin_capture_print]).""" + self.app.end_capture_print(self) + + def check_message_enabled(self, message: Message) -> bool: + """Check if a given message is enabled (allowed to be sent). + + Args: + message: A message object + + Returns: + `True` if the message will be sent, or `False` if it is disabled. + """ + # Do the normal checking and get out if that fails. + if not super().check_message_enabled(message) or self._is_prevented( + type(message) + ): + return False + + # Mouse scroll events should always go through, this allows mouse + # wheel scrolling to pass through disabled widgets. + if isinstance(message, _MOUSE_EVENTS_ALLOW_IF_DISABLED): + return True + # Otherwise, if this is any other mouse event, the widget receiving + # the event must not be disabled at this moment. + return ( + not self._self_or_ancestors_disabled + if isinstance(message, _MOUSE_EVENTS_DISALLOW_IF_DISABLED) + else True + ) + + async def broker_event(self, event_name: str, event: events.Event) -> bool: + return await self.app._broker_event(event_name, event, default_namespace=self) + + def notify_style_update(self) -> None: + self._rich_style_cache.clear() + self._visual_style_cache.clear() + self._visual_style = None + super().notify_style_update() + + async def _on_mouse_down(self, event: events.MouseDown) -> None: + await self.broker_event("mouse.down", event) + + async def _on_mouse_up(self, event: events.MouseUp) -> None: + await self.broker_event("mouse.up", event) + + async def _on_click(self, event: events.Click) -> None: + if event.widget is self: + if self.allow_select and self.screen.allow_select and self.app.ALLOW_SELECT: + if event.chain == 2: + self.text_select_all() + elif event.chain == 3 and self.parent is not None: + self.select_container.text_select_all() + + await self.broker_event("click", event) + + async def _on_key(self, event: events.Key) -> None: + await self.handle_key(event) + + async def handle_key(self, event: events.Key) -> bool: + return await dispatch_key(self, event) + + async def _on_compose(self, event: events.Compose) -> None: + _rich_traceback_omit = True + event.prevent_default() + await self._compose() + + async def _compose(self) -> None: + try: + widgets = [*self._pending_children, *compose(self)] + self._pending_children.clear() + except TypeError as error: + raise TypeError( + f"{self!r} compose() method returned an invalid result; {error}" + ) from error + except Exception as error: + self.app._handle_exception(error) + else: + self._extend_compose(widgets) + await self.mount_composed_widgets(widgets) + + async def mount_composed_widgets(self, widgets: list[Widget]) -> None: + """Called by Textual to mount widgets after compose. + + There is generally no need to implement this method in your application. + See [Lazy][textual.lazy.Lazy] for a class which uses this method to implement + *lazy* mounting. + + Args: + widgets: A list of child widgets. + """ + if widgets: + await self.mount_all(widgets) + + def _extend_compose(self, widgets: list[Widget]) -> None: + """Hook to extend composed widgets. + + Args: + widgets: Widgets to be mounted. + """ + + def _on_mount(self, event: events.Mount) -> None: + if self.styles.overflow_y == "scroll": + self.show_vertical_scrollbar = True + if self.styles.overflow_x == "scroll": + self.show_horizontal_scrollbar = True + + def _on_leave(self, event: events.Leave) -> None: + if event.node is self: + self.mouse_hover = False + self.hover_style = Style() + + def _on_enter(self, event: events.Enter) -> None: + if event.node is self: + self.mouse_hover = True + + def _on_focus(self, event: events.Focus) -> None: + self.has_focus = True + self.refresh() + if self.parent is not None: + self.parent.post_message(events.DescendantFocus(self)) + + def _on_blur(self, event: events.Blur) -> None: + self.has_focus = False + self.refresh() + if self.parent is not None: + self.parent.post_message(events.DescendantBlur(self)) + + def _on_mouse_scroll_down(self, event: events.MouseScrollDown) -> None: + if event.ctrl or event.shift: + if self.allow_horizontal_scroll: + if self._scroll_right_for_pointer(animate=False): + event.stop() + else: + if self.allow_vertical_scroll: + if self._scroll_down_for_pointer(animate=False): + event.stop() + + def _on_mouse_scroll_up(self, event: events.MouseScrollUp) -> None: + if event.ctrl or event.shift: + if self.allow_horizontal_scroll: + if self._scroll_left_for_pointer(animate=False): + event.stop() + else: + if self.allow_vertical_scroll: + if self._scroll_up_for_pointer(animate=False): + event.stop() + + def _on_mouse_scroll_right(self, event: events.MouseScrollRight) -> None: + if self.allow_horizontal_scroll: + if self._scroll_right_for_pointer(): + event.stop() + + def _on_mouse_scroll_left(self, event: events.MouseScrollLeft) -> None: + if self.allow_horizontal_scroll: + if self._scroll_left_for_pointer(): + event.stop() + + def _on_scroll_to(self, message: ScrollTo) -> None: + if self._allow_scroll: + self.scroll_to(message.x, message.y, animate=message.animate, duration=0.1) + message.stop() + + def _on_scroll_up(self, event: ScrollUp) -> None: + if self.allow_vertical_scroll: + self.scroll_page_up() + event.stop() + + def _on_scroll_down(self, event: ScrollDown) -> None: + if self.allow_vertical_scroll: + self.scroll_page_down() + event.stop() + + def _on_scroll_left(self, event: ScrollLeft) -> None: + if self.allow_horizontal_scroll: + self.scroll_page_left() + event.stop() + + def _on_scroll_right(self, event: ScrollRight) -> None: + if self.allow_horizontal_scroll: + self.scroll_page_right() + event.stop() + + def _on_show(self, event: events.Show) -> None: + if self.show_horizontal_scrollbar: + self.horizontal_scrollbar.post_message(event) + if self.show_vertical_scrollbar: + self.vertical_scrollbar.post_message(event) + + def _on_hide(self, event: events.Hide) -> None: + if self.show_horizontal_scrollbar: + self.horizontal_scrollbar.post_message(event) + if self.show_vertical_scrollbar: + self.vertical_scrollbar.post_message(event) + if self.has_focus: + self.blur() + + def _on_scroll_to_region(self, message: messages.ScrollToRegion) -> None: + self.scroll_to_region(message.region, animate=True) + + def _on_unmount(self) -> None: + self._uncover() + self.workers.cancel_node(self) + + def action_scroll_home(self) -> None: + if not self._allow_scroll: + raise SkipAction() + self.scroll_home(x_axis=self.scroll_y == 0) + + def action_scroll_end(self) -> None: + if not self._allow_scroll: + raise SkipAction() + self.scroll_end(x_axis=self.scroll_y == self.is_vertical_scroll_end) + + def action_scroll_left(self) -> None: + if not self.allow_horizontal_scroll: + raise SkipAction() + self.scroll_left() + + def action_scroll_right(self) -> None: + if not self.allow_horizontal_scroll: + raise SkipAction() + self.scroll_right() + + def action_scroll_up(self) -> None: + if not self.allow_vertical_scroll: + raise SkipAction() + self.scroll_up() + + def action_scroll_down(self) -> None: + if not self.allow_vertical_scroll: + raise SkipAction() + self.scroll_down() + + def action_page_down(self) -> None: + if not self.allow_vertical_scroll: + raise SkipAction() + self.scroll_page_down() + + def action_page_up(self) -> None: + if not self.allow_vertical_scroll: + raise SkipAction() + self.scroll_page_up() + + def action_page_left(self) -> None: + if not self.allow_horizontal_scroll: + raise SkipAction() + self.scroll_page_left() + + def action_page_right(self) -> None: + if not self.allow_horizontal_scroll: + raise SkipAction() + self.scroll_page_right() + + def notify( + self, + message: str, + *, + title: str = "", + severity: SeverityLevel = "information", + timeout: float | None = None, + markup: bool = True, + ) -> None: + """Create a notification. + + !!! tip + + This method is thread-safe. + + Args: + message: The message for the notification. + title: The title for the notification. + severity: The severity of the notification. + timeout: The timeout (in seconds) for the notification, or `None` for default. + markup: Render the message as content markup? + + See [`App.notify`][textual.app.App.notify] for the full + documentation for this method. + """ + if timeout is None: + return self.app.notify( + message, + title=title, + severity=severity, + markup=markup, + ) + else: + return self.app.notify( + message, + title=title, + severity=severity, + timeout=timeout, + markup=markup, + ) + + def action_notify( + self, + message: str, + title: str = "", + severity: str = "information", + markup: bool = True, + ) -> None: + self.notify( + message, + title=title, + severity=severity, + markup=markup, + ) diff --git a/src/memray/_vendor/textual/widgets/__init__.py b/src/memray/_vendor/textual/widgets/__init__.py new file mode 100644 index 0000000000..5030a70077 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/__init__.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +import typing +from importlib import import_module + +from memray._vendor.textual.case import camel_to_snake + +# For any new built-in Widget we create, not only do we have to import them here and add them to `__all__`, +# but also to the `__init__.pyi` file in this same folder - otherwise text editors and type checkers won't +# be able to "see" them. +if typing.TYPE_CHECKING: + from memray._vendor.textual.widget import Widget + from memray._vendor.textual.widgets._button import Button + from memray._vendor.textual.widgets._checkbox import Checkbox + from memray._vendor.textual.widgets._collapsible import Collapsible + from memray._vendor.textual.widgets._content_switcher import ContentSwitcher + from memray._vendor.textual.widgets._data_table import DataTable + from memray._vendor.textual.widgets._digits import Digits + from memray._vendor.textual.widgets._directory_tree import DirectoryTree + from memray._vendor.textual.widgets._footer import Footer + from memray._vendor.textual.widgets._header import Header + from memray._vendor.textual.widgets._help_panel import HelpPanel + from memray._vendor.textual.widgets._input import Input + from memray._vendor.textual.widgets._key_panel import KeyPanel + from memray._vendor.textual.widgets._label import Label + from memray._vendor.textual.widgets._link import Link + from memray._vendor.textual.widgets._list_item import ListItem + from memray._vendor.textual.widgets._list_view import ListView + from memray._vendor.textual.widgets._loading_indicator import LoadingIndicator + from memray._vendor.textual.widgets._log import Log + from memray._vendor.textual.widgets._markdown import Markdown, MarkdownViewer + from memray._vendor.textual.widgets._masked_input import MaskedInput + from memray._vendor.textual.widgets._option_list import OptionList + from memray._vendor.textual.widgets._placeholder import Placeholder + from memray._vendor.textual.widgets._pretty import Pretty + from memray._vendor.textual.widgets._progress_bar import ProgressBar + from memray._vendor.textual.widgets._radio_button import RadioButton + from memray._vendor.textual.widgets._radio_set import RadioSet + from memray._vendor.textual.widgets._rich_log import RichLog + from memray._vendor.textual.widgets._rule import Rule + from memray._vendor.textual.widgets._select import Select + from memray._vendor.textual.widgets._selection_list import SelectionList + from memray._vendor.textual.widgets._sparkline import Sparkline + from memray._vendor.textual.widgets._static import Static + from memray._vendor.textual.widgets._switch import Switch + from memray._vendor.textual.widgets._tabbed_content import TabbedContent, TabPane + from memray._vendor.textual.widgets._tabs import Tab, Tabs + from memray._vendor.textual.widgets._text_area import TextArea + from memray._vendor.textual.widgets._tooltip import Tooltip + from memray._vendor.textual.widgets._tree import Tree + from memray._vendor.textual.widgets._welcome import Welcome + +__all__ = [ + "Button", + "Checkbox", + "Collapsible", + "ContentSwitcher", + "DataTable", + "Digits", + "DirectoryTree", + "Footer", + "Header", + "HelpPanel", + "Input", + "KeyPanel", + "Label", + "Link", + "ListItem", + "ListView", + "LoadingIndicator", + "Log", + "Markdown", + "MarkdownViewer", + "MaskedInput", + "OptionList", + "Placeholder", + "Pretty", + "ProgressBar", + "RadioButton", + "RadioSet", + "RichLog", + "Rule", + "Select", + "SelectionList", + "Sparkline", + "Static", + "Switch", + "Tab", + "TabbedContent", + "TabPane", + "Tabs", + "TextArea", + "Tooltip", + "Tree", + "Welcome", +] + +_WIDGETS_LAZY_LOADING_CACHE: dict[str, type[Widget]] = {} + + +# Let's decrease startup time by lazy loading our Widgets: +def __getattr__(widget_class: str) -> type[Widget]: + try: + return _WIDGETS_LAZY_LOADING_CACHE[widget_class] + except KeyError: + pass + + if widget_class not in __all__: + raise AttributeError( + f"Package 'memray._vendor.textual.widgets' has no class '{widget_class}'" + ) + + widget_module_path = f"._{camel_to_snake(widget_class)}" + module = import_module( + widget_module_path, package="memray._vendor.textual.widgets" + ) + class_ = getattr(module, widget_class) + + _WIDGETS_LAZY_LOADING_CACHE[widget_class] = class_ + return class_ diff --git a/src/memray/_vendor/textual/widgets/__init__.pyi b/src/memray/_vendor/textual/widgets/__init__.pyi new file mode 100644 index 0000000000..907ae843b8 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/__init__.pyi @@ -0,0 +1,43 @@ +# This stub file must re-export every classes exposed in the __init__.py's `__all__` list: +from ._button import Button as Button +from ._checkbox import Checkbox as Checkbox +from ._collapsible import Collapsible as Collapsible +from ._content_switcher import ContentSwitcher as ContentSwitcher +from ._data_table import DataTable as DataTable +from ._digits import Digits as Digits +from ._directory_tree import DirectoryTree as DirectoryTree +from ._footer import Footer as Footer +from ._header import Header as Header +from ._help_panel import HelpPanel as HelpPanel +from ._input import Input as Input +from ._key_panel import KeyPanel as KeyPanel +from ._label import Label as Label +from ._link import Link as Link +from ._list_item import ListItem as ListItem +from ._list_view import ListView as ListView +from ._loading_indicator import LoadingIndicator as LoadingIndicator +from ._log import Log as Log +from ._markdown import Markdown as Markdown +from ._markdown import MarkdownViewer as MarkdownViewer +from ._masked_input import MaskedInput as MaskedInput +from ._option_list import OptionList as OptionList +from ._placeholder import Placeholder as Placeholder +from ._pretty import Pretty as Pretty +from ._progress_bar import ProgressBar as ProgressBar +from ._radio_button import RadioButton as RadioButton +from ._radio_set import RadioSet as RadioSet +from ._rich_log import RichLog as RichLog +from ._rule import Rule as Rule +from ._select import Select as Select +from ._selection_list import SelectionList as SelectionList +from ._sparkline import Sparkline as Sparkline +from ._static import Static as Static +from ._switch import Switch as Switch +from ._tabbed_content import TabbedContent as TabbedContent +from ._tabbed_content import TabPane as TabPane +from ._tabs import Tab as Tab +from ._tabs import Tabs as Tabs +from ._text_area import TextArea as TextArea +from ._tooltip import Tooltip as Tooltip +from ._tree import Tree as Tree +from ._welcome import Welcome as Welcome diff --git a/src/memray/_vendor/textual/widgets/_button.py b/src/memray/_vendor/textual/widgets/_button.py new file mode 100644 index 0000000000..9923bc0e1a --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_button.py @@ -0,0 +1,492 @@ +from __future__ import annotations + +from functools import partial +from typing import TYPE_CHECKING, cast + +import rich.repr +from rich.cells import cell_len +from rich.console import ConsoleRenderable, RenderableType +from typing_extensions import Literal, Self + +from memray._vendor.textual import events + +if TYPE_CHECKING: + from memray._vendor.textual.app import RenderResult + +from rich.style import Style + +from memray._vendor.textual.binding import Binding +from memray._vendor.textual.content import Content, ContentText +from memray._vendor.textual.css._error_tools import friendly_list +from memray._vendor.textual.geometry import Size +from memray._vendor.textual.message import Message +from memray._vendor.textual.reactive import reactive +from memray._vendor.textual.widget import Widget + +ButtonVariant = Literal["default", "primary", "success", "warning", "error"] +"""The names of the valid button variants. + +These are the variants that can be used with a [`Button`][textual.widgets.Button]. +""" + +_VALID_BUTTON_VARIANTS = {"default", "primary", "success", "warning", "error"} + + +class InvalidButtonVariant(Exception): + """Exception raised if an invalid button variant is used.""" + + +class Button(Widget, can_focus=True): + """A simple clickable button. + + Clicking the button will send a [Button.Pressed][textual.widgets.Button.Pressed] message, + unless the `action` parameter is provided. + + """ + + ALLOW_SELECT = False + + DEFAULT_CSS = """ + Button { + width: auto; + min-width: 16; + height:auto; + line-pad: 1; + text-align: center; + content-align: center middle; + pointer: pointer; + + &.-style-flat { + text-style: bold; + color: auto 90%; + background: $surface; + border: block $surface; + &:hover { + background: $primary; + border: block $primary; + } + &:focus { + text-style: $button-focus-text-style; + } + &.-active { + background: $surface; + border: block $surface; + tint: $background 30%; + } + &:disabled { + color: auto 50%; + pointer: not-allowed; + } + + &.-primary { + background: $primary-muted; + border: block $primary-muted; + color: $text-primary; + &:hover { + color: $text; + background: $primary; + border: block $primary; + } + } + &.-success { + background: $success-muted; + border: block $success-muted; + color: $text-success; + &:hover { + color: $text; + background: $success; + border: block $success; + } + } + &.-warning { + background: $warning-muted; + border: block $warning-muted; + color: $text-warning; + &:hover { + color: $text; + background: $warning; + border: block $warning; + } + } + &.-error { + background: $error-muted; + border: block $error-muted; + color: $text-error; + &:hover { + color: $text; + background: $error; + border: block $error; + } + } + } + &.-style-default { + text-style: bold; + color: $button-foreground; + background: $surface; + border: none; + border-top: tall $surface-lighten-1; + border-bottom: tall $surface-darken-1; + + + &.-textual-compact { + border: none !important; + } + + &:disabled { + text-opacity: 0.6; + pointer: not-allowed; + } + + &:focus { + text-style: $button-focus-text-style; + background-tint: $foreground 5%; + } + &:hover { + border-top: tall $surface; + background: $surface-darken-1; + } + + &.-active { + background: $surface; + border-bottom: tall $surface-lighten-1; + border-top: tall $surface-darken-1; + tint: $background 30%; + } + + &.-primary { + color: $button-color-foreground; + background: $primary; + border-top: tall $primary-lighten-3; + border-bottom: tall $primary-darken-3; + + &:hover { + background: $primary-darken-2; + border-top: tall $primary; + } + + &.-active { + background: $primary; + border-bottom: tall $primary-lighten-3; + border-top: tall $primary-darken-3; + } + } + + &.-success { + color: $button-color-foreground; + background: $success; + border-top: tall $success-lighten-2; + border-bottom: tall $success-darken-3; + + &:hover { + background: $success-darken-2; + border-top: tall $success; + } + + &.-active { + background: $success; + border-bottom: tall $success-lighten-2; + border-top: tall $success-darken-2; + } + } + + &.-warning{ + color: $button-color-foreground; + background: $warning; + border-top: tall $warning-lighten-2; + border-bottom: tall $warning-darken-3; + + &:hover { + background: $warning-darken-2; + border-top: tall $warning; + } + + &.-active { + background: $warning; + border-bottom: tall $warning-lighten-2; + border-top: tall $warning-darken-2; + } + } + + &.-error { + color: $button-color-foreground; + background: $error; + border-top: tall $error-lighten-2; + border-bottom: tall $error-darken-3; + + &:hover { + background: $error-darken-1; + border-top: tall $error; + } + + &.-active { + background: $error; + border-bottom: tall $error-lighten-2; + border-top: tall $error-darken-2; + } + } + } + } + """ + + BINDINGS = [Binding("enter", "press", "Press button", show=False)] + + label: reactive[ContentText] = reactive[ContentText](Content.empty()) + """The text label that appears within the button.""" + + variant = reactive("default", init=False) + """The variant name for the button.""" + + compact = reactive(False, toggle_class="-textual-compact") + """Make the button compact (without borders).""" + + flat = reactive(False) + """Enable alternative flat button style.""" + + class Pressed(Message): + """Event sent when a `Button` is pressed and there is no Button action. + + Can be handled using `on_button_pressed` in a subclass of + [`Button`][textual.widgets.Button] or in a parent widget in the DOM. + """ + + def __init__(self, button: Button) -> None: + self.button: Button = button + """The button that was pressed.""" + super().__init__() + + @property + def control(self) -> Button: + """An alias for [Pressed.button][textual.widgets.Button.Pressed.button]. + + This will be the same value as [Pressed.button][textual.widgets.Button.Pressed.button]. + """ + return self.button + + def __init__( + self, + label: ContentText | None = None, + variant: ButtonVariant = "default", + *, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + tooltip: RenderableType | None = None, + action: str | None = None, + compact: bool = False, + flat: bool = False, + ): + """Create a Button widget. + + Args: + label: The text that appears within the button. + variant: The variant of the button. + name: The name of the button. + id: The ID of the button in the DOM. + classes: The CSS classes of the button. + disabled: Whether the button is disabled or not. + tooltip: Optional tooltip. + action: Optional action to run when clicked. + compact: Enable compact button style. + flat: Enable alternative flat look buttons. + """ + super().__init__(name=name, id=id, classes=classes, disabled=disabled) + + if label is None: + label = self.css_identifier_styled + + self.variant = variant + self.flat = flat + self.compact = compact + self.set_reactive(Button.label, Content.from_text(label)) + + self.action = action + self.active_effect_duration = 0.2 + """Amount of time in seconds the button 'press' animation lasts.""" + + if tooltip is not None: + self.tooltip = tooltip + + def get_content_width(self, container: Size, viewport: Size) -> int: + assert isinstance(self.label, Content) + try: + return max([cell_len(line) for line in self.label.plain.splitlines()]) + 2 + except ValueError: + # Empty string label + return 2 + + def __rich_repr__(self) -> rich.repr.Result: + yield from super().__rich_repr__() + yield "variant", self.variant, "default" + + def validate_variant(self, variant: str) -> str: + if variant not in _VALID_BUTTON_VARIANTS: + raise InvalidButtonVariant( + f"Valid button variants are {friendly_list(_VALID_BUTTON_VARIANTS)}" + ) + return variant + + def watch_variant(self, old_variant: str, variant: str): + self.remove_class(f"-{old_variant}") + self.add_class(f"-{variant}") + + def watch_flat(self, flat: bool) -> None: + self.set_class(flat, "-style-flat") + self.set_class(not flat, "-style-default") + + def validate_label(self, label: ContentText) -> Content: + """Parse markup for self.label""" + return Content.from_text(label) + + def render(self) -> RenderResult: + assert isinstance(self.label, Content) + return self.label + + def post_render( + self, renderable: RenderableType, base_style: Style + ) -> ConsoleRenderable: + return cast(ConsoleRenderable, renderable) + + async def _on_click(self, event: events.Click) -> None: + event.stop() + if not self.has_class("-active"): + self.press() + + def press(self) -> Self: + """Animate the button and send the [Pressed][textual.widgets.Button.Pressed] message. + + Can be used to simulate the button being pressed by a user. + + Returns: + The button instance. + """ + if self.disabled or not self.display: + return self + # Manage the "active" effect: + self._start_active_affect() + # ...and let other components know that we've just been clicked: + if self.action is None: + self.post_message(Button.Pressed(self)) + else: + self.call_later( + self.app.run_action, self.action, default_namespace=self._parent + ) + return self + + def _start_active_affect(self) -> None: + """Start a small animation to show the button was clicked.""" + if self.active_effect_duration > 0: + self.add_class("-active") + self.set_timer( + self.active_effect_duration, partial(self.remove_class, "-active") + ) + + def action_press(self) -> None: + """Activate a press of the button.""" + if not self.has_class("-active"): + self.press() + + @classmethod + def success( + cls, + label: ContentText | None = None, + *, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + flat: bool = False, + ) -> Button: + """Utility constructor for creating a success Button variant. + + Args: + label: The text that appears within the button. + name: The name of the button. + id: The ID of the button in the DOM. + classes: The CSS classes of the button. + disabled: Whether the button is disabled or not. + flat: Enable alternative flat look buttons. + + Returns: + A [`Button`][textual.widgets.Button] widget of the 'success' + [variant][textual.widgets.button.ButtonVariant]. + """ + return Button( + label=label, + variant="success", + name=name, + id=id, + classes=classes, + disabled=disabled, + flat=flat, + ) + + @classmethod + def warning( + cls, + label: ContentText | None = None, + *, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + flat: bool = False, + ) -> Button: + """Utility constructor for creating a warning Button variant. + + Args: + label: The text that appears within the button. + name: The name of the button. + id: The ID of the button in the DOM. + classes: The CSS classes of the button. + disabled: Whether the button is disabled or not. + flat: Enable alternative flat look buttons. + + Returns: + A [`Button`][textual.widgets.Button] widget of the 'warning' + [variant][textual.widgets.button.ButtonVariant]. + """ + return Button( + label=label, + variant="warning", + name=name, + id=id, + classes=classes, + disabled=disabled, + flat=flat, + ) + + @classmethod + def error( + cls, + label: ContentText | None = None, + *, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + flat: bool = False, + ) -> Button: + """Utility constructor for creating an error Button variant. + + Args: + label: The text that appears within the button. + name: The name of the button. + id: The ID of the button in the DOM. + classes: The CSS classes of the button. + disabled: Whether the button is disabled or not. + flat: Enable alternative flat look buttons. + + Returns: + A [`Button`][textual.widgets.Button] widget of the 'error' + [variant][textual.widgets.button.ButtonVariant]. + """ + return Button( + label=label, + variant="error", + name=name, + id=id, + classes=classes, + disabled=disabled, + flat=flat, + ) diff --git a/src/memray/_vendor/textual/widgets/_checkbox.py b/src/memray/_vendor/textual/widgets/_checkbox.py new file mode 100644 index 0000000000..8e7573ac35 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_checkbox.py @@ -0,0 +1,26 @@ +"""Provides a check box widget.""" + +from __future__ import annotations + +from memray._vendor.textual.widgets._toggle_button import ToggleButton + + +class Checkbox(ToggleButton): + """A check box widget that represents a boolean value.""" + + class Changed(ToggleButton.Changed): + """Posted when the value of the checkbox changes. + + This message can be handled using an `on_checkbox_changed` method. + """ + + @property + def checkbox(self) -> Checkbox: + """The checkbox that was changed.""" + assert isinstance(self._toggle_button, Checkbox) + return self._toggle_button + + @property + def control(self) -> Checkbox: + """An alias for [Changed.checkbox][textual.widgets.Checkbox.Changed.checkbox].""" + return self.checkbox diff --git a/src/memray/_vendor/textual/widgets/_collapsible.py b/src/memray/_vendor/textual/widgets/_collapsible.py new file mode 100644 index 0000000000..679849fcd2 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_collapsible.py @@ -0,0 +1,246 @@ +from __future__ import annotations + +from memray._vendor.textual import events +from memray._vendor.textual.app import ComposeResult +from memray._vendor.textual.binding import Binding +from memray._vendor.textual.containers import Container +from memray._vendor.textual.content import Content, ContentText +from memray._vendor.textual.css.query import NoMatches +from memray._vendor.textual.message import Message +from memray._vendor.textual.reactive import reactive +from memray._vendor.textual.widget import Widget +from memray._vendor.textual.widgets import Static + +__all__ = ["Collapsible", "CollapsibleTitle"] + + +class CollapsibleTitle(Static, can_focus=True): + """Title and symbol for the Collapsible.""" + + BINDING_GROUP_TITLE = "Collapsible" + + ALLOW_SELECT = False + DEFAULT_CSS = """ + CollapsibleTitle { + width: auto; + height: auto; + padding: 0 1; + text-style: $block-cursor-blurred-text-style; + color: $block-cursor-blurred-foreground; + pointer: pointer; + + &:hover { + background: $block-hover-background; + color: $foreground; + } + &:focus { + text-style: $block-cursor-text-style; + background: $block-cursor-background; + color: $block-cursor-foreground; + } + } + """ + + BINDINGS = [ + Binding("enter", "toggle_collapsible", "Toggle collapsible", show=False) + ] + """ + | Key(s) | Description | + | :- | :- | + | enter | Toggle the collapsible. | + """ + + collapsed = reactive(True) + label: reactive[ContentText] = reactive(Content("Toggle")) + + def __init__( + self, + *, + label: ContentText, + collapsed_symbol: str, + expanded_symbol: str, + collapsed: bool, + ) -> None: + super().__init__() + self.collapsed_symbol = collapsed_symbol + self.expanded_symbol = expanded_symbol + self.label = Content.from_text(label) + self.collapsed = collapsed + + class Toggle(Message): + """Request toggle.""" + + async def _on_click(self, event: events.Click) -> None: + """Inform ancestor we want to toggle.""" + event.stop() + self.post_message(self.Toggle()) + + def action_toggle_collapsible(self) -> None: + """Toggle the state of the parent collapsible.""" + self.post_message(self.Toggle()) + + def validate_label(self, label: ContentText) -> Content: + return Content.from_text(label) + + def _update_label(self) -> None: + assert isinstance(self.label, Content) + if self.collapsed: + self.update(Content.assemble(self.collapsed_symbol, " ", self.label)) + else: + self.update(Content.assemble(self.expanded_symbol, " ", self.label)) + + def _watch_label(self) -> None: + self._update_label() + + def _watch_collapsed(self, collapsed: bool) -> None: + self._update_label() + + +class Collapsible(Widget): + """A collapsible container.""" + + ALLOW_MAXIMIZE = True + collapsed = reactive(True, init=False) + title = reactive("Toggle") + + DEFAULT_CSS = """ + Collapsible { + width: 1fr; + height: auto; + background: $surface; + border-top: hkey $background; + padding-bottom: 1; + padding-left: 1; + + &:focus-within { + background-tint: $foreground 5%; + } + + &.-collapsed > Contents { + display: none; + } + } + """ + + class Toggled(Message): + """Parent class subclassed by `Collapsible` messages. + + Can be handled with `on(Collapsible.Toggled)` if you want to handle expansions + and collapsed in the same way, or you can handle the specific events individually. + """ + + def __init__(self, collapsible: Collapsible) -> None: + """Create an instance of the message. + + Args: + collapsible: The `Collapsible` widget that was toggled. + """ + self.collapsible: Collapsible = collapsible + """The collapsible that was toggled.""" + super().__init__() + + @property + def control(self) -> Collapsible: + """An alias for [Toggled.collapsible][textual.widgets.Collapsible.Toggled.collapsible].""" + return self.collapsible + + class Expanded(Toggled): + """Event sent when the `Collapsible` widget is expanded. + + Can be handled using `on_collapsible_expanded` in a subclass of + [`Collapsible`][textual.widgets.Collapsible] or in a parent widget in the DOM. + """ + + class Collapsed(Toggled): + """Event sent when the `Collapsible` widget is collapsed. + + Can be handled using `on_collapsible_collapsed` in a subclass of + [`Collapsible`][textual.widgets.Collapsible] or in a parent widget in the DOM. + """ + + class Contents(Container): + DEFAULT_CSS = """ + Contents { + width: 100%; + height: auto; + padding: 1 0 0 3; + } + """ + + def __init__( + self, + *children: Widget, + title: str = "Toggle", + collapsed: bool = True, + collapsed_symbol: str = "▶", + expanded_symbol: str = "▼", + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + ) -> None: + """Initialize a Collapsible widget. + + Args: + *children: Contents that will be collapsed/expanded. + title: Title of the collapsed/expanded contents. + collapsed: Default status of the contents. + collapsed_symbol: Collapsed symbol before the title. + expanded_symbol: Expanded symbol before the title. + name: The name of the collapsible. + id: The ID of the collapsible in the DOM. + classes: The CSS classes of the collapsible. + disabled: Whether the collapsible is disabled or not. + """ + super().__init__(name=name, id=id, classes=classes, disabled=disabled) + self._title = CollapsibleTitle( + label=title, + collapsed_symbol=collapsed_symbol, + expanded_symbol=expanded_symbol, + collapsed=collapsed, + ) + self.title = title + self._contents_list: list[Widget] = list(children) + self.collapsed = collapsed + + def _on_collapsible_title_toggle(self, event: CollapsibleTitle.Toggle) -> None: + event.stop() + self.collapsed = not self.collapsed + + def _watch_collapsed(self, collapsed: bool) -> None: + """Update collapsed state when reactive is changed.""" + self._update_collapsed(collapsed) + if self.collapsed: + self.post_message(self.Collapsed(self)) + else: + self.post_message(self.Expanded(self)) + if self.is_mounted: + self.call_after_refresh(self.scroll_visible) + + def _update_collapsed(self, collapsed: bool) -> None: + """Update children to match collapsed state.""" + try: + self._title.collapsed = collapsed + self.set_class(collapsed, "-collapsed") + except NoMatches: + pass + + def _on_mount(self, event: events.Mount) -> None: + """Initialise collapsed state.""" + self._update_collapsed(self.collapsed) + + def compose(self) -> ComposeResult: + yield self._title + with self.Contents(): + yield from self._contents_list + + def compose_add_child(self, widget: Widget) -> None: + """When using the context manager compose syntax, we want to attach nodes to the contents. + + Args: + widget: A Widget to add. + """ + self._contents_list.append(widget) + + def _watch_title(self, title: str) -> None: + self._title.label = title diff --git a/src/memray/_vendor/textual/widgets/_content_switcher.py b/src/memray/_vendor/textual/widgets/_content_switcher.py new file mode 100644 index 0000000000..a0fca10b51 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_content_switcher.py @@ -0,0 +1,132 @@ +"""Provides a widget for switching between the display of its immediate children.""" + +from __future__ import annotations + +from typing import Optional + +from memray._vendor.textual.await_complete import AwaitComplete +from memray._vendor.textual.containers import Container +from memray._vendor.textual.css.query import NoMatches +from memray._vendor.textual.events import Mount +from memray._vendor.textual.reactive import reactive +from memray._vendor.textual.widget import Widget + + +class ContentSwitcher(Container): + """A widget for switching between different children. + + Note: + All child widgets that are to be switched between need a unique ID. + Children that have no ID will be hidden and ignored. + """ + + DEFAULT_CSS = """ + ContentSwitcher { + height: auto; + } + + """ + + current: reactive[str | None] = reactive[Optional[str]](None, init=False) + """The ID of the currently-displayed widget. + + If set to `None` then no widget is visible. + + Note: + If set to an unknown ID, this will result in + [`NoMatches`][textual.css.query.NoMatches] being raised. + """ + + def __init__( + self, + *children: Widget, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + initial: str | None = None, + ) -> None: + """Initialise the content switching widget. + + Args: + *children: The widgets to switch between. + name: The name of the content switcher. + id: The ID of the content switcher in the DOM. + classes: The CSS classes of the content switcher. + disabled: Whether the content switcher is disabled or not. + initial: The ID of the initial widget to show, ``None`` or empty string for the first tab. + + Note: + If `initial` is not supplied no children will be shown to start with. + """ + super().__init__( + *children, + name=name, + id=id, + classes=classes, + disabled=disabled, + ) + self._initial = initial + + def _on_mount(self, _: Mount) -> None: + """Perform the initial setup of the widget once the DOM is ready.""" + initial = self._initial + with self.app.batch_update(): + for child in self.children: + child.display = bool(initial) and child.id == initial + self._reactive_current = initial + + @property + def visible_content(self) -> Widget | None: + """A reference to the currently-visible widget. + + `None` if nothing is visible. + """ + return self.get_child_by_id(self.current) if self.current is not None else None + + def watch_current(self, old: str | None, new: str | None) -> None: + """React to the current visible child choice being changed. + + Args: + old: The old widget ID (or `None` if there was no widget). + new: The new widget ID (or `None` if nothing should be shown). + """ + with self.app.batch_update(): + if old: + try: + self.get_child_by_id(old).display = False + except NoMatches: + pass + if new: + self.get_child_by_id(new).display = True + + def add_content( + self, widget: Widget, *, id: str | None = None, set_current: bool = False + ) -> AwaitComplete: + """Add new content to the `ContentSwitcher`. + + Args: + widget: A Widget to add. + id: ID for the widget, or `None` if the widget already has an ID. + set_current: Set the new widget as current (which will cause it to display). + + Returns: + An awaitable to wait for the new content to be mounted. + """ + if id is not None and widget.id != id: + widget.id = id + + if not widget.id: + raise ValueError( + "Widget must have an ID (or set id parameter when calling add_content)" + ) + + async def _add_content() -> None: + """Add new widget and potentially change the current widget.""" + widget.display = False + with self.app.batch_update(): + await self.mount(widget) + if set_current: + self.current = widget.id + + return AwaitComplete(_add_content()) diff --git a/src/memray/_vendor/textual/widgets/_data_table.py b/src/memray/_vendor/textual/widgets/_data_table.py new file mode 100644 index 0000000000..345f178691 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_data_table.py @@ -0,0 +1,2864 @@ +from __future__ import annotations + +import functools +from dataclasses import dataclass +from itertools import chain, zip_longest +from operator import itemgetter +from typing import ( + Any, + Callable, + ClassVar, + Generic, + Iterable, + NamedTuple, + TypeVar, + Union, +) + +import rich.repr +from rich.console import RenderableType +from rich.padding import Padding +from rich.protocol import is_renderable +from rich.segment import Segment +from rich.style import Style +from rich.text import Text, TextType +from typing_extensions import Literal, Self, TypeAlias + +from memray._vendor.textual import events +from memray._vendor.textual._segment_tools import line_crop +from memray._vendor.textual._two_way_dict import TwoWayDict +from memray._vendor.textual._types import SegmentLines +from memray._vendor.textual.binding import Binding, BindingType +from memray._vendor.textual.cache import LRUCache +from memray._vendor.textual.color import Color +from memray._vendor.textual.coordinate import Coordinate +from memray._vendor.textual.geometry import Region, Size, Spacing, clamp +from memray._vendor.textual.message import Message +from memray._vendor.textual.reactive import Reactive +from memray._vendor.textual.render import measure +from memray._vendor.textual.renderables.styled import Styled +from memray._vendor.textual.scroll_view import ScrollView +from memray._vendor.textual.strip import Strip +from memray._vendor.textual.widget import PseudoClasses + +CellCacheKey: TypeAlias = ( + "tuple[RowKey, ColumnKey, Style, bool, bool, bool, int, PseudoClasses]" +) +LineCacheKey: TypeAlias = ( + "tuple[int, int, int, int, Coordinate, Coordinate, Style, CursorType, bool, int, PseudoClasses]" +) +RowCacheKey: TypeAlias = ( + "tuple[RowKey, int, Style, Coordinate, Coordinate, CursorType, bool, bool, int, PseudoClasses]" +) +CursorType = Literal["cell", "row", "column", "none"] +"""The valid types of cursors for [`DataTable.cursor_type`][textual.widgets.DataTable.cursor_type].""" +CellType = TypeVar("CellType") +"""Type used for cells in the DataTable.""" + +_EMPTY_TEXT = Text(no_wrap=True, end="") + + +class CellDoesNotExist(Exception): + """The cell key/index was invalid. + + Raised when the coordinates or cell key provided does not exist + in the DataTable (e.g. out of bounds index, invalid key)""" + + +class RowDoesNotExist(Exception): + """Raised when the row index or row key provided does not exist + in the DataTable (e.g. out of bounds index, invalid key)""" + + +class ColumnDoesNotExist(Exception): + """Raised when the column index or column key provided does not exist + in the DataTable (e.g. out of bounds index, invalid key)""" + + +class DuplicateKey(Exception): + """The key supplied already exists. + + Raised when the RowKey or ColumnKey provided already refers to + an existing row or column in the DataTable. Keys must be unique.""" + + +@functools.total_ordering +class StringKey: + """An object used as a key in a mapping. + + It can optionally wrap a string, + and lookups into a map using the object behave the same as lookups using + the string itself.""" + + value: str | None + + def __init__(self, value: str | None = None): + self.value = value + + def __hash__(self): + # If a string is supplied, we use the hash of the string. If no string was + # supplied, we use the default hash to ensure uniqueness amongst instances. + return hash(self.value) if self.value is not None else id(self) + + def __eq__(self, other: object) -> bool: + # Strings will match Keys containing the same string value. + # Otherwise, you'll need to supply the exact same key object. + if isinstance(other, str): + return self.value == other + elif isinstance(other, StringKey): + if self.value is not None and other.value is not None: + return self.value == other.value + else: + return hash(self) == hash(other) + else: + return NotImplemented + + def __lt__(self, other): + if isinstance(other, str): + return self.value < other + elif isinstance(other, StringKey): + return self.value < other.value + else: + return NotImplemented + + def __rich_repr__(self): + yield "value", self.value + + +class RowKey(StringKey): + """Uniquely identifies a row in the DataTable. + + Even if the visual location + of the row changes due to sorting or other modifications, a key will always + refer to the same row.""" + + +class ColumnKey(StringKey): + """Uniquely identifies a column in the DataTable. + + Even if the visual location + of the column changes due to sorting or other modifications, a key will always + refer to the same column.""" + + +class CellKey(NamedTuple): + """A unique identifier for a cell in the DataTable. + + A cell key is a `(row_key, column_key)` tuple. + + Even if the cell changes + visual location (i.e. moves to a different coordinate in the table), this key + can still be used to retrieve it, regardless of where it currently is.""" + + row_key: RowKey + """The key of this cell's row.""" + column_key: ColumnKey + """The key of this cell's column.""" + + def __rich_repr__(self): + yield "row_key", self.row_key + yield "column_key", self.column_key + + +def _find_newline(string: str, number: int) -> int: + """Find newline number n (the nth newline) in a string. + + Args: + string: The string to search. + number: The nth newline character to find. + + Returns: + The index of the nth newline character, or -1 if not found. + """ + if not string or number < 1: + return -1 + + pos = -1 + for _ in range(number): + pos = string.find("\n", pos + 1) + if pos == -1: + break + return pos + + +def default_cell_formatter( + obj: object, wrap: bool = True, height: int = 0 +) -> RenderableType: + """Convert a cell into a Rich renderable for display. + + Args: + obj: Data for a cell. + wrap: Enable or disable wrapping inside the cell. + height: The height of the cell, or `None` to render the entire cell. + This can be used to short-circuit rendering. e.g. If we know the cell + has a height of 1, we can render the cell as a single line of text + without any wrapping. + + Returns: + A renderable to be displayed which represents the data. + """ + # Get the string which will be displayed in the cell. + possible_markup = False + if isinstance(obj, str): + possible_markup = True + content = obj + elif isinstance(obj, float): + content = f"{obj:.2f}" + elif not is_renderable(obj): + content = str(obj) + else: + return obj + + if height: + # Let's throw away lines which definitely won't appear in the cell + # after wrapping using the height constraint. A cell can only grow + # vertically after wrapping occurs, so this is a safe operation. + trim_position = _find_newline(content, height) + if trim_position != -1 and trim_position != len(content) - 1: + content = content[:trim_position] + + if possible_markup: + text = Text.from_markup(content, end="") + text.no_wrap = not wrap + return text + return Text(content, no_wrap=not wrap, end="") + + +@dataclass +class Column: + """Metadata for a column in the DataTable.""" + + key: ColumnKey + label: Text + width: int = 0 + content_width: int = 0 + auto_width: bool = False + + def get_render_width(self, data_table: DataTable[Any]) -> int: + """Width, in cells, required to render the column with padding included. + + Args: + data_table: The data table where the column will be rendered. + + Returns: + The width, in cells, required to render the column with padding included. + """ + return 2 * data_table.cell_padding + ( + self.content_width if self.auto_width else self.width + ) + + +@dataclass +class Row: + """Metadata for a row in the DataTable.""" + + key: RowKey + height: int + label: Text | None = None + auto_height: bool = False + + +class RowRenderables(NamedTuple): + """Container for a row, which contains an optional label and some data cells.""" + + label: RenderableType | None + cells: list[RenderableType] + + +class DataTable(ScrollView, Generic[CellType], can_focus=True): + """A tabular widget that contains data.""" + + ALLOW_SELECT = False + + BINDINGS: ClassVar[list[BindingType]] = [ + Binding("enter", "select_cursor", "Select", show=False), + Binding("up", "cursor_up", "Cursor up", show=False), + Binding("down", "cursor_down", "Cursor down", show=False), + Binding("right", "cursor_right", "Cursor right", show=False), + Binding("left", "cursor_left", "Cursor left", show=False), + Binding("pageup", "page_up", "Page up", show=False), + Binding("pagedown", "page_down", "Page down", show=False), + Binding("ctrl+home", "scroll_top", "Top", show=False), + Binding("ctrl+end", "scroll_bottom", "Bottom", show=False), + Binding("home", "scroll_home", "Home", show=False), + Binding("end", "scroll_end", "End", show=False), + ] + """ + | Key(s) | Description | + | :- | :- | + | enter | Select cells under the cursor. | + | up | Move the cursor up. | + | down | Move the cursor down. | + | right | Move the cursor right. | + | left | Move the cursor left. | + | pageup | Move one page up. | + | pagedown | Move one page down. | + | ctrl+home | Move to the top. | + | ctrl+end | Move to the bottom. | + | home | Move to the home position (leftmost column). | + | end | Move to the end position (rightmost column). | + """ + + COMPONENT_CLASSES: ClassVar[set[str]] = { + "datatable--cursor", + "datatable--hover", + "datatable--fixed", + "datatable--fixed-cursor", + "datatable--header", + "datatable--header-cursor", + "datatable--header-hover", + "datatable--odd-row", + "datatable--even-row", + } + """ + | Class | Description | + | :- | :- | + | `datatable--cursor` | Target the cursor. | + | `datatable--hover` | Target the cells under the hover cursor. | + | `datatable--fixed` | Target fixed columns and fixed rows. | + | `datatable--fixed-cursor` | Target highlighted and fixed columns or header. | + | `datatable--header` | Target the header of the data table. | + | `datatable--header-cursor` | Target cells highlighted by the cursor. | + | `datatable--header-hover` | Target hovered header or row label cells. | + | `datatable--even-row` | Target even rows (row indices start at 0) if zebra_stripes. | + | `datatable--odd-row` | Target odd rows (row indices start at 0) if zebra_stripes. | + """ + + DEFAULT_CSS = """ + DataTable { + background: $surface; + color: $foreground; + height: auto; + max-height: 100%; + &.datatable--fixed-cursor { + background: $block-cursor-blurred-background; + } + + &:focus { + background-tint: $foreground 5%; + & > .datatable--cursor { + background: $block-cursor-background; + color: $block-cursor-foreground; + text-style: $block-cursor-text-style; + } + + & > .datatable--header { + background-tint: $foreground 5%; + } + + & > .datatable--fixed-cursor { + color: $block-cursor-foreground; + background: $block-cursor-background; + } + } + + &:dark { + & > .datatable--even-row { + background: $surface-darken-1 40%; + } + } + + & > .datatable--header { + text-style: bold; + background: $panel; + color: $foreground; + } + &:ansi > .datatable--header { + background: ansi_bright_blue; + color: ansi_default; + } + + & > .datatable--fixed { + background: $secondary-muted; + color: $foreground; + } + + & > .datatable--odd-row { + + } + + & > .datatable--even-row { + background: $surface-lighten-1 50%; + } + + & > .datatable--cursor { + background: $block-cursor-blurred-background; + color: $block-cursor-blurred-foreground; + text-style: $block-cursor-blurred-text-style; + } + + & > .datatable--fixed-cursor { + background: $block-cursor-blurred-background; + color: $foreground; + } + + & > .datatable--header-cursor { + background: $accent-darken-1; + color: $foreground; + } + + & > .datatable--header-hover { + background: $accent 30%; + } + + & > .datatable--hover { + background: $block-hover-background; + } + } + """ + + show_header = Reactive(True) + show_row_labels = Reactive(True) + fixed_rows = Reactive(0) + fixed_columns = Reactive(0) + zebra_stripes = Reactive(False) + header_height = Reactive(1) + show_cursor = Reactive(True) + cursor_type: Reactive[CursorType] = Reactive[CursorType]("cell") + """The type of the cursor of the `DataTable`.""" + cell_padding = Reactive(1) + """Horizontal padding between cells, applied on each side of each cell.""" + + cursor_coordinate: Reactive[Coordinate] = Reactive( + Coordinate(0, 0), repaint=False, always_update=True + ) + """Current cursor [`Coordinate`][textual.coordinate.Coordinate]. + + This can be set programmatically or changed via the method + [`move_cursor`][textual.widgets.DataTable.move_cursor]. + """ + hover_coordinate: Reactive[Coordinate] = Reactive( + Coordinate(0, 0), repaint=False, always_update=True + ) + """The coordinate of the `DataTable` that is being hovered.""" + + class CellHighlighted(Message): + """Posted when the cursor moves to highlight a new cell. + + This is only relevant when the `cursor_type` is `"cell"`. + It's also posted when the cell cursor is + re-enabled (by setting `show_cursor=True`), and when the cursor type is + changed to `"cell"`. Can be handled using `on_data_table_cell_highlighted` in + a subclass of `DataTable` or in a parent widget in the DOM. + """ + + def __init__( + self, + data_table: DataTable, + value: CellType, + coordinate: Coordinate, + cell_key: CellKey, + ) -> None: + self.data_table = data_table + """The data table.""" + self.value = value + """The value in the highlighted cell.""" + self.coordinate: Coordinate = coordinate + """The coordinate of the highlighted cell.""" + self.cell_key: CellKey = cell_key + """The key for the highlighted cell.""" + super().__init__() + + def __rich_repr__(self) -> rich.repr.Result: + yield "value", self.value + yield "coordinate", self.coordinate + yield "cell_key", self.cell_key + + @property + def control(self) -> DataTable: + """Alias for the data table.""" + return self.data_table + + class CellSelected(Message): + """Posted by the `DataTable` widget when a cell is selected. + + This is only relevant when the `cursor_type` is `"cell"`. Can be handled using + `on_data_table_cell_selected` in a subclass of `DataTable` or in a parent + widget in the DOM. + """ + + def __init__( + self, + data_table: DataTable, + value: CellType, + coordinate: Coordinate, + cell_key: CellKey, + ) -> None: + self.data_table = data_table + """The data table.""" + self.value: CellType = value + """The value in the cell that was selected.""" + self.coordinate: Coordinate = coordinate + """The coordinate of the cell that was selected.""" + self.cell_key: CellKey = cell_key + """The key for the selected cell.""" + super().__init__() + + def __rich_repr__(self) -> rich.repr.Result: + yield "value", self.value + yield "coordinate", self.coordinate + yield "cell_key", self.cell_key + + @property + def control(self) -> DataTable: + """Alias for the data table.""" + return self.data_table + + class RowHighlighted(Message): + """Posted when a row is highlighted. + + This message is only posted when the + `cursor_type` is set to `"row"`. Can be handled using + `on_data_table_row_highlighted` in a subclass of `DataTable` or in a parent + widget in the DOM. + """ + + def __init__( + self, data_table: DataTable, cursor_row: int, row_key: RowKey + ) -> None: + self.data_table = data_table + """The data table.""" + self.cursor_row: int = cursor_row + """The y-coordinate of the cursor that highlighted the row.""" + self.row_key: RowKey = row_key + """The key of the row that was highlighted.""" + super().__init__() + + def __rich_repr__(self) -> rich.repr.Result: + yield "cursor_row", self.cursor_row + yield "row_key", self.row_key + + @property + def control(self) -> DataTable: + """Alias for the data table.""" + return self.data_table + + class RowSelected(Message): + """Posted when a row is selected. + + This message is only posted when the + `cursor_type` is set to `"row"`. Can be handled using + `on_data_table_row_selected` in a subclass of `DataTable` or in a parent + widget in the DOM. + """ + + def __init__( + self, data_table: DataTable, cursor_row: int, row_key: RowKey + ) -> None: + self.data_table = data_table + """The data table.""" + self.cursor_row: int = cursor_row + """The y-coordinate of the cursor that made the selection.""" + self.row_key: RowKey = row_key + """The key of the row that was selected.""" + super().__init__() + + def __rich_repr__(self) -> rich.repr.Result: + yield "cursor_row", self.cursor_row + yield "row_key", self.row_key + + @property + def control(self) -> DataTable: + """Alias for the data table.""" + return self.data_table + + class ColumnHighlighted(Message): + """Posted when a column is highlighted. + + This message is only posted when the + `cursor_type` is set to `"column"`. Can be handled using + `on_data_table_column_highlighted` in a subclass of `DataTable` or in a parent + widget in the DOM. + """ + + def __init__( + self, data_table: DataTable, cursor_column: int, column_key: ColumnKey + ) -> None: + self.data_table = data_table + """The data table.""" + self.cursor_column: int = cursor_column + """The x-coordinate of the column that was highlighted.""" + self.column_key = column_key + """The key of the column that was highlighted.""" + super().__init__() + + def __rich_repr__(self) -> rich.repr.Result: + yield "cursor_column", self.cursor_column + yield "column_key", self.column_key + + @property + def control(self) -> DataTable: + """Alias for the data table.""" + return self.data_table + + class ColumnSelected(Message): + """Posted when a column is selected. + + This message is only posted when the + `cursor_type` is set to `"column"`. Can be handled using + `on_data_table_column_selected` in a subclass of `DataTable` or in a parent + widget in the DOM. + """ + + def __init__( + self, data_table: DataTable, cursor_column: int, column_key: ColumnKey + ) -> None: + self.data_table = data_table + """The data table.""" + self.cursor_column: int = cursor_column + """The x-coordinate of the column that was selected.""" + self.column_key = column_key + """The key of the column that was selected.""" + super().__init__() + + def __rich_repr__(self) -> rich.repr.Result: + yield "cursor_column", self.cursor_column + yield "column_key", self.column_key + + @property + def control(self) -> DataTable: + """Alias for the data table.""" + return self.data_table + + class HeaderSelected(Message): + """Posted when a column header/label is clicked.""" + + def __init__( + self, + data_table: DataTable, + column_key: ColumnKey, + column_index: int, + label: Text, + ): + self.data_table = data_table + """The data table.""" + self.column_key = column_key + """The key for the column.""" + self.column_index = column_index + """The index for the column.""" + self.label = label + """The text of the label.""" + super().__init__() + + def __rich_repr__(self) -> rich.repr.Result: + yield "column_key", self.column_key + yield "column_index", self.column_index + yield "label", self.label.plain + + @property + def control(self) -> DataTable: + """Alias for the data table.""" + return self.data_table + + class RowLabelSelected(Message): + """Posted when a row label is clicked.""" + + def __init__( + self, + data_table: DataTable, + row_key: RowKey, + row_index: int, + label: Text, + ): + self.data_table = data_table + """The data table.""" + self.row_key = row_key + """The key for the column.""" + self.row_index = row_index + """The index for the column.""" + self.label = label + """The text of the label.""" + super().__init__() + + def __rich_repr__(self) -> rich.repr.Result: + yield "row_key", self.row_key + yield "row_index", self.row_index + yield "label", self.label.plain + + @property + def control(self) -> DataTable: + """Alias for the data table.""" + return self.data_table + + def __init__( + self, + *, + show_header: bool = True, + show_row_labels: bool = True, + fixed_rows: int = 0, + fixed_columns: int = 0, + zebra_stripes: bool = False, + header_height: int = 1, + show_cursor: bool = True, + cursor_foreground_priority: Literal["renderable", "css"] = "css", + cursor_background_priority: Literal["renderable", "css"] = "renderable", + cursor_type: CursorType = "cell", + cell_padding: int = 1, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + ) -> None: + """Initializes a widget to display tabular data. + + Args: + show_header: Whether the table header should be visible or not. + show_row_labels: Whether the row labels should be shown or not. + fixed_rows: The number of rows, counting from the top, that should be fixed + and still visible when the user scrolls down. + fixed_columns: The number of columns, counting from the left, that should be + fixed and still visible when the user scrolls right. + zebra_stripes: Enables or disables a zebra effect applied to the background + color of the rows of the table, where alternate colors are styled + differently to improve the readability of the table. + header_height: The height, in number of cells, of the data table header. + show_cursor: Whether the cursor should be visible when navigating the data + table or not. + cursor_foreground_priority: If the data associated with a cell is an + arbitrary renderable with a set foreground color, this determines whether + that color is prioritized over the cursor component class or not. + cursor_background_priority: If the data associated with a cell is an + arbitrary renderable with a set background color, this determines whether + that color is prioritized over the cursor component class or not. + cursor_type: The type of cursor to be used when navigating the data table + with the keyboard. + cell_padding: The number of cells added on each side of each column. Setting + this value to zero will likely make your table very hard to read. + name: The name of the widget. + id: The ID of the widget in the DOM. + classes: The CSS classes for the widget. + disabled: Whether the widget is disabled or not. + """ + + super().__init__(name=name, id=id, classes=classes, disabled=disabled) + self._data: dict[RowKey, dict[ColumnKey, CellType]] = {} + """Contains the cells of the table, indexed by row key and column key. + The final positioning of a cell on screen cannot be determined solely by this + structure. Instead, we must check _row_locations and _column_locations to find + where each cell currently resides in space.""" + + self.columns: dict[ColumnKey, Column] = {} + """Metadata about the columns of the table, indexed by their key.""" + self.rows: dict[RowKey, Row] = {} + """Metadata about the rows of the table, indexed by their key.""" + + # Keep tracking of key -> index for rows/cols. These allow us to retrieve, + # given a row or column key, the index that row or column is currently + # present at, and mean that rows and columns are location independent - they + # can move around without requiring us to modify the underlying data. + self._row_locations: TwoWayDict[RowKey, int] = TwoWayDict({}) + """Maps row keys to row indices which represent row order.""" + self._column_locations: TwoWayDict[ColumnKey, int] = TwoWayDict({}) + """Maps column keys to column indices which represent column order.""" + + self._row_render_cache: LRUCache[ + RowCacheKey, tuple[SegmentLines, SegmentLines] + ] = LRUCache(1000) + """For each row (a row can have a height of multiple lines), we maintain a + cache of the fixed and scrollable lines within that row to minimize how often + we need to re-render it. """ + self._cell_render_cache: LRUCache[CellCacheKey, SegmentLines] = LRUCache(10000) + """Cache for individual cells.""" + self._row_renderable_cache: LRUCache[tuple[int, int], RowRenderables] = ( + LRUCache(1000) + ) + """Caches row renderables - key is (update_count, row_index)""" + self._line_cache: LRUCache[LineCacheKey, Strip] = LRUCache(1000) + """Cache for lines within rows.""" + self._offset_cache: LRUCache[int, list[tuple[RowKey, int]]] = LRUCache(1) + """Cached y_offset - key is update_count - see y_offsets property for more + information """ + self._ordered_row_cache: LRUCache[tuple[int, int], list[Row]] = LRUCache(1) + """Caches row ordering - key is (num_rows, update_count).""" + + self._pseudo_class_state = PseudoClasses(False, False, False) + """The pseudo-class state is used as part of cache keys to ensure that, for example, + when we lose focus on the DataTable, rules which apply to :focus are invalidated + and we prevent lingering styles.""" + + self._require_update_dimensions: bool = False + """Set to re-calculate dimensions on idle.""" + self._new_rows: set[RowKey] = set() + """Tracking newly added rows to be used in calculation of dimensions on idle.""" + self._updated_cells: set[CellKey] = set() + """Track which cells were updated, so that we can refresh them once on idle.""" + + self._show_hover_cursor = False + """Used to hide the mouse hover cursor when the user uses the keyboard.""" + self._update_count = 0 + """Number of update (INCLUDING SORT) operations so far. Used for cache invalidation.""" + self._header_row_key = RowKey() + """The header is a special row - not part of the data. Retrieve via this key.""" + self._label_column_key = ColumnKey() + """The column containing row labels is not part of the data. This key identifies it.""" + self._labelled_row_exists = False + """Whether or not the user has supplied any rows with labels.""" + self._label_column = Column(self._label_column_key, Text(), auto_width=True) + """The largest content width out of all row labels in the table.""" + + self.show_header = show_header + """Show/hide the header row (the row of column labels).""" + self.show_row_labels = show_row_labels + """Show/hide the column containing the labels of rows.""" + self.header_height = header_height + """The height of the header row (the row of column labels).""" + self.fixed_rows = fixed_rows + """The number of rows to fix (prevented from scrolling).""" + self.fixed_columns = fixed_columns + """The number of columns to fix (prevented from scrolling).""" + self.zebra_stripes = zebra_stripes + """Apply alternating styles, datatable--even-row and datatable-odd-row, to create a zebra effect, e.g., + alternating light and dark backgrounds.""" + self.show_cursor = show_cursor + """Show/hide both the keyboard and hover cursor.""" + self.cursor_foreground_priority = cursor_foreground_priority + """Should we prioritize the cursor component class CSS foreground or the renderable foreground + in the event where a cell contains a renderable with a foreground color.""" + self.cursor_background_priority = cursor_background_priority + """Should we prioritize the cursor component class CSS background or the renderable background + in the event where a cell contains a renderable with a background color.""" + self.cursor_type = cursor_type + """The type of cursor of the `DataTable`.""" + self.cell_padding = cell_padding + """Horizontal padding between cells, applied on each side of each cell.""" + + @property + def hover_row(self) -> int: + """The index of the row that the mouse cursor is currently hovering above.""" + return self.hover_coordinate.row + + @property + def hover_column(self) -> int: + """The index of the column that the mouse cursor is currently hovering above.""" + return self.hover_coordinate.column + + @property + def cursor_row(self) -> int: + """The index of the row that the DataTable cursor is currently on.""" + return self.cursor_coordinate.row + + @property + def cursor_column(self) -> int: + """The index of the column that the DataTable cursor is currently on.""" + return self.cursor_coordinate.column + + @property + def row_count(self) -> int: + """The number of rows currently present in the DataTable.""" + return len(self.rows) + + @property + def _y_offsets(self) -> list[tuple[RowKey, int]]: + """Contains a 2-tuple for each line (not row!) of the DataTable. Given a + y-coordinate, we can index into this list to find which row that y-coordinate + lands on, and the y-offset *within* that row. The length of the returned list + is therefore the total height of all rows within the DataTable.""" + y_offsets: list[tuple[RowKey, int]] = [] + if self._update_count in self._offset_cache: + y_offsets = self._offset_cache[self._update_count] + else: + for row in self.ordered_rows: + y_offsets += [(row.key, y) for y in range(row.height)] + self._offset_cache[self._update_count] = y_offsets + + return y_offsets + + @property + def _total_row_height(self) -> int: + """The total height of all rows within the DataTable""" + return len(self._y_offsets) + + def update_cell( + self, + row_key: RowKey | str, + column_key: ColumnKey | str, + value: CellType, + *, + update_width: bool = False, + ) -> None: + """Update the cell identified by the specified row key and column key. + + Args: + row_key: The key identifying the row. + column_key: The key identifying the column. + value: The new value to put inside the cell. + update_width: Whether to resize the column width to accommodate + for the new cell content. + + Raises: + CellDoesNotExist: When the supplied `row_key` and `column_key` + cannot be found in the table. + """ + if isinstance(row_key, str): + row_key = RowKey(row_key) + if isinstance(column_key, str): + column_key = ColumnKey(column_key) + + if ( + row_key not in self._row_locations + or column_key not in self._column_locations + ): + raise CellDoesNotExist( + f"No cell exists for row_key={row_key!r}, column_key={column_key!r}." + ) + + self._data[row_key][column_key] = value + self._update_count += 1 + + # Recalculate widths if necessary + if update_width: + self._updated_cells.add(CellKey(row_key, column_key)) + self._require_update_dimensions = True + + self.refresh() + + def update_cell_at( + self, coordinate: Coordinate, value: CellType, *, update_width: bool = False + ) -> None: + """Update the content inside the cell currently occupying the given coordinate. + + Args: + coordinate: The coordinate to update the cell at. + value: The new value to place inside the cell. + update_width: Whether to resize the column width to accommodate + for the new cell content. + """ + if not self.is_valid_coordinate(coordinate): + raise CellDoesNotExist(f"Coordinate {coordinate!r} is invalid.") + + row_key, column_key = self.coordinate_to_cell_key(coordinate) + self.update_cell(row_key, column_key, value, update_width=update_width) + + def get_cell(self, row_key: RowKey | str, column_key: ColumnKey | str) -> CellType: + """Given a row key and column key, return the value of the corresponding cell. + + Args: + row_key: The row key of the cell. + column_key: The column key of the cell. + + Returns: + The value of the cell identified by the row and column keys. + """ + try: + cell_value = self._data[row_key][column_key] + except KeyError: + raise CellDoesNotExist( + f"No cell exists for row_key={row_key!r}, column_key={column_key!r}." + ) + return cell_value + + def get_cell_at(self, coordinate: Coordinate) -> CellType: + """Get the value from the cell occupying the given coordinate. + + Args: + coordinate: The coordinate to retrieve the value from. + + Returns: + The value of the cell at the coordinate. + + Raises: + CellDoesNotExist: If there is no cell with the given coordinate. + """ + row_key, column_key = self.coordinate_to_cell_key(coordinate) + return self.get_cell(row_key, column_key) + + def get_cell_coordinate( + self, row_key: RowKey | str, column_key: ColumnKey | str + ) -> Coordinate: + """Given a row key and column key, return the corresponding cell coordinate. + + Args: + row_key: The row key of the cell. + column_key: The column key of the cell. + + Returns: + The current coordinate of the cell identified by the row and column keys. + + Raises: + CellDoesNotExist: If the specified cell does not exist. + """ + if ( + row_key not in self._row_locations + or column_key not in self._column_locations + ): + raise CellDoesNotExist( + f"No cell exists for row_key={row_key!r}, column_key={column_key!r}." + ) + row_index = self._row_locations.get(row_key) + column_index = self._column_locations.get(column_key) + return Coordinate(row_index, column_index) + + def get_row(self, row_key: RowKey | str) -> list[CellType]: + """Get the values from the row identified by the given row key. + + Args: + row_key: The key of the row. + + Returns: + A list of the values contained within the row. + + Raises: + RowDoesNotExist: When there is no row corresponding to the key. + """ + if row_key not in self._row_locations: + raise RowDoesNotExist(f"Row key {row_key!r} is not valid.") + cell_mapping: dict[ColumnKey, CellType] = self._data.get(row_key, {}) + ordered_row: list[CellType] = [ + cell_mapping[column.key] for column in self.ordered_columns + ] + return ordered_row + + def get_row_at(self, row_index: int) -> list[CellType]: + """Get the values from the cells in a row at a given index. This will + return the values from a row based on the rows _current position_ in + the table. + + Args: + row_index: The index of the row. + + Returns: + A list of the values contained in the row. + + Raises: + RowDoesNotExist: If there is no row with the given index. + """ + if not self.is_valid_row_index(row_index): + raise RowDoesNotExist(f"Row index {row_index!r} is not valid.") + row_key = self._row_locations.get_key(row_index) + return self.get_row(row_key) + + def get_row_index(self, row_key: RowKey | str) -> int: + """Return the current index for the row identified by row_key. + + Args: + row_key: The row key to find the current index of. + + Returns: + The current index of the specified row key. + + Raises: + RowDoesNotExist: If the row key does not exist. + """ + if row_key not in self._row_locations: + raise RowDoesNotExist(f"No row exists for row_key={row_key!r}") + return self._row_locations.get(row_key) + + def get_column(self, column_key: ColumnKey | str) -> Iterable[CellType]: + """Get the values from the column identified by the given column key. + + Args: + column_key: The key of the column. + + Returns: + A generator which yields the cells in the column. + + Raises: + ColumnDoesNotExist: If there is no column corresponding to the key. + """ + if column_key not in self._column_locations: + raise ColumnDoesNotExist(f"Column key {column_key!r} is not valid.") + + data = self._data + for row_metadata in self.ordered_rows: + row_key = row_metadata.key + yield data[row_key][column_key] + + def get_column_at(self, column_index: int) -> Iterable[CellType]: + """Get the values from the column at a given index. + + Args: + column_index: The index of the column. + + Returns: + A generator which yields the cells in the column. + + Raises: + ColumnDoesNotExist: If there is no column with the given index. + """ + if not self.is_valid_column_index(column_index): + raise ColumnDoesNotExist(f"Column index {column_index!r} is not valid.") + + column_key = self._column_locations.get_key(column_index) + yield from self.get_column(column_key) + + def get_column_index(self, column_key: ColumnKey | str) -> int: + """Return the current index for the column identified by column_key. + + Args: + column_key: The column key to find the current index of. + + Returns: + The current index of the specified column key. + + Raises: + ColumnDoesNotExist: If the column key does not exist. + """ + if column_key not in self._column_locations: + raise ColumnDoesNotExist(f"No column exists for column_key={column_key!r}") + return self._column_locations.get(column_key) + + def _clear_caches(self) -> None: + self._row_render_cache.clear() + self._cell_render_cache.clear() + self._row_renderable_cache.clear() + self._line_cache.clear() + self._styles_cache.clear() + self._offset_cache.clear() + self._ordered_row_cache.clear() + self._get_styles_to_render_cell.cache_clear() + + def get_row_height(self, row_key: RowKey) -> int: + """Given a row key, return the height of that row in terminal cells. + + Args: + row_key: The key of the row. + + Returns: + The height of the row, measured in terminal character cells. + """ + if row_key is self._header_row_key: + return self.header_height + return self.rows[row_key].height + + def notify_style_update(self) -> None: + super().notify_style_update() + self._row_render_cache.clear() + self._cell_render_cache.clear() + self._row_renderable_cache.clear() + self._line_cache.clear() + self._styles_cache.clear() + self._get_styles_to_render_cell.cache_clear() + self.refresh() + + def _on_resize(self, _: events.Resize) -> None: + self._update_count += 1 + + def watch_show_cursor(self, show_cursor: bool) -> None: + self._clear_caches() + if show_cursor and self.cursor_type != "none": + # When we re-enable the cursor, apply highlighting and + # post the appropriate [Row|Column|Cell]Highlighted event. + self._scroll_cursor_into_view(animate=False) + if self.cursor_type == "cell": + self._highlight_coordinate(self.cursor_coordinate) + elif self.cursor_type == "row": + self._highlight_row(self.cursor_row) + elif self.cursor_type == "column": + self._highlight_column(self.cursor_column) + + def watch_show_header(self, show: bool) -> None: + width, height = self.virtual_size + height_change = self.header_height if show else -self.header_height + self.virtual_size = Size(width, height + height_change) + self._scroll_cursor_into_view() + self._clear_caches() + + def watch_show_row_labels(self, show: bool) -> None: + width, height = self.virtual_size + column_width = self._label_column.get_render_width(self) + width_change = column_width if show else -column_width + self.virtual_size = Size(width + width_change, height) + self._scroll_cursor_into_view() + self._clear_caches() + + def watch_fixed_rows(self) -> None: + self._clear_caches() + + def watch_fixed_columns(self) -> None: + self._clear_caches() + + def watch_zebra_stripes(self) -> None: + self._clear_caches() + + def watch_header_height(self) -> None: + self._clear_caches() + + def validate_cell_padding(self, cell_padding: int) -> int: + return max(cell_padding, 0) + + def watch_cell_padding(self, old_padding: int, new_padding: int) -> None: + # A single side of a single cell will have its width changed by (new - old), + # so the total width change is double that per column, times the number of + # columns for the whole data table. + width_change = 2 * (new_padding - old_padding) * len(self.columns) + width, height = self.virtual_size + self.virtual_size = Size(width + width_change, height) + self._scroll_cursor_into_view() + self._clear_caches() + + def watch_hover_coordinate(self, old: Coordinate, value: Coordinate) -> None: + self.refresh_coordinate(old) + self.refresh_coordinate(value) + + def watch_cursor_coordinate( + self, old_coordinate: Coordinate, new_coordinate: Coordinate + ) -> None: + if old_coordinate != new_coordinate: + # Refresh the old and the new cell, and post the appropriate + # message to tell users of the newly highlighted row/cell/column. + if self.cursor_type == "cell": + self.refresh_coordinate(old_coordinate) + self._highlight_coordinate(new_coordinate) + elif self.cursor_type == "row": + self.refresh_row(old_coordinate.row) + self._highlight_row(new_coordinate.row) + elif self.cursor_type == "column": + self.refresh_column(old_coordinate.column) + self._highlight_column(new_coordinate.column) + + if self._require_update_dimensions: + self.call_after_refresh(self._scroll_cursor_into_view) + else: + self._scroll_cursor_into_view() + + def move_cursor( + self, + *, + row: int | None = None, + column: int | None = None, + animate: bool = False, + scroll: bool = True, + ) -> None: + """Move the cursor to the given position. + + Example: + ```py + datatable = app.query_one(DataTable) + datatable.move_cursor(row=4, column=6) + # datatable.cursor_coordinate == Coordinate(4, 6) + datatable.move_cursor(row=3) + # datatable.cursor_coordinate == Coordinate(3, 6) + ``` + + Args: + row: The new row to move the cursor to. + column: The new column to move the cursor to. + animate: Whether to animate the change of coordinates. + scroll: Scroll the cursor into view after moving. + """ + + cursor_row, cursor_column = self.cursor_coordinate + if row is not None: + cursor_row = row + if column is not None: + cursor_column = column + destination = Coordinate(cursor_row, cursor_column) + + # Scroll the cursor after refresh to ensure the virtual height + # (calculated in on_idle) has settled. If we tried to scroll before + # the virtual size has been set, then it might fail if we added a bunch + # of rows then tried to immediately move the cursor. + # We do this before setting `cursor_coordinate` because its watcher will also + # schedule a call to `_scroll_cursor_into_view` without optionally animating. + if scroll: + if self._require_update_dimensions: + self.call_after_refresh(self._scroll_cursor_into_view, animate=animate) + else: + self._scroll_cursor_into_view(animate=animate) + + self.cursor_coordinate = destination + + def _highlight_coordinate(self, coordinate: Coordinate) -> None: + """Apply highlighting to the cell at the coordinate, and post event.""" + self.refresh_coordinate(coordinate) + try: + cell_value = self.get_cell_at(coordinate) + except CellDoesNotExist: + # The cell may not exist e.g. when the table is cleared. + # In that case, there's nothing for us to do here. + return + else: + cell_key = self.coordinate_to_cell_key(coordinate) + self.post_message( + DataTable.CellHighlighted( + self, cell_value, coordinate=coordinate, cell_key=cell_key + ) + ) + + def coordinate_to_cell_key(self, coordinate: Coordinate) -> CellKey: + """Return the key for the cell currently occupying this coordinate. + + Args: + coordinate: The coordinate to exam the current cell key of. + + Returns: + The key of the cell currently occupying this coordinate. + + Raises: + CellDoesNotExist: If the coordinate is not valid. + """ + if not self.is_valid_coordinate(coordinate): + raise CellDoesNotExist(f"No cell exists at {coordinate!r}.") + row_index, column_index = coordinate + row_key = self._row_locations.get_key(row_index) + column_key = self._column_locations.get_key(column_index) + return CellKey(row_key, column_key) + + def _highlight_row(self, row_index: int) -> None: + """Apply highlighting to the row at the given index, and post event.""" + self.refresh_row(row_index) + is_valid_row = row_index < len(self._data) + if is_valid_row: + row_key = self._row_locations.get_key(row_index) + self.post_message(DataTable.RowHighlighted(self, row_index, row_key)) + + def _highlight_column(self, column_index: int) -> None: + """Apply highlighting to the column at the given index, and post event.""" + self.refresh_column(column_index) + if column_index < len(self.columns): + column_key = self._column_locations.get_key(column_index) + self.post_message( + DataTable.ColumnHighlighted(self, column_index, column_key) + ) + + def validate_cursor_coordinate(self, value: Coordinate) -> Coordinate: + return self._clamp_cursor_coordinate(value) + + def _clamp_cursor_coordinate(self, coordinate: Coordinate) -> Coordinate: + """Clamp a coordinate such that it falls within the boundaries of the table.""" + row, column = coordinate + row = clamp(row, 0, self.row_count - 1) + column = clamp(column, 0, len(self.columns) - 1) + return Coordinate(row, column) + + def watch_cursor_type(self, old: str, new: str) -> None: + self._set_hover_cursor(False) + if self.show_cursor: + self._highlight_cursor() + + # Refresh cells that were previously impacted by the cursor + # but may no longer be. + if old == "cell": + self.refresh_coordinate(self.cursor_coordinate) + elif old == "row": + row_index, _ = self.cursor_coordinate + self.refresh_row(row_index) + elif old == "column": + _, column_index = self.cursor_coordinate + self.refresh_column(column_index) + + self._scroll_cursor_into_view() + + def _highlight_cursor(self) -> None: + """Applies the appropriate highlighting and raises the appropriate + [Row|Column|Cell]Highlighted event for the given cursor coordinate + and cursor type.""" + row_index, column_index = self.cursor_coordinate + cursor_type = self.cursor_type + # Apply the highlighting to the newly relevant cells + if cursor_type == "cell": + self._highlight_coordinate(self.cursor_coordinate) + elif cursor_type == "row": + self._highlight_row(row_index) + elif cursor_type == "column": + self._highlight_column(column_index) + + @property + def _row_label_column_width(self) -> int: + """The render width of the column containing row labels""" + return ( + self._label_column.get_render_width(self) + if self._should_render_row_labels + else 0 + ) + + def _update_column_widths(self, updated_cells: set[CellKey]) -> None: + """Update the widths of the columns based on the newly updated cell widths.""" + for row_key, column_key in updated_cells: + column = self.columns.get(column_key) + row = self.rows.get(row_key) + if column is None or row is None: + continue + console = self.app.console + label_width = measure(console, column.label, 1) + content_width = column.content_width + cell_value = self._data[row_key][column_key] + + render_height = row.height + new_content_width = measure( + console, + default_cell_formatter( + cell_value, + wrap=row.height != 1, + height=render_height, + ), + 1, + ) + + if new_content_width < content_width: + cells_in_column = self.get_column(column_key) + cell_widths = [ + measure( + console, + default_cell_formatter( + cell, + wrap=row.height != 1, + height=render_height, + ), + 1, + ) + for cell in cells_in_column + ] + column.content_width = max([*cell_widths, label_width]) + else: + column.content_width = max(new_content_width, label_width) + + self._require_update_dimensions = True + + def _update_dimensions(self, new_rows: Iterable[RowKey]) -> None: + """Called to recalculate the virtual (scrollable) size. + + This recomputes column widths and then checks if any of the new rows need + to have their height computed. + + Args: + new_rows: The new rows that will affect the `DataTable` dimensions. + """ + console = self.app.console + auto_height_rows: list[tuple[int, Row, list[RenderableType]]] = [] + for row_key in new_rows: + row_index = self._row_locations.get(row_key) + + # The row could have been removed before on_idle was called, so we + # need to be quite defensive here and don't assume that the row exists. + if row_index is None: + continue + + row = self.rows.get(row_key) + assert row is not None + + if row.label is not None: + self._labelled_row_exists = True + + row_label, cells_in_row = self._get_row_renderables(row_index) + label_content_width = measure(console, row_label, 1) if row_label else 0 + self._label_column.content_width = max( + self._label_column.content_width, label_content_width + ) + + for column, renderable in zip(self.ordered_columns, cells_in_row): + content_width = measure(console, renderable, 1) + column.content_width = max(column.content_width, content_width) + + if row.auto_height: + auto_height_rows.append((row_index, row, cells_in_row)) + + # If there are rows that need to have their height computed, render them correctly + # so that we can cache this rendering for later. + if auto_height_rows: + self._offset_cache.clear() + render_cell = self._render_cell # This method renders & caches. + should_highlight = self._should_highlight + cursor_type = self.cursor_type + cursor_location = self.cursor_coordinate + hover_location = self.hover_coordinate + base_style = self.rich_style + fixed_style = self.get_component_styles( + "datatable--fixed" + ).rich_style + Style.from_meta({"fixed": True}) + ordered_columns = self.ordered_columns + fixed_columns = self.fixed_columns + + for row_index, row, cells_in_row in auto_height_rows: + height = 0 + row_style = self._get_row_style(row_index, base_style) + + # As we go through the cells, save their rendering, height, and + # column width. After we compute the height of the row, go over the cells + # that were rendered with the wrong height and append the missing padding. + rendered_cells: list[tuple[SegmentLines, int, int]] = [] + for column_index, column in enumerate(ordered_columns): + style = fixed_style if column_index < fixed_columns else row_style + cell_location = Coordinate(row_index, column_index) + rendered_cell = render_cell( + row_index, + column_index, + style, + column.get_render_width(self), + cursor=should_highlight( + cursor_location, cell_location, cursor_type + ), + hover=should_highlight( + hover_location, cell_location, cursor_type + ), + ) + cell_height = len(rendered_cell) + rendered_cells.append( + (rendered_cell, cell_height, column.get_render_width(self)) + ) + height = max(height, cell_height) + + row.height = height + # Do surgery on the cache for cells that were rendered with the incorrect + # height during the first pass. + for cell_renderable, cell_height, column_width in rendered_cells: + if cell_height < height: + first_line_space_style = cell_renderable[0][0].style + cell_renderable.extend( + [ + [Segment(" " * column_width, first_line_space_style)] + for _ in range(height - cell_height) + ] + ) + + self._line_cache.clear() + self._styles_cache.clear() + + data_cells_width = sum( + column.get_render_width(self) for column in self.columns.values() + ) + total_width = data_cells_width + self._row_label_column_width + header_height = self.header_height if self.show_header else 0 + self.virtual_size = Size( + total_width, + self._total_row_height + header_height, + ) + + def _get_cell_region(self, coordinate: Coordinate) -> Region: + """Get the region of the cell at the given spatial coordinate.""" + if not self.is_valid_coordinate(coordinate): + return Region(0, 0, 0, 0) + + row_index, column_index = coordinate + row_key = self._row_locations.get_key(row_index) + row = self.rows[row_key] + + # The x-coordinate of a cell is the sum of widths of the data cells to the left + # plus the width of the render width of the longest row label. + x = ( + sum( + column.get_render_width(self) + for column in self.ordered_columns[:column_index] + ) + + self._row_label_column_width + ) + column_key = self._column_locations.get_key(column_index) + width = self.columns[column_key].get_render_width(self) + height = row.height + y = sum(ordered_row.height for ordered_row in self.ordered_rows[:row_index]) + if self.show_header: + y += self.header_height + cell_region = Region(x, y, width, height) + return cell_region + + def _get_row_region(self, row_index: int) -> Region: + """Get the region of the row at the given index.""" + if not self.is_valid_row_index(row_index): + return Region(0, 0, 0, 0) + + rows = self.rows + row_key = self._row_locations.get_key(row_index) + row = rows[row_key] + row_width = ( + sum(column.get_render_width(self) for column in self.columns.values()) + + self._row_label_column_width + ) + y = sum(ordered_row.height for ordered_row in self.ordered_rows[:row_index]) + if self.show_header: + y += self.header_height + row_region = Region(0, y, max(self.size.width, row_width), row.height) + return row_region + + def _get_column_region(self, column_index: int) -> Region: + """Get the region of the column at the given index.""" + if not self.is_valid_column_index(column_index): + return Region(0, 0, 0, 0) + + columns = self.columns + x = ( + sum( + column.get_render_width(self) + for column in self.ordered_columns[:column_index] + ) + + self._row_label_column_width + ) + column_key = self._column_locations.get_key(column_index) + width = columns[column_key].get_render_width(self) + header_height = self.header_height if self.show_header else 0 + height = self._total_row_height + header_height + full_column_region = Region(x, 0, width, height) + return full_column_region + + def clear(self, columns: bool = False) -> Self: + """Clear the table. + + Args: + columns: Also clear the columns. + + Returns: + The `DataTable` instance. + """ + self._clear_caches() + self._y_offsets.clear() + self._data.clear() + self.rows.clear() + self._row_locations = TwoWayDict({}) + if columns: + self.columns.clear() + self._column_locations = TwoWayDict({}) + self._require_update_dimensions = True + self.cursor_coordinate = Coordinate(0, 0) + self.hover_coordinate = Coordinate(0, 0) + self._label_column = Column(self._label_column_key, Text(), auto_width=True) + self._labelled_row_exists = False + self.refresh() + self.scroll_x = 0 + self.scroll_y = 0 + self.scroll_target_x = 0 + self.scroll_target_y = 0 + return self + + def add_column( + self, + label: TextType, + *, + width: int | None = None, + key: str | None = None, + default: CellType | None = None, + ) -> ColumnKey: + """Add a column to the table. + + Args: + label: A str or Text object containing the label (shown top of column). + width: Width of the column in cells or None to fit content. + key: A key which uniquely identifies this column. + If None, it will be generated for you. + default: The value to insert into pre-existing rows. + + Returns: + Uniquely identifies this column. Can be used to retrieve this column + regardless of its current location in the DataTable (it could have moved + after being added due to sorting/insertion/deletion of other columns). + """ + column_key = ColumnKey(key) + if column_key in self._column_locations: + raise DuplicateKey(f"The column key {key!r} already exists.") + column_index = len(self.columns) + label = Text.from_markup(label) if isinstance(label, str) else label + content_width = measure(self.app.console, label, 1) + if width is None: + column = Column( + column_key, + label, + content_width, + content_width=content_width, + auto_width=True, + ) + else: + column = Column( + column_key, + label, + width, + content_width=content_width, + ) + + self.columns[column_key] = column + self._column_locations[column_key] = column_index + + # Update pre-existing rows to account for the new column. + for row_key in self.rows.keys(): + self._data[row_key][column_key] = default + self._updated_cells.add(CellKey(row_key, column_key)) + + self._require_update_dimensions = True + self._update_count += 1 + self.check_idle() + + return column_key + + def add_row( + self, + *cells: CellType, + height: int | None = 1, + key: str | None = None, + label: TextType | None = None, + ) -> RowKey: + """Add a row at the bottom of the DataTable. + + Args: + *cells: Positional arguments should contain cell data. + height: The height of a row (in lines). Use `None` to auto-detect the optimal + height. + key: A key which uniquely identifies this row. If None, it will be generated + for you and returned. + label: The label for the row. Will be displayed to the left if supplied. + + Returns: + Unique identifier for this row. Can be used to retrieve this row regardless + of its current location in the DataTable (it could have moved after + being added due to sorting or insertion/deletion of other rows). + """ + row_key = RowKey(key) + if row_key in self._row_locations: + raise DuplicateKey(f"The row key {row_key!r} already exists.") + + # TODO: If there are no columns: do we generate them here? + # If we don't do this, users will be required to call add_column(s) + # Before they call add_row. + + if len(cells) > len(self.ordered_columns): + raise ValueError("More values provided than there are columns.") + + row_index = self.row_count + # Map the key of this row to its current index + self._row_locations[row_key] = row_index + self._data[row_key] = { + column.key: cell + for column, cell in zip_longest(self.ordered_columns, cells) + } + + label = Text.from_markup(label, end="") if isinstance(label, str) else label + + # Rows with auto-height get a height of 0 because 1) we need an integer height + # to do some intermediate computations and 2) because 0 doesn't impact the data + # table while we don't figure out how tall this row is. + self.rows[row_key] = Row( + row_key, + height or 0, + label, + height is None, + ) + self._new_rows.add(row_key) + self._require_update_dimensions = True + self.cursor_coordinate = self.cursor_coordinate + + # If a position has opened for the cursor to appear, where it previously + # could not (e.g. when there's no data in the table), then a highlighted + # event is posted, since there's now a highlighted cell when there wasn't + # before. + cell_now_available = self.row_count == 1 and len(self.columns) > 0 + visible_cursor = self.show_cursor and self.cursor_type != "none" + if cell_now_available and visible_cursor: + self._highlight_cursor() + + self._update_count += 1 + self.check_idle() + return row_key + + def add_columns( + self, *columns: Union[TextType, tuple[TextType, str]] + ) -> list[ColumnKey]: + """Add multiple columns to the DataTable. + + Args: + *columns: Column specifications. Each can be either: + - A string or Text object (label only, auto-generated key) + - A tuple of (label, key) for manual key control + + Returns: + A list of the keys for the columns that were added. See + the `add_column` method docstring for more information on how + these keys are used. + + Examples: + ```python + # Add columns with auto-generated keys + keys = table.add_columns("Name", "Age", "City") + + # Add columns with manual keys + keys = table.add_columns( + ("Name", "name_col"), + ("Age", "age_col"), + "City" # Mixed with auto-generated key + ) + ``` + """ + column_keys = [] + for column in columns: + if isinstance(column, tuple): + label, key = column + column_key = self.add_column(label, width=None, key=key) + else: + column_key = self.add_column(column, width=None) + column_keys.append(column_key) + return column_keys + + def add_rows(self, rows: Iterable[Iterable[CellType]]) -> list[RowKey]: + """Add a number of rows at the bottom of the DataTable. + + Args: + rows: Iterable of rows. A row is an iterable of cells. + + Returns: + A list of the keys for the rows that were added. See + the `add_row` method docstring for more information on how + these keys are used. + """ + row_keys = [] + for row in rows: + row_key = self.add_row(*row) + row_keys.append(row_key) + return row_keys + + def remove_row(self, row_key: RowKey | str) -> None: + """Remove a row (identified by a key) from the DataTable. + + Args: + row_key: The key identifying the row to remove. + + Raises: + RowDoesNotExist: If the row key does not exist. + """ + if row_key not in self._row_locations: + raise RowDoesNotExist(f"Row key {row_key!r} is not valid.") + + self._require_update_dimensions = True + self.check_idle() + + index_to_delete = self._row_locations.get(row_key) + new_row_locations = TwoWayDict({}) + for row_location_key in self._row_locations: + row_index = self._row_locations.get(row_location_key) + if row_index > index_to_delete: + new_row_locations[row_location_key] = row_index - 1 + elif row_index < index_to_delete: + new_row_locations[row_location_key] = row_index + + self._row_locations = new_row_locations + + # Prevent the removed cells from triggering dimension updates + for column_key in self._data.get(row_key): + self._updated_cells.discard(CellKey(row_key, column_key)) + + del self.rows[row_key] + del self._data[row_key] + + self.cursor_coordinate = self.cursor_coordinate + self.hover_coordinate = self.hover_coordinate + + self._update_count += 1 + self.refresh(layout=True) + + def remove_column(self, column_key: ColumnKey | str) -> None: + """Remove a column (identified by a key) from the DataTable. + + Args: + column_key: The key identifying the column to remove. + + Raises: + ColumnDoesNotExist: If the column key does not exist. + """ + if column_key not in self._column_locations: + raise ColumnDoesNotExist(f"Column key {column_key!r} is not valid.") + + self._require_update_dimensions = True + self.check_idle() + + index_to_delete = self._column_locations.get(column_key) + new_column_locations = TwoWayDict({}) + for column_location_key in self._column_locations: + column_index = self._column_locations.get(column_location_key) + if column_index > index_to_delete: + new_column_locations[column_location_key] = column_index - 1 + elif column_index < index_to_delete: + new_column_locations[column_location_key] = column_index + + self._column_locations = new_column_locations + + del self.columns[column_key] + + for row_key in self._data: + self._updated_cells.discard(CellKey(row_key, column_key)) + del self._data[row_key][column_key] + + self.cursor_coordinate = self.cursor_coordinate + self.hover_coordinate = self.hover_coordinate + + self._update_count += 1 + self.refresh(layout=True) + + async def _on_idle(self, _: events.Idle) -> None: + """Runs when the message pump is empty. + + We use this for some expensive calculations like re-computing dimensions of the + whole DataTable and re-computing column widths after some cells + have been updated. This is more efficient in the case of high + frequency updates, ensuring we only do expensive computations once.""" + if self._updated_cells: + # Cell contents have already been updated at this point. + # Now we only need to worry about measuring column widths. + updated_cells = self._updated_cells.copy() + self._updated_cells.clear() + self._update_column_widths(updated_cells) + + if self._require_update_dimensions: + # Add the new rows *before* updating the column widths, since + # cells in a new row may influence the final width of a column. + # Only then can we compute optimal height of rows with "auto" height. + self._require_update_dimensions = False + new_rows = self._new_rows.copy() + self._new_rows.clear() + self._update_dimensions(new_rows) + + def refresh_coordinate(self, coordinate: Coordinate) -> Self: + """Refresh the cell at a coordinate. + + Args: + coordinate: The coordinate to refresh. + + Returns: + The `DataTable` instance. + """ + if not self.is_valid_coordinate(coordinate): + return self + region = self._get_cell_region(coordinate) + self._refresh_region(region) + return self + + def refresh_row(self, row_index: int) -> Self: + """Refresh the row at the given index. + + Args: + row_index: The index of the row to refresh. + + Returns: + The `DataTable` instance. + """ + if not self.is_valid_row_index(row_index): + return self + + region = self._get_row_region(row_index) + self._refresh_region(region) + return self + + def refresh_column(self, column_index: int) -> Self: + """Refresh the column at the given index. + + Args: + column_index: The index of the column to refresh. + + Returns: + The `DataTable` instance. + """ + if not self.is_valid_column_index(column_index): + return self + + region = self._get_column_region(column_index) + self._refresh_region(region) + return self + + def _refresh_region(self, region: Region) -> Self: + """Refresh a region of the DataTable, if it's visible within the window. + + This method will translate the region to account for scrolling. + + Returns: + The `DataTable` instance. + """ + if not self.window_region.overlaps(region): + return self + region = region.translate(-self.scroll_offset) + self.refresh(region) + return self + + def is_valid_row_index(self, row_index: int) -> bool: + """Return a boolean indicating whether the row_index is within table bounds. + + Args: + row_index: The row index to check. + + Returns: + True if the row index is within the bounds of the table. + """ + return 0 <= row_index < len(self.rows) + + def is_valid_column_index(self, column_index: int) -> bool: + """Return a boolean indicating whether the column_index is within table bounds. + + Args: + column_index: The column index to check. + + Returns: + True if the column index is within the bounds of the table. + """ + return 0 <= column_index < len(self.columns) + + def is_valid_coordinate(self, coordinate: Coordinate) -> bool: + """Return a boolean indicating whether the given coordinate is valid. + + Args: + coordinate: The coordinate to validate. + + Returns: + True if the coordinate is within the bounds of the table. + """ + row_index, column_index = coordinate + return self.is_valid_row_index(row_index) and self.is_valid_column_index( + column_index + ) + + @property + def ordered_columns(self) -> list[Column]: + """The list of Columns in the DataTable, ordered as they appear on screen.""" + column_indices = range(len(self.columns)) + column_keys = [ + self._column_locations.get_key(index) for index in column_indices + ] + ordered_columns = [self.columns[key] for key in column_keys] + return ordered_columns + + @property + def ordered_rows(self) -> list[Row]: + """The list of Rows in the DataTable, ordered as they appear on screen.""" + num_rows = self.row_count + update_count = self._update_count + cache_key = (num_rows, update_count) + if cache_key in self._ordered_row_cache: + ordered_rows = self._ordered_row_cache[cache_key] + else: + row_indices = range(num_rows) + ordered_rows = [] + for row_index in row_indices: + row_key = self._row_locations.get_key(row_index) + row = self.rows[row_key] + ordered_rows.append(row) + self._ordered_row_cache[cache_key] = ordered_rows + return ordered_rows + + @property + def _should_render_row_labels(self) -> bool: + """Whether row labels should be rendered or not.""" + return self._labelled_row_exists and self.show_row_labels + + def _get_row_renderables(self, row_index: int) -> RowRenderables: + """Get renderables for the row currently at the given row index. The renderables + returned here have already been passed through the default_cell_formatter. + + Args: + row_index: Index of the row. + + Returns: + A RowRenderables containing the optional label and the rendered cells. + """ + update_count = self._update_count + cache_key = (update_count, row_index) + if cache_key in self._row_renderable_cache: + row_renderables = self._row_renderable_cache[cache_key] + else: + row_renderables = self._compute_row_renderables(row_index) + self._row_renderable_cache[cache_key] = row_renderables + return row_renderables + + def _compute_row_renderables(self, row_index: int) -> RowRenderables: + """Actual computation for _get_row_renderables""" + ordered_columns = self.ordered_columns + if row_index == -1: + header_row: list[RenderableType] = [ + column.label for column in ordered_columns + ] + # This is the cell where header and row labels intersect + return RowRenderables(None, header_row) + + ordered_row = self.get_row_at(row_index) + row_key = self._row_locations.get_key(row_index) + if row_key is None: + return RowRenderables(None, []) + row_metadata = self.rows.get(row_key) + if row_metadata is None: + return RowRenderables(None, []) + + formatted_row_cells: list[RenderableType] = [ + ( + _EMPTY_TEXT + if datum is None + else default_cell_formatter( + datum, + wrap=row_metadata.height != 1, + height=row_metadata.height, + ) + or _EMPTY_TEXT + ) + for datum, _ in zip_longest(ordered_row, range(len(self.columns))) + ] + + label = None + if self._should_render_row_labels: + label = ( + default_cell_formatter( + row_metadata.label, + wrap=row_metadata.height != 1, + height=row_metadata.height, + ) + if row_metadata.label + else None + ) + return RowRenderables(label, formatted_row_cells) + + def _render_cell( + self, + row_index: int, + column_index: int, + base_style: Style, + width: int, + cursor: bool = False, + hover: bool = False, + ) -> SegmentLines: + """Render the given cell. + + Args: + row_index: Index of the row. + column_index: Index of the column. + base_style: Style to apply. + width: Width of the cell. + cursor: Is this cell affected by cursor highlighting? + hover: Is this cell affected by hover cursor highlighting? + + Returns: + A list of segments per line. + """ + is_header_cell = row_index == -1 + is_row_label_cell = column_index == -1 + + is_fixed_style_cell = ( + not is_header_cell + and not is_row_label_cell + and (row_index < self.fixed_rows or column_index < self.fixed_columns) + ) + + if is_header_cell: + row_key = self._header_row_key + else: + row_key = self._row_locations.get_key(row_index) + + column_key = self._column_locations.get_key(column_index) + cell_cache_key: CellCacheKey = ( + row_key, + column_key, + base_style, + cursor, + hover, + self._show_hover_cursor, + self._update_count, + self._pseudo_class_state, + ) + + if cell_cache_key not in self._cell_render_cache: + base_style += Style.from_meta({"row": row_index, "column": column_index}) + row_label, row_cells = self._get_row_renderables(row_index) + + if is_row_label_cell: + cell = row_label if row_label is not None else "" + else: + cell = row_cells[column_index] + + component_style, post_style = self._get_styles_to_render_cell( + is_header_cell, + is_row_label_cell, + is_fixed_style_cell, + hover, + cursor, + self.show_cursor, + self._show_hover_cursor, + self.cursor_foreground_priority == "css", + self.cursor_background_priority == "css", + ) + + if is_header_cell: + row_height = self.header_height + options = self.app.console_options.update_dimensions(width, row_height) + else: + # If an auto-height row hasn't had its height calculated, we don't fix + # the value for `height` so that we can measure the height of the cell. + row = self.rows[row_key] + if row.auto_height and row.height == 0: + row_height = 0 + options = self.app.console_options.update_width(width) + else: + row_height = row.height + options = self.app.console_options.update_dimensions( + width, row_height + ) + + # If the row height is explicitly set to 1, then we don't wrap. + if row_height == 1: + options = options.update(no_wrap=True) + + lines = self.app.console.render_lines( + Styled( + Padding(cell, (0, self.cell_padding)), + pre_style=base_style + component_style, + post_style=post_style, + ), + options, + ) + + self._cell_render_cache[cell_cache_key] = lines + + return self._cell_render_cache[cell_cache_key] + + @functools.lru_cache(maxsize=32) + def _get_styles_to_render_cell( + self, + is_header_cell: bool, + is_row_label_cell: bool, + is_fixed_style_cell: bool, + hover: bool, + cursor: bool, + show_cursor: bool, + show_hover_cursor: bool, + has_css_foreground_priority: bool, + has_css_background_priority: bool, + ) -> tuple[Style, Style]: + """Auxiliary method to compute styles used to render a given cell. + + Args: + is_header_cell: Is this a cell from a header? + is_row_label_cell: Is this the label of any given row? + is_fixed_style_cell: Should this cell be styled like a fixed cell? + hover: Does this cell have the hover pseudo class? + cursor: Is this cell covered by the cursor? + show_cursor: Do we want to show the cursor in the data table? + show_hover_cursor: Do we want to show the mouse hover when using the keyboard + to move the cursor? + has_css_foreground_priority: `self.cursor_foreground_priority == "css"`? + has_css_background_priority: `self.cursor_background_priority == "css"`? + """ + get_component = self.get_component_rich_style + component_style = Style() + + if hover and show_cursor and show_hover_cursor: + component_style += get_component("datatable--hover") + if is_header_cell or is_row_label_cell: + # Apply subtle variation in style for the header/label (blue background by + # default) rows and columns affected by the cursor, to ensure we can + # still differentiate between the labels and the data. + component_style += get_component("datatable--header-hover") + + if cursor and show_cursor: + cursor_style = get_component("datatable--cursor") + component_style += cursor_style + if is_header_cell or is_row_label_cell: + component_style += get_component("datatable--header-cursor") + elif is_fixed_style_cell: + component_style += get_component("datatable--fixed-cursor") + + post_foreground = ( + Style.from_color(color=component_style.color) + if has_css_foreground_priority + else Style.null() + ) + post_background = ( + Style.from_color(bgcolor=component_style.bgcolor) + if has_css_background_priority + else Style.null() + ) + + return component_style, post_foreground + post_background + + def _render_line_in_row( + self, + row_key: RowKey, + line_no: int, + base_style: Style, + cursor_location: Coordinate, + hover_location: Coordinate, + ) -> tuple[SegmentLines, SegmentLines]: + """Render a single line from a row in the DataTable. + + Args: + row_key: The identifying key for this row. + line_no: Line number (y-coordinate) within row. 0 is the first strip of + cells in the row, line_no=1 is the next line in the row, and so on... + base_style: Base style of row. + cursor_location: The location of the cursor in the DataTable. + hover_location: The location of the hover cursor in the DataTable. + + Returns: + Lines for fixed cells, and Lines for scrollable cells. + """ + cursor_type = self.cursor_type + show_cursor = self.show_cursor + + cache_key = ( + row_key, + line_no, + base_style, + cursor_location, + hover_location, + cursor_type, + show_cursor, + self._show_hover_cursor, + self._update_count, + self._pseudo_class_state, + ) + + if cache_key in self._row_render_cache: + return self._row_render_cache[cache_key] + + should_highlight = self._should_highlight + render_cell = self._render_cell + header_style = self.get_component_styles("datatable--header").rich_style + + if row_key in self._row_locations: + row_index = self._row_locations.get(row_key) + else: + row_index = -1 + + # If the row has a label, add it to fixed_row here with correct style. + fixed_row = [] + + if self._labelled_row_exists and self.show_row_labels: + # The width of the row label is updated again on idle + cell_location = Coordinate(row_index, -1) + label_cell_lines = render_cell( + row_index, + -1, + header_style, + width=self._row_label_column_width, + cursor=should_highlight(cursor_location, cell_location, cursor_type), + hover=should_highlight(hover_location, cell_location, cursor_type), + )[line_no] + fixed_row.append(label_cell_lines) + + if self.fixed_columns: + if row_key is self._header_row_key: + fixed_style = header_style # We use the header style either way. + else: + fixed_style = self.get_component_styles("datatable--fixed").rich_style + fixed_style += Style.from_meta({"fixed": True}) + for column_index, column in enumerate( + self.ordered_columns[: self.fixed_columns] + ): + cell_location = Coordinate(row_index, column_index) + fixed_cell_lines = render_cell( + row_index, + column_index, + fixed_style, + column.get_render_width(self), + cursor=should_highlight( + cursor_location, cell_location, cursor_type + ), + hover=should_highlight(hover_location, cell_location, cursor_type), + )[line_no] + fixed_row.append(fixed_cell_lines) + + row_style = self._get_row_style(row_index, base_style) + + scrollable_row = [] + for column_index, column in enumerate(self.ordered_columns): + cell_location = Coordinate(row_index, column_index) + cell_lines = render_cell( + row_index, + column_index, + row_style, + column.get_render_width(self), + cursor=should_highlight(cursor_location, cell_location, cursor_type), + hover=should_highlight(hover_location, cell_location, cursor_type), + )[line_no] + scrollable_row.append(cell_lines) + + # Extending the styling out horizontally to fill the container + widget_width = self.size.width + table_width = ( + sum( + column.get_render_width(self) + for column in self.ordered_columns[self.fixed_columns :] + ) + + self._row_label_column_width + ) + remaining_space = max(0, widget_width - table_width) + background_color = self.background_colors[1] + if self.cursor_type == "row": + extend_style, _ = self._get_styles_to_render_cell( + row_index == -1, + False, + False, + should_highlight( + hover_location, Coordinate(row_index or 0, 0), cursor_type + ), + row_index == cursor_location.row, + self.show_cursor, + self._show_hover_cursor, + False, + False, + ) + extend_style = row_style + extend_style + else: + if row_style.bgcolor is not None: + # TODO: This should really be in a component class + faded_color = Color.from_rich_color(row_style.bgcolor).blend( + background_color, factor=0.25 + ) + extend_style = Style.from_color( + color=row_style.color, bgcolor=faded_color.rich_color + ) + else: + extend_style = Style.from_color(row_style.color, row_style.bgcolor) + extend_style += Style.from_meta( + {"row": row_index, "column": 0, "out_of_bounds": True} + ) + scrollable_row.append([Segment(" " * remaining_space, extend_style)]) + + row_pair = (fixed_row, scrollable_row) + self._row_render_cache[cache_key] = row_pair + return row_pair + + def _get_offsets(self, y: int) -> tuple[RowKey, int]: + """Get row key and line offset for a given line. + + Args: + y: Y coordinate relative to DataTable top. + + Returns: + Row key and line (y) offset within cell. + """ + header_height = self.header_height + y_offsets = self._y_offsets + if self.show_header: + if y < header_height: + return self._header_row_key, y + y -= header_height + if y > len(y_offsets): + raise LookupError(f"Y coord {y!r} is greater than total height") + + return y_offsets[y] + + def _render_line(self, y: int, x1: int, x2: int, base_style: Style) -> Strip: + """Render a (possibly cropped) line into a Strip (a list of segments + representing a horizontal line). + + Args: + y: Y coordinate of line + x1: X start crop. + x2: X end crop (exclusive). + base_style: Style to apply to line. + + Returns: + The Strip which represents this cropped line. + """ + + width = self.size.width + + try: + row_key, y_offset_in_row = self._get_offsets(y) + except LookupError: + return Strip.blank(width, base_style) + + cache_key = ( + y, + x1, + x2, + width, + self.cursor_coordinate, + self.hover_coordinate, + base_style, + self.cursor_type, + self._show_hover_cursor, + self._update_count, + self._pseudo_class_state, + ) + if cache_key in self._line_cache: + return self._line_cache[cache_key] + + fixed, scrollable = self._render_line_in_row( + row_key, + y_offset_in_row, + base_style, + cursor_location=self.cursor_coordinate, + hover_location=self.hover_coordinate, + ) + fixed_width = sum( + column.get_render_width(self) + for column in self.ordered_columns[: self.fixed_columns] + ) + + fixed_line: list[Segment] = list(chain.from_iterable(fixed)) if fixed else [] + scrollable_line: list[Segment] = list(chain.from_iterable(scrollable)) + + segments = fixed_line + line_crop(scrollable_line, x1 + fixed_width, x2, width) + strip = Strip(segments).adjust_cell_length(width, base_style).simplify() + + self._line_cache[cache_key] = strip + return strip + + def render_lines(self, crop: Region) -> list[Strip]: + self._pseudo_class_state = self.get_pseudo_class_state() + return super().render_lines(crop) + + def render_line(self, y: int) -> Strip: + width, height = self.size + scroll_x, scroll_y = self.scroll_offset + + fixed_row_keys: list[RowKey] = [ + self._row_locations.get_key(row_index) + for row_index in range(self.fixed_rows) + ] + + fixed_rows_height = sum( + self.get_row_height(row_key) for row_key in fixed_row_keys + ) + if self.show_header: + fixed_rows_height += self.get_row_height(self._header_row_key) + + if y >= fixed_rows_height: + y += scroll_y + + return self._render_line(y, scroll_x, scroll_x + width, self.rich_style) + + def _should_highlight( + self, + cursor: Coordinate, + target_cell: Coordinate, + type_of_cursor: CursorType, + ) -> bool: + """Determine if the given cell should be highlighted because of the cursor. + + This auxiliary method takes the cursor position and type into account when + determining whether the cell should be highlighted. + + Args: + cursor: The current position of the cursor. + target_cell: The cell we're checking for the need to highlight. + type_of_cursor: The type of cursor that is currently active. + + Returns: + Whether or not the given cell should be highlighted. + """ + if type_of_cursor == "cell": + return cursor == target_cell + elif type_of_cursor == "row": + cursor_row, _ = cursor + cell_row, _ = target_cell + return cursor_row == cell_row + elif type_of_cursor == "column": + _, cursor_column = cursor + _, cell_column = target_cell + return cursor_column == cell_column + else: + return False + + def _get_row_style(self, row_index: int, base_style: Style) -> Style: + """Gets the Style that should be applied to the row at the given index. + + Args: + row_index: The index of the row to style. + base_style: The base style to use by default. + + Returns: + The appropriate style. + """ + + if row_index == -1: + row_style = self.get_component_styles("datatable--header").rich_style + elif row_index < self.fixed_rows: + row_style = self.get_component_styles("datatable--fixed").rich_style + else: + if self.zebra_stripes: + component_row_style = ( + "datatable--odd-row" if row_index % 2 else "datatable--even-row" + ) + row_style = self.get_component_styles(component_row_style).rich_style + else: + row_style = base_style + return row_style + + def _on_mouse_move(self, event: events.MouseMove): + """If the hover cursor is visible, display it by extracting the row + and column metadata from the segments present in the cells.""" + self._set_hover_cursor(True) + meta = event.style.meta + if not meta: + self._set_hover_cursor(False) + return + + if self.cursor_type != "row" and meta.get("out_of_bounds", False): + self._set_hover_cursor(False) + return + + if self.show_cursor and self.cursor_type != "none": + try: + self.hover_coordinate = Coordinate(meta["row"], meta["column"]) + except KeyError: + pass + + def _on_leave(self, _: events.Leave) -> None: + self._set_hover_cursor(False) + + def _get_fixed_offset(self) -> Spacing: + """Calculate the "fixed offset", that is the space to the top and left + that is occupied by fixed rows and columns respectively. Fixed rows and columns + are rows and columns that do not participate in scrolling.""" + top = self.header_height if self.show_header else 0 + top += sum(row.height for row in self.ordered_rows[: self.fixed_rows]) + left = ( + sum( + column.get_render_width(self) + for column in self.ordered_columns[: self.fixed_columns] + ) + + self._row_label_column_width + ) + return Spacing(top, 0, 0, left) + + def sort( + self, + *columns: ColumnKey | str, + key: Callable[[Any], Any] | None = None, + reverse: bool = False, + ) -> Self: + """Sort the rows in the `DataTable` by one or more column keys or a + key function (or other callable). If both columns and a key function + are specified, only data from those columns will sent to the key function. + + Args: + columns: One or more columns to sort by the values in. + key: A function (or other callable) that returns a key to + use for sorting purposes. + reverse: If True, the sort order will be reversed. + + Returns: + The `DataTable` instance. + """ + + def key_wrapper(row: tuple[RowKey, dict[ColumnKey | str, CellType]]) -> Any: + _, row_data = row + if columns: + result = itemgetter(*columns)(row_data) + else: + result = tuple(row_data.values()) + if key is not None: + return key(result) + return result + + ordered_rows = sorted( + self._data.items(), + key=key_wrapper, + reverse=reverse, + ) + self._row_locations = TwoWayDict( + {row_key: new_index for new_index, (row_key, _) in enumerate(ordered_rows)} + ) + self._update_count += 1 + self.refresh() + return self + + def _scroll_cursor_into_view(self, animate: bool = False) -> None: + """When the cursor is at a boundary of the DataTable and moves out + of view, this method handles scrolling to ensure it remains visible.""" + fixed_offset = self._get_fixed_offset() + top, _, _, left = fixed_offset + + if self.cursor_type == "row": + x, y, width, height = self._get_row_region(self.cursor_row) + region = Region(int(self.scroll_x) + left, y, width - left, height) + elif self.cursor_type == "column": + x, y, width, height = self._get_column_region(self.cursor_column) + region = Region(x, int(self.scroll_y) + top, width, height - top) + else: + region = self._get_cell_region(self.cursor_coordinate) + + self.scroll_to_region(region, animate=animate, spacing=fixed_offset, force=True) + + def _set_hover_cursor(self, active: bool) -> None: + """Set whether the hover cursor (the faint cursor you see when you + hover the mouse cursor over a cell) is visible or not. Typically, + when you interact with the keyboard, you want to switch the hover + cursor off. + + Args: + active: Display the hover cursor. + """ + self._show_hover_cursor = active + cursor_type = self.cursor_type + if cursor_type == "column": + self.refresh_column(self.hover_column) + elif cursor_type == "row": + self.refresh_row(self.hover_row) + elif cursor_type == "cell": + self.refresh_coordinate(self.hover_coordinate) + + async def _on_click(self, event: events.Click) -> None: + self._set_hover_cursor(True) + meta = event.style.meta + if "row" not in meta or "column" not in meta: + return + if self.cursor_type != "row" and meta.get("out_of_bounds", False): + return + + row_index = meta["row"] + column_index = meta["column"] + is_header_click = self.show_header and row_index == -1 + is_row_label_click = self.show_row_labels and column_index == -1 + if is_header_click: + # Header clicks work even if cursor is off, and doesn't move the cursor. + column = self.ordered_columns[column_index] + message = DataTable.HeaderSelected( + self, column.key, column_index, label=column.label + ) + self.post_message(message) + elif is_row_label_click: + row = self.ordered_rows[row_index] + message = DataTable.RowLabelSelected( + self, row.key, row_index, label=row.label + ) + self.post_message(message) + elif self.show_cursor and self.cursor_type != "none": + # Only post selection events if there is a visible row/col/cell cursor. + new_coordinate = Coordinate(row_index, column_index) + highlight_click = new_coordinate == self.cursor_coordinate + self.cursor_coordinate = new_coordinate + if highlight_click: + self._post_selected_message() + self._scroll_cursor_into_view(animate=True) + event.stop() + + def action_page_down(self) -> None: + """Move the cursor one page down.""" + self._set_hover_cursor(False) + if self.show_cursor and self.cursor_type in ("cell", "row"): + height = self.scrollable_content_region.height - ( + self.header_height if self.show_header else 0 + ) + + # Determine how many rows constitutes a "page" + offset = 0 + rows_to_scroll = 0 + row_index, _ = self.cursor_coordinate + for ordered_row in self.ordered_rows[row_index:]: + offset += ordered_row.height + rows_to_scroll += 1 + if offset > height: + break + + target_row = row_index + rows_to_scroll - 1 + self.scroll_relative(y=height, animate=False, force=True) + self.move_cursor(row=target_row, scroll=False) + else: + super().action_page_down() + + def action_page_up(self) -> None: + """Move the cursor one page up.""" + self._set_hover_cursor(False) + if self.show_cursor and self.cursor_type in ("cell", "row"): + height = self.scrollable_content_region.height - ( + self.header_height if self.show_header else 0 + ) + + # Determine how many rows constitutes a "page" + offset = 0 + rows_to_scroll = 0 + row_index, _ = self.cursor_coordinate + for ordered_row in self.ordered_rows[: row_index + 1]: + offset += ordered_row.height + rows_to_scroll += 1 + if offset > height: + break + + target_row = row_index - rows_to_scroll + 1 + self.scroll_relative(y=-height, animate=False) + self.move_cursor(row=target_row, scroll=False) + else: + super().action_page_up() + + def action_page_left(self) -> None: + """Move the cursor one page left.""" + self._set_hover_cursor(False) + super().scroll_page_left() + + def action_page_right(self) -> None: + """Move the cursor one page right.""" + self._set_hover_cursor(False) + super().scroll_page_right() + + def action_scroll_top(self) -> None: + """Move the cursor and scroll to the top.""" + self._set_hover_cursor(False) + cursor_type = self.cursor_type + if self.show_cursor and (cursor_type == "cell" or cursor_type == "row"): + _, column_index = self.cursor_coordinate + self.cursor_coordinate = Coordinate(0, column_index) + else: + super().action_scroll_home() + + def action_scroll_bottom(self) -> None: + """Move the cursor and scroll to the bottom.""" + self._set_hover_cursor(False) + cursor_type = self.cursor_type + if self.show_cursor and (cursor_type == "cell" or cursor_type == "row"): + _, column_index = self.cursor_coordinate + self.cursor_coordinate = Coordinate(self.row_count - 1, column_index) + else: + super().action_scroll_end() + + def action_scroll_home(self) -> None: + """Move the cursor and scroll to the leftmost column.""" + self._set_hover_cursor(False) + cursor_type = self.cursor_type + if self.show_cursor and (cursor_type == "cell" or cursor_type == "column"): + self.move_cursor(column=0) + else: + self.scroll_x = 0 + + def action_scroll_end(self) -> None: + """Move the cursor and scroll to the rightmost column.""" + self._set_hover_cursor(False) + cursor_type = self.cursor_type + if self.show_cursor and (cursor_type == "cell" or cursor_type == "column"): + self.move_cursor(column=len(self.columns) - 1) + else: + self.scroll_x = self.max_scroll_x + + def action_cursor_up(self) -> None: + self._set_hover_cursor(False) + cursor_type = self.cursor_type + if self.show_cursor and (cursor_type == "cell" or cursor_type == "row"): + self.cursor_coordinate = self.cursor_coordinate.up() + else: + # If the cursor doesn't move up (e.g. column cursor can't go up), + # then ensure that we instead scroll the DataTable. + super().action_scroll_up() + + def action_cursor_down(self) -> None: + self._set_hover_cursor(False) + cursor_type = self.cursor_type + if self.show_cursor and (cursor_type == "cell" or cursor_type == "row"): + self.cursor_coordinate = self.cursor_coordinate.down() + else: + super().action_scroll_down() + + def action_cursor_right(self) -> None: + self._set_hover_cursor(False) + cursor_type = self.cursor_type + if self.show_cursor and (cursor_type == "cell" or cursor_type == "column"): + self.cursor_coordinate = self.cursor_coordinate.right() + self._scroll_cursor_into_view(animate=True) + else: + super().action_scroll_right() + + def action_cursor_left(self) -> None: + self._set_hover_cursor(False) + cursor_type = self.cursor_type + if self.show_cursor and (cursor_type == "cell" or cursor_type == "column"): + self.cursor_coordinate = self.cursor_coordinate.left() + self._scroll_cursor_into_view(animate=True) + else: + super().action_scroll_left() + + def action_select_cursor(self) -> None: + self._set_hover_cursor(False) + if self.show_cursor and self.cursor_type != "none": + self._post_selected_message() + + def _post_selected_message(self): + """Post the appropriate message for a selection based on the `cursor_type`.""" + cursor_coordinate = self.cursor_coordinate + cursor_type = self.cursor_type + if len(self._data) == 0: + return + cell_key = self.coordinate_to_cell_key(cursor_coordinate) + if cursor_type == "cell": + self.post_message( + DataTable.CellSelected( + self, + self.get_cell_at(cursor_coordinate), + coordinate=cursor_coordinate, + cell_key=cell_key, + ) + ) + elif cursor_type == "row": + row_index, _ = cursor_coordinate + row_key, _ = cell_key + self.post_message(DataTable.RowSelected(self, row_index, row_key)) + elif cursor_type == "column": + _, column_index = cursor_coordinate + _, column_key = cell_key + self.post_message(DataTable.ColumnSelected(self, column_index, column_key)) diff --git a/src/memray/_vendor/textual/widgets/_digits.py b/src/memray/_vendor/textual/widgets/_digits.py new file mode 100644 index 0000000000..5b87a5b4c8 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_digits.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, cast + +from rich.align import Align, AlignMethod + +if TYPE_CHECKING: + from memray._vendor.textual.app import RenderResult + +from memray._vendor.textual.geometry import Size +from memray._vendor.textual.renderables.digits import Digits as DigitsRenderable +from memray._vendor.textual.selection import Selection +from memray._vendor.textual.widget import Widget + + +class Digits(Widget): + """A widget to display numerical values using a 3x3 grid of unicode characters.""" + + DEFAULT_CSS = """ + Digits { + width: 1fr; + height: auto; + text-align: left; + box-sizing: border-box; + } + """ + + def __init__( + self, + value: str = "", + *, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + ) -> None: + """Initialize a Digits widget. + + Args: + value: Value to display in widget. + name: The name of the widget. + id: The ID of the widget in the DOM. + classes: The CSS classes of the widget. + disabled: Whether the widget is disabled or not. + + """ + if not isinstance(value, str): + raise TypeError("value must be a str") + super().__init__(name=name, id=id, classes=classes, disabled=disabled) + self._value = value + + @property + def value(self) -> str: + """The current value displayed in the Digits.""" + return self._value + + def get_selection(self, selection: Selection) -> str | None: + return self._value + + def update(self, value: str) -> None: + """Update the Digits with a new value. + + Args: + value: New value to display. + + Raises: + TypeError: If the value isn't a `str`. + """ + if not isinstance(value, str): + raise TypeError("value must be a str") + layout_required = len(value) != len(self._value) or ( + DigitsRenderable.get_width(self._value) != DigitsRenderable.get_width(value) + ) + self._value = value + self.refresh(layout=layout_required) + + def render(self) -> RenderResult: + """Render digits.""" + rich_style = self.rich_style + if self.text_selection: + rich_style += self.selection_style + digits = DigitsRenderable(self._value, rich_style) + text_align = self.styles.text_align + align = "left" if text_align not in {"left", "center", "right"} else text_align + return Align(digits, cast(AlignMethod, align), rich_style) + + def get_content_width(self, container: Size, viewport: Size) -> int: + """Called by textual to get the width of the content area. + + Args: + container: Size of the container (immediate parent) widget. + viewport: Size of the viewport. + + Returns: + The optimal width of the content. + """ + return DigitsRenderable.get_width(self._value) + + def get_content_height(self, container: Size, viewport: Size, width: int) -> int: + """Called by Textual to get the height of the content area. + + Args: + container: Size of the container (immediate parent) widget. + viewport: Size of the viewport. + width: Width of renderable. + + Returns: + The height of the content. + """ + return 3 # Always 3 lines diff --git a/src/memray/_vendor/textual/widgets/_directory_tree.py b/src/memray/_vendor/textual/widgets/_directory_tree.py new file mode 100644 index 0000000000..68a25e0e76 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_directory_tree.py @@ -0,0 +1,585 @@ +from __future__ import annotations + +import asyncio +from asyncio import Queue +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Callable, ClassVar, Iterable, Iterator + +from rich.style import Style +from rich.text import Text, TextType + +from memray._vendor.textual import work +from memray._vendor.textual.await_complete import AwaitComplete +from memray._vendor.textual.message import Message +from memray._vendor.textual.reactive import var +from memray._vendor.textual.widgets._tree import TOGGLE_STYLE, Tree, TreeNode +from memray._vendor.textual.worker import Worker, WorkerCancelled, WorkerFailed, get_current_worker + +if TYPE_CHECKING: + from typing_extensions import Self + + +@dataclass +class DirEntry: + """Attaches directory information to a [`DirectoryTree`][textual.widgets.DirectoryTree] node.""" + + path: Path + """The path of the directory entry.""" + loaded: bool = False + """Has this been loaded?""" + + +class DirectoryTree(Tree[DirEntry]): + """A Tree widget that presents files and directories.""" + + ICON_NODE_EXPANDED = "📂 " + ICON_NODE = "📁 " + ICON_FILE = "📄 " + """Unicode 'icon' to represent a file.""" + + COMPONENT_CLASSES: ClassVar[set[str]] = { + "directory-tree--extension", + "directory-tree--file", + "directory-tree--folder", + "directory-tree--hidden", + } + """ + | Class | Description | + | :- | :- | + | `directory-tree--extension` | Target the extension of a file name. | + | `directory-tree--file` | Target files in the directory structure. | + | `directory-tree--folder` | Target folders in the directory structure. | + | `directory-tree--hidden` | Target hidden items in the directory structure. | + + See also the [component classes for `Tree`][textual.widgets.Tree.COMPONENT_CLASSES]. + """ + + DEFAULT_CSS = """ + DirectoryTree { + + & > .directory-tree--folder { + text-style: bold; + } + + & > .directory-tree--extension { + text-style: italic; + } + + & > .directory-tree--hidden { + text-style: dim; + } + + &:ansi { + + & > .tree--guides { + color: transparent; + } + + & > .directory-tree--folder { + text-style: bold; + } + + & > .directory-tree--extension { + text-style: italic; + } + + & > .directory-tree--hidden { + color: ansi_default; + text-style: dim; + } + } + + } + + """ + + PATH: Callable[[str | Path], Path] = Path + """Callable that returns a fresh path object.""" + + class FileSelected(Message): + """Posted when a file is selected. + + Can be handled using `on_directory_tree_file_selected` in a subclass of + `DirectoryTree` or in a parent widget in the DOM. + """ + + def __init__(self, node: TreeNode[DirEntry], path: Path) -> None: + """Initialise the FileSelected object. + + Args: + node: The tree node for the file that was selected. + path: The path of the file that was selected. + """ + super().__init__() + self.node: TreeNode[DirEntry] = node + """The tree node of the file that was selected.""" + self.path: Path = path + """The path of the file that was selected.""" + + @property + def control(self) -> Tree[DirEntry]: + """The `Tree` that had a file selected.""" + return self.node.tree + + class DirectorySelected(Message): + """Posted when a directory is selected. + + Can be handled using `on_directory_tree_directory_selected` in a + subclass of `DirectoryTree` or in a parent widget in the DOM. + """ + + def __init__(self, node: TreeNode[DirEntry], path: Path) -> None: + """Initialise the DirectorySelected object. + + Args: + node: The tree node for the directory that was selected. + path: The path of the directory that was selected. + """ + super().__init__() + self.node: TreeNode[DirEntry] = node + """The tree node of the directory that was selected.""" + self.path: Path = path + """The path of the directory that was selected.""" + + @property + def control(self) -> Tree[DirEntry]: + """The `Tree` that had a directory selected.""" + return self.node.tree + + path: var[str | Path] = var["str | Path"](PATH("."), init=False, always_update=True) + """The path that is the root of the directory tree. + + Note: + This can be set to either a `str` or a `pathlib.Path` object, but + the value will always be a `pathlib.Path` object. + """ + + def __init__( + self, + path: str | Path, + *, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + ) -> None: + """Initialise the directory tree. + + Args: + path: Path to directory. + name: The name of the widget, or None for no name. + id: The ID of the widget in the DOM, or None for no ID. + classes: A space-separated list of classes, or None for no classes. + disabled: Whether the directory tree is disabled or not. + """ + self._load_queue: Queue[TreeNode[DirEntry]] = Queue() + super().__init__( + str(path), + data=DirEntry(self.PATH(path)), + name=name, + id=id, + classes=classes, + disabled=disabled, + ) + self.path = path + + def _add_to_load_queue(self, node: TreeNode[DirEntry]) -> AwaitComplete: + """Add the given node to the load queue. + + The return value can optionally be awaited until the queue is empty. + + Args: + node: The node to add to the load queue. + + Returns: + An optionally awaitable object that can be awaited until the + load queue has finished processing. + """ + assert node.data is not None + if not node.data.loaded: + node.data.loaded = True + self._load_queue.put_nowait(node) + + return AwaitComplete(self._load_queue.join()) + + def reload(self) -> AwaitComplete: + """Reload the `DirectoryTree` contents. + + Returns: + An optionally awaitable that ensures the tree has finished reloading. + """ + # Orphan the old queue... + self._load_queue = Queue() + # ... reset the root node ... + processed = self.reload_node(self.root) + # ...and replace the old load with a new one. + self._loader() + return processed + + def clear_node(self, node: TreeNode[DirEntry]) -> Self: + """Clear all nodes under the given node. + + Returns: + The `Tree` instance. + """ + self._clear_line_cache() + node.remove_children() + self._updates += 1 + self.refresh() + return self + + def reset_node( + self, node: TreeNode[DirEntry], label: TextType, data: DirEntry | None = None + ) -> Self: + """Clear the subtree and reset the given node. + + Args: + node: The node to reset. + label: The label for the node. + data: Optional data for the node. + + Returns: + The `Tree` instance. + """ + self.clear_node(node) + node.label = label + node.data = data + return self + + async def _reload(self, node: TreeNode[DirEntry]) -> None: + """Reloads the subtree rooted at the given node while preserving state. + + After reloading the subtree, nodes that were expanded and still exist + will remain expanded and the highlighted node will be preserved, if it + still exists. If it doesn't, highlighting goes up to the first parent + directory that still exists. + + Args: + node: The root of the subtree to reload. + """ + async with self.lock: + # Track nodes that were expanded before reloading. + currently_open: set[Path] = set() + to_check: list[TreeNode[DirEntry]] = [node] + while to_check: + checking = to_check.pop() + if checking.allow_expand and checking.is_expanded: + if checking.data: + currently_open.add(checking.data.path) + to_check.extend(checking.children) + + # Track node that was highlighted before reloading. + highlighted_path: None | Path = None + if self.cursor_line > -1: + highlighted_node = self.get_node_at_line(self.cursor_line) + if highlighted_node is not None and highlighted_node.data is not None: + highlighted_path = highlighted_node.data.path + + if node.data is not None: + self.reset_node( + node, str(node.data.path.name), DirEntry(self.PATH(node.data.path)) + ) + + # Reopen nodes that were expanded and still exist. + to_reopen = [node] + while to_reopen: + reopening = to_reopen.pop() + if not reopening.data: + continue + if reopening.allow_expand and ( + reopening.data.path in currently_open or reopening == node + ): + try: + content = await self._load_directory(reopening).wait() + except (WorkerCancelled, WorkerFailed): + continue + reopening.data.loaded = True + self._populate_node(reopening, content) + to_reopen.extend(reopening.children) + reopening.expand() + + if highlighted_path is None: + return + + # Restore the highlighted path and consider the parents as fallbacks. + looking = [node] + highlight_candidates = set(highlighted_path.parents) + highlight_candidates.add(highlighted_path) + best_found: None | TreeNode[DirEntry] = None + while looking: + checking = looking.pop() + checking_path = ( + checking.data.path if checking.data is not None else None + ) + if checking_path in highlight_candidates: + best_found = checking + if checking_path == highlighted_path: + break + if ( + checking.allow_expand + and checking.is_expanded + and checking_path in highlighted_path.parents + ): + looking.extend(checking.children) + if best_found is not None: + # We need valid lines. Make sure the tree lines have been computed: + _ = self._tree_lines + self.cursor_line = best_found.line + + def reload_node(self, node: TreeNode[DirEntry]) -> AwaitComplete: + """Reload the given node's contents. + + The return value may be awaited to ensure the DirectoryTree has reached + a stable state and is no longer performing any node reloading (of this node + or any other nodes). + + Args: + node: The root of the subtree to reload. + + Returns: + An optionally awaitable that ensures the subtree has finished reloading. + """ + return AwaitComplete(self._reload(node)) + + def validate_path(self, path: str | Path) -> Path: + """Ensure that the path is of the `Path` type. + + Args: + path: The path to validate. + + Returns: + The validated Path value. + + Note: + The result will always be a Python `Path` object, regardless of + the value given. + """ + return self.PATH(path) + + async def watch_path(self) -> None: + """Watch for changes to the `path` of the directory tree. + + If the path is changed the directory tree will be repopulated using + the new value as the root. + """ + has_cursor = self.cursor_node is not None + self.reset_node(self.root, str(self.path), DirEntry(self.PATH(self.path))) + await self.reload() + if has_cursor: + self.cursor_line = 0 + self.scroll_to(0, 0, animate=False) + + def process_label(self, label: TextType) -> Text: + """Process a str or Text into a label. May be overridden in a subclass to modify how labels are rendered. + + Args: + label: Label. + + Returns: + A Rich Text object. + """ + if isinstance(label, str): + text_label = Text(label) + else: + text_label = label + first_line = text_label.split()[0] + return first_line + + def render_label( + self, node: TreeNode[DirEntry], base_style: Style, style: Style + ) -> Text: + """Render a label for the given node. + + Args: + node: A tree node. + base_style: The base style of the widget. + style: The additional style for the label. + + Returns: + A Rich Text object containing the label. + """ + node_label = node._label.copy() + node_label.stylize(style) + + # If the tree isn't mounted yet we can't use component classes to stylize + # the label fully, so we return early. + if not self.is_mounted: + return node_label + + if node._allow_expand: + prefix = ( + self.ICON_NODE_EXPANDED if node.is_expanded else self.ICON_NODE, + base_style + TOGGLE_STYLE, + ) + node_label.stylize_before( + self.get_component_rich_style("directory-tree--folder", partial=True) + ) + else: + prefix = ( + self.ICON_FILE, + base_style, + ) + node_label.stylize_before( + self.get_component_rich_style("directory-tree--file", partial=True), + ) + node_label.highlight_regex( + r"\..+$", + self.get_component_rich_style( + "directory-tree--extension", partial=True + ), + ) + + if node_label.plain.startswith("."): + node_label.stylize_before( + self.get_component_rich_style("directory-tree--hidden", partial=True) + ) + + text = Text.assemble(prefix, node_label) + return text + + def filter_paths(self, paths: Iterable[Path]) -> Iterable[Path]: + """Filter the paths before adding them to the tree. + + Args: + paths: The paths to be filtered. + + Returns: + The filtered paths. + + By default this method returns all of the paths provided. To create + a filtered `DirectoryTree` inherit from it and implement your own + version of this method. + """ + return paths + + @staticmethod + def _safe_is_dir(path: Path) -> bool: + """Safely check if a path is a directory. + + Args: + path: The path to check. + + Returns: + `True` if the path is for a directory, `False` if not. + """ + try: + return path.is_dir() + except OSError: + # We may or may not have been looking at a directory, but we + # don't have the rights or permissions to even know that. Best + # we can do, short of letting the error blow up, is assume it's + # not a directory. A possible improvement in here could be to + # have a third state which is "unknown", and reflect that in the + # tree. + return False + + def _populate_node(self, node: TreeNode[DirEntry], content: Iterable[Path]) -> None: + """Populate the given tree node with the given directory content. + + Args: + node: The Tree node to populate. + content: The collection of `Path` objects to populate the node with. + """ + node.remove_children() + for path in content: + node.add( + path.name, + data=DirEntry(path), + allow_expand=self._safe_is_dir(path), + ) + node.expand() + + def _directory_content(self, location: Path, worker: Worker) -> Iterator[Path]: + """Load the content of a given directory. + + Args: + location: The location to load from. + worker: The worker that the loading is taking place in. + + Yields: + Path: An entry within the location. + """ + try: + for entry in location.iterdir(): + if worker.is_cancelled: + break + yield entry + except OSError: + pass + + @work(thread=True, exit_on_error=False) + def _load_directory(self, node: TreeNode[DirEntry]) -> list[Path]: + """Load the directory contents for a given node. + + Args: + node: The node to load the directory contents for. + + Returns: + The list of entries within the directory associated with the node. + """ + assert node.data is not None + path = node.data.path + path = path.expanduser().resolve() + return sorted( + self.filter_paths(self._directory_content(path, get_current_worker())), + key=lambda path: (not self._safe_is_dir(path), path.name.lower()), + ) + + @work(exclusive=True, group="_loader") + async def _loader(self) -> None: + """Background loading queue processor.""" + worker = get_current_worker() + load_queue = self._load_queue + while not worker.is_cancelled: + # Get the next node that needs loading off the queue. Note that + # this blocks if the queue is empty. + node = await load_queue.get() + content: list[Path] = [] + async with self.lock: + cursor_node = self.cursor_node + try: + # Spin up a short-lived thread that will load the content of + # the directory associated with that node. + content = await self._load_directory(node).wait() + except WorkerCancelled: + # The worker was cancelled, that would suggest we're all + # done here and we should get out of the loader in general. + break + except WorkerFailed: + # This particular worker failed to start. We don't know the + # reason so let's no-op that (for now anyway). + pass + else: + # We're still here and we have directory content, get it into + # the tree. + if content: + self._populate_node(node, content) + if cursor_node is not None: + self.move_cursor(cursor_node, animate=False) + finally: + load_queue.task_done() + + async def _on_tree_node_expanded(self, event: Tree.NodeExpanded[DirEntry]) -> None: + event.stop() + dir_entry = event.node.data + if dir_entry is None: + return + if await asyncio.to_thread(self._safe_is_dir, dir_entry.path): + if event.node.data is not None: + await self._add_to_load_queue(event.node) + else: + if event.node.data is not None: + self.post_message(self.FileSelected(event.node, dir_entry.path)) + + async def _on_tree_node_selected(self, event: Tree.NodeSelected[DirEntry]) -> None: + event.stop() + dir_entry = event.node.data + if dir_entry is None: + return + if await asyncio.to_thread(self._safe_is_dir, dir_entry.path): + if event.node.data is not None: + self.post_message(self.DirectorySelected(event.node, dir_entry.path)) + else: + if event.node.data is not None: + self.post_message(self.FileSelected(event.node, dir_entry.path)) diff --git a/src/memray/_vendor/textual/widgets/_footer.py b/src/memray/_vendor/textual/widgets/_footer.py new file mode 100644 index 0000000000..6bf69c0b82 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_footer.py @@ -0,0 +1,355 @@ +from __future__ import annotations + +from collections import defaultdict +from itertools import groupby +from typing import TYPE_CHECKING + +import rich.repr +from rich.text import Text + +from memray._vendor.textual import events +from memray._vendor.textual.app import ComposeResult +from memray._vendor.textual.binding import Binding +from memray._vendor.textual.containers import HorizontalGroup, ScrollableContainer +from memray._vendor.textual.reactive import reactive +from memray._vendor.textual.widget import Widget +from memray._vendor.textual.widgets import Label + +if TYPE_CHECKING: + from memray._vendor.textual.screen import Screen + + +@rich.repr.auto +class KeyGroup(HorizontalGroup): + DEFAULT_CSS = """ + KeyGroup { + width: auto; + } + """ + + +@rich.repr.auto +class FooterKey(Widget): + ALLOW_SELECT = False + COMPONENT_CLASSES = { + "footer-key--key", + "footer-key--description", + } + + DEFAULT_CSS = """ + FooterKey { + width: auto; + height: 1; + text-wrap: nowrap; + background: $footer-item-background; + .footer-key--key { + color: $footer-key-foreground; + background: $footer-key-background; + text-style: bold; + padding: 0 1; + } + + .footer-key--description { + padding: 0 1 0 0; + color: $footer-description-foreground; + background: $footer-description-background; + } + + &:hover { + color: $footer-key-foreground; + background: $block-hover-background; + } + + &.-disabled { + text-style: dim; + } + + &.-compact { + .footer-key--key { + padding: 0; + } + .footer-key--description { + padding: 0 0 0 1; + } + } + } + """ + + compact = reactive(True) + """Display compact style.""" + + def __init__( + self, + key: str, + key_display: str, + description: str, + action: str, + disabled: bool = False, + tooltip: str = "", + classes="", + ) -> None: + self.key = key + self.key_display = key_display + self.description = description + self.action = action + self._disabled = disabled + if disabled: + classes += " -disabled" + super().__init__(classes=classes) + self.set_reactive(Widget.shrink, False) + if tooltip: + self.tooltip = tooltip + + def render(self) -> Text: + key_style = self.get_component_rich_style("footer-key--key") + description_style = self.get_component_rich_style("footer-key--description") + key_display = self.key_display + key_padding = self.get_component_styles("footer-key--key").padding + description_padding = self.get_component_styles( + "footer-key--description" + ).padding + + description = self.description + if description: + label_text = Text.assemble( + ( + " " * key_padding.left + key_display + " " * key_padding.right, + key_style, + ), + ( + " " * description_padding.left + + description + + " " * description_padding.right, + description_style, + ), + ) + else: + label_text = Text.assemble((key_display, key_style)) + + label_text.stylize_before(self.rich_style) + return label_text + + def on_mouse_down(self) -> None: + if self._disabled: + self.app.bell() + else: + self.app.simulate_key(self.key) + + def _watch_compact(self, compact: bool) -> None: + self.set_class(compact, "-compact") + + +class FooterLabel(Label): + """Text displayed in the footer (used by binding groups).""" + + +@rich.repr.auto +class Footer(ScrollableContainer, can_focus=False, can_focus_children=False): + ALLOW_SELECT = False + DEFAULT_CSS = """ + Footer { + layout: horizontal; + color: $footer-foreground; + background: $footer-background; + dock: bottom; + height: 1; + scrollbar-size: 0 0; + &.-compact { + FooterLabel { + margin: 0; + } + FooterKey { + margin-right: 1; + } + FooterKey.-grouped { + margin: 0 1; + } + FooterKey.-command-palette { + padding-right: 0; + } + } + FooterKey.-command-palette { + dock: right; + padding-right: 1; + border-left: vkey $foreground 20%; + } + HorizontalGroup.binding-group { + width: auto; + height: 1; + layout: horizontal; + } + KeyGroup.-compact { + FooterKey.-grouped { + margin: 0; + } + margin: 0 1 0 0; + padding-left: 1; + } + + FooterKey.-grouped { + margin: 0 1; + } + FooterLabel { + margin: 0 1 0 0; + color: $footer-description-foreground; + background: $footer-description-background; + } + + &:ansi { + background: ansi_default; + .footer-key--key { + background: ansi_default; + color: ansi_magenta; + } + .footer-key--description { + background: ansi_default; + color: ansi_default; + } + FooterKey:hover { + text-style: underline; + background: ansi_default; + color: ansi_default; + .footer-key--key { + background: ansi_default; + } + } + FooterKey.-command-palette { + background: ansi_default; + border-left: vkey ansi_black; + } + } + + } + """ + + compact = reactive(False, toggle_class="-compact") + """Display in compact style.""" + _bindings_ready = reactive(False, repaint=False) + """True if the bindings are ready to be displayed.""" + show_command_palette = reactive(True) + """Show the key to invoke the command palette.""" + combine_groups = reactive(True) + """Combine bindings in the same group?""" + + def __init__( + self, + *children: Widget, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + show_command_palette: bool = True, + compact: bool = False, + ) -> None: + """A footer to show key bindings. + + Args: + *children: Child widgets. + name: The name of the widget. + id: The ID of the widget in the DOM. + classes: The CSS classes for the widget. + disabled: Whether the widget is disabled or not. + show_command_palette: Show key binding to invoke the command palette, on the right of the footer. + compact: Display a compact style (less whitespace) footer. + """ + super().__init__( + *children, + name=name, + id=id, + classes=classes, + disabled=disabled, + ) + self.set_reactive(Footer.show_command_palette, show_command_palette) + self.set_reactive(Footer.compact, compact) + self.set_class(compact, "-compact", update=False) + + def compose(self) -> ComposeResult: + if not self._bindings_ready: + return + active_bindings = self.screen.active_bindings + bindings = [ + (binding, enabled, tooltip) + for (_, binding, enabled, tooltip) in active_bindings.values() + if binding.show + ] + action_to_bindings: defaultdict[str, list[tuple[Binding, bool, str]]] + action_to_bindings = defaultdict(list) + for binding, enabled, tooltip in bindings: + action_to_bindings[binding.action].append((binding, enabled, tooltip)) + + self.styles.grid_size_columns = len(action_to_bindings) + + for group, multi_bindings_iterable in groupby( + action_to_bindings.values(), + lambda multi_bindings_: multi_bindings_[0][0].group, + ): + multi_bindings = list(multi_bindings_iterable) + if group is not None and len(multi_bindings) > 1: + with KeyGroup(classes="-compact" if group.compact else ""): + for multi_bindings in multi_bindings: + binding, enabled, tooltip = multi_bindings[0] + yield FooterKey( + binding.key, + self.app.get_key_display(binding), + "", + binding.action, + disabled=not enabled, + tooltip=tooltip or binding.description, + classes="-grouped", + ).data_bind(compact=Footer.compact) + yield FooterLabel(group.description) + else: + for multi_bindings in multi_bindings: + binding, enabled, tooltip = multi_bindings[0] + yield FooterKey( + binding.key, + self.app.get_key_display(binding), + binding.description, + binding.action, + disabled=not enabled, + tooltip=tooltip, + ).data_bind(compact=Footer.compact) + if self.show_command_palette and self.app.ENABLE_COMMAND_PALETTE: + try: + _node, binding, enabled, tooltip = active_bindings[ + self.app.COMMAND_PALETTE_BINDING + ] + except KeyError: + pass + else: + yield FooterKey( + binding.key, + self.app.get_key_display(binding), + binding.description, + binding.action, + classes="-command-palette", + disabled=not enabled, + tooltip=binding.tooltip or binding.description, + ) + + def bindings_changed(self, screen: Screen) -> None: + self._bindings_ready = True + if not screen.app.app_focus: + return + if self.is_attached and screen is self.screen: + self.call_after_refresh(self.recompose) + + def _on_mouse_scroll_down(self, event: events.MouseScrollDown) -> None: + if self.allow_horizontal_scroll: + self.release_anchor() + if self._scroll_right_for_pointer(animate=True): + event.stop() + event.prevent_default() + + def _on_mouse_scroll_up(self, event: events.MouseScrollUp) -> None: + if self.allow_horizontal_scroll: + self.release_anchor() + if self._scroll_left_for_pointer(animate=True): + event.stop() + event.prevent_default() + + def on_mount(self) -> None: + self.screen.bindings_updated_signal.subscribe(self, self.bindings_changed) + + def on_unmount(self) -> None: + self.screen.bindings_updated_signal.unsubscribe(self) diff --git a/src/memray/_vendor/textual/widgets/_header.py b/src/memray/_vendor/textual/widgets/_header.py new file mode 100644 index 0000000000..5522e12b5b --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_header.py @@ -0,0 +1,228 @@ +"""Provides a Textual application header widget.""" + +from __future__ import annotations + +from datetime import datetime + +from rich.text import Text + +from memray._vendor.textual.app import ComposeResult, RenderResult +from memray._vendor.textual.content import Content +from memray._vendor.textual.dom import NoScreen +from memray._vendor.textual.events import Click, Mount +from memray._vendor.textual.reactive import Reactive +from memray._vendor.textual.widget import Widget +from memray._vendor.textual.widgets import Static + + +class HeaderIcon(Widget): + """Display an 'icon' on the left of the header.""" + + DEFAULT_CSS = """ + HeaderIcon { + dock: left; + padding: 0 1; + width: 8; + content-align: left middle; + } + + HeaderIcon:hover { + background: $foreground 10%; + } + """ + + icon = Reactive("⭘") + """The character to use as the icon within the header.""" + + def on_mount(self) -> None: + if self.app.ENABLE_COMMAND_PALETTE: + self.tooltip = "Open the command palette" + else: + self.disabled = True + + async def on_click(self, event: Click) -> None: + """Launch the command palette when icon is clicked.""" + event.stop() + await self.run_action("app.command_palette") + + def render(self) -> RenderResult: + """Render the header icon. + + Returns: + The rendered icon. + """ + return self.icon + + +class HeaderClockSpace(Widget): + """The space taken up by the clock on the right of the header.""" + + DEFAULT_CSS = """ + HeaderClockSpace { + dock: right; + width: 10; + padding: 0 1; + } + """ + + def render(self) -> RenderResult: + """Render the header clock space. + + Returns: + The rendered space. + """ + return "" + + +class HeaderClock(HeaderClockSpace): + """Display a clock on the right of the header.""" + + DEFAULT_CSS = """ + HeaderClock { + background: $foreground-darken-1 5%; + color: $foreground; + text-opacity: 85%; + content-align: center middle; + } + """ + + time_format: Reactive[str] = Reactive("%X") + + def _on_mount(self, _: Mount) -> None: + self.set_interval(1, callback=self.refresh, name="update header clock") + + def render(self) -> RenderResult: + """Render the header clock. + + Returns: + The rendered clock. + """ + return Text(datetime.now().time().strftime(self.time_format)) + + +class HeaderTitle(Static): + """Display the title / subtitle in the header.""" + + DEFAULT_CSS = """ + HeaderTitle { + text-wrap: nowrap; + text-overflow: ellipsis; + content-align: center middle; + width: 100%; + } + """ + + +class Header(Widget): + """A header widget with icon and clock.""" + + DEFAULT_CSS = """ + Header { + dock: top; + width: 100%; + background: $panel; + color: $foreground; + height: 1; + } + Header.-tall { + height: 3; + } + """ + + DEFAULT_CLASSES = "" + + tall: Reactive[bool] = Reactive(False) + """Set to `True` for a taller header or `False` for a single line header.""" + + icon: Reactive[str] = Reactive("⭘") + """A character for the icon at the top left.""" + + time_format: Reactive[str] = Reactive("%X") + """Time format of the clock.""" + + def __init__( + self, + show_clock: bool = False, + *, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + icon: str | None = None, + time_format: str | None = None, + ): + """Initialise the header widget. + + Args: + show_clock: ``True`` if the clock should be shown on the right of the header. + name: The name of the header widget. + id: The ID of the header widget in the DOM. + classes: The CSS classes of the header widget. + icon: Single character to use as an icon, or `None` for default. + time_format: Time format (used by strftime) for clock, or `None` for default. + """ + super().__init__(name=name, id=id, classes=classes) + self._show_clock = show_clock + if icon is not None: + self.icon = icon + if time_format is not None: + self.time_format = time_format + + def compose(self) -> ComposeResult: + yield HeaderIcon().data_bind(Header.icon) + yield HeaderTitle() + yield ( + HeaderClock().data_bind(Header.time_format) + if self._show_clock + else HeaderClockSpace() + ) + + def watch_tall(self, tall: bool) -> None: + self.set_class(tall, "-tall") + + def _on_click(self): + self.toggle_class("-tall") + + def format_title(self) -> Content: + """Format the title and subtitle. + + Defers to [App.format_title][textual.app.App.format_title] by default. + Override this method if you want to customize how the title is displayed in the header. + + Returns: + Content for title display. + """ + return self.app.format_title(self.screen_title, self.screen_sub_title) + + @property + def screen_title(self) -> str: + """The title that this header will display. + + This depends on [`Screen.title`][textual.screen.Screen.title] and [`App.title`][textual.app.App.title]. + """ + screen_title = self.screen.title + title = screen_title if screen_title is not None else self.app.title + return title + + @property + def screen_sub_title(self) -> str: + """The sub-title that this header will display. + + This depends on [`Screen.sub_title`][textual.screen.Screen.sub_title] and [`App.sub_title`][textual.app.App.sub_title]. + """ + screen_sub_title = self.screen.sub_title + sub_title = ( + screen_sub_title if screen_sub_title is not None else self.app.sub_title + ) + return sub_title + + def _on_mount(self, _: Mount) -> None: + async def set_title() -> None: + try: + self.query_one(HeaderTitle).update(self.format_title()) + except NoScreen: + pass + + self.watch(self.app, "title", set_title) + self.watch(self.app, "sub_title", set_title) + self.watch(self.screen, "title", set_title) + self.watch(self.screen, "sub_title", set_title) diff --git a/src/memray/_vendor/textual/widgets/_help_panel.py b/src/memray/_vendor/textual/widgets/_help_panel.py new file mode 100644 index 0000000000..8718f40af6 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_help_panel.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +from textwrap import dedent + +from memray._vendor.textual.app import ComposeResult +from memray._vendor.textual.css.query import NoMatches +from memray._vendor.textual.widget import Widget +from memray._vendor.textual.widgets import KeyPanel, Markdown + + +class HelpPanel(Widget): + """ + Shows context sensitive help for the currently focused widget. + """ + + DEFAULT_CSS = """ + + HelpPanel { + split: right; + width: 33%; + min-width: 30; + max-width: 60; + border-left: vkey $foreground 30%; + padding: 0 1; + height: 1fr; + padding-right: 1; + layout: vertical; + height: 100%; + + &:ansi { + background: ansi_default; + border-left: vkey ansi_black; + + Markdown, KeyPanel { + background: ansi_default; + } + .bindings-table--divide { + color: transparent; + } + } + + #widget-help { + height: auto; + max-height: 50%; + width: 1fr; + padding: 0; + margin: 0; + padding: 1 0; + margin-top: 1; + display: none; + background: $panel; + + &:ansi { + background: ansi_default; + } + + MarkdownBlock { + padding-left: 2; + padding-right: 2; + } + } + + &.-show-help #widget-help { + display: block; + } + + KeyPanel#keys-help { + width: 1fr; + height: 1fr; + min-width: initial; + split: initial; + border-left: none; + padding: 0; + } + } + + """ + + DEFAULT_CLASSES = "-textual-system" + + def on_mount(self): + def update_help(focused_widget: Widget | None): + self.update_help(focused_widget) + + self.watch(self.screen, "focused", update_help) + + def update_help(self, focused_widget: Widget | None) -> None: + """Update the help for the focused widget. + + Args: + focused_widget: The currently focused widget, or `None` if no widget was focused. + """ + if not self.app.app_focus: + return + if not self.screen.is_active: + return + self.set_class(focused_widget is not None, "-show-help") + if focused_widget is not None: + help: str = "" + for node in focused_widget.ancestors_with_self: + if isinstance(node, Widget) and node.HELP: + help = node.HELP + break + if not help: + self.remove_class("-show-help") + try: + self.query_one(Markdown).update(dedent(help.rstrip())) + except NoMatches: + pass + + def compose(self) -> ComposeResult: + yield Markdown(id="widget-help") + yield KeyPanel(id="keys-help") diff --git a/src/memray/_vendor/textual/widgets/_input.py b/src/memray/_vendor/textual/widgets/_input.py new file mode 100644 index 0000000000..4dc1253443 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_input.py @@ -0,0 +1,1123 @@ +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import TYPE_CHECKING, ClassVar, Iterable, NamedTuple + +from rich.cells import cell_len, get_character_cell_size +from rich.console import RenderableType +from rich.highlighter import Highlighter +from rich.text import Text +from typing_extensions import Literal + +from memray._vendor.textual import events +from memray._vendor.textual.actions import SkipAction +from memray._vendor.textual.expand_tabs import expand_tabs_inline +from memray._vendor.textual.screen import Screen +from memray._vendor.textual.scroll_view import ScrollView +from memray._vendor.textual.strip import Strip + +if TYPE_CHECKING: + pass + +from memray._vendor.textual.binding import Binding, BindingType +from memray._vendor.textual.css._error_tools import friendly_list +from memray._vendor.textual.events import Blur, Focus, Mount +from memray._vendor.textual.geometry import Offset, Region, Size, clamp +from memray._vendor.textual.message import Message +from memray._vendor.textual.reactive import Reactive, reactive, var +from memray._vendor.textual.suggester import Suggester, SuggestionReady +from memray._vendor.textual.timer import Timer +from memray._vendor.textual.validation import ValidationResult, Validator + +InputValidationOn = Literal["blur", "changed", "submitted"] +"""Possible messages that trigger input validation.""" +_POSSIBLE_VALIDATE_ON_VALUES = {"blur", "changed", "submitted"} +"""Set literal with the legal values for the type `InputValidationOn`.""" + +_RESTRICT_TYPES = { + "integer": r"[-+]?(?:\d*|\d+_)*", + "number": r"[-+]?(?:\d*|\d+_)*\.?(?:\d*|\d+_)*(?:\d[eE]?[-+]?(?:\d*|\d+_)*)?", + "text": None, +} +InputType = Literal["integer", "number", "text"] + + +class Selection(NamedTuple): + """A range of selected text within the Input. + + Text can be selected by clicking and dragging the mouse, or by pressing + shift+arrow keys. + + Attributes: + start: The start index of the selection. + end: The end index of the selection. + """ + + start: int + end: int + + @classmethod + def cursor(cls, cursor_position: int) -> Selection: + """Create a selection from a cursor position.""" + return cls(cursor_position, cursor_position) + + @property + def is_empty(self) -> bool: + """Return True if the selection is empty.""" + return self.start == self.end + + +class Input(ScrollView): + """A text input widget.""" + + BINDING_GROUP_TITLE = "Input" + BINDINGS: ClassVar[list[BindingType]] = [ + Binding("left", "cursor_left", "Move cursor left", show=False), + Binding( + "shift+left", + "cursor_left(True)", + "Move cursor left and select", + show=False, + ), + Binding("ctrl+left", "cursor_left_word", "Move cursor left a word", show=False), + Binding( + "ctrl+shift+left", + "cursor_left_word(True)", + "Move cursor left a word and select", + show=False, + ), + Binding( + "right", + "cursor_right", + "Move cursor right or accept the completion suggestion", + show=False, + ), + Binding( + "shift+right", + "cursor_right(True)", + "Move cursor right and select", + show=False, + ), + Binding( + "ctrl+right", + "cursor_right_word", + "Move cursor right a word", + show=False, + ), + Binding( + "ctrl+shift+right", + "cursor_right_word(True)", + "Move cursor right a word and select", + show=False, + ), + Binding("backspace", "delete_left", "Delete character left", show=False), + Binding("ctrl+shift+a", "select_all", "Select all", show=False), + Binding("home,ctrl+a", "home", "Go to start", show=False), + Binding("end,ctrl+e", "end", "Go to end", show=False), + Binding("shift+home", "home(True)", "Select line start", show=False), + Binding("shift+end", "end(True)", "Select line end", show=False), + Binding("delete,ctrl+d", "delete_right", "Delete character right", show=False), + Binding("enter", "submit", "Submit", show=False), + Binding( + "ctrl+w", "delete_left_word", "Delete left to start of word", show=False + ), + Binding("ctrl+u", "delete_left_all", "Delete all to the left", show=False), + Binding( + "ctrl+f", "delete_right_word", "Delete right to start of word", show=False + ), + Binding("ctrl+k", "delete_right_all", "Delete all to the right", show=False), + Binding("ctrl+x", "cut", "Cut selected text", show=False), + Binding("ctrl+c,super+c", "copy", "Copy selected text", show=False), + Binding("ctrl+v", "paste", "Paste text from the clipboard", show=False), + ] + """ + | Key(s) | Description | + | :- | :- | + | left | Move the cursor left. | + | shift+left | Move cursor left and select. | + | ctrl+left | Move the cursor one word to the left. | + | right | Move the cursor right or accept the completion suggestion. | + | ctrl+shift+left | Move cursor left a word and select. | + | shift+right | Move cursor right and select. | + | ctrl+right | Move the cursor one word to the right. | + | backspace | Delete the character to the left of the cursor. | + | ctrl+shift+right | Move cursor right a word and select. | + | ctrl+shift+a | Select all text in the input. | + | home,ctrl+a | Go to the beginning of the input. | + | end,ctrl+e | Go to the end of the input. | + | shift+home | Select up to the input start. | + | shift+end | Select up to the input end. | + | delete,ctrl+d | Delete the character to the right of the cursor. | + | enter | Submit the current value of the input. | + | ctrl+w | Delete the word to the left of the cursor. | + | ctrl+u | Delete everything to the left of the cursor. | + | ctrl+f | Delete the word to the right of the cursor. | + | ctrl+k | Delete everything to the right of the cursor. | + | ctrl+x | Cut selected text. | + | ctrl+c | Copy selected text. | + | ctrl+v | Paste text from the clipboard. | + """ + + COMPONENT_CLASSES: ClassVar[set[str]] = { + "input--cursor", + "input--placeholder", + "input--suggestion", + "input--selection", + } + """ + | Class | Description | + | :- | :- | + | `input--cursor` | Target the cursor. | + | `input--placeholder` | Target the placeholder text (when it exists). | + | `input--suggestion` | Target the auto-completion suggestion (when it exists). | + | `input--selection` | Target the selected text. | + """ + + DEFAULT_CSS = """ + Input { + background: $surface; + color: $foreground; + padding: 0 2; + border: tall $border-blurred; + width: 100%; + height: 3; + scrollbar-size-horizontal: 0; + pointer: text; + + &.-textual-compact { + border: none !important; + height: 1; + padding: 0; + &.-invalid { + background-tint: $error 20%; + } + } + + &:focus { + border: tall $border; + background-tint: $foreground 5%; + } + &>.input--cursor { + background: $input-cursor-background; + color: $input-cursor-foreground; + text-style: $input-cursor-text-style; + } + &>.input--selection { + background: $input-selection-background; + } + &>.input--placeholder, &>.input--suggestion { + color: $text-disabled; + } + &.-invalid { + border: tall $error 60%; + } + &.-invalid:focus { + border: tall $error; + } + + &:ansi { + background: ansi_default; + color: ansi_default; + &>.input--cursor { + background: ansi_white; + color: ansi_black; + } + &>.input--placeholder, &>.input--suggestion { + text-style: dim; + color: ansi_default; + } + &.-invalid { + border: tall ansi_red; + } + &.-invalid:focus { + border: tall ansi_red; + } + } + } + + """ + + cursor_blink = reactive(True, init=False) + # TODO - check with width: auto to see if layout=True is needed + value: Reactive[str] = reactive("", init=False) + + @property + def cursor_position(self) -> int: + """The current position of the cursor, corresponding to the end of the selection.""" + return self.selection.end + + @cursor_position.setter + def cursor_position(self, position: int) -> None: + """Set the current position of the cursor.""" + self.selection = Selection.cursor(position) + + selection: Reactive[Selection] = reactive(Selection.cursor(0)) + """The currently selected range of text.""" + + placeholder = reactive("") + _cursor_visible = reactive(True) + password = reactive(False) + suggester: Suggester | None + """The suggester used to provide completions as the user types.""" + _suggestion = reactive("") + """A completion suggestion for the current value in the input.""" + restrict = var["str | None"](None) + """A regular expression to limit changes in value.""" + type = var[InputType]("text") + """The type of the input.""" + max_length = var["int | None"](None) + """The maximum length of the input, in characters.""" + valid_empty = var(False) + """Empty values should pass validation.""" + compact = reactive(False, toggle_class="-textual-compact") + """Make the input compact (without borders).""" + + @dataclass + class Changed(Message): + """Posted when the value changes. + + Can be handled using `on_input_changed` in a subclass of `Input` or in a parent + widget in the DOM. + """ + + input: Input + """The `Input` widget that was changed.""" + + value: str + """The value that the input was changed to.""" + + validation_result: ValidationResult | None = None + """The result of validating the value (formed by combining the results from each validator), or None + if validation was not performed (for example when no validators are specified in the `Input`s init)""" + + @property + def control(self) -> Input: + """Alias for self.input.""" + return self.input + + @dataclass + class Submitted(Message): + """Posted when the enter key is pressed within an `Input`. + + Can be handled using `on_input_submitted` in a subclass of `Input` or in a + parent widget in the DOM. + """ + + input: Input + """The `Input` widget that is being submitted.""" + value: str + """The value of the `Input` being submitted.""" + validation_result: ValidationResult | None = None + """The result of validating the value on submission, formed by combining the results for each validator. + This value will be None if no validation was performed, which will be the case if no validators are supplied + to the corresponding `Input` widget.""" + + @property + def control(self) -> Input: + """Alias for self.input.""" + return self.input + + @dataclass + class Blurred(Message): + """Posted when the widget is blurred (loses focus). + + Can be handled using `on_input_blurred` in a subclass of `Input` or in a parent + widget in the DOM. + """ + + input: Input + """The `Input` widget that was changed.""" + + value: str + """The value that the input was changed to.""" + + validation_result: ValidationResult | None = None + """The result of validating the value (formed by combining the results from each validator), or None + if validation was not performed (for example when no validators are specified in the `Input`s init)""" + + @property + def control(self) -> Input: + """Alias for self.input.""" + return self.input + + def __init__( + self, + value: str | None = None, + placeholder: str = "", + highlighter: Highlighter | None = None, + password: bool = False, + *, + restrict: str | None = None, + type: InputType = "text", + max_length: int = 0, + suggester: Suggester | None = None, + validators: Validator | Iterable[Validator] | None = None, + validate_on: Iterable[InputValidationOn] | None = None, + valid_empty: bool = False, + select_on_focus: bool = True, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + tooltip: RenderableType | None = None, + compact: bool = False, + ) -> None: + """Initialise the `Input` widget. + + Args: + value: An optional default value for the input. + placeholder: Optional placeholder text for the input. + highlighter: An optional highlighter for the input. + password: Flag to say if the field should obfuscate its content. + restrict: A regex to restrict character inputs. + type: The type of the input. + max_length: The maximum length of the input, or 0 for no maximum length. + suggester: [`Suggester`][textual.suggester.Suggester] associated with this + input instance. + validators: An iterable of validators that the Input value will be checked against. + validate_on: Zero or more of the values "blur", "changed", and "submitted", + which determine when to do input validation. The default is to do + validation for all messages. + valid_empty: Empty values are valid. + select_on_focus: Whether to select all text on focus. + name: Optional name for the input widget. + id: Optional ID for the widget. + classes: Optional initial classes for the widget. + disabled: Whether the input is disabled or not. + tooltip: Optional tooltip. + compact: Enable compact style (without borders). + """ + super().__init__(name=name, id=id, classes=classes, disabled=disabled) + + self._blink_timer: Timer | None = None + """Timer controlling the blinking of the cursor, instantiated in `on_mount`.""" + + self.placeholder = placeholder + self.highlighter = highlighter + self.password = password + self.suggester = suggester + + # Ensure we always end up with an Iterable of validators + if isinstance(validators, Validator): + self.validators: list[Validator] = [validators] + elif validators is None: + self.validators = [] + else: + self.validators = list(validators) + + self.validate_on: set[str] = ( + (_POSSIBLE_VALIDATE_ON_VALUES & set(validate_on)) + if validate_on is not None + else _POSSIBLE_VALIDATE_ON_VALUES + ) + """Set with event names to do input validation on. + + Validation can only be performed on blur, on input changes and on input submission. + + Example: + This creates an `Input` widget that only gets validated when the value + is submitted explicitly: + + ```py + input = Input(validate_on=["submitted"]) + ``` + """ + self._reactive_valid_empty = valid_empty + self._valid = True + + self.restrict = restrict + if type not in _RESTRICT_TYPES: + raise ValueError( + f"Input type must be one of {friendly_list(_RESTRICT_TYPES.keys())}; not {type!r}" + ) + self.type = type + self.max_length = max_length + if not self.validators: + from memray._vendor.textual.validation import Integer, Number + + if self.type == "integer": + self.validators.append(Integer()) + elif self.type == "number": + self.validators.append(Number()) + + self._selecting = False + """True if the user is selecting text with the mouse.""" + + self._initial_value = True + """Indicates if the value has been set for the first time yet.""" + if value is not None: + self.value = value + + if tooltip is not None: + self.tooltip = tooltip + + self.compact = compact + + self.select_on_focus = select_on_focus + + def _position_to_cell(self, position: int) -> int: + """Convert an index within the value to cell position. + + Args: + position: The index within the value to convert. + + Returns: + The cell position corresponding to the index. + """ + return cell_len(expand_tabs_inline(self.value[:position], 4)) + + @property + def _cursor_offset(self) -> int: + """The cell offset of the cursor.""" + offset = self._position_to_cell(self.cursor_position) + if self.cursor_at_end: + offset += 1 + return offset + + @property + def cursor_at_start(self) -> bool: + """Flag to indicate if the cursor is at the start.""" + return self.cursor_position == 0 + + @property + def cursor_at_end(self) -> bool: + """Flag to indicate if the cursor is at the end.""" + return self.cursor_position == len(self.value) + + def check_consume_key(self, key: str, character: str | None) -> bool: + """Check if the widget may consume the given key. + + As an input we are expecting to capture printable keys. + + Args: + key: A key identifier. + character: A character associated with the key, or `None` if there isn't one. + + Returns: + `True` if the widget may capture the key in its `Key` message, or `False` if it won't. + """ + return character is not None and character.isprintable() + + def validate_selection(self, selection: Selection) -> Selection: + start, end = selection + value_length = len(self.value) + return Selection(clamp(start, 0, value_length), clamp(end, 0, value_length)) + + def _watch_selection(self, selection: Selection) -> None: + self.app.clear_selection() + self.app.cursor_position = self.cursor_screen_offset + if not self._initial_value: + self.scroll_to_region( + Region(self._cursor_offset, 0, width=1, height=1), + force=True, + animate=False, + ) + + def _watch_cursor_blink(self, blink: bool) -> None: + """Ensure we handle updating the cursor blink at runtime.""" + if self._blink_timer is not None: + if blink: + self._blink_timer.resume() + else: + self._pause_blink() + self._cursor_visible = True + + @property + def cursor_screen_offset(self) -> Offset: + """The offset of the cursor of this input in screen-space. (x, y)/(column, row).""" + x, y, _width, _height = self.content_region + scroll_x, _ = self.scroll_offset + return Offset(x + self._cursor_offset - scroll_x, y) + + def _watch_value(self, value: str) -> None: + """Update the virtual size and suggestion when the value changes.""" + self.virtual_size = Size(self.content_width, 1) + self._suggestion = "" + if self.suggester and value: + self.run_worker(self.suggester._get_suggestion(self, value)) + if self.styles.auto_dimensions: + self.refresh(layout=True) + + validation_result = ( + self.validate(value) if "changed" in self.validate_on else None + ) + self.post_message(self.Changed(self, value, validation_result)) + + # If this is the first time the value has been updated, set the cursor position to the end + if self._initial_value: + self.cursor_position = len(self.value) + self._initial_value = False + else: + # Force a re-validation of the selection to ensure it accounts for + # the length of the new value + self.selection = self.selection + + def _watch_valid_empty(self) -> None: + """Repeat validation when valid_empty changes.""" + self._watch_value(self.value) + + def validate(self, value: str) -> ValidationResult | None: + """Run all the validators associated with this Input on the supplied value. + + Runs all validators, combines the result into one. If any of the validators + failed, the combined result will be a failure. If no validators are present, + None will be returned. This also sets the `-invalid` CSS class on the Input + if the validation fails, and sets the `-valid` CSS class on the Input if + the validation succeeds. + + Returns: + A ValidationResult indicating whether *all* validators succeeded or not. + That is, if *any* validator fails, the result will be an unsuccessful + validation. + """ + + def set_classes() -> None: + """Set classes for valid flag.""" + valid = self._valid + self.set_class(not valid, "-invalid") + self.set_class(valid, "-valid") + + # If no validators are supplied, and therefore no validation occurs, we return None. + if not self.validators: + self._valid = True + set_classes() + return None + + if self.valid_empty and not value: + self._valid = True + set_classes() + return None + + validation_results: list[ValidationResult] = [ + validator.validate(value) for validator in self.validators + ] + combined_result = ValidationResult.merge(validation_results) + self._valid = combined_result.is_valid + set_classes() + + return combined_result + + @property + def is_valid(self) -> bool: + """Check if the value has passed validation.""" + return self._valid + + def render_line(self, y: int) -> Strip: + if y != 0: + return Strip.blank(self.size.width, self.rich_style) + + console = self.app.console + console_options = self.app.console_options + max_content_width = self.scrollable_content_region.width + + if not self.value: + placeholder = Text(self.placeholder, justify="left", end="") + placeholder.stylize(self.get_component_rich_style("input--placeholder")) + if self.has_focus: + cursor_style = self.get_component_rich_style("input--cursor") + if self._cursor_visible: + # If the placeholder is empty, there's no characters to stylise + # to make the cursor flash, so use a single space character + if len(placeholder) == 0: + placeholder = Text(" ", end="") + placeholder.stylize(cursor_style, 0, 1) + + strip = Strip( + console.render( + placeholder, console_options.update_width(max_content_width + 1) + ) + ) + else: + result = self._value + + # Add the completion with a faded style. + value = self.value + value_length = len(value) + suggestion = self._suggestion + show_suggestion = len(suggestion) > value_length and self.has_focus + if show_suggestion: + result += Text( + suggestion[value_length:], + self.get_component_rich_style("input--suggestion"), + end="", + ) + + if self.has_focus: + if not self.selection.is_empty: + start, end = self.selection + start, end = sorted((start, end)) + selection_style = self.get_component_rich_style("input--selection") + result.stylize_before(selection_style, start, end) + + if self._cursor_visible: + cursor_style = self.get_component_rich_style("input--cursor") + cursor = self.cursor_position + if not show_suggestion and self.cursor_at_end: + result.pad_right(1) + result.stylize(cursor_style, cursor, cursor + 1) + + segments = list( + console.render(result, console_options.update_width(self.content_width)) + ) + + strip = Strip(segments) + scroll_x, _ = self.scroll_offset + strip = strip.crop(scroll_x, scroll_x + max_content_width + 1) + strip = strip.extend_cell_length(max_content_width + 1) + + return strip.apply_style(self.rich_style) + + @property + def _value(self) -> Text: + """Value rendered as text.""" + if self.password: + return Text("•" * len(self.value), no_wrap=True, overflow="ignore", end="") + else: + text = Text(self.value, no_wrap=True, overflow="ignore", end="") + if self.highlighter is not None: + text = self.highlighter(text) + return text + + @property + def content_width(self) -> int: + """The width of the content.""" + if self.placeholder and not self.value: + return cell_len(self.placeholder) + + # Extra space for cursor at the end. + return self._value.cell_len + 1 + + def get_content_width(self, container: Size, viewport: Size) -> int: + """Get the widget of the content.""" + return self.content_width + + def get_content_height(self, container: Size, viewport: Size, width: int) -> int: + return 1 + + def _toggle_cursor(self) -> None: + """Toggle visibility of cursor.""" + if self.screen.is_active: + self._cursor_visible = not self._cursor_visible + else: + self._cursor_visible = True + + def _on_mount(self, event: Mount) -> None: + def text_selection_started(screen: Screen) -> None: + """Signal callback to unselect when arbitrary text selection starts.""" + self.selection = Selection.cursor(self.cursor_position) + + self.screen.text_selection_started_signal.subscribe( + self, text_selection_started, immediate=True + ) + self._blink_timer = self.set_interval( + 0.5, + self._toggle_cursor, + pause=not (self.cursor_blink and self.has_focus), + ) + + def _on_blur(self, event: Blur) -> None: + self._pause_blink() + validation_result = ( + self.validate(self.value) if "blur" in self.validate_on else None + ) + self.post_message(self.Blurred(self, self.value, validation_result)) + + def _on_focus(self, event: Focus) -> None: + self._restart_blink() + if self.select_on_focus and not event.from_app_focus: + self.selection = Selection(0, len(self.value)) + self.app.cursor_position = self.cursor_screen_offset + self._suggestion = "" + + async def _on_key(self, event: events.Key) -> None: + self._restart_blink() + + if event.is_printable: + event.stop() + assert event.character is not None + selection = self.selection + if selection.is_empty: + self.insert_text_at_cursor(event.character) + else: + self.replace(event.character, *selection) + event.prevent_default() + + def _on_paste(self, event: events.Paste) -> None: + if event.text: + line = event.text.splitlines()[0] + selection = self.selection + if selection.is_empty: + self.insert_text_at_cursor(line) + else: + self.replace(line, *selection) + event.stop() + + def _cell_offset_to_index(self, offset: int) -> int: + """Convert a cell offset to a character index, accounting for character width. + + Args: + offset: The cell offset to convert. + + Returns: + The character index corresponding to the cell offset. + """ + cell_offset = 0 + _cell_size = get_character_cell_size + scroll_x, _ = self.scroll_offset + offset += scroll_x + for index, char in enumerate(self.value): + cell_width = _cell_size(char) + if cell_offset <= offset < (cell_offset + cell_width): + return index + cell_offset += cell_width + return clamp(offset, 0, len(self.value)) + + async def _on_mouse_down(self, event: events.MouseDown) -> None: + self._pause_blink(visible=True) + offset_x, _ = event.get_content_offset_capture(self) + self.selection = Selection.cursor(self._cell_offset_to_index(offset_x)) + self._selecting = True + self.capture_mouse() + + def _end_selecting(self) -> None: + """End selecting if it is currently active.""" + if self._selecting: + self._selecting = False + self.release_mouse() + self._restart_blink() + + async def _on_mouse_release(self, _event: events.MouseRelease) -> None: + self._end_selecting() + + async def _on_mouse_up(self, _event: events.MouseUp) -> None: + self._end_selecting() + + async def _on_mouse_move(self, event: events.MouseMove) -> None: + if self._selecting: + # As we drag the mouse, we update the end position of the selection, + # keeping the start position fixed. + offset = event.get_content_offset_capture(self) + selection_start, _ = self.selection + self.selection = Selection( + selection_start, self._cell_offset_to_index(offset.x) + ) + + async def _on_suggestion_ready(self, event: SuggestionReady) -> None: + """Handle suggestion messages and set the suggestion when relevant.""" + if event.value == self.value: + self._suggestion = event.suggestion + + def _restart_blink(self) -> None: + """Restart the cursor blink cycle.""" + self._cursor_visible = True + if self.cursor_blink and self._blink_timer: + self._blink_timer.reset() + + def _pause_blink(self, visible: bool = False) -> None: + """Hide the blinking cursor and pause the blink cycle.""" + self._cursor_visible = visible + if self._blink_timer: + self._blink_timer.pause() + + def insert_text_at_cursor(self, text: str) -> None: + """Insert new text at the cursor, move the cursor to the end of the new text. + + Args: + text: New text to insert. + """ + self.insert(text, self.cursor_position) + + def restricted(self) -> None: + """Called when a character has been restricted. + + The default behavior is to play the system bell. + You may want to override this method if you want to disable the bell or do something else entirely. + """ + self.app.bell() + + def clear(self) -> None: + """Clear the input.""" + self.value = "" + + @property + def selected_text(self) -> str: + """The text between the start and end points of the current selection.""" + start, end = sorted(self.selection) + return self.value[start:end] + + def action_cursor_left(self, select: bool = False) -> None: + """Move the cursor one position to the left. + + Args: + select: If `True`, select the text to the left of the cursor. + """ + start, end = self.selection + if select: + self.selection = Selection(start, end - 1) + else: + if self.selection.is_empty: + self.cursor_position -= 1 + else: + self.cursor_position = min(start, end) + + def action_cursor_right(self, select: bool = False) -> None: + """Accept an auto-completion or move the cursor one position to the right. + + Args: + select: If `True`, select the text to the right of the cursor. + """ + start, end = self.selection + if select: + self.selection = Selection(start, end + 1) + else: + if self.cursor_at_end and self._suggestion: + self.value = self._suggestion + self.cursor_position = len(self.value) + else: + if self.selection.is_empty: + self.cursor_position += 1 + else: + self.cursor_position = max(start, end) + + def select_all(self) -> None: + """Select all of the text in the Input.""" + self.selection = Selection(0, len(self.value)) + self._suggestion = "" + + def action_select_all(self) -> None: + """Select all of the text in the Input.""" + self.select_all() + + def action_home(self, select: bool = False) -> None: + """Move the cursor to the start of the input. + + Args: + select: If `True`, select the text between the old and new cursor positions. + """ + if select: + self.selection = Selection(self.cursor_position, 0) + else: + self.cursor_position = 0 + + def action_end(self, select: bool = False) -> None: + """Move the cursor to the end of the input. + + Args: + select: If `True`, select the text between the old and new cursor positions. + """ + if select: + self.selection = Selection(self.cursor_position, len(self.value)) + else: + self.cursor_position = len(self.value) + + _WORD_START = re.compile(r"(?<=\W)\w") + + def action_cursor_left_word(self, select: bool = False) -> None: + """Move the cursor left to the start of a word. + + Args: + select: If `True`, select the text between the old and new cursor positions. + """ + if self.password: + # This is a password field so don't give any hints about word + # boundaries, even during movement. + self.action_home(select) + else: + start, _ = self.selection + try: + *_, hit = re.finditer( + self._WORD_START, self.value[: self.cursor_position] + ) + except ValueError: + target = 0 + else: + target = hit.start() + + if select: + self.selection = Selection(start, target) + else: + self.cursor_position = target + + def action_cursor_right_word(self, select: bool = False) -> None: + """Move the cursor right to the start of a word. + + Args: + select: If `True`, select the text between the old and new cursor positions. + """ + if self.password: + # This is a password field so don't give any hints about word + # boundaries, even during movement. + self.action_end(select) + else: + hit = re.search(self._WORD_START, self.value[self.cursor_position :]) + + start, end = self.selection + if hit is None: + target = len(self.value) + else: + target = end + hit.start() + + if select: + self.selection = Selection(start, target) + else: + self.cursor_position = target + + def replace(self, text: str, start: int, end: int) -> None: + """Replace the text between the start and end locations with the given text. + + Args: + text: Text to replace the existing text with. + start: Start index to replace (inclusive). + end: End index to replace (inclusive). + """ + + def check_allowed_value(value: str) -> bool: + """Check if new value is restricted.""" + + # Check max length + if self.max_length and len(value) > self.max_length: + return False + # Check explicit restrict + if self.restrict and re.fullmatch(self.restrict, value) is None: + return False + # Check type restrict + if self.type: + type_restrict = _RESTRICT_TYPES.get(self.type, None) + if ( + type_restrict is not None + and re.fullmatch(type_restrict, value) is None + ): + return False + # Character is allowed + return True + + value = self.value + start, end = sorted((max(0, start), min(len(value), end))) + new_value = f"{value[:start]}{text}{value[end:]}" + if check_allowed_value(new_value): + self.value = new_value + self.cursor_position = start + len(text) + else: + self.restricted() + + def insert(self, text: str, index: int) -> None: + """Insert text at the given index. + + Args: + text: Text to insert. + index: Index to insert the text at (inclusive). + """ + self.replace(text, index, index) + + def delete(self, start: int, end: int) -> None: + """Delete the text between the start and end locations. + + Args: + start: Start index to delete (inclusive). + end: End index to delete (inclusive). + """ + self.replace("", start, end) + + def delete_selection(self) -> None: + """Delete the current selection.""" + self.delete(*self.selection) + + def action_delete_right(self) -> None: + """Delete one character at the current cursor position.""" + if self.selection.is_empty: + self.delete(self.cursor_position, self.cursor_position + 1) + else: + self.delete_selection() + + def action_delete_right_word(self) -> None: + """Delete the current character and all rightward to the start of the next word.""" + if not self.selection.is_empty: + self.delete_selection() + return + + if self.password: + # This is a password field so don't give any hints about word + # boundaries, even during deletion. + self.action_delete_right_all() + else: + after = self.value[self.cursor_position :] + hit = re.search(self._WORD_START, after) + if hit is None: + self.action_delete_right_all() + else: + start = self.cursor_position + end = start + hit.end() - 1 + self.delete(start, end) + + def action_delete_right_all(self) -> None: + """Delete the current character and all characters to the right of the cursor position.""" + if self.selection.is_empty: + self.delete(self.cursor_position, len(self.value)) + else: + self.delete_selection() + + def action_delete_left(self) -> None: + """Delete one character to the left of the current cursor position.""" + if self.selection.is_empty: + self.delete(self.cursor_position - 1, self.cursor_position) + else: + self.delete_selection() + + def action_delete_left_word(self) -> None: + """Delete leftward of the cursor position to the start of a word.""" + if not self.selection.is_empty: + self.delete_selection() + return + + if self.password: + # This is a password field so don't give any hints about word + # boundaries, even during deletion. + self.action_delete_left_all() + else: + try: + *_, hit = re.finditer( + self._WORD_START, self.value[: self.cursor_position] + ) + except ValueError: + target = 0 + else: + target = hit.start() + + self.delete(target, self.cursor_position) + + def action_delete_left_all(self) -> None: + """Delete all characters to the left of the cursor position.""" + if self.selection.is_empty: + self.delete(0, self.cursor_position) + else: + self.delete_selection() + + async def action_submit(self) -> None: + """Handle a submit action. + + Normally triggered by the user pressing Enter. This may also run any validators. + """ + validation_result = ( + self.validate(self.value) if "submitted" in self.validate_on else None + ) + self.post_message(self.Submitted(self, self.value, validation_result)) + + def action_cut(self) -> None: + """Cut the current selection (copy to clipboard and remove from input).""" + self.app.copy_to_clipboard(self.selected_text) + self.delete_selection() + + def action_copy(self) -> None: + """Copy the current selection to the clipboard.""" + selected_text = self.selected_text + if selected_text: + self.app.copy_to_clipboard(selected_text) + else: + raise SkipAction() + + def action_paste(self) -> None: + """Paste from the local clipboard.""" + clipboard = self.app.clipboard + start, end = self.selection + self.replace(clipboard, start, end) diff --git a/src/memray/_vendor/textual/widgets/_key_panel.py b/src/memray/_vendor/textual/widgets/_key_panel.py new file mode 100644 index 0000000000..69f65a6f17 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_key_panel.py @@ -0,0 +1,182 @@ +from __future__ import annotations + +from collections import defaultdict +from itertools import groupby +from operator import itemgetter +from typing import TYPE_CHECKING + +from rich import box +from rich.table import Table +from rich.text import Text + +from memray._vendor.textual.app import ComposeResult +from memray._vendor.textual.binding import Binding +from memray._vendor.textual.containers import VerticalScroll +from memray._vendor.textual.widgets import Static + +if TYPE_CHECKING: + from memray._vendor.textual.screen import Screen + + +class BindingsTable(Static): + """A widget to display bindings.""" + + COMPONENT_CLASSES = { + "bindings-table--key", + "bindings-table--description", + "bindings-table--divider", + "bindings-table--header", + } + + DEFAULT_CSS = """ + BindingsTable { + width: auto; + height: auto; + } + """ + + def render_bindings_table(self) -> Table: + """Render a table with all the key bindings. + + Returns: + A Rich Table. + """ + + bindings = self.screen.active_bindings.values() + + key_style = self.get_component_rich_style("bindings-table--key") + divider_transparent = ( + self.get_component_styles("bindings-table--divider").color.a == 0 + ) + table = Table( + padding=(0, 0), + show_header=False, + box=box.SIMPLE if divider_transparent else box.HORIZONTALS, + border_style=self.get_component_rich_style("bindings-table--divider"), + ) + table.add_column("", justify="right") + + header_style = self.get_component_rich_style("bindings-table--header") + previous_namespace: object = None + for namespace, _bindings in groupby(bindings, key=itemgetter(0)): + table_bindings = list(_bindings) + if not table_bindings: + continue + + if namespace.BINDING_GROUP_TITLE: + title = Text(namespace.BINDING_GROUP_TITLE, end="") + title.stylize(header_style) + table.add_row("", title) + + action_to_bindings: defaultdict[str, list[tuple[Binding, bool, str]]] + action_to_bindings = defaultdict(list) + for _, binding, enabled, tooltip in table_bindings: + if not binding.system: + action_to_bindings[binding.action].append( + (binding, enabled, tooltip) + ) + + description_style = self.get_component_rich_style( + "bindings-table--description" + ) + + def render_description(binding: Binding) -> Text: + """Render description text from a binding.""" + text = Text.from_markup( + binding.description, end="", style=description_style + ) + if binding.tooltip: + if binding.description: + text.append(" ") + text.append(binding.tooltip, "dim") + return text + + get_key_display = self.app.get_key_display + for multi_bindings in action_to_bindings.values(): + binding, enabled, tooltip = multi_bindings[0] + keys_display = " ".join( + dict.fromkeys( # Remove duplicates while preserving order + get_key_display(binding) for binding, _, _ in multi_bindings + ) + ) + table.add_row( + Text(keys_display, style=key_style), + render_description(binding), + ) + if namespace != previous_namespace: + table.add_section() + + previous_namespace = namespace + + return table + + def render(self) -> Table: + return self.render_bindings_table() + + +class KeyPanel(VerticalScroll, can_focus=False): + """ + Shows bindings for currently focused widget. + """ + + DEFAULT_CSS = """ + KeyPanel { + split: right; + width: 33%; + min-width: 30; + max-width: 60; + border-left: vkey $foreground 30%; + padding: 0 1; + height: 1fr; + padding-right: 1; + align: center top; + + &> BindingsTable > .bindings-table--key { + color: $text-accent; + text-style: bold; + padding: 0 1; + } + + &> BindingsTable > .bindings-table--description { + color: $foreground; + } + + &> BindingsTable > .bindings-table--divider { + color: transparent; + } + + &> BindingsTable > .bindings-table--header { + color: $text-primary; + text-style: underline; + } + + #bindings-table { + width: auto; + height: auto; + } + } + """ + + DEFAULT_CLASSES = "-textual-system" + + def compose(self) -> ComposeResult: + yield BindingsTable(shrink=True, expand=False) + + async def on_mount(self) -> None: + mount_screen = self.screen + + async def bindings_changed(screen: Screen) -> None: + """Update bindings.""" + if not screen.app.app_focus: + return + if self.is_attached and screen is mount_screen: + await self.recompose() + + def _bindings_changed(screen: Screen) -> None: + self.call_after_refresh(bindings_changed, screen) + + self.set_class(self.app.ansi_color, "-ansi-scrollbar") + self.screen.bindings_updated_signal.subscribe(self, _bindings_changed) + + def on_unmount(self) -> None: + self.screen.bindings_updated_signal.unsubscribe(self) diff --git a/src/memray/_vendor/textual/widgets/_label.py b/src/memray/_vendor/textual/widgets/_label.py new file mode 100644 index 0000000000..3ac9bcea33 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_label.py @@ -0,0 +1,73 @@ +"""Provides a simple Label widget.""" + +from __future__ import annotations + +from typing import Literal + +from memray._vendor.textual.visual import VisualType +from memray._vendor.textual.widgets._static import Static + +LabelVariant = Literal["success", "error", "warning", "primary", "secondary", "accent"] + + +class Label(Static): + """A simple label widget for displaying text-oriented renderables.""" + + DEFAULT_CSS = """ + Label { + width: auto; + height: auto; + min-height: 1; + + &.success { + color: $text-success; + background: $success-muted; + } + &.error { + color: $text-error; + background: $error-muted; + } + &.warning { + color: $text-warning; + background: $warning-muted; + } + &.primary { + color: $text-primary; + background: $primary-muted; + } + &.secondary { + color: $text-secondary; + background: $secondary-muted; + } + &.accent { + color: $text-accent; + background: $accent-muted; + } + } + """ + + def __init__( + self, + content: VisualType = "", + *, + variant: LabelVariant | None = None, + expand: bool = False, + shrink: bool = False, + markup: bool = True, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + ) -> None: + super().__init__( + content, + expand=expand, + shrink=shrink, + markup=markup, + name=name, + id=id, + classes=classes, + disabled=disabled, + ) + if variant: + self.add_class(variant) diff --git a/src/memray/_vendor/textual/widgets/_link.py b/src/memray/_vendor/textual/widgets/_link.py new file mode 100644 index 0000000000..24775452c1 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_link.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from memray._vendor.textual.binding import Binding +from memray._vendor.textual.reactive import reactive +from memray._vendor.textual.widgets import Static + + +class Link(Static, can_focus=True): + """A simple, clickable link that opens a URL.""" + + DEFAULT_CSS = """ + Link { + width: auto; + height: auto; + min-height: 1; + color: $text-accent; + text-style: underline; + &:hover { color: $accent; } + &:focus { text-style: bold reverse; } + pointer: pointer; + } + """ + + BINDINGS = [Binding("enter", "open_link", "Open link")] + """ + | Key(s) | Description | + | :- | :- | + | enter | Open the link in the browser. | + """ + + text: reactive[str] = reactive("", layout=True) + url: reactive[str] = reactive("") + + def __init__( + self, + text: str, + *, + url: str | None = None, + tooltip: str | None = None, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + ) -> None: + """A link widget. + + Args: + text: Text of the link. + url: A URL to open, when clicked. If `None`, the `text` parameter will also be used as the url. + tooltip: Optional tooltip. + name: Name of widget. + id: ID of Widget. + classes: Space separated list of class names. + disabled: Whether the static is disabled or not. + """ + super().__init__( + text, name=name, id=id, classes=classes, disabled=disabled, markup=False + ) + self.set_reactive(Link.text, text) + self.set_reactive(Link.url, text if url is None else url) + self.tooltip = tooltip + + def watch_text(self, text: str) -> None: + self.update(text) + + def on_click(self) -> None: + self.action_open_link() + + def action_open_link(self) -> None: + if self.url: + self.app.open_url(self.url) diff --git a/src/memray/_vendor/textual/widgets/_list_item.py b/src/memray/_vendor/textual/widgets/_list_item.py new file mode 100644 index 0000000000..331368ffcd --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_list_item.py @@ -0,0 +1,39 @@ +"""Provides a list item widget for use with `ListView`.""" + +from __future__ import annotations + +from memray._vendor.textual import events, on +from memray._vendor.textual.message import Message +from memray._vendor.textual.reactive import reactive +from memray._vendor.textual.widget import Widget + + +class ListItem(Widget, can_focus=False): + """A widget that is an item within a `ListView`. + + A `ListItem` is designed for use within a + [ListView][textual.widgets._list_view.ListView], please see `ListView`'s + documentation for more details on use. + """ + + highlighted = reactive(False) + """Is this item highlighted?""" + + class _ChildClicked(Message): + """For informing with the parent ListView that we were clicked""" + + def __init__(self, item: ListItem) -> None: + self.item = item + super().__init__() + + def _on_click(self, _: events.Click) -> None: + self.post_message(self._ChildClicked(self)) + + def watch_highlighted(self, value: bool) -> None: + self.set_class(value, "-highlight") + + @on(events.Enter) + @on(events.Leave) + def on_enter_or_leave(self, event: events.Enter | events.Leave) -> None: + event.stop() + self.set_class(self.is_mouse_over, "-hovered") diff --git a/src/memray/_vendor/textual/widgets/_list_view.py b/src/memray/_vendor/textual/widgets/_list_view.py new file mode 100644 index 0000000000..a205c81c22 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_list_view.py @@ -0,0 +1,396 @@ +from __future__ import annotations + +from typing import ClassVar, Iterable, Optional + +from typing_extensions import TypeGuard + +from memray._vendor.textual._loop import loop_from_index +from memray._vendor.textual.await_complete import AwaitComplete +from memray._vendor.textual.await_remove import AwaitRemove +from memray._vendor.textual.binding import Binding, BindingType +from memray._vendor.textual.containers import VerticalScroll +from memray._vendor.textual.events import Mount +from memray._vendor.textual.message import Message +from memray._vendor.textual.reactive import reactive +from memray._vendor.textual.widget import AwaitMount +from memray._vendor.textual.widgets._list_item import ListItem + + +class ListView(VerticalScroll, can_focus=True, can_focus_children=False): + """A vertical list view widget. + + Displays a vertical list of `ListItem`s which can be highlighted and + selected using the mouse or keyboard. + + Attributes: + index: The index in the list that's currently highlighted. + """ + + ALLOW_MAXIMIZE = True + + DEFAULT_CSS = """ + ListView { + background: $surface; + & > ListItem { + color: $foreground; + height: auto; + overflow: hidden hidden; + width: 1fr; + + &.-hovered { + background: $block-hover-background; + } + + &.-highlight { + color: $block-cursor-blurred-foreground; + background: $block-cursor-blurred-background; + text-style: $block-cursor-blurred-text-style; + } + } + + &:focus { + background-tint: $foreground 5%; + & > ListItem.-highlight { + color: $block-cursor-foreground; + background: $block-cursor-background; + text-style: $block-cursor-text-style; + } + } + + } + """ + + BINDINGS: ClassVar[list[BindingType]] = [ + Binding("enter", "select_cursor", "Select", show=False), + Binding("up", "cursor_up", "Cursor up", show=False), + Binding("down", "cursor_down", "Cursor down", show=False), + ] + """ + | Key(s) | Description | + | :- | :- | + | enter | Select the current item. | + | up | Move the cursor up. | + | down | Move the cursor down. | + """ + + index = reactive[Optional[int]](None, init=False) + """The index of the currently highlighted item.""" + + class Highlighted(Message): + """Posted when the highlighted item changes. + + Highlighted item is controlled using up/down keys. + Can be handled using `on_list_view_highlighted` in a subclass of `ListView` + or in a parent widget in the DOM. + """ + + ALLOW_SELECTOR_MATCH = {"item"} + """Additional message attributes that can be used with the [`on` decorator][textual.on].""" + + def __init__(self, list_view: ListView, item: ListItem | None) -> None: + super().__init__() + self.list_view: ListView = list_view + """The view that contains the item highlighted.""" + self.item: ListItem | None = item + """The highlighted item, if there is one highlighted.""" + + @property + def control(self) -> ListView: + """The view that contains the item highlighted. + + This is an alias for [`Highlighted.list_view`][textual.widgets.ListView.Highlighted.list_view] + and is used by the [`on`][textual.on] decorator. + """ + return self.list_view + + class Selected(Message): + """Posted when a list item is selected, e.g. when you press the enter key on it. + + Can be handled using `on_list_view_selected` in a subclass of `ListView` or in + a parent widget in the DOM. + """ + + ALLOW_SELECTOR_MATCH = {"item"} + """Additional message attributes that can be used with the [`on` decorator][textual.on].""" + + def __init__(self, list_view: ListView, item: ListItem, index: int) -> None: + super().__init__() + self.list_view: ListView = list_view + """The view that contains the item selected.""" + self.item: ListItem = item + """The selected item.""" + self.index = index + """Index of the selected item.""" + + @property + def control(self) -> ListView: + """The view that contains the item selected. + + This is an alias for [`Selected.list_view`][textual.widgets.ListView.Selected.list_view] + and is used by the [`on`][textual.on] decorator. + """ + return self.list_view + + def __init__( + self, + *children: ListItem, + initial_index: int | None = 0, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + ) -> None: + """ + Initialize a ListView. + + Args: + *children: The ListItems to display in the list. + initial_index: The index that should be highlighted when the list is first mounted. + name: The name of the widget. + id: The unique ID of the widget used in CSS/query selection. + classes: The CSS classes of the widget. + disabled: Whether the ListView is disabled or not. + """ + super().__init__( + *children, name=name, id=id, classes=classes, disabled=disabled + ) + self._initial_index = initial_index + + def _on_mount(self, _: Mount) -> None: + """Ensure the ListView is fully-settled after mounting.""" + + if self._initial_index is not None and self.children: + index = self._initial_index + if index >= len(self.children): + index = 0 + if self._nodes[index].disabled: + for index, node in loop_from_index(self._nodes, index, wrap=True): + if not node.disabled: + break + self.index = index + + @property + def highlighted_child(self) -> ListItem | None: + """The currently highlighted ListItem, or None if nothing is highlighted.""" + if self.index is not None and 0 <= self.index < len(self._nodes): + list_item = self._nodes[self.index] + assert isinstance(list_item, ListItem) + return list_item + else: + return None + + def validate_index(self, index: int | None) -> int | None: + """Clamp the index to the valid range, or set to None if there's nothing to highlight. + + Args: + index: The index to clamp. + + Returns: + The clamped index. + """ + if index is None or not self._nodes: + return None + elif index < 0: + return 0 + elif index >= len(self._nodes): + return len(self._nodes) - 1 + + return index + + def _is_valid_index(self, index: int | None) -> TypeGuard[int]: + """Determine whether the current index is valid into the list of children.""" + if index is None: + return False + return 0 <= index < len(self._nodes) + + def watch_index(self, old_index: int | None, new_index: int | None) -> None: + """Updates the highlighting when the index changes.""" + + if new_index is not None: + selected_widget = self._nodes[new_index] + if selected_widget.region: + self.scroll_to_widget(self._nodes[new_index], animate=False) + else: + # Call after refresh to permit a refresh operation + self.call_after_refresh( + self.scroll_to_widget, selected_widget, animate=False + ) + + if self._is_valid_index(old_index): + old_child = self._nodes[old_index] + assert isinstance(old_child, ListItem) + old_child.highlighted = False + + if ( + new_index is not None + and self._is_valid_index(new_index) + and not self._nodes[new_index].disabled + ): + new_child = self._nodes[new_index] + assert isinstance(new_child, ListItem) + new_child.highlighted = True + self.post_message(self.Highlighted(self, new_child)) + else: + self.post_message(self.Highlighted(self, None)) + + def extend(self, items: Iterable[ListItem]) -> AwaitMount: + """Append multiple new ListItems to the end of the ListView. + + Args: + items: The ListItems to append. + + Returns: + An awaitable that yields control to the event loop + until the DOM has been updated with the new child items. + """ + await_mount = self.mount(*items) + return await_mount + + def append(self, item: ListItem) -> AwaitMount: + """Append a new ListItem to the end of the ListView. + + Args: + item: The ListItem to append. + + Returns: + An awaitable that yields control to the event loop + until the DOM has been updated with the new child item. + """ + return self.extend([item]) + + def clear(self) -> AwaitRemove: + """Clear all items from the ListView. + + Returns: + An awaitable that yields control to the event loop until + the DOM has been updated to reflect all children being removed. + """ + await_remove = self.query("ListView > ListItem").remove() + self.index = None + return await_remove + + def insert(self, index: int, items: Iterable[ListItem]) -> AwaitMount: + """Insert new ListItem(s) to specified index. + + Args: + index: index to insert new ListItem. + items: The ListItems to insert. + + Returns: + An awaitable that yields control to the event loop + until the DOM has been updated with the new child item. + """ + await_mount = self.mount(*items, before=index) + return await_mount + + def pop(self, index: Optional[int] = None) -> AwaitComplete: + """Remove last ListItem from ListView or + Remove ListItem from ListView by index + + Args: + index: index of ListItem to remove from ListView + + Returns: + An awaitable that yields control to the event loop until + the DOM has been updated to reflect item being removed. + """ + if len(self) == 0: + raise IndexError("pop from empty list") + + index = index if index is not None else -1 + item_to_remove = self.query("ListItem")[index] + normalized_index = index if index >= 0 else index + len(self) + + async def do_pop() -> None: + """Remove the item and update the highlighted index.""" + await item_to_remove.remove() + if self.index is not None: + if normalized_index < self.index: + self.index -= 1 + elif normalized_index == self.index: + old_index = self.index + # Force a re-validation of the index + self.index = self.index + # If the index hasn't changed, the watcher won't be called + # but we need to update the highlighted item + if old_index == self.index: + self.watch_index(old_index, self.index) + + return AwaitComplete(do_pop()) + + def remove_items(self, indices: Iterable[int]) -> AwaitComplete: + """Remove ListItems from ListView by indices + + Args: + indices: index(s) of ListItems to remove from ListView + + Returns: + An awaitable object that waits for the direct children to be removed. + """ + items = self.query("ListItem") + items_to_remove = [items[index] for index in indices] + normalized_indices = set( + index if index >= 0 else index + len(self) for index in indices + ) + + async def do_remove_items() -> None: + """Remove the items and update the highlighted index.""" + await self.remove_children(items_to_remove) + if self.index is not None: + removed_before_highlighted = sum( + 1 for index in normalized_indices if index < self.index + ) + if removed_before_highlighted: + self.index -= removed_before_highlighted + elif self.index in normalized_indices: + old_index = self.index + # Force a re-validation of the index + self.index = self.index + # If the index hasn't changed, the watcher won't be called + # but we need to update the highlighted item + if old_index == self.index: + self.watch_index(old_index, self.index) + + return AwaitComplete(do_remove_items()) + + def action_select_cursor(self) -> None: + """Select the current item in the list.""" + selected_child = self.highlighted_child + if selected_child is None: + return + self.post_message(self.Selected(self, selected_child, self.index)) + + def action_cursor_down(self) -> None: + """Highlight the next item in the list.""" + if self.index is None: + if self._nodes: + self.index = 0 + else: + index = self.index + for index, item in loop_from_index(self._nodes, self.index, wrap=False): + if not item.disabled: + self.index = index + break + + def action_cursor_up(self) -> None: + """Highlight the previous item in the list.""" + if self.index is None: + if self._nodes: + self.index = len(self._nodes) - 1 + else: + for index, item in loop_from_index( + self._nodes, self.index, direction=-1, wrap=False + ): + if not item.disabled: + self.index = index + break + + def _on_list_item__child_clicked(self, event: ListItem._ChildClicked) -> None: + event.stop() + self.focus() + self.index = self._nodes.index(event.item) + self.post_message(self.Selected(self, event.item, self.index)) + + def __len__(self) -> int: + """Compute the length (in number of items) of the list view.""" + return len(self._nodes) diff --git a/src/memray/_vendor/textual/widgets/_loading_indicator.py b/src/memray/_vendor/textual/widgets/_loading_indicator.py new file mode 100644 index 0000000000..0f94e47c24 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_loading_indicator.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +from time import time +from typing import TYPE_CHECKING + +from rich.style import Style +from rich.text import Text + +if TYPE_CHECKING: + from memray._vendor.textual.app import RenderResult + +from memray._vendor.textual import on +from memray._vendor.textual.color import Gradient +from memray._vendor.textual.events import InputEvent, Mount +from memray._vendor.textual.widget import Widget + + +class LoadingIndicator(Widget): + """Display an animated loading indicator.""" + + DEFAULT_CSS = """ + LoadingIndicator { + width: 100%; + height: 100%; + min-height: 1; + content-align: center middle; + color: $primary; + text-style: not reverse; + } + LoadingIndicator.-textual-loading-indicator { + layer: _loading; + background: $boost; + dock: top; + } + """ + + def __init__( + self, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + ): + """Initialize a loading indicator. + + Args: + name: The name of the widget. + id: The ID of the widget in the DOM. + classes: The CSS classes for the widget. + disabled: Whether the widget is disabled or not. + """ + super().__init__(name=name, id=id, classes=classes, disabled=disabled) + + self._start_time: float = 0.0 + """The time the loading indicator was mounted (a Unix timestamp).""" + + def _on_mount(self, _: Mount) -> None: + self._start_time = time() + self.auto_refresh = 1 / 16 + + @on(InputEvent) + def on_input(self, event: InputEvent) -> None: + """Prevent all input events from bubbling, thus disabling widgets in a loading state.""" + event.stop() + event.prevent_default() + + def render(self) -> RenderResult: + if self.app.animation_level == "none": + return Text("Loading...") + + elapsed = time() - self._start_time + speed = 0.8 + dot = "\u25cf" + _, _, background, color = self.colors + + gradient = Gradient( + (0.0, background.blend(color, 0.1)), + (0.7, color), + (1.0, color.lighten(0.1)), + ) + + blends = [(elapsed * speed - dot_number / 8) % 1 for dot_number in range(5)] + + dots = [ + ( + f"{dot} ", + Style.from_color(gradient.get_color((1 - blend) ** 2).rich_color), + ) + for blend in blends + ] + indicator = Text.assemble(*dots) + indicator.rstrip() + return indicator diff --git a/src/memray/_vendor/textual/widgets/_log.py b/src/memray/_vendor/textual/widgets/_log.py new file mode 100644 index 0000000000..14f3d77de6 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_log.py @@ -0,0 +1,362 @@ +from __future__ import annotations + +import re +from typing import TYPE_CHECKING, Iterable, Optional, Sequence + +from rich.cells import cell_len +from rich.highlighter import Highlighter, ReprHighlighter +from rich.style import Style +from rich.text import Text + +from memray._vendor.textual import work +from memray._vendor.textual._line_split import line_split +from memray._vendor.textual.cache import LRUCache +from memray._vendor.textual.geometry import Size +from memray._vendor.textual.reactive import var +from memray._vendor.textual.scroll_view import ScrollView +from memray._vendor.textual.selection import Selection +from memray._vendor.textual.strip import Strip + +if TYPE_CHECKING: + from typing_extensions import Self + +_sub_escape = re.compile("[\u0000-\u0014]").sub + + +class Log(ScrollView, can_focus=True): + """A widget to log text.""" + + ALLOW_SELECT = True + DEFAULT_CSS = """ + Log { + background: $surface; + color: $text; + overflow: scroll; + &:focus { + background-tint: $foreground 5%; + } + } + """ + + max_lines: var[int | None] = var[Optional[int]](None) + """Maximum number of lines to show""" + + auto_scroll: var[bool] = var(True) + """Automatically scroll to new lines.""" + + def __init__( + self, + highlight: bool = False, + max_lines: int | None = None, + auto_scroll: bool = True, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + ) -> None: + """Create a Log widget. + + Args: + highlight: Enable highlighting. + max_lines: Maximum number of lines to display. + auto_scroll: Scroll to end on new lines. + name: The name of the text log. + id: The ID of the text log in the DOM. + classes: The CSS classes of the text log. + disabled: Whether the text log is disabled or not. + """ + super().__init__(name=name, id=id, classes=classes, disabled=disabled) + self.highlight = highlight + """Enable highlighting.""" + self.max_lines = max_lines + self.auto_scroll = auto_scroll + self._lines: list[str] = [] + self._width = 0 + self._updates = 0 + self._render_line_cache: LRUCache[int, Strip] = LRUCache(1024) + self.highlighter: Highlighter = ReprHighlighter() + """The Rich Highlighter object to use, if `highlight=True`""" + self._clear_y = 0 + + @property + def allow_select(self) -> bool: + return True + + @property + def lines(self) -> Sequence[str]: + """The raw lines in the Log. + + Note that this attribute is read only. + Changing the lines will not update the Log's contents. + + """ + return self._lines + + def notify_style_update(self) -> None: + """Called by Textual when styles update.""" + super().notify_style_update() + self._render_line_cache.clear() + + def _update_maximum_width(self, updates: int, size: int) -> None: + """Update the virtual size width. + + Args: + updates: A counter of updates. + size: Maximum size of new lines. + """ + if updates == self._updates: + self._width = max(size, self._width) + self.virtual_size = Size(self._width, self.line_count) + + @property + def line_count(self) -> int: + """Number of lines of content.""" + if self._lines: + return len(self._lines) - (self._lines[-1] == "") + return 0 + + @classmethod + def _process_line(cls, line: str) -> str: + """Process a line before it is rendered to remove control codes. + + Args: + line: A string. + + Returns: + New string with no control codes. + """ + return _sub_escape("�", line.expandtabs()) + + @work(thread=True) + def _update_size(self, updates: int, lines: list[str]) -> None: + """A thread worker to update the width in the background. + + Args: + updates: The update index at the time of invocation. + lines: Lines that were added. + """ + if lines: + _process_line = self._process_line + max_length = max(cell_len(_process_line(line)) for line in lines) + self.app.call_from_thread(self._update_maximum_width, updates, max_length) + + def _prune_max_lines(self) -> None: + """Prune lines if there are more than the maximum.""" + if self.max_lines is None: + return + remove_lines = len(self._lines) - self.max_lines + if remove_lines > 0: + _cache = self._render_line_cache + # We've removed some lines, which means the y values in the cache are out of sync + # Calculated a new dict of cache values + updated_cache = { + y - remove_lines: _cache[y] for y in _cache.keys() if y > remove_lines + } + # Clear the cache + _cache.clear() + # Update the cache with previously calculated values + for y, line in updated_cache.items(): + _cache[y] = line + del self._lines[:remove_lines] + + def write( + self, + data: str, + scroll_end: bool | None = None, + ) -> Self: + """Write to the log. + + Args: + data: Data to write. + scroll_end: Scroll to the end after writing, or `None` to use `self.auto_scroll`. + + Returns: + The `Log` instance. + """ + is_vertical_scroll_end = self.is_vertical_scroll_end + if data: + if not self._lines: + self._lines.append("") + for line, ending in line_split(data): + self._lines[-1] += line + self._width = max( + self._width, cell_len(self._process_line(self._lines[-1])) + ) + self.refresh_lines(len(self._lines) - 1) + if ending: + self._lines.append("") + self.virtual_size = Size(self._width, self.line_count) + + if self.max_lines is not None and len(self._lines) > self.max_lines: + self._prune_max_lines() + + auto_scroll = self.auto_scroll if scroll_end is None else scroll_end + if auto_scroll: + self.scroll_end(animate=False, immediate=True, x_axis=False) + return self + + def write_line( + self, + line: str, + scroll_end: bool | None = None, + ) -> Self: + """Write content on a new line. + + Args: + line: String to write to the log. + scroll_end: Scroll to the end after writing, or `None` to use `self.auto_scroll`. + + Returns: + The `Log` instance. + """ + self.write_lines([line], scroll_end) + return self + + def write_lines( + self, + lines: Iterable[str], + scroll_end: bool | None = None, + ) -> Self: + """Write an iterable of lines. + + Args: + lines: An iterable of strings to write. + scroll_end: Scroll to the end after writing, or `None` to use `self.auto_scroll`. + + Returns: + The `Log` instance. + """ + is_vertical_scroll_end = self.is_vertical_scroll_end + auto_scroll = self.auto_scroll if scroll_end is None else scroll_end + new_lines = [] + for line in lines: + new_lines.extend(line.splitlines()) + start_line = len(self._lines) + self._lines.extend(new_lines) + if self.max_lines is not None and len(self._lines) > self.max_lines: + self._prune_max_lines() + self.virtual_size = Size(self._width, len(self._lines)) + self._update_size(self._updates, new_lines) + self.refresh_lines(start_line, len(new_lines)) + if ( + auto_scroll + and not self.is_vertical_scrollbar_grabbed + and is_vertical_scroll_end + ): + self.scroll_end(animate=False, immediate=True, x_axis=False) + else: + self.refresh() + return self + + def clear(self) -> Self: + """Clear the Log. + + Returns: + The `Log` instance. + """ + self._lines.clear() + self._width = 0 + self._render_line_cache.clear() + self._updates += 1 + self.virtual_size = Size(0, 0) + self._clear_y = 0 + return self + + def get_selection(self, selection: Selection) -> tuple[str, str] | None: + """Get the text under the selection. + + Args: + selection: Selection information. + + Returns: + Tuple of extracted text and ending (typically "\n" or " "), or `None` if no text could be extracted. + """ + text = "\n".join(self._lines) + return selection.extract(text), "\n" + + def selection_updated(self, selection: Selection | None) -> None: + self._render_line_cache.clear() + self.refresh() + + def render_line(self, y: int) -> Strip: + """Render a line of content. + + Args: + y: Y Coordinate of line. + + Returns: + A rendered line. + """ + scroll_x, scroll_y = self.scroll_offset + strip = self._render_line(scroll_y + y, scroll_x, self.size.width) + return strip + + def _render_line(self, y: int, scroll_x: int, width: int) -> Strip: + """Render a line into a cropped strip. + + Args: + y: Y offset of line. + scroll_x: Current horizontal scroll. + width: Width of the widget. + + Returns: + A Strip suitable for rendering. + """ + rich_style = self.rich_style + if y >= len(self._lines): + return Strip.blank(width, rich_style) + + line = self._render_line_strip(y, rich_style) + assert line._cell_length is not None + line = line.crop_extend(scroll_x, scroll_x + width, rich_style) + line = line.apply_offsets(scroll_x, y) + return line + + def _render_line_strip(self, y: int, rich_style: Style) -> Strip: + """Render a line into a Strip. + + Args: + y: Y offset of line. + rich_style: Rich style of line. + + Returns: + An uncropped Strip. + """ + selection = self.text_selection + if y in self._render_line_cache and selection is None: + return self._render_line_cache[y] + + _line = self._process_line(self._lines[y]) + + line_text = Text(_line, no_wrap=True) + line_text.stylize(rich_style) + + if self.highlight: + line_text = self.highlighter(line_text) + if selection is not None: + if (select_span := selection.get_span(y - self._clear_y)) is not None: + start, end = select_span + if end == -1: + end = len(line_text) + + selection_style = self.screen.get_component_rich_style( + "screen--selection" + ) + line_text.stylize(selection_style, start, end) + + line = Strip(line_text.render(self.app.console), cell_len(_line)) + + if selection is not None: + self._render_line_cache[y] = line + return line + + def refresh_lines(self, y_start: int, line_count: int = 1) -> None: + """Refresh one or more lines. + + Args: + y_start: First line to refresh. + line_count: Total number of lines to refresh. + """ + for y in range(y_start, y_start + line_count): + self._render_line_cache.discard(y) + super().refresh_lines(y_start, line_count=line_count) diff --git a/src/memray/_vendor/textual/widgets/_markdown.py b/src/memray/_vendor/textual/widgets/_markdown.py new file mode 100644 index 0000000000..7b0510eda2 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_markdown.py @@ -0,0 +1,1668 @@ +from __future__ import annotations + +import asyncio +import re +import weakref +from contextlib import suppress +from functools import partial +from pathlib import Path, PurePath +from typing import Callable, Iterable, Optional +from urllib.parse import unquote + +from markdown_it import MarkdownIt +from markdown_it.token import Token +from rich.text import Text +from typing_extensions import TypeAlias + +from memray._vendor.textual._slug import TrackedSlugs, slug_for_tcss_id +from memray._vendor.textual.app import ComposeResult +from memray._vendor.textual.await_complete import AwaitComplete +from memray._vendor.textual.containers import Horizontal, Vertical, VerticalScroll +from memray._vendor.textual.content import Content, Span +from memray._vendor.textual.css.query import NoMatches +from memray._vendor.textual.events import Mount +from memray._vendor.textual.highlight import highlight +from memray._vendor.textual.layout import Layout +from memray._vendor.textual.layouts.grid import GridLayout +from memray._vendor.textual.message import Message +from memray._vendor.textual.reactive import reactive, var +from memray._vendor.textual.style import Style +from memray._vendor.textual.widget import Widget +from memray._vendor.textual.widgets import Static, Tree +from memray._vendor.textual.widgets._label import Label + +TableOfContentsType: TypeAlias = "list[tuple[int, str, str | None]]" +"""Information about the table of contents of a markdown document. + +The triples encode the level, the label, and the optional block id of each heading. +""" + + +class MarkdownStream: + """An object to manage streaming markdown. + + This will accumulate markdown fragments if they can't be rendered fast enough. + + This object is typically created by the [Markdown.get_stream][textual.widgets.Markdown.get_stream] method. + + """ + + def __init__(self, markdown_widget: Markdown) -> None: + """ + Args: + markdown_widget: Markdown widget to update. + """ + self.markdown_widget = markdown_widget + self._task: asyncio.Task | None = None + self._new_markup = asyncio.Event() + self._pending: list[str] = [] + self._stopped = False + + def start(self) -> None: + """Start the updater running in the background. + + No need to call this, if the object was created by [Markdown.get_stream][textual.widgets.Markdown.get_stream]. + + """ + if self._task is None: + self._task = asyncio.create_task(self._run()) + + async def stop(self) -> None: + """Stop the stream and await its finish.""" + if self._task is not None: + self._task.cancel() + await self._task + self._task = None + self._stopped = True + + async def write(self, markdown_fragment: str) -> None: + """Append or enqueue a markdown fragment. + + Args: + markdown_fragment: A string to append at the end of the document. + """ + if self._stopped: + raise RuntimeError("Can't write to the stream after it has stopped.") + if not markdown_fragment: + # Nothing to do for empty strings. + return + # Append the new fragment, and set an event to tell the _run loop to wake up + self._pending.append(markdown_fragment) + self._new_markup.set() + # Allow the task to wake up and actually display the new markdown + await asyncio.sleep(0) + + async def _run(self) -> None: + """Run a task to append markdown fragments when available.""" + try: + while await self._new_markup.wait(): + new_markdown = "".join(self._pending) + self._pending.clear() + self._new_markup.clear() + await asyncio.shield(self.markdown_widget.append(new_markdown)) + except asyncio.CancelledError: + # Task has been cancelled, add any outstanding markdown + pass + + new_markdown = "".join(self._pending) + if new_markdown: + await self.markdown_widget.append(new_markdown) + + +class Navigator: + """Manages a stack of paths like a browser.""" + + def __init__(self) -> None: + self.stack: list[Path] = [] + self.index = 0 + + @property + def location(self) -> Path: + """The current location. + + Returns: + A path for the current document. + """ + if not self.stack: + return Path(".") + return self.stack[self.index] + + @property + def start(self) -> bool: + """Is the current location at the start of the stack?""" + return self.index == 0 + + @property + def end(self) -> bool: + """Is the current location at the end of the stack?""" + return self.index >= len(self.stack) - 1 + + def go(self, path: str | PurePath) -> Path: + """Go to a new document. + + Args: + path: Path to new document. + + Returns: + New location. + """ + location, anchor = Markdown.sanitize_location(str(path)) + if location == Path(".") and anchor: + current_file, _ = Markdown.sanitize_location(str(self.location)) + path = f"{current_file}#{anchor}" + new_path = self.location.parent / Path(path) + self.stack = self.stack[: self.index + 1] + new_path = new_path.absolute() + self.stack.append(new_path) + self.index = len(self.stack) - 1 + return new_path + + def back(self) -> bool: + """Go back in the stack. + + Returns: + True if the location changed, otherwise False. + """ + if self.index: + self.index -= 1 + return True + return False + + def forward(self) -> bool: + """Go forward in the stack. + + Returns: + True if the location changed, otherwise False. + """ + if self.index < len(self.stack) - 1: + self.index += 1 + return True + return False + + +class MarkdownBlock(Static): + """The base class for a Markdown Element.""" + + COMPONENT_CLASSES = {"em", "strong", "s", "code_inline"} + """ + These component classes target standard inline markdown styles. + Changing these will potentially break the standard markdown formatting. + + | Class | Description | + | :- | :- | + | `code_inline` | Target text that is styled as inline code. | + | `em` | Target text that is emphasized inline. | + | `s` | Target text that is styled inline with strikethrough. | + | `strong` | Target text that is styled inline with strong. | + """ + + DEFAULT_CSS = """ + MarkdownBlock { + width: 1fr; + height: auto; + } + """ + + def __init__( + self, + markdown: Markdown, + token: Token, + source_range: tuple[int, int] | None = None, + *args, + **kwargs, + ) -> None: + self._markdown_ref = weakref.ref(markdown) + """A reference to the Markdown document that contains this block.""" + self._content: Content = Content() + self._token: Token = token + self._blocks: list[MarkdownBlock] = [] + self._inline_token: Token | None = None + self.source_range: tuple[int, int] = source_range or ( + (token.map[0], token.map[1]) if token.map is not None else (0, 0) + ) + + super().__init__( + *args, name=token.type, classes=f"level-{token.level}", **kwargs + ) + + @property + def _markdown(self) -> Markdown: + """Resolve the weak ref to _markdown""" + markdown = self._markdown_ref() + assert markdown is not None + return markdown + + @property + def select_container(self) -> Widget: + return self.query_ancestor(Markdown) + + @property + def source(self) -> str | None: + """The source of this block if known, otherwise `None`.""" + if self.source_range is None: + return None + start, end = self.source_range + return "".join(self._markdown.source.splitlines(keepends=True)[start:end]) + + def _copy_context(self, block: MarkdownBlock) -> None: + """Copy the context from another block.""" + self._token = block._token + + def compose(self) -> ComposeResult: + yield from self._blocks + self._blocks.clear() + + def set_content(self, content: Content) -> None: + self._content = content + self.update(content) + + async def _update_from_block(self, block: MarkdownBlock) -> None: + await self.remove() + await self._markdown.mount(block) + + async def action_link(self, href: str) -> None: + """Called on link click.""" + self.post_message(Markdown.LinkClicked(self._markdown, href)) + + def build_from_token(self, token: Token) -> None: + """Build inline block content from its source token. + + Args: + token: The token from which this block is built. + """ + self._inline_token = token + content = self._token_to_content(token) + self.set_content(content) + + def _token_to_content(self, token: Token) -> Content: + """Convert an inline token to Textual Content. + + Args: + token: A markdown token. + + Returns: + Content instance. + """ + + if token.children is None: + return Content("") + + tokens: list[str] = [] + spans: list[Span] = [] + style_stack: list[tuple[Style | str, int]] = [] + position: int = 0 + + def add_content(text: str) -> None: + """Add text to the tokens list, and advance the position. + + Args: + text: Text to add. + + """ + nonlocal position + tokens.append(text) + position += len(text) + + def add_style(style: Style | str) -> None: + """Add a style to the stack. + + Args: + style: A style as Style instance or string. + """ + style_stack.append((style, position)) + + position = 0 + + def close_tag() -> None: + style, start = style_stack.pop() + spans.append(Span(start, position, style)) + + for child in token.children: + child_type = child.type + if child_type == "text": + add_content(re.sub(r"\s+", " ", child.content)) + if child_type == "hardbreak": + add_content("\n") + if child_type == "softbreak": + add_content(" ") + elif child_type == "code_inline": + add_style(".code_inline") + add_content(child.content) + close_tag() + elif child_type == "em_open": + add_style(".em") + elif child_type == "strong_open": + add_style(".strong") + elif child_type == "s_open": + add_style(".s") + elif child_type == "link_open": + href = child.attrs.get("href", "") + action = f"link({href!r})" + add_style(Style.from_meta({"@click": action})) + elif child_type == "image": + href = child.attrs.get("src", "") + alt = child.attrs.get("alt", "") + action = f"link({href!r})" + add_style(Style.from_meta({"@click": action})) + add_content("🖼 ") + if alt: + add_content(f"({alt})") + if child.children is not None: + for grandchild in child.children: + add_content(grandchild.content) + close_tag() + + elif child_type.endswith("_close"): + close_tag() + + content = Content("".join(tokens), spans=spans) + return content + + +class MarkdownHeader(MarkdownBlock): + """Base class for a Markdown header.""" + + LEVEL = 0 + + DEFAULT_CSS = """ + MarkdownHeader { + color: $text; + margin: 2 0 1 0; + + } + """ + + +class MarkdownH1(MarkdownHeader): + """An H1 Markdown header.""" + + LEVEL = 1 + + DEFAULT_CSS = """ + MarkdownH1 { + content-align: center middle; + color: $markdown-h1-color; + background: $markdown-h1-background; + text-style: $markdown-h1-text-style; + } + """ + + +class MarkdownH2(MarkdownHeader): + """An H2 Markdown header.""" + + LEVEL = 2 + + DEFAULT_CSS = """ + MarkdownH2 { + color: $markdown-h2-color; + background: $markdown-h2-background; + text-style: $markdown-h2-text-style; + } + """ + + +class MarkdownH3(MarkdownHeader): + """An H3 Markdown header.""" + + LEVEL = 3 + + DEFAULT_CSS = """ + MarkdownH3 { + color: $markdown-h3-color; + background: $markdown-h3-background; + text-style: $markdown-h3-text-style; + margin: 1 0; + width: auto; + } + """ + + +class MarkdownH4(MarkdownHeader): + """An H4 Markdown header.""" + + LEVEL = 4 + + DEFAULT_CSS = """ + MarkdownH4 { + color: $markdown-h4-color; + background: $markdown-h4-background; + text-style: $markdown-h4-text-style; + margin: 1 0; + } + """ + + +class MarkdownH5(MarkdownHeader): + """An H5 Markdown header.""" + + LEVEL = 5 + + DEFAULT_CSS = """ + MarkdownH5 { + color: $markdown-h5-color; + background: $markdown-h5-background; + text-style: $markdown-h5-text-style; + margin: 1 0; + } + """ + + +class MarkdownH6(MarkdownHeader): + """An H6 Markdown header.""" + + LEVEL = 6 + + DEFAULT_CSS = """ + MarkdownH6 { + color: $markdown-h6-color; + background: $markdown-h6-background; + text-style: $markdown-h6-text-style; + margin: 1 0; + } + """ + + +class MarkdownHorizontalRule(MarkdownBlock): + """A horizontal rule.""" + + DEFAULT_CSS = """ + MarkdownHorizontalRule { + border-bottom: solid $secondary; + height: 1; + padding-top: 1; + margin-bottom: 1; + } + """ + + +class MarkdownParagraph(MarkdownBlock): + """A paragraph Markdown block.""" + + SCOPED_CSS = False + DEFAULT_CSS = """ + Markdown > MarkdownParagraph { + margin: 0 0 1 0; + } + """ + + async def _update_from_block(self, block: MarkdownBlock): + if isinstance(block, MarkdownParagraph): + self.set_content(block._content) + self._copy_context(block) + else: + await super()._update_from_block(block) + + +class MarkdownBlockQuote(MarkdownBlock): + """A block quote Markdown block.""" + + DEFAULT_CSS = """ + MarkdownBlockQuote { + background: $boost; + border-left: outer $text-primary 50%; + margin: 1 0; + padding: 0 1; + } + MarkdownBlockQuote:light { + border-left: outer $text-secondary; + } + MarkdownBlockQuote > BlockQuote { + margin-left: 2; + margin-top: 1; + } + """ + + +class MarkdownList(MarkdownBlock): + DEFAULT_CSS = """ + + MarkdownList { + width: 1fr; + } + + MarkdownList MarkdownList { + margin: 0; + padding-top: 0; + } + """ + + +class MarkdownBulletList(MarkdownList): + """A Bullet list Markdown block.""" + + DEFAULT_CSS = """ + MarkdownBulletList { + margin: 0 0 1 0; + padding: 0 0; + } + + MarkdownBulletList Horizontal { + height: auto; + width: 1fr; + } + + MarkdownBulletList Vertical { + height: auto; + width: 1fr; + } + """ + + def compose(self) -> ComposeResult: + for block in self._blocks: + if isinstance(block, MarkdownListItem): + bullet = MarkdownBullet() + bullet.symbol = block.bullet + yield Horizontal(bullet, Vertical(*block._blocks)) + self._blocks.clear() + + +class MarkdownOrderedList(MarkdownList): + """An ordered list Markdown block.""" + + DEFAULT_CSS = """ + MarkdownOrderedList { + margin: 0 0 1 0; + padding: 0 0; + } + + MarkdownOrderedList Horizontal { + height: auto; + width: 1fr; + } + + MarkdownOrderedList Vertical { + height: auto; + width: 1fr; + } + """ + + def compose(self) -> ComposeResult: + suffix = ". " + start = 1 + if self._blocks and isinstance(self._blocks[0], MarkdownOrderedListItem): + try: + start = int(self._blocks[0].bullet) + except ValueError: + pass + symbol_size = max( + len(f"{number}{suffix}") + for number, block in enumerate(self._blocks, start) + if isinstance(block, MarkdownListItem) + ) + for number, block in enumerate(self._blocks, start): + if isinstance(block, MarkdownListItem): + bullet = MarkdownBullet() + bullet.symbol = f"{number}{suffix}".rjust(symbol_size + 1) + yield Horizontal(bullet, Vertical(*block._blocks)) + + self._blocks.clear() + + +class MarkdownTableCellContents(Static): + """Widget for table cells. + + A shim over a Static which responds to links. + """ + + async def action_link(self, href: str) -> None: + """Pass a link action on to the MarkdownTable parent.""" + self.post_message(Markdown.LinkClicked(self.query_ancestor(Markdown), href)) + + +class MarkdownTableContent(Widget): + """Renders a Markdown table.""" + + DEFAULT_CSS = """ + MarkdownTableContent { + width: 1fr; + height: auto; + layout: grid; + grid-columns: auto; + grid-rows: auto; + grid-gutter: 1 1; + + & > .cell { + margin: 0 0; + height: auto; + padding: 0 1; + text-overflow: ellipsis; + } + & > .header { + height: auto; + margin: 0 0; + padding: 0 1; + color: $primary; + text-overflow: ellipsis; + content-align: left bottom; + } + keyline: thin $foreground 20%; + } + MarkdownTableContent > .markdown-table--header { + text-style: bold; + } + """ + + COMPONENT_CLASSES = {"markdown-table--header", "markdown-table--lines"} + + def __init__(self, headers: list[Content], rows: list[list[Content]]): + self.headers = headers.copy() + """List of header text.""" + self.rows = rows.copy() + """The row contents.""" + super().__init__() + self.shrink = True + self.last_row = 0 + + def pre_layout(self, layout: Layout) -> None: + assert isinstance(layout, GridLayout) + layout.auto_minimum = True + layout.expand = not self.query_ancestor(MarkdownTable).styles.is_auto_width + layout.shrink = True + layout.stretch_height = True + + def compose(self) -> ComposeResult: + for header in self.headers: + yield MarkdownTableCellContents(header, classes="header").with_tooltip( + header + ) + for row_index, row in enumerate(self.rows, 1): + for cell in row: + yield MarkdownTableCellContents( + cell, classes=f"row{row_index} cell" + ).with_tooltip(cell.plain) + self.last_row = row_index + + def _update_content(self, headers: list[Content], rows: list[list[Content]]): + """Update cell contents.""" + self.headers = headers + self.rows = rows + cells: list[Content] = [ + *self.headers, + *[cell for row in self.rows for cell in row], + ] + for child, updated_cell in zip(self.query(MarkdownTableCellContents), cells): + child.update(updated_cell, layout=False) + + async def _update_rows(self, updated_rows: list[list[Content]]) -> None: + self.styles.grid_size_columns = len(self.headers) + await self.query_children(f".cell.row{self.last_row}").remove() + new_cells: list[Static] = [] + for row_index, row in enumerate(updated_rows, self.last_row): + for cell in row: + new_cells.append( + Static(cell, classes=f"row{row_index} cell").with_tooltip(cell) + ) + self.last_row = row_index + await self.mount_all(new_cells) + + def on_mount(self) -> None: + self.styles.grid_size_columns = len(self.headers) + + async def action_link(self, href: str) -> None: + """Pass a link action on to the MarkdownTable parent.""" + if isinstance(self.parent, MarkdownTable): + await self.parent.action_link(href) + + +class MarkdownTable(MarkdownBlock): + """A Table markdown Block.""" + + DEFAULT_CSS = """ + MarkdownTable { + width: 1fr; + margin-bottom: 1; + &:light { + background: white 30%; + } + } + """ + + def __init__(self, markdown: Markdown, token: Token, *args, **kwargs) -> None: + super().__init__(markdown, token, *args, **kwargs) + self._headers: list[Content] = [] + self._rows: list[list[Content]] = [] + + def compose(self) -> ComposeResult: + headers, rows = self._get_headers_and_rows() + self._headers = headers + self._rows = rows + yield MarkdownTableContent(headers, rows) + + def _get_headers_and_rows(self) -> tuple[list[Content], list[list[Content]]]: + """Get list of headers, and list of rows. + + Returns: + A tuple containing a list of headers, and a list of rows. + """ + + def flatten(block: MarkdownBlock) -> Iterable[MarkdownBlock]: + for block in block._blocks: + if block._blocks: + yield from flatten(block) + yield block + + headers: list[Content] = [] + rows: list[list[Content]] = [] + for block in flatten(self): + if isinstance(block, MarkdownTH): + headers.append(block._content) + elif isinstance(block, MarkdownTR): + rows.append([]) + elif isinstance(block, MarkdownTD): + rows[-1].append(block._content) + if rows and not rows[-1]: + rows.pop() + return headers, rows + + async def _update_from_block(self, block: MarkdownBlock) -> None: + """Special case to update a Markdown table. + + Args: + block: Existing markdown block. + """ + if isinstance(block, MarkdownTable): + try: + table_content = self.query_one(MarkdownTableContent) + except NoMatches: + pass + else: + if table_content.rows: + current_rows = self._rows + _new_headers, new_rows = block._get_headers_and_rows() + updated_rows = new_rows[len(current_rows) - 1 :] + self._rows = new_rows + await table_content._update_rows(updated_rows) + return + await super()._update_from_block(block) + + +class MarkdownTBody(MarkdownBlock): + """A table body Markdown block.""" + + +class MarkdownTHead(MarkdownBlock): + """A table head Markdown block.""" + + +class MarkdownTR(MarkdownBlock): + """A table row Markdown block.""" + + +class MarkdownTH(MarkdownBlock): + """A table header Markdown block.""" + + +class MarkdownTD(MarkdownBlock): + """A table data Markdown block.""" + + +class MarkdownBullet(Widget): + """A bullet widget.""" + + DEFAULT_CSS = """ + MarkdownBullet { + width: auto; + color: $text-primary; + &:light { + color: $text-secondary; + } + } + """ + + symbol = reactive("\u25cf") + """The symbol for the bullet.""" + + def get_selection(self, _selection) -> tuple[str, str] | None: + return self.symbol, " " + + def render(self) -> Content: + return Content(self.symbol) + + +class MarkdownListItem(MarkdownBlock): + """A list item Markdown block.""" + + DEFAULT_CSS = """ + MarkdownListItem { + layout: horizontal; + margin-right: 1; + height: auto; + } + + MarkdownListItem > Vertical { + width: 1fr; + height: auto; + } + """ + + def __init__(self, markdown: Markdown, token: Token, bullet: str) -> None: + self.bullet = bullet + super().__init__(markdown, token) + + +class MarkdownOrderedListItem(MarkdownListItem): + pass + + +class MarkdownUnorderedListItem(MarkdownListItem): + pass + + +class MarkdownFence(MarkdownBlock): + """A fence Markdown block.""" + + DEFAULT_CSS = """ + MarkdownFence { + padding: 0; + margin: 1 0; + overflow: scroll hidden; + scrollbar-size-horizontal: 0; + scrollbar-size-vertical: 0; + width: 1fr; + height: auto; + color: rgb(210,210,210); + background: black 10%; + &:light { + background: white 30%; + } + & > Label { + padding: 1 2; + } + } + """ + + def __init__(self, markdown: Markdown, token: Token, code: str) -> None: + super().__init__(markdown, token) + self.code = code + self.lexer = token.info + self._highlighted_code = self.highlight(self.code, self.lexer) + + @property + def allow_horizontal_scroll(self) -> bool: + return True + + @classmethod + def highlight(cls, code: str, language: str) -> Content: + return highlight(code, language=language or None) + + def _copy_context(self, block: MarkdownBlock) -> None: + if isinstance(block, MarkdownFence): + self.lexer = block.lexer + self._token = block._token + + async def _update_from_block(self, block: MarkdownBlock): + if isinstance(block, MarkdownFence): + self.set_content(block._highlighted_code) + self._copy_context(block) + else: + await super()._update_from_block(block) + + def set_content(self, content: Content) -> None: + self._content = content + with suppress(NoMatches): + self.query_one("#code-content", Label).update(content) + + def compose(self) -> ComposeResult: + yield Label(self._highlighted_code, id="code-content") + + +NUMERALS = " ⅠⅡⅢⅣⅤⅥ" + + +class Markdown(Widget): + DEFAULT_CSS = """ + Markdown { + height: auto; + padding: 0 2 0 2; + layout: vertical; + color: $foreground; + overflow-y: hidden; + + MarkdownBlock { + &:dark > .code_inline { + background: $warning 10%; + color: $text-warning 95%; + } + &:light > .code_inline { + background: $error 5%; + color: $text-error 95%; + } + & > .em { + text-style: italic; + } + & > .strong { + text-style: bold; + } + & > .s { + text-style: strike; + } + } + } + """ + + BULLETS = ["• ", "▪ ", "‣ ", "⭑ ", "◦ "] + """Unicode bullets used for unordered lists.""" + + BLOCKS: dict[str, type[MarkdownBlock]] = { + "h1": MarkdownH1, + "h2": MarkdownH2, + "h3": MarkdownH3, + "h4": MarkdownH4, + "h5": MarkdownH5, + "h6": MarkdownH6, + "hr": MarkdownHorizontalRule, + "paragraph_open": MarkdownParagraph, + "blockquote_open": MarkdownBlockQuote, + "bullet_list_open": MarkdownBulletList, + "ordered_list_open": MarkdownOrderedList, + "list_item_ordered_open": MarkdownOrderedListItem, + "list_item_unordered_open": MarkdownUnorderedListItem, + "table_open": MarkdownTable, + "tbody_open": MarkdownTBody, + "thead_open": MarkdownTHead, + "tr_open": MarkdownTR, + "th_open": MarkdownTH, + "td_open": MarkdownTD, + "fence": MarkdownFence, + "code_block": MarkdownFence, + } + """Mapping of block names on to a widget class.""" + + def __init__( + self, + markdown: str | None = None, + *, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + parser_factory: Callable[[], MarkdownIt] | None = None, + open_links: bool = True, + ): + """A Markdown widget. + + Args: + markdown: String containing Markdown or None to leave blank for now. + name: The name of the widget. + id: The ID of the widget in the DOM. + classes: The CSS classes of the widget. + parser_factory: A factory function to return a configured MarkdownIt instance. If `None`, a "gfm-like" parser is used. + open_links: Open links automatically. If you set this to `False`, you can handle the [`LinkClicked`][textual.widgets.markdown.Markdown.LinkClicked] events. + """ + super().__init__(name=name, id=id, classes=classes) + self._initial_markdown: str | None = markdown + self._markdown = "" + self._parser_factory = parser_factory + self._table_of_contents: TableOfContentsType | None = None + self._open_links = open_links + self._last_parsed_line = 0 + self._theme = "" + + @property + def table_of_contents(self) -> TableOfContentsType: + """The document's table of contents.""" + if self._table_of_contents is None: + self._table_of_contents = [ + (header.LEVEL, header._content.plain, header.id) + for header in self.children + if isinstance(header, MarkdownHeader) + ] + return self._table_of_contents + + class TableOfContentsUpdated(Message): + """The table of contents was updated.""" + + def __init__( + self, markdown: Markdown, table_of_contents: TableOfContentsType + ) -> None: + super().__init__() + self.markdown: Markdown = markdown + """The `Markdown` widget associated with the table of contents.""" + self.table_of_contents: TableOfContentsType = table_of_contents + """Table of contents.""" + + @property + def control(self) -> Markdown: + """The `Markdown` widget associated with the table of contents. + + This is an alias for [`TableOfContentsUpdated.markdown`][textual.widgets.Markdown.TableOfContentsSelected.markdown] + and is used by the [`on`][textual.on] decorator. + """ + return self.markdown + + class TableOfContentsSelected(Message): + """An item in the TOC was selected.""" + + def __init__(self, markdown: Markdown, block_id: str) -> None: + super().__init__() + self.markdown: Markdown = markdown + """The `Markdown` widget where the selected item is.""" + self.block_id: str = block_id + """ID of the block that was selected.""" + + @property + def control(self) -> Markdown: + """The `Markdown` widget where the selected item is. + + This is an alias for [`TableOfContentsSelected.markdown`][textual.widgets.Markdown.TableOfContentsSelected.markdown] + and is used by the [`on`][textual.on] decorator. + """ + return self.markdown + + class LinkClicked(Message): + """A link in the document was clicked.""" + + def __init__(self, markdown: Markdown, href: str) -> None: + super().__init__() + self.markdown: Markdown = markdown + """The `Markdown` widget containing the link clicked.""" + self.href: str = unquote(href) + """The link that was selected.""" + + @property + def control(self) -> Markdown: + """The `Markdown` widget containing the link clicked. + + This is an alias for [`LinkClicked.markdown`][textual.widgets.Markdown.LinkClicked.markdown] + and is used by the [`on`][textual.on] decorator. + """ + return self.markdown + + @property + def source(self) -> str: + """The markdown source.""" + return self._markdown or "" + + def get_block_class(self, block_name: str) -> type[MarkdownBlock]: + """Get the block widget class. + + Args: + block_name: Name of the block. + + Returns: + A MarkdownBlock class + """ + return self.BLOCKS[block_name] + + async def _on_mount(self, _: Mount) -> None: + initial_markdown = self._initial_markdown + self._initial_markdown = None + await self.update(initial_markdown or "") + + if initial_markdown is None: + self.post_message( + Markdown.TableOfContentsUpdated( + self, self._table_of_contents + ).set_sender(self) + ) + + @classmethod + def get_stream(cls, markdown: Markdown) -> MarkdownStream: + """Get a [MarkdownStream][textual.widgets.markdown.MarkdownStream] instance to stream Markdown in the background. + + If you append to the Markdown document many times a second, it is possible the widget won't + be able to update as fast as you write (occurs around 20 appends per second). It will still + work, but the user will have to wait for the UI to catch up after the document has be retrieved. + + Using a [MarkdownStream][textual.widgets.markdown.MarkdownStream] will combine several updates in to one + as necessary to keep up with the incoming data. + + example: + ```python + # self.get_chunk is a hypothetical method that retrieves a + # markdown fragment from the network + @work + async def stream_markdown(self) -> None: + markdown_widget = self.query_one(Markdown) + container = self.query_one(VerticalScroll) + container.anchor() + + stream = Markdown.get_stream(markdown_widget) + try: + while (chunk:= await self.get_chunk()) is not None: + await stream.write(chunk) + finally: + await stream.stop() + ``` + + + Args: + markdown: A [Markdown][textual.widgets.Markdown] widget instance. + + Returns: + The Markdown stream object. + """ + updater = MarkdownStream(markdown) + updater.start() + return updater + + def on_markdown_link_clicked(self, event: LinkClicked) -> None: + if self._open_links: + self.app.open_url(event.href) + + @staticmethod + def sanitize_location(location: str) -> tuple[Path, str]: + """Given a location, break out the path and any anchor. + + Args: + location: The location to sanitize. + + Returns: + A tuple of the path to the location cleaned of any anchor, plus + the anchor (or an empty string if none was found). + """ + location, _, anchor = location.partition("#") + return Path(location), anchor + + def goto_anchor(self, anchor: str) -> bool: + """Try and find the given anchor in the current document. + + Args: + anchor: The anchor to try and find. + + Note: + The anchor is found by looking at all of the headings in the + document and finding the first one whose slug matches the + anchor. + + Note that the slugging method used is similar to that found on + GitHub. + + Returns: + True when the anchor was found in the current document, False otherwise. + """ + if not self._table_of_contents or not isinstance(self.parent, Widget): + return False + unique = TrackedSlugs() + for _, title, header_id in self._table_of_contents: + if unique.slug(title) == anchor: + self.query_one(f"#{header_id}").scroll_visible(top=True) + return True + return False + + async def load(self, path: Path) -> None: + """Load a new Markdown document. + + Args: + path: Path to the document. + + Raises: + OSError: If there was some form of error loading the document. + + Note: + The exceptions that can be raised by this method are all of + those that can be raised by calling [`Path.read_text`][pathlib.Path.read_text]. + """ + path, anchor = self.sanitize_location(str(path)) + data = await asyncio.get_running_loop().run_in_executor( + None, partial(path.read_text, encoding="utf-8") + ) + await self.update(data) + if anchor: + self.goto_anchor(anchor) + + def unhandled_token(self, token: Token) -> MarkdownBlock | None: + """Process an unhandled token. + + Args: + token: The MarkdownIt token to handle. + + Returns: + Either a widget to be added to the output, or `None`. + """ + return None + + def _parse_markdown(self, tokens: Iterable[Token]) -> Iterable[MarkdownBlock]: + """Create a stream of MarkdownBlock widgets from markdown. + + Args: + tokens: List of tokens. + + Yields: + Widgets for mounting. + """ + + stack: list[MarkdownBlock] = [] + stack_append = stack.append + + get_block_class = self.get_block_class + + for token in tokens: + token_type = token.type + if token_type == "heading_open": + stack_append(get_block_class(token.tag)(self, token)) + elif token_type == "hr": + yield get_block_class("hr")(self, token) + elif token_type == "paragraph_open": + stack_append(get_block_class("paragraph_open")(self, token)) + elif token_type == "blockquote_open": + stack_append(get_block_class("blockquote_open")(self, token)) + elif token_type == "bullet_list_open": + stack_append(get_block_class("bullet_list_open")(self, token)) + elif token_type == "ordered_list_open": + stack_append(get_block_class("ordered_list_open")(self, token)) + elif token_type == "list_item_open": + if token.info: + stack_append( + get_block_class("list_item_ordered_open")( + self, token, token.info + ) + ) + else: + item_count = sum( + 1 + for block in stack + if isinstance(block, MarkdownUnorderedListItem) + ) + stack_append( + get_block_class("list_item_unordered_open")( + self, + token, + self.BULLETS[item_count % len(self.BULLETS)], + ) + ) + elif token_type == "table_open": + stack_append(get_block_class("table_open")(self, token)) + elif token_type == "tbody_open": + stack_append(get_block_class("tbody_open")(self, token)) + elif token_type == "thead_open": + stack_append(get_block_class("thead_open")(self, token)) + elif token_type == "tr_open": + stack_append(get_block_class("tr_open")(self, token)) + elif token_type == "th_open": + stack_append(get_block_class("th_open")(self, token)) + elif token_type == "td_open": + stack_append(get_block_class("td_open")(self, token)) + elif token_type.endswith("_close"): + block = stack.pop() + if token.type == "heading_close": + block.id = ( + f"heading-{slug_for_tcss_id(block._content.plain)}-{id(block)}" + ) + if stack: + stack[-1]._blocks.append(block) + else: + yield block + elif token_type == "inline": + stack[-1].build_from_token(token) + elif token_type in ("fence", "code_block"): + fence_class = get_block_class(token_type) + assert issubclass(fence_class, MarkdownFence) + fence = fence_class(self, token, token.content.rstrip()) + if stack: + stack[-1]._blocks.append(fence) + else: + yield fence + else: + external = self.unhandled_token(token) + if external is not None: + if stack: + stack[-1]._blocks.append(external) + else: + yield external + + def _build_from_source(self, markdown: str) -> list[MarkdownBlock]: + """Build blocks from markdown source. + + Args: + markdown: A Markdown document, or partial document. + + Returns: + A list of MarkdownBlock instances. + """ + parser = ( + MarkdownIt("gfm-like") + if self._parser_factory is None + else self._parser_factory() + ) + tokens = parser.parse(markdown) + return list(self._parse_markdown(tokens)) + + def update(self, markdown: str) -> AwaitComplete: + """Update the document with new Markdown. + + Args: + markdown: A string containing Markdown. + + Returns: + An optionally awaitable object. Await this to ensure that all children have been mounted. + """ + self._theme = self.app.theme + parser = ( + MarkdownIt("gfm-like") + if self._parser_factory is None + else self._parser_factory() + ) + + markdown_block = self.query("MarkdownBlock") + self._markdown = markdown + self._table_of_contents = None + + async def await_update() -> None: + """Update in batches.""" + BATCH_SIZE = 200 + batch: list[MarkdownBlock] = [] + + # Lock so that you can't update with more than one document simultaneously + async with self.lock: + tokens = await asyncio.get_running_loop().run_in_executor( + None, parser.parse, markdown + ) + + # Remove existing blocks for the first batch only + removed: bool = False + + async def mount_batch(batch: list[MarkdownBlock]) -> None: + """Mount a single match of blocks. + + Args: + batch: A list of blocks to mount. + """ + nonlocal removed + if removed: + await self.mount_all(batch) + else: + with self.app.batch_update(): + await markdown_block.remove() + await self.mount_all(batch) + removed = True + + for block in self._parse_markdown(tokens): + batch.append(block) + if len(batch) == BATCH_SIZE: + await mount_batch(batch) + batch.clear() + if batch: + await mount_batch(batch) + if not removed: + await markdown_block.remove() + + lines = markdown.splitlines() + self._last_parsed_line = len(lines) - (1 if lines and lines[-1] else 0) + self.post_message( + Markdown.TableOfContentsUpdated( + self, self.table_of_contents + ).set_sender(self) + ) + + return AwaitComplete(await_update()) + + def append(self, markdown: str) -> AwaitComplete: + """Append to markdown. + + Args: + markdown: A fragment of markdown to be appended. + + Returns: + An optionally awaitable object. Await this to ensure that the markdown has been append by the next line. + """ + parser = ( + MarkdownIt("gfm-like") + if self._parser_factory is None + else self._parser_factory() + ) + + self._markdown = self.source + markdown + updated_source = "".join( + self._markdown.splitlines(keepends=True)[self._last_parsed_line :] + ) + + async def await_append() -> None: + """Append new markdown widgets.""" + async with self.lock: + tokens = parser.parse(updated_source) + existing_blocks = [ + child for child in self.children if isinstance(child, MarkdownBlock) + ] + start_line = self._last_parsed_line + for token in reversed(tokens): + if token.map is not None and token.level == 0: + self._last_parsed_line += token.map[0] + break + + new_blocks = list(self._parse_markdown(tokens)) + any_headers = any( + isinstance(block, MarkdownHeader) for block in new_blocks + ) + for block in new_blocks: + start, end = block.source_range + block.source_range = ( + start + start_line, + end + start_line, + ) + + with self.app.batch_update(): + if existing_blocks and new_blocks: + last_block = existing_blocks[-1] + last_block.source_range = new_blocks[0].source_range + try: + await last_block._update_from_block(new_blocks[0]) + except IndexError: + pass + else: + new_blocks = new_blocks[1:] + + if new_blocks: + await self.mount_all(new_blocks) + + if any_headers: + self._table_of_contents = None + self.post_message( + Markdown.TableOfContentsUpdated( + self, self.table_of_contents + ).set_sender(self) + ) + + return AwaitComplete(await_append()) + + +class MarkdownTableOfContents(Widget, can_focus_children=True): + """Displays a table of contents for a markdown document.""" + + DEFAULT_CSS = """ + MarkdownTableOfContents { + width: auto; + height: 1fr; + background: $panel; + &:focus-within { + background-tint: $foreground 5%; + } + } + MarkdownTableOfContents > Tree { + padding: 1; + width: auto; + height: 1fr; + background: $panel; + } + """ + + table_of_contents = reactive[Optional[TableOfContentsType]](None, init=False) + """Underlying data to populate the table of contents widget.""" + + def __init__( + self, + markdown: Markdown, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + ) -> None: + """Initialize a table of contents. + + Args: + markdown: The Markdown document associated with this table of contents. + name: The name of the widget. + id: The ID of the widget in the DOM. + classes: The CSS classes for the widget. + disabled: Whether the widget is disabled or not. + """ + self.markdown: Markdown = markdown + """The Markdown document associated with this table of contents.""" + super().__init__(name=name, id=id, classes=classes, disabled=disabled) + + def compose(self) -> ComposeResult: + tree: Tree = Tree("TOC") + tree.show_root = False + tree.show_guides = True + tree.guide_depth = 4 + tree.auto_expand = False + yield tree + + def watch_table_of_contents(self, table_of_contents: TableOfContentsType) -> None: + """Triggered when the table of contents changes.""" + self.rebuild_table_of_contents(table_of_contents) + + def rebuild_table_of_contents(self, table_of_contents: TableOfContentsType) -> None: + """Rebuilds the tree representation of the table of contents data. + + Args: + table_of_contents: Table of contents. + """ + tree = self.query_one(Tree) + tree.clear() + root = tree.root + for level, name, block_id in table_of_contents: + node = root + for _ in range(level - 1): + if node._children: + node = node._children[-1] + node.expand() + node.allow_expand = True + else: + node = node.add(NUMERALS[level], expand=True) + node_label = Text.assemble((f"{NUMERALS[level]} ", "dim"), name) + node.add_leaf(node_label, {"block_id": block_id}) + + async def _on_tree_node_selected(self, message: Tree.NodeSelected) -> None: + node_data = message.node.data + if node_data is not None: + await self._post_message( + Markdown.TableOfContentsSelected(self.markdown, node_data["block_id"]) + ) + message.stop() + + +class MarkdownViewer(VerticalScroll, can_focus=False, can_focus_children=True): + """A Markdown viewer widget.""" + + SCOPED_CSS = False + + DEFAULT_CSS = """ + MarkdownViewer { + height: 1fr; + scrollbar-gutter: stable; + background: $surface; + & > MarkdownTableOfContents { + display: none; + dock:left; + } + } + + MarkdownViewer.-show-table-of-contents > MarkdownTableOfContents { + display: block; + } + """ + + show_table_of_contents = reactive(True) + """Show the table of contents?""" + top_block = reactive("") + + navigator: var[Navigator] = var(Navigator) + + class NavigatorUpdated(Message): + """Navigator has been changed (clicked link etc).""" + + def __init__( + self, + markdown: str | None = None, + *, + show_table_of_contents: bool = True, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + parser_factory: Callable[[], MarkdownIt] | None = None, + open_links: bool = True, + ): + """Create a Markdown Viewer object. + + Args: + markdown: String containing Markdown, or None to leave blank. + show_table_of_contents: Show a table of contents in a sidebar. + name: The name of the widget. + id: The ID of the widget in the DOM. + classes: The CSS classes of the widget. + parser_factory: A factory function to return a configured MarkdownIt instance. If `None`, a "gfm-like" parser is used. + open_links: Open links automatically. If you set this to `False`, you can handle the [`LinkClicked`][textual.widgets.markdown.Markdown.LinkClicked] events. + """ + super().__init__(name=name, id=id, classes=classes) + self.show_table_of_contents = show_table_of_contents + self._markdown = markdown + self._parser_factory = parser_factory + self._open_links = open_links + + @property + def document(self) -> Markdown: + """The [`Markdown`][textual.widgets.Markdown] document widget.""" + return self.query_one(Markdown) + + @property + def table_of_contents(self) -> MarkdownTableOfContents: + """The [table of contents][textual.widgets.markdown.MarkdownTableOfContents] widget.""" + return self.query_one(MarkdownTableOfContents) + + async def _on_mount(self, _: Mount) -> None: + await self.document.update(self._markdown or "") + + async def go(self, location: str | PurePath) -> None: + """Navigate to a new document path.""" + path, anchor = self.document.sanitize_location(str(location)) + if path == Path(".") and anchor: + # We've been asked to go to an anchor but with no file specified. + self.document.goto_anchor(anchor) + else: + # We've been asked to go to a file, optionally with an anchor. + await self.document.load(self.navigator.go(location)) + self.post_message(self.NavigatorUpdated()) + + async def back(self) -> None: + """Go back one level in the history.""" + if self.navigator.back(): + await self.document.load(self.navigator.location) + self.post_message(self.NavigatorUpdated()) + + async def forward(self) -> None: + """Go forward one level in the history.""" + if self.navigator.forward(): + await self.document.load(self.navigator.location) + self.post_message(self.NavigatorUpdated()) + + async def _on_markdown_link_clicked(self, message: Markdown.LinkClicked) -> None: + message.stop() + await self.go(message.href) + + def watch_show_table_of_contents(self, show_table_of_contents: bool) -> None: + self.set_class(show_table_of_contents, "-show-table-of-contents") + + def compose(self) -> ComposeResult: + markdown = Markdown( + parser_factory=self._parser_factory, open_links=self._open_links + ) + markdown.can_focus = True + yield markdown + yield MarkdownTableOfContents(markdown) + + def _on_markdown_table_of_contents_updated( + self, message: Markdown.TableOfContentsUpdated + ) -> None: + self.query_one(MarkdownTableOfContents).table_of_contents = ( + message.table_of_contents + ) + message.stop() + + def _on_markdown_table_of_contents_selected( + self, message: Markdown.TableOfContentsSelected + ) -> None: + block_selector = f"#{message.block_id}" + block = self.query_one(block_selector, MarkdownBlock) + self.scroll_to_widget(block, top=True) + message.stop() diff --git a/src/memray/_vendor/textual/widgets/_markdown_viewer.py b/src/memray/_vendor/textual/widgets/_markdown_viewer.py new file mode 100644 index 0000000000..2d32175867 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_markdown_viewer.py @@ -0,0 +1,3 @@ +from memray._vendor.textual.widgets._markdown import MarkdownViewer + +__all__ = ["MarkdownViewer"] diff --git a/src/memray/_vendor/textual/widgets/_masked_input.py b/src/memray/_vendor/textual/widgets/_masked_input.py new file mode 100644 index 0000000000..27cdf01b8e --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_masked_input.py @@ -0,0 +1,707 @@ +from __future__ import annotations + +import re +from dataclasses import dataclass +from enum import Flag, auto +from typing import TYPE_CHECKING, Iterable, Pattern + +from rich.console import RenderableType +from rich.segment import Segment +from rich.text import Text +from typing_extensions import Literal + +from memray._vendor.textual import events +from memray._vendor.textual.strip import Strip + +if TYPE_CHECKING: + pass + +from memray._vendor.textual.reactive import Reactive, var +from memray._vendor.textual.validation import ValidationResult, Validator +from memray._vendor.textual.widgets._input import Input + +InputValidationOn = Literal["blur", "changed", "submitted"] +"""Possible messages that trigger input validation.""" + + +class _CharFlags(Flag): + """Misc flags for a single template character definition""" + + NONE = 0 + """Empty flags value""" + + REQUIRED = auto() + """Is this character required for validation?""" + + SEPARATOR = auto() + """Is this character a separator?""" + + UPPERCASE = auto() + """Char is forced to be uppercase""" + + LOWERCASE = auto() + """Char is forced to be lowercase""" + + +_TEMPLATE_CHARACTERS = { + "A": (r"[A-Za-z]", _CharFlags.REQUIRED), + "a": (r"[A-Za-z]", None), + "N": (r"[A-Za-z0-9]", _CharFlags.REQUIRED), + "n": (r"[A-Za-z0-9]", None), + "X": (r"[^ ]", _CharFlags.REQUIRED), + "x": (r"[^ ]", None), + "9": (r"[0-9]", _CharFlags.REQUIRED), + "0": (r"[0-9]", None), + "D": (r"[1-9]", _CharFlags.REQUIRED), + "d": (r"[1-9]", None), + "#": (r"[0-9+\-]", None), + "H": (r"[A-Fa-f0-9]", _CharFlags.REQUIRED), + "h": (r"[A-Fa-f0-9]", None), + "B": (r"[0-1]", _CharFlags.REQUIRED), + "b": (r"[0-1]", None), +} + + +class _Template(Validator): + """Template mask enforcer.""" + + @dataclass + class CharDefinition: + """Holds data for a single char of the template mask.""" + + pattern: Pattern[str] + """Compiled regular expression to check for matches.""" + + flags: _CharFlags = _CharFlags.NONE + """Flags defining special behaviors""" + + char: str = "" + """Mask character (separator or blank or placeholder)""" + + def __init__(self, input: Input, template_str: str) -> None: + """Initialise the mask enforcer, which is also a subclass of `Validator`. + + Args: + input: The `MaskedInput` that owns this object. + template_str: Template string controlling masked input behavior. + """ + self.input = input + self.template: list[_Template.CharDefinition] = [] + self.blank: str = " " + escaped = False + flags = _CharFlags.NONE + template_chars: list[str] = list(template_str) + + while template_chars: + char = template_chars.pop(0) + if escaped: + char_definition = self.CharDefinition( + re.compile(re.escape(char)), _CharFlags.SEPARATOR, char + ) + escaped = False + else: + if char == "\\": + escaped = True + continue + elif char == ";": + break + + new_flags = { + ">": _CharFlags.UPPERCASE, + "<": _CharFlags.LOWERCASE, + "!": _CharFlags.NONE, + }.get(char, None) + if new_flags is not None: + flags = new_flags + continue + + pattern, required_flag = _TEMPLATE_CHARACTERS.get(char, (None, None)) + if pattern: + char_flags = ( + _CharFlags.REQUIRED if required_flag else _CharFlags.NONE + ) + char_definition = self.CharDefinition( + re.compile(pattern), char_flags + ) + else: + char_definition = self.CharDefinition( + re.compile(re.escape(char)), _CharFlags.SEPARATOR, char + ) + + char_definition.flags |= flags + self.template.append(char_definition) + + if template_chars: + self.blank = template_chars[0] + + if all( + (_CharFlags.SEPARATOR in char_definition.flags) + for char_definition in self.template + ): + raise ValueError( + "Template must contain at least one non-separator character" + ) + + self.update_mask(input.placeholder) + + def validate(self, value: str) -> ValidationResult: + """Checks if `value` matches this template, always returning a ValidationResult. + + Args: + value: The string value to be validated. + + Returns: + A ValidationResult with the validation outcome. + + """ + if self.check(value.ljust(len(self.template), chr(0)), False): + return self.success() + else: + return self.failure("Value does not match template!", value) + + def check(self, value: str, allow_space: bool) -> bool: + """Checks if `value matches this template, but returns result as a bool. + + Args: + value: The string value to be validated. + allow_space: Consider space character in `value` as valid. + + Returns: + True if `value` is valid for this template, False otherwise. + """ + for char, char_definition in zip(value, self.template): + if ( + (_CharFlags.REQUIRED in char_definition.flags) + and (not char_definition.pattern.match(char)) + and ((char != " ") or not allow_space) + ): + return False + return True + + def insert_separators(self, value: str, cursor_position: int) -> tuple[str, int]: + """Automatically inserts separators in `value` at `cursor_position` if expected, eventually advancing + the current cursor position. + + Args: + value: Current control value entered by user. + cursor_position: Where to start inserting separators (if any). + + Returns: + A tuple in the form `(value, cursor_position)` with new value and possibly advanced cursor position. + """ + while cursor_position < len(self.template) and ( + _CharFlags.SEPARATOR in self.template[cursor_position].flags + ): + value = ( + value[:cursor_position] + + self.template[cursor_position].char + + value[cursor_position + 1 :] + ) + cursor_position += 1 + return value, cursor_position + + def insert_text_at_cursor(self, text: str) -> str | None: + """Inserts `text` at current cursor position. If not present in `text`, any expected separator is automatically + inserted at the correct position. + + Args: + text: The text to be inserted. + + Returns: + A tuple in the form `(value, cursor_position)` with the new control value and current cursor position if + `text` matches the template, None otherwise. + """ + value = self.input.value + cursor_position = self.input.cursor_position + separators = set( + [ + char_definition.char + for char_definition in self.template + if _CharFlags.SEPARATOR in char_definition.flags + ] + ) + for char in text: + if char in separators: + if char == self.next_separator(cursor_position): + prev_position = self.prev_separator_position(cursor_position) + if (cursor_position > 0) and (prev_position != cursor_position - 1): + next_position = self.next_separator_position(cursor_position) + while cursor_position < next_position + 1: + if ( + _CharFlags.SEPARATOR + in self.template[cursor_position].flags + ): + char = self.template[cursor_position].char + else: + char = " " + value = ( + value[:cursor_position] + + char + + value[cursor_position + 1 :] + ) + cursor_position += 1 + continue + if cursor_position >= len(self.template): + break + char_definition = self.template[cursor_position] + assert _CharFlags.SEPARATOR not in char_definition.flags + if not char_definition.pattern.match(char): + return None + if _CharFlags.LOWERCASE in char_definition.flags: + char = char.lower() + elif _CharFlags.UPPERCASE in char_definition.flags: + char = char.upper() + value = value[:cursor_position] + char + value[cursor_position + 1 :] + cursor_position += 1 + value, cursor_position = self.insert_separators(value, cursor_position) + return value, cursor_position + + def move_cursor(self, delta: int) -> None: + """Moves the cursor position by `delta` characters, skipping separators if + running over them. + + Args: + delta: The number of characters to move; positive moves right, negative + moves left. + """ + cursor_position = self.input.cursor_position + if delta < 0 and all( + [ + (_CharFlags.SEPARATOR in char_definition.flags) + for char_definition in self.template[:cursor_position] + ] + ): + return + cursor_position += delta + while ( + (cursor_position >= 0) + and (cursor_position < len(self.template)) + and (_CharFlags.SEPARATOR in self.template[cursor_position].flags) + ): + cursor_position += delta + self.input.cursor_position = cursor_position + + def delete_at_position(self, position: int | None = None) -> None: + """Deletes character at `position`. + + Args: + position: Position within the control value where to delete a character; + if None the current cursor position is used. + """ + value = self.input.value + if position is None: + position = self.input.cursor_position + cursor_position = position + if cursor_position < len(self.template): + assert _CharFlags.SEPARATOR not in self.template[cursor_position].flags + if cursor_position == len(value) - 1: + value = value[:cursor_position] + else: + value = value[:cursor_position] + " " + value[cursor_position + 1 :] + pos = len(value) + while pos > 0: + char_definition = self.template[pos - 1] + if (_CharFlags.SEPARATOR not in char_definition.flags) and ( + value[pos - 1] != " " + ): + break + pos -= 1 + value = value[:pos] + if cursor_position > len(value): + cursor_position = len(value) + value, cursor_position = self.insert_separators(value, cursor_position) + self.input.cursor_position = cursor_position + self.input.value = value + + def at_separator(self, position: int | None = None) -> bool: + """Checks if character at `position` is a separator. + + Args: + position: Position within the control value where to check; + if None the current cursor position is used. + + Returns: + True if character is a separator, False otherwise. + """ + if position is None: + position = self.input.cursor_position + if (position >= 0) and (position < len(self.template)): + return _CharFlags.SEPARATOR in self.template[position].flags + else: + return False + + def prev_separator_position(self, position: int | None = None) -> int | None: + """Obtains the position of the previous separator character starting from + `position` within the template string. + + Args: + position: Starting position from which to search previous separator. + If None, current cursor position is used. + + Returns: + The position of the previous separator, or None if no previous + separator is found. + """ + if position is None: + position = self.input.cursor_position + for index in range(position - 1, 0, -1): + if _CharFlags.SEPARATOR in self.template[index].flags: + return index + else: + return None + + def next_separator_position(self, position: int | None = None) -> int | None: + """Obtains the position of the next separator character starting from + `position` within the template string. + + Args: + position: Starting position from which to search next separator. + If None, current cursor position is used. + + Returns: + The position of the next separator, or None if no next + separator is found. + """ + if position is None: + position = self.input.cursor_position + for index in range(position + 1, len(self.template)): + if _CharFlags.SEPARATOR in self.template[index].flags: + return index + else: + return None + + def next_separator(self, position: int | None = None) -> str | None: + """Obtains the next separator character starting from `position` + within the template string. + + Args: + position: Starting position from which to search next separator. + If None, current cursor position is used. + + Returns: + The next separator character, or None if no next + separator is found. + """ + position = self.next_separator_position(position) + if position is None: + return None + else: + return self.template[position].char + + def display(self, value: str) -> str: + """Returns `value` ready for display, with spaces replaced by + placeholder characters. + + Args: + value: String value to display. + + Returns: + New string value with spaces replaced by placeholders. + """ + result = [] + for char, char_definition in zip(value, self.template): + if char == " ": + char = char_definition.char + result.append(char) + return "".join(result) + + def update_mask(self, placeholder: str) -> None: + """Updates template placeholder characters from `placeholder`. If + given string is smaller than template string, template blank character + is used to fill remaining template placeholder characters. + + Args: + placeholder: New placeholder string. + """ + for index, char_definition in enumerate(self.template): + if _CharFlags.SEPARATOR not in char_definition.flags: + if index < len(placeholder): + char_definition.char = placeholder[index] + else: + char_definition.char = self.blank + + @property + def mask(self) -> str: + """Property returning the template placeholder mask.""" + return "".join([char_definition.char for char_definition in self.template]) + + @property + def empty_mask(self) -> str: + """Property returning the template placeholder mask with all non-separators replaced by space.""" + return "".join( + [ + ( + " " + if (_CharFlags.SEPARATOR not in char_definition.flags) + else char_definition.char + ) + for char_definition in self.template + ] + ) + + +class MaskedInput(Input, can_focus=True): + """A masked text input widget.""" + + template: Reactive[str] = var("") + """Input template mask currently in use.""" + + def __init__( + self, + template: str, + value: str | None = None, + placeholder: str = "", + *, + validators: Validator | Iterable[Validator] | None = None, + validate_on: Iterable[InputValidationOn] | None = None, + valid_empty: bool = False, + select_on_focus: bool = True, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + tooltip: RenderableType | None = None, + compact: bool = False, + ) -> None: + """Initialise the `MaskedInput` widget. + + Args: + template: Template string. + value: An optional default value for the input. + placeholder: Optional placeholder text for the input. + validators: An iterable of validators that the MaskedInput value will be checked against. + validate_on: Zero or more of the values "blur", "changed", and "submitted", + which determine when to do input validation. The default is to do + validation for all messages. + valid_empty: Empty values are valid. + name: Optional name for the masked input widget. + id: Optional ID for the widget. + classes: Optional initial classes for the widget. + disabled: Whether the input is disabled or not. + tooltip: Optional tooltip. + compact: Enable compact style (without borders). + """ + self._template: _Template = None + super().__init__( + placeholder=placeholder, + validators=validators, + validate_on=validate_on, + valid_empty=valid_empty, + select_on_focus=select_on_focus, + name=name, + id=id, + classes=classes, + disabled=disabled, + compact=compact, + ) + + self._template = _Template(self, template) + self.template = template + + value, _ = self._template.insert_separators(value or "", 0) + self.value = value + if tooltip is not None: + self.tooltip = tooltip + + def validate_value(self, value: str) -> str: + """Validates value against template.""" + if self._template is None: + return value + if not self._template.check(value, True): + raise ValueError("Value does not match template!") + return value[: len(self._template.mask)] + + def _watch_template(self, template: str) -> None: + """Revalidate when template changes.""" + self._template = _Template(self, template) if template else None + if self.is_mounted: + self._watch_value(self.value) + + def _watch_placeholder(self, placeholder: str) -> None: + """Update template display mask when placeholder changes.""" + if self._template is not None: + self._template.update_mask(placeholder) + self.refresh() + + def validate(self, value: str) -> ValidationResult | None: + """Run all the validators associated with this MaskedInput on the supplied value. + + Same as `Input.validate()` but also validates against template which acts as an + additional implicit validator. + + Returns: + A ValidationResult indicating whether *all* validators succeeded or not. + That is, if *any* validator fails, the result will be an unsuccessful + validation. + """ + + def set_classes() -> None: + """Set classes for valid flag.""" + valid = self._valid + self.set_class(not valid, "-invalid") + self.set_class(valid, "-valid") + + result = super().validate(value) + validation_results: list[ValidationResult] = [self._template.validate(value)] + if result is not None: + validation_results.append(result) + combined_result = ValidationResult.merge(validation_results) + self._valid = combined_result.is_valid + set_classes() + + return combined_result + + def render_line(self, y: int) -> Strip: + if y != 0: + return Strip.blank(self.size.width, self.rich_style) + + result = self._value + width = self.content_size.width + + # Add the completion with a faded style. + value = self.value + value_length = len(value) + template = self._template + style = self.get_component_rich_style("input--placeholder") + result += Text( + template.mask[value_length:], + style, + end="", + ) + for index, (char, char_definition) in enumerate(zip(value, template.template)): + if char == " ": + result.stylize(style, index, index + 1) + + if self._cursor_visible and self.has_focus: + if self.cursor_at_end: + result.pad_right(1) + cursor_style = self.get_component_rich_style("input--cursor") + cursor = self.cursor_position + result.stylize(cursor_style, cursor, cursor + 1) + + segments = list(result.render(self.app.console)) + line_length = Segment.get_line_length(segments) + if line_length < width: + segments = Segment.adjust_line_length(segments, width) + line_length = width + + strip = Strip(segments).crop(self.scroll_offset.x, self.scroll_offset.x + width) + return strip.apply_style(self.rich_style) + + @property + def _value(self) -> Text: + """Value rendered as text.""" + value = self._template.display(self.value) + return Text(value, no_wrap=True, overflow="ignore", end="") + + async def _on_click(self, event: events.Click) -> None: + """Ensure clicking on value does not leave cursor on a separator.""" + await super()._on_click(event) + if self._template.at_separator(): + self._template.move_cursor(1) + + def insert_text_at_cursor(self, text: str) -> None: + """Insert new text at the cursor, move the cursor to the end of the new text. + + Args: + text: New text to insert. + """ + + new_value = self._template.insert_text_at_cursor(text) + if new_value is not None: + self.value, self.cursor_position = new_value + else: + self.restricted() + + def clear(self) -> None: + """Clear the masked input.""" + self.value, self.cursor_position = self._template.insert_separators("", 0) + + def action_cursor_left(self) -> None: + """Move the cursor one position to the left; separators are skipped.""" + self._template.move_cursor(-1) + + def action_cursor_right(self) -> None: + """Move the cursor one position to the right; separators are skipped.""" + self._template.move_cursor(1) + + def action_home(self) -> None: + """Move the cursor to the start of the input.""" + self._template.move_cursor(-len(self.template)) + + def action_cursor_left_word(self) -> None: + """Move the cursor left next to the previous separator. If no previous + separator is found, moves the cursor to the start of the input.""" + if self._template.at_separator(self.cursor_position - 1): + position = self._template.prev_separator_position(self.cursor_position - 1) + else: + position = self._template.prev_separator_position() + if position: + position += 1 + self.cursor_position = position or 0 + + def action_cursor_right_word(self) -> None: + """Move the cursor right next to the next separator. If no next + separator is found, moves the cursor to the end of the input.""" + position = self._template.next_separator_position() + if position is None: + self.cursor_position = len(self._template.mask) + else: + self.cursor_position = position + 1 + + def action_delete_right(self) -> None: + """Delete one character at the current cursor position.""" + self._template.delete_at_position() + + def action_delete_right_word(self) -> None: + """Delete the current character and all rightward to next separator or + the end of the input.""" + position = self._template.next_separator_position() + if position is not None: + position += 1 + else: + position = len(self.value) + for index in range(self.cursor_position, position): + self.cursor_position = index + if not self._template.at_separator(): + self._template.delete_at_position() + + def action_delete_left(self) -> None: + """Delete one character to the left of the current cursor position.""" + if self.cursor_position <= 0: + # Cursor at the start, so nothing to delete + return + self._template.move_cursor(-1) + self._template.delete_at_position() + + def action_delete_left_word(self) -> None: + """Delete leftward of the cursor position to the previous separator or + the start of the input.""" + if self.cursor_position <= 0: + return + if self._template.at_separator(self.cursor_position - 1): + position = self._template.prev_separator_position(self.cursor_position - 1) + else: + position = self._template.prev_separator_position() + if position: + position += 1 + else: + position = 0 + for index in range(position, self.cursor_position): + self.cursor_position = index + if not self._template.at_separator(): + self._template.delete_at_position() + self.cursor_position = position + + def action_delete_left_all(self) -> None: + """Delete all characters to the left of the cursor position.""" + if self.cursor_position > 0: + cursor_position = self.cursor_position + if cursor_position >= len(self.value): + self.value = "" + else: + self.value = ( + self._template.empty_mask[:cursor_position] + + self.value[cursor_position:] + ) + self.cursor_position = 0 diff --git a/src/memray/_vendor/textual/widgets/_option_list.py b/src/memray/_vendor/textual/widgets/_option_list.py new file mode 100644 index 0000000000..e345632c5f --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_option_list.py @@ -0,0 +1,1039 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, ClassVar, Iterable, Sequence, cast + +import rich.repr +from rich.segment import Segment + +from memray._vendor.textual import _widget_navigation, events +from memray._vendor.textual._loop import loop_last +from memray._vendor.textual.binding import Binding, BindingType +from memray._vendor.textual.cache import LRUCache +from memray._vendor.textual.css.styles import RulesMap +from memray._vendor.textual.geometry import Region, Size, Spacing, clamp +from memray._vendor.textual.message import Message +from memray._vendor.textual.reactive import reactive +from memray._vendor.textual.scroll_view import ScrollView +from memray._vendor.textual.strip import Strip +from memray._vendor.textual.style import Style +from memray._vendor.textual.visual import Padding, Visual, VisualType, visualize + +if TYPE_CHECKING: + from typing_extensions import Self, TypeAlias + + +OptionListContent: TypeAlias = "Option | VisualType | None" +"""Types accepted in OptionList constructor and [add_options()][textual.widgets.OptionList.ads_options].""" + + +class OptionListError(Exception): + """An error occurred in the option list.""" + + +class DuplicateID(OptionListError): + """Raised if a duplicate ID is used when adding options to an option list.""" + + +class OptionDoesNotExist(OptionListError): + """Raised when a request has been made for an option that doesn't exist.""" + + +@rich.repr.auto +class Option: + """This class holds details of options in the list.""" + + def __init__( + self, prompt: VisualType, id: str | None = None, disabled: bool = False + ) -> None: + """Initialise the option. + + Args: + prompt: The prompt (text displayed) for the option. + id: An option ID for the option. + disabled: Disable the option (will be shown grayed out, and will not be selectable). + + """ + self._prompt = prompt + self._visual: Visual | None = None + self._id = id + self.disabled = disabled + self._divider = False + + @property + def prompt(self) -> VisualType: + """The original prompt.""" + return self._prompt + + @property + def id(self) -> str | None: + """Optional ID for the option.""" + return self._id + + def _set_prompt(self, prompt: VisualType) -> None: + """Update the prompt. + + Args: + prompt: New prompt. + + """ + self._prompt = prompt + self._visual = None + + def __hash__(self) -> int: + return id(self) + + def __rich_repr__(self) -> rich.repr.Result: + yield self._prompt + yield "id", self._id, None + yield "disabled", self.disabled, False + yield "_divider", self._divider, False + + +@dataclass +class _LineCache: + """Cached line information.""" + + lines: list[tuple[int, int]] = field(default_factory=list) + heights: dict[int, int] = field(default_factory=dict) + index_to_line: dict[int, int] = field(default_factory=dict) + + def clear(self) -> None: + """Reset all caches.""" + self.lines.clear() + self.heights.clear() + self.index_to_line.clear() + + +class OptionList(ScrollView, can_focus=True): + """A navigable list of options.""" + + ALLOW_SELECT = False + BINDINGS: ClassVar[list[BindingType]] = [ + Binding("down", "cursor_down", "Down", show=False), + Binding("end", "last", "Last", show=False), + Binding("enter", "select", "Select", show=False), + Binding("home", "first", "First", show=False), + Binding("pagedown", "page_down", "Page Down", show=False), + Binding("pageup", "page_up", "Page Up", show=False), + Binding("up", "cursor_up", "Up", show=False), + ] + """ + | Key(s) | Description | + | :- | :- | + | down | Move the highlight down. | + | end | Move the highlight to the last option. | + | enter | Select the current option. | + | home | Move the highlight to the first option. | + | pagedown | Move the highlight down a page of options. | + | pageup | Move the highlight up a page of options. | + | up | Move the highlight up. | + """ + + DEFAULT_CSS = """ + OptionList { + height: auto; + max-height: 100%; + color: $foreground; + overflow-x: hidden; + border: tall $border-blurred; + padding: 0 1; + background: $surface; + &.-textual-compact { + border: none !important; + padding: 0; + & > .option-list--option { + padding: 0; + } + } + & > .option-list--option-highlighted { + color: $block-cursor-blurred-foreground; + background: $block-cursor-blurred-background; + text-style: $block-cursor-blurred-text-style; + } + &:focus { + border: tall $border; + background-tint: $foreground 5%; + & > .option-list--option-highlighted { + color: $block-cursor-foreground; + background: $block-cursor-background; + text-style: $block-cursor-text-style; + } + } + & > .option-list--separator { + color: $foreground 15%; + } + & > .option-list--option-highlighted { + color: $foreground; + background: $block-cursor-blurred-background; + } + & > .option-list--option-disabled { + color: $text-disabled; + } + & > .option-list--option-hover { + background: $block-hover-background; + } + } + """ + + COMPONENT_CLASSES: ClassVar[set[str]] = { + "option-list--option", + "option-list--option-disabled", + "option-list--option-highlighted", + "option-list--option-hover", + "option-list--separator", + } + """ + | Class | Description | + | :- | :- | + | `option-list--option` | Target options that are not disabled, highlighted or have the mouse over them. | + | `option-list--option-disabled` | Target disabled options. | + | `option-list--option-highlighted` | Target the highlighted option. | + | `option-list--option-hover` | Target an option that has the mouse over it. | + | `option-list--separator` | Target the separators. | + """ + + highlighted: reactive[int | None] = reactive(None) + """The index of the currently-highlighted option, or `None` if no option is highlighted.""" + + _mouse_hovering_over: reactive[int | None] = reactive(None) + """The index of the option under the mouse or `None`.""" + + compact: reactive[bool] = reactive(False, toggle_class="-textual-compact") + """Enable compact display?""" + + class OptionMessage(Message): + """Base class for all option messages.""" + + def __init__(self, option_list: OptionList, option: Option, index: int) -> None: + """Initialise the option message. + + Args: + option_list: The option list that owns the option. + index: The index of the option that the message relates to. + """ + super().__init__() + self.option_list: OptionList = option_list + """The option list that sent the message.""" + self.option: Option = option + """The highlighted option.""" + self.option_id: str | None = option.id + """The ID of the option that the message relates to.""" + self.option_index: int = index + """The index of the option that the message relates to.""" + + @property + def control(self) -> OptionList: + """The option list that sent the message. + + This is an alias for [`OptionMessage.option_list`][textual.widgets.OptionList.OptionMessage.option_list] + and is used by the [`on`][textual.on] decorator. + """ + return self.option_list + + def __rich_repr__(self) -> rich.repr.Result: + try: + yield "option_list", self.option_list + yield "option", self.option + yield "option_id", self.option_id + yield "option_index", self.option_index + except AttributeError: + return + + class OptionHighlighted(OptionMessage): + """Message sent when an option is highlighted. + + Can be handled using `on_option_list_option_highlighted` in a subclass of + `OptionList` or in a parent node in the DOM. + """ + + class OptionSelected(OptionMessage): + """Message sent when an option is selected. + + Can be handled using `on_option_list_option_selected` in a subclass of + `OptionList` or in a parent node in the DOM. + """ + + def __init__( + self, + *content: OptionListContent, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + markup: bool = True, + compact: bool = False, + ): + """Initialize an OptionList. + + Args: + *content: Positional arguments become the options. + name: Name of the OptionList. + id: The ID of the OptionList in the DOM. + classes: Initial CSS classes. + disabled: Disable the widget? + markup: Strips should be rendered as content markup if `True`, or plain text if `False`. + compact: Enable compact style? + """ + super().__init__(name=name, id=id, classes=classes, disabled=disabled) + self._markup = markup + self.compact = compact + self._options: list[Option] = [] + """List of options.""" + self._id_to_option: dict[str, Option] = {} + """Maps an Options's ID on to the option itself.""" + self._option_to_index: dict[Option, int] = {} + """Maps an Option to its index in self._options.""" + + self._option_render_cache: LRUCache[tuple[Option, Style, Spacing], list[Strip]] + self._option_render_cache = LRUCache(maxsize=1024 * 2) + """Caches rendered options.""" + + self._line_cache = _LineCache() + """Used to cache additional information that can be recomputed.""" + + self.add_options(content) + if self._options: + # TODO: Inherited from previous version. Do we always want this? + self.action_first() + + @property + def options(self) -> Sequence[Option]: + """Sequence of options in the OptionList. + + !!! note "This is read-only" + + """ + return self._options + + @property + def option_count(self) -> int: + """The number of options.""" + return len(self._options) + + @property + def highlighted_option(self) -> Option | None: + """The currently highlighted option, or `None` if no option is highlighted. + + Returns: + An Option, or `None`. + """ + if self.highlighted is not None: + return self.options[self.highlighted] + else: + return None + + def clear_options(self) -> Self: + """Clear the content of the option list. + + Returns: + The `OptionList` instance. + """ + self._options.clear() + self._line_cache.clear() + self._option_render_cache.clear() + self._id_to_option.clear() + self._option_to_index.clear() + self.highlighted = None + self.refresh() + self.scroll_y = 0 + self._update_lines() + return self + + def set_options(self, options: Iterable[OptionListContent]) -> Self: + """Set options, potentially clearing existing options. + + Args: + options: Options to set. + + Returns: + The `OptionList` instance. + """ + self._options.clear() + self._line_cache.clear() + self._option_render_cache.clear() + self._id_to_option.clear() + self._option_to_index.clear() + self.highlighted = None + self.scroll_y = 0 + self.add_options(options) + return self + + def add_options(self, new_options: Iterable[OptionListContent]) -> Self: + """Add new options. + + Args: + new_options: Content of new options. + + Returns: + The `OptionList` instance. + """ + + new_options = list(new_options) + + option_ids = [ + option._id + for option in new_options + if isinstance(option, Option) and option._id is not None + ] + if len(option_ids) != len(set(option_ids)): + raise DuplicateID( + "New options contain duplicated IDs; Ensure that the IDs are unique." + ) + + if not new_options: + return self + if new_options[0] is None: + # Handle the case where the first new option is None, + # which would update the previous option. + # This is sub-optimal, but hopefully not a common occurrence + self._clear_caches() + options = self._options + add_option = self._options.append + + for prompt in new_options: + if isinstance(prompt, Option): + option = prompt + elif prompt is None: + if options: + options[-1]._divider = True + continue + else: + option = Option(prompt) + self._option_to_index[option] = len(options) + if option._id is not None: + if option._id in self._id_to_option: + raise DuplicateID(f"Unable to add {option!r} due to duplicate ID") + self._id_to_option[option._id] = option + add_option(option) + if self.is_mounted: + self.refresh(layout=self.styles.auto_dimensions) + self._update_lines() + return self + + def add_option(self, option: Option | VisualType | None = None) -> Self: + """Add a new option to the end of the option list. + + Args: + option: New option to add, or `None` for a separator. + + Returns: + The `OptionList` instance. + + Raises: + DuplicateID: If there is an attempt to use a duplicate ID. + """ + self.add_options([option]) + return self + + def get_option(self, option_id: str) -> Option: + """Get the option with the given ID. + + Args: + option_id: The ID of the option to get. + + Returns: + The option with the ID. + + Raises: + OptionDoesNotExist: If no option has the given ID. + """ + try: + return self._id_to_option[option_id] + except KeyError: + raise OptionDoesNotExist( + f"There is no option with an ID of {option_id!r}" + ) from None + + def get_option_index(self, option_id: str) -> int: + """Get the index (offset in `self.options`) of the option with the given ID. + + Args: + option_id: The ID of the option to get the index of. + + Returns: + The index of the item with the given ID. + + Raises: + OptionDoesNotExist: If no option has the given ID. + """ + option = self.get_option(option_id) + return self._option_to_index[option] + + def get_option_at_index(self, index: int) -> Option: + """Get the option at the given index. + + Args: + index: The index of the option to get. + + Returns: + The option at that index. + + Raises: + OptionDoesNotExist: If there is no option with the given index. + """ + try: + return self._options[index] + except IndexError: + raise OptionDoesNotExist( + f"There is no option with an index of {index}" + ) from None + + def _set_option_disabled(self, index: int, disabled: bool) -> Self: + """Set the disabled state of an option in the list. + + Args: + index: The index of the option to set the disabled state of. + disabled: The disabled state to set. + + Returns: + The `OptionList` instance. + """ + self._options[index].disabled = disabled + if index == self.highlighted: + self.highlighted = _widget_navigation.find_next_enabled( + self._options, anchor=index, direction=1 + ) + # TODO: Refresh only if the affected option is visible. + self.refresh() + return self + + def enable_option_at_index(self, index: int) -> Self: + """Enable the option at the given index. + + Returns: + The `OptionList` instance. + + Raises: + OptionDoesNotExist: If there is no option with the given index. + """ + try: + return self._set_option_disabled(index, False) + except IndexError: + raise OptionDoesNotExist( + f"There is no option with an index of {index}" + ) from None + + def disable_option_at_index(self, index: int) -> Self: + """Disable the option at the given index. + + Returns: + The `OptionList` instance. + + Raises: + OptionDoesNotExist: If there is no option with the given index. + """ + try: + return self._set_option_disabled(index, True) + except IndexError: + raise OptionDoesNotExist( + f"There is no option with an index of {index}" + ) from None + + def enable_option(self, option_id: str) -> Self: + """Enable the option with the given ID. + + Args: + option_id: The ID of the option to enable. + + Returns: + The `OptionList` instance. + + Raises: + OptionDoesNotExist: If no option has the given ID. + """ + return self.enable_option_at_index(self.get_option_index(option_id)) + + def disable_option(self, option_id: str) -> Self: + """Disable the option with the given ID. + + Args: + option_id: The ID of the option to disable. + + Returns: + The `OptionList` instance. + + Raises: + OptionDoesNotExist: If no option has the given ID. + """ + return self.disable_option_at_index(self.get_option_index(option_id)) + + def _remove_option(self, option: Option) -> Self: + """Remove the option with the given ID. + + Args: + option: The Option to return. + + Returns: + The `OptionList` instance. + + Raises: + OptionDoesNotExist: If no option has the given ID. + """ + + index = self._option_to_index[option] + self._mouse_hovering_over = None + self._pre_remove_option(option, index) + for option in self.options[index + 1 :]: + current_index = self._option_to_index[option] + self._option_to_index[option] = current_index - 1 + + option = self._options[index] + del self._options[index] + if option._id is not None: + del self._id_to_option[option._id] + del self._option_to_index[option] + self.highlighted = self.highlighted + self._clear_caches() + return self + + def _pre_remove_option(self, option: Option, index: int) -> None: + """Hook called prior to removing an option. + + Args: + option: Option being removed. + index: Index of option being removed. + + """ + + def remove_option(self, option_id: str) -> Self: + """Remove the option with the given ID. + + Args: + option_id: The ID of the option to remove. + + Returns: + The `OptionList` instance. + + Raises: + OptionDoesNotExist: If no option has the given ID. + """ + option = self.get_option(option_id) + return self._remove_option(option) + + def remove_option_at_index(self, index: int) -> Self: + """Remove the option at the given index. + + Args: + index: The index of the option to remove. + + Returns: + The `OptionList` instance. + + Raises: + OptionDoesNotExist: If there is no option with the given index. + """ + try: + option = self._options[index] + except IndexError: + raise OptionDoesNotExist( + f"Unable to remove; there is no option at index {index}" + ) from None + return self._remove_option(option) + + def _replace_option_prompt(self, index: int, prompt: VisualType) -> None: + """Replace the prompt of an option in the list. + + Args: + index: The index of the option to replace the prompt of. + prompt: The new prompt for the option. + + Raises: + OptionDoesNotExist: If there is no option with the given index. + """ + self.get_option_at_index(index)._set_prompt(prompt) + self._clear_caches() + + def replace_option_prompt(self, option_id: str, prompt: VisualType) -> Self: + """Replace the prompt of the option with the given ID. + + Args: + option_id: The ID of the option to replace the prompt of. + prompt: The new prompt for the option. + + Returns: + The `OptionList` instance. + + Raises: + OptionDoesNotExist: If no option has the given ID. + """ + self._replace_option_prompt(self.get_option_index(option_id), prompt) + return self + + def replace_option_prompt_at_index(self, index: int, prompt: VisualType) -> Self: + """Replace the prompt of the option at the given index. + + Args: + index: The index of the option to replace the prompt of. + prompt: The new prompt for the option. + + Returns: + The `OptionList` instance. + + Raises: + OptionDoesNotExist: If there is no option with the given index. + """ + self._replace_option_prompt(index, prompt) + return self + + @property + def _lines(self) -> Sequence[tuple[int, int]]: + """A sequence of pairs of ints for each line, used internally. + + The first int is the index of the option, and second is the line offset. + + !!! note "This is read-only" + + Returns: + A sequence of tuples. + """ + self._update_lines() + return self._line_cache.lines + + @property + def _heights(self) -> dict[int, int]: + self._update_lines() + return self._line_cache.heights + + @property + def _index_to_line(self) -> dict[int, int]: + self._update_lines() + return self._line_cache.index_to_line + + def _clear_caches(self) -> None: + self._option_render_cache.clear() + self._line_cache.clear() + self.refresh() + + def notify_style_update(self) -> None: + self.refresh() + super().notify_style_update() + + def _on_resize(self): + self._clear_caches() + + def on_show(self) -> None: + self.scroll_to_highlight() + + def on_mount(self) -> None: + self._update_lines() + + async def _on_click(self, event: events.Click) -> None: + """React to the mouse being clicked on an item. + + Args: + event: The click event. + """ + clicked_option: int | None = event.style.meta.get("option") + if clicked_option is not None and not self._options[clicked_option].disabled: + self.highlighted = clicked_option + self.action_select() + + def _get_left_gutter_width(self) -> int: + """Returns the size of any left gutter that should be taken into account. + + Returns: + The width of the left gutter. + """ + return 0 + + def _on_mouse_move(self, event: events.MouseMove) -> None: + """React to the mouse moving. + + Args: + event: The mouse movement event. + """ + self._mouse_hovering_over = event.style.meta.get("option") + + def _on_leave(self, _: events.Leave) -> None: + """React to the mouse leaving the widget.""" + self._mouse_hovering_over = None + + def _get_visual(self, option: Option) -> Visual: + """Get a visual for the given option. + + Args: + option: An option. + + Returns: + A Visual. + + """ + if (visual := option._visual) is None: + visual = visualize(self, option.prompt, markup=self._markup) + option._visual = visual + return visual + + def _get_visual_from_index(self, index: int) -> Visual: + """Get a visual from the given index. + + Args: + index: An index (offset in self.options). + + Returns: + A Visual. + """ + option = self.get_option_at_index(index) + return self._get_visual(option) + + def _get_option_render(self, option: Option, style: Style) -> list[Strip]: + """Get rendered option with a given style. + + Args: + option: An option. + style: Style of render. + + Returns: + A list of strips. + """ + padding = self.get_component_styles("option-list--option").padding + render_width = self.scrollable_content_region.width + width = render_width - self._get_left_gutter_width() + cache_key = (option, style, padding) + if (strips := self._option_render_cache.get(cache_key)) is None: + visual = self._get_visual(option) + if padding: + visual = Padding(visual, padding) + strips = visual.to_strips(self, visual, width, None, style) + meta = {"option": self._option_to_index[option]} + strips = [ + strip.extend_cell_length(width, style.rich_style).apply_meta(meta) + for strip in strips + ] + if option._divider: + style = self.get_visual_style("option-list--separator") + rule_segments = [Segment("─" * width, style.rich_style)] + strips.append(Strip(rule_segments, width)) + self._option_render_cache[cache_key] = strips + return strips + + def _update_lines(self) -> None: + """Update internal structures when new lines are added.""" + if not self.scrollable_content_region: + return + + line_cache = self._line_cache + lines = line_cache.lines + next_index = lines[-1][0] + 1 if lines else 0 + get_visual = self._get_visual + width = self.scrollable_content_region.width - self._get_left_gutter_width() + + if next_index < len(self.options): + padding = self.get_component_styles("option-list--option").padding + for index, option in enumerate(self.options[next_index:], next_index): + line_cache.index_to_line[index] = len(line_cache.lines) + line_count = ( + get_visual(option).get_height(self.styles, width - padding.width) + + option._divider + ) + line_cache.heights[index] = line_count + line_cache.lines.extend( + [(index, line_no) for line_no in range(0, line_count)] + ) + + last_divider = self.options and self.options[-1]._divider + virtual_size = Size(width, len(lines) - (1 if last_divider else 0)) + if virtual_size != self.virtual_size: + self.virtual_size = virtual_size + self._scroll_update(virtual_size) + + def get_content_width(self, container: Size, viewport: Size) -> int: + """Get maximum width of options.""" + if not self.options: + return 0 + styles = self.styles + get_visual_from_index = self._get_visual_from_index + padding = self.get_component_styles("option-list--option").padding + gutter_width = self._get_left_gutter_width() + container_width = container.width + width = ( + max( + get_visual_from_index(index).get_optimal_width(styles, container_width) + for index in range(len(self.options)) + ) + + padding.width + + gutter_width + ) + return width + + def get_content_height(self, container: Size, viewport: Size, width: int) -> int: + """Get height for the given width.""" + styles = self.styles + rules = cast(RulesMap, styles) + padding_width = self.get_component_styles("option-list--option").padding.width + get_visual = self._get_visual + height = sum( + ( + get_visual(option).get_height(rules, width - padding_width) + + (1 if option._divider and not last else 0) + ) + for last, option in loop_last(self.options) + ) + return height + + def _get_line(self, style: Style, y: int) -> Strip: + index, line_offset = self._lines[y] + option = self.get_option_at_index(index) + strips = self._get_option_render(option, style) + return strips[line_offset] + + def render_lines(self, crop: Region) -> list[Strip]: + self._update_lines() + return super().render_lines(crop) + + def render_line(self, y: int) -> Strip: + line_number = self.scroll_offset.y + y + try: + option_index, line_offset = self._lines[line_number] + option = self.options[option_index] + except IndexError: + return Strip.blank( + self.scrollable_content_region.width, + self.get_visual_style("option-list--option").rich_style, + ) + + mouse_over = self._mouse_hovering_over == option_index + component_class = "" + if option.disabled: + component_class = "option-list--option-disabled" + elif self.highlighted == option_index: + component_class = "option-list--option-highlighted" + elif mouse_over: + component_class = "option-list--option-hover" + + if component_class: + style = self.get_visual_style("option-list--option", component_class) + else: + style = self.get_visual_style("option-list--option") + + strips = self._get_option_render(option, style) + try: + strip = strips[line_offset] + except IndexError: + return Strip.blank( + self.scrollable_content_region.width, + self.get_visual_style("option-list--option").rich_style, + ) + return strip + + def validate_highlighted(self, highlighted: int | None) -> int | None: + """Validate the `highlighted` property value on access.""" + if highlighted is None or not self.options: + return None + elif highlighted < 0: + return 0 + elif highlighted >= len(self.options): + return len(self.options) - 1 + return highlighted + + def watch_highlighted(self, highlighted: int | None) -> None: + """React to the highlighted option having changed.""" + if highlighted is None: + return + if not self._options[highlighted].disabled: + self.scroll_to_highlight() + self.post_message( + self.OptionHighlighted(self, self.options[highlighted], highlighted) + ) + + def scroll_to_highlight(self, top: bool = False) -> None: + """Scroll to the highlighted option. + + Args: + top: Ensure highlighted option is at the top of the widget. + """ + highlighted = self.highlighted + if highlighted is None or not self.is_mounted: + return + + self._update_lines() + + try: + y = self._index_to_line[highlighted] + except KeyError: + return + height = self._heights[highlighted] + + self.scroll_to_region( + Region(0, y, self.scrollable_content_region.width, height), + force=True, + animate=False, + top=top, + immediate=True, + ) + + def action_cursor_up(self) -> None: + """Move the highlight up to the previous enabled option.""" + self.highlighted = _widget_navigation.find_next_enabled( + self.options, + anchor=self.highlighted, + direction=-1, + ) + + def action_cursor_down(self) -> None: + """Move the highlight down to the next enabled option.""" + self.highlighted = _widget_navigation.find_next_enabled( + self.options, + anchor=self.highlighted, + direction=1, + ) + + def action_first(self) -> None: + """Move the highlight to the first enabled option.""" + self.highlighted = _widget_navigation.find_first_enabled(self.options) + + def action_last(self) -> None: + """Move the highlight to the last enabled option.""" + self.highlighted = _widget_navigation.find_last_enabled(self.options) + + def _move_page(self, direction: _widget_navigation.Direction) -> None: + """Move the height roughly by one page in the given direction. + + This method will attempt to avoid selecting a disabled option. + + Args: + direction: `-1` to move up a page, `1` to move down a page. + """ + if not self._options: + return + + height = self.scrollable_content_region.height + y = clamp( + self._index_to_line[self.highlighted or 0] + direction * height, + 0, + len(self._lines) - 1, + ) + option_index = self._lines[y][0] + self.highlighted = _widget_navigation.find_next_enabled_no_wrap( + candidates=self._options, + anchor=option_index, + direction=direction, + with_anchor=True, + ) + + def action_page_up(self): + """Move the highlight up one page.""" + if self.highlighted is None: + self.action_first() + else: + self._move_page(-1) + + def action_page_down(self): + """Move the highlight down one page.""" + if self.highlighted is None: + self.action_last() + else: + self._move_page(1) + + def action_select(self) -> None: + """Select the currently highlighted option. + + If an option is selected then a + [OptionList.OptionSelected][textual.widgets.OptionList.OptionSelected] will be posted. + """ + highlighted = self.highlighted + if highlighted is None: + return + option = self._options[highlighted] + if highlighted is not None and not option.disabled: + self.post_message(self.OptionSelected(self, option, highlighted)) diff --git a/src/memray/_vendor/textual/widgets/_placeholder.py b/src/memray/_vendor/textual/widgets/_placeholder.py new file mode 100644 index 0000000000..a67bef237c --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_placeholder.py @@ -0,0 +1,186 @@ +"""Provides a Textual placeholder widget; useful when designing an app's layout.""" + +from __future__ import annotations + +from itertools import cycle +from typing import TYPE_CHECKING +from weakref import WeakKeyDictionary + +from typing_extensions import Literal, Self + +from memray._vendor.textual import events + +if TYPE_CHECKING: + from memray._vendor.textual.app import RenderResult + +from memray._vendor.textual._context import NoActiveAppError +from memray._vendor.textual.css._error_tools import friendly_list +from memray._vendor.textual.reactive import Reactive, reactive +from memray._vendor.textual.widget import Widget + +if TYPE_CHECKING: + from memray._vendor.textual.app import App + +PlaceholderVariant = Literal["default", "size", "text"] +"""The different variants of placeholder.""" + +_VALID_PLACEHOLDER_VARIANTS_ORDERED: list[PlaceholderVariant] = [ + "default", + "size", + "text", +] +_VALID_PLACEHOLDER_VARIANTS: set[PlaceholderVariant] = set( + _VALID_PLACEHOLDER_VARIANTS_ORDERED +) +_PLACEHOLDER_BACKGROUND_COLORS = [ + "#881177", + "#aa3355", + "#cc6666", + "#ee9944", + "#eedd00", + "#99dd55", + "#44dd88", + "#22ccbb", + "#00bbcc", + "#0099cc", + "#3366bb", + "#663399", +] +_LOREM_IPSUM_PLACEHOLDER_TEXT = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Etiam feugiat ac elit sit amet accumsan. Suspendisse bibendum nec libero quis gravida. Phasellus id eleifend ligula. Nullam imperdiet sem tellus, sed vehicula nisl faucibus sit amet. Praesent iaculis tempor ultricies. Sed lacinia, tellus id rutrum lacinia, sapien sapien congue mauris, sit amet pellentesque quam quam vel nisl. Curabitur vulputate erat pellentesque mauris posuere, non dictum risus mattis." + + +class InvalidPlaceholderVariant(Exception): + """Raised when an invalid Placeholder variant is set.""" + + +class Placeholder(Widget): + """A simple placeholder widget to use before you build your custom widgets. + + This placeholder has a couple of variants that show different data. + Clicking the placeholder cycles through the available variants, but a placeholder + can also be initialised in a specific variant. + + The variants available are: + + | Variant | Placeholder shows | + |---------|------------------------------------------------| + | default | Identifier label or the ID of the placeholder. | + | size | Size of the placeholder. | + | text | Lorem Ipsum text. | + """ + + DEFAULT_CSS = """ + Placeholder { + content-align: center middle; + overflow: hidden; + color: $text; + + &:disabled { + opacity: 0.7; + } + } + Placeholder.-text { + padding: 1; + } + """ + + # Consecutive placeholders get assigned consecutive colors. + _COLORS: WeakKeyDictionary[App, int] = WeakKeyDictionary() + _SIZE_RENDER_TEMPLATE = "[b]{} x {}[/b]" + + variant: Reactive[PlaceholderVariant] = reactive[PlaceholderVariant]("default") + + _renderables: dict[PlaceholderVariant, str] + + def __init__( + self, + label: str | None = None, + variant: PlaceholderVariant = "default", + *, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + ) -> None: + """Create a Placeholder widget. + + Args: + label: The label to identify the placeholder. + If no label is present, uses the placeholder ID instead. + variant: The variant of the placeholder. + name: The name of the placeholder. + id: The ID of the placeholder in the DOM. + classes: A space separated string with the CSS classes + of the placeholder, if any. + disabled: Whether the placeholder is disabled or not. + """ + # Create and cache renderables for all the variants. + self._renderables = { + "default": label if label else f"#{id}" if id else "Placeholder", + "size": "", + "text": "\n\n".join(_LOREM_IPSUM_PLACEHOLDER_TEXT for _ in range(5)), + } + + super().__init__(name=name, id=id, classes=classes, disabled=disabled) + + self.variant = self.validate_variant(variant) + """The current variant of the placeholder.""" + + try: + self._COLORS[self.app] = self._COLORS.setdefault(self.app, -1) + 1 + self._color_offset = self._COLORS[self.app] + except NoActiveAppError: + self._color_offset = 0 + + # Set a cycle through the variants with the correct starting point. + self._variants_cycle = cycle(_VALID_PLACEHOLDER_VARIANTS_ORDERED) + while next(self._variants_cycle) != self.variant: + pass + + async def _on_compose(self, event: events.Compose) -> None: + """Set the color for this placeholder.""" + color_count = len(_PLACEHOLDER_BACKGROUND_COLORS) + color = _PLACEHOLDER_BACKGROUND_COLORS[self._color_offset % color_count] + self.styles.background = f"{color} 50%" + + def render(self) -> RenderResult: + """Render the placeholder. + + Returns: + The value to render. + """ + return self._renderables[self.variant] + + def cycle_variant(self) -> Self: + """Get the next variant in the cycle. + + Returns: + The `Placeholder` instance. + """ + self.variant = next(self._variants_cycle) + return self + + def watch_variant( + self, old_variant: PlaceholderVariant, variant: PlaceholderVariant + ) -> None: + self.remove_class(f"-{old_variant}") + self.add_class(f"-{variant}") + + def validate_variant(self, variant: PlaceholderVariant) -> PlaceholderVariant: + """Validate the variant to which the placeholder was set.""" + if variant not in _VALID_PLACEHOLDER_VARIANTS: + raise InvalidPlaceholderVariant( + "Valid placeholder variants are " + + f"{friendly_list(_VALID_PLACEHOLDER_VARIANTS)}" + ) + return variant + + async def _on_click(self, _: events.Click) -> None: + """Click handler to cycle through the placeholder variants.""" + self.cycle_variant() + + def _on_resize(self, event: events.Resize) -> None: + """Update the placeholder "size" variant with the new placeholder size.""" + self._renderables["size"] = self._SIZE_RENDER_TEMPLATE.format(*event.size) + if self.variant == "size": + self.refresh() diff --git a/src/memray/_vendor/textual/widgets/_pretty.py b/src/memray/_vendor/textual/widgets/_pretty.py new file mode 100644 index 0000000000..4fba2eadad --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_pretty.py @@ -0,0 +1,56 @@ +"""Provides a pretty-printing widget.""" + +from __future__ import annotations + +from typing import Any + +from rich.pretty import Pretty as PrettyRenderable + +from memray._vendor.textual.app import RenderResult +from memray._vendor.textual.widget import Widget + + +class Pretty(Widget): + """A pretty-printing widget. + + Used to pretty-print any object. + """ + + DEFAULT_CSS = """ + Pretty { + height: auto; + } + """ + + def __init__( + self, + object: Any, + *, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + ) -> None: + """Initialise the `Pretty` widget. + + Args: + object: The object to pretty-print. + name: The name of the pretty widget. + id: The ID of the pretty in the DOM. + classes: The CSS classes of the pretty. + """ + super().__init__(name=name, id=id, classes=classes) + self.shrink = False + self._pretty_renderable = PrettyRenderable(object) + + def render(self) -> RenderResult: + return self._pretty_renderable + + def update(self, object: object) -> None: + """Update the content of the pretty widget. + + Args: + object: The object to pretty-print. + """ + self._pretty_renderable = PrettyRenderable(object) + self.clear_cached_dimensions() + self.refresh(layout=True) diff --git a/src/memray/_vendor/textual/widgets/_progress_bar.py b/src/memray/_vendor/textual/widgets/_progress_bar.py new file mode 100644 index 0000000000..5ee8e72833 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_progress_bar.py @@ -0,0 +1,385 @@ +"""Implements a progress bar widget.""" + +from __future__ import annotations + +from typing import Optional, Type + +from rich.style import Style + +from memray._vendor.textual._types import UnusedParameter +from memray._vendor.textual.app import ComposeResult, RenderResult +from memray._vendor.textual.clock import Clock +from memray._vendor.textual.color import Gradient +from memray._vendor.textual.eta import ETA +from memray._vendor.textual.geometry import clamp +from memray._vendor.textual.reactive import reactive +from memray._vendor.textual.renderables.bar import Bar as BarRenderable +from memray._vendor.textual.widget import Widget +from memray._vendor.textual.widgets import Label + +UNUSED = UnusedParameter() +"""Sentinel for method signatures.""" + + +class Bar(Widget, can_focus=False): + """The bar portion of the progress bar.""" + + COMPONENT_CLASSES = {"bar--bar", "bar--complete", "bar--indeterminate"} + """ + The bar sub-widget provides the component classes that follow. + + These component classes let you modify the foreground and background color of the + bar in its different states. + + | Class | Description | + | :- | :- | + | `bar--bar` | Style of the bar (may be used to change the color). | + | `bar--complete` | Style of the bar when it's complete. | + | `bar--indeterminate` | Style of the bar when it's in an indeterminate state. | + """ + + DEFAULT_CSS = """ + Bar { + width: 32; + height: 1; + + &> .bar--bar { + color: $primary; + background: $surface; + } + &> .bar--indeterminate { + color: $error; + background: $surface; + } + &> .bar--complete { + color: $success; + background: $surface; + } + } + """ + + percentage: reactive[float | None] = reactive[Optional[float]](None) + """The percentage of progress that has been completed.""" + + gradient: reactive[Gradient | None] = reactive(None) + """An optional gradient.""" + + def __init__( + self, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + clock: Clock | None = None, + gradient: Gradient | None = None, + bar_renderable: Type[BarRenderable] = BarRenderable, + ): + """Create a bar for a [`ProgressBar`][textual.widgets.ProgressBar].""" + self._clock = (clock or Clock()).clone() + super().__init__(name=name, id=id, classes=classes, disabled=disabled) + self.set_reactive(Bar.gradient, gradient) + self.bar_renderable = bar_renderable + + def _validate_percentage(self, percentage: float | None) -> float | None: + """Avoid updating the bar, if the percentage increase is too small to render.""" + width = self.size.width * 2 + return ( + None + if percentage is None + else (int(percentage * width) / width if width else percentage) + ) + + def watch_percentage(self, percentage: float | None) -> None: + """Manage the timer that enables the indeterminate bar animation.""" + if percentage is not None: + self.auto_refresh = None + else: + self.auto_refresh = 1 / 15 + + def render(self) -> RenderResult: + """Render the bar with the correct portion filled.""" + if self.percentage is None: + return self.render_indeterminate() + else: + bar_style = ( + self.get_component_rich_style("bar--bar") + if self.percentage < 1 + else self.get_component_rich_style("bar--complete") + ) + return self.bar_renderable( + highlight_range=(0, self.size.width * self.percentage), + highlight_style=Style.from_color(bar_style.color), + background_style=Style.from_color(bar_style.bgcolor), + gradient=self.gradient, + ) + + def render_indeterminate(self) -> RenderResult: + """Render a frame of the indeterminate progress bar animation.""" + width = self.size.width + highlighted_bar_width = 0.25 * width + # Width used to enable the visual effect of the bar going into the corners. + total_imaginary_width = width + highlighted_bar_width + start: float + end: float + if self.app.animation_level == "none": + start = 0 + end = width + else: + speed = 30 # Cells per second. + # Compute the position of the bar. + start = ( + (speed * self._clock.time) % (2 * total_imaginary_width) + if total_imaginary_width + else 0 + ) + if start > total_imaginary_width: + # If the bar is to the right of its width, wrap it back from right to left. + start = 2 * total_imaginary_width - start # = (tiw - (start - tiw)) + start -= highlighted_bar_width + end = start + highlighted_bar_width + + bar_style = self.get_component_rich_style("bar--indeterminate") + return self.bar_renderable( + highlight_range=(max(0, start), min(end, width)), + highlight_style=Style.from_color(bar_style.color), + background_style=Style.from_color(bar_style.bgcolor), + ) + + +class PercentageStatus(Label): + """A label to display the percentage status of the progress bar.""" + + DEFAULT_CSS = """ + PercentageStatus { + width: 5; + content-align-horizontal: right; + } + """ + + percentage: reactive[int | None] = reactive[Optional[int]](None) + """The percentage of progress that has been completed.""" + + def _validate_percentage(self, percentage: float | None) -> int | None: + return None if percentage is None else round(percentage * 100) + + def render(self) -> RenderResult: + return "--%" if self.percentage is None else f"{self.percentage}%" + + +class ETAStatus(Label): + """A label to display the estimated time until completion of the progress bar.""" + + DEFAULT_CSS = """ + ETAStatus { + width: 9; + content-align-horizontal: right; + } + """ + eta: reactive[float | None] = reactive[Optional[float]](None) + """Estimated number of seconds till completion, or `None` if no estimate is available.""" + + def render(self) -> RenderResult: + """Render the ETA display.""" + eta = self.eta + if eta is None: + return "--:--:--" + else: + minutes, seconds = divmod(round(eta), 60) + hours, minutes = divmod(minutes, 60) + if hours > 999999: + return "+999999h" + elif hours > 99: + return f"{hours}h" + else: + return f"{hours:02}:{minutes:02}:{seconds:02}" + + +class ProgressBar(Widget, can_focus=False): + """A progress bar widget.""" + + DEFAULT_CSS = """ + ProgressBar { + width: auto; + height: 1; + layout: horizontal; + } + """ + + progress: reactive[float] = reactive(0.0) + """The progress so far, in number of steps.""" + total: reactive[float | None] = reactive[Optional[float]](None) + """The total number of steps associated with this progress bar, when known. + + The value `None` will render an indeterminate progress bar. + """ + percentage: reactive[float | None] = reactive[Optional[float]](None) + """The percentage of progress that has been completed. + + The percentage is a value between 0 and 1 and the returned value is only + `None` if the total progress of the bar hasn't been set yet. + + Example: + ```py + progress_bar = ProgressBar() + print(progress_bar.percentage) # None + progress_bar.update(total=100) + progress_bar.advance(50) + print(progress_bar.percentage) # 0.5 + ``` + """ + _display_eta: reactive[int | None] = reactive[Optional[int]](None) + + gradient: reactive[Gradient | None] = reactive(None) + """Optional gradient object (will replace CSS styling in bar).""" + + BAR_RENDERABLE: Type[BarRenderable] = BarRenderable + """BarRenderable to use for rendering the bar-part of the ProgressBar""" + + def __init__( + self, + total: float | None = None, + *, + show_bar: bool = True, + show_percentage: bool = True, + show_eta: bool = True, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + clock: Clock | None = None, + gradient: Gradient | None = None, + ): + """Create a Progress Bar widget. + + The progress bar uses "steps" as the measurement unit. + + Example: + ```py + class MyApp(App): + def compose(self): + yield ProgressBar(total=100) + + def key_space(self): + self.query_one(ProgressBar).advance(5) + ``` + + Args: + total: The total number of steps in the progress if known. + show_bar: Whether to show the bar portion of the progress bar. + show_percentage: Whether to show the percentage status of the bar. + show_eta: Whether to show the ETA countdown of the progress bar. + name: The name of the widget. + id: The ID of the widget in the DOM. + classes: The CSS classes for the widget. + disabled: Whether the widget is disabled or not. + clock: An optional clock object (leave as default unless testing). + gradient: An optional Gradient object (will replace CSS styles in the bar). + """ + self._clock = clock or Clock() + self._eta = ETA() + super().__init__(name=name, id=id, classes=classes, disabled=disabled) + self.total = total + self.show_bar = show_bar + self.show_percentage = show_percentage + self.show_eta = show_eta + self.set_reactive(ProgressBar.gradient, gradient) + + def on_mount(self) -> None: + self.update() + self.set_interval(1, self.update) + self._clock.reset() + + def compose(self) -> ComposeResult: + if self.show_bar: + yield ( + Bar(id="bar", clock=self._clock, bar_renderable=self.BAR_RENDERABLE) + .data_bind(ProgressBar.percentage) + .data_bind(ProgressBar.gradient) + ) + if self.show_percentage: + yield PercentageStatus(id="percentage").data_bind(ProgressBar.percentage) + if self.show_eta: + yield ETAStatus(id="eta").data_bind(eta=ProgressBar._display_eta) + + def _validate_total(self, total: float | None) -> float | None: + """Ensure the total is not negative.""" + if total is None: + return total + return max(0, total) + + def _compute_percentage(self) -> float | None: + """Keep the percentage of progress updated automatically. + + This will report a percentage of `1` if the total is zero. + """ + if self.total: + return clamp(self.progress / self.total, 0.0, 1.0) + elif self.total == 0: + return 1.0 + return None + + def _watch_progress(self, progress: float) -> None: + """Perform update when progress is modified.""" + self.update(progress=progress) + + def _watch_total(self, total: float) -> None: + """Update when the total is modified.""" + self.update(total=total) + + def advance(self, advance: float = 1) -> None: + """Advance the progress of the progress bar by the given amount. + + Example: + ```py + progress_bar.advance(10) # Advance 10 steps. + ``` + + Args: + advance: Number of steps to advance progress by. + """ + self.update(advance=advance) + + def update( + self, + *, + total: None | float | UnusedParameter = UNUSED, + progress: float | UnusedParameter = UNUSED, + advance: float | UnusedParameter = UNUSED, + ) -> None: + """Update the progress bar with the given options. + + Example: + ```py + progress_bar.update( + total=200, # Set new total to 200 steps. + progress=50, # Set the progress to 50 (out of 200). + ) + ``` + + Args: + total: New total number of steps. + progress: Set the progress to the given number of steps. + advance: Advance the progress by this number of steps. + """ + current_time = self._clock.time + if not isinstance(total, UnusedParameter): + if total is None or total != self.total: + self._eta.reset() + self.total = total + + def add_sample() -> None: + """Add a new sample.""" + if self.progress is not None and self.total: + self._eta.add_sample(current_time, self.progress / self.total) + + if not isinstance(progress, UnusedParameter): + self.progress = progress + add_sample() + + if not isinstance(advance, UnusedParameter): + self.progress += advance + add_sample() + + self._display_eta = ( + None if self.total is None else self._eta.get_eta(current_time) + ) diff --git a/src/memray/_vendor/textual/widgets/_radio_button.py b/src/memray/_vendor/textual/widgets/_radio_button.py new file mode 100644 index 0000000000..20d1ead306 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_radio_button.py @@ -0,0 +1,33 @@ +"""Provides a radio button widget.""" + +from __future__ import annotations + +from memray._vendor.textual.widgets._toggle_button import ToggleButton + + +class RadioButton(ToggleButton): + """A radio button widget that represents a boolean value. + + Note: + A `RadioButton` is best used within a [RadioSet][textual.widgets.RadioSet]. + """ + + BUTTON_INNER = "\u25cf" + """The character used for the inside of the button.""" + + class Changed(ToggleButton.Changed): + """Posted when the value of the radio button changes. + + This message can be handled using an `on_radio_button_changed` method. + """ + + @property + def radio_button(self) -> RadioButton: + """The radio button that was changed.""" + assert isinstance(self._toggle_button, RadioButton) + return self._toggle_button + + @property + def control(self) -> RadioButton: + """Alias for [Changed.radio_button][textual.widgets.RadioButton.Changed.radio_button].""" + return self.radio_button diff --git a/src/memray/_vendor/textual/widgets/_radio_set.py b/src/memray/_vendor/textual/widgets/_radio_set.py new file mode 100644 index 0000000000..d20b3f77f5 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_radio_set.py @@ -0,0 +1,315 @@ +"""Provides a RadioSet widget, which groups radio buttons.""" + +from __future__ import annotations + +from typing import ClassVar, Optional + +import rich.repr +from rich.console import RenderableType + +from memray._vendor.textual import _widget_navigation +from memray._vendor.textual.binding import Binding, BindingType +from memray._vendor.textual.containers import VerticalScroll +from memray._vendor.textual.events import Click, Mount +from memray._vendor.textual.message import Message +from memray._vendor.textual.reactive import reactive, var +from memray._vendor.textual.widgets._radio_button import RadioButton + + +class RadioSet(VerticalScroll, can_focus=True, can_focus_children=False): + """Widget for grouping a collection of radio buttons into a set. + + When a collection of [`RadioButton`][textual.widgets.RadioButton]s are + grouped with this widget, they will be treated as a mutually-exclusive + grouping. If one button is turned on, the previously-on button will be + turned off. + """ + + ALLOW_SELECT = False + ALLOW_MAXIMIZE = True + + DEFAULT_CSS = """ + RadioSet { + border: tall $border-blurred; + background: $surface; + padding: 0 1; + height: auto; + width: 1fr; + pointer: pointer; + + &.-textual-compact { + border: none !important; + padding: 0; + } + + & > RadioButton { + background: transparent; + border: none; + padding: 0; + width: 1fr; + + & > .toggle--button { + color: $panel-darken-2; + background: $panel; + } + } + + & > RadioButton.-on .toggle--button { + color: $text-success; + } + + &:blur { + & > RadioButton.-selected { + & > .toggle--label { + background: $block-cursor-blurred-background; + } + } + } + + &:focus { + /* The following rules/styles mimic similar ToggleButton:focus rules in + * ToggleButton. If those styles ever get updated, these should be too. + */ + border: tall $border; + background-tint: $foreground 5%; + & > RadioButton.-selected { + + & > .toggle--label { + background: $block-cursor-background; + color: $block-cursor-foreground; + text-style: $block-cursor-text-style; + } + } + + } + } + """ + + BINDINGS: ClassVar[list[BindingType]] = [ + Binding("down,right", "next_button", "Next option", show=False), + Binding("enter,space", "toggle_button", "Toggle", show=False), + Binding("up,left", "previous_button", "Previous option", show=False), + ] + """ + | Key(s) | Description | + | :- | :- | + | enter, space | Toggle the currently-selected button. | + | left, up | Select the previous radio button in the set. | + | right, down | Select the next radio button in the set. | + """ + + _selected: var[int | None] = var[Optional[int]](None) + """The index of the currently-selected radio button.""" + + compact: reactive[bool] = reactive(False, toggle_class="-textual-compact") + """Enable compact display?""" + + @rich.repr.auto + class Changed(Message): + """Posted when the pressed button in the set changes. + + This message can be handled using an `on_radio_set_changed` method. + """ + + ALLOW_SELECTOR_MATCH = {"pressed"} + """Additional message attributes that can be used with the [`on` decorator][textual.on].""" + + def __init__(self, radio_set: RadioSet, pressed: RadioButton) -> None: + """Initialise the message. + + Args: + pressed: The radio button that was pressed. + """ + super().__init__() + self.radio_set = radio_set + """A reference to the [`RadioSet`][textual.widgets.RadioSet] that was changed.""" + self.pressed = pressed + """The [`RadioButton`][textual.widgets.RadioButton] that was pressed to make the change.""" + self.index = radio_set.pressed_index + """The index of the [`RadioButton`][textual.widgets.RadioButton] that was pressed to make the change.""" + + @property + def control(self) -> RadioSet: + """A reference to the [`RadioSet`][textual.widgets.RadioSet] that was changed. + + This is an alias for [`Changed.radio_set`][textual.widgets.RadioSet.Changed.radio_set] + and is used by the [`on`][textual.on] decorator. + """ + return self.radio_set + + def __rich_repr__(self) -> rich.repr.Result: + yield "radio_set", self.radio_set + yield "pressed", self.pressed + yield "index", self.index + + def __init__( + self, + *buttons: str | RadioButton, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + tooltip: RenderableType | None = None, + compact: bool = False, + ) -> None: + """Initialise the radio set. + + Args: + buttons: The labels or [`RadioButton`][textual.widgets.RadioButton]s to group together. + name: The name of the radio set. + id: The ID of the radio set in the DOM. + classes: The CSS classes of the radio set. + disabled: Whether the radio set is disabled or not. + tooltip: Optional tooltip. + compact: Enable compact radio set style + + Note: + When a `str` label is provided, a + [RadioButton][textual.widgets.RadioButton] will be created from + it. + """ + self._pressed_button: RadioButton | None = None + """Holds the radio buttons we're responsible for.""" + super().__init__( + *[ + (button if isinstance(button, RadioButton) else RadioButton(button)) + for button in buttons + ], + name=name, + id=id, + classes=classes, + disabled=disabled, + ) + if tooltip is not None: + self.tooltip = tooltip + self.compact = compact + + def _on_mount(self, _: Mount) -> None: + """Perform some processing once mounted in the DOM.""" + + # If there are radio buttons, select the first available one. + self.action_next_button() + + # Get all the buttons within us; we'll be doing a couple of things + # with that list. + buttons = list(self.query(RadioButton)) + + # RadioButtons can have focus, by default. But we're going to take + # that over and handle movement between them. So here we tell them + # all they can't focus. + for button in buttons: + button.can_focus = False + + # It's possible for the user to pass in a collection of radio + # buttons, with more than one set to on; they shouldn't, but we + # can't stop them. So here we check for that and, for want of a + # better approach, we keep the first one on and turn all the others + # off. + switched_on = [button for button in buttons if button.value] + with self.prevent(RadioButton.Changed): + for button in switched_on[1:]: + button.value = False + + # Keep track of which button is initially pressed. + if switched_on: + self._pressed_button = switched_on[0] + + def watch__selected(self) -> None: + self.query(RadioButton).remove_class("-selected") + if self._selected is not None: + self._nodes[self._selected].add_class("-selected") + self._scroll_to_selected() + + def _on_radio_button_changed(self, event: RadioButton.Changed) -> None: + """Respond to the value of a button in the set being changed. + + Args: + event: The event. + """ + # We're going to consume the underlying radio button events, making + # it appear as if they don't emit their own, as far as the caller is + # concerned. As such, stop the event bubbling and also prohibit the + # same event being sent out if/when we make a value change in here. + event.stop() + with self.prevent(RadioButton.Changed): + # If the message pertains to a button being clicked to on... + if event.radio_button.value: + # If there's a button pressed right now and it's not really a + # case of the user mashing on the same button... + if ( + self._pressed_button is not None + and self._pressed_button != event.radio_button + ): + self._pressed_button.value = False + # Make the pressed button this new button. + self._pressed_button = event.radio_button + # Emit a message to say our state has changed. + self.post_message(self.Changed(self, event.radio_button)) + else: + # We're being clicked off, we don't want that. + event.radio_button.value = True + + def _on_radio_set_changed(self, event: RadioSet.Changed) -> None: + """Handle a change to which button in the set is pressed. + + This handler ensures that, when a button is pressed, it's also the + selected button. + """ + self._selected = event.index + + async def _on_click(self, _: Click) -> None: + """Handle a click on or within the radio set. + + This handler ensures that focus moves to the clicked radio set, even + if there's a click on one of the radio buttons it contains. + """ + self.focus() + + @property + def pressed_button(self) -> RadioButton | None: + """The currently-pressed [`RadioButton`][textual.widgets.RadioButton], or `None` if none are pressed.""" + return self._pressed_button + + @property + def pressed_index(self) -> int: + """The index of the currently-pressed [`RadioButton`][textual.widgets.RadioButton], or -1 if none are pressed.""" + return ( + self._nodes.index(self._pressed_button) + if self._pressed_button is not None + else -1 + ) + + def action_previous_button(self) -> None: + """Navigate to the previous button in the set. + + Note that this will wrap around to the end if at the start. + """ + self._selected = _widget_navigation.find_next_enabled( + self.children, + anchor=self._selected, + direction=-1, + ) + + def action_next_button(self) -> None: + """Navigate to the next button in the set. + + Note that this will wrap around to the start if at the end. + """ + self._selected = _widget_navigation.find_next_enabled( + self.children, + anchor=self._selected, + direction=1, + ) + + def action_toggle_button(self) -> None: + """Toggle the state of the currently-selected button.""" + if self._selected is not None: + button = self._nodes[self._selected] + assert isinstance(button, RadioButton) + button.toggle() + + def _scroll_to_selected(self) -> None: + """Ensure that the selected button is in view.""" + if self._selected is not None: + button = self._nodes[self._selected] + self.call_after_refresh(self.scroll_to_widget, button, animate=False) diff --git a/src/memray/_vendor/textual/widgets/_rich_log.py b/src/memray/_vendor/textual/widgets/_rich_log.py new file mode 100644 index 0000000000..b5b41d7479 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_rich_log.py @@ -0,0 +1,320 @@ +"""Provides a scrollable text-logging widget.""" + +from __future__ import annotations + +from collections import deque +from typing import TYPE_CHECKING, NamedTuple, Optional, cast + +from rich.console import RenderableType +from rich.highlighter import Highlighter, ReprHighlighter +from rich.measure import measure_renderables +from rich.pretty import Pretty +from rich.protocol import is_renderable +from rich.segment import Segment +from rich.text import Text + +from memray._vendor.textual.cache import LRUCache +from memray._vendor.textual.events import Resize +from memray._vendor.textual.geometry import Size +from memray._vendor.textual.reactive import var +from memray._vendor.textual.scroll_view import ScrollView +from memray._vendor.textual.strip import Strip + +if TYPE_CHECKING: + from typing_extensions import Self + + +class DeferredRender(NamedTuple): + """A renderable which is awaiting rendering. + This may happen if a `write` occurs before the width is known. + + The arguments are the same as for `RichLog.write`, as this just + represents a deferred call to that method. + """ + + content: RenderableType | object + """The content to render.""" + width: int | None = None + """The width to render or `None` to use optimal width.""" + expand: bool = False + """Enable expand to widget width, or `False` to use `width`.""" + shrink: bool = True + """Enable shrinking of content to fit width.""" + scroll_end: bool | None = None + """Enable automatic scroll to end, or `None` to use `self.auto_scroll`.""" + + +class RichLog(ScrollView, can_focus=True): + """A widget for logging Rich renderables and text.""" + + DEFAULT_CSS = """ + RichLog{ + background: $surface; + color: $foreground; + overflow-y: scroll; + &:focus { + background-tint: $foreground 5%; + } + } + """ + + max_lines: var[int | None] = var[Optional[int]](None) + min_width: var[int] = var(78) + wrap: var[bool] = var(False) + highlight: var[bool] = var(False) + markup: var[bool] = var(False) + auto_scroll: var[bool] = var(True) + + def __init__( + self, + *, + max_lines: int | None = None, + min_width: int = 78, + wrap: bool = False, + highlight: bool = False, + markup: bool = False, + auto_scroll: bool = True, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + ) -> None: + """Create a `RichLog` widget. + + Args: + max_lines: Maximum number of lines in the log or `None` for no maximum. + min_width: Width to use for calls to `write` with no specified `width`. + wrap: Enable word wrapping (default is off). + highlight: Automatically highlight content. By default, the `ReprHighlighter` is used. + To customize highlighting, set `highlight=True` and then set the `highlighter` + attribute to an instance of `Highlighter`. + markup: Apply Rich console markup. + auto_scroll: Enable automatic scrolling to end. + name: The name of the text log. + id: The ID of the text log in the DOM. + classes: The CSS classes of the text log. + disabled: Whether the text log is disabled or not. + """ + super().__init__(name=name, id=id, classes=classes, disabled=disabled) + self.max_lines = max_lines + """Maximum number of lines in the log or `None` for no maximum.""" + self._start_line: int = 0 + self.lines: list[Strip] = [] + """The lines currently visible in the log.""" + self._line_cache: LRUCache[tuple[int, int, int, int], Strip] + self._line_cache = LRUCache(1024) + self._deferred_renders: deque[DeferredRender] = deque() + """Queue of deferred renderables to be rendered.""" + self.min_width = min_width + """Minimum width of renderables.""" + self.wrap = wrap + """Enable word wrapping.""" + self.highlight = highlight + """Automatically highlight content.""" + self.markup = markup + """Apply Rich console markup.""" + self.auto_scroll = auto_scroll + """Automatically scroll to the end on write.""" + self.highlighter: Highlighter = ReprHighlighter() + """Rich Highlighter used to highlight content when highlight is True""" + + self._widest_line_width = 0 + """The width of the widest line currently in the log.""" + + self._size_known = False + """Flag which is set to True when the size of the RichLog is known, + indicating we can proceed with rendering deferred writes.""" + + def notify_style_update(self) -> None: + super().notify_style_update() + self._line_cache.clear() + + def on_resize(self, event: Resize) -> None: + if event.size.width and not self._size_known: + # This size is known for the first time. + self._size_known = True + deferred_renders = self._deferred_renders + while deferred_renders: + deferred_render = deferred_renders.popleft() + self.write(*deferred_render) + + def get_content_width(self, container: Size, viewport: Size) -> int: + if self._size_known: + return self.virtual_size.width + else: + return container.width + + def _make_renderable(self, content: RenderableType | object) -> RenderableType: + """Make content renderable. + + Args: + content: Content to render. + + Returns: + A Rich renderable. + """ + renderable: RenderableType + if not is_renderable(content): + renderable = Pretty(content) + else: + if isinstance(content, str): + if self.markup: + renderable = Text.from_markup(content) + else: + renderable = Text(content) + if self.highlight: + renderable = self.highlighter(renderable) + else: + renderable = cast(RenderableType, content) + + if isinstance(renderable, Text): + renderable.expand_tabs() + + return renderable + + def write( + self, + content: RenderableType | object, + width: int | None = None, + expand: bool = False, + shrink: bool = True, + scroll_end: bool | None = None, + animate: bool = False, + ) -> Self: + """Write a string or a Rich renderable to the bottom of the log. + + Notes: + The rendering of content will be deferred until the size of the `RichLog` is known. + This means if you call `write` in `compose` or `on_mount`, the content will not be + rendered immediately. + + Args: + content: Rich renderable (or a string). + width: Width to render, or `None` to use `RichLog.min_width`. + If specified, `expand` and `shrink` will be ignored. + expand: Permit expanding of content to the width of the content region of the RichLog. + If `width` is specified, then `expand` will be ignored. + shrink: Permit shrinking of content to fit within the content region of the RichLog. + If `width` is specified, then `shrink` will be ignored. + scroll_end: Enable automatic scroll to end, or `None` to use `self.auto_scroll`. + animate: Enable animation if the log will scroll. + + Returns: + The `RichLog` instance. + """ + if not self._size_known: + # We don't know the size yet, so we'll need to render this later. + # We defer ALL writes until the size is known, to ensure ordering is preserved. + if isinstance(content, Text): + content = content.copy() + self._deferred_renders.append( + DeferredRender(content, width, expand, shrink, scroll_end) + ) + return self + + renderable = self._make_renderable(content) + auto_scroll = self.auto_scroll if scroll_end is None else scroll_end + + console = self.app.console + render_options = console.options + + if isinstance(renderable, Text) and not self.wrap: + render_options = render_options.update(overflow="ignore", no_wrap=True) + + if width is not None: + # Use the width specified by the caller. + # We ignore `expand` and `shrink` when a width is specified. + # This also overrides `min_width` set on the RichLog. + render_width = width + else: + # Compute the width based on available information. + renderable_width = measure_renderables( + console, render_options, [renderable] + ).maximum + + render_width = renderable_width + scrollable_content_width = self.scrollable_content_region.width + + if expand and renderable_width < scrollable_content_width: + # Expand the renderable to the width of the scrollable content region. + render_width = max(renderable_width, scrollable_content_width) + + if shrink and renderable_width > scrollable_content_width: + # Shrink the renderable down to fit within the scrollable content region. + render_width = min(renderable_width, scrollable_content_width) + + # The user has not supplied a width, so make sure min_width is respected. + render_width = max(render_width, self.min_width) + + render_options = render_options.update_width(render_width) + + # Render into (possibly) wrapped lines. + segments = self.app.console.render(renderable, render_options) + lines = list(Segment.split_lines(segments)) + + if not lines: + self._widest_line_width = max(render_width, self._widest_line_width) + self.lines.append(Strip.blank(render_width)) + else: + strips = Strip.from_lines(lines) + for strip in strips: + strip.adjust_cell_length(render_width) + self.lines.extend(strips) + + if self.max_lines is not None and len(self.lines) > self.max_lines: + self._start_line += len(self.lines) - self.max_lines + self.refresh() + self.lines = self.lines[-self.max_lines :] + + # Compute the width after wrapping and trimming + # TODO - this is wrong because if we trim a long line, the max width + # could decrease, but we don't look at which lines were trimmed here. + self._widest_line_width = max( + self._widest_line_width, + max(sum([segment.cell_length for segment in _line]) for _line in lines), + ) + + # Update the virtual size - the width may have changed after adding + # the new line(s), and the height will definitely have changed. + self.virtual_size = Size(self._widest_line_width, len(self.lines)) + + if auto_scroll: + self.scroll_end(animate=animate, immediate=False, x_axis=False) + + return self + + def clear(self) -> Self: + """Clear the text log. + + Returns: + The `RichLog` instance. + """ + self.lines.clear() + self._line_cache.clear() + self._start_line = 0 + self._widest_line_width = 0 + self._deferred_renders.clear() + self.virtual_size = Size(0, len(self.lines)) + self.refresh() + return self + + def render_line(self, y: int) -> Strip: + scroll_x, scroll_y = self.scroll_offset + line = self._render_line( + scroll_y + y, scroll_x, self.scrollable_content_region.width + ) + strip = line.apply_style(self.rich_style) + return strip + + def _render_line(self, y: int, scroll_x: int, width: int) -> Strip: + if y >= len(self.lines): + return Strip.blank(width, self.rich_style) + + key = (y + self._start_line, scroll_x, width, self._widest_line_width) + if key in self._line_cache: + return self._line_cache[key] + + line = self.lines[y].crop_extend(scroll_x, scroll_x + width, self.rich_style) + + self._line_cache[key] = line + return line diff --git a/src/memray/_vendor/textual/widgets/_rule.py b/src/memray/_vendor/textual/widgets/_rule.py new file mode 100644 index 0000000000..b086c3e2a5 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_rule.py @@ -0,0 +1,268 @@ +from __future__ import annotations + +from typing import Iterable + +from rich.console import Console, ConsoleOptions +from rich.segment import Segment +from rich.style import Style +from typing_extensions import Literal + +from memray._vendor.textual.app import RenderResult +from memray._vendor.textual.css._error_tools import friendly_list +from memray._vendor.textual.geometry import Size +from memray._vendor.textual.reactive import Reactive, reactive +from memray._vendor.textual.widget import Widget + +RuleOrientation = Literal["horizontal", "vertical"] +"""The valid orientations of the rule widget.""" + +LineStyle = Literal[ + "ascii", + "blank", + "dashed", + "double", + "heavy", + "hidden", + "none", + "solid", + "thick", +] +"""The valid line styles of the rule widget.""" + + +_VALID_RULE_ORIENTATIONS = {"horizontal", "vertical"} + +_VALID_LINE_STYLES = { + "ascii", + "blank", + "dashed", + "double", + "heavy", + "hidden", + "none", + "solid", + "thick", +} + +_HORIZONTAL_LINE_CHARS: dict[LineStyle, str] = { + "ascii": "-", + "blank": " ", + "dashed": "╍", + "double": "═", + "heavy": "━", + "hidden": " ", + "none": " ", + "solid": "─", + "thick": "█", +} + +_VERTICAL_LINE_CHARS: dict[LineStyle, str] = { + "ascii": "|", + "blank": " ", + "dashed": "╏", + "double": "║", + "heavy": "┃", + "hidden": " ", + "none": " ", + "solid": "│", + "thick": "█", +} + + +class InvalidRuleOrientation(Exception): + """Exception raised for an invalid rule orientation.""" + + +class InvalidLineStyle(Exception): + """Exception raised for an invalid rule line style.""" + + +class HorizontalRuleRenderable: + """Renders a horizontal rule.""" + + def __init__(self, character: str, style: Style, width: int): + self.character = character + self.style = style + self.width = width + + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> Iterable[Segment]: + yield Segment(self.width * self.character, self.style) + + +class VerticalRuleRenderable: + """Renders a vertical rule.""" + + def __init__(self, character: str, style: Style, height: int): + self.character = character + self.style = style + self.height = height + + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> Iterable[Segment]: + segment = Segment(self.character, self.style) + new_line = Segment.line() + return ([segment, new_line] * self.height)[:-1] + + +class Rule(Widget, can_focus=False): + """A rule widget to separate content, similar to a `
` HTML tag.""" + + DEFAULT_CSS = """ + Rule { + color: $secondary; + } + + Rule.-horizontal { + height: 1; + margin: 1 0; + width: 1fr; + } + + Rule.-vertical { + width: 1; + margin: 0 2; + height: 1fr; + } + """ + + orientation: Reactive[RuleOrientation] = reactive[RuleOrientation]("horizontal") + """The orientation of the rule.""" + + line_style: Reactive[LineStyle] = reactive[LineStyle]("solid") + """The line style of the rule.""" + + def __init__( + self, + orientation: RuleOrientation = "horizontal", + line_style: LineStyle = "solid", + *, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + ) -> None: + """Initialize a rule widget. + + Args: + orientation: The orientation of the rule. + line_style: The line style of the rule. + name: The name of the widget. + id: The ID of the widget in the DOM. + classes: The CSS classes of the widget. + disabled: Whether the widget is disabled or not. + """ + super().__init__(name=name, id=id, classes=classes, disabled=disabled) + self.orientation = orientation + self.line_style = line_style + self.expand = True + + def render(self) -> RenderResult: + rule_character: str + style = self.rich_style + if self.orientation == "vertical": + rule_character = _VERTICAL_LINE_CHARS[self.line_style] + return VerticalRuleRenderable( + rule_character, style, self.content_size.height + ) + elif self.orientation == "horizontal": + rule_character = _HORIZONTAL_LINE_CHARS[self.line_style] + return HorizontalRuleRenderable( + rule_character, style, self.content_size.width + ) + else: + raise InvalidRuleOrientation( + f"Valid rule orientations are {friendly_list(_VALID_RULE_ORIENTATIONS)}" + ) + + def watch_orientation( + self, old_orientation: RuleOrientation, orientation: RuleOrientation + ) -> None: + self.remove_class(f"-{old_orientation}") + self.add_class(f"-{orientation}") + + def validate_orientation(self, orientation: RuleOrientation) -> RuleOrientation: + if orientation not in _VALID_RULE_ORIENTATIONS: + raise InvalidRuleOrientation( + f"Valid rule orientations are {friendly_list(_VALID_RULE_ORIENTATIONS)}" + ) + return orientation + + def validate_line_style(self, style: LineStyle) -> LineStyle: + if style not in _VALID_LINE_STYLES: + raise InvalidLineStyle( + f"Valid rule line styles are {friendly_list(_VALID_LINE_STYLES)}" + ) + return style + + def get_content_width(self, container: Size, viewport: Size) -> int: + if self.orientation == "horizontal": + return container.width + return 1 + + def get_content_height(self, container: Size, viewport: Size, width: int) -> int: + if self.orientation == "horizontal": + return 1 + return container.height + + @classmethod + def horizontal( + cls, + line_style: LineStyle = "solid", + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + ) -> Rule: + """Utility constructor for creating a horizontal rule. + + Args: + line_style: The line style of the rule. + name: The name of the widget. + id: The ID of the widget in the DOM. + classes: The CSS classes of the widget. + disabled: Whether the widget is disabled or not. + + Returns: + A rule widget with horizontal orientation. + """ + return Rule( + orientation="horizontal", + line_style=line_style, + name=name, + id=id, + classes=classes, + disabled=disabled, + ) + + @classmethod + def vertical( + cls, + line_style: LineStyle = "solid", + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + ) -> Rule: + """Utility constructor for creating a vertical rule. + + Args: + line_style: The line style of the rule. + name: The name of the widget. + id: The ID of the widget in the DOM. + classes: The CSS classes of the widget. + disabled: Whether the widget is disabled or not. + + Returns: + A rule widget with vertical orientation. + """ + return Rule( + orientation="vertical", + line_style=line_style, + name=name, + id=id, + classes=classes, + disabled=disabled, + ) diff --git a/src/memray/_vendor/textual/widgets/_select.py b/src/memray/_vendor/textual/widgets/_select.py new file mode 100644 index 0000000000..2b1f976ebe --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_select.py @@ -0,0 +1,716 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Generic, Hashable, Iterable, TypeVar + +import rich.repr +from rich.console import RenderableType +from rich.text import Text + +from memray._vendor.textual import events, on +from memray._vendor.textual.binding import Binding +from memray._vendor.textual.containers import Horizontal, Vertical +from memray._vendor.textual.css.query import NoMatches +from memray._vendor.textual.message import Message +from memray._vendor.textual.reactive import reactive, var +from memray._vendor.textual.timer import Timer +from memray._vendor.textual.widgets import Static +from memray._vendor.textual.widgets._option_list import Option, OptionList + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + from memray._vendor.textual.app import ComposeResult + + +class NonSelectableStatic(Static): + ALLOW_SELECT = False + + +class NoSelection: + """Used by the `Select` widget to flag the unselected state. See [`Select.NULL`][textual.widgets.Select.NULL].""" + + def __repr__(self) -> str: + return "Select.NULL" + + +NULL = NoSelection() + + +class InvalidSelectValueError(Exception): + """Raised when setting a [`Select`][textual.widgets.Select] to an unknown option.""" + + +class EmptySelectError(Exception): + """Raised when a [`Select`][textual.widgets.Select] has no options and `allow_blank=False`.""" + + +class SelectOverlay(OptionList): + """The 'pop-up' overlay for the Select control.""" + + BINDINGS = [("escape", "dismiss", "Dismiss menu")] + + ALLOW_SELECT = False + + @dataclass + class Dismiss(Message): + """Inform ancestor the overlay should be dismissed.""" + + lost_focus: bool = False + """True if the overlay lost focus.""" + + @dataclass + class UpdateSelection(Message): + """Inform ancestor the selection was changed.""" + + option_index: int + """The index of the new selection.""" + + def __init__(self, type_to_search: bool = True) -> None: + super().__init__() + self._type_to_search = type_to_search + """If True (default), the user can type to search for a matching option and the cursor will jump to it.""" + + self._search_query: str = "" + """The current search query used to find a matching option and jump to it.""" + + self._search_reset_delay: float = 0.7 + """The number of seconds to wait after the most recent key press before resetting the search query.""" + + def on_mount(self) -> None: + def reset_query() -> None: + self._search_query = "" + + self._search_reset_timer = Timer( + self, self._search_reset_delay, callback=reset_query + ) + + def watch_has_focus(self, value: bool) -> None: + self._search_query = "" + if value: + self._search_reset_timer._start() + else: + self._search_reset_timer.reset() + self._search_reset_timer.stop() + super().watch_has_focus(value) + + async def _on_key(self, event: events.Key) -> None: + if not self._type_to_search: + return + + self._search_reset_timer.reset() + + if event.character is not None and event.is_printable: + event.time = 0 + event.stop() + event.prevent_default() + + # Update the search query and jump to the next option that matches. + self._search_query += event.character + index = self._find_search_match(self._search_query) + if index is not None: + self.select(index) + + def check_consume_key(self, key: str, character: str | None = None) -> bool: + """Check if the widget may consume the given key.""" + return ( + self._type_to_search and character is not None and character.isprintable() + ) + + def select(self, index: int | None) -> None: + """Move selection. + + Args: + index: Index of new selection. + """ + self.highlighted = index + self.scroll_to_highlight() + + def _find_search_match(self, query: str) -> int | None: + """A simple substring search which favors options containing the substring + earlier in the prompt. + + Args: + query: The substring to search for. + + Returns: + The index of the option that matches the query, or `None` if no match is found. + """ + best_match: int | None = None + minimum_index: int | None = None + + query = query.lower() + for index, option in enumerate(self._options): + prompt = option.prompt + if isinstance(prompt, Text): + lower_prompt = prompt.plain.lower() + elif isinstance(prompt, str): + lower_prompt = prompt.lower() + else: + continue + + match_index = lower_prompt.find(query) + if match_index != -1 and ( + minimum_index is None or match_index < minimum_index + ): + best_match = index + minimum_index = match_index + + return best_match + + def action_dismiss(self) -> None: + """Dismiss the overlay.""" + self.post_message(self.Dismiss()) + + def _on_blur(self, _event: events.Blur) -> None: + """On blur we want to dismiss the overlay.""" + self.post_message(self.Dismiss(lost_focus=True)) + self.suppress_click() + + def on_option_list_option_selected(self, event: OptionList.OptionSelected) -> None: + """Inform parent when an option is selected.""" + event.stop() + self.post_message(self.UpdateSelection(event.option_index)) + + def on_option_list_option_highlighted( + self, event: OptionList.OptionHighlighted + ) -> None: + """Stop option list highlighted messages leaking.""" + event.stop() + + +class SelectCurrent(Horizontal): + """Displays the currently selected option.""" + + DEFAULT_CSS = """ + SelectCurrent { + border: tall $border-blurred; + color: $foreground; + background: $surface; + width: 1fr; + height: auto; + padding: 0 2; + pointer: pointer; + + &.-textual-compact { + border: none !important; + } + + &:ansi { + border: tall ansi_blue; + color: ansi_default; + background: ansi_default; + } + + Static#label { + width: 1fr; + height: auto; + color: $foreground 50%; + background: transparent; + } + + &.-has-value Static#label { + color: $foreground; + } + + .arrow { + box-sizing: content-box; + width: 1; + height: 1; + padding: 0 0 0 1; + color: $foreground 50%; + background: transparent; + } + } + """ + + ALLOW_SELECT = False + + has_value: var[bool] = var(False) + """True if there is a current value, or False if it is None.""" + + class Toggle(Message): + """Request toggle overlay.""" + + def __init__(self, placeholder: str) -> None: + """Initialize the SelectCurrent. + + Args: + placeholder: A string to display when there is nothing selected. + """ + super().__init__() + self.placeholder = placeholder + self.label: RenderableType | NoSelection = Select.NULL + + def update(self, label: RenderableType | NoSelection) -> None: + """Update the content in the widget. + + Args: + label: A renderable to display, or `None` for the placeholder. + """ + self.label = label + self.has_value = label is not Select.NULL + self.query_one("#label", Static).update( + self.placeholder if isinstance(label, NoSelection) else label + ) + + def compose(self) -> ComposeResult: + """Compose label and down arrow.""" + yield NonSelectableStatic(self.placeholder, id="label") + yield NonSelectableStatic("▼", classes="arrow down-arrow") + yield NonSelectableStatic("▲", classes="arrow up-arrow") + + def _watch_has_value(self, has_value: bool) -> None: + """Toggle the class.""" + self.set_class(has_value, "-has-value") + + def _on_click(self, event: events.Click) -> None: + """Inform ancestor we want to toggle.""" + event.stop() + self.post_message(self.Toggle()) + + +SelectType = TypeVar("SelectType", bound=Hashable) +"""The type used for data in the Select.""" +SelectOption: TypeAlias = "tuple[str, SelectType]" +"""The type used for options in the Select.""" + + +class Select(Generic[SelectType], Vertical, can_focus=True): + """Widget to select from a list of possible options. + + A Select displays the current selection. + When activated with ++enter++ the widget displays an overlay with a list of all possible options. + """ + + NULL = NULL + """Constant to flag that the widget has no selection.""" + + BINDINGS = [ + Binding("enter,down,space,up", "show_overlay", "Show menu", show=False), + ] + """ + | Key(s) | Description | + | :- | :- | + | enter,down,space,up | Activate the overlay | + """ + + ALLOW_SELECT = False + + DEFAULT_CSS = """ + Select { + height: auto; + color: $foreground; + + &.-textual-compact { + & > SelectCurrent { + padding: 0 1 0 0; + border: none !important; + } + } + + .up-arrow { + display: none; + } + + &:focus > SelectCurrent { + border: tall $border; + background-tint: $foreground 5%; + } + + & > SelectOverlay { + width: 1fr; + display: none; + height: auto; + max-height: 12; + overlay: screen; + constrain: none inside; + color: $foreground; + border: tall $border-blurred; + background: $surface; + &:focus { + background-tint: $foreground 5%; + } + & > .option-list--option { + padding: 0 1; + } + } + + &.-expanded { + .down-arrow { + display: none; + } + .up-arrow { + display: block; + } + & > SelectOverlay { + display: block; + } + } + + } + + """ + + expanded: var[bool] = var(False, init=False) + """True to show the overlay, otherwise False.""" + prompt: var[str] = var[str]("Select") + """The prompt to show when no value is selected.""" + value: var[SelectType | NoSelection] = var(NULL, init=False) + """The value of the selection. + + If the widget has no selection, its value will be [`Select.NULL`][textual.widgets.Select.NULL]. + Setting this to an illegal value will raise a [`InvalidSelectValueError`][textual.widgets.select.InvalidSelectValueError] + exception. + """ + + compact = reactive(False, toggle_class="-textual-compact") + """Make the select compact (without borders).""" + + @rich.repr.auto + class Changed(Message): + """Posted when the select value was changed. + + This message can be handled using a `on_select_changed` method. + """ + + def __init__( + self, select: Select[SelectType], value: SelectType | NoSelection + ) -> None: + """ + Initialize the Changed message. + """ + super().__init__() + self.select = select + """The select widget.""" + self.value = value + """The value of the Select when it changed.""" + + def __rich_repr__(self) -> rich.repr.Result: + yield self.select + yield self.value + + @property + def control(self) -> Select[SelectType]: + """The Select that sent the message.""" + return self.select + + def __init__( + self, + options: Iterable[tuple[RenderableType, SelectType]], + *, + prompt: str = "Select", + allow_blank: bool = True, + value: SelectType | NoSelection = NULL, + type_to_search: bool = True, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + tooltip: RenderableType | None = None, + compact: bool = False, + ): + """Initialize the Select control. + + Args: + options: Options to select from. If no options are provided then + `allow_blank` must be set to `True`. + prompt: Text to show in the control when no option is selected. + allow_blank: Enables or disables the ability to have the widget in a state + with no selection made, in which case its value is set to the constant + [`Select.NULL`][textual.widgets.Select.NULL]. + value: Initial value selected. Should be one of the values in `options`. + If no initial value is set and `allow_blank` is `False`, the widget + will auto-select the first available option. + type_to_search: If `True`, typing will search for options. + name: The name of the select control. + id: The ID of the control in the DOM. + classes: The CSS classes of the control. + disabled: Whether the control is disabled or not. + tooltip: Optional tooltip. + compact: Enable compact select (without borders). + + Raises: + EmptySelectError: If no options are provided and `allow_blank` is `False`. + """ + super().__init__(name=name, id=id, classes=classes, disabled=disabled) + self._allow_blank = allow_blank + self.prompt = prompt + self._value = value + self._setup_variables_for_options(options) + self._type_to_search = type_to_search + if tooltip is not None: + self.tooltip = tooltip + self.compact = compact + + @classmethod + def from_values( + cls, + values: Iterable[SelectType], + *, + prompt: str = "Select", + allow_blank: bool = True, + value: SelectType | NoSelection = NULL, + type_to_search: bool = True, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + compact: bool = False, + ) -> Select[SelectType]: + """Initialize the Select control with values specified by an arbitrary iterable + + The options shown in the control are computed by calling the built-in `str` + on each value. + + Args: + values: Values used to generate options to select from. + prompt: Text to show in the control when no option is selected. + allow_blank: Enables or disables the ability to have the widget in a state + with no selection made, in which case its value is set to the constant + [`Select.NULL`][textual.widgets.Select.NULL]. + value: Initial value selected. Should be one of the values in `values`. + If no initial value is set and `allow_blank` is `False`, the widget + will auto-select the first available value. + type_to_search: If `True`, typing will search for options. + name: The name of the select control. + id: The ID of the control in the DOM. + classes: The CSS classes of the control. + disabled: Whether the control is disabled or not. + compact: Enable compact style? + + Returns: + A new Select widget with the provided values as options. + """ + options_iterator = [(str(value), value) for value in values] + + return cls( + options_iterator, + prompt=prompt, + allow_blank=allow_blank, + value=value, + type_to_search=type_to_search, + name=name, + id=id, + classes=classes, + disabled=disabled, + compact=compact, + ) + + @property + def selection(self) -> SelectType | None: + """The currently selected item. + + Unlike [value][textual.widgets.Select.value], this will not return Blanks. + If nothing is selected, this will return `None`. + + """ + value = self.value + if isinstance(value, NoSelection): + return None + return value + + def _setup_variables_for_options( + self, + options: Iterable[tuple[RenderableType, SelectType]], + ) -> None: + """Setup function for the auxiliary variables related to options. + + This method sets up `self._options` and `self._legal_values`. + """ + self._options: list[tuple[RenderableType, SelectType | NoSelection]] = [] + if self._allow_blank: + self._options.append(("", self.NULL)) + self._options.extend(options) + + if not self._options: + raise EmptySelectError( + "Select options cannot be empty if selection can't be blank." + ) + + self._legal_values: set[SelectType | NoSelection] = { + value for _, value in self._options + } + + def _setup_options_renderables(self) -> None: + """Sets up the `Option` renderables associated with the `Select` options.""" + options: list[Option] = [ + ( + Option(Text(self.prompt, style="dim")) + if value == self.NULL + else Option(prompt) + ) + for prompt, value in self._options + ] + + option_list = self.query_one(SelectOverlay) + option_list.clear_options() + option_list.add_options(options) + + def _init_selected_option(self, hint: SelectType | NoSelection = NULL) -> None: + """Initialises the selected option for the `Select`.""" + if hint == self.NULL and not self._allow_blank: + hint = self._options[0][1] + self.value = hint + + def set_options(self, options: Iterable[tuple[RenderableType, SelectType]]) -> None: + """Set the options for the Select. + + This will reset the selection. The selection will be empty, if allowed, otherwise + the first valid option is picked. + + Args: + options: An iterable of tuples containing the renderable to display for each + option and the corresponding internal value. + + Raises: + EmptySelectError: If the options iterable is empty and `allow_blank` is + `False`. + """ + self._setup_variables_for_options(options) + self._setup_options_renderables() + self._init_selected_option() + + def _validate_value( + self, value: SelectType | NoSelection + ) -> SelectType | NoSelection: + """Ensure the new value is a valid option. + + If `allow_blank` is `True`, `None` is also a valid value and corresponds to no + selection. + + Raises: + InvalidSelectValueError: If the new value does not correspond to any known + value. + """ + if value not in self._legal_values: + # It would make sense to use `None` to flag that the Select has no selection, + # so we provide a helpful message to catch this mistake in case people didn't + # realise we use a special value to flag "no selection". + help_text = " Did you mean to use Select.clear()?" if value is None else "" + raise InvalidSelectValueError( + f"Illegal select value {value!r}." + help_text + ) + + return value + + def _watch_value(self, value: SelectType | NoSelection) -> None: + """Update the current value when it changes.""" + self._value = value + try: + select_current = self.query_one(SelectCurrent) + except NoMatches: + pass + else: + if value == self.NULL: + select_current.update(self.NULL) + else: + for index, (prompt, _value) in enumerate(self._options): + if _value == value: + select_overlay = self.query_one(SelectOverlay) + select_overlay.highlighted = index + select_current.update(prompt) + break + self.post_message(self.Changed(self, value)) + + def compose(self) -> ComposeResult: + """Compose Select with overlay and current value.""" + yield SelectCurrent(self.prompt) + yield SelectOverlay(type_to_search=self._type_to_search).data_bind( + compact=Select.compact + ) + + def _on_mount(self, _event: events.Mount) -> None: + """Set initial values.""" + self._setup_options_renderables() + self._init_selected_option(self._value) + + def _watch_expanded(self, expanded: bool) -> None: + """Display or hide overlay.""" + try: + overlay = self.query_one(SelectOverlay) + except NoMatches: + # The widget has likely been removed + return + self.set_class(expanded, "-expanded") + if expanded: + overlay.focus(scroll_visible=False) + if self.value is self.NULL: + overlay.select(None) + self.query_one(SelectCurrent).has_value = False + else: + value = self.value + for index, (_prompt, prompt_value) in enumerate(self._options): + if value == prompt_value: + overlay.select(index) + break + self.query_one(SelectCurrent).has_value = True + + @on(SelectCurrent.Toggle) + def _select_current_toggle(self, event: SelectCurrent.Toggle) -> None: + """Show the overlay when toggled.""" + event.stop() + self.expanded = not self.expanded + + @on(SelectOverlay.Dismiss) + def _select_overlay_dismiss(self, event: SelectOverlay.Dismiss) -> None: + """Dismiss the overlay.""" + event.stop() + self.expanded = False + if not event.lost_focus: + # If the overlay didn't lose focus, we want to re-focus the select. + self.focus() + + @on(SelectOverlay.UpdateSelection) + def _update_selection(self, event: SelectOverlay.UpdateSelection) -> None: + """Update the current selection.""" + event.stop() + value = self._options[event.option_index][1] + if value != self.value: + self.value = value + + self.focus() + self.expanded = False + + def action_show_overlay(self) -> None: + """Show the overlay.""" + select_current = self.query_one(SelectCurrent) + select_current.has_value = True + self.expanded = True + # If we haven't opened the overlay yet, highlight the first option. + select_overlay = self.query_one(SelectOverlay) + if select_overlay.highlighted is None: + select_overlay.action_first() + + def is_blank(self) -> bool: + """Indicates whether this `Select` is blank or not. + + Returns: + True if the selection is blank, False otherwise. + """ + return self.value == self.NULL + + def clear(self) -> None: + """Clear the selection if `allow_blank` is `True`. + + Raises: + InvalidSelectValueError: If `allow_blank` is set to `False`. + """ + try: + self.value = self.NULL + except InvalidSelectValueError: + raise InvalidSelectValueError( + "Can't clear selection if allow_blank is set to False." + ) from None + + def _watch_prompt(self, prompt: str) -> None: + if not self.is_mounted: + return + select_current = self.query_one(SelectCurrent) + select_current.placeholder = prompt + if not self._allow_blank: + return + if self.value == self.NULL: + select_current.update(self.NULL) + option_list = self.query_one(SelectOverlay) + option_list.replace_option_prompt_at_index(0, Text(prompt, style="dim")) diff --git a/src/memray/_vendor/textual/widgets/_selection_list.py b/src/memray/_vendor/textual/widgets/_selection_list.py new file mode 100644 index 0000000000..2e38d99ad9 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_selection_list.py @@ -0,0 +1,715 @@ +"""Provides a selection list widget, allowing one or more items to be selected.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, ClassVar, Generic, Iterable, TypeVar, cast + +from rich.repr import Result +from rich.segment import Segment +from rich.style import Style +from typing_extensions import Self + +from memray._vendor.textual import events +from memray._vendor.textual.binding import Binding +from memray._vendor.textual.content import Content, ContentText +from memray._vendor.textual.messages import Message +from memray._vendor.textual.strip import Strip +from memray._vendor.textual.widgets._option_list import ( + Option, + OptionDoesNotExist, + OptionList, + OptionListContent, +) +from memray._vendor.textual.widgets._toggle_button import ToggleButton + +SelectionType = TypeVar("SelectionType") +"""The type for the value of a [`Selection`][textual.widgets.selection_list.Selection] in a [`SelectionList`][textual.widgets.SelectionList]""" + +MessageSelectionType = TypeVar("MessageSelectionType") +"""The type for the value of a [`Selection`][textual.widgets.selection_list.Selection] in a [`SelectionList`][textual.widgets.SelectionList] message.""" + + +class SelectionError(TypeError): + """Type of an error raised if a selection is badly-formed.""" + + +class Selection(Generic[SelectionType], Option): + """A selection for a [`SelectionList`][textual.widgets.SelectionList].""" + + def __init__( + self, + prompt: ContentText, + value: SelectionType, + initial_state: bool = False, + id: str | None = None, + disabled: bool = False, + ): + """Initialise the selection. + + Args: + prompt: The prompt for the selection. + value: The value for the selection. + initial_state: The initial selected state of the selection. + id: The optional ID for the selection. + disabled: The initial enabled/disabled state. Enabled by default. + """ + + selection_prompt = Content.from_text(prompt) + super().__init__(selection_prompt.split()[0], id, disabled) + self._value: SelectionType = value + """The value associated with the selection.""" + self._initial_state: bool = initial_state + """The initial selected state for the selection.""" + + @property + def value(self) -> SelectionType: + """The value for this selection.""" + return self._value + + @property + def initial_state(self) -> bool: + """The initial selected state for the selection.""" + return self._initial_state + + +class SelectionList(Generic[SelectionType], OptionList): + """A vertical selection list that allows making multiple selections.""" + + BINDINGS = [Binding("space", "select", "Toggle option", show=False)] + """ + | Key(s) | Description | + | :- | :- | + | space | Toggle the state of the highlighted selection. | + """ + + COMPONENT_CLASSES: ClassVar[set[str]] = { + "selection-list--button", + "selection-list--button-selected", + "selection-list--button-highlighted", + "selection-list--button-selected-highlighted", + } + """ + | Class | Description | + | :- | :- | + | `selection-list--button` | Target the default button style. | + | `selection-list--button-selected` | Target a selected button style. | + | `selection-list--button-highlighted` | Target a highlighted button style. | + | `selection-list--button-selected-highlighted` | Target a highlighted selected button style. | + """ + + DEFAULT_CSS = """ + SelectionList { + height: auto; + text-wrap: nowrap; + text-overflow: ellipsis; + + & > .selection-list--button { + color: $panel-darken-2; + background: $panel; + } + + & > .selection-list--button-highlighted { + color: $panel-darken-2; + background: $panel; + } + + & > .selection-list--button-selected { + color: $text-success; + background: $panel; + } + + & > .selection-list--button-selected-highlighted { + color: $text-success; + background: $panel; + } + + } + """ + + class SelectionMessage(Generic[MessageSelectionType], Message): + """Base class for all selection messages.""" + + def __init__( + self, selection_list: SelectionList[MessageSelectionType], index: int + ) -> None: + """Initialise the selection message. + + Args: + selection_list: The selection list that owns the selection. + index: The index of the selection that the message relates to. + """ + super().__init__() + self.selection_list: SelectionList[MessageSelectionType] = selection_list + """The selection list that sent the message.""" + self.selection: Selection[MessageSelectionType] = ( + selection_list.get_option_at_index(index) + ) + """The highlighted selection.""" + self.selection_index: int = index + """The index of the selection that the message relates to.""" + + @property + def control(self) -> OptionList: + """The selection list that sent the message. + + This is an alias for + [`SelectionMessage.selection_list`][textual.widgets.SelectionList.SelectionMessage.selection_list] + and is used by the [`on`][textual.on] decorator. + """ + return self.selection_list + + def __rich_repr__(self) -> Result: + yield "selection_list", self.selection_list + yield "selection", self.selection + yield "selection_index", self.selection_index + + class SelectionHighlighted(SelectionMessage[MessageSelectionType]): + """Message sent when a selection is highlighted. + + Can be handled using `on_selection_list_selection_highlighted` in a subclass of + [`SelectionList`][textual.widgets.SelectionList] or in a parent node in the DOM. + """ + + class SelectionToggled(SelectionMessage[MessageSelectionType]): + """Message sent when a selection is toggled. + + This is only sent when the value is *explicitly* toggled e.g. + via `toggle` or `toggle_all`, or via user interaction. + If you programmatically set a value to be selected, this message will + not be sent, even if it happens to be the opposite of what was + originally selected (i.e. setting a True to a False or vice-versa). + + Since this message indicates a toggle occurring at a per-option level, + a message will be sent for each option that is toggled, even when a + bulk action is performed (e.g. via `toggle_all`). + + Can be handled using `on_selection_list_selection_toggled` in a subclass of + [`SelectionList`][textual.widgets.SelectionList] or in a parent node in the DOM. + """ + + @dataclass + class SelectedChanged(Generic[MessageSelectionType], Message): + """Message sent when the collection of selected values changes. + + This is sent regardless of whether the change occurred via user interaction + or programmatically via the `SelectionList` API. + + When a bulk change occurs, such as through `select_all` or `deselect_all`, + only a single `SelectedChanged` message will be sent (rather than one per + option). + + Can be handled using `on_selection_list_selected_changed` in a subclass of + [`SelectionList`][textual.widgets.SelectionList] or in a parent node in the DOM. + """ + + selection_list: SelectionList[MessageSelectionType] + """The `SelectionList` that sent the message.""" + + @property + def control(self) -> SelectionList[MessageSelectionType]: + """An alias for `selection_list`.""" + return self.selection_list + + def __init__( + self, + *selections: Selection[SelectionType] + | tuple[ContentText, SelectionType] + | tuple[ContentText, SelectionType, bool], + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + compact: bool = False, + ): + """Initialise the selection list. + + Args: + *selections: The content for the selection list. + name: The name of the selection list. + id: The ID of the selection list in the DOM. + classes: The CSS classes of the selection list. + disabled: Whether the selection list is disabled or not. + compact: Enable a compact style? + """ + + self._selected: dict[SelectionType, None] = {} + """Tracking of which values are selected.""" + self._send_messages = False + """Keep track of when we're ready to start sending messages.""" + options = [self._make_selection(selection) for selection in selections] + self._values: dict[SelectionType, int] = { + option.value: index for index, option in enumerate(options) + } + """Keeps track of which value relates to which option.""" + super().__init__(*options, name=name, id=id, classes=classes, disabled=disabled) + self.compact = compact + + @property + def selected(self) -> list[SelectionType]: + """The selected values. + + This is a list of all of the + [values][textual.widgets.selection_list.Selection.value] associated + with selections in the list that are currently in the selected + state. + """ + return list(self._selected.keys()) + + def _on_mount(self, _event: events.Mount) -> None: + """Configure the list once the DOM is ready.""" + self._send_messages = True + + def _message_changed(self) -> None: + """Post a message that the selected collection has changed, where appropriate. + + Note: + A message will only be sent if `_send_messages` is `True`. This + makes this safe to call before the widget is ready for posting + messages. + """ + if self._send_messages: + self.post_message(self.SelectedChanged(self).set_sender(self)) + + def _message_toggled(self, option_index: int) -> None: + """Post a message that an option was toggled, where appropriate. + + Note: + A message will only be sent if `_send_messages` is `True`. This + makes this safe to call before the widget is ready for posting + messages. + """ + if self._send_messages: + self.post_message( + self.SelectionToggled(self, option_index).set_sender(self) + ) + + def _apply_to_all(self, state_change: Callable[[SelectionType], bool]) -> Self: + """Apply a selection state change to all selection options in the list. + + Args: + state_change: The state change function to apply. + + Returns: + The [`SelectionList`][textual.widgets.SelectionList] instance. + + Note: + This method will post a single + [`SelectedChanged`][textual.widgets.OptionList.SelectedChanged] + message if a change is made in a call to this method. + """ + + # Keep track of if anything changed. + changed = False + + # Apply the state change function to all options. + # We don't send a SelectedChanged for each option, and instead + # send a single SelectedChanged afterwards if any values change. + with self.prevent(self.SelectedChanged): + for selection in self._options: + changed = ( + state_change(cast(Selection[SelectionType], selection).value) + or changed + ) + + # If the above did make a change, *then* send a message. + if changed: + self._message_changed() + + self.refresh() + return self + + def _select(self, value: SelectionType) -> bool: + """Mark the given value as selected. + + Args: + value: The value to mark as selected. + + Returns: + `True` if the value was selected, `False` if not. + """ + if value not in self._selected: + self._selected[value] = None + self._message_changed() + return True + return False + + def select(self, selection: Selection[SelectionType] | SelectionType) -> Self: + """Mark the given selection as selected. + + Args: + selection: The selection to mark as selected. + + Returns: + The [`SelectionList`][textual.widgets.SelectionList] instance. + """ + if self._select( + selection.value + if isinstance(selection, Selection) + else cast(SelectionType, selection) + ): + self.refresh() + return self + + def select_all(self) -> Self: + """Select all items. + + Returns: + The [`SelectionList`][textual.widgets.SelectionList] instance. + """ + return self._apply_to_all(self._select) + + def _deselect(self, value: SelectionType) -> bool: + """Mark the given selection as not selected. + + Args: + value: The value to mark as not selected. + + Returns: + `True` if the value was deselected, `False` if not. + """ + try: + del self._selected[value] + except KeyError: + return False + self._message_changed() + return True + + def deselect(self, selection: Selection[SelectionType] | SelectionType) -> Self: + """Mark the given selection as not selected. + + Args: + selection: The selection to mark as not selected. + + Returns: + The [`SelectionList`][textual.widgets.SelectionList] instance. + """ + if self._deselect( + selection.value + if isinstance(selection, Selection) + else cast(SelectionType, selection) + ): + self.refresh() + return self + + def deselect_all(self) -> Self: + """Deselect all items. + + Returns: + The [`SelectionList`][textual.widgets.SelectionList] instance. + """ + return self._apply_to_all(self._deselect) + + def _toggle(self, value: SelectionType) -> bool: + """Toggle the selection state of the given value. + + Args: + value: The value to toggle. + + Returns: + `True`. + """ + if value in self._selected: + self._deselect(value) + else: + self._select(value) + self._message_toggled(self._values[value]) + return True + + def toggle(self, selection: Selection[SelectionType] | SelectionType) -> Self: + """Toggle the selected state of the given selection. + + Args: + selection: The selection to toggle. + + Returns: + The [`SelectionList`][textual.widgets.SelectionList] instance. + """ + self._toggle( + selection.value + if isinstance(selection, Selection) + else cast(SelectionType, selection) + ) + self.refresh() + return self + + def toggle_all(self) -> Self: + """Toggle all items. + + Returns: + The [`SelectionList`][textual.widgets.SelectionList] instance. + """ + return self._apply_to_all(self._toggle) + + def _make_selection( + self, + selection: ( + Selection[SelectionType] + | tuple[ContentText, SelectionType] + | tuple[ContentText, SelectionType, bool] + ), + ) -> Selection[SelectionType]: + """Turn incoming selection data into a `Selection` instance. + + Args: + selection: The selection data. + + Returns: + An instance of a `Selection`. + + Raises: + SelectionError: If the selection was badly-formed. + """ + + # If we've been given a tuple of some sort, turn that into a proper + # Selection. + if isinstance(selection, tuple): + if len(selection) == 2: + selection = cast( + "tuple[ContentText, SelectionType, bool]", (*selection, False) + ) + elif len(selection) != 3: + raise SelectionError(f"Expected 2 or 3 values, got {len(selection)}") + selection = Selection[SelectionType](*selection) + + # At this point we should have a proper selection. + assert isinstance(selection, Selection) + + # If the initial state for this is that it's selected, add it to the + # selected collection. + if selection.initial_state: + self._select(selection.value) + + return selection + + def _toggle_highlighted_selection(self) -> None: + """Toggle the state of the highlighted selection. + + If nothing is selected in the list this is a non-operation. + """ + if self.highlighted is not None: + self.toggle(self.get_option_at_index(self.highlighted)) + + def _get_left_gutter_width(self) -> int: + """Returns the size of any left gutter that should be taken into account. + + Returns: + The width of the left gutter. + """ + return len( + ToggleButton.BUTTON_LEFT + + ToggleButton.BUTTON_INNER + + ToggleButton.BUTTON_RIGHT + + " " + ) + + def render_line(self, y: int) -> Strip: + """Render a line in the display. + + Args: + y: The line to render. + + Returns: + A [`Strip`][textual.strip.Strip] that is the line to render. + """ + + # TODO: This is rather crufty and hard to fathom. Candidate for a rewrite. + + # First off, get the underlying prompt from OptionList. + line = super().render_line(y) + + # We know the prompt we're going to display, what we're going to do + # is place a CheckBox-a-like button next to it. So to start with + # let's pull out the actual Selection we're looking at right now. + _, scroll_y = self.scroll_offset + selection_index = scroll_y + y + try: + selection = self.get_option_at_index(selection_index) + except OptionDoesNotExist: + return line + + # Figure out which component style is relevant for a checkbox on + # this particular line. + component_style = "selection-list--button" + if selection.value in self._selected: + component_style += "-selected" + if self.highlighted == selection_index: + component_style += "-highlighted" + + # # # Get the underlying style used for the prompt. + # TODO: This is not a reliable way of getting the base style + underlying_style = next(iter(line)).style or self.rich_style + assert underlying_style is not None + + # Get the style for the button. + button_style = self.get_component_rich_style(component_style) + + # Build the style for the side characters. Note that this is + # sensitive to the type of character used, so pay attention to + # BUTTON_LEFT and BUTTON_RIGHT. + side_style = Style.from_color(button_style.bgcolor, underlying_style.bgcolor) + + # Add the option index to the style. This is used to determine which + # option to select when the button is clicked or hovered. + side_style += Style(meta={"option": selection_index}) + button_style += Style(meta={"option": selection_index}) + + # At this point we should have everything we need to place a + # "button" before the option. + return Strip( + [ + Segment(ToggleButton.BUTTON_LEFT, style=side_style), + Segment(ToggleButton.BUTTON_INNER, style=button_style), + Segment(ToggleButton.BUTTON_RIGHT, style=side_style), + Segment(" ", style=underlying_style), + *line, + ] + ) + + def _on_option_list_option_highlighted( + self, event: OptionList.OptionHighlighted + ) -> None: + """Capture the `OptionList` highlight event and turn it into a [`SelectionList`][textual.widgets.SelectionList] event. + + Args: + event: The event to capture and recreate. + """ + event.stop() + self.post_message(self.SelectionHighlighted(self, event.option_index)) + + def _on_option_list_option_selected(self, event: OptionList.OptionSelected) -> None: + """Capture the `OptionList` selected event and turn it into a [`SelectionList`][textual.widgets.SelectionList] event. + + Args: + event: The event to capture and recreate. + """ + event.stop() + self._toggle_highlighted_selection() + + def get_option_at_index(self, index: int) -> Selection[SelectionType]: + """Get the selection option at the given index. + + Args: + index: The index of the selection option to get. + + Returns: + The selection option at that index. + + Raises: + OptionDoesNotExist: If there is no selection option with the index. + """ + return cast("Selection[SelectionType]", super().get_option_at_index(index)) + + def get_option(self, option_id: str) -> Selection[SelectionType]: + """Get the selection option with the given ID. + + Args: + option_id: The ID of the selection option to get. + + Returns: + The selection option with the ID. + + Raises: + OptionDoesNotExist: If no selection option has the given ID. + """ + return cast("Selection[SelectionType]", super().get_option(option_id)) + + def _pre_remove_option(self, option: Option, index: int) -> None: + """Hook called prior to removing an option.""" + assert isinstance(option, Selection) + self._deselect(option.value) + del self._values[option.value] + + # Decrement index of options after the one we just removed. + self._values = { + option_value: option_index - 1 if option_index > index else option_index + for option_value, option_index in self._values.items() + } + + def add_options( + self, + items: Iterable[ + OptionListContent + | Selection[SelectionType] + | tuple[ContentText, SelectionType] + | tuple[ContentText, SelectionType, bool] + ], + ) -> Self: + """Add new selection options to the end of the list. + + Args: + items: The new items to add. + + Returns: + The [`SelectionList`][textual.widgets.SelectionList] instance. + + Raises: + DuplicateID: If there is an attempt to use a duplicate ID. + SelectionError: If one of the selection options is of the wrong form. + """ + # This... is sort of sub-optimal, but a natural consequence of + # inheriting from and narrowing down OptionList. Here we don't want + # things like a separator, or a base Option, being passed in. So we + # extend the types of accepted items to keep mypy and friends happy, + # but then we runtime check that we've been given sensible types (in + # this case the supported tuple values). + cleaned_options: list[Selection[SelectionType]] = [] + for item in items: + if isinstance(item, tuple): + cleaned_options.append( + self._make_selection( + cast( + "tuple[ContentText, SelectionType] | tuple[ContentText, SelectionType, bool]", + item, + ) + ) + ) + elif isinstance(item, Selection): + cleaned_options.append(self._make_selection(item)) + else: + raise SelectionError( + "Only Selection or a prompt/value tuple is supported in SelectionList" + ) + + # Add the new items to the value mappings. + self._values.update( + { + option.value: index + for index, option in enumerate(cleaned_options, start=self.option_count) + } + ) + + return super().add_options(cleaned_options) + + def add_option( + self, + item: ( + OptionListContent + | Selection + | tuple[ContentText, SelectionType] + | tuple[ContentText, SelectionType, bool] + ) = None, + ) -> Self: + """Add a new selection option to the end of the list. + + Args: + item: The new item to add. + + Returns: + The [`SelectionList`][textual.widgets.SelectionList] instance. + + Raises: + DuplicateID: If there is an attempt to use a duplicate ID. + SelectionError: If the selection option is of the wrong form. + """ + return self.add_options([item]) + + def clear_options(self) -> Self: + """Clear the content of the selection list. + + Returns: + The `SelectionList` instance. + """ + self._selected.clear() + self._values.clear() + return super().clear_options() diff --git a/src/memray/_vendor/textual/widgets/_sparkline.py b/src/memray/_vendor/textual/widgets/_sparkline.py new file mode 100644 index 0000000000..c0391bbc61 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_sparkline.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +from typing import Callable, ClassVar, Optional, Sequence + +from memray._vendor.textual.app import RenderResult +from memray._vendor.textual.color import Color +from memray._vendor.textual.reactive import reactive +from memray._vendor.textual.renderables.sparkline import Sparkline as SparklineRenderable +from memray._vendor.textual.widget import Widget + + +def _max_factory() -> Callable[[Sequence[float]], float]: + """Callable that returns the built-in max to initialise a reactive.""" + return max + + +class Sparkline(Widget): + """A sparkline widget to display numerical data.""" + + COMPONENT_CLASSES: ClassVar[set[str]] = { + "sparkline--max-color", + "sparkline--min-color", + } + """ + Use these component classes to define the two colors that the sparkline + interpolates to represent its numerical data. + + Note: + These two component classes are used exclusively for the _color_ of the + sparkline widget. Setting any style other than [`color`](/styles/color.md) + will have no effect. + + | Class | Description | + | :- | :- | + | `sparkline--max-color` | The color used for the larger values in the data. | + | `sparkline--min-color` | The color used for the smaller values in the data. | + """ + + DEFAULT_CSS = """ + Sparkline { + height: 1; + } + Sparkline > .sparkline--max-color { + color: $primary; + } + Sparkline > .sparkline--min-color { + color: $primary 30%; + } + """ + + data = reactive[Optional[Sequence[float]]](None) + """The data that populates the sparkline.""" + summary_function = reactive[Callable[[Sequence[float]], float]](_max_factory) + """The function that computes the value that represents each bar.""" + + def __init__( + self, + data: Sequence[float] | None = None, + *, + min_color: Color | str | None = None, + max_color: Color | str | None = None, + summary_function: Callable[[Sequence[float]], float] | None = None, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + ) -> None: + """Initialize a sparkline widget. + + Args: + data: The initial data to populate the sparkline with. + min_color: The color of the minimum value, or `None` to take from CSS. + max_color: the color of the maximum value, or `None` to take from CSS. + summary_function: Summarizes bar values into a single value used to + represent each bar. + name: The name of the widget. + id: The ID of the widget in the DOM. + classes: The CSS classes for the widget. + disabled: Whether the widget is disabled or not. + """ + super().__init__(name=name, id=id, classes=classes, disabled=disabled) + self.min_color = None if min_color is None else Color.parse(min_color) + self.max_color = None if max_color is None else Color.parse(max_color) + self.data = data + if summary_function is not None: + self.summary_function = summary_function + + def render(self) -> RenderResult: + """Renders the sparkline when there is data available.""" + data = self.data or [] + _, base = self.background_colors + min_color = base + ( + self.get_component_styles("sparkline--min-color").color + if self.min_color is None + else self.min_color + ) + max_color = base + ( + self.get_component_styles("sparkline--max-color").color + if self.max_color is None + else self.max_color + ) + return SparklineRenderable( + data, + width=self.size.width, + height=self.size.height, + min_color=min_color.rich_color, + max_color=max_color.rich_color, + summary_function=self.summary_function, + ) diff --git a/src/memray/_vendor/textual/widgets/_static.py b/src/memray/_vendor/textual/widgets/_static.py new file mode 100644 index 0000000000..c764114b38 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_static.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from memray._vendor.textual.app import RenderResult + +from memray._vendor.textual.visual import Visual, VisualType, visualize +from memray._vendor.textual.widget import Widget + + +class Static(Widget, inherit_bindings=False): + """A widget to display simple static content, or use as a base class for more complex widgets. + + Args: + content: A Content object, Rich renderable, or string containing console markup. + expand: Expand content if required to fill container. + shrink: Shrink content if required to fill container. + markup: True if markup should be parsed and rendered. + name: Name of widget. + id: ID of Widget. + classes: Space separated list of class names. + disabled: Whether the static is disabled or not. + """ + + DEFAULT_CSS = """ + Static { + height: auto; + } + """ + + def __init__( + self, + content: VisualType = "", + *, + expand: bool = False, + shrink: bool = False, + markup: bool = True, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + ) -> None: + super().__init__( + name=name, id=id, classes=classes, disabled=disabled, markup=markup + ) + self.set_reactive(Widget.expand, expand) + self.set_reactive(Widget.shrink, shrink) + self.__content = content + self.__visual: Visual | None = None + + @property + def visual(self) -> Visual: + """The visual to be displayed. + + Note that the visual is what is ultimately rendered in the widget, but may not be the + same object set with the `update` method or `content` property. For instance, if you + update with a string, then the visual will be a [Content][textual.content.Content] instance. + + """ + if self.__visual is None: + self.__visual = visualize(self, self.__content, markup=self._render_markup) + return self.__visual + + @property + def content(self) -> VisualType: + """The original content set in the constructor.""" + return self.__content + + @content.setter + def content(self, content: VisualType) -> None: + self.__content = content + self.__visual = visualize(self, content, markup=self._render_markup) + self.clear_cached_dimensions() + self.refresh(layout=True) + + def render(self) -> RenderResult: + """Get a rich renderable for the widget's content. + + Returns: + A rich renderable. + """ + return self.visual + + def update(self, content: VisualType = "", *, layout: bool = True) -> None: + """Update the widget's content area with a string, a Visual (such as [Content][textual.content.Content]), or a [Rich renderable](https://rich.readthedocs.io/en/latest/protocol.html). + + Args: + content: New content. + layout: Also perform a layout operation (set to `False` if you are certain the size won't change). + """ + + self.__content = content + self.__visual = visualize(self, content, markup=self._render_markup) + self.refresh(layout=layout) diff --git a/src/memray/_vendor/textual/widgets/_switch.py b/src/memray/_vendor/textual/widgets/_switch.py new file mode 100644 index 0000000000..6072798076 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_switch.py @@ -0,0 +1,207 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, ClassVar + +from rich.console import RenderableType + +if TYPE_CHECKING: + from memray._vendor.textual.app import RenderResult + from typing_extensions import Self + +from memray._vendor.textual.binding import Binding, BindingType +from memray._vendor.textual.events import Click +from memray._vendor.textual.geometry import Size +from memray._vendor.textual.message import Message +from memray._vendor.textual.reactive import reactive +from memray._vendor.textual.scrollbar import ScrollBarRender +from memray._vendor.textual.widget import Widget + + +class Switch(Widget, can_focus=True): + """A switch widget that represents a boolean value. + + Can be toggled by clicking on it or through its [bindings][textual.widgets.Switch.BINDINGS]. + + The switch widget also contains [component classes][textual.widgets.Switch.COMPONENT_CLASSES] + that enable more customization. + """ + + BINDINGS: ClassVar[list[BindingType]] = [ + Binding("enter,space", "toggle_switch", "Toggle", show=False), + ] + """ + | Key(s) | Description | + | :- | :- | + | enter,space | Toggle the switch state. | + """ + + COMPONENT_CLASSES: ClassVar[set[str]] = { + "switch--slider", + } + """ + | Class | Description | + | :- | :- | + | `switch--slider` | Targets the slider of the switch. | + """ + + ALLOW_SELECT = False + + DEFAULT_CSS = """ + Switch { + border: tall $border-blurred; + background: $surface; + height: auto; + width: auto; + pointer: pointer; + + padding: 0 2; + &.-on .switch--slider { + color: $success; + } + & .switch--slider { + color: $panel; + background: $panel-darken-2; + } + &:hover { + & > .switch--slider { + color: $panel-lighten-1 + } + &.-on > .switch--slider { + color: $success-lighten-1; + } + } + &:focus { + border: tall $border; + background-tint: $foreground 5%; + } + + &:light { + &.-on .switch--slider { + color: $success; + } + & .switch--slider { + color: $primary 15%; + background: $panel-darken-2; + } + &:hover { + & > .switch--slider { + color: $primary 25%; + } + &.-on > .switch--slider { + color: $success-lighten-1; + } + } + } + } + + """ + + value: reactive[bool] = reactive(False, init=False) + """The value of the switch; `True` for on and `False` for off.""" + + _slider_position = reactive(0.0) + """The position of the slider.""" + + class Changed(Message): + """Posted when the status of the switch changes. + + Can be handled using `on_switch_changed` in a subclass of `Switch` + or in a parent widget in the DOM. + + Attributes: + value: The value that the switch was changed to. + switch: The `Switch` widget that was changed. + """ + + def __init__(self, switch: Switch, value: bool) -> None: + super().__init__() + self.value: bool = value + self.switch: Switch = switch + + @property + def control(self) -> Switch: + """Alias for self.switch.""" + return self.switch + + def __init__( + self, + value: bool = False, + *, + animate: bool = True, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + tooltip: RenderableType | None = None, + ): + """Initialise the switch. + + Args: + value: The initial value of the switch. + animate: True if the switch should animate when toggled. + name: The name of the switch. + id: The ID of the switch in the DOM. + classes: The CSS classes of the switch. + disabled: Whether the switch is disabled or not. + tooltip: Optional tooltip. + """ + super().__init__(name=name, id=id, classes=classes, disabled=disabled) + if value: + self._slider_position = 1.0 + self.set_reactive(Switch.value, value) + self._should_animate = animate + if tooltip is not None: + self.tooltip = tooltip + + def watch_value(self, value: bool) -> None: + target_slider_position = 1.0 if value else 0.0 + if self._should_animate: + self.animate( + "_slider_position", + target_slider_position, + duration=0.3, + level="basic", + ) + else: + self._slider_position = target_slider_position + self.post_message(self.Changed(self, self.value)) + + def watch__slider_position(self, slider_position: float) -> None: + self.set_class(slider_position == 1, "-on") + + def render(self) -> RenderResult: + style = self.get_component_rich_style("switch--slider") + return ScrollBarRender( + virtual_size=100, + window_size=50, + position=self._slider_position * 50, + style=style, + vertical=False, + ) + + def get_content_width(self, container: Size, viewport: Size) -> int: + return 4 + + def get_content_height(self, container: Size, viewport: Size, width: int) -> int: + return 1 + + async def _on_click(self, event: Click) -> None: + """Toggle the state of the switch.""" + event.stop() + self.toggle() + + def action_toggle_switch(self) -> None: + """Toggle the state of the switch.""" + self.toggle() + + def toggle(self) -> Self: + """Toggle the switch value. + + As a result of the value changing, a `Switch.Changed` message will + be posted. + + Returns: + The `Switch` instance. + """ + self.value = not self.value + return self diff --git a/src/memray/_vendor/textual/widgets/_tab.py b/src/memray/_vendor/textual/widgets/_tab.py new file mode 100644 index 0000000000..3928d48b31 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_tab.py @@ -0,0 +1,3 @@ +from memray._vendor.textual.widgets._tabs import Tab + +__all__ = ["Tab"] diff --git a/src/memray/_vendor/textual/widgets/_tab_pane.py b/src/memray/_vendor/textual/widgets/_tab_pane.py new file mode 100644 index 0000000000..a3c0c6f083 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_tab_pane.py @@ -0,0 +1,3 @@ +from memray._vendor.textual.widgets._tabbed_content import TabPane + +__all__ = ["TabPane"] diff --git a/src/memray/_vendor/textual/widgets/_tabbed_content.py b/src/memray/_vendor/textual/widgets/_tabbed_content.py new file mode 100644 index 0000000000..44722142c8 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_tabbed_content.py @@ -0,0 +1,716 @@ +from __future__ import annotations + +from asyncio import gather +from dataclasses import dataclass +from itertools import zip_longest +from typing import Awaitable + +from rich.repr import Result +from typing_extensions import Final + +from memray._vendor.textual import events +from memray._vendor.textual.app import ComposeResult +from memray._vendor.textual.await_complete import AwaitComplete +from memray._vendor.textual.content import ContentText, ContentType +from memray._vendor.textual.css.query import NoMatches +from memray._vendor.textual.message import Message +from memray._vendor.textual.reactive import reactive +from memray._vendor.textual.widget import Widget +from memray._vendor.textual.widgets._content_switcher import ContentSwitcher +from memray._vendor.textual.widgets._tabs import Tab, Tabs + +__all__ = [ + "ContentTab", + "TabbedContent", + "TabPane", +] + + +class ContentTab(Tab): + """A Tab with an associated content id.""" + + _PREFIX: Final[str] = "--content-tab-" + """The prefix given to the tab IDs.""" + + @classmethod + def add_prefix(cls, content_id: str) -> str: + """Add the prefix to the given ID. + + Args: + content_id: The ID to add the prefix to. + + Returns: + The ID with the prefix added. + """ + return f"{cls._PREFIX}{content_id}" if content_id else content_id + + @classmethod + def sans_prefix(cls, content_id: str) -> str: + """Remove the prefix from the given ID. + + Args: + content_id: The ID to remove the prefix from. + + Returns: + The ID with the prefix removed. + """ + return ( + content_id[len(cls._PREFIX) :] + if content_id.startswith(cls._PREFIX) + else content_id + ) + + def __init__( + self, label: ContentType, content_id: str, disabled: bool = False + ) -> None: + """Initialize a ContentTab. + + Args: + label: The label to be displayed within the tab. + content_id: The id of the content associated with the tab. + disabled: Is the tab disabled? + """ + super().__init__(label, id=self.add_prefix(content_id), disabled=disabled) + + +class ContentTabs(Tabs): + """A Tabs which is associated with a TabbedContent.""" + + def __init__( + self, + *tabs: Tab | ContentText, + active: str | None = None, + tabbed_content: TabbedContent, + ): + """Initialize a ContentTabs. + + Args: + *tabs: The child tabs. + active: ID of the tab which should be active on start. + tabbed_content: The associated TabbedContent instance. + """ + super().__init__( + *tabs, active=active if active is None else ContentTab.add_prefix(active) + ) + self.tabbed_content = tabbed_content + + def get_content_tab(self, tab_id: str) -> ContentTab: + """Get the `ContentTab` associated with the given `TabPane` ID. + + Args: + tab_id: The ID of the tab to get. + + Returns: + The tab associated with that ID. + """ + return self.query_one(f"#{ContentTab.add_prefix(tab_id)}", ContentTab) + + def disable(self, tab_id: str) -> Tab: + """Disable the indicated tab. + + Args: + tab_id: The ID of the [`Tab`][textual.widgets.Tab] to disable. + + Returns: + The [`Tab`][textual.widgets.Tab] that was targeted. + + Raises: + TabError: If there are any issues with the request. + """ + return super().disable(ContentTab.add_prefix(tab_id)) + + def enable(self, tab_id: str) -> Tab: + """Enable the indicated tab. + + Args: + tab_id: The ID of the [`Tab`][textual.widgets.Tab] to enable. + + Returns: + The [`Tab`][textual.widgets.Tab] that was targeted. + + Raises: + TabError: If there are any issues with the request. + """ + return super().enable(ContentTab.add_prefix(tab_id)) + + def hide(self, tab_id: str) -> Tab: + """Hide the indicated tab. + + Args: + tab_id: The ID of the [`Tab`][textual.widgets.Tab] to hide. + + Returns: + The [`Tab`][textual.widgets.Tab] that was targeted. + + Raises: + TabError: If there are any issues with the request. + """ + return super().hide(ContentTab.add_prefix(tab_id)) + + def show(self, tab_id: str) -> Tab: + """Show the indicated tab. + + Args: + tab_id: The ID of the [`Tab`][textual.widgets.Tab] to show. + + Returns: + The [`Tab`][textual.widgets.Tab] that was targeted. + + Raises: + TabError: If there are any issues with the request. + """ + return super().show(ContentTab.add_prefix(tab_id)) + + +class TabPane(Widget): + """A container for switchable content, with additional title. + + This widget is intended to be used with [TabbedContent][textual.widgets.TabbedContent]. + """ + + DEFAULT_CSS = """ + TabPane { + height: auto; + } + """ + + @dataclass + class TabPaneMessage(Message): + """Base class for `TabPane` messages.""" + + tab_pane: TabPane + """The `TabPane` that is he object of this message.""" + + @property + def control(self) -> TabPane: + """The tab pane that is the object of this message. + + This is an alias for the attribute `tab_pane` and is used by the + [`on`][textual.on] decorator. + """ + return self.tab_pane + + @dataclass + class Disabled(TabPaneMessage): + """Sent when a tab pane is disabled via its reactive `disabled`.""" + + @dataclass + class Enabled(TabPaneMessage): + """Sent when a tab pane is enabled via its reactive `disabled`.""" + + @dataclass + class Focused(TabPaneMessage): + """Sent when a child widget is focused.""" + + def __init__( + self, + title: ContentType, + *children: Widget, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + ): + """Initialize a TabPane. + + Args: + title: Title of the TabPane (will be displayed in a tab label). + *children: Widget to go inside the TabPane. + name: Optional name for the TabPane. + id: Optional ID for the TabPane. + classes: Optional initial classes for the widget. + disabled: Whether the TabPane is disabled or not. + """ + self._title = self.render_str(title) + super().__init__( + *children, name=name, id=id, classes=classes, disabled=disabled + ) + + def _watch_disabled(self, disabled: bool) -> None: + """Notify the parent `TabbedContent` that a tab pane was enabled/disabled.""" + self.post_message(self.Disabled(self) if disabled else self.Enabled(self)) + + def _on_descendant_focus(self, event: events.DescendantFocus): + """Tell TabbedContent parent something is focused in this pane.""" + self.post_message(self.Focused(self)) + + +class TabbedContent(Widget): + """A container with associated tabs to toggle content visibility.""" + + ALLOW_MAXIMIZE = True + DEFAULT_CSS = """ + TabbedContent { + height: auto; + &> ContentTabs { + dock: top; + } + } + """ + + active: reactive[str] = reactive("", init=False) + """The ID of the active tab, or empty string if none are active.""" + + class TabActivated(Message): + """Posted when the active tab changes.""" + + ALLOW_SELECTOR_MATCH = {"pane"} + """Additional message attributes that can be used with the [`on` decorator][textual.on].""" + + def __init__(self, tabbed_content: TabbedContent, tab: ContentTab) -> None: + """Initialize message. + + Args: + tabbed_content: The TabbedContent widget. + tab: The Tab widget that was selected (contains the tab label). + """ + self.tabbed_content = tabbed_content + """The `TabbedContent` widget that contains the tab activated.""" + self.tab = tab + """The `Tab` widget that was selected (contains the tab label).""" + self.pane = tabbed_content.get_pane(tab) + """The `TabPane` widget that was activated by selecting the tab.""" + super().__init__() + + @property + def control(self) -> TabbedContent: + """The `TabbedContent` widget that contains the tab activated. + + This is an alias for [`TabActivated.tabbed_content`][textual.widgets.TabbedContent.TabActivated.tabbed_content] + and is used by the [`on`][textual.on] decorator. + """ + return self.tabbed_content + + def __rich_repr__(self) -> Result: + yield self.tabbed_content + yield self.tab + yield self.pane + + class Cleared(Message): + """Posted when no tab pane is active. + + This can happen if all tab panes are removed or if the currently active tab + pane is unset. + """ + + def __init__(self, tabbed_content: TabbedContent) -> None: + """Initialize message. + + Args: + tabbed_content: The TabbedContent widget. + """ + self.tabbed_content = tabbed_content + """The `TabbedContent` widget that contains the tab activated.""" + super().__init__() + + @property + def control(self) -> TabbedContent: + """The `TabbedContent` widget that was cleared of all tab panes. + + This is an alias for [`Cleared.tabbed_content`][textual.widgets.TabbedContent.Cleared.tabbed_content] + and is used by the [`on`][textual.on] decorator. + """ + return self.tabbed_content + + def __init__( + self, + *titles: ContentType, + initial: str = "", + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + ): + """Initialize a TabbedContent widgets. + + Args: + *titles: Positional argument will be used as title. + initial: The id of the initial tab, or empty string to select the first tab. + name: The name of the tabbed content. + id: The ID of the tabbed content in the DOM. + classes: The CSS classes of the tabbed content. + disabled: Whether the tabbed content is disabled or not. + """ + self.titles = [self.render_str(title) for title in titles] + self._tab_content: list[Widget] = [] + self._initial = initial + self._tab_counter = 0 + super().__init__(name=name, id=id, classes=classes, disabled=disabled) + + @property + def active_pane(self) -> TabPane | None: + """The currently active pane, or `None` if no pane is active.""" + active = self.active + if not active: + return None + return self.get_pane(self.active) + + @staticmethod + def _set_id(content: TabPane, new_id: int) -> TabPane: + """Set an id on the content, if not already present. + + Args: + content: a TabPane. + new_id: Numeric ID to make the pane ID from. + + Returns: + The same TabPane. + """ + if content.id is None: + content.id = f"tab-{new_id}" + return content + + def _generate_tab_id(self) -> int: + """Auto generate a new tab id. + + Returns: + An auto-incrementing integer. + """ + self._tab_counter += 1 + return self._tab_counter + + def compose(self) -> ComposeResult: + """Compose the tabbed content.""" + + # Wrap content in a `TabPane` if required. + pane_content = [ + self._set_id( + ( + content + if isinstance(content, TabPane) + else TabPane(title or self.render_str(f"Tab {index}"), content) + ), + self._generate_tab_id(), + ) + for index, (title, content) in enumerate( + zip_longest(self.titles, self._tab_content), 1 + ) + ] + # Get a tab for each pane + tabs = [ + ContentTab( + content._title, + content.id or "", + disabled=content.disabled, + ) + for content in pane_content + ] + + # Yield the tabs, and ensure they're linked to this TabbedContent. + # It's important to associate the Tabs with the TabbedContent, so that this + # TabbedContent can determine whether a message received from a Tabs instance + # has been sent from this Tabs, or from a Tabs that may exist as a descendant + # deeper in the DOM. + yield ContentTabs(*tabs, active=self._initial or None, tabbed_content=self) + + # Yield the content switcher and panes + with ContentSwitcher(initial=self._initial or None): + yield from pane_content + + def add_pane( + self, + pane: TabPane, + *, + before: TabPane | str | None = None, + after: TabPane | str | None = None, + ) -> AwaitComplete: + """Add a new pane to the tabbed content. + + Args: + pane: The pane to add. + before: Optional pane or pane ID to add the pane before. + after: Optional pane or pane ID to add the pane after. + + Returns: + An optionally awaitable object that waits for the pane to be added. + + Raises: + Tabs.TabError: If there is a problem with the addition request. + + Note: + Only one of `before` or `after` can be provided. If both are + provided an exception is raised. + """ + if isinstance(before, TabPane): + before = before.id + if isinstance(after, TabPane): + after = after.id + tabs = self.get_child_by_type(ContentTabs) + pane = self._set_id(pane, self._generate_tab_id()) + assert pane.id is not None + pane.display = False + return AwaitComplete( + tabs.add_tab( + ContentTab(pane._title, pane.id), + before=before if before is None else ContentTab.add_prefix(before), + after=after if after is None else ContentTab.add_prefix(after), + ), + self.get_child_by_type(ContentSwitcher).mount(pane), + ) + + def remove_pane(self, pane_id: str) -> AwaitComplete: + """Remove a given pane from the tabbed content. + + Args: + pane_id: The ID of the pane to remove. + + Returns: + An optionally awaitable object that waits for the pane to be removed + and the Cleared message to be posted. + """ + removal_awaitables: list[Awaitable] = [ + self.get_child_by_type(ContentTabs).remove_tab( + ContentTab.add_prefix(pane_id) + ) + ] + try: + removal_awaitables.append( + self.get_child_by_type(ContentSwitcher) + .get_child_by_id(pane_id) + .remove() + ) + except NoMatches: + # It's possible that the content itself may have gone away via + # other means; so allow that to be a no-op. + pass + + return AwaitComplete(*removal_awaitables) + + def clear_panes(self) -> AwaitComplete: + """Remove all the panes in the tabbed content. + + Returns: + An optionally awaitable object which waits for all panes to be removed + and the Cleared message to be posted. + """ + await_clear = gather( + self.get_child_by_type(ContentTabs).clear(), + self.get_child_by_type(ContentSwitcher).remove_children(), + ) + + async def _clear_content() -> None: + await await_clear + + return AwaitComplete(_clear_content()) + + def compose_add_child(self, widget: Widget) -> None: + """When using the context manager compose syntax, we want to attach nodes to the switcher. + + Args: + widget: A Widget to add. + """ + self._tab_content.append(widget) + + def _on_tabs_tab_activated(self, event: Tabs.TabActivated) -> None: + """User clicked a tab.""" + if self._is_associated_tabs(event.tabs): + # The message is relevant, so consume it and update state accordingly. + event.stop() + assert event.tab.id is not None + switcher = self.get_child_by_type(ContentSwitcher) + switcher.current = ContentTab.sans_prefix(event.tab.id) + with self.prevent(self.TabActivated): + # We prevent TabbedContent.TabActivated because it is also + # posted from the watcher for active, we're also about to + # post it below too, which is valid as here we're reacting + # to what the Tabs are doing. This ensures we don't get + # doubled-up messages. + self.active = ContentTab.sans_prefix(event.tab.id) + self.post_message( + TabbedContent.TabActivated( + tabbed_content=self, + tab=self.get_child_by_type(ContentTabs).get_content_tab( + self.active + ), + ) + ) + + def _on_tab_pane_focused(self, event: TabPane.Focused) -> None: + """One of the panes contains a widget that was programmatically focused.""" + event.stop() + if event.tab_pane.id is not None: + self.active = event.tab_pane.id + + def _on_tabs_cleared(self, event: Tabs.Cleared) -> None: + """Called when there are no active tabs. The tabs may have been cleared, + or they may all be hidden.""" + if self._is_associated_tabs(event.tabs): + event.stop() + self.get_child_by_type(ContentSwitcher).current = None + self.active = "" + + def _is_associated_tabs(self, tabs: Tabs) -> bool: + """Determine whether a tab is associated with this TabbedContent or not. + + A tab is "associated" with a `TabbedContent`, if it's one of the tabs that can + be used to control it. These have a special type: `ContentTab`, and are linked + back to this `TabbedContent` instance via a `tabbed_content` attribute. + + Args: + tabs: The Tabs instance to check. + + Returns: + True if the tab is associated with this `TabbedContent`. + """ + return isinstance(tabs, ContentTabs) and tabs.tabbed_content is self + + def _watch_active(self, active: str) -> None: + """Switch tabs when the active attributes changes.""" + with self.prevent(Tabs.TabActivated, Tabs.Cleared): + self.get_child_by_type(ContentTabs).active = ContentTab.add_prefix(active) + self.get_child_by_type(ContentSwitcher).current = active + if active: + self.post_message( + TabbedContent.TabActivated( + tabbed_content=self, + tab=self.get_child_by_type(ContentTabs).get_content_tab(active), + ) + ) + else: + self.post_message( + TabbedContent.Cleared(tabbed_content=self).set_sender(self) + ) + + @property + def tab_count(self) -> int: + """Total number of tabs.""" + return self.get_child_by_type(ContentTabs).tab_count + + def get_tab(self, pane_id: str | TabPane) -> Tab: + """Get the `Tab` associated with the given ID or `TabPane`. + + Args: + pane_id: The ID of the pane, or the pane itself. + + Returns: + The Tab associated with the ID. + + Raises: + ValueError: Raised if no ID was available. + """ + if target_id := (pane_id if isinstance(pane_id, str) else pane_id.id): + return self.get_child_by_type(ContentTabs).get_content_tab(target_id) + raise ValueError( + "'pane_id' must be a non-empty string or a TabPane with an id." + ) + + def get_pane(self, pane_id: str | ContentTab) -> TabPane: + """Get the `TabPane` associated with the given ID or tab. + + Args: + pane_id: The ID of the pane to get, or the Tab it is associated with. + + Returns: + The `TabPane` associated with the ID or the given tab. + + Raises: + ValueError: Raised if no ID was available. + """ + target_id: str | None = None + if isinstance(pane_id, ContentTab): + target_id = ( + pane_id.id if pane_id.id is None else ContentTab.sans_prefix(pane_id.id) + ) + else: + target_id = pane_id + if target_id: + pane = self.get_child_by_type(ContentSwitcher).get_child_by_id(target_id) + assert isinstance(pane, TabPane) + return pane + raise ValueError( + "'pane_id' must be a non-empty string or a ContentTab with an id." + ) + + def _on_tabs_tab_disabled(self, event: Tabs.TabDisabled) -> None: + """Disable the corresponding tab pane.""" + if event.tabs.parent is not self: + return + event.stop() + tab_id = event.tab.id or "" + try: + with self.prevent(TabPane.Disabled): + self.get_child_by_type(ContentSwitcher).get_child_by_id( + ContentTab.sans_prefix(tab_id), expect_type=TabPane + ).disabled = True + except NoMatches: + return + + def _on_tab_pane_disabled(self, event: TabPane.Disabled) -> None: + """Disable the corresponding tab.""" + event.stop() + try: + with self.prevent(Tab.Disabled): + self.get_tab(event.tab_pane).disabled = True + except NoMatches: + return + + def _on_tabs_tab_enabled(self, event: Tabs.TabEnabled) -> None: + """Enable the corresponding tab pane.""" + if event.tabs.parent is not self: + return + event.stop() + tab_id = event.tab.id or "" + try: + with self.prevent(TabPane.Enabled): + self.get_child_by_type(ContentSwitcher).get_child_by_id( + ContentTab.sans_prefix(tab_id), expect_type=TabPane + ).disabled = False + except NoMatches: + return + + def _on_tab_pane_enabled(self, event: TabPane.Enabled) -> None: + """Enable the corresponding tab.""" + event.stop() + try: + with self.prevent(Tab.Disabled): + self.get_tab(event.tab_pane).disabled = False + except NoMatches: + return + + def disable_tab(self, tab_id: str) -> None: + """Disables the tab with the given ID. + + Args: + tab_id: The ID of the [`TabPane`][textual.widgets.TabPane] to disable. + + Raises: + Tabs.TabError: If there are any issues with the request. + """ + + self.get_child_by_type(ContentTabs).disable(tab_id) + + def enable_tab(self, tab_id: str) -> None: + """Enables the tab with the given ID. + + Args: + tab_id: The ID of the [`TabPane`][textual.widgets.TabPane] to enable. + + Raises: + Tabs.TabError: If there are any issues with the request. + """ + + self.get_child_by_type(ContentTabs).enable(tab_id) + + def hide_tab(self, tab_id: str) -> None: + """Hides the tab with the given ID. + + Args: + tab_id: The ID of the [`TabPane`][textual.widgets.TabPane] to hide. + + Raises: + Tabs.TabError: If there are any issues with the request. + """ + + self.get_child_by_type(ContentTabs).hide(tab_id) + + def show_tab(self, tab_id: str) -> None: + """Shows the tab with the given ID. + + Args: + tab_id: The ID of the [`TabPane`][textual.widgets.TabPane] to show. + + Raises: + Tabs.TabError: If there are any issues with the request. + """ + + self.get_child_by_type(ContentTabs).show(tab_id) diff --git a/src/memray/_vendor/textual/widgets/_tabs.py b/src/memray/_vendor/textual/widgets/_tabs.py new file mode 100644 index 0000000000..04d22016e9 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_tabs.py @@ -0,0 +1,884 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import ClassVar + +import rich.repr +from rich.style import Style +from rich.text import Text + +from memray._vendor.textual import events +from memray._vendor.textual.app import ComposeResult, RenderResult +from memray._vendor.textual.await_complete import AwaitComplete +from memray._vendor.textual.binding import Binding, BindingType +from memray._vendor.textual.containers import Container, Horizontal, Vertical +from memray._vendor.textual.content import Content, ContentText +from memray._vendor.textual.css.query import NoMatches +from memray._vendor.textual.events import Mount +from memray._vendor.textual.geometry import Offset +from memray._vendor.textual.message import Message +from memray._vendor.textual.reactive import reactive +from memray._vendor.textual.renderables.bar import Bar +from memray._vendor.textual.visual import VisualType +from memray._vendor.textual.widget import Widget +from memray._vendor.textual.widgets import Static + + +class Underline(Widget): + """The animated underline beneath tabs.""" + + DEFAULT_CSS = """ + Underline { + width: 1fr; + height: 1; + & > .underline--bar { + color: $block-cursor-background; + background: $foreground 10%; + } + &:ansi { + text-style: dim; + } + } + """ + + COMPONENT_CLASSES = {"underline--bar"} + """ + | Class | Description | + | :- | :- | + | `underline--bar` | Style of the bar (may be used to change the color). | + """ + + highlight_start = reactive(0) + """First cell in highlight.""" + highlight_end = reactive(0) + """Last cell (inclusive) in highlight.""" + show_highlight: reactive[bool] = reactive(True) + """Flag to indicate if a highlight should be shown at all.""" + + class Clicked(Message): + """Inform ancestors the underline was clicked.""" + + offset: Offset + """The offset of the click, relative to the origin of the bar.""" + + def __init__(self, offset: Offset) -> None: + self.offset = offset + super().__init__() + + @property + def _highlight_range(self) -> tuple[int, int]: + """Highlighted range for underline bar.""" + return ( + (self.highlight_start, self.highlight_end) + if self.show_highlight + else (0, 0) + ) + + def render(self) -> RenderResult: + """Render the bar.""" + bar_style = self.get_component_rich_style("underline--bar") + return Bar( + highlight_range=self._highlight_range, + highlight_style=Style.from_color(bar_style.color), + background_style=Style.from_color(bar_style.bgcolor), + ) + + def _on_click(self, event: events.Click): + """Catch clicks, so that the underline can activate the tabs.""" + event.stop() + self.post_message(self.Clicked(event.screen_offset)) + + +class Tab(Static): + """A Widget to manage a single tab within a Tabs widget.""" + + DEFAULT_CSS = """ + Tab { + width: auto; + height: 1; + padding: 0 1; + text-align: center; + color: $foreground 50%; + pointer: pointer; + + &:hover { + color: $foreground; + } + &:disabled { + color: $foreground 25%; + } + + &.-active { + color: $foreground; + } + &.-hidden { + display: none; + } + } + """ + + ALLOW_SELECT = False + + @dataclass + class TabMessage(Message): + """Tab-related messages. + + These are mostly intended for internal use when interacting with `Tabs`. + """ + + tab: Tab + """The tab that is the object of this message.""" + + @property + def control(self) -> Tab: + """The tab that is the object of this message. + + This is an alias for the attribute `tab` and is used by the + [`on`][textual.on] decorator. + """ + return self.tab + + class Clicked(TabMessage): + """A tab was clicked.""" + + class Disabled(TabMessage): + """A tab was disabled.""" + + class Enabled(TabMessage): + """A tab was enabled.""" + + class Relabelled(TabMessage): + """A tab was relabelled.""" + + def __init__( + self, + label: ContentText, + *, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + ) -> None: + """Initialise a Tab. + + Args: + label: The label to use in the tab. + id: Optional ID for the widget. + classes: Space separated list of class names. + disabled: Whether the tab is disabled or not. + """ + super().__init__(id=id, classes=classes, disabled=disabled) + self._label: Content + # Setter takes Text or str + self.label = Content.from_text(label) + + @property + def label(self) -> Content: + """The label for the tab.""" + return self._label + + @label.setter + def label(self, label: ContentText) -> None: + self._label = Content.from_text(label) + self.update(self._label) + + def update(self, content: VisualType = "") -> None: + self.post_message(self.Relabelled(self)) + return super().update(content) + + @property + def label_text(self) -> str: + """Undecorated text of the label.""" + return self.label.plain + + def _on_click(self): + """Inform the message that the tab was clicked.""" + self.post_message(self.Clicked(self)) + + def _watch_disabled(self, disabled: bool) -> None: + """Notify the parent `Tabs` that a tab was enabled/disabled.""" + self.post_message(self.Disabled(self) if disabled else self.Enabled(self)) + + +class Tabs(Widget, can_focus=True): + """A row of tabs.""" + + DEFAULT_CSS = """ + Tabs { + width: 100%; + height: 2; + &:focus { + .underline--bar { + background: $foreground 30%; + } + & .-active { + text-style: $block-cursor-text-style; + color: $block-cursor-foreground; + background: $block-cursor-background; + } + } + + & > #tabs-scroll { + overflow: hidden; + } + + #tabs-list { + width: auto; + } + #tabs-list-bar, #tabs-list { + width: auto; + height: auto; + min-width: 100%; + overflow: hidden hidden; + } + &:ansi { + #tabs-list { + text-style: dim; + } + & #tabs-list > .-active { + text-style: not dim; + } + &:focus { + #tabs-list > .-active { + text-style: bold not dim; + } + } + & .underline--bar { + color: ansi_bright_blue; + background: ansi_default; + } + & .-active { + color: transparent; + background: transparent; + } + } + } + """ + + BINDINGS: ClassVar[list[BindingType]] = [ + Binding("left", "previous_tab", "Previous tab", show=False), + Binding("right", "next_tab", "Next tab", show=False), + ] + """ + | Key(s) | Description | + | :- | :- | + | left | Move to the previous tab. | + | right | Move to the next tab. | + """ + + class TabError(Exception): + """Exception raised when there is an error relating to tabs.""" + + class TabMessage(Message): + """Parent class for all messages that have to do with a specific tab.""" + + ALLOW_SELECTOR_MATCH = {"tab"} + """Additional message attributes that can be used with the [`on` decorator][textual.on].""" + + def __init__(self, tabs: Tabs, tab: Tab) -> None: + """Initialize event. + + Args: + tabs: The Tabs widget. + tab: The tab that is the object of this message. + """ + self.tabs: Tabs = tabs + """The tabs widget containing the tab.""" + self.tab: Tab = tab + """The tab that is the object of this message.""" + super().__init__() + + @property + def control(self) -> Tabs: + """The tabs widget containing the tab that is the object of this message. + + This is an alias for the attribute `tabs` and is used by the + [`on`][textual.on] decorator. + """ + return self.tabs + + def __rich_repr__(self) -> rich.repr.Result: + yield self.tabs + yield self.tab + + class TabActivated(TabMessage): + """Sent when a new tab is activated.""" + + class TabDisabled(TabMessage): + """Sent when a tab is disabled.""" + + class TabEnabled(TabMessage): + """Sent when a tab is enabled.""" + + class TabHidden(TabMessage): + """Sent when a tab is hidden.""" + + class TabShown(TabMessage): + """Sent when a tab is shown.""" + + class Cleared(Message): + """Sent when there are no active tabs. + + This can occur when Tabs are cleared, if all tabs are hidden, or if the + currently active tab is unset. + """ + + def __init__(self, tabs: Tabs) -> None: + """Initialize the event. + + Args: + tabs: The tabs widget. + """ + self.tabs: Tabs = tabs + """The tabs widget which was cleared.""" + super().__init__() + + @property + def control(self) -> Tabs: + """The tabs widget which was cleared. + + This is an alias for [`Cleared.tabs`][textual.widgets.Tabs.Cleared] which + is used by the [`on`][textual.on] decorator. + """ + return self.tabs + + def __rich_repr__(self) -> rich.repr.Result: + yield self.tabs + + active: reactive[str] = reactive("", init=False) + """The ID of the active tab, or empty string if none are active.""" + + def __init__( + self, + *tabs: Tab | ContentText, + active: str | None = None, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + ): + """Construct a Tabs widget. + + Args: + *tabs: Positional argument should be explicit Tab objects, or a str or Text. + active: ID of the tab which should be active on start. + name: Optional name for the tabs widget. + id: Optional ID for the widget. + classes: Optional initial classes for the widget. + disabled: Whether the widget is disabled or not. + """ + self._tabs_counter = 0 + + add_tabs = [ + ( + Tab(tab, id=f"tab-{self._new_tab_id}") + if isinstance(tab, (str, Content, Text)) + else self._auto_tab_id(tab) + ) + for tab in tabs + ] + super().__init__( + name=name, + id=id, + classes=classes, + disabled=disabled, + ) + self._tabs = add_tabs + self._first_active = active + + def _auto_tab_id(self, tab: Tab) -> Tab: + """Set an automatic ID if not supplied.""" + if tab.id is None: + tab.id = f"tab-{self._new_tab_id}" + return tab + + @property + def _new_tab_id(self) -> int: + """Get the next tab id in a sequence.""" + self._tabs_counter += 1 + return self._tabs_counter + + @property + def tab_count(self) -> int: + """Total number of tabs.""" + return len(self.query("#tabs-list > Tab")) + + @property + def _potentially_active_tabs(self) -> list[Tab]: + """List of all tabs that could be active. + + This list is comprised of all tabs that are shown and enabled, + plus the active tab in case it is disabled. + """ + return [ + tab + for tab in self.query("#tabs-list > Tab").results(Tab) + if ((not tab.disabled or tab is self.active_tab) and tab.display) + ] + + @property + def _next_active(self) -> Tab | None: + """Next tab to make active if the active tab is removed.""" + tabs = self._potentially_active_tabs + if self.active_tab is None: + return None + try: + active_index = tabs.index(self.active_tab) + except ValueError: + return None + del tabs[active_index] + try: + return tabs[active_index] + except IndexError: + try: + return tabs[active_index - 1] + except IndexError: + pass + return None + + def add_tab( + self, + tab: Tab | ContentText, + *, + before: Tab | str | None = None, + after: Tab | str | None = None, + ) -> AwaitComplete: + """Add a new tab to the end of the tab list. + + Args: + tab: A new tab object, or a label (str or Text). + before: Optional tab or tab ID to add the tab before. + after: Optional tab or tab ID to add the tab after. + + Returns: + An optionally awaitable object that waits for the tab to be mounted and + internal state to be fully updated to reflect the new tab. + + Raises: + Tabs.TabError: If there is a problem with the addition request. + + Note: + Only one of `before` or `after` can be provided. If both are + provided a `Tabs.TabError` will be raised. + """ + + if before and after: + raise self.TabError("Unable to add a tab both before and after a tab") + + if isinstance(before, str): + try: + before = self.query_one(f"#tabs-list > #{before}", Tab) + except NoMatches: + raise self.TabError( + f"There is no tab with ID '{before}' to mount before" + ) + elif isinstance(before, Tab) and self not in before.ancestors: + raise self.TabError( + "Request to add a tab before a tab that isn't part of this tab collection" + ) + + if isinstance(after, str): + try: + after = self.query_one(f"#tabs-list > #{after}", Tab) + except NoMatches: + raise self.TabError(f"There is no tab with ID '{after}' to mount after") + elif isinstance(after, Tab) and self not in after.ancestors: + raise self.TabError( + "Request to add a tab after a tab that isn't part of this tab collection" + ) + + from_empty = self.tab_count == 0 + tab_widget = ( + Tab(tab, id=f"tab-{self._new_tab_id}") + if isinstance(tab, (str, Content, Text)) + else self._auto_tab_id(tab) + ) + + mount_await = self.query_one("#tabs-list").mount( + tab_widget, before=before, after=after + ) + + if from_empty: + tab_widget.add_class("-active") + activated_message = self.TabActivated(self, tab_widget) + + async def refresh_active() -> None: + """Wait for things to be mounted before highlighting.""" + await mount_await + self.active = tab_widget.id or "" + self._highlight_active(animate=False) + self.post_message(activated_message) + + return AwaitComplete(refresh_active()) + elif before or after: + + async def refresh_active() -> None: + await mount_await + self._highlight_active(animate=False) + + return AwaitComplete(refresh_active()) + + return AwaitComplete(mount_await()) + + def clear(self) -> AwaitComplete: + """Clear all the tabs. + + Returns: + An awaitable object that waits for the tabs to be removed. + """ + underline = self.query_one(Underline) + underline.highlight_start = 0 + underline.highlight_end = 0 + self.post_message(self.Cleared(self)) + self.active = "" + return AwaitComplete(self.query("#tabs-list > Tab").remove()) + + def get_tab(self, tab_id: str) -> Tab | None: + """Get a tab from its ID. + + Args: + tab_id: The tab ID. + + Returns: + The Tab instance, or `None` if no tab with the given ID. + """ + try: + tab = self.query_one(f"#tabs-list > #{tab_id}", Tab) + except NoMatches: + return None + return tab + + def remove_tab(self, tab_or_id: Tab | str | None) -> AwaitComplete: + """Remove a tab. + + Args: + tab_or_id: The Tab to remove or its id. + + Returns: + An optionally awaitable object that waits for the tab to be removed. + """ + if not tab_or_id: + return AwaitComplete() + + if isinstance(tab_or_id, Tab): + remove_tab = tab_or_id + else: + try: + remove_tab = self.query_one(f"#tabs-list > #{tab_or_id}", Tab) + except NoMatches: + return AwaitComplete() + + if remove_tab.has_class("-active"): + next_tab = self._next_active + else: + next_tab = None + + async def do_remove() -> None: + """Perform the remove after refresh so the underline bar gets new positions.""" + await remove_tab.remove() + if not self.query("#tabs-list > Tab"): + self.active = "" + elif next_tab is not None: + self.active = next_tab.id or "" + else: + self._highlight_active(animate=False) + + return AwaitComplete(do_remove()) + + def validate_active(self, active: str) -> str: + """Check id assigned to active attribute is a valid tab.""" + if active and not self.query(f"#tabs-list > #{active}"): + raise ValueError(f"No Tab with id {active!r}") + return active + + @property + def active_tab(self) -> Tab | None: + """The currently active tab, or None if there are no active tabs.""" + try: + return self.query_one("#tabs-list Tab.-active", Tab) + except NoMatches: + return None + + def _on_mount(self, _: Mount) -> None: + """Make the first tab active.""" + if self._first_active is not None: + self.active = self._first_active + if not self.active: + try: + tab = self.query("#tabs-list > Tab").first(Tab) + except NoMatches: + # Tabs are empty! + return + self.active = tab.id or "" + + def compose(self) -> ComposeResult: + with Container(id="tabs-scroll"): + with Vertical(id="tabs-list-bar"): + with Horizontal(id="tabs-list"): + yield from self._tabs + yield Underline() + + def watch_active(self, previously_active: str, active: str) -> None: + """Handle a change to the active tab.""" + self.query("#tabs-list > Tab.-active").remove_class("-active") + if active: + try: + active_tab = self.query_one(f"#tabs-list > #{active}", Tab) + except NoMatches: + return + active_tab.add_class("-active") + + self._highlight_active(animate=previously_active != "") + + self._scroll_active_tab() + self.post_message(self.TabActivated(self, active_tab)) + else: + underline = self.query_one(Underline) + underline.highlight_start = 0 + underline.highlight_end = 0 + self.post_message(self.Cleared(self)) + + def _highlight_active( + self, + animate: bool = True, + ) -> None: + """Move the underline bar to under the active tab. + + Args: + animate: Should the bar animate? + """ + underline = self.query_one(Underline) + try: + _active_tab = self.query_one("#tabs-list > Tab.-active") + except NoMatches: + underline.show_highlight = False + underline.highlight_start = 0 + underline.highlight_end = 0 + else: + underline.show_highlight = True + + def move_underline(animate: bool) -> None: + """Move the tab underline. + + Args: + animate: animate the underline to its new position. + """ + try: + active_tab = self.query_one("#tabs-list > Tab.-active") + except NoMatches: + pass + else: + tab_region = active_tab.virtual_region.shrink( + active_tab.styles.gutter + ) + start, end = tab_region.column_span + if animate: + underline.animate( + "highlight_start", + start, + duration=0.3, + level="basic", + ) + underline.animate( + "highlight_end", + end, + duration=0.3, + level="basic", + ) + else: + underline.highlight_start = start + underline.highlight_end = end + + if animate and self.app.animation_level != "none": + self.set_timer( + 0.02, + lambda: self.call_after_refresh(move_underline, True), + ) + else: + self.call_after_refresh(move_underline, False) + + async def _on_tab_clicked(self, event: Tab.Clicked) -> None: + """Activate a tab that was clicked.""" + self.focus() + event.stop() + self._activate_tab(event.tab) + + def _activate_tab(self, tab: Tab) -> None: + """Activate a tab. + + Args: + tab: The Tab that was clicked. + """ + self.query("#tabs-list Tab.-active").remove_class("-active") + tab.add_class("-active") + self.active = tab.id or "" + + def _on_underline_clicked(self, event: Underline.Clicked) -> None: + """The underline was clicked. + + Activate the tab above to make a larger clickable area. + + Args: + event: The Underline.Clicked event. + """ + event.stop() + offset = event.offset + (0, -1) + self.focus() + for tab in self.query(Tab): + if offset in tab.region and not tab.disabled: + self._activate_tab(tab) + break + + def _scroll_active_tab(self) -> None: + """Scroll the active tab into view.""" + if self.active_tab: + try: + self.query_one("#tabs-scroll").scroll_to_center( + self.active_tab, force=True + ) + except NoMatches: + pass + + def _on_resize(self): + """Make the active tab visible on resize.""" + self._highlight_active(animate=False) + self._scroll_active_tab() + + def action_next_tab(self) -> None: + """Make the next tab active.""" + self._move_tab(+1) + + def action_previous_tab(self) -> None: + """Make the previous tab active.""" + self._move_tab(-1) + + def _move_tab(self, direction: int) -> None: + """Activate the next enabled tab in the given direction. + + Tab selection wraps around. If no tab is currently active, the "next" + tab is set to be the first and the "previous" tab is the last one. + + Args: + direction: +1 for the next tab, -1 for the previous. + """ + active_tab = self.active_tab + tabs = self._potentially_active_tabs + if not tabs: + return + if not active_tab: + self.active = tabs[0 if direction == 1 else -1].id or "" + return + tab_count = len(tabs) + new_tab_index = (tabs.index(active_tab) + direction) % tab_count + self.active = tabs[new_tab_index].id or "" + + def _on_tab_disabled(self, event: Tab.Disabled) -> None: + """Re-post the disabled message.""" + event.stop() + self.post_message(self.TabDisabled(self, event.tab)) + + def _on_tab_enabled(self, event: Tab.Enabled) -> None: + """Re-post the enabled message.""" + event.stop() + self.post_message(self.TabEnabled(self, event.tab)) + + def _on_tab_relabelled(self, event: Tab.Relabelled) -> None: + """Redraw the highlight when tab is relabelled.""" + event.stop() + self._highlight_active() + + def disable(self, tab_id: str) -> Tab: + """Disable the indicated tab. + + Args: + tab_id: The ID of the [`Tab`][textual.widgets.Tab] to disable. + + Returns: + The [`Tab`][textual.widgets.Tab] that was targeted. + + Raises: + TabError: If there are any issues with the request. + """ + + try: + tab_to_disable = self.query_one(f"#tabs-list > Tab#{tab_id}", Tab) + except NoMatches: + raise self.TabError( + f"There is no tab with ID {tab_id!r} to disable." + ) from None + + tab_to_disable.disabled = True + return tab_to_disable + + def enable(self, tab_id: str) -> Tab: + """Enable the indicated tab. + + Args: + tab_id: The ID of the [`Tab`][textual.widgets.Tab] to enable. + + Returns: + The [`Tab`][textual.widgets.Tab] that was targeted. + + Raises: + TabError: If there are any issues with the request. + """ + + try: + tab_to_enable = self.query_one(f"#tabs-list > Tab#{tab_id}", Tab) + except NoMatches: + raise self.TabError( + f"There is no tab with ID {tab_id!r} to enable." + ) from None + + tab_to_enable.disabled = False + return tab_to_enable + + def hide(self, tab_id: str) -> Tab: + """Hide the indicated tab. + + Args: + tab_id: The ID of the [`Tab`][textual.widgets.Tab] to hide. + + Returns: + The [`Tab`][textual.widgets.Tab] that was targeted. + + Raises: + TabError: If there are any issues with the request. + """ + + try: + tab_to_hide = self.query_one(f"#tabs-list > Tab#{tab_id}", Tab) + except NoMatches: + raise self.TabError(f"There is no tab with ID {tab_id!r} to hide.") + + if tab_to_hide.has_class("-active"): + next_tab = self._next_active + self.active = next_tab.id or "" if next_tab else "" + tab_to_hide.add_class("-hidden") + self.post_message(self.TabHidden(self, tab_to_hide).set_sender(self)) + self.call_after_refresh(self._highlight_active) + return tab_to_hide + + def show(self, tab_id: str) -> Tab: + """Show the indicated tab. + + Args: + tab_id: The ID of the [`Tab`][textual.widgets.Tab] to show. + + Returns: + The [`Tab`][textual.widgets.Tab] that was targeted. + + Raises: + TabError: If there are any issues with the request. + """ + + try: + tab_to_show = self.query_one(f"#tabs-list > Tab#{tab_id}", Tab) + except NoMatches: + raise self.TabError(f"There is no tab with ID {tab_id!r} to show.") + + tab_to_show.remove_class("-hidden") + self.post_message(self.TabShown(self, tab_to_show).set_sender(self)) + if not self.active: + self._activate_tab(tab_to_show) + self.call_after_refresh(self._highlight_active) + return tab_to_show diff --git a/src/memray/_vendor/textual/widgets/_text_area.py b/src/memray/_vendor/textual/widgets/_text_area.py new file mode 100644 index 0000000000..adb69b2c84 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_text_area.py @@ -0,0 +1,2655 @@ +from __future__ import annotations + +import dataclasses +import re +from collections import defaultdict +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path +from typing import TYPE_CHECKING, ClassVar, Iterable, Optional, Sequence, Tuple + +from rich.console import RenderableType +from rich.segment import Segment +from rich.style import Style +from rich.text import Text +from typing_extensions import Literal + +from memray._vendor.textual._text_area_theme import TextAreaTheme +from memray._vendor.textual._tree_sitter import TREE_SITTER, get_language +from memray._vendor.textual.actions import SkipAction +from memray._vendor.textual.cache import LRUCache +from memray._vendor.textual.color import Color +from memray._vendor.textual.content import Content +from memray._vendor.textual.document._document import ( + Document, + DocumentBase, + EditResult, + Location, + Selection, + _utf8_encode, +) +from memray._vendor.textual.document._document_navigator import DocumentNavigator +from memray._vendor.textual.document._edit import Edit +from memray._vendor.textual.document._history import EditHistory +from memray._vendor.textual.document._syntax_aware_document import ( + SyntaxAwareDocument, + SyntaxAwareDocumentError, +) +from memray._vendor.textual.document._wrapped_document import WrappedDocument +from memray._vendor.textual.expand_tabs import expand_tabs_inline, expand_text_tabs_from_widths +from memray._vendor.textual.screen import Screen +from memray._vendor.textual.style import Style as ContentStyle + +if TYPE_CHECKING: + from tree_sitter import Language, Query + +from memray._vendor.textual import events, log +from memray._vendor.textual._cells import cell_len, cell_width_to_column_index +from memray._vendor.textual.binding import Binding +from memray._vendor.textual.events import Message, MouseEvent +from memray._vendor.textual.geometry import Offset, Region, Size, Spacing, clamp +from memray._vendor.textual.reactive import Reactive, reactive +from memray._vendor.textual.scroll_view import ScrollView +from memray._vendor.textual.strip import Strip + +_OPENING_BRACKETS = {"{": "}", "[": "]", "(": ")"} +_CLOSING_BRACKETS = {v: k for k, v in _OPENING_BRACKETS.items()} +_TREE_SITTER_PATH = Path(__file__).parent / "../tree-sitter/" +_HIGHLIGHTS_PATH = _TREE_SITTER_PATH / "highlights/" + +StartColumn = int +EndColumn = Optional[int] +HighlightName = str +Highlight = Tuple[StartColumn, EndColumn, HighlightName] +"""A tuple representing a syntax highlight within one line.""" + +BUILTIN_LANGUAGES = [ + "python", + "markdown", + "json", + "toml", + "yaml", + "html", + "css", + "javascript", + "rust", + "go", + "regex", + "sql", + "java", + "bash", + "xml", +] +"""Languages that are included in the `syntax` extras.""" + + +class ThemeDoesNotExist(Exception): + """Raised when the user tries to use a theme which does not exist. + This means a theme which is not builtin, or has not been registered. + """ + + +class LanguageDoesNotExist(Exception): + """Raised when the user tries to use a language which does not exist. + This means a language which is not builtin, or has not been registered. + """ + + +@dataclass +class TextAreaLanguage: + """A container for a language which has been registered with the TextArea.""" + + name: str + """The name of the language""" + + language: "Language" | None + """The tree-sitter language object if that has been overridden, or None if it is a built-in language.""" + + highlight_query: str + """The tree-sitter highlight query to use for syntax highlighting.""" + + +class TextArea(ScrollView): + DEFAULT_CSS = """\ +TextArea { + width: 1fr; + height: 1fr; + border: tall $border-blurred; + padding: 0 1; + color: $foreground; + background: $surface; + pointer: text; + &.-textual-compact { + border: none !important; + } + & .text-area--cursor { + text-style: $input-cursor-text-style; + } + & .text-area--gutter { + color: $foreground 40%; + } + + & .text-area--cursor-gutter { + color: $foreground 60%; + background: $boost; + text-style: bold; + } + + & .text-area--cursor-line { + background: $boost; + } + + & .text-area--selection { + background: $input-selection-background; + } + + & .text-area--matching-bracket { + background: $foreground 30%; + } + + & .text-area--suggestion { + color: $text-muted; + } + + & .text-area--placeholder { + color: $text 40%; + } + + &:focus { + border: tall $border; + } + + &:ansi { + & .text-area--selection { + background: transparent; + text-style: reverse; + } + } + + &:dark { + .text-area--cursor { + color: $input-cursor-foreground; + background: $input-cursor-background; + } + &.-read-only .text-area--cursor { + background: $warning-darken-1; + } + } + + &:light { + .text-area--cursor { + color: $text 90%; + background: $foreground 70%; + } + &.-read-only .text-area--cursor { + background: $warning-darken-1; + } + } +} +""" + + COMPONENT_CLASSES: ClassVar[set[str]] = { + "text-area--cursor", + "text-area--gutter", + "text-area--cursor-gutter", + "text-area--cursor-line", + "text-area--selection", + "text-area--matching-bracket", + "text-area--suggestion", + "text-area--placeholder", + } + """ + `TextArea` offers some component classes which can be used to style aspects of the widget. + + Note that any attributes provided in the chosen `TextAreaTheme` will take priority here. + + | Class | Description | + | :- | :- | + | `text-area--cursor` | Target the cursor. | + | `text-area--gutter` | Target the gutter (line number column). | + | `text-area--cursor-gutter` | Target the gutter area of the line the cursor is on. | + | `text-area--cursor-line` | Target the line the cursor is on. | + | `text-area--selection` | Target the current selection. | + | `text-area--matching-bracket` | Target matching brackets. | + | `text-area--suggestion` | Target the text set in the `suggestion` reactive. | + | `text-area--placeholder` | Target the placeholder text. | + """ + + BINDINGS = [ + # Cursor movement + Binding("up", "cursor_up", "Cursor up", show=False), + Binding("down", "cursor_down", "Cursor down", show=False), + Binding("left", "cursor_left", "Cursor left", show=False), + Binding("right", "cursor_right", "Cursor right", show=False), + Binding("ctrl+left", "cursor_word_left", "Cursor word left", show=False), + Binding("ctrl+right", "cursor_word_right", "Cursor word right", show=False), + Binding("home,ctrl+a", "cursor_line_start", "Cursor line start", show=False), + Binding("end,ctrl+e", "cursor_line_end", "Cursor line end", show=False), + Binding("pageup", "cursor_page_up", "Cursor page up", show=False), + Binding("pagedown", "cursor_page_down", "Cursor page down", show=False), + # Making selections (generally holding the shift key and moving cursor) + Binding( + "ctrl+shift+left", + "cursor_word_left(True)", + "Cursor left word select", + show=False, + ), + Binding( + "ctrl+shift+right", + "cursor_word_right(True)", + "Cursor right word select", + show=False, + ), + Binding( + "shift+home", + "cursor_line_start(True)", + "Cursor line start select", + show=False, + ), + Binding( + "shift+end", "cursor_line_end(True)", "Cursor line end select", show=False + ), + Binding("shift+up", "cursor_up(True)", "Cursor up select", show=False), + Binding("shift+down", "cursor_down(True)", "Cursor down select", show=False), + Binding("shift+left", "cursor_left(True)", "Cursor left select", show=False), + Binding("shift+right", "cursor_right(True)", "Cursor right select", show=False), + # Shortcut ways of making selections + # Binding("f5", "select_word", "select word", show=False), + Binding("f6", "select_line", "Select line", show=False), + Binding("f7", "select_all", "Select all", show=False), + # Deletion + Binding("backspace", "delete_left", "Delete character left", show=False), + Binding( + "ctrl+w", "delete_word_left", "Delete left to start of word", show=False + ), + Binding("delete,ctrl+d", "delete_right", "Delete character right", show=False), + Binding( + "ctrl+f", "delete_word_right", "Delete right to start of word", show=False + ), + Binding("ctrl+x", "cut", "Cut", show=False), + Binding("ctrl+c,super+c", "copy", "Copy", show=False), + Binding("ctrl+v", "paste", "Paste", show=False), + Binding( + "ctrl+u", "delete_to_start_of_line", "Delete to line start", show=False + ), + Binding( + "ctrl+k", + "delete_to_end_of_line_or_delete_line", + "Delete to line end", + show=False, + ), + Binding( + "ctrl+shift+k", + "delete_line", + "Delete line", + show=False, + ), + Binding("ctrl+z", "undo", "Undo", show=False), + Binding("ctrl+y", "redo", "Redo", show=False), + ] + """ + | Key(s) | Description | + | :- | :- | + | up | Move the cursor up. | + | down | Move the cursor down. | + | left | Move the cursor left. | + | ctrl+left | Move the cursor to the start of the word. | + | ctrl+shift+left | Move the cursor to the start of the word and select. | + | right | Move the cursor right. | + | ctrl+right | Move the cursor to the end of the word. | + | ctrl+shift+right | Move the cursor to the end of the word and select. | + | home,ctrl+a | Move the cursor to the start of the line. | + | end,ctrl+e | Move the cursor to the end of the line. | + | shift+home | Move the cursor to the start of the line and select. | + | shift+end | Move the cursor to the end of the line and select. | + | pageup | Move the cursor one page up. | + | pagedown | Move the cursor one page down. | + | shift+up | Select while moving the cursor up. | + | shift+down | Select while moving the cursor down. | + | shift+left | Select while moving the cursor left. | + | shift+right | Select while moving the cursor right. | + | backspace | Delete character to the left of cursor. | + | ctrl+w | Delete from cursor to start of the word. | + | delete,ctrl+d | Delete character to the right of cursor. | + | ctrl+f | Delete from cursor to end of the word. | + | ctrl+shift+k | Delete the current line. | + | ctrl+u | Delete from cursor to the start of the line. | + | ctrl+k | Delete from cursor to the end of the line. | + | f6 | Select the current line. | + | f7 | Select all text in the document. | + | ctrl+z | Undo. | + | ctrl+y | Redo. | + | ctrl+x | Cut selection or line if no selection. | + | ctrl+c | Copy selection to clipboard. | + | ctrl+v | Paste from clipboard. | + """ + + language: Reactive[str | None] = reactive(None, always_update=True, init=False) + """The language to use. + + This must be set to a valid, non-None value for syntax highlighting to work. + + If the value is a string, a built-in language parser will be used if available. + + If you wish to use an unsupported language, you'll have to register + it first using [`TextArea.register_language`][textual.widgets._text_area.TextArea.register_language]. + """ + + theme: Reactive[str] = reactive("css", always_update=True, init=False) + """The name of the theme to use. + + Themes must be registered using [`TextArea.register_theme`][textual.widgets._text_area.TextArea.register_theme] before they can be used. + + Syntax highlighting is only possible when the `language` attribute is set. + """ + + selection: Reactive[Selection] = reactive( + Selection(), init=False, always_update=True + ) + """The selection start and end locations (zero-based line_index, offset). + + This represents the cursor location and the current selection. + + The `Selection.end` always refers to the cursor location. + + If no text is selected, then `Selection.end == Selection.start` is True. + + The text selected in the document is available via the `TextArea.selected_text` property. + """ + + show_line_numbers: Reactive[bool] = reactive(False, init=False) + """True to show the line number column on the left edge, otherwise False. + + Changing this value will immediately re-render the `TextArea`.""" + + line_number_start: Reactive[int] = reactive(1, init=False) + """The line number the first line should be.""" + + indent_width: Reactive[int] = reactive(4, init=False) + """The width of tabs or the multiple of spaces to align to on pressing the `tab` key. + + If the document currently open contains tabs that are currently visible on screen, + altering this value will immediately change the display width of the visible tabs. + """ + + match_cursor_bracket: Reactive[bool] = reactive(True, init=False) + """If the cursor is at a bracket, highlight the matching bracket (if found).""" + + cursor_blink: Reactive[bool] = reactive(True, init=False) + """True if the cursor should blink.""" + + soft_wrap: Reactive[bool] = reactive(True, init=False) + """True if text should soft wrap.""" + + read_only: Reactive[bool] = reactive(False) + """True if the content is read-only. + + Read-only means end users cannot insert, delete or replace content. + + The document can still be edited programmatically via the API. + """ + + show_cursor: Reactive[bool] = reactive(True) + """Show the cursor in read only mode? + + If `True`, the cursor will be visible when `read_only==True`. + If `False`, the cursor will be hidden when `read_only==True`, and the TextArea will + scroll like other containers. + + """ + + compact: reactive[bool] = reactive(False, toggle_class="-textual-compact") + """Enable compact display?""" + + highlight_cursor_line: reactive[bool] = reactive(True) + """Highlight the line under the cursor?""" + + _cursor_visible: Reactive[bool] = reactive(False, repaint=False, init=False) + """Indicates where the cursor is in the blink cycle. If it's currently + not visible due to blinking, this is False.""" + + suggestion: Reactive[str] = reactive("") + """A suggestion for auto-complete (pressing right will insert it).""" + + hide_suggestion_on_blur: Reactive[bool] = reactive(True) + """Hide suggestion when the TextArea does not have focus.""" + + placeholder: Reactive[str | Content] = reactive("") + """Text to show when the text area has no content.""" + + @dataclass + class Changed(Message): + """Posted when the content inside the TextArea changes. + + Handle this message using the `on` decorator - `@on(TextArea.Changed)` + or a method named `on_text_area_changed`. + """ + + text_area: TextArea + """The `text_area` that sent this message.""" + + @property + def control(self) -> TextArea: + """The `TextArea` that sent this message.""" + return self.text_area + + @dataclass + class SelectionChanged(Message): + """Posted when the selection changes. + + This includes when the cursor moves or when text is selected.""" + + selection: Selection + """The new selection.""" + text_area: TextArea + """The `text_area` that sent this message.""" + + @property + def control(self) -> TextArea: + return self.text_area + + def __init__( + self, + text: str = "", + *, + language: str | None = None, + theme: str = "css", + soft_wrap: bool = True, + tab_behavior: Literal["focus", "indent"] = "focus", + read_only: bool = False, + show_cursor: bool = True, + show_line_numbers: bool = False, + line_number_start: int = 1, + max_checkpoints: int = 50, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + tooltip: RenderableType | None = None, + compact: bool = False, + highlight_cursor_line: bool = True, + placeholder: str | Content = "", + ) -> None: + """Construct a new `TextArea`. + + Args: + text: The initial text to load into the TextArea. + language: The language to use. + theme: The theme to use. + soft_wrap: Enable soft wrapping. + tab_behavior: If 'focus', pressing tab will switch focus. If 'indent', pressing tab will insert a tab. + read_only: Enable read-only mode. This prevents edits using the keyboard. + show_cursor: Show the cursor in read only mode (no effect otherwise). + show_line_numbers: Show line numbers on the left edge. + line_number_start: What line number to start on. + max_checkpoints: The maximum number of undo history checkpoints to retain. + name: The name of the `TextArea` widget. + id: The ID of the widget, used to refer to it from Textual CSS. + classes: One or more Textual CSS compatible class names separated by spaces. + disabled: True if the widget is disabled. + tooltip: Optional tooltip. + compact: Enable compact style (without borders). + highlight_cursor_line: Highlight the line under the cursor. + placeholder: Text to display when there is not content. + """ + super().__init__(name=name, id=id, classes=classes, disabled=disabled) + + self._languages: dict[str, TextAreaLanguage] = {} + """Maps language names to TextAreaLanguage. This is only used for languages + registered by end-users using `TextArea.register_language`. If a user attempts + to set `TextArea.language` to a language that is not registered here, we'll + attempt to get it from the environment. If that fails, we'll fall back to + plain text. + """ + + self._themes: dict[str, TextAreaTheme] = {} + """Maps theme names to TextAreaTheme.""" + + self.indent_type: Literal["tabs", "spaces"] = "spaces" + """Whether to indent using tabs or spaces.""" + + self._word_pattern = re.compile(r"(?<=\W)(?=\w)|(?<=\w)(?=\W)") + """Compiled regular expression for what we consider to be a 'word'.""" + + self.history: EditHistory = EditHistory( + max_checkpoints=max_checkpoints, + checkpoint_timer=2.0, + checkpoint_max_characters=100, + ) + """A stack (the end of the list is the top of the stack) for tracking edits.""" + + self._selecting = False + """True if we're currently selecting text using the mouse, otherwise False.""" + + self._matching_bracket_location: Location | None = None + """The location (row, column) of the bracket which matches the bracket the + cursor is currently at. If the cursor is at a bracket, or there's no matching + bracket, this will be `None`.""" + + self._highlights: dict[int, list[Highlight]] = defaultdict(list) + """Mapping line numbers to the set of highlights for that line.""" + + self._highlight_query: "Query | None" = None + """The query that's currently being used for highlighting.""" + + self.document: DocumentBase = Document(text) + """The document this widget is currently editing.""" + + self.wrapped_document: WrappedDocument = WrappedDocument(self.document) + """The wrapped view of the document.""" + + self.navigator: DocumentNavigator = DocumentNavigator(self.wrapped_document) + """Queried to determine where the cursor should move given a navigation + action, accounting for wrapping etc.""" + + self._cursor_offset = (0, 0) + """The virtual offset of the cursor (not screen-space offset).""" + + self.set_reactive(TextArea.soft_wrap, soft_wrap) + self.set_reactive(TextArea.read_only, read_only) + self.set_reactive(TextArea.show_cursor, show_cursor) + self.set_reactive(TextArea.show_line_numbers, show_line_numbers) + self.set_reactive(TextArea.line_number_start, line_number_start) + self.set_reactive(TextArea.highlight_cursor_line, highlight_cursor_line) + self.set_reactive(TextArea.placeholder, placeholder) + + self._line_cache: LRUCache[tuple, Strip] = LRUCache(1024) + + self._set_document(text, language) + + self.language = language + self.theme = theme + + self._theme: TextAreaTheme + """The `TextAreaTheme` corresponding to the set theme name. When the `theme` + reactive is set as a string, the watcher will update this attribute to the + corresponding `TextAreaTheme` object.""" + + self.tab_behavior = tab_behavior + + if tooltip is not None: + self.tooltip = tooltip + + self.compact = compact + + @classmethod + def code_editor( + cls, + text: str = "", + *, + language: str | None = None, + theme: str = "monokai", + soft_wrap: bool = False, + tab_behavior: Literal["focus", "indent"] = "indent", + read_only: bool = False, + show_cursor: bool = True, + show_line_numbers: bool = True, + line_number_start: int = 1, + max_checkpoints: int = 50, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + tooltip: RenderableType | None = None, + compact: bool = False, + highlight_cursor_line: bool = True, + placeholder: str | Content = "", + ) -> TextArea: + """Construct a new `TextArea` with sensible defaults for editing code. + + This instantiates a `TextArea` with line numbers enabled, soft wrapping + disabled, "indent" tab behavior, and the "monokai" theme. + + Args: + text: The initial text to load into the TextArea. + language: The language to use. + theme: The theme to use. + soft_wrap: Enable soft wrapping. + tab_behavior: If 'focus', pressing tab will switch focus. If 'indent', pressing tab will insert a tab. + read_only: Enable read-only mode. This prevents edits using the keyboard. + show_cursor: Show the cursor in read only mode (no effect otherwise). + show_line_numbers: Show line numbers on the left edge. + line_number_start: What line number to start on. + name: The name of the `TextArea` widget. + id: The ID of the widget, used to refer to it from Textual CSS. + classes: One or more Textual CSS compatible class names separated by spaces. + disabled: True if the widget is disabled. + tooltip: Optional tooltip + compact: Enable compact style (without borders). + highlight_cursor_line: Highlight the line under the cursor. + """ + return cls( + text, + language=language, + theme=theme, + soft_wrap=soft_wrap, + tab_behavior=tab_behavior, + read_only=read_only, + show_cursor=show_cursor, + show_line_numbers=show_line_numbers, + line_number_start=line_number_start, + max_checkpoints=max_checkpoints, + name=name, + id=id, + classes=classes, + disabled=disabled, + tooltip=tooltip, + compact=compact, + highlight_cursor_line=highlight_cursor_line, + placeholder=placeholder, + ) + + @staticmethod + def _get_builtin_highlight_query(language_name: str) -> str: + """Get the highlight query for a builtin language. + + Args: + language_name: The name of the builtin language. + + Returns: + The highlight query. + """ + try: + highlight_query_path = ( + Path(_HIGHLIGHTS_PATH.resolve()) / f"{language_name}.scm" + ) + highlight_query = highlight_query_path.read_text() + except OSError as error: + log.warning(f"Unable to load highlight query. {error}") + highlight_query = "" + + return highlight_query + + def notify_style_update(self) -> None: + self._line_cache.clear() + super().notify_style_update() + + def update_suggestion(self) -> None: + """A hook to update the [`suggestion`][textual.widgets.TextArea.suggestion] attribute.""" + + def check_consume_key(self, key: str, character: str | None = None) -> bool: + """Check if the widget may consume the given key. + + As a textarea we are expecting to capture printable keys. + + Args: + key: A key identifier. + character: A character associated with the key, or `None` if there isn't one. + + Returns: + `True` if the widget may capture the key in its `Key` message, or `False` if it won't. + """ + if self.read_only: + # In read only mode we don't consume any key events + return False + if self.tab_behavior == "indent" and key == "tab": + # If tab_behavior is indent, then we consume the tab + return True + # Otherwise we capture all printable keys + return character is not None and character.isprintable() + + def _build_highlight_map(self) -> None: + """Query the tree for ranges to highlights, and update the internal highlights mapping.""" + self._line_cache.clear() + highlights = self._highlights + highlights.clear() + if not self._highlight_query: + return + + captures = self.document.query_syntax_tree(self._highlight_query) + for highlight_name, nodes in captures.items(): + for node in nodes: + node_start_row, node_start_column = node.start_point + node_end_row, node_end_column = node.end_point + + if node_start_row == node_end_row: + highlight = (node_start_column, node_end_column, highlight_name) + highlights[node_start_row].append(highlight) + else: + # Add the first line of the node range + highlights[node_start_row].append( + (node_start_column, None, highlight_name) + ) + + # Add the middle lines - entire row of this node is highlighted + for node_row in range(node_start_row + 1, node_end_row): + highlights[node_row].append((0, None, highlight_name)) + + # Add the last line of the node range + highlights[node_end_row].append( + (0, node_end_column, highlight_name) + ) + + def _watch_has_focus(self, focus: bool) -> None: + self._cursor_visible = focus + if focus: + self._restart_blink() + self.app.cursor_position = self.cursor_screen_offset + self.history.checkpoint() + else: + self._pause_blink(visible=False) + + def _watch_selection( + self, previous_selection: Selection, selection: Selection + ) -> None: + """When the cursor moves, scroll it into view.""" + # Find the visual offset of the cursor in the document + + if not self.is_mounted: + return + + self.app.clear_selection() + + cursor_location = selection.end + + self.scroll_cursor_visible() + + cursor_row, cursor_column = cursor_location + + try: + character = self.document[cursor_row][cursor_column] + except IndexError: + character = "" + + # Record the location of a matching closing/opening bracket. + match_location = self.find_matching_bracket(character, cursor_location) + self._matching_bracket_location = match_location + if match_location is not None: + _, offset_y = self._cursor_offset + self.refresh_lines(offset_y) + + self.app.cursor_position = self.cursor_screen_offset + if previous_selection != selection: + self.post_message(self.SelectionChanged(selection, self)) + + def _watch_cursor_blink(self, blink: bool) -> None: + if not self.is_mounted: + return None + if blink and self.has_focus: + self._restart_blink() + else: + self._pause_blink(visible=self.has_focus) + + def _watch_read_only(self, read_only: bool) -> None: + self.set_class(read_only, "-read-only") + self._set_theme(self._theme.name) + + def _recompute_cursor_offset(self): + """Recompute the (x, y) coordinate of the cursor in the wrapped document.""" + self._cursor_offset = self.wrapped_document.location_to_offset( + self.cursor_location + ) + + def find_matching_bracket( + self, bracket: str, search_from: Location + ) -> Location | None: + """If the character is a bracket, find the matching bracket. + + Args: + bracket: The character we're searching for the matching bracket of. + search_from: The location to start the search. + + Returns: + The `Location` of the matching bracket, or `None` if it's not found. + If the character is not available for bracket matching, `None` is returned. + """ + match_location = None + bracket_stack: list[str] = [] + if bracket in _OPENING_BRACKETS: + # Search forwards for a closing bracket + for candidate, candidate_location in self._yield_character_locations( + search_from + ): + if candidate in _OPENING_BRACKETS: + bracket_stack.append(candidate) + elif candidate in _CLOSING_BRACKETS: + if ( + bracket_stack + and bracket_stack[-1] == _CLOSING_BRACKETS[candidate] + ): + bracket_stack.pop() + if not bracket_stack: + match_location = candidate_location + break + elif bracket in _CLOSING_BRACKETS: + # Search backwards for an opening bracket + for ( + candidate, + candidate_location, + ) in self._yield_character_locations_reverse(search_from): + if candidate in _CLOSING_BRACKETS: + bracket_stack.append(candidate) + elif candidate in _OPENING_BRACKETS: + if ( + bracket_stack + and bracket_stack[-1] == _OPENING_BRACKETS[candidate] + ): + bracket_stack.pop() + if not bracket_stack: + match_location = candidate_location + break + + return match_location + + def _validate_selection(self, selection: Selection) -> Selection: + """Clamp the selection to valid locations.""" + start, end = selection + clamp_visitable = self.clamp_visitable + return Selection(clamp_visitable(start), clamp_visitable(end)) + + def _watch_language(self, language: str | None) -> None: + """When the language is updated, update the type of document.""" + self._set_document(self.document.text, language) + + def _watch_show_line_numbers(self) -> None: + """The line number gutter contributes to virtual size, so recalculate.""" + self._rewrap_and_refresh_virtual_size() + self.scroll_cursor_visible() + + def _watch_line_number_start(self) -> None: + """The line number gutter max size might change and contributes to virtual size, so recalculate.""" + self._rewrap_and_refresh_virtual_size() + self.scroll_cursor_visible() + + def _watch_indent_width(self) -> None: + """Changing width of tabs will change the document display width.""" + self._rewrap_and_refresh_virtual_size() + self.scroll_cursor_visible() + + def _watch_show_vertical_scrollbar(self) -> None: + if self.wrap_width: + self._rewrap_and_refresh_virtual_size() + self.scroll_cursor_visible() + + def _watch_theme(self, theme: str) -> None: + """We set the styles on this widget when the theme changes, to ensure that + if padding is applied, the colors match.""" + self._set_theme(theme) + + def _app_theme_changed(self) -> None: + self._set_theme(self._theme.name) + + def _set_theme(self, theme: str) -> None: + theme_object: TextAreaTheme | None + + # If the user supplied a string theme name, find it and apply it. + try: + theme_object = self._themes[theme] + except KeyError: + theme_object = TextAreaTheme.get_builtin_theme(theme) + if theme_object is None: + raise ThemeDoesNotExist( + f"{theme!r} is not a builtin theme, or it has not been registered. " + f"To use a custom theme, register it first using `register_theme`, " + f"then switch to that theme by setting the `TextArea.theme` attribute." + ) from None + + self._theme = dataclasses.replace(theme_object) + if theme_object: + base_style = theme_object.base_style + if base_style: + color = base_style.color + background = base_style.bgcolor + if color: + self.styles.color = Color.from_rich_color(color) + if background: + self.styles.background = Color.from_rich_color(background) + else: + # When the theme doesn't define a base style (e.g. the `css` theme), + # the TextArea background/color should fallback to its CSS colors. + # + # Since these styles may have already been changed by another theme, + # we need to reset the background/color styles to the default values. + self.styles.color = None + self.styles.background = None + + @property + def available_themes(self) -> set[str]: + """A list of the names of the themes available to the `TextArea`. + + The values in this list can be assigned `theme` reactive attribute of + `TextArea`. + + You can retrieve the full specification for a theme by passing one of + the strings from this list into `TextAreaTheme.get_by_name(theme_name: str)`. + + Alternatively, you can directly retrieve a list of `TextAreaTheme` objects + (which contain the full theme specification) by calling + `TextAreaTheme.builtin_themes()`. + """ + return { + theme.name for theme in TextAreaTheme.builtin_themes() + } | self._themes.keys() + + def register_theme(self, theme: TextAreaTheme) -> None: + """Register a theme for use by the `TextArea`. + + After registering a theme, you can set themes by assigning the theme + name to the `TextArea.theme` reactive attribute. For example + `text_area.theme = "my_custom_theme"` where `"my_custom_theme"` is the + name of the theme you registered. + + If you supply a theme with a name that already exists that theme + will be overwritten. + """ + self._themes[theme.name] = theme + + @property + def available_languages(self) -> set[str]: + """A set of the names of languages available to the `TextArea`. + + The values in this set can be assigned to the `language` reactive attribute + of `TextArea`. + + The returned set contains the builtin languages installed with the syntax extras, + plus those registered via the `register_language` method. + """ + return set(BUILTIN_LANGUAGES) | self._languages.keys() + + def register_language( + self, + name: str, + language: "Language", + highlight_query: str, + ) -> None: + """Register a language and corresponding highlight query. + + Calling this method does not change the language of the `TextArea`. + On switching to this language (via the `language` reactive attribute), + syntax highlighting will be performed using the given highlight query. + + If a string `name` is supplied for a builtin supported language, then + this method will update the default highlight query for that language. + + Registering a language only registers it to this instance of `TextArea`. + + Args: + name: The name of the language. + language: A tree-sitter `Language` object. + highlight_query: The highlight query to use for syntax highlighting this language. + """ + self._languages[name] = TextAreaLanguage(name, language, highlight_query) + + def update_highlight_query(self, name: str, highlight_query: str) -> None: + """Update the highlight query for an already registered language. + + Args: + name: The name of the language. + highlight_query: The highlight query to use for syntax highlighting this language. + """ + if name not in self._languages: + self._languages[name] = TextAreaLanguage(name, None, highlight_query) + else: + self._languages[name].highlight_query = highlight_query + + # If this is the currently loaded language, reload the document because + # it could be a different highlight query for the same language. + if name == self.language: + self._set_document(self.text, name) + + def _set_document(self, text: str, language: str | None) -> None: + """Construct and return an appropriate document. + + Args: + text: The text of the document. + language: The name of the language to use. This must correspond to a tree-sitter + language available in the current environment (e.g. use `python` for `tree-sitter-python`). + If None, the document will be treated as plain text. + """ + self._highlight_query = None + if TREE_SITTER and language: + if language in self._languages: + # User-registered languages take priority. + highlight_query = self._languages[language].highlight_query + document_language = self._languages[language].language + if document_language is None: + document_language = get_language(language) + else: + # No user-registered language, so attempt to use a built-in language. + highlight_query = self._get_builtin_highlight_query(language) + document_language = get_language(language) + + # No built-in language, and no user-registered language: use plain text and warn. + if document_language is None: + raise LanguageDoesNotExist( + f"tree-sitter is available, but no built-in or user-registered language called {language!r}.\n" + f"Ensure the language is installed (e.g. `pip install tree-sitter-ruby`)\n" + f"Falling back to plain text." + ) + else: + document: DocumentBase + try: + document = SyntaxAwareDocument(text, document_language) + except SyntaxAwareDocumentError: + document = Document(text) + log.warning( + f"Parser not found for language {document_language!r}. Parsing disabled." + ) + else: + self._highlight_query = document.prepare_query(highlight_query) + elif language and not TREE_SITTER: + # User has supplied a language i.e. `TextArea(language="python")`, but they + # don't have tree-sitter available in the environment. We fallback to plain text. + log.warning( + "tree-sitter not available in this environment. Parsing disabled.\n" + "You may need to install the `syntax` extras alongside textual.\n" + "Try `pip install 'textual[syntax]'` or '`poetry add textual[syntax]' to get started quickly.\n\n" + "Alternatively, install tree-sitter manually (`pip install tree-sitter`) and then\n" + "install the required language (e.g. `pip install tree-sitter-ruby`), then register it.\n" + "and its highlight query using TextArea.register_language().\n\n" + "Falling back to plain text for now." + ) + document = Document(text) + else: + # tree-sitter is available, but the user has supplied None or "" for the language. + # Use a regular plain-text document. + document = Document(text) + + self.document = document + self.wrapped_document = WrappedDocument(document, tab_width=self.indent_width) + self.navigator = DocumentNavigator(self.wrapped_document) + self._build_highlight_map() + self.move_cursor((0, 0)) + self._rewrap_and_refresh_virtual_size() + + @property + def _visible_line_indices(self) -> tuple[int, int]: + """Return the visible line indices as a tuple (top, bottom). + + Returns: + A tuple (top, bottom) indicating the top and bottom visible line indices. + """ + _, scroll_offset_y = self.scroll_offset + return scroll_offset_y, scroll_offset_y + self.size.height + + def _watch_scroll_x(self) -> None: + self.app.cursor_position = self.cursor_screen_offset + + def _watch_scroll_y(self) -> None: + self.app.cursor_position = self.cursor_screen_offset + + def load_text(self, text: str) -> None: + """Load text into the TextArea. + + This will replace the text currently in the TextArea and clear the edit history. + + Args: + text: The text to load into the TextArea. + """ + self.history.clear() + self._set_document(text, self.language) + self.post_message(self.Changed(self).set_sender(self)) + self.update_suggestion() + + def _on_resize(self) -> None: + self._rewrap_and_refresh_virtual_size() + + def _watch_soft_wrap(self) -> None: + self._rewrap_and_refresh_virtual_size() + self.call_after_refresh(self.scroll_cursor_visible, center=True) + + @property + def wrap_width(self) -> int: + """The width which gets used when the document wraps. + + Accounts for gutter, scrollbars, etc. + """ + width, _ = self.scrollable_content_region.size + cursor_width = 1 + if self.soft_wrap: + return max(0, width - self.gutter_width - cursor_width) + return 0 + + def _rewrap_and_refresh_virtual_size(self) -> None: + self.wrapped_document.wrap(self.wrap_width, tab_width=self.indent_width) + self._line_cache.clear() + self._refresh_size() + + @property + def is_syntax_aware(self) -> bool: + """True if the TextArea is currently syntax aware - i.e. it's parsing document content.""" + return isinstance(self.document, SyntaxAwareDocument) + + def _yield_character_locations( + self, start: Location + ) -> Iterable[tuple[str, Location]]: + """Yields character locations starting from the given location. + + Does not yield location of line separator characters like `\\n`. + + Args: + start: The location to start yielding from. + + Returns: + Yields tuples of (character, (row, column)). + """ + row, column = start + document = self.document + line_count = document.line_count + + while 0 <= row < line_count: + line = document[row] + while column < len(line): + yield line[column], (row, column) + column += 1 + column = 0 + row += 1 + + def _yield_character_locations_reverse( + self, start: Location + ) -> Iterable[tuple[str, Location]]: + row, column = start + document = self.document + line_count = document.line_count + + while line_count > row >= 0: + line = document[row] + if column == -1: + column = len(line) - 1 + while column >= 0: + yield line[column], (row, column) + column -= 1 + row -= 1 + + def _refresh_size(self) -> None: + """Update the virtual size of the TextArea.""" + if self.soft_wrap: + self.virtual_size = Size(0, self.wrapped_document.height) + else: + # +1 width to make space for the cursor resting at the end of the line + width, height = self.document.get_size(self.indent_width) + self.virtual_size = Size(width + self.gutter_width + 1, height) + self._refresh_scrollbars() + + @property + def _draw_cursor(self) -> bool: + """Draw the cursor?""" + if self.read_only: + # If we are in read only mode, we don't want the cursor to blink + return self.show_cursor and self.has_focus + draw_cursor = ( + self.has_focus + and not self.cursor_blink + or (self.cursor_blink and self._cursor_visible) + ) + return draw_cursor + + @property + def _has_cursor(self) -> bool: + """Is there a usable cursor?""" + return not (self.read_only and not self.show_cursor) + + def get_line(self, line_index: int) -> Text: + """Retrieve the line at the given line index. + + You can stylize the Text object returned here to apply additional + styling to TextArea content. + + Args: + line_index: The index of the line. + + Returns: + A `rich.Text` object containing the requested line. + """ + line_string = self.document.get_line(line_index) + return Text(line_string, end="", no_wrap=True) + + def render_lines(self, crop: Region) -> list[Strip]: + theme = self._theme + if theme: + theme.apply_css(self) + return super().render_lines(crop) + + def render_line(self, y: int) -> Strip: + """Render a single line of the TextArea. Called by Textual. + + Args: + y: Y Coordinate of line relative to the widget region. + + Returns: + A rendered line. + """ + + if not self.text and self.placeholder: + placeholder_lines = Content.from_text(self.placeholder).wrap( + self.content_size.width + ) + if y < len(placeholder_lines): + style = self.get_visual_style("text-area--placeholder") + content = placeholder_lines[y].stylize(style) + if self._draw_cursor and y == 0: + theme = self._theme + cursor_style = theme.cursor_style if theme else None + if cursor_style: + content = content.stylize( + ContentStyle.from_rich_style(cursor_style), 0, 1 + ) + return Strip( + content.render_segments(self.visual_style), content.cell_length + ) + + scroll_x, scroll_y = self.scroll_offset + absolute_y = scroll_y + y + selection = self.selection + _, cursor_y = self._cursor_offset + cache_key = ( + self.size, + scroll_x, + absolute_y, + ( + selection + if selection.contains_line(absolute_y) or self.soft_wrap + else selection.end[0] == absolute_y + ), + ( + selection.end + if ( + self._cursor_visible + and self.cursor_blink + and absolute_y == cursor_y + ) + else None + ), + self.theme, + self._matching_bracket_location, + self.match_cursor_bracket, + self.soft_wrap, + self.show_line_numbers, + self.read_only, + self.show_cursor, + self.suggestion, + ) + if (cached_line := self._line_cache.get(cache_key)) is not None: + return cached_line + line = self._render_line(y) + self._line_cache[cache_key] = line + return line + + def _render_line(self, y: int) -> Strip: + """Render a single line of the TextArea. Called by Textual. + + Args: + y: Y Coordinate of line relative to the widget region. + + Returns: + A rendered line. + """ + theme = self._theme + base_style = ( + theme.base_style + if theme and theme.base_style is not None + else self.rich_style + ) + + wrapped_document = self.wrapped_document + scroll_x, scroll_y = self.scroll_offset + + # Account for how much the TextArea is scrolled. + y_offset = y + scroll_y + + # If we're beyond the height of the document, render blank lines + out_of_bounds = y_offset >= wrapped_document.height + + if out_of_bounds: + return Strip.blank(self.size.width, base_style) + + # Get the line corresponding to this offset + try: + line_info = wrapped_document._offset_to_line_info[y_offset] + except IndexError: + line_info = None + + if line_info is None: + return Strip.blank(self.size.width, base_style) + + line_index, section_offset = line_info + + line = self.get_line(line_index) + line_character_count = len(line) + line.tab_size = self.indent_width + line.set_length(line_character_count + 1) # space at end for cursor + virtual_width, _virtual_height = self.virtual_size + + selection = self.selection + start, end = selection + cursor_row, cursor_column = end + + selection_top, selection_bottom = sorted(selection) + selection_top_row, selection_top_column = selection_top + selection_bottom_row, selection_bottom_column = selection_bottom + + highlight_cursor_line = self.highlight_cursor_line and self._has_cursor + cursor_line_style = ( + theme.cursor_line_style if (theme and highlight_cursor_line) else None + ) + has_cursor = self._has_cursor + + if has_cursor and cursor_line_style and cursor_row == line_index: + line.stylize(cursor_line_style) + + # Selection styling + if start != end and selection_top_row <= line_index <= selection_bottom_row: + # If this row intersects with the selection range + selection_style = theme.selection_style if theme else None + cursor_row, _ = end + if selection_style: + if line_character_count == 0 and line_index != cursor_row: + # A simple highlight to show empty lines are included in the selection + line.plain = "▌" + line.stylize(Style(color=selection_style.bgcolor)) + else: + if line_index == selection_top_row == selection_bottom_row: + # Selection within a single line + line.stylize( + selection_style, + start=selection_top_column, + end=selection_bottom_column, + ) + else: + # Selection spanning multiple lines + if line_index == selection_top_row: + line.stylize( + selection_style, + start=selection_top_column, + end=line_character_count, + ) + elif line_index == selection_bottom_row: + line.stylize(selection_style, end=selection_bottom_column) + else: + line.stylize(selection_style, end=line_character_count) + + highlights = self._highlights + if highlights and theme: + line_bytes = _utf8_encode(line.plain) + byte_to_codepoint = build_byte_to_codepoint_dict(line_bytes) + get_highlight_from_theme = theme.syntax_styles.get + line_highlights = highlights[line_index] + for highlight_start, highlight_end, highlight_name in line_highlights: + node_style = get_highlight_from_theme(highlight_name) + if node_style is not None: + line.stylize( + node_style, + byte_to_codepoint.get(highlight_start, 0), + byte_to_codepoint.get(highlight_end) if highlight_end else None, + ) + + # Highlight the cursor + matching_bracket = self._matching_bracket_location + match_cursor_bracket = self.match_cursor_bracket + draw_matched_brackets = ( + has_cursor + and match_cursor_bracket + and matching_bracket is not None + and start == end + ) + + if cursor_row == line_index: + draw_cursor = self._draw_cursor + if draw_matched_brackets: + matching_bracket_style = theme.bracket_matching_style if theme else None + if matching_bracket_style: + line.stylize( + matching_bracket_style, + cursor_column, + cursor_column + 1, + ) + + if self.suggestion and (self.has_focus or not self.hide_suggestion_on_blur): + suggestion_style = self.get_component_rich_style( + "text-area--suggestion" + ) + line = Text.assemble( + line[:cursor_column], + (self.suggestion, suggestion_style), + line[cursor_column:], + ) + + if draw_cursor: + cursor_style = theme.cursor_style if theme else None + if cursor_style: + line.stylize(cursor_style, cursor_column, cursor_column + 1) + + # Highlight the partner opening/closing bracket. + if draw_matched_brackets: + # mypy doesn't know matching bracket is guaranteed to be non-None + assert matching_bracket is not None + bracket_match_row, bracket_match_column = matching_bracket + if theme and bracket_match_row == line_index: + matching_bracket_style = theme.bracket_matching_style + if matching_bracket_style: + line.stylize( + matching_bracket_style, + bracket_match_column, + bracket_match_column + 1, + ) + + # Build the gutter text for this line + gutter_width = self.gutter_width + if self.show_line_numbers: + if cursor_row == line_index and highlight_cursor_line: + gutter_style = theme.cursor_line_gutter_style + else: + gutter_style = theme.gutter_style + + gutter_width_no_margin = gutter_width - 2 + gutter_content = ( + str(line_index + self.line_number_start) if section_offset == 0 else "" + ) + gutter = [ + Segment(f"{gutter_content:>{gutter_width_no_margin}} ", gutter_style) + ] + else: + gutter = [] + + # TODO: Lets not apply the division each time through render_line. + # We should cache sections with the edit counts. + wrap_offsets = wrapped_document.get_offsets(line_index) + if wrap_offsets: + sections = line.divide(wrap_offsets) # TODO cache result with edit count + line = sections[section_offset] + line_tab_widths = wrapped_document.get_tab_widths(line_index) + line.end = "" + + # Get the widths of the tabs corresponding only to the section of the + # line that is currently being rendered. We don't care about tabs in + # other sections of the same line. + + # Count the tabs before this section. + tabs_before = 0 + for section_index in range(section_offset): + tabs_before += sections[section_index].plain.count("\t") + + # Count the tabs in this section. + tabs_within = line.plain.count("\t") + section_tab_widths = line_tab_widths[ + tabs_before : tabs_before + tabs_within + ] + line = expand_text_tabs_from_widths(line, section_tab_widths) + else: + line.expand_tabs(self.indent_width) + + base_width = ( + self.scrollable_content_region.size.width + if self.soft_wrap + else max(virtual_width, self.region.size.width) + ) + target_width = base_width - self.gutter_width + + # Crop the line to show only the visible part (some may be scrolled out of view) + console = self.app.console + text_strip = Strip(line.render(console), cell_length=line.cell_len) + if not self.soft_wrap: + text_strip = text_strip.crop(scroll_x, scroll_x + virtual_width) + + # Stylize the line the cursor is currently on. + if cursor_row == line_index and self.highlight_cursor_line: + line_style = cursor_line_style + else: + line_style = theme.base_style if theme else None + + text_strip = text_strip.extend_cell_length(target_width, line_style) + if gutter: + strip = Strip.join([Strip(gutter, cell_length=gutter_width), text_strip]) + else: + strip = text_strip + + return strip.apply_style(base_style) + + @property + def text(self) -> str: + """The entire text content of the document.""" + return self.document.text + + @text.setter + def text(self, value: str) -> None: + """Replace the text currently in the TextArea. This is an alias of `load_text`. + + Setting this value will clear the edit history. + + Args: + value: The text to load into the TextArea. + """ + self.load_text(value) + + @property + def selected_text(self) -> str: + """The text between the start and end points of the current selection.""" + start, end = self.selection + return self.get_text_range(start, end) + + @property + def matching_bracket_location(self) -> Location | None: + """The location of the matching bracket, if there is one.""" + return self._matching_bracket_location + + def get_text_range(self, start: Location, end: Location) -> str: + """Get the text between a start and end location. + + Args: + start: The start location. + end: The end location. + + Returns: + The text between start and end. + """ + start, end = sorted((start, end)) + return self.document.get_text_range(start, end) + + def edit(self, edit: Edit) -> EditResult: + """Perform an Edit. + + Args: + edit: The Edit to perform. + + Returns: + Data relating to the edit that may be useful. The data returned + may be different depending on the edit performed. + """ + if self.suggestion.startswith(edit.text): + self.suggestion = self.suggestion[len(edit.text) :] + else: + self.suggestion = "" + old_gutter_width = self.gutter_width + result = edit.do(self) + self.history.record(edit) + new_gutter_width = self.gutter_width + + if old_gutter_width != new_gutter_width: + self.wrapped_document.wrap(self.wrap_width, self.indent_width) + else: + self.wrapped_document.wrap_range( + edit.top, + edit.bottom, + result.end_location, + ) + + edit.after(self) + self._build_highlight_map() + self.post_message(self.Changed(self)) + self.update_suggestion() + self._refresh_size() + return result + + def undo(self) -> None: + """Undo the edits since the last checkpoint (the most recent batch of edits).""" + if edits := self.history._pop_undo(): + self._undo_batch(edits) + + def action_undo(self) -> None: + """Undo the edits since the last checkpoint (the most recent batch of edits).""" + self.undo() + + def redo(self) -> None: + """Redo the most recently undone batch of edits.""" + if edits := self.history._pop_redo(): + self._redo_batch(edits) + + def action_redo(self) -> None: + """Redo the most recently undone batch of edits.""" + self.redo() + + def _undo_batch(self, edits: Sequence[Edit]) -> None: + """Undo a batch of Edits. + + The sequence must be chronologically ordered by edit time. + + There must be no edits missing from the sequence, or the resulting content + will be incorrect. + + Args: + edits: The edits to undo, in the order they were originally performed. + """ + if not edits: + return + + old_gutter_width = self.gutter_width + minimum_top = edits[-1].top + maximum_old_bottom = (0, 0) + maximum_new_bottom = (0, 0) + for edit in reversed(edits): + edit.undo(self) + end_location = ( + edit._edit_result.end_location if edit._edit_result else (0, 0) + ) + if edit.top < minimum_top: + minimum_top = edit.top + if end_location > maximum_old_bottom: + maximum_old_bottom = end_location + if edit.bottom > maximum_new_bottom: + maximum_new_bottom = edit.bottom + + new_gutter_width = self.gutter_width + if old_gutter_width != new_gutter_width: + self.wrapped_document.wrap(self.wrap_width, self.indent_width) + else: + self.wrapped_document.wrap_range( + minimum_top, maximum_old_bottom, maximum_new_bottom + ) + + self._refresh_size() + for edit in reversed(edits): + edit.after(self) + self._build_highlight_map() + self.post_message(self.Changed(self)) + self.update_suggestion() + + def _redo_batch(self, edits: Sequence[Edit]) -> None: + """Redo a batch of Edits in order. + + The sequence must be chronologically ordered by edit time. + + Edits are applied from the start of the sequence to the end. + + There must be no edits missing from the sequence, or the resulting content + will be incorrect. + + Args: + edits: The edits to redo. + """ + if not edits: + return + + old_gutter_width = self.gutter_width + minimum_top = edits[0].top + maximum_old_bottom = (0, 0) + maximum_new_bottom = (0, 0) + for edit in edits: + edit.do(self, record_selection=False) + end_location = ( + edit._edit_result.end_location if edit._edit_result else (0, 0) + ) + if edit.top < minimum_top: + minimum_top = edit.top + if end_location > maximum_new_bottom: + maximum_new_bottom = end_location + if edit.bottom > maximum_old_bottom: + maximum_old_bottom = edit.bottom + + new_gutter_width = self.gutter_width + if old_gutter_width != new_gutter_width: + self.wrapped_document.wrap(self.wrap_width, self.indent_width) + else: + self.wrapped_document.wrap_range( + minimum_top, + maximum_old_bottom, + maximum_new_bottom, + ) + + self._refresh_size() + for edit in edits: + edit.after(self) + self._build_highlight_map() + self.post_message(self.Changed(self)) + self.update_suggestion() + + async def _on_key(self, event: events.Key) -> None: + """Handle key presses which correspond to document inserts.""" + + self._restart_blink() + + if self.read_only: + return + + key = event.key + insert_values = { + "enter": "\n", + } + if self.tab_behavior == "indent": + if key == "escape": + event.stop() + event.prevent_default() + self.screen.focus_next() + return + if self.indent_type == "tabs": + insert_values["tab"] = "\t" + else: + insert_values["tab"] = " " * self._find_columns_to_next_tab_stop() + + if event.is_printable or key in insert_values: + event.stop() + event.prevent_default() + insert = insert_values.get(key, event.character) + # `insert` is not None because event.character cannot be + # None because we've checked that it's printable. + assert insert is not None + start, end = self.selection + self._replace_via_keyboard(insert, start, end) + + def _find_columns_to_next_tab_stop(self) -> int: + """Get the location of the next tab stop after the cursors position on the current line. + + If the cursor is already at a tab stop, this returns the *next* tab stop location. + + Returns: + The number of cells to the next tab stop from the current cursor column. + """ + cursor_row, cursor_column = self.cursor_location + line_text = self.document[cursor_row] + indent_width = self.indent_width + if not line_text: + return indent_width + + width_before_cursor = self.get_column_width(cursor_row, cursor_column) + spaces_to_insert = indent_width - ( + (indent_width + width_before_cursor) % indent_width + ) + + return spaces_to_insert + + def get_target_document_location(self, event: MouseEvent) -> Location: + """Given a MouseEvent, return the row and column offset of the event in document-space. + + Args: + event: The MouseEvent. + + Returns: + The location of the mouse event within the document. + """ + scroll_x, scroll_y = self.scroll_offset + target_x = event.x - self.gutter_width + scroll_x - self.gutter.left + target_y = event.y + scroll_y - self.gutter.top + location = self.wrapped_document.offset_to_location(Offset(target_x, target_y)) + return location + + @property + def gutter_width(self) -> int: + """The width of the gutter (the left column containing line numbers). + + Returns: + The cell-width of the line number column. If `show_line_numbers` is `False` returns 0. + """ + # The longest number in the gutter plus two extra characters: `│ `. + gutter_margin = 2 + gutter_width = ( + len(str(self.document.line_count - 1 + self.line_number_start)) + + gutter_margin + if self.show_line_numbers + else 0 + ) + return gutter_width + + def _on_mount(self, event: events.Mount) -> None: + def text_selection_started(screen: Screen) -> None: + """Signal callback to unselect when arbitrary text selection starts.""" + self.selection = Selection(self.cursor_location, self.cursor_location) + + self.screen.text_selection_started_signal.subscribe( + self, text_selection_started, immediate=True + ) + + # When `app.theme` reactive is changed, reset the theme to clear cached styles. + self.watch(self.app, "theme", self._app_theme_changed, init=False) + self.blink_timer = self.set_interval( + 0.5, + self._toggle_cursor_blink_visible, + pause=not (self.cursor_blink and self.has_focus), + ) + + def _toggle_cursor_blink_visible(self) -> None: + """Toggle visibility of the cursor for the purposes of 'cursor blink'.""" + if not self.screen.is_active: + return + + self._cursor_visible = not self._cursor_visible + _, cursor_y = self._cursor_offset + self.refresh_lines(cursor_y) + + def _watch__cursor_visible(self) -> None: + """When the cursor visibility is toggled, ensure the row is refreshed.""" + _, cursor_y = self._cursor_offset + self.refresh_lines(cursor_y) + + def _restart_blink(self) -> None: + """Reset the cursor blink timer.""" + if self.cursor_blink: + self._cursor_visible = True + if self.is_mounted: + self.blink_timer.reset() + + def _pause_blink(self, visible: bool = True) -> None: + """Pause the cursor blinking but ensure it stays visible.""" + self._cursor_visible = visible + if self.is_mounted: + self.blink_timer.pause() + + async def _on_mouse_down(self, event: events.MouseDown) -> None: + """Update the cursor position, and begin a selection using the mouse.""" + target = self.get_target_document_location(event) + self.selection = Selection.cursor(target) + self._selecting = True + # Capture the mouse so that if the cursor moves outside the + # TextArea widget while selecting, the widget still scrolls. + self.capture_mouse() + self._pause_blink(visible=False) + self.history.checkpoint() + + async def _on_mouse_move(self, event: events.MouseMove) -> None: + """Handles click and drag to expand and contract the selection.""" + if self._selecting: + target = self.get_target_document_location(event) + selection_start, _ = self.selection + self.selection = Selection(selection_start, target) + + def _end_mouse_selection(self) -> None: + """Finalize the selection that has been made using the mouse.""" + if self._selecting: + self._selecting = False + self.release_mouse() + self.record_cursor_width() + self._restart_blink() + + async def _on_mouse_up(self, event: events.MouseUp) -> None: + """Finalize the selection that has been made using the mouse.""" + self._end_mouse_selection() + + async def _on_hide(self, event: events.Hide) -> None: + """Finalize the selection that has been made using the mouse when the widget is hidden.""" + self._end_mouse_selection() + + async def _on_paste(self, event: events.Paste) -> None: + """When a paste occurs, insert the text from the paste event into the document.""" + if self.read_only: + return + if result := self._replace_via_keyboard(event.text, *self.selection): + self.move_cursor(result.end_location) + self.focus() + + def cell_width_to_column_index(self, cell_width: int, row_index: int) -> int: + """Return the column that the cell width corresponds to on the given row. + + Args: + cell_width: The cell width to convert. + row_index: The index of the row to examine. + + Returns: + The column corresponding to the cell width on that row. + """ + line = self.document[row_index] + return cell_width_to_column_index(line, cell_width, self.indent_width) + + def clamp_visitable(self, location: Location) -> Location: + """Clamp the given location to the nearest visitable location. + + Args: + location: The location to clamp. + + Returns: + The nearest location that we could conceivably navigate to using the cursor. + """ + document = self.document + + row, column = location + try: + line_text = document[row] + except IndexError: + line_text = "" + + row = clamp(row, 0, document.line_count - 1) + column = clamp(column, 0, len(line_text)) + + return row, column + + # --- Cursor/selection utilities + def scroll_cursor_visible( + self, center: bool = False, animate: bool = False + ) -> Offset: + """Scroll the `TextArea` such that the cursor is visible on screen. + + Args: + center: True if the cursor should be scrolled to the center. + animate: True if we should animate while scrolling. + + Returns: + The offset that was scrolled to bring the cursor into view. + """ + if not self._has_cursor: + return Offset(0, 0) + self._recompute_cursor_offset() + + x, y = self._cursor_offset + scroll_offset = self.scroll_to_region( + Region(x, y, width=3, height=1), + spacing=Spacing(right=self.gutter_width), + animate=animate, + force=True, + center=center, + ) + return scroll_offset + + def move_cursor( + self, + location: Location, + select: bool = False, + center: bool = False, + record_width: bool = True, + ) -> None: + """Move the cursor to a location. + + Args: + location: The location to move the cursor to. + select: If True, select text between the old and new location. + center: If True, scroll such that the cursor is centered. + record_width: If True, record the cursor column cell width after navigating + so that we jump back to the same width the next time we move to a row + that is wide enough. + """ + if not self._has_cursor: + return + if select: + start, _end = self.selection + self.selection = Selection(start, location) + else: + self.selection = Selection.cursor(location) + + if record_width: + self.record_cursor_width() + + if center: + self.scroll_cursor_visible(center) + + self.history.checkpoint() + + def move_cursor_relative( + self, + rows: int = 0, + columns: int = 0, + select: bool = False, + center: bool = False, + record_width: bool = True, + ) -> None: + """Move the cursor relative to its current location in document-space. + + Args: + rows: The number of rows to move down by (negative to move up) + columns: The number of columns to move right by (negative to move left) + select: If True, select text between the old and new location. + center: If True, scroll such that the cursor is centered. + record_width: If True, record the cursor column cell width after navigating + so that we jump back to the same width the next time we move to a row + that is wide enough. + """ + clamp_visitable = self.clamp_visitable + _start, end = self.selection + current_row, current_column = end + target = clamp_visitable((current_row + rows, current_column + columns)) + self.move_cursor(target, select, center, record_width) + + def select_line(self, index: int) -> None: + """Select all the text in the specified line. + + Args: + index: The index of the line to select (starting from 0). + """ + try: + line = self.document[index] + except IndexError: + return + else: + self.selection = Selection((index, 0), (index, len(line))) + self.record_cursor_width() + + def action_select_line(self) -> None: + """Select all the text on the current line.""" + cursor_row, _ = self.cursor_location + self.select_line(cursor_row) + + def select_all(self) -> None: + """Select all of the text in the `TextArea`.""" + last_line = self.document.line_count - 1 + length_of_last_line = len(self.document[last_line]) + selection_start = (0, 0) + selection_end = (last_line, length_of_last_line) + self.selection = Selection(selection_start, selection_end) + self.record_cursor_width() + + def action_select_all(self) -> None: + """Select all the text in the document.""" + self.select_all() + + @property + def cursor_location(self) -> Location: + """The current location of the cursor in the document. + + This is a utility for accessing the `end` of `TextArea.selection`. + """ + return self.selection.end + + @cursor_location.setter + def cursor_location(self, location: Location) -> None: + """Set the cursor_location to a new location. + + If a selection is in progress, the anchor point will remain. + """ + self.move_cursor(location, select=not self.selection.is_empty) + + @property + def cursor_screen_offset(self) -> Offset: + """The offset of the cursor relative to the screen.""" + cursor_x, cursor_y = self._cursor_offset + scroll_x, scroll_y = self.scroll_offset + region_x, region_y, _width, _height = self.content_region + + offset_x = region_x + cursor_x - scroll_x + self.gutter_width + offset_y = region_y + cursor_y - scroll_y + + return Offset(offset_x, offset_y) + + @property + def cursor_at_first_line(self) -> bool: + """True if and only if the cursor is on the first line.""" + return self.selection.end[0] == 0 + + @property + def cursor_at_last_line(self) -> bool: + """True if and only if the cursor is on the last line.""" + return self.selection.end[0] == self.document.line_count - 1 + + @property + def cursor_at_start_of_line(self) -> bool: + """True if and only if the cursor is at column 0.""" + return self.selection.end[1] == 0 + + @property + def cursor_at_end_of_line(self) -> bool: + """True if and only if the cursor is at the end of a row.""" + cursor_row, cursor_column = self.selection.end + row_length = len(self.document[cursor_row]) + cursor_at_end = cursor_column == row_length + return cursor_at_end + + @property + def cursor_at_start_of_text(self) -> bool: + """True if and only if the cursor is at location (0, 0)""" + return self.selection.end == (0, 0) + + @property + def cursor_at_end_of_text(self) -> bool: + """True if and only if the cursor is at the very end of the document.""" + return self.cursor_at_last_line and self.cursor_at_end_of_line + + # ------ Cursor movement actions + def action_cursor_left(self, select: bool = False) -> None: + """Move the cursor one location to the left. + + If the cursor is at the left edge of the document, try to move it to + the end of the previous line. + + If text is selected, move the cursor to the start of the selection. + + Args: + select: If True, select the text while moving. + """ + if not self._has_cursor: + self.scroll_left() + return + target = ( + self.get_cursor_left_location() + if select or self.selection.is_empty + else min(*self.selection) + ) + self.move_cursor(target, select=select) + + def get_cursor_left_location(self) -> Location: + """Get the location the cursor will move to if it moves left. + + Returns: + The location of the cursor if it moves left. + """ + return self.navigator.get_location_left(self.cursor_location) + + def action_cursor_right(self, select: bool = False) -> None: + """Move the cursor one location to the right. + + If the cursor is at the end of a line, attempt to go to the start of the next line. + + If text is selected, move the cursor to the end of the selection. + + Args: + select: If True, select the text while moving. + """ + if not self._has_cursor: + self.scroll_right() + return + if self.suggestion: + self.insert(self.suggestion) + return + target = ( + self.get_cursor_right_location() + if select or self.selection.is_empty + else max(*self.selection) + ) + self.move_cursor(target, select=select) + + def get_cursor_right_location(self) -> Location: + """Get the location the cursor will move to if it moves right. + + Returns: + the location the cursor will move to if it moves right. + """ + return self.navigator.get_location_right(self.cursor_location) + + def action_cursor_down(self, select: bool = False) -> None: + """Move the cursor down one cell. + + Args: + select: If True, select the text while moving. + """ + if not self._has_cursor: + self.scroll_down() + return + target = self.get_cursor_down_location() + self.move_cursor(target, record_width=False, select=select) + + def get_cursor_down_location(self) -> Location: + """Get the location the cursor will move to if it moves down. + + Returns: + The location the cursor will move to if it moves down. + """ + return self.navigator.get_location_below(self.cursor_location) + + def action_cursor_up(self, select: bool = False) -> None: + """Move the cursor up one cell. + + Args: + select: If True, select the text while moving. + """ + if not self._has_cursor: + self.scroll_up() + return + target = self.get_cursor_up_location() + self.move_cursor(target, record_width=False, select=select) + + def get_cursor_up_location(self) -> Location: + """Get the location the cursor will move to if it moves up. + + Returns: + The location the cursor will move to if it moves up. + """ + return self.navigator.get_location_above(self.cursor_location) + + def action_cursor_line_end(self, select: bool = False) -> None: + """Move the cursor to the end of the line.""" + if not self._has_cursor: + self.scroll_end() + return + location = self.get_cursor_line_end_location() + self.move_cursor(location, select=select) + + def get_cursor_line_end_location(self) -> Location: + """Get the location of the end of the current line. + + Returns: + The (row, column) location of the end of the cursors current line. + """ + return self.navigator.get_location_end(self.cursor_location) + + def action_cursor_line_start(self, select: bool = False) -> None: + """Move the cursor to the start of the line.""" + if not self._has_cursor: + self.scroll_home() + return + target = self.get_cursor_line_start_location(smart_home=True) + self.move_cursor(target, select=select) + + def get_cursor_line_start_location(self, smart_home: bool = False) -> Location: + """Get the location of the start of the current line. + + Args: + smart_home: If True, use "smart home key" behavior - go to the first + non-whitespace character on the line, and if already there, go to + offset 0. Smart home only works when wrapping is disabled. + + Returns: + The (row, column) location of the start of the cursors current line. + """ + return self.navigator.get_location_home( + self.cursor_location, smart_home=smart_home + ) + + def action_cursor_word_left(self, select: bool = False) -> None: + """Move the cursor left by a single word, skipping trailing whitespace. + + Args: + select: Whether to select while moving the cursor. + """ + if not self.show_cursor: + return + if self.cursor_at_start_of_text: + return + target = self.get_cursor_word_left_location() + self.move_cursor(target, select=select) + + def get_cursor_word_left_location(self) -> Location: + """Get the location the cursor will jump to if it goes 1 word left. + + Returns: + The location the cursor will jump on "jump word left". + """ + cursor_row, cursor_column = self.cursor_location + if cursor_row > 0 and cursor_column == 0: + # Going to the previous row + return cursor_row - 1, len(self.document[cursor_row - 1]) + + # Staying on the same row + line = self.document[cursor_row][:cursor_column] + search_string = line.rstrip() + matches = list(re.finditer(self._word_pattern, search_string)) + cursor_column = matches[-1].start() if matches else 0 + return cursor_row, cursor_column + + def action_cursor_word_right(self, select: bool = False) -> None: + """Move the cursor right by a single word, skipping leading whitespace.""" + if not self.show_cursor: + return + if self.cursor_at_end_of_text: + return + + target = self.get_cursor_word_right_location() + self.move_cursor(target, select=select) + + def get_cursor_word_right_location(self) -> Location: + """Get the location the cursor will jump to if it goes 1 word right. + + Returns: + The location the cursor will jump on "jump word right". + """ + cursor_row, cursor_column = self.selection.end + line = self.document[cursor_row] + if cursor_row < self.document.line_count - 1 and cursor_column == len(line): + # Moving to the line below + return cursor_row + 1, 0 + + # Staying on the same line + search_string = line[cursor_column:] + pre_strip_length = len(search_string) + search_string = search_string.lstrip() + strip_offset = pre_strip_length - len(search_string) + + matches = list(re.finditer(self._word_pattern, search_string)) + if matches: + cursor_column += matches[0].start() + strip_offset + else: + cursor_column = len(line) + + return cursor_row, cursor_column + + def action_cursor_page_up(self) -> None: + """Move the cursor and scroll up one page.""" + if not self.show_cursor: + self.scroll_page_up() + return + height = self.content_size.height + _, cursor_location = self.selection + target = self.navigator.get_location_at_y_offset( + cursor_location, + -height, + ) + self.scroll_relative(y=-height, animate=False) + self.move_cursor(target) + + def action_cursor_page_down(self) -> None: + """Move the cursor and scroll down one page.""" + if not self.show_cursor: + self.scroll_page_down() + return + height = self.content_size.height + _, cursor_location = self.selection + target = self.navigator.get_location_at_y_offset( + cursor_location, + height, + ) + self.scroll_relative(y=height, animate=False) + self.move_cursor(target) + + def get_column_width(self, row: int, column: int) -> int: + """Get the cell offset of the column from the start of the row. + + Args: + row: The row index. + column: The column index (codepoint offset from start of row). + + Returns: + The cell width of the column relative to the start of the row. + """ + line = self.document[row] + return cell_len(expand_tabs_inline(line[:column], self.indent_width)) + + def record_cursor_width(self) -> None: + """Record the current cell width of the cursor. + + This is used where we navigate up and down through rows. + If we're in the middle of a row, and go down to a row with no + content, then we go down to another row, we want our cursor to + jump back to the same offset that we were originally at. + """ + cursor_x_offset, _ = self.wrapped_document.location_to_offset( + self.cursor_location + ) + self.navigator.last_x_offset = cursor_x_offset + + # --- Editor operations + def insert( + self, + text: str, + location: Location | None = None, + *, + maintain_selection_offset: bool = True, + ) -> EditResult: + """Insert text into the document. + + Args: + text: The text to insert. + location: The location to insert text, or None to use the cursor location. + maintain_selection_offset: If True, the active Selection will be updated + such that the same text is selected before and after the selection, + if possible. Otherwise, the cursor will jump to the end point of the + edit. + + Returns: + An `EditResult` containing information about the edit. + """ + if len(text) > 1: + self._restart_blink() + if location is None: + location = self.cursor_location + return self.edit(Edit(text, location, location, maintain_selection_offset)) + + def delete( + self, + start: Location, + end: Location, + *, + maintain_selection_offset: bool = True, + ) -> EditResult: + """Delete the text between two locations in the document. + + Args: + start: The start location. + end: The end location. + maintain_selection_offset: If True, the active Selection will be updated + such that the same text is selected before and after the selection, + if possible. Otherwise, the cursor will jump to the end point of the + edit. + + Returns: + An `EditResult` containing information about the edit. + """ + return self.edit(Edit("", start, end, maintain_selection_offset)) + + def replace( + self, + insert: str, + start: Location, + end: Location, + *, + maintain_selection_offset: bool = True, + ) -> EditResult: + """Replace text in the document with new text. + + Args: + insert: The text to insert. + start: The start location + end: The end location. + maintain_selection_offset: If True, the active Selection will be updated + such that the same text is selected before and after the selection, + if possible. Otherwise, the cursor will jump to the end point of the + edit. + + Returns: + An `EditResult` containing information about the edit. + """ + return self.edit(Edit(insert, start, end, maintain_selection_offset)) + + def clear(self) -> EditResult: + """Delete all text from the document. + + Returns: + An EditResult relating to the deletion of all content. + """ + return self.delete((0, 0), self.document.end, maintain_selection_offset=False) + + def _delete_via_keyboard( + self, + start: Location, + end: Location, + ) -> EditResult | None: + """Handle a deletion performed using a keyboard (as opposed to the API). + + Args: + start: The start location of the text to delete. + end: The end location of the text to delete. + + Returns: + An EditResult or None if no edit was performed (e.g. on read-only mode). + """ + if self.read_only: + return None + return self.delete(start, end, maintain_selection_offset=False) + + def _replace_via_keyboard( + self, + insert: str, + start: Location, + end: Location, + ) -> EditResult | None: + """Handle a replacement performed using a keyboard (as opposed to the API). + + Args: + insert: The text to insert into the document. + start: The start location of the text to replace. + end: The end location of the text to replace. + + Returns: + An EditResult or None if no edit was performed (e.g. on read-only mode). + """ + if self.read_only: + return None + return self.replace(insert, start, end, maintain_selection_offset=False) + + def action_delete_left(self) -> None: + """Deletes the character to the left of the cursor and updates the cursor location. + + If there's a selection, then the selected range is deleted.""" + + if self.read_only: + return + + selection = self.selection + start, end = selection + + if selection.is_empty: + end = self.get_cursor_left_location() + + self._delete_via_keyboard(start, end) + + def action_delete_right(self) -> None: + """Deletes the character to the right of the cursor and keeps the cursor at the same location. + + If there's a selection, then the selected range is deleted.""" + if self.read_only: + return + + selection = self.selection + start, end = selection + + if selection.is_empty: + end = self.get_cursor_right_location() + + self._delete_via_keyboard(start, end) + + def action_delete_line(self) -> None: + """Deletes the lines which intersect with the selection.""" + if self.read_only: + return + self._delete_cursor_line() + + def _delete_cursor_line(self) -> EditResult | None: + """Deletes the line (including the line terminator) that the cursor is on.""" + start, end = self.selection + start, end = sorted((start, end)) + start_row, _start_column = start + end_row, end_column = end + + # Generally editors will only delete line the end line of the + # selection if the cursor is not at column 0 of that line. + if start_row != end_row and end_column == 0 and end_row >= 0: + end_row -= 1 + + from_location = (start_row, 0) + to_location = (end_row + 1, 0) + + deletion = self._delete_via_keyboard(from_location, to_location) + if deletion is not None: + self.move_cursor_relative(columns=end_column, record_width=False) + return deletion + + def action_cut(self) -> None: + """Cut text (remove and copy to clipboard).""" + if self.read_only: + return + start, end = self.selection + if start == end: + edit_result = self._delete_cursor_line() + else: + edit_result = self._delete_via_keyboard(start, end) + + if edit_result is not None: + self.app.copy_to_clipboard(edit_result.replaced_text) + + def action_copy(self) -> None: + """Copy selection to clipboard.""" + selected_text = self.selected_text + if selected_text: + self.app.copy_to_clipboard(selected_text) + else: + raise SkipAction() + + def action_paste(self) -> None: + """Paste from local clipboard.""" + if self.read_only: + return + clipboard = self.app.clipboard + if result := self._replace_via_keyboard(clipboard, *self.selection): + self.move_cursor(result.end_location) + + def action_delete_to_start_of_line(self) -> None: + """Deletes from the cursor location to the start of the line.""" + from_location = self.selection.end + to_location = self.get_cursor_line_start_location() + self._delete_via_keyboard(from_location, to_location) + + def action_delete_to_end_of_line(self) -> None: + """Deletes from the cursor location to the end of the line.""" + from_location = self.selection.end + to_location = self.get_cursor_line_end_location() + self._delete_via_keyboard(from_location, to_location) + + async def action_delete_to_end_of_line_or_delete_line(self) -> None: + """Deletes from the cursor location to the end of the line, or deletes the line. + + The line will be deleted if the line is empty. + """ + # Assume we're just going to delete to the end of the line. + action = "delete_to_end_of_line" + if self.get_cursor_line_start_location() == self.get_cursor_line_end_location(): + # The line is empty, so we'll simply remove the line itself. + action = "delete_line" + elif ( + self.selection.start + == self.selection.end + == self.get_cursor_line_end_location() + ): + # We're at the end of the line, so the kill delete operation + # should join the next line to this. + action = "delete_right" + await self.run_action(action) + + def action_delete_word_left(self) -> None: + """Deletes the word to the left of the cursor and updates the cursor location.""" + if self.cursor_at_start_of_text: + return + + # If there's a non-zero selection, then "delete word left" typically only + # deletes the characters within the selection range, ignoring word boundaries. + start, end = self.selection + if start != end: + self._delete_via_keyboard(start, end) + return + + to_location = self.get_cursor_word_left_location() + self._delete_via_keyboard(self.selection.end, to_location) + + def action_delete_word_right(self) -> None: + """Deletes the word to the right of the cursor and keeps the cursor at the same location. + + Note that the location that we delete to using this action is not the same + as the location we move to when we move the cursor one word to the right. + This action does not skip leading whitespace, whereas cursor movement does. + """ + if self.cursor_at_end_of_text: + return + + start, end = self.selection + if start != end: + self._delete_via_keyboard(start, end) + return + + cursor_row, cursor_column = end + + # Check the current line for a word boundary + line = self.document[cursor_row][cursor_column:] + matches = list(re.finditer(self._word_pattern, line)) + + current_row_length = len(self.document[cursor_row]) + if matches: + to_location = (cursor_row, cursor_column + matches[0].end()) + elif ( + cursor_row < self.document.line_count - 1 + and cursor_column == current_row_length + ): + to_location = (cursor_row + 1, 0) + else: + to_location = (cursor_row, current_row_length) + + self._delete_via_keyboard(end, to_location) + + +@lru_cache(maxsize=128) +def build_byte_to_codepoint_dict(data: bytes) -> dict[int, int]: + """Build a mapping of utf-8 byte offsets to codepoint offsets for the given data. + + Args: + data: utf-8 bytes. + + Returns: + A `dict[int, int]` mapping byte indices to codepoint indices within `data`. + """ + byte_to_codepoint: dict[int, int] = {} + current_byte_offset = 0 + code_point_offset = 0 + + while current_byte_offset < len(data): + byte_to_codepoint[current_byte_offset] = code_point_offset + first_byte = data[current_byte_offset] + + # Single-byte character + if (first_byte & 0b10000000) == 0: + current_byte_offset += 1 + # 2-byte character + elif (first_byte & 0b11100000) == 0b11000000: + current_byte_offset += 2 + # 3-byte character + elif (first_byte & 0b11110000) == 0b11100000: + current_byte_offset += 3 + # 4-byte character + elif (first_byte & 0b11111000) == 0b11110000: + current_byte_offset += 4 + else: + raise ValueError(f"Invalid UTF-8 byte: {first_byte}") + + code_point_offset += 1 + + # Mapping for the end of the string + byte_to_codepoint[current_byte_offset] = code_point_offset + return byte_to_codepoint diff --git a/src/memray/_vendor/textual/widgets/_toast.py b/src/memray/_vendor/textual/widgets/_toast.py new file mode 100644 index 0000000000..41017b1494 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_toast.py @@ -0,0 +1,202 @@ +"""Widgets for showing notification messages in toasts.""" + +from __future__ import annotations + +from typing import ClassVar + +from memray._vendor.textual import on +from memray._vendor.textual.containers import Container +from memray._vendor.textual.content import Content +from memray._vendor.textual.css.query import NoMatches +from memray._vendor.textual.events import Click, Mount +from memray._vendor.textual.notifications import Notification, Notifications +from memray._vendor.textual.widgets._static import Static + + +class ToastHolder(Container, inherit_css=False): + """Container that holds a single toast. + + Used to control the alignment of each of the toasts in the main toast + container. + """ + + DEFAULT_CSS = """ + ToastHolder { + align-horizontal: right; + width: 1fr; + height: auto; + visibility: hidden; + } + """ + + +class Toast(Static, inherit_css=False): + """A widget for displaying short-lived notifications.""" + + DEFAULT_CSS = """ + Toast { + width: 60; + max-width: 50%; + height: auto; + margin-top: 1; + visibility: visible; + padding: 1 1; + background: $panel-lighten-1; + link-background: initial; + link-color: $foreground; + link-style: underline; + link-background-hover: $primary; + link-color-hover: $foreground; + link-style-hover: bold not underline; + } + + .toast--title { + text-style: bold; + color: $foreground; + } + + Toast.-information { + border-left: outer $success; + } + + Toast.-information .toast--title { + color: $text-success; + } + + Toast.-warning { + border-left: outer $warning; + } + + Toast.-warning .toast--title { + color: $text-warning; + } + + Toast.-error { + border-left: outer $error; + } + + Toast.-error .toast--title { + color: $text-error; + } + """ + + COMPONENT_CLASSES: ClassVar[set[str]] = {"toast--title"} + """ + | Class | Description | + | :- | :- | + | `toast--title` | Targets the title of the toast. | + """ + + DEFAULT_CLASSES = "-textual-system" + + def __init__(self, notification: Notification) -> None: + """Initialise the toast. + + Args: + notification: The notification to show in the toast. + """ + super().__init__(classes=f"-{notification.severity}") + self._notification = notification + self._timeout = notification.time_left + + def render(self) -> Content: + """Render the toast's content. + + Returns: + A Rich renderable for the title and content of the Toast. + """ + notification = self._notification + + message_content = ( + Content.from_markup(notification.message) + if notification.markup + else Content(notification.message) + ) + + if notification.title: + header_style = self.get_visual_style("toast--title") + message_content = Content.assemble( + (notification.title, header_style), "\n", message_content + ) + + return message_content + + def _on_mount(self, _: Mount) -> None: + """Set the time running once the toast is mounted.""" + self.set_timer(self._timeout, self._expire) + + @on(Click) + def _expire(self) -> None: + """Remove the toast once the timer has expired.""" + # Before we removed ourself, we also call on the app to forget about + # the notification that caused us to exist. Note that we tell the + # app to not bother refreshing the display on our account, we're + # about to handle that anyway. + self.app._unnotify(self._notification, refresh=False) + # Note that we attempt to remove our parent, because we're wrapped + # inside an alignment container. The testing that we are is as much + # to keep type checkers happy as anything else. + (self.parent if isinstance(self.parent, ToastHolder) else self).remove() + + +class ToastRack(Container, inherit_css=False): + """A container for holding toasts.""" + + DEFAULT_CSS = """ + ToastRack { + display: none; + layer: _toastrack; + width: 1fr; + height: auto; + dock: bottom; + align: right bottom; + visibility: hidden; + layout: vertical; + overflow-y: scroll; + margin-bottom: 1; + } + """ + DEFAULT_CLASSES = "-textual-system" + + @staticmethod + def _toast_id(notification: Notification) -> str: + """Create a Textual-DOM-internal ID for the given notification. + + Args: + notification: The notification to create the ID for. + + Returns: + An ID for the notification that can be used within the DOM. + """ + return f"--textual-toast-{notification.identity}" + + def show(self, notifications: Notifications) -> None: + """Show the notifications as toasts. + + Args: + notifications: The notifications to show. + """ + self.display = bool(notifications) + # Look for any stale toasts and remove them. + for toast in self.query(Toast): + if toast._notification not in notifications: + toast.remove() + + # Gather up all the notifications that we don't have toasts for yet. + new_toasts: list[Notification] = [] + for notification in notifications: + try: + # See if there's already a toast for that notification. + _ = self.get_child_by_id(self._toast_id(notification)) + except NoMatches: + if not notification.has_expired: + new_toasts.append(notification) + + # If we got any... + if new_toasts: + # ...mount them. + self.mount_all( + ToastHolder(Toast(toast), id=self._toast_id(toast)) + for toast in new_toasts + ) + self.call_later(self.scroll_end, animate=False, force=True) diff --git a/src/memray/_vendor/textual/widgets/_toggle_button.py b/src/memray/_vendor/textual/widgets/_toggle_button.py new file mode 100644 index 0000000000..e019004f86 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_toggle_button.py @@ -0,0 +1,271 @@ +"""Provides the base code and implementations of toggle widgets. + +In particular it provides `Checkbox`, `RadioButton` and `RadioSet`. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, ClassVar + +from rich.console import RenderableType + +from memray._vendor.textual.binding import Binding, BindingType +from memray._vendor.textual.content import Content, ContentText +from memray._vendor.textual.events import Click +from memray._vendor.textual.geometry import Size +from memray._vendor.textual.message import Message +from memray._vendor.textual.reactive import reactive +from memray._vendor.textual.style import Style +from memray._vendor.textual.widgets._static import Static + +if TYPE_CHECKING: + from typing_extensions import Self + + +class ToggleButton(Static, can_focus=True): + """Base toggle button widget. + + Warning: + `ToggleButton` should be considered to be an internal class; it + exists to serve as the common core of [Checkbox][textual.widgets.Checkbox] and + [RadioButton][textual.widgets.RadioButton]. + """ + + ALLOW_SELECT = False + BINDINGS: ClassVar[list[BindingType]] = [ + Binding("enter,space", "toggle_button", "Toggle", show=False), + ] + """ + | Key(s) | Description | + | :- | :- | + | enter, space | Toggle the value. | + """ + + COMPONENT_CLASSES: ClassVar[set[str]] = { + "toggle--button", + "toggle--label", + } + """ + | Class | Description | + | :- | :- | + | `toggle--button` | Targets the toggle button itself. | + | `toggle--label` | Targets the text label of the toggle button. | + """ + + DEFAULT_CSS = """ + ToggleButton { + width: auto; + border: tall $border-blurred; + padding: 0 1; + background: $surface; + text-wrap: nowrap; + text-overflow: ellipsis; + pointer: pointer; + + &.-textual-compact { + border: none !important; + padding: 0; + &:focus { + border: tall $border; + background-tint: $foreground 5%; + & > .toggle--label { + color: $block-cursor-foreground; + background: $block-cursor-background; + text-style: $block-cursor-text-style; + } + } + } + + & > .toggle--button { + color: $panel-darken-2; + background: $panel; + } + + &.-on > .toggle--button { + color: $text-success; + background: $panel; + } + + &:focus { + border: tall $border; + background-tint: $foreground 5%; + + & > .toggle--label { + color: $block-cursor-foreground; + background: $block-cursor-background; + text-style: $block-cursor-text-style; + } + } + &:blur:hover { + & > .toggle--label { + background: $block-hover-background; + } + } + } + """ + + BUTTON_LEFT: str = "▐" + """The character used for the left side of the toggle button.""" + + BUTTON_INNER: str = "X" + """The character used for the inside of the button.""" + + BUTTON_RIGHT: str = "▌" + """The character used for the right side of the toggle button.""" + + value: reactive[bool] = reactive(False, init=False) + """The value of the button. `True` for on, `False` for off.""" + + compact: reactive[bool] = reactive(False, toggle_class="-textual-compact") + """Enable compact display?""" + + def __init__( + self, + label: ContentText = "", + value: bool = False, + button_first: bool = True, + *, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + tooltip: RenderableType | None = None, + compact: bool = False, + ) -> None: + """Initialise the toggle. + + Args: + label: The label for the toggle. + value: The initial value of the toggle. + button_first: Should the button come before the label, or after? + name: The name of the toggle. + id: The ID of the toggle in the DOM. + classes: The CSS classes of the toggle. + disabled: Whether the button is disabled or not. + tooltip: RenderableType | None = None, + compact: Show a compact button. + """ + super().__init__(name=name, id=id, classes=classes, disabled=disabled) + self._button_first = button_first + # NOTE: Don't send a Changed message in response to the initial set. + with self.prevent(self.Changed): + self.value = value + self._label = self._make_label(label) + if tooltip is not None: + self.tooltip = tooltip + self.compact = compact + + def _make_label(self, label: ContentText) -> Content: + """Make label content. + + Args: + label: The source value for the label. + + Returns: + A `Content` rendering of the label for use in the button. + """ + label = Content.from_text(label).first_line.rstrip() + return label + + @property + def label(self) -> Content: + """The label associated with the button.""" + return self._label + + @label.setter + def label(self, label: ContentText) -> None: + self._label = self._make_label(label) + self.refresh(layout=True) + + @property + def _button(self) -> Content: + """The button, reflecting the current value.""" + + # Grab the button style. + button_style = self.get_visual_style("toggle--button") + + # Building the style for the side characters. Note that this is + # sensitive to the type of character used, so pay attention to + # BUTTON_LEFT and BUTTON_RIGHT. + side_style = Style( + foreground=button_style.background, + background=self.background_colors[1], + ) + + return Content.assemble( + (self.BUTTON_LEFT, side_style), + (self.BUTTON_INNER, button_style), + (self.BUTTON_RIGHT, side_style), + ) + + def render(self) -> Content: + """Render the content of the widget. + + Returns: + The content to render for the widget. + """ + button = self._button + label_style = self.get_visual_style("toggle--label") + label = self._label.pad(1, 1).stylize_before(label_style) + + if self._button_first: + content = Content.assemble(button, label) + else: + content = Content.assemble(label, button) + return content + + def get_content_width(self, container: Size, viewport: Size) -> int: + return ( + self._button.get_optimal_width(self.styles, 0) + + (2 if self._label else 0) + + self._label.get_optimal_width(self.styles, 0) + ) + + def get_content_height(self, container: Size, viewport: Size, width: int) -> int: + return 1 + + def toggle(self) -> Self: + """Toggle the value of the widget. + + Returns: + The `ToggleButton` instance. + """ + self.value = not self.value + return self + + def action_toggle_button(self) -> None: + """Toggle the value of the widget when called as an action. + + This would normally be used for a keyboard binding. + """ + self.toggle() + + async def _on_click(self, _: Click) -> None: + """Toggle the value of the widget when clicked with the mouse.""" + self.toggle() + + class Changed(Message): + """Posted when the value of the toggle button changes.""" + + def __init__(self, toggle_button: ToggleButton, value: bool) -> None: + """Initialise the message. + + Args: + toggle_button: The toggle button sending the message. + value: The value of the toggle button. + """ + super().__init__() + self._toggle_button = toggle_button + """A reference to the toggle button that was changed.""" + self.value = value + """The value of the toggle button after the change.""" + + def watch_value(self) -> None: + """React to the value being changed. + + When triggered, the CSS class `-on` is applied to the widget if + `value` has become `True`, or it is removed if it has become + `False`. Subsequently a related `Changed` event will be posted. + """ + self.set_class(self.value, "-on") + self.post_message(self.Changed(self, self.value)) diff --git a/src/memray/_vendor/textual/widgets/_tooltip.py b/src/memray/_vendor/textual/widgets/_tooltip.py new file mode 100644 index 0000000000..a74005be3c --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_tooltip.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from memray._vendor.textual.widgets import Static + + +class Tooltip(Static, inherit_css=False): + DEFAULT_CSS = """ + Tooltip { + layer: _tooltips; + margin: 1 0; + padding: 1 2; + background: $panel; + width: auto; + height: auto; + constrain: inside inflect; + max-width: 40; + display: none; + offset-x: -50%; + } + """ + DEFAULT_CLASSES = "-textual-system" diff --git a/src/memray/_vendor/textual/widgets/_tree.py b/src/memray/_vendor/textual/widgets/_tree.py new file mode 100644 index 0000000000..0531dac246 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_tree.py @@ -0,0 +1,1600 @@ +"""Provides a tree widget.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, ClassVar, Generic, Iterable, NewType, TypeVar, cast + +import rich.repr +from rich.style import NULL_STYLE, Style +from rich.text import Text, TextType + +from memray._vendor.textual import events, on +from memray._vendor.textual._immutable_sequence_view import ImmutableSequenceView +from memray._vendor.textual._loop import loop_last +from memray._vendor.textual._segment_tools import line_pad +from memray._vendor.textual.binding import Binding, BindingType +from memray._vendor.textual.cache import LRUCache +from memray._vendor.textual.geometry import Region, Size, clamp +from memray._vendor.textual.message import Message +from memray._vendor.textual.reactive import reactive, var +from memray._vendor.textual.scroll_view import ScrollView +from memray._vendor.textual.strip import Strip + +if TYPE_CHECKING: + from typing_extensions import Self, TypeAlias + +NodeID = NewType("NodeID", int) +"""The type of an ID applied to a [TreeNode][textual.widgets._tree.TreeNode].""" + +TreeDataType = TypeVar("TreeDataType") +"""The type of the data for a given instance of a [Tree][textual.widgets.Tree].""" + +EventTreeDataType = TypeVar("EventTreeDataType") +"""The type of the data for a given instance of a [Tree][textual.widgets.Tree]. + +Similar to [TreeDataType][textual.widgets._tree.TreeDataType] but used for +``Tree`` messages. +""" + +LineCacheKey: TypeAlias = "tuple[int | tuple, ...]" + +TOGGLE_STYLE = Style.from_meta({"toggle": True}) + + +class RemoveRootError(Exception): + """Exception raised when trying to remove the root of a [`TreeNode`][textual.widgets.tree.TreeNode].""" + + +class UnknownNodeID(Exception): + """Exception raised when referring to an unknown [`TreeNode`][textual.widgets.tree.TreeNode] ID.""" + + +class AddNodeError(Exception): + """Exception raised when there is an error with a request to add a node.""" + + +@dataclass +class _TreeLine(Generic[TreeDataType]): + path: list[TreeNode[TreeDataType]] + last: bool + + @property + def node(self) -> TreeNode[TreeDataType]: + """The node associated with this line.""" + return self.path[-1] + + def _get_guide_width(self, guide_depth: int, show_root: bool) -> int: + """Get the cell width of the line as rendered. + + Args: + guide_depth: The guide depth (cells in the indentation). + + Returns: + Width in cells. + """ + if show_root: + width = (max(0, len(self.path) - 1)) * guide_depth + else: + width = 0 + if len(self.path) > 1: + width += (len(self.path) - 1) * guide_depth + + return width + + +class TreeNodes(ImmutableSequenceView["TreeNode[TreeDataType]"]): + """An immutable collection of `TreeNode`.""" + + +@rich.repr.auto +class TreeNode(Generic[TreeDataType]): + """An object that represents a "node" in a tree control.""" + + def __init__( + self, + tree: Tree[TreeDataType], + parent: TreeNode[TreeDataType] | None, + id: NodeID, + label: Text, + data: TreeDataType | None = None, + *, + expanded: bool = True, + allow_expand: bool = True, + ) -> None: + """Initialise the node. + + Args: + tree: The tree that the node is being attached to. + parent: The parent node that this node is being attached to. + id: The ID of the node. + label: The label for the node. + data: Optional data to associate with the node. + expanded: Should the node be attached in an expanded state? + allow_expand: Should the node allow being expanded by the user? + """ + self._tree = tree + self._parent = parent + self._id = id + self._label = tree.process_label(label) + self.data = data + """Optional data associated with the tree node.""" + self._expanded = expanded + self._children: list[TreeNode[TreeDataType]] = [] + + self._hover_ = False + self._selected_ = False + self._allow_expand = allow_expand + self._updates: int = 0 + self._line: int = -1 + + def __rich_repr__(self) -> rich.repr.Result: + yield self._label.plain + yield self.data + + def _reset(self) -> None: + self._hover_ = False + self._selected_ = False + self._updates += 1 + + @property + def tree(self) -> Tree[TreeDataType]: + """The tree that this node is attached to.""" + return self._tree + + @property + def children(self) -> TreeNodes[TreeDataType]: + """The child nodes of a TreeNode.""" + return TreeNodes(self._children) + + @property + def siblings(self) -> TreeNodes[TreeDataType]: + """The siblings of this node (includes self).""" + if self.parent is None: + return TreeNodes([self]) + else: + return self.parent.children + + @property + def line(self) -> int: + """The line number for this node, or -1 if it is not displayed.""" + return self._line + + @property + def _hover(self) -> bool: + """Check if the mouse is over the node.""" + return self._hover_ + + @_hover.setter + def _hover(self, hover: bool) -> None: + self._updates += 1 + self._hover_ = hover + + @property + def _selected(self) -> bool: + """Check if the node is selected.""" + return self._selected_ + + @_selected.setter + def _selected(self, selected: bool) -> None: + self._updates += 1 + self._selected_ = selected + + @property + def id(self) -> NodeID: + """The ID of the node.""" + return self._id + + @property + def parent(self) -> TreeNode[TreeDataType] | None: + """The parent of the node.""" + return self._parent + + @property + def next_sibling(self) -> TreeNode[TreeDataType] | None: + """The next sibling below the node.""" + siblings = self.siblings + index = siblings.index(self) + 1 + try: + return siblings[index] + except IndexError: + return None + + @property + def previous_sibling(self) -> TreeNode[TreeDataType] | None: + """The previous sibling below the node.""" + siblings = self.siblings + index = siblings.index(self) - 1 + if index < 0: + return None + try: + return siblings[index] + except IndexError: + return None + + @property + def is_expanded(self) -> bool: + """Is the node expanded?""" + return self._expanded + + @property + def is_collapsed(self) -> bool: + """Is the node collapsed?""" + return not self._expanded + + @property + def is_last(self) -> bool: + """Is this the last child node of its parent?""" + if self._parent is None: + return True + return bool( + self._parent._children and self._parent._children[-1] == self, + ) + + @property + def is_root(self) -> bool: + """Is this node the root of the tree?""" + return self == self._tree.root + + @property + def allow_expand(self) -> bool: + """Is this node allowed to expand?""" + return self._allow_expand + + @allow_expand.setter + def allow_expand(self, allow_expand: bool) -> None: + self._allow_expand = allow_expand + self._updates += 1 + + def _expand(self, expand_all: bool) -> None: + """Mark the node as expanded (its children are shown). + + Args: + expand_all: If `True` expand all offspring at all depths. + """ + self._expanded = True + self._updates += 1 + self._tree.post_message(Tree.NodeExpanded(self).set_sender(self._tree)) + if expand_all: + for child in self.children: + child._expand(expand_all) + + def expand(self) -> Self: + """Expand the node (show its children). + + Returns: + The `TreeNode` instance. + """ + self._expand(False) + self._tree._invalidate() + return self + + def expand_all(self) -> Self: + """Expand the node (show its children) and all those below it. + + Returns: + The `TreeNode` instance. + """ + self._expand(True) + self._tree._invalidate() + return self + + def _collapse(self, collapse_all: bool) -> None: + """Mark the node as collapsed (its children are hidden). + + Args: + collapse_all: If `True` collapse all offspring at all depths. + """ + self._expanded = False + self._updates += 1 + self._tree.post_message(Tree.NodeCollapsed(self).set_sender(self._tree)) + if collapse_all: + for child in self.children: + child._collapse(collapse_all) + + def collapse(self) -> Self: + """Collapse the node (hide its children). + + Returns: + The `TreeNode` instance. + """ + self._collapse(False) + self._tree._invalidate() + return self + + def collapse_all(self) -> Self: + """Collapse the node (hide its children) and all those below it. + + Returns: + The `TreeNode` instance. + """ + self._collapse(True) + self._tree._invalidate() + return self + + def toggle(self) -> Self: + """Toggle the node's expanded state. + + Returns: + The `TreeNode` instance. + """ + if self._expanded: + self.collapse() + else: + self.expand() + return self + + def toggle_all(self) -> Self: + """Toggle the node's expanded state and make all those below it match. + + Returns: + The `TreeNode` instance. + """ + if self._expanded: + self.collapse_all() + else: + self.expand_all() + return self + + @property + def label(self) -> TextType: + """The label for the node.""" + return self._label + + @label.setter + def label(self, new_label: TextType) -> None: + self.set_label(new_label) + + def set_label(self, label: TextType) -> None: + """Set a new label for the node. + + Args: + label: A ``str`` or ``Text`` object with the new label. + """ + self._updates += 1 + text_label = self._tree.process_label(label) + self._label = text_label + self._tree.call_later(self._tree._refresh_node, self) + + def add( + self, + label: TextType, + data: TreeDataType | None = None, + *, + before: int | TreeNode[TreeDataType] | None = None, + after: int | TreeNode[TreeDataType] | None = None, + expand: bool = False, + allow_expand: bool = True, + ) -> TreeNode[TreeDataType]: + """Add a node to the sub-tree. + + Args: + label: The new node's label. + data: Data associated with the new node. + before: Optional index or `TreeNode` to add the node before. + after: Optional index or `TreeNode` to add the node after. + expand: Node should be expanded. + allow_expand: Allow user to expand the node via keyboard or mouse. + + Returns: + A new Tree node + + Raises: + AddNodeError: If there is a problem with the addition request. + + Note: + Only one of `before` or `after` can be provided. If both are + provided a `AddNodeError` will be raised. + """ + if before is not None and after is not None: + raise AddNodeError("Unable to add a node both before and after a node") + + insert_index: int = len(self.children) + + if before is not None: + if isinstance(before, int): + insert_index = before + elif isinstance(before, TreeNode): + try: + insert_index = self.children.index(before) + except ValueError: + raise AddNodeError( + "The node specified for `before` is not a child of this node" + ) + else: + raise TypeError( + "`before` argument must be an index or a TreeNode object to add before" + ) + + if after is not None: + if isinstance(after, int): + insert_index = after + 1 + if after < 0: + insert_index += len(self.children) + elif isinstance(after, TreeNode): + try: + insert_index = self.children.index(after) + 1 + except ValueError: + raise AddNodeError( + "The node specified for `after` is not a child of this node" + ) + else: + raise TypeError( + "`after` argument must be an index or a TreeNode object to add after" + ) + + text_label = self._tree.process_label(label) + node = self._tree._add_node(self, text_label, data) + node._expanded = expand + node._allow_expand = allow_expand + self._updates += 1 + self._children.insert(insert_index, node) + self._tree._invalidate() + + return node + + def add_leaf( + self, + label: TextType, + data: TreeDataType | None = None, + *, + before: int | TreeNode[TreeDataType] | None = None, + after: int | TreeNode[TreeDataType] | None = None, + ) -> TreeNode[TreeDataType]: + """Add a 'leaf' node (a node that can not expand). + + Args: + label: Label for the node. + data: Optional data. + before: Optional index or `TreeNode` to add the node before. + after: Optional index or `TreeNode` to add the node after. + + Returns: + New node. + + Raises: + AddNodeError: If there is a problem with the addition request. + + Note: + Only one of `before` or `after` can be provided. If both are + provided a `AddNodeError` will be raised. + """ + node = self.add( + label, + data, + before=before, + after=after, + expand=False, + allow_expand=False, + ) + return node + + def _remove_children(self) -> None: + """Remove child nodes of this node. + + Note: + This is the internal support method for `remove_children`. Call + `remove_children` to ensure the tree gets refreshed. + """ + for child in reversed(self._children): + child._remove() + + def _remove(self) -> None: + """Remove the current node and all its children. + + Note: + This is the internal support method for `remove`. Call `remove` + to ensure the tree gets refreshed. + """ + self._remove_children() + assert self._parent is not None + del self._parent._children[self._parent._children.index(self)] + del self._tree._tree_nodes[self.id] + + def remove(self) -> None: + """Remove this node from the tree. + + Raises: + RemoveRootError: If there is an attempt to remove the root. + """ + if self.is_root: + raise RemoveRootError("Attempt to remove the root node of a Tree.") + self._remove() + self._tree._invalidate() + + def remove_children(self) -> None: + """Remove any child nodes of this node.""" + self._remove_children() + self._tree._invalidate() + + def refresh(self) -> None: + """Initiate a refresh (repaint) of this node.""" + self._updates += 1 + self._tree._refresh_line(self._line) + + +class Tree(Generic[TreeDataType], ScrollView, can_focus=True): + """A widget for displaying and navigating data in a tree.""" + + ICON_NODE = "▶ " + """Unicode 'icon' to use for an expandable node.""" + ICON_NODE_EXPANDED = "▼ " + """Unicode 'icon' to use for an expanded node.""" + + BINDINGS: ClassVar[list[BindingType]] = [ + Binding("shift+left", "cursor_parent", "Cursor to parent", show=False), + Binding( + "shift+right", + "cursor_parent_next_sibling", + "Cursor to next ancestor", + show=False, + ), + Binding( + "shift+up", + "cursor_previous_sibling", + "Cursor to previous sibling", + show=False, + ), + Binding( + "shift+down", + "cursor_next_sibling", + "Cursor to next sibling", + show=False, + ), + Binding("enter", "select_cursor", "Select", show=False), + Binding("space", "toggle_node", "Toggle", show=False), + Binding( + "shift+space", "toggle_expand_all", "Expand or collapse all", show=False + ), + Binding("up", "cursor_up", "Cursor Up", show=False), + Binding("down", "cursor_down", "Cursor Down", show=False), + ] + """ + | Key(s) | Description | + | :- | :- | + | enter | Select the current item. | + | space | Toggle the expand/collapsed state of the current item. | + | up | Move the cursor up. | + | down | Move the cursor down. | + """ + + ALLOW_SELECT = False + + COMPONENT_CLASSES: ClassVar[set[str]] = { + "tree--cursor", + "tree--guides", + "tree--guides-hover", + "tree--guides-selected", + "tree--highlight", + "tree--highlight-line", + "tree--label", + } + """ + | Class | Description | + | :- | :- | + | `tree--cursor` | Targets the cursor. | + | `tree--guides` | Targets the indentation guides. | + | `tree--guides-hover` | Targets the indentation guides under the cursor. | + | `tree--guides-selected` | Targets the indentation guides that are selected. | + | `tree--highlight` | Targets the highlighted items. | + | `tree--highlight-line` | Targets the lines under the cursor. | + | `tree--label` | Targets the (text) labels of the items. | + """ + + DEFAULT_CSS = """ + Tree { + background: $surface; + color: $foreground; + + & > .tree--label {} + & > .tree--guides { + color: $surface-lighten-2; + } + & > .tree--guides-hover { + color: $surface-lighten-2; + } + & > .tree--guides-selected { + color: $block-cursor-blurred-background; + } + & > .tree--cursor { + text-style: $block-cursor-blurred-text-style; + background: $block-cursor-blurred-background; + } + & > .tree--highlight {} + & > .tree--highlight-line { + background: $block-hover-background; + } + + &:focus { + background-tint: $foreground 5%; + & > .tree--cursor { + color: $block-cursor-foreground; + background: $block-cursor-background; + text-style: $block-cursor-text-style; + } + & > .tree--guides { + color: $surface-lighten-3; + } + & > .tree--guides-hover { + color: $surface-lighten-3; + } + & > .tree--guides-selected { + color: $block-cursor-background; + } + } + + &:light { + /* In light mode the guides are darker*/ + & > .tree--guides { + color: $surface-darken-1; + } + & > .tree--guides-hover { + color: $block-cursor-background; + } + & > .tree--guides-selected { + color: $block-cursor-background; + } + } + + &:ansi { + color: ansi_default; + & > .tree--guides { + color: ansi_green; + } + &:nocolor > .tree--cursor{ + text-style: reverse; + } + } + } + + """ + + show_root = reactive(True) + """Show the root of the tree.""" + hover_line = var(-1) + """The line number under the mouse pointer, or -1 if not under the mouse pointer.""" + cursor_line = var(-1, always_update=True) + """The line with the cursor, or -1 if no cursor.""" + show_guides = reactive(True) + """Enable display of tree guide lines.""" + guide_depth = reactive(4, init=False) + """The indent depth of tree nodes.""" + auto_expand = var(True) + """Auto expand tree nodes when they are selected.""" + center_scroll = var(False) + """Keep selected node in the center of the control, where possible.""" + + LINES: dict[str, tuple[str, str, str, str]] = { + "default": ( + " ", + "│ ", + "└─", + "├─", + ), + "bold": ( + " ", + "┃ ", + "┗━", + "┣━", + ), + "double": ( + " ", + "║ ", + "╚═", + "╠═", + ), + } + + class NodeCollapsed(Generic[EventTreeDataType], Message): + """Event sent when a node is collapsed. + + Can be handled using `on_tree_node_collapsed` in a subclass of `Tree` or in a + parent node in the DOM. + """ + + def __init__(self, node: TreeNode[EventTreeDataType]) -> None: + self.node: TreeNode[EventTreeDataType] = node + """The node that was collapsed.""" + super().__init__() + + @property + def control(self) -> Tree[EventTreeDataType]: + """The tree that sent the message.""" + return self.node.tree + + class NodeExpanded(Generic[EventTreeDataType], Message): + """Event sent when a node is expanded. + + Can be handled using `on_tree_node_expanded` in a subclass of `Tree` or in a + parent node in the DOM. + """ + + def __init__(self, node: TreeNode[EventTreeDataType]) -> None: + self.node: TreeNode[EventTreeDataType] = node + """The node that was expanded.""" + super().__init__() + + @property + def control(self) -> Tree[EventTreeDataType]: + """The tree that sent the message.""" + return self.node.tree + + class NodeHighlighted(Generic[EventTreeDataType], Message): + """Event sent when a node is highlighted. + + Can be handled using `on_tree_node_highlighted` in a subclass of `Tree` or in a + parent node in the DOM. + """ + + def __init__(self, node: TreeNode[EventTreeDataType]) -> None: + self.node: TreeNode[EventTreeDataType] = node + """The node that was highlighted.""" + super().__init__() + + @property + def control(self) -> Tree[EventTreeDataType]: + """The tree that sent the message.""" + return self.node.tree + + class NodeSelected(Generic[EventTreeDataType], Message): + """Event sent when a node is selected. + + Can be handled using `on_tree_node_selected` in a subclass of `Tree` or in a + parent node in the DOM. + """ + + def __init__(self, node: TreeNode[EventTreeDataType]) -> None: + self.node: TreeNode[EventTreeDataType] = node + """The node that was selected.""" + super().__init__() + + @property + def control(self) -> Tree[EventTreeDataType]: + """The tree that sent the message.""" + return self.node.tree + + def __init__( + self, + label: TextType, + data: TreeDataType | None = None, + *, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + disabled: bool = False, + ) -> None: + """Initialise a Tree. + + Args: + label: The label of the root node of the tree. + data: The optional data to associate with the root node of the tree. + name: The name of the Tree. + id: The ID of the tree in the DOM. + classes: The CSS classes of the tree. + disabled: Whether the tree is disabled or not. + """ + + text_label = self.process_label(label) + + self._updates = 0 + self._tree_nodes: dict[NodeID, TreeNode[TreeDataType]] = {} + self._current_id = 0 + self.root = self._add_node(None, text_label, data) + """The root node of the tree.""" + self._line_cache: LRUCache[LineCacheKey, Strip] = LRUCache(1024) + self._tree_lines_cached: list[_TreeLine[TreeDataType]] | None = None + self._cursor_node: TreeNode[TreeDataType] | None = None + + super().__init__(name=name, id=id, classes=classes, disabled=disabled) + + def add_json(self, json_data: object, node: TreeNode | None = None) -> None: + """Adds JSON data to a node. + + Args: + json_data: An object decoded from JSON. + node: Node to add data to. + + """ + + if node is None: + node = self.root + + from rich.highlighter import ReprHighlighter + + highlighter = ReprHighlighter() + + def add_node(name: str, node: TreeNode, data: object) -> None: + """Adds a node to the tree. + + Args: + name: Name of the node. + node: Parent node. + data: Data associated with the node. + """ + if isinstance(data, dict): + node.set_label(Text(f"{{}} {name}")) + for key, value in data.items(): + new_node = node.add("") + add_node(key, new_node, value) + elif isinstance(data, list): + node.set_label(Text(f"[] {name}")) + for index, value in enumerate(data): + new_node = node.add("") + add_node(str(index), new_node, value) + else: + node.allow_expand = False + if name: + label = Text.assemble( + Text.from_markup(f"[b]{name}[/b]="), highlighter(repr(data)) + ) + else: + label = Text(repr(data)) + node.set_label(label) + + add_node("", node, json_data) + + @property + def cursor_node(self) -> TreeNode[TreeDataType] | None: + """The currently selected node, or ``None`` if no selection.""" + return self._cursor_node + + @property + def last_line(self) -> int: + """The index of the last line.""" + return len(self._tree_lines) - 1 + + def process_label(self, label: TextType) -> Text: + """Process a `str` or `Text` value into a label. + + May be overridden in a subclass to change how labels are rendered. + + Args: + label: Label. + + Returns: + A Rich Text object. + """ + if isinstance(label, str): + text_label = Text.from_markup(label) + else: + text_label = label + first_line = text_label.split()[0] + return first_line + + def _add_node( + self, + parent: TreeNode[TreeDataType] | None, + label: Text, + data: TreeDataType | None, + expand: bool = False, + ) -> TreeNode[TreeDataType]: + node = TreeNode(self, parent, self._new_id(), label, data, expanded=expand) + self._tree_nodes[node._id] = node + self._updates += 1 + return node + + def render_label( + self, node: TreeNode[TreeDataType], base_style: Style, style: Style + ) -> Text: + """Render a label for the given node. Override this to modify how labels are rendered. + + Args: + node: A tree node. + base_style: The base style of the widget. + style: The additional style for the label. + + Returns: + A Rich Text object containing the label. + """ + node_label = node._label.copy() + node_label.stylize(style) + + if node._allow_expand: + prefix = ( + self.ICON_NODE_EXPANDED if node.is_expanded else self.ICON_NODE, + base_style + TOGGLE_STYLE, + ) + else: + prefix = ("", base_style) + + text = Text.assemble(prefix, node_label) + return text + + def get_label_width(self, node: TreeNode[TreeDataType]) -> int: + """Get the width of the nodes label. + + The default behavior is to call `render_label` and return the cell length. This method may be + overridden in a sub-class if it can be done more efficiently. + + Args: + node: A node. + + Returns: + Width in cells. + """ + label = self.render_label(node, NULL_STYLE, NULL_STYLE) + return label.cell_len + + def _clear_line_cache(self) -> None: + """Clear line cache.""" + self._line_cache.clear() + self._tree_lines_cached = None + + def clear(self) -> Self: + """Clear all nodes under root. + + Returns: + The `Tree` instance. + """ + self._clear_line_cache() + self._current_id = 0 + root_label = self.root._label + root_data = self.root.data + root_expanded = self.root.is_expanded + self.root = TreeNode( + self, + None, + self._new_id(), + root_label, + root_data, + expanded=root_expanded, + ) + self._updates += 1 + self.refresh() + return self + + def reset(self, label: TextType, data: TreeDataType | None = None) -> Self: + """Clear the tree and reset the root node. + + Args: + label: The label for the root node. + data: Optional data for the root node. + + Returns: + The `Tree` instance. + """ + self.clear() + self.root.label = label + self.root.data = data + return self + + def move_cursor( + self, node: TreeNode[TreeDataType] | None, animate: bool = False + ) -> None: + """Move the cursor to the given node, or reset cursor. + + Args: + node: A tree node, or None to reset cursor. + animate: Enable animation + """ + previous_cursor_line = self.cursor_line + self.cursor_line = -1 if node is None else node._line + if node is not None and self.cursor_node is not None: + self.scroll_to_node( + self.cursor_node, + animate=animate and abs(self.cursor_line - previous_cursor_line) > 1, + ) + + def move_cursor_to_line(self, line: int, animate=False) -> None: + """Move the cursor to the given line. + + Args: + line: The line number (negative indexes are offsets from the last line). + animate: Enable scrolling animation. + + Raises: + IndexError: If the line doesn't exist. + """ + if self.cursor_line == line: + return + try: + node = self._tree_lines[line].node + except IndexError: + raise IndexError(f"No line no. {line} in the tree") + self.move_cursor(node, animate=animate) + + def select_node(self, node: TreeNode[TreeDataType] | None) -> None: + """Move the cursor to the given node and select it, or reset cursor. + + Args: + node: A tree node to move the cursor to and select, or None to reset cursor. + """ + self.move_cursor(node) + if node is not None: + self.post_message(Tree.NodeSelected(node)) + + def unselect(self) -> None: + """Hide and reset the cursor.""" + self.set_reactive(Tree.cursor_line, -1) + self._invalidate() + + @on(NodeSelected) + def _expand_node_on_select(self, event: NodeSelected[TreeDataType]) -> None: + """When the node is selected, expand the node if `auto_expand` is True.""" + node = event.node + if self.auto_expand: + self._toggle_node(node) + + def get_node_at_line(self, line_no: int) -> TreeNode[TreeDataType] | None: + """Get the node for a given line. + + Args: + line_no: A line number. + + Returns: + A tree node, or ``None`` if there is no node at that line. + """ + try: + line = self._tree_lines[line_no] + except IndexError: + return None + else: + return line.node + + def get_node_by_id(self, node_id: NodeID) -> TreeNode[TreeDataType]: + """Get a tree node by its ID. + + Args: + node_id: The ID of the node to get. + + Returns: + The node associated with that ID. + + Raises: + UnknownNodeID: Raised if the `TreeNode` ID is unknown. + """ + try: + return self._tree_nodes[node_id] + except KeyError: + raise UnknownNodeID(f"Unknown NodeID ({node_id}) in tree") from None + + def validate_cursor_line(self, value: int) -> int: + """Prevent cursor line from going outside of range. + + Args: + value: The value to test. + + Return: + A valid version of the given value. + """ + return clamp(value, 0, len(self._tree_lines) - 1) + + def validate_guide_depth(self, value: int) -> int: + """Restrict guide depth to reasonable range. + + Args: + value: The value to test. + + Return: + A valid version of the given value. + """ + return clamp(value, 2, 10) + + def _invalidate(self) -> None: + """Invalidate caches.""" + self._clear_line_cache() + self._updates += 1 + self.root._reset() + self.refresh(layout=True) + + def _on_mouse_move(self, event: events.MouseMove) -> None: + meta = event.style.meta + if meta and "line" in meta: + self.hover_line = meta["line"] + else: + self.hover_line = -1 + + def _on_leave(self, _: events.Leave) -> None: + # Ensure the hover effect doesn't linger after the mouse leaves. + self.hover_line = -1 + + def _new_id(self) -> NodeID: + """Create a new node ID. + + Returns: + A unique node ID. + """ + id = self._current_id + self._current_id += 1 + return NodeID(id) + + def _get_node(self, line: int) -> TreeNode[TreeDataType] | None: + if line == -1: + return None + try: + tree_line = self._tree_lines[line] + except IndexError: + return None + else: + return tree_line.node + + def _get_label_region(self, line: int) -> Region | None: + """Returns the region occupied by the label of the node at line `line`. + + This can be used, e.g., when scrolling to that line such that the label + is visible after the scroll. + + Args: + line: A line number. + + Returns: + The region occupied by the label, or `None` if the + line is not in the tree. + """ + try: + tree_line = self._tree_lines[line] + except IndexError: + return None + region_x = tree_line._get_guide_width(self.guide_depth, self.show_root) + region_width = self.get_label_width(tree_line.node) + return Region(region_x, line, region_width, 1) + + def watch_hover_line(self, previous_hover_line: int, hover_line: int) -> None: + previous_node = self._get_node(previous_hover_line) + if previous_node is not None: + self._refresh_node(previous_node) + previous_node._hover = False + + node = self._get_node(hover_line) + if node is not None: + self._refresh_node(node) + node._hover = True + + def watch_cursor_line(self, previous_line: int, line: int) -> None: + previous_node = self._get_node(previous_line) + node = self._get_node(line) + + if self.cursor_node is not None: + self.cursor_node._selected = False + + if previous_node is not None: + previous_node._selected = False + + if node is not None: + node._selected = True + self._cursor_node = node + else: + self._cursor_node = None + + if previous_line == line: + # No change, so no need for refresh + return + + # Refresh previous cursor node + if previous_node is not None: + self._refresh_node(previous_node) + + # Refresh new node + if node is not None: + self._refresh_node(node) + if previous_node != node: + self.post_message(self.NodeHighlighted(node)) + + def watch_guide_depth(self, guide_depth: int) -> None: + self._invalidate() + + def watch_show_root(self, show_root: bool) -> None: + self.cursor_line = -1 + self._invalidate() + + def scroll_to_line(self, line: int, animate: bool = True) -> None: + """Scroll to the given line. + + Args: + line: A line number. + animate: Enable animation. + """ + region = self._get_label_region(line) + if region is not None: + self.scroll_to_region( + region, + animate=animate, + force=True, + center=self.center_scroll, + origin_visible=False, + x_axis=False, # Scrolling the X axis is quite jarring, and rarely necessary + ) + + def scroll_to_node( + self, node: TreeNode[TreeDataType], animate: bool = True + ) -> None: + """Scroll to the given node. + + Args: + node: Node to scroll into view. + animate: Animate scrolling. + """ + line = node._line + if line != -1: + self.scroll_to_line(line, animate=animate) + + def _refresh_line(self, line: int) -> None: + """Refresh (repaint) a given line in the tree. + + Args: + line: Line number. + """ + region = Region(0, line - self.scroll_offset.y, self.size.width, 1) + self.refresh(region) + + def _refresh_node_line(self, line: int) -> None: + node = self._get_node(line) + if node is not None: + self._refresh_node(node) + + def _refresh_node(self, node: TreeNode[TreeDataType]) -> None: + """Refresh a node and all its children. + + Args: + node: A tree node. + """ + scroll_y = self.scroll_offset.y + height = self.size.height + visible_lines = self._tree_lines[scroll_y : scroll_y + height] + for line_no, line in enumerate(visible_lines, scroll_y): + if node in line.path: + self._refresh_line(line_no) + + @property + def _tree_lines(self) -> list[_TreeLine[TreeDataType]]: + if self._tree_lines_cached is None: + self._build() + assert self._tree_lines_cached is not None + return self._tree_lines_cached + + async def _on_idle(self, event: events.Idle) -> None: + """Check tree needs a rebuild on idle.""" + # Property calls build if required + async with self.lock: + self._tree_lines + + def _build(self) -> None: + """Builds the tree by traversing nodes, and creating tree lines.""" + + TreeLine = _TreeLine + lines: list[_TreeLine[TreeDataType]] = [] + add_line = lines.append + + root = self.root + + def add_node( + path: list[TreeNode[TreeDataType]], node: TreeNode[TreeDataType], last: bool + ) -> None: + child_path = [*path, node] + node._line = len(lines) + add_line(TreeLine(child_path, last)) + if node._expanded: + for last, child in loop_last(node._children): + add_node(child_path, child, last) + + if self.show_root: + add_node([], root, True) + else: + for node in self.root._children: + add_node([], node, True) + self._tree_lines_cached = lines + + guide_depth = self.guide_depth + show_root = self.show_root + get_label_width = self.get_label_width + + def get_line_width(line: _TreeLine[TreeDataType]) -> int: + return get_label_width(line.node) + line._get_guide_width( + guide_depth, show_root + ) + + if lines: + width = max([get_line_width(line) for line in lines]) + else: + width = self.size.width + + self.virtual_size = Size(width, len(lines)) + if self.cursor_line != -1: + if self.cursor_node is not None: + self.cursor_line = self.cursor_node._line + if self.cursor_line >= len(lines): + self.cursor_line = -1 + + def render_lines(self, crop: Region) -> list[Strip]: + self._pseudo_class_state = self.get_pseudo_class_state() + return super().render_lines(crop) + + def render_line(self, y: int) -> Strip: + width = self.size.width + scroll_x, scroll_y = self.scroll_offset + style = self.rich_style + return self._render_line( + y + scroll_y, + scroll_x, + scroll_x + width, + style, + ) + + def _render_line(self, y: int, x1: int, x2: int, base_style: Style) -> Strip: + tree_lines = self._tree_lines + width = self.size.width + + if y >= len(tree_lines): + return Strip.blank(width, base_style) + + line = tree_lines[y] + + is_hover = self.hover_line >= 0 and any(node._hover for node in line.path) + + cache_key = ( + y, + is_hover, + width, + self._updates, + self._pseudo_class_state, + tuple(node._updates for node in line.path), + ) + if cache_key in self._line_cache: + strip = self._line_cache[cache_key] + else: + # Allow tree guides to be explicitly disabled by setting color to transparent + base_hidden = self.get_component_styles("tree--guides").color.a == 0 + hover_hidden = self.get_component_styles("tree--guides-hover").color.a == 0 + selected_hidden = ( + self.get_component_styles("tree--guides-selected").color.a == 0 + ) + + base_guide_style = self.get_component_rich_style( + "tree--guides", partial=True + ) + guide_hover_style = base_guide_style + self.get_component_rich_style( + "tree--guides-hover", partial=True + ) + guide_selected_style = base_guide_style + self.get_component_rich_style( + "tree--guides-selected", partial=True + ) + + hover = line.path[0]._hover + selected = line.path[0]._selected and self.has_focus + + def get_guides(style: Style, hidden: bool) -> tuple[str, str, str, str]: + """Get the guide strings for a given style. + + Args: + style: A Style object. + hidden: Switch to hide guides (make them invisible). + + Returns: + Strings for space, vertical, terminator and cross. + """ + lines: tuple[Iterable[str], Iterable[str], Iterable[str], Iterable[str]] + if self.show_guides and not hidden: + lines = self.LINES["default"] + if style.bold: + lines = self.LINES["bold"] + elif style.underline2: + lines = self.LINES["double"] + else: + lines = (" ", " ", " ", " ") + + guide_depth = max(0, self.guide_depth - 2) + guide_lines = tuple( + f"{characters[0]}{characters[1] * guide_depth} " + for characters in lines + ) + return cast("tuple[str, str, str, str]", guide_lines) + + if is_hover: + line_style = self.get_component_rich_style("tree--highlight-line") + else: + line_style = base_style + + line_style += Style(meta={"line": y}) + + guides = Text(style=line_style) + guides_append = guides.append + + guide_style = base_guide_style + + hidden = True + for node in line.path[1:]: + hidden = base_hidden + if hover: + guide_style = guide_hover_style + hidden = hover_hidden + if selected: + guide_style = guide_selected_style + hidden = selected_hidden + + space, vertical, _, _ = get_guides(guide_style, hidden) + guide = space if node.is_last else vertical + if node != line.path[-1]: + guides_append(guide, style=guide_style) + hover = hover or node._hover + selected = (selected or node._selected) and self.has_focus + + if len(line.path) > 1: + _, _, terminator, cross = get_guides(guide_style, hidden) + if line.last: + guides.append(terminator, style=guide_style) + else: + guides.append(cross, style=guide_style) + + label_style = self.get_component_rich_style("tree--label", partial=True) + if self.hover_line == y: + label_style += self.get_component_rich_style( + "tree--highlight", partial=True + ) + if self.cursor_line == y: + label_style += self.get_component_rich_style( + "tree--cursor", partial=False + ) + + label = self.render_label(line.path[-1], line_style, label_style).copy() + label.stylize(Style(meta={"node": line.node._id})) + guides.append(label) + + segments = list(guides.render(self.app.console)) + pad_width = max(self.virtual_size.width, width) + segments = line_pad(segments, 0, pad_width - guides.cell_len, line_style) + strip = self._line_cache[cache_key] = Strip(segments) + + strip = strip.crop(x1, x2) + return strip + + def _on_resize(self, event: events.Resize) -> None: + self._line_cache.grow(event.size.height) + self._invalidate() + + def _toggle_node(self, node: TreeNode[TreeDataType]) -> None: + if not node.allow_expand: + return + if node.is_expanded: + node.collapse() + else: + node.expand() + + async def _on_click(self, event: events.Click) -> None: + async with self.lock: + meta = event.style.meta + if "line" in meta: + cursor_line = meta["line"] + if meta.get("toggle", False): + node = self.get_node_at_line(cursor_line) + if node is not None: + self._toggle_node(node) + + else: + self.cursor_line = cursor_line + await self.run_action("select_cursor") + + def notify_style_update(self) -> None: + super().notify_style_update() + self._invalidate() + + def action_cursor_up(self) -> None: + """Move the cursor up one node.""" + if self.cursor_line == -1: + self.cursor_line = self.last_line + else: + self.cursor_line -= 1 + self.scroll_to_line(self.cursor_line, animate=False) + + def action_cursor_down(self) -> None: + """Move the cursor down one node.""" + if self.cursor_line == -1: + self.cursor_line = 0 + else: + self.cursor_line += 1 + self.scroll_to_line(self.cursor_line, animate=False) + + def action_page_down(self) -> None: + """Move the cursor down a page's-worth of nodes.""" + if self.cursor_line == -1: + self.cursor_line = 0 + self.cursor_line += self.scrollable_content_region.height - 1 + self.scroll_to_line(self.cursor_line) + + def action_page_up(self) -> None: + """Move the cursor up a page's-worth of nodes.""" + if self.cursor_line == -1: + self.cursor_line = self.last_line + self.cursor_line -= self.scrollable_content_region.height - 1 + self.scroll_to_line(self.cursor_line) + + def action_scroll_home(self) -> None: + """Move the cursor to the top of the tree.""" + self.cursor_line = 0 + self.scroll_to_line(self.cursor_line) + + def action_scroll_end(self) -> None: + """Move the cursor to the bottom of the tree. + + Note: + Here bottom means vertically, not branch depth. + """ + self.cursor_line = self.last_line + self.scroll_to_line(self.cursor_line) + + def action_toggle_node(self) -> None: + """Toggle the expanded state of the target node.""" + try: + line = self._tree_lines[self.cursor_line] + except IndexError: + pass + else: + self._toggle_node(line.path[-1]) + + def action_select_cursor(self) -> None: + """Cause a select event for the target node. + + Note: + If `auto_expand` is `True` use of this action on a non-leaf node + will cause both an expand/collapse event to occur, as well as a + selected event. + """ + if self.cursor_line < 0: + return + try: + line = self._tree_lines[self.cursor_line] + except IndexError: + pass + else: + node = line.path[-1] + self.post_message(Tree.NodeSelected(node)) + + def action_cursor_parent(self) -> None: + """Move the cursor to the parent node.""" + cursor_node = self.cursor_node + if cursor_node is not None and cursor_node.parent is not None: + self.move_cursor(cursor_node.parent, animate=True) + + def action_cursor_parent_next_sibling(self) -> None: + """Move the cursor to the parent's next sibling.""" + cursor_node = self.cursor_node + if cursor_node is not None and cursor_node.parent is not None: + self.move_cursor(cursor_node.parent.next_sibling, animate=True) + + def action_cursor_previous_sibling(self) -> None: + """Move the cursor to previous sibling, or to the parent if there are no more siblings.""" + cursor_node = self.cursor_node + if cursor_node is not None: + previous_sibling = cursor_node.previous_sibling + if previous_sibling is None: + self.move_cursor(cursor_node.parent, animate=True) + else: + self.move_cursor(previous_sibling, animate=True) + + def action_cursor_next_sibling(self) -> None: + """Move the cursor to the next sibling, or to the paren't sibling if there are no more siblings.""" + cursor_node = self.cursor_node + if cursor_node is not None: + next_sibling = cursor_node.next_sibling + if next_sibling is None: + if cursor_node.parent is not None: + parent_sibling = cursor_node.parent.next_sibling + self.move_cursor(parent_sibling, animate=True) + else: + self.move_cursor(next_sibling, animate=True) + + def action_toggle_expand_all(self) -> None: + """Expand or collapse all siblings. + + If all the siblings are collapsed then they will be expanded. + Otherwise they will all be collapsed. + + """ + + if self.cursor_node is None or self.cursor_node.parent is None: + return + + siblings = self.cursor_node.siblings + cursor_node = self.cursor_node + + # If all siblings are collapsed we want to expand them all + if all(child.is_collapsed for child in siblings): + for child in siblings: + if child.allow_expand: + child.expand() + # Otherwise we want to collapse them all + else: + for child in siblings: + if child.allow_expand: + child.collapse() + + self.call_after_refresh(self.move_cursor, cursor_node, animate=False) diff --git a/src/memray/_vendor/textual/widgets/_welcome.py b/src/memray/_vendor/textual/widgets/_welcome.py new file mode 100644 index 0000000000..47eb022539 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/_welcome.py @@ -0,0 +1,59 @@ +"""Provides a Textual welcome widget.""" + +from rich.markdown import Markdown + +from memray._vendor.textual.app import ComposeResult +from memray._vendor.textual.containers import Container +from memray._vendor.textual.widgets._button import Button +from memray._vendor.textual.widgets._static import Static + +WELCOME_MD = """\ +# Welcome! + +Textual is a TUI, or *Text User Interface*, framework for Python inspired by modern web development. **We hope you enjoy using Textual!** + +## Dune quote + +> "I must not fear. +Fear is the mind-killer. +Fear is the little-death that brings total obliteration. +I will face my fear. +I will permit it to pass over me and through me. +And when it has gone past, I will turn the inner eye to see its path. +Where the fear has gone there will be nothing. Only I will remain." +""" + + +class Welcome(Static): + """A Textual welcome widget. + + This widget can be used as a form of placeholder within a Textual + application; although also see + [Placeholder][textual.widgets._placeholder.Placeholder]. + """ + + DEFAULT_CSS = """ + Welcome { + width: 100%; + height: 100%; + background: $surface; + } + + Welcome Container { + padding: 1; + color: $foreground; + } + + Welcome #text { + margin: 0 1; + } + + Welcome #close { + dock: bottom; + width: 100%; + } + """ + + def compose(self) -> ComposeResult: + yield Container(Static(Markdown(WELCOME_MD), id="text"), id="md") + yield Button("OK", id="close", variant="success") diff --git a/src/memray/_vendor/textual/widgets/button.py b/src/memray/_vendor/textual/widgets/button.py new file mode 100644 index 0000000000..a7d585a961 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/button.py @@ -0,0 +1,3 @@ +from memray._vendor.textual.widgets._button import ButtonVariant + +__all__ = ["ButtonVariant"] diff --git a/src/memray/_vendor/textual/widgets/collapsible.py b/src/memray/_vendor/textual/widgets/collapsible.py new file mode 100644 index 0000000000..ed5421bb52 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/collapsible.py @@ -0,0 +1,3 @@ +from memray._vendor.textual.widgets._collapsible import CollapsibleTitle + +__all__ = ["CollapsibleTitle"] diff --git a/src/memray/_vendor/textual/widgets/data_table.py b/src/memray/_vendor/textual/widgets/data_table.py new file mode 100644 index 0000000000..359e100409 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/data_table.py @@ -0,0 +1,29 @@ +from memray._vendor.textual.widgets._data_table import ( + CellDoesNotExist, + CellKey, + CellType, + Column, + ColumnDoesNotExist, + ColumnKey, + CursorType, + DuplicateKey, + Row, + RowDoesNotExist, + RowKey, + StringKey, +) + +__all__ = [ + "CellDoesNotExist", + "CellKey", + "CellType", + "Column", + "ColumnDoesNotExist", + "ColumnKey", + "CursorType", + "DuplicateKey", + "Row", + "RowDoesNotExist", + "RowKey", + "StringKey", +] diff --git a/src/memray/_vendor/textual/widgets/directory_tree.py b/src/memray/_vendor/textual/widgets/directory_tree.py new file mode 100644 index 0000000000..5bcf063aa6 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/directory_tree.py @@ -0,0 +1,3 @@ +from memray._vendor.textual.widgets._directory_tree import DirEntry + +__all__ = ["DirEntry"] diff --git a/src/memray/_vendor/textual/widgets/input.py b/src/memray/_vendor/textual/widgets/input.py new file mode 100644 index 0000000000..2e3fc621e3 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/input.py @@ -0,0 +1,3 @@ +from memray._vendor.textual.widgets._input import Selection + +__all__ = ["Selection"] diff --git a/src/memray/_vendor/textual/widgets/markdown.py b/src/memray/_vendor/textual/widgets/markdown.py new file mode 100644 index 0000000000..8de4ef424f --- /dev/null +++ b/src/memray/_vendor/textual/widgets/markdown.py @@ -0,0 +1,17 @@ +from memray._vendor.textual.widgets._markdown import ( + Markdown, + MarkdownBlock, + MarkdownFence, + MarkdownStream, + MarkdownTableOfContents, + TableOfContentsType, +) + +__all__ = [ + "Markdown", + "MarkdownBlock", + "MarkdownFence", + "MarkdownStream", + "MarkdownTableOfContents", + "TableOfContentsType", +] diff --git a/src/memray/_vendor/textual/widgets/option_list.py b/src/memray/_vendor/textual/widgets/option_list.py new file mode 100644 index 0000000000..e2168442f1 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/option_list.py @@ -0,0 +1,3 @@ +from memray._vendor.textual.widgets._option_list import DuplicateID, Option, OptionDoesNotExist + +__all__ = ["DuplicateID", "Option", "OptionDoesNotExist"] diff --git a/src/memray/_vendor/textual/widgets/rule.py b/src/memray/_vendor/textual/widgets/rule.py new file mode 100644 index 0000000000..a171495bab --- /dev/null +++ b/src/memray/_vendor/textual/widgets/rule.py @@ -0,0 +1,13 @@ +from memray._vendor.textual.widgets._rule import ( + InvalidLineStyle, + InvalidRuleOrientation, + LineStyle, + RuleOrientation, +) + +__all__ = [ + "InvalidLineStyle", + "InvalidRuleOrientation", + "LineStyle", + "RuleOrientation", +] diff --git a/src/memray/_vendor/textual/widgets/select.py b/src/memray/_vendor/textual/widgets/select.py new file mode 100644 index 0000000000..238780df86 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/select.py @@ -0,0 +1,8 @@ +from memray._vendor.textual.widgets._select import ( + NULL, + EmptySelectError, + InvalidSelectValueError, + NoSelection, +) + +__all__ = ["EmptySelectError", "InvalidSelectValueError", "NoSelection", "NULL"] diff --git a/src/memray/_vendor/textual/widgets/selection_list.py b/src/memray/_vendor/textual/widgets/selection_list.py new file mode 100644 index 0000000000..aaa90d50f5 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/selection_list.py @@ -0,0 +1,8 @@ +from memray._vendor.textual.widgets._selection_list import ( + MessageSelectionType, + Selection, + SelectionError, + SelectionType, +) + +__all__ = ["MessageSelectionType", "Selection", "SelectionError", "SelectionType"] diff --git a/src/memray/_vendor/textual/widgets/tabbed_content.py b/src/memray/_vendor/textual/widgets/tabbed_content.py new file mode 100644 index 0000000000..64f901cccd --- /dev/null +++ b/src/memray/_vendor/textual/widgets/tabbed_content.py @@ -0,0 +1,6 @@ +from memray._vendor.textual.widgets._tabbed_content import ContentTab, ContentTabs + +__all__ = [ + "ContentTab", + "ContentTabs", +] diff --git a/src/memray/_vendor/textual/widgets/text_area.py b/src/memray/_vendor/textual/widgets/text_area.py new file mode 100644 index 0000000000..58c9b5969d --- /dev/null +++ b/src/memray/_vendor/textual/widgets/text_area.py @@ -0,0 +1,43 @@ +from memray._vendor.textual._text_area_theme import TextAreaTheme +from memray._vendor.textual.document._document import ( + Document, + DocumentBase, + EditResult, + Location, + Selection, +) +from memray._vendor.textual.document._document_navigator import DocumentNavigator +from memray._vendor.textual.document._edit import Edit +from memray._vendor.textual.document._history import EditHistory +from memray._vendor.textual.document._syntax_aware_document import SyntaxAwareDocument +from memray._vendor.textual.document._wrapped_document import WrappedDocument +from memray._vendor.textual.widgets._text_area import ( + EndColumn, + Highlight, + HighlightName, + LanguageDoesNotExist, + StartColumn, + ThemeDoesNotExist, + BUILTIN_LANGUAGES, +) + +__all__ = [ + "BUILTIN_LANGUAGES", + "Document", + "DocumentBase", + "DocumentNavigator", + "Edit", + "EditResult", + "EditHistory", + "EndColumn", + "Highlight", + "HighlightName", + "LanguageDoesNotExist", + "Location", + "Selection", + "StartColumn", + "SyntaxAwareDocument", + "TextAreaTheme", + "ThemeDoesNotExist", + "WrappedDocument", +] diff --git a/src/memray/_vendor/textual/widgets/tree.py b/src/memray/_vendor/textual/widgets/tree.py new file mode 100644 index 0000000000..22019a14b1 --- /dev/null +++ b/src/memray/_vendor/textual/widgets/tree.py @@ -0,0 +1,21 @@ +"""Make non-widget Tree support classes available.""" + +from memray._vendor.textual.widgets._tree import ( + AddNodeError, + EventTreeDataType, + NodeID, + RemoveRootError, + TreeDataType, + TreeNode, + UnknownNodeID, +) + +__all__ = [ + "AddNodeError", + "EventTreeDataType", + "NodeID", + "RemoveRootError", + "TreeDataType", + "TreeNode", + "UnknownNodeID", +] diff --git a/src/memray/_vendor/textual/worker.py b/src/memray/_vendor/textual/worker.py new file mode 100644 index 0000000000..eab5cf3402 --- /dev/null +++ b/src/memray/_vendor/textual/worker.py @@ -0,0 +1,455 @@ +""" +This module contains the `Worker` class and related objects. + +See the guide for how to use [workers](/guide/workers). + +""" + +from __future__ import annotations + +import asyncio +import enum +import inspect +from contextvars import ContextVar +from threading import Event +from time import monotonic +from typing import ( + TYPE_CHECKING, + Awaitable, + Callable, + Coroutine, + Generic, + TypeVar, + Union, + cast, +) + +import rich.repr +from typing_extensions import TypeAlias + +from memray._vendor.textual.message import Message + +if TYPE_CHECKING: + from memray._vendor.textual.app import App + from memray._vendor.textual.dom import DOMNode + + +active_worker: ContextVar[Worker] = ContextVar("active_worker") +"""Currently active worker context var.""" + + +class NoActiveWorker(Exception): + """There is no active worker.""" + + +class WorkerError(Exception): + """A worker related error.""" + + +class WorkerFailed(WorkerError): + """The worker raised an exception and did not complete.""" + + def __init__(self, error: BaseException) -> None: + self.error = error + super().__init__(f"Worker raised exception: {error!r}") + + +class DeadlockError(WorkerError): + """The operation would result in a deadlock.""" + + +class WorkerCancelled(WorkerError): + """The worker was cancelled and did not complete.""" + + +def get_current_worker() -> Worker: + """Get the currently active worker. + + Raises: + NoActiveWorker: If there is no active worker. + + Returns: + A Worker instance. + """ + try: + return active_worker.get() + except LookupError: + raise NoActiveWorker( + "There is no active worker in this task or thread." + ) from None + + +class WorkerState(enum.Enum): + """A description of the worker's current state.""" + + PENDING = 1 + """Worker is initialized, but not running.""" + RUNNING = 2 + """Worker is running.""" + CANCELLED = 3 + """Worker is not running, and was cancelled.""" + ERROR = 4 + """Worker is not running, and exited with an error.""" + SUCCESS = 5 + """Worker is not running, and completed successfully.""" + + +ResultType = TypeVar("ResultType") + + +WorkType: TypeAlias = Union[ + Callable[[], Coroutine[None, None, ResultType]], + Callable[[], ResultType], + Awaitable[ResultType], +] +"""Type used for [workers](/guide/workers/).""" + + +class _ReprText: + """Shim to insert a word into the Worker's repr.""" + + def __init__(self, text: str) -> None: + self.text = text + + def __repr__(self) -> str: + return self.text + + +@rich.repr.auto(angular=True) +class Worker(Generic[ResultType]): + """A class to manage concurrent work (either a task or a thread).""" + + @rich.repr.auto + class StateChanged(Message, bubble=False, namespace="worker"): + """The worker state changed.""" + + def __init__(self, worker: Worker, state: WorkerState) -> None: + """Initialize the StateChanged message. + + Args: + worker: The worker object. + state: New state. + """ + self.worker = worker + self.state = state + super().__init__() + + def __rich_repr__(self) -> rich.repr.Result: + yield self.worker + yield self.state + + def __init__( + self, + node: DOMNode, + work: WorkType, + *, + name: str = "", + group: str = "default", + description: str = "", + exit_on_error: bool = True, + thread: bool = False, + ) -> None: + """Initialize a Worker. + + Args: + node: The widget, screen, or App that initiated the work. + work: A callable, coroutine, or other awaitable object to run in the worker. + name: Name of the worker (short string to help identify when debugging). + group: The worker group. + description: Description of the worker (longer string with more details). + exit_on_error: Exit the app if the worker raises an error. Set to `False` to suppress exceptions. + thread: Mark the worker as a thread worker. + """ + self._node = node + self._work = work + self.name = name + self.group = group + self.description = ( + description if len(description) <= 1000 else description[:1000] + "..." + ) + self.exit_on_error = exit_on_error + self.cancelled_event: Event = Event() + """A threading event set when the worker is cancelled.""" + self._thread_worker = thread + self._state = WorkerState.PENDING + self.state = self._state + self._error: BaseException | None = None + self._completed_steps: int = 0 + self._total_steps: int | None = None + self._cancelled: bool = False + self._created_time = monotonic() + self._result: ResultType | None = None + self._task: asyncio.Task | None = None + self._node.post_message(self.StateChanged(self, self._state)) + + def __rich_repr__(self) -> rich.repr.Result: + yield _ReprText(self.state.name) + yield "name", self.name, "" + yield "group", self.group, "default" + yield "description", self.description, "" + yield "progress", round(self.progress, 1), 0.0 + + @property + def node(self) -> DOMNode: + """The node where this worker was run from.""" + return self._node + + @property + def state(self) -> WorkerState: + """The current state of the worker.""" + return self._state + + @state.setter + def state(self, state: WorkerState) -> None: + """Set the state, and send a message.""" + changed = state != self._state + self._state = state + if changed: + self._node.post_message(self.StateChanged(self, state)) + + @property + def is_cancelled(self) -> bool: + """Has the work been cancelled? + + Note that cancelled work may still be running. + """ + return self._cancelled + + @property + def is_running(self) -> bool: + """Is the task running?""" + return self.state == WorkerState.RUNNING + + @property + def is_finished(self) -> bool: + """Has the task finished (cancelled, error, or success)?""" + return self.state in ( + WorkerState.CANCELLED, + WorkerState.ERROR, + WorkerState.SUCCESS, + ) + + @property + def completed_steps(self) -> int: + """The number of completed steps.""" + return self._completed_steps + + @property + def total_steps(self) -> int | None: + """The number of total steps, or None if indeterminate.""" + return self._total_steps + + @property + def progress(self) -> float: + """Progress as a percentage. + + If the total steps is None, then this will return 0. The percentage will be clamped between 0 and 100. + """ + if not self._total_steps: + return 0.0 + return max(0, min(100, (self._completed_steps / self._total_steps) * 100.0)) + + @property + def result(self) -> ResultType | None: + """The result of the worker, or `None` if there is no result.""" + return self._result + + @property + def error(self) -> BaseException | None: + """The exception raised by the worker, or `None` if there was no error.""" + return self._error + + def update( + self, completed_steps: int | None = None, total_steps: int | None = -1 + ) -> None: + """Update the number of completed steps. + + Args: + completed_steps: The number of completed seps, or `None` to not change. + total_steps: The total number of steps, `None` for indeterminate, or -1 to leave unchanged. + """ + if completed_steps is not None: + self._completed_steps += completed_steps + if total_steps != -1: + self._total_steps = None if total_steps is None else max(0, total_steps) + + def advance(self, steps: int = 1) -> None: + """Advance the number of completed steps. + + Args: + steps: Number of steps to advance. + """ + self._completed_steps += steps + + async def _run_threaded(self) -> ResultType: + """Run a threaded worker. + + Returns: + Return value of the work. + """ + + def run_awaitable(work: Awaitable[ResultType]) -> ResultType: + """Set the active worker and await the awaitable.""" + + async def do_work() -> ResultType: + active_worker.set(self) + return await work + + return asyncio.run(do_work()) + + def run_coroutine( + work: Callable[[], Coroutine[None, None, ResultType]], + ) -> ResultType: + """Set the active worker and await coroutine.""" + return run_awaitable(work()) + + def run_callable(work: Callable[[], ResultType]) -> ResultType: + """Set the active worker, and call the callable.""" + active_worker.set(self) + return work() + + if ( + inspect.iscoroutinefunction(self._work) + or hasattr(self._work, "func") + and inspect.iscoroutinefunction(self._work.func) + ): + runner = run_coroutine + elif inspect.isawaitable(self._work): + runner = run_awaitable + elif callable(self._work): + runner = run_callable + else: + raise WorkerError("Unsupported attempt to run a thread worker") + + loop = asyncio.get_running_loop() + assert loop is not None + return await loop.run_in_executor(None, runner, self._work) + + async def _run_async(self) -> ResultType: + """Run an async worker. + + Returns: + Return value of the work. + """ + if ( + inspect.iscoroutinefunction(self._work) + or hasattr(self._work, "func") + and inspect.iscoroutinefunction(self._work.func) + ): + return await self._work() + elif inspect.isawaitable(self._work): + return await self._work + elif callable(self._work): + raise WorkerError("Request to run a non-async function as an async worker") + raise WorkerError("Unsupported attempt to run an async worker") + + async def run(self) -> ResultType: + """Run the work. + + Implement this method in a subclass, or pass a callable to the constructor. + + Returns: + Return value of the work. + """ + return await ( + self._run_threaded() if self._thread_worker else self._run_async() + ) + + async def _run(self, app: App) -> None: + """Run the worker. + + Args: + app: App instance. + """ + with app._context(): + active_worker.set(self) + + self.state = WorkerState.RUNNING + app.log.worker(self) + try: + self._result = await self.run() + except asyncio.CancelledError as error: + self.state = WorkerState.CANCELLED + self._error = error + app.log.worker(self) + except Exception as error: + self.state = WorkerState.ERROR + self._error = error + app.log.worker(self, "failed", repr(error)) + from rich.traceback import Traceback + + app.log.worker(Traceback()) + if self.exit_on_error: + worker_failed = WorkerFailed(self._error) + app._handle_exception(worker_failed) + else: + self.state = WorkerState.SUCCESS + app.log.worker(self) + + def _start( + self, app: App, done_callback: Callable[[Worker], None] | None = None + ) -> None: + """Start the worker. + + Args: + app: An app instance. + done_callback: A callback to call when the task is done. + """ + if self._task is not None: + return + self.state = WorkerState.RUNNING + self._task = asyncio.create_task(self._run(app)) + + def task_done_callback(_task: asyncio.Task) -> None: + """Run the callback. + + Called by `Task.add_done_callback`. + + Args: + The worker's task. + """ + if done_callback is not None: + done_callback(self) + + self._task.add_done_callback(task_done_callback) + + def cancel(self) -> None: + """Cancel the task.""" + self._cancelled = True + if self._task is not None: + self._task.cancel() + self.cancelled_event.set() + + async def wait(self) -> ResultType: + """Wait for the work to complete. + + Raises: + WorkerFailed: If the Worker raised an exception. + WorkerCancelled: If the Worker was cancelled before it completed. + + Returns: + The return value of the work. + """ + try: + if active_worker.get() is self: + raise DeadlockError( + "Can't call worker.wait from within the worker function!" + ) + except LookupError: + # Not in a worker + pass + + if self.state == WorkerState.PENDING: + raise WorkerError("Worker must be started before calling this method.") + if self._task is not None: + try: + await self._task + except asyncio.CancelledError as error: + self.state = WorkerState.CANCELLED + self._error = error + if self.state == WorkerState.ERROR: + assert self._error is not None + raise WorkerFailed(self._error) + elif self.state == WorkerState.CANCELLED: + raise WorkerCancelled("Worker was cancelled, and did not complete.") + return cast("ResultType", self._result) diff --git a/src/memray/_vendor/textual/worker_manager.py b/src/memray/_vendor/textual/worker_manager.py new file mode 100644 index 0000000000..008b8595ad --- /dev/null +++ b/src/memray/_vendor/textual/worker_manager.py @@ -0,0 +1,181 @@ +""" +Contains `WorkerManager`, a class to manage [workers](/guide/workers) for an app. + +You access this object via [App.workers][textual.app.App.workers] or [Widget.workers][textual.dom.DOMNode.workers]. +""" + +from __future__ import annotations + +import asyncio +from collections import Counter +from operator import attrgetter +from typing import TYPE_CHECKING, Any, Iterable, Iterator + +import rich.repr + +from memray._vendor.textual.worker import Worker, WorkerState, WorkType + +if TYPE_CHECKING: + from memray._vendor.textual.app import App + from memray._vendor.textual.dom import DOMNode + + +@rich.repr.auto(angular=True) +class WorkerManager: + """An object to manage a number of workers. + + You will not have to construct this class manually, as widgets, screens, and apps + have a worker manager accessibly via a `workers` attribute. + """ + + def __init__(self, app: App) -> None: + """Initialize a worker manager. + + Args: + app: An App instance. + """ + self._app = app + """A reference to the app.""" + self._workers: set[Worker] = set() + """The workers being managed.""" + + def __rich_repr__(self) -> rich.repr.Result: + counter: Counter[WorkerState] = Counter() + counter.update(worker.state for worker in self._workers) + for state, count in sorted(counter.items()): + yield state.name, count + + def __iter__(self) -> Iterator[Worker[Any]]: + return iter(sorted(self._workers, key=attrgetter("_created_time"))) + + def __reversed__(self) -> Iterator[Worker[Any]]: + return iter( + sorted(self._workers, key=attrgetter("_created_time"), reverse=True) + ) + + def __bool__(self) -> bool: + return bool(self._workers) + + def __len__(self) -> int: + return len(self._workers) + + def __contains__(self, worker: object) -> bool: + return worker in self._workers + + def add_worker( + self, worker: Worker, start: bool = True, exclusive: bool = True + ) -> None: + """Add a new worker. + + Args: + worker: A Worker instance. + start: Start the worker if True, otherwise the worker must be started manually. + exclusive: Cancel all workers in the same group as `worker`. + """ + if exclusive and worker.group: + self.cancel_group(worker.node, worker.group) + self._workers.add(worker) + if start: + worker._start(self._app, self._remove_worker) + + def _new_worker( + self, + work: WorkType, + node: DOMNode, + *, + name: str | None = "", + group: str = "default", + description: str = "", + exit_on_error: bool = True, + start: bool = True, + exclusive: bool = False, + thread: bool = False, + ) -> Worker: + """Create a worker from a function, coroutine, or awaitable. + + Args: + work: A callable, a coroutine, or other awaitable. + name: A name to identify the worker. + group: The worker group. + description: A description of the worker. + exit_on_error: Exit the app if the worker raises an error. Set to `False` to suppress exceptions. + start: Automatically start the worker. + exclusive: Cancel all workers in the same group. + thread: Mark the worker as a thread worker. + + Returns: + A Worker instance. + """ + worker: Worker[Any] = Worker( + node, + work, + name=name or getattr(work, "__name__", "") or "", + group=group, + description=description or repr(work), + exit_on_error=exit_on_error, + thread=thread, + ) + self.add_worker(worker, start=start, exclusive=exclusive) + return worker + + def _remove_worker(self, worker: Worker) -> None: + """Remove a worker from the manager. + + Args: + worker: A Worker instance. + """ + self._workers.discard(worker) + + def start_all(self) -> None: + """Start all the workers.""" + for worker in self._workers: + worker._start(self._app, self._remove_worker) + + def cancel_all(self) -> None: + """Cancel all workers.""" + for worker in self._workers: + worker.cancel() + + def cancel_group(self, node: DOMNode, group: str) -> list[Worker]: + """Cancel a single group. + + Args: + node: Worker DOM node. + group: A group name. + + Returns: + A list of workers that were cancelled. + """ + workers = [ + worker + for worker in self._workers + if (worker.group == group and worker.node == node) + ] + for worker in workers: + worker.cancel() + return workers + + def cancel_node(self, node: DOMNode) -> list[Worker]: + """Cancel all workers associated with a given node + + Args: + node: A DOM node (widget, screen, or App). + + Returns: + List of cancelled workers. + """ + workers = [worker for worker in self._workers if worker.node == node] + for worker in workers: + worker.cancel() + return workers + + async def wait_for_complete(self, workers: Iterable[Worker] | None = None) -> None: + """Wait for workers to complete. + + Args: + workers: An iterable of workers or None to wait for all workers in the manager. + """ + try: + await asyncio.gather(*[worker.wait() for worker in (workers or self)]) + except asyncio.CancelledError: + pass diff --git a/src/memray/reporters/_textual_hacks.py b/src/memray/reporters/_textual_hacks.py deleted file mode 100644 index 2be6329124..0000000000 --- a/src/memray/reporters/_textual_hacks.py +++ /dev/null @@ -1,40 +0,0 @@ -import dataclasses -from typing import Any -from typing import Dict -from typing import Tuple -from typing import Union - -from textual import binding -from textual.app import App -from textual.binding import Binding -from textual.dom import DOMNode -from textual.widgets import Footer - -# In Textual 0.61, `App.namespace_bindings` was removed in favor of -# `Screen.active_bindings`. The two have a slightly different interface: -# a 2 item `tuple` was updated to a 3 item `namedtuple`. -# The `Bindings` type alias shows the two possible structures. -# The `update_key_description` implementation works for both, -# since we still support Textual versions below 0.61. - -Bindings = Union[Dict[str, "binding.ActiveBinding"], Dict[str, Tuple[DOMNode, Binding]]] - - -def update_key_description(bindings: Bindings, key: str, description: str) -> None: - val = bindings[key] - binding = dataclasses.replace(val[1], description=description) - if type(val) is tuple: - bindings[key] = val[:1] + (binding,) + val[2:] # type: ignore - else: - bindings[key] = val._replace(binding=binding) # type: ignore - - -def redraw_footer(app: App[Any]) -> None: - footer = app.screen.query_one(Footer) - if hasattr(footer, "recompose"): - # Added in Textual v0.53 - footer.refresh(recompose=True) - else: # pragma: no cover - # Hack: trick the Footer into redrawing itself - footer.highlight_key = "q" # type: ignore[attr-defined] - footer.highlight_key = None # type: ignore[attr-defined] diff --git a/src/memray/reporters/tree.py b/src/memray/reporters/tree.py index 570c4a5f7a..3f5d66f679 100644 --- a/src/memray/reporters/tree.py +++ b/src/memray/reporters/tree.py @@ -1,4 +1,5 @@ import asyncio +import dataclasses import functools import linecache import sys @@ -15,31 +16,27 @@ from rich.style import Style from rich.text import Text -from textual import binding -from textual import work -from textual.app import App -from textual.app import ComposeResult -from textual.binding import Binding -from textual.color import Color -from textual.color import Gradient -from textual.containers import Grid -from textual.containers import Horizontal -from textual.containers import Vertical -from textual.dom import DOMNode -from textual.reactive import reactive -from textual.screen import Screen -from textual.widget import Widget -from textual.widgets import Footer -from textual.widgets import Label -from textual.widgets import TextArea -from textual.widgets import Tree -from textual.widgets.tree import TreeNode from memray import AllocationRecord from memray._memray import size_fmt -from memray.reporters._textual_hacks import Bindings -from memray.reporters._textual_hacks import redraw_footer -from memray.reporters._textual_hacks import update_key_description +from memray._vendor.textual import work +from memray._vendor.textual.app import App +from memray._vendor.textual.app import ComposeResult +from memray._vendor.textual.binding import ActiveBinding +from memray._vendor.textual.binding import Binding +from memray._vendor.textual.color import Color +from memray._vendor.textual.color import Gradient +from memray._vendor.textual.containers import Grid +from memray._vendor.textual.containers import Horizontal +from memray._vendor.textual.containers import Vertical +from memray._vendor.textual.reactive import reactive +from memray._vendor.textual.screen import Screen +from memray._vendor.textual.widget import Widget +from memray._vendor.textual.widgets import Footer +from memray._vendor.textual.widgets import Label +from memray._vendor.textual.widgets import TextArea +from memray._vendor.textual.widgets import Tree +from memray._vendor.textual.widgets.tree import TreeNode from memray.reporters.common import format_thread_name from memray.reporters.frame_tools import is_cpython_internal from memray.reporters.frame_tools import is_frame_from_import_system @@ -53,6 +50,16 @@ ROOT_NODE = ("", "", 0) +def _ensure_event_loop() -> None: + # Vendored Textual may create asyncio.Lock during widget/screen init + # (via textual.rlock). App objects can be constructed before a current + # loop exists, so bootstrap one to avoid runtime failures. + try: + asyncio.get_event_loop() + except RuntimeError: + asyncio.set_event_loop(asyncio.new_event_loop()) + + @dataclass class Frame: """A frame in the tree""" @@ -369,7 +376,7 @@ def action_toggle_import_system(self) -> None: else: self.import_system_filter = None - redraw_footer(self.app) + self.app.screen.query_one(Footer).refresh(recompose=True) self.repopulate_tree(self.query_one(FrameTree)) def action_toggle_uninteresting(self) -> None: @@ -378,17 +385,27 @@ def action_toggle_uninteresting(self) -> None: else: self.uninteresting_filter = None - redraw_footer(self.app) + self.app.screen.query_one(Footer).refresh(recompose=True) self.repopulate_tree(self.query_one(FrameTree)) - def rewrite_bindings(self, bindings: Bindings) -> None: + def rewrite_bindings(self, bindings: Dict[str, ActiveBinding]) -> None: if self.import_system_filter is not None: - update_key_description(bindings, "i", "Show import system") + ab = bindings["i"] + bindings["i"] = ab._replace( + binding=dataclasses.replace( + ab.binding, description="Show import system" + ) + ) if self.uninteresting_filter is not None: - update_key_description(bindings, "u", "Show uninteresting") + ab = bindings["u"] + bindings["u"] = ab._replace( + binding=dataclasses.replace( + ab.binding, description="Show uninteresting" + ) + ) @property - def active_bindings(self) -> Dict[str, "binding.ActiveBinding"]: + def active_bindings(self) -> Dict[str, ActiveBinding]: bindings = super().active_bindings.copy() self.rewrite_bindings(bindings) return bindings @@ -400,20 +417,13 @@ def __init__( data: Frame, elided_locations: ElidedLocations, ): + _ensure_event_loop() super().__init__() self.tree_screen = TreeScreen(data, elided_locations) def on_mount(self) -> None: self.push_screen(self.tree_screen) - if hasattr(App, "namespace_bindings"): - # Removed in Textual 0.61 - @property - def namespace_bindings(self) -> Dict[str, Tuple[DOMNode, Binding]]: - bindings = super().namespace_bindings.copy() # type: ignore[misc] - self.tree_screen.rewrite_bindings(bindings) - return bindings # type: ignore[no-any-return] - @functools.lru_cache(maxsize=None) def _percentage_to_color(percentage: int) -> Color: diff --git a/src/memray/reporters/tui.py b/src/memray/reporters/tui.py index 3a51386359..0e1fd89f3e 100644 --- a/src/memray/reporters/tui.py +++ b/src/memray/reporters/tui.py @@ -1,4 +1,6 @@ +import asyncio import contextlib +import dataclasses import os import pathlib import sys @@ -16,44 +18,50 @@ from typing import List from typing import Optional from typing import Set -from typing import Tuple from typing import cast from rich.markup import escape from rich.segment import Segment from rich.style import Style from rich.text import Text -from textual import events -from textual import log -from textual.app import App -from textual.app import ComposeResult -from textual.binding import Binding -from textual.color import Color -from textual.color import Gradient -from textual.containers import Container -from textual.containers import HorizontalScroll -from textual.dom import DOMNode -from textual.message import Message -from textual.reactive import reactive -from textual.screen import Screen -from textual.strip import Strip -from textual.widget import Widget -from textual.widgets import DataTable -from textual.widgets import Footer -from textual.widgets import Label -from textual.widgets import Static -from textual.widgets.data_table import RowKey from memray import AllocationRecord from memray import SocketReader from memray._memray import size_fmt -from memray.reporters._textual_hacks import Bindings -from memray.reporters._textual_hacks import redraw_footer -from memray.reporters._textual_hacks import update_key_description +from memray._vendor.textual import events +from memray._vendor.textual import log +from memray._vendor.textual.app import App +from memray._vendor.textual.app import ComposeResult +from memray._vendor.textual.binding import ActiveBinding +from memray._vendor.textual.binding import Binding +from memray._vendor.textual.color import Color +from memray._vendor.textual.color import Gradient +from memray._vendor.textual.containers import Container +from memray._vendor.textual.containers import HorizontalScroll +from memray._vendor.textual.message import Message +from memray._vendor.textual.reactive import reactive +from memray._vendor.textual.screen import Screen +from memray._vendor.textual.strip import Strip +from memray._vendor.textual.widget import Widget +from memray._vendor.textual.widgets import DataTable +from memray._vendor.textual.widgets import Footer +from memray._vendor.textual.widgets import Label +from memray._vendor.textual.widgets import Static +from memray._vendor.textual.widgets.data_table import RowKey MAX_MEMORY_RATIO = 0.95 +def _ensure_event_loop() -> None: + # Vendored Textual may create asyncio.Lock during app/widget init. + # TUIApp can be constructed before a current loop exists, so bootstrap + # one up front to avoid runtime failures. + try: + asyncio.get_event_loop() + except RuntimeError: + asyncio.set_event_loop(asyncio.new_event_loop()) + + @dataclass(frozen=True) class Location: function: str @@ -585,7 +593,7 @@ def _populate_header_thread_labels(self, thread_idx: int) -> None: def action_toggle_merge_threads(self) -> None: """An action to toggle showing allocations from all threads together.""" self._merge_threads = not self._merge_threads - redraw_footer(self.app) + self.app.screen.query_one(Footer).refresh(recompose=True) self.app.screen.query_one(AllocationTable).merge_threads = self._merge_threads self._populate_header_thread_labels(self.thread_idx) @@ -593,7 +601,7 @@ def action_toggle_pause(self) -> None: """Toggle pause on keypress""" if self.paused or not self.disconnected: self.paused = not self.paused - redraw_footer(self.app) + self.app.screen.query_one(Footer).refresh(recompose=True) if not self.paused: self.display_snapshot() @@ -613,7 +621,7 @@ def watch_threads(self) -> None: def watch_disconnected(self) -> None: self.update_label() - redraw_footer(self.app) + self.app.screen.query_one(Footer).refresh(recompose=True) def watch_paused(self) -> None: self.update_label() @@ -686,20 +694,26 @@ def update_sort_key(self, col_number: int) -> None: body = self.query_one(AllocationTable) body.sort_column_id = col_number - def rewrite_bindings(self, bindings: Bindings) -> None: - if "space" in bindings and bindings["space"][1].description == "Pause": + def rewrite_bindings(self, bindings: Dict[str, ActiveBinding]) -> None: + if "space" in bindings and bindings["space"].binding.description == "Pause": if self.paused: - update_key_description(bindings, "space", "Unpause") + ab = bindings["space"] + bindings["space"] = ab._replace( + binding=dataclasses.replace(ab.binding, description="Unpause") + ) elif self.disconnected: del bindings["space"] if self._merge_threads: - bindings.pop("less_than_sign") - bindings.pop("greater_than_sign") - update_key_description(bindings, "m", "Unmerge Threads") + bindings.pop("less_than_sign", None) + bindings.pop("greater_than_sign", None) + ab = bindings["m"] + bindings["m"] = ab._replace( + binding=dataclasses.replace(ab.binding, description="Unmerge Threads") + ) @property - def active_bindings(self) -> Dict[str, Any]: + def active_bindings(self) -> Dict[str, ActiveBinding]: bindings = super().active_bindings.copy() self.rewrite_bindings(bindings) return bindings @@ -760,6 +774,7 @@ def __init__( cmdline_override: Optional[str] = None, poll_interval: float = 1.0, ) -> None: + _ensure_event_loop() self._reader = reader self._poll_interval = poll_interval self._cmdline_override = cmdline_override @@ -804,12 +819,3 @@ def on_snapshot_fetched(self, message: SnapshotFetched) -> None: def on_resize(self, event: events.Resize) -> None: self.set_class(0 <= event.size.width < 81, "narrow") - - if hasattr(App, "namespace_bindings"): - # Removed in Textual 0.61 - @property - def namespace_bindings(self) -> Dict[str, Tuple[DOMNode, Binding]]: - bindings = super().namespace_bindings.copy() # type: ignore[misc] - if self.tui: - self.tui.rewrite_bindings(bindings) - return bindings # type: ignore[no-any-return] diff --git a/tests/conftest.py b/tests/conftest.py index 827cfcbffd..f18b706be8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,14 +1,43 @@ +import importlib +import pkgutil import socket import sys import pytest +import pytest_textual_snapshot as _pytest_textual_snapshot from packaging import version +# Patch textual imports to avoid pytest-textual-snapshot loading it from the environment +import memray._vendor.textual as _vendored_textual +from memray._vendor.textual.app import App as _vendored_textual_app + +sys.modules["textual"] = _vendored_textual +_pytest_textual_snapshot.App = _vendored_textual_app + + +def _alias_vendored_textual_submodule(modname: str) -> None: + try: + mod = importlib.import_module(modname) + except Exception: + return + bare_name = modname.replace("memray._vendor.", "", 1) + sys.modules[bare_name] = mod + + +# Also inject commonly used submodules that pytest-textual-snapshot accesses +for _importer, _modname, _ispkg in pkgutil.walk_packages( + _vendored_textual.__path__, + prefix="memray._vendor.textual.", +): + _alias_vendored_textual_submodule(_modname) + + SNAPSHOT_MINIMUM_VERSIONS = { - "textual": "6.8.0", "pytest-textual-snapshot": "1.1.0", } +VENDORED_TEXTUAL_VERSION = _vendored_textual.__version__ + @pytest.fixture def free_port(): @@ -20,11 +49,10 @@ def free_port(): def _snapshot_skip_reason(): - if sys.version_info < (3, 8): - # Every version available for 3.7 is too old - return f"snapshot tests require textual>={SNAPSHOT_MINIMUM_VERSIONS['textual']}" + if sys.version_info < (3, 9): + return "snapshot tests require Python >= 3.9" - from importlib import metadata # Added in 3.8 + from importlib import metadata for lib, min_ver in SNAPSHOT_MINIMUM_VERSIONS.items(): try: @@ -40,7 +68,7 @@ def _snapshot_skip_reason(): def pytest_configure(config): if config.option.update_snapshots: - from importlib import metadata # Added in 3.8 + from importlib import metadata for lib, min_ver in SNAPSHOT_MINIMUM_VERSIONS.items(): ver = version.parse(metadata.version(lib)) diff --git a/tests/unit/test_tree_reporter.py b/tests/unit/test_tree_reporter.py index adbb0fe22d..a62a7a8436 100644 --- a/tests/unit/test_tree_reporter.py +++ b/tests/unit/test_tree_reporter.py @@ -13,12 +13,12 @@ from unittest.mock import patch import pytest -from textual.pilot import Pilot -from textual.widgets import Tree -from textual.widgets.tree import TreeNode from memray import AllocationRecord from memray import AllocatorType +from memray._vendor.textual.pilot import Pilot +from memray._vendor.textual.widgets import Tree +from memray._vendor.textual.widgets.tree import TreeNode from memray.reporters.tree import MAX_STACKS from memray.reporters.tree import Frame from memray.reporters.tree import TreeReporter diff --git a/tests/unit/test_tui_reporter.py b/tests/unit/test_tui_reporter.py index 382fdbc8c3..987fd4f6a3 100644 --- a/tests/unit/test_tui_reporter.py +++ b/tests/unit/test_tui_reporter.py @@ -12,16 +12,16 @@ import pytest from rich import print as rprint -from textual.app import App -from textual.coordinate import Coordinate -from textual.pilot import Pilot -from textual.widget import Widget -from textual.widgets import DataTable -from textual.widgets import Label import memray.reporters.tui from memray import AllocationRecord from memray import AllocatorType +from memray._vendor.textual.app import App +from memray._vendor.textual.coordinate import Coordinate +from memray._vendor.textual.pilot import Pilot +from memray._vendor.textual.widget import Widget +from memray._vendor.textual.widgets import DataTable +from memray._vendor.textual.widgets import Label from memray.reporters.tui import Location from memray.reporters.tui import MemoryGraph from memray.reporters.tui import Snapshot diff --git a/tools/check_vendor_versions.py b/tools/check_vendor_versions.py new file mode 100644 index 0000000000..fd7f632bed --- /dev/null +++ b/tools/check_vendor_versions.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +import pathlib +import re +import sys + +ROOT = pathlib.Path(__file__).resolve().parent.parent +VENDOR_TXT = ROOT / "vendor.txt" +TEXTUAL_VERSION_PATCH = ( + ROOT / "tools" / "vendoring" / "patches" / "textual-version.patch" +) +SETUP_PY = ROOT / "setup.py" + + +def _read_vendor_txt_version() -> str: + match = re.search( + r"^textual==([^\s]+)$", + VENDOR_TXT.read_text(encoding="utf-8"), + re.MULTILINE, + ) + if match is None: + raise SystemExit(f"missing textual pin in {VENDOR_TXT}") + return match.group(1) + + +def _read_patch_version() -> str: + match = re.search( + r'^\+__version__ = "([^"]+)"$', + TEXTUAL_VERSION_PATCH.read_text(encoding="utf-8"), + re.MULTILINE, + ) + if match is None: + raise SystemExit( + f"missing textual __version__ patch in {TEXTUAL_VERSION_PATCH}" + ) + return match.group(1) + + +def _read_setup_py_test_pin() -> str: + match = re.search( + r'^\s*"textual==([^"]+)",\s*$', + SETUP_PY.read_text(encoding="utf-8"), + re.MULTILINE, + ) + if match is None: + raise SystemExit(f"missing textual test pin in {SETUP_PY}") + return match.group(1) + + +def main() -> int: + vendor_txt_version = _read_vendor_txt_version() + patch_version = _read_patch_version() + setup_py_version = _read_setup_py_test_pin() + if vendor_txt_version != patch_version: + print( + "textual version mismatch:" + f" vendor.txt has {vendor_txt_version}," + f" textual-version.patch has {patch_version}", + file=sys.stderr, + ) + return 1 + if vendor_txt_version != setup_py_version: + print( + "textual version mismatch:" + f" vendor.txt has {vendor_txt_version}," + f" setup.py has {setup_py_version}", + file=sys.stderr, + ) + return 1 + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tools/vendoring/patches/textual-version.patch b/tools/vendoring/patches/textual-version.patch new file mode 100644 index 0000000000..170aba64ae --- /dev/null +++ b/tools/vendoring/patches/textual-version.patch @@ -0,0 +1,51 @@ +diff --git a/src/memray/_vendor/textual/__init__.py b/src/memray/_vendor/textual/__init__.py +--- a/src/memray/_vendor/textual/__init__.py ++++ b/src/memray/_vendor/textual/__init__.py +@@ -33,23 +33,10 @@ + LogCallable: TypeAlias = "Callable" + + +-if TYPE_CHECKING: +- from importlib.metadata import version +- +- from textual.app import App as _App +- +- __version__ = version("textual") +- """The version of Textual.""" +- +-else: +- +- def __getattr__(name: str) -> str: +- """Lazily get the version.""" +- if name == "__version__": +- from importlib.metadata import version +- +- return version("textual") +- raise AttributeError(f"module {__name__!r} has no attribute {name!r}") ++__version__ = "8.2.1" ++ ++if TYPE_CHECKING: ++ from textual.app import App as _App + + + class LoggerError(Exception): +diff --git a/src/memray/_vendor/textual/widgets/__init__.py b/src/memray/_vendor/textual/widgets/__init__.py +--- a/src/memray/_vendor/textual/widgets/__init__.py ++++ b/src/memray/_vendor/textual/widgets/__init__.py +@@ -106,10 +106,14 @@ def __getattr__(widget_class: str) -> type[Widget]: + pass + + if widget_class not in __all__: +- raise AttributeError(f"Package 'textual.widgets' has no class '{widget_class}'") ++ raise AttributeError( ++ f"Package 'memray._vendor.textual.widgets' has no class '{widget_class}'" ++ ) + + widget_module_path = f"._{camel_to_snake(widget_class)}" +- module = import_module(widget_module_path, package="textual.widgets") ++ module = import_module( ++ widget_module_path, package="memray._vendor.textual.widgets" ++ ) + class_ = getattr(module, widget_class) + + _WIDGETS_LAZY_LOADING_CACHE[widget_class] = class_ diff --git a/vendor.txt b/vendor.txt new file mode 100644 index 0000000000..8a7816574c --- /dev/null +++ b/vendor.txt @@ -0,0 +1 @@ +textual==8.2.1