diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ee1559c..13a133d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,22 +2,24 @@ name: CI on: push: - branches: [main] + branches: [main, dev] pull_request: - branches: [main] + branches: [main, dev] + workflow_call: jobs: test: runs-on: ubuntu-latest strategy: + fail-fast: false matrix: - python-version: ["3.11"] + python-version: ["3.11", "3.12", "3.13"] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Install uv - uses: astral-sh/setup-uv@v5 + uses: astral-sh/setup-uv@v7 with: enable-cache: true diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..9a83c6e --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,61 @@ +name: Publish to PyPI + +on: + workflow_dispatch: + +permissions: + contents: write + id-token: write + +jobs: + ci: + uses: ./.github/workflows/ci.yml + + publish: + needs: ci + runs-on: ubuntu-latest + if: github.ref == 'refs/heads/main' + environment: pypi + steps: + - uses: actions/checkout@v6 + with: + fetch-depth: 0 + + - name: Install uv + uses: astral-sh/setup-uv@v8.0.0 + with: + python-version: "3.12" + + - name: Get version + id: version + run: | + VERSION=$(uv run python -c "import tomllib; print(tomllib.load(open('pyproject.toml','rb'))['project']['version'])") + echo "version=$VERSION" >> $GITHUB_OUTPUT + + - name: Check version not already published + run: | + VERSION="${{ steps.version.outputs.version }}" + if uv pip index versions designer-plugin 2>/dev/null | grep -q "$VERSION"; then + echo "Version $VERSION already exists on PyPI. Aborting." + exit 1 + fi + + - name: Validate tag does not exist + run: | + if git rev-parse "v${{ steps.version.outputs.version }}" >/dev/null 2>&1; then + echo "Tag v${{ steps.version.outputs.version }} already exists. Aborting." + exit 1 + fi + + - name: Build package + run: uv build + + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + + - name: Tag release + run: | + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + git tag "v${{ steps.version.outputs.version }}" + git push origin "v${{ steps.version.outputs.version }}" diff --git a/.github/workflows/test-publish.yml b/.github/workflows/test-publish.yml new file mode 100644 index 0000000..c9b9f84 --- /dev/null +++ b/.github/workflows/test-publish.yml @@ -0,0 +1,64 @@ +name: Publish to Test PyPI + +on: + workflow_dispatch: + +permissions: + contents: read + id-token: write # Required for trusted publishing + +jobs: + test-publish: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v6 + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: '3.12' # tomllib requires >= 3.11 + + - name: Install uv + uses: astral-sh/setup-uv@v7 + + - name: Append .dev suffix for unique Test PyPI versions + run: | + python -c " + import tomllib, pathlib, re + path = pathlib.Path('pyproject.toml') + text = path.read_text() + data = tomllib.loads(text) + version = data['project']['version'] + dev_version = f'{version}.dev${{ github.run_number }}' + # Only replace the version inside the [project] section to avoid + # accidentally matching a version key in [tool.*] sections. + def replace_in_project_section(text, old_ver, new_ver): + project_match = re.search(r'^\[project\]', text, re.MULTILINE) + if not project_match: + raise RuntimeError('[project] section not found in pyproject.toml') + start = project_match.start() + # Find the next top-level section header or end of file + next_section = re.search(r'^\[(?!project[.\]])', text[start+1:], re.MULTILINE) + end = (start + 1 + next_section.start()) if next_section else len(text) + section = text[start:end] + section = re.sub( + r'(version\s*=\s*\")' + re.escape(old_ver) + r'\"', + r'\g<1>' + new_ver + '\"', + section, count=1, + ) + return text[:start] + section + text[end:] + text = replace_in_project_section(text, version, dev_version) + path.write_text(text) + print(f'Version set to {dev_version}') + " + + - name: Build package + run: uv build + + - name: Publish to Test PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + repository-url: https://test.pypi.org/legacy/ + skip-existing: true diff --git a/CHANGELOG.md b/CHANGELOG.md index b4df8d3..1ad91b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,26 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [1.3.0] - 2026-01-06 + +### Added +- **Lazy module registration**: `D3Session.execute()` and `D3AsyncSession.execute()` now automatically register a `@d3function` module on first use, eliminating the need to declare all modules in `context_modules` upfront. +- `registered_modules` tracking on session instances prevents duplicate registration calls. +- **Jupyter notebook support**: `@d3function` now automatically replaces a previously registered function when the same name is re-registered in the same module, with a warning log. This enables iterative workflows in Jupyter notebooks where cells are re-executed. +- **Automatic import detection**: `@d3function` now automatically discovers file-level imports used by the decorated function and includes them in the registered module. In Jupyter notebooks, place imports inside the function body instead. + +### Removed +- `add_packages_in_current_file()`: Removed. Imports are now detected automatically by `@d3function`. +- `find_packages_in_current_file()`: Removed. Replaced by `find_imports_for_function()`. + +### Changed +- `d3_api_plugin` has been renamed to `d3_api_execute`. +- `d3_api_aplugin` has been renamed to `d3_api_aexecute`. +- `context_modules` parameter type updated from `list[str]` to `set[str]` on `D3Session`, `D3AsyncSession`, and `D3SessionBase`. +- Updated documentation to reflect `pystub` proxy support. +- Bumped `actions/checkout` to v6 and `astral-sh/setup-uv` to v7 in CI. +- Added Test PyPI publish workflow (`test-publish.yml`) for dev version releases. + ## [1.2.0] - 2025-12-02 ### Added diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 9f5e34d..fe4a450 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -31,14 +31,19 @@ Thank you for your interest in contributing to designer-plugin! This document pr ### Running Tests -Run the full test suite: +Run unit tests (default): ```bash uv run pytest ``` -Run tests with verbose output: +Run integration tests (requires a running d3 instance): ```bash -uv run pytest -v +uv run pytest -m integration +``` + +Run all tests: +```bash +uv run pytest -m "" ``` Run specific test file: diff --git a/README.md b/README.md index 3c754f9..1d39b3c 100644 --- a/README.md +++ b/README.md @@ -83,11 +83,11 @@ To enable IDE autocomplete and type checking for Designer's Python API, install pip install designer-plugin-pystub ``` -Once installed, import the stubs using the `TYPE_CHECKING` pattern. This provides type hints in your IDE without affecting runtime execution: +Once installed, import the stubs. +> **Important:** `pystub` provides type hints for Designer's API objects but not their implementations. These objects only exist in Designer's runtime and cannot be used in local Python code. They must only be referenced in code that will be executed remotely in Designer. + ```python -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from designer_plugin.pystub.d3 import * +from designer_plugin.pystub import * ``` This allows you to get autocomplete for Designer objects like `resourceManager`, `Screen2`, `Path`, etc., while writing your plugin code. @@ -100,9 +100,7 @@ The Client API allows you to define a class with methods that execute remotely o ```python from designer_plugin.d3sdk import D3PluginClient -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from designer_plugin.pystub.d3 import * +from designer_plugin.pystub import * # 1. Sync example ----------------------------------- class MySyncPlugin(D3PluginClient): @@ -169,7 +167,15 @@ The Functional API offers two decorators: `@d3pythonscript` and `@d3function`: - **`@d3function`**: - Must be registered on Designer before execution. - Functions decorated with the same `module_name` are grouped together and can call each other, enabling function chaining and code reuse. - - Registration is automatic when you pass module names to the session context manager (e.g., `D3AsyncSession('localhost', 80, ["mymodule"])`). If you don't provide module names, no registration occurs. + - Registration happens automatically on the first call to `execute()` or `rpc()` that references the module — no need to declare modules upfront. You can also pre-register specific modules by passing them to the session context manager (e.g., `D3AsyncSession('localhost', 80, {"mymodule"})`). + +> **Jupyter Notebook:** File-level imports (e.g., `import numpy as np` in a separate cell) cannot be automatically detected. In Jupyter, place any required imports inside the function body itself: +> ```python +> @d3function("mymodule") +> def my_fn(): +> import numpy as np +> return np.array([1, 2]) +> ``` ### Session API Methods @@ -186,9 +192,7 @@ Both `D3AsyncSession` and `D3Session` provide two methods for executing function ```python from designer_plugin.d3sdk import d3pythonscript, d3function, D3AsyncSession -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from designer_plugin.pystub.d3 import * +from designer_plugin.pystub import * # 1. @d3pythonscript - simple one-off execution @d3pythonscript @@ -213,11 +217,11 @@ def my_time() -> str: return str(datetime.datetime.now()) # Usage with async session -async with D3AsyncSession('localhost', 80, ["mymodule"]) as session: +async with D3AsyncSession('localhost', 80) as session: # d3pythonscript: no registration needed await session.rpc(rename_surface.payload("surface 1", "surface 2")) - # d3function: registered automatically via context manager + # d3function: module is registered automatically on first call time: str = await session.rpc( rename_surface_get_time.payload("surface 1", "surface 2")) @@ -230,7 +234,7 @@ async with D3AsyncSession('localhost', 80, ["mymodule"]) as session: # Sync usage from designer_plugin.d3sdk import D3Session -with D3Session('localhost', 80, ["mymodule"]) as session: +with D3Session('localhost', 80) as session: session.rpc(rename_surface.payload("surface 1", "surface 2")) ``` @@ -251,4 +255,3 @@ logging.getLogger('designer_plugin').setLevel(logging.DEBUG) # License This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details. - diff --git a/pyproject.toml b/pyproject.toml index db59363..b54eca7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,11 +4,11 @@ build-backend = "setuptools.build_meta" [project] name = "designer-plugin" -version = "1.2.1" +version = "1.3.0" description = "Python library for creating Disguise Designer plugins with DNS-SD discovery and remote Python execution" authors = [ - { name = "Tom Whittock", email = "tom.whittock@disguise.one" }, - { name = "Taegyun Ha", email = "taegyun.ha@disguise.one" } + { name = "Taegyun Ha", email = "taegyun.ha@disguise.one" }, + { name = "Tom Whittock", email = "tom.whittock@disguise.one" } ] dependencies = [ "aiohttp>=3.13.2", @@ -109,6 +109,11 @@ python_classes = ["Test*"] python_functions = ["test_*"] addopts = [ "-v", + "-m", "not integration", "--strict-markers", "--strict-config", ] +markers = [ + "integration: tests that require a running d3 instance", +] + diff --git a/src/designer_plugin/api.py b/src/designer_plugin/api.py index 672d014..65cd9d0 100644 --- a/src/designer_plugin/api.py +++ b/src/designer_plugin/api.py @@ -125,7 +125,7 @@ async def d3_api_arequest( ############################################################################### # API async interface -async def d3_api_aplugin( +async def d3_api_aexecute( hostname: str, port: int, payload: PluginPayload[RetType], @@ -219,7 +219,7 @@ async def d3_api_aregister_module( ############################################################################### # API sync interface -def d3_api_plugin( +def d3_api_execute( hostname: str, port: int, payload: PluginPayload[RetType], diff --git a/src/designer_plugin/d3sdk/__init__.py b/src/designer_plugin/d3sdk/__init__.py index 7f5989e..905c283 100644 --- a/src/designer_plugin/d3sdk/__init__.py +++ b/src/designer_plugin/d3sdk/__init__.py @@ -5,7 +5,7 @@ from .client import D3PluginClient from .function import ( - add_packages_in_current_file, + PackageInfo, d3function, d3pythonscript, get_all_d3functions, @@ -18,9 +18,9 @@ "D3AsyncSession", "D3PluginClient", "D3Session", + "PackageInfo", "d3pythonscript", "d3function", - "add_packages_in_current_file", "get_register_payload", "get_all_d3functions", "get_all_modules", diff --git a/src/designer_plugin/d3sdk/ast_utils.py b/src/designer_plugin/d3sdk/ast_utils.py index 696cad1..839b3d7 100644 --- a/src/designer_plugin/d3sdk/ast_utils.py +++ b/src/designer_plugin/d3sdk/ast_utils.py @@ -4,11 +4,71 @@ """ import ast +import functools import inspect +import logging import textwrap import types +from collections.abc import Callable from typing import Any +from pydantic import BaseModel, Field + +from designer_plugin.d3sdk.builtin_modules import SUPPORTED_MODULES + +logger = logging.getLogger(__name__) + + +############################################################################### +# Package info models +class ImportAlias(BaseModel): + """Represents a single imported name with an optional alias. + + Mirrors the structure of ast.alias for Pydantic compatibility. + """ + + name: str = Field( + description="The imported name (e.g., 'Path' in 'from pathlib import Path')" + ) + asname: str | None = Field( + default=None, + description="The alias (e.g., 'np' in 'import numpy as np')", + ) + + +class PackageInfo(BaseModel): + """Structured representation of a Python import statement. + + Rendering rules (via to_import_statement using ast.unparse): + - package only → import package + - package + alias → import package as alias + - package + methods → from package import method1, method2 + - package + methods w/alias → from package import method1 as alias1 + """ + + package: str = Field(description="The module/package name to import") + alias: str | None = Field( + default=None, + description="Alias for the package (e.g., 'np' in 'import numpy as np')", + ) + methods: list[ImportAlias] = Field( + default_factory=list, + description="Imported names for 'from X import ...' style imports", + ) + + def to_import_statement(self) -> str: + """Render back to a Python import statement using ast.unparse.""" + node: ast.stmt + if self.methods: + node = ast.ImportFrom( + module=self.package, + names=[ast.alias(name=m.name, asname=m.asname) for m in self.methods], + level=0, + ) + else: + node = ast.Import(names=[ast.alias(name=self.package, asname=self.alias)]) + return ast.unparse(node) + ############################################################################### # Source code extraction utilities @@ -369,94 +429,157 @@ def validate_and_extract_args( ############################################################################### -# Python package finder utility -def find_packages_in_current_file(caller_stack: int = 1) -> list[str]: - """Find all import statements in the caller's file by inspecting the call stack. +# Function-scoped import extraction utility +def _collect_used_names(func_node: ast.FunctionDef | ast.AsyncFunctionDef) -> set[str]: + """Collect all identifier names used inside a function body. + + Walks the function's AST body and extracts: + - Simple names (ast.Name nodes, e.g., ``foo`` in ``foo()``) + - Root names of attribute chains (e.g., ``np`` in ``np.array()``) + + Args: + func_node: The function AST node to analyse. + + Returns: + Set of identifier strings used in the function body. + """ + names: set[str] = set() + for node in ast.walk(func_node): + if isinstance(node, ast.Name): + names.add(node.id) + elif isinstance(node, ast.Attribute): + # Walk down the attribute chain to find the root name + root: ast.expr = node + while isinstance(root, ast.Attribute): + root = root.value + if isinstance(root, ast.Name): + names.add(root.id) + return names + + +def _is_type_checking_block(node: ast.If) -> bool: + """Check if an if statement is ``if TYPE_CHECKING:``.""" + if isinstance(node.test, ast.Name) and node.test.id == "TYPE_CHECKING": + return True + # Also match `if typing.TYPE_CHECKING:` + if isinstance(node.test, ast.Attribute): + return ( + node.test.attr == "TYPE_CHECKING" + and isinstance(node.test.value, ast.Name) + and node.test.value.id == "typing" + ) + return False + + +def _is_supported_module(module_name: str) -> bool: + """Check if a module (or its top-level parent) is Designer-supported.""" + top_level = module_name.split(".")[0] + return top_level in SUPPORTED_MODULES + - This function walks up the call stack to find the module where it was called from, - then parses that module's source code to extract all import statements that are - compatible with Python 2.7 and safe to send to Designer. +@functools.lru_cache(maxsize=128) +def _get_module_ast(module: types.ModuleType) -> ast.Module | None: + """Return the parsed AST for *module*, cached by module identity.""" + try: + return ast.parse(inspect.getsource(module)) + except (OSError, TypeError): + return None + + +def find_imports_for_function(func: Callable[..., Any]) -> list[PackageInfo]: + """Extract import statements used by a function from its source file. + + Inspects the module containing *func*, parses all top-level imports, then + filters them down to only those whose imported names are actually referenced + inside the function body. Args: - caller_stack: Number of frames to go up the call stack. Default is 1 (immediate caller). - Use higher values to inspect files further up the call chain. + func: The callable to analyse. Returns: - Sorted list of unique import statement strings (e.g., "import ast", "from pathlib import Path"). + Sorted list of :class:`PackageInfo` objects representing the imports + used by *func*. Filters applied: - - Excludes imports inside `if TYPE_CHECKING:` blocks (type checking only) - - Excludes imports from the 'd3blobgen' package (client-side only) - - Excludes imports from the 'typing' module (not supported in Python 2.7) - - Excludes imports of this function itself to avoid circular references + - Excludes imports inside ``if TYPE_CHECKING:`` blocks + - Only includes imports from Designer-supported builtin modules + (see ``SUPPORTED_MODULES`` in ``builtin_modules.py``) + - Only includes imports whose names are actually used in the function body """ - # Get the this file frame - current_frame: types.FrameType | None = inspect.currentframe() - if not current_frame: + # --- 1. Get the function's module source --- + module = inspect.getmodule(func) + if not module: return [] - # Get the caller's frame (file where this function is called) - caller_frame: types.FrameType | None = current_frame - for _ in range(caller_stack): - if not caller_frame or not caller_frame.f_back: - return [] - caller_frame = caller_frame.f_back - - if not caller_frame: + module_tree = _get_module_ast(module) + if module_tree is None: + logger.warning( + "Cannot detect file-level imports for '%s': module source unavailable " + "(e.g. Jupyter notebook). Place imports inside the function body instead.", + func.__qualname__, + ) return [] - modules: types.ModuleType | None = inspect.getmodule(caller_frame) - if not modules: + # --- 2. Collect names used inside the function body --- + func_source = textwrap.dedent(inspect.getsource(func)) + func_tree = ast.parse(func_source) + if not func_tree.body: return [] - source: str = inspect.getsource(modules) - - # Parse the source code - tree = ast.parse(source) - - # Get the name of this function to filter it out - # For example, we don't want `from core import find_packages_in_current_file` - function_name: str = current_frame.f_code.co_name - # Skip any package from d3blobgen - d3blobgen_package_name: str = "d3blobgen" - # typing not supported in python2.7 - typing_package_name: str = "typing" + func_node = func_tree.body[0] + if not isinstance(func_node, (ast.FunctionDef, ast.AsyncFunctionDef)): + return [] - def is_type_checking_block(node: ast.If) -> bool: - """Check if an if statement is 'if TYPE_CHECKING:'""" - return isinstance(node.test, ast.Name) and node.test.id == "TYPE_CHECKING" + used_names = _collect_used_names(func_node) - imports: list[str] = [] - for node in tree.body: - # Skip TYPE_CHECKING blocks entirely - if isinstance(node, ast.If) and is_type_checking_block(node): + # --- 3. Parse file-level imports and filter to used ones --- + packages: list[PackageInfo] = [] + for node in module_tree.body: + # Skip TYPE_CHECKING blocks + if isinstance(node, ast.If) and _is_type_checking_block(node): continue if isinstance(node, ast.Import): - imported_modules: list[str] = [alias.name for alias in node.names] - # Skip imports that include d3blobgen - if any(d3blobgen_package_name in module for module in imported_modules): - continue - if any(typing_package_name in module for module in imported_modules): - continue - import_text: str = f"import {', '.join(imported_modules)}" - imports.append(import_text) + for alias in node.names: + if not _is_supported_module(alias.name): + continue + + # The name used in code is the alias if present, otherwise the top-level + # package name (e.g. "import logging.handlers" binds "logging", not + # "logging.handlers"). + effective_name = ( + alias.asname if alias.asname else alias.name.split(".")[0] + ) + if effective_name in used_names: + packages.append( + PackageInfo( + package=alias.name, + alias=alias.asname, + ) + ) elif isinstance(node, ast.ImportFrom): - imported_module: str | None = node.module - imported_names: list[str] = [alias.name for alias in node.names] - if not imported_module: - continue - # Skip imports that include d3blobgen - if d3blobgen_package_name in imported_module: + if not node.module: continue - elif typing_package_name in imported_module: - continue - # Skip imports that include this function itself - if function_name in imported_names: + if not _is_supported_module(node.module): continue - line_text = f"from {imported_module} import {', '.join(imported_names)}" - imports.append(line_text) + # Filter to only methods actually used by the function + matched_methods: list[ImportAlias] = [] + for alias in node.names: + effective_name = alias.asname if alias.asname else alias.name + if effective_name in used_names: + matched_methods.append( + ImportAlias(name=alias.name, asname=alias.asname) + ) + + if matched_methods: + packages.append( + PackageInfo( + package=node.module, + methods=matched_methods, + ) + ) - return sorted(set(imports)) + # Sort by import statement string for deterministic output + return sorted(packages, key=lambda p: p.to_import_statement()) diff --git a/src/designer_plugin/d3sdk/builtin_modules.py b/src/designer_plugin/d3sdk/builtin_modules.py new file mode 100644 index 0000000..8b933a8 --- /dev/null +++ b/src/designer_plugin/d3sdk/builtin_modules.py @@ -0,0 +1,220 @@ +SUPPORTED_MODULES: frozenset[str] = frozenset( + [ + "Bastion", + "ConfigParser", + "Cookie", + "HTMLParser", + "SocketServer", + "StringIO", + "UserDict", + "UserList", + "_winreg", + "abc", + "aifc", + "anydbm", + "array", + "ast", + "atexit", + "audioop", + "base64", + "binascii", + "bisect", + "bz2", + "cPickle", + "cStringIO", + "chunk", + "cmath", + "cmd", + "codecs", + "codeop", + "collections", + "copy", + "copy_reg", + "csv", + "ctypes", + "datetime", + "difflib", + "dircache", + "dis", + "dumbdbm", + "dummy_thread", + "errno", + "filecmp", + "fnmatch", + "functools", + "future_builtins", + "gc", + "getopt", + "hashlib", + "heapq", + "hmac", + "htmlentitydefs", + "imghdr", + "imp", + "importlib", + "inspect", + "io", + "itertools", + "json", + "keyword", + "linecache", + "locale", + "logging", + "mailcap", + "marshal", + "math", + "mmap", + "msvcrt", + "mutex", + "netrc", + "new", + "nntplib", + "numbers", + "operator", + "os", + "parser", + "pkgutil", + "plistlib", + "poplib", + "pprint", + "quopri", + "random", + "re", + "repr", + "rfc822", + "rlcompleter", + "sched", + "select", + "sets", + "sgmllib", + "sha", + "shelve", + "shlex", + "shutil", + "signal", + "site", + "sndhdr", + "socket", + "sqlite3", + "stat", + "statvfs", + "string", + "stringprep", + "struct", + "subprocess", + "sunau", + "symbol", + "symtable", + "sysconfig", + "tempfile", + "textwrap", + "thread", + "time", + "token", + "tokenize", + "traceback", + "types", + "unicodedata", + "unittest", + "urlparse", + "uuid", + "warnings", + "weakref", + "webbrowser", + "winsound", + "wsgiref", + "xdrlib", + "xml", + "zipfile", + "zipimport", + "zlib", + ] +) + +NOT_SUPPORTED_MODULES: frozenset[str] = frozenset( + [ + "BaseHTTPServer", + "CGIHTTPServer", + "DocXMLRPCServer", + "Queue", + "ScrolledText", + "SimpleHTTPServer", + "SimpleXMLRPCServer", + "Tix", + "Tkinter", + "UserString", + "argparse", + "asynchat", + "asyncore", + "bdb", + "binhex", + "bsddb", + "calendar", + "cgi", + "cgitb", + "code", + "colorsys", + "compileall", + "compiler", + "contextlib", + "cookielib", + "dbhash", + "dbm", + "decimal", + "distutils", + "doctest", + "dummy_threading", + "email", + "ensurepip", + "fileinput", + "formatter", + "fractions", + "ftplib", + "getpass", + "gettext", + "glob", + "gzip", + "htmllib", + "httplib", + "imaplib", + "mailbox", + "mhlib", + "mimetools", + "mimetypes", + "mimify", + "modulefinder", + "msilib", + "multiprocessing", + "optparse", + "pdb", + "pickle", + "pickletools", + "platform", + "popen2", + "profile", + "py_compile", + "pyclbr", + "pydoc", + "robotparser", + "runpy", + "smtpd", + "smtplib", + "ssl", + "sys", + "tabnanny", + "tarfile", + "telnetlib", + "test", + "threading", + "timeit", + "trace", + "ttk", + "turtle", + "urllib", + "urllib2", + "uu", + "wave", + "whichdb", + "xmlrpclib", + ] +) diff --git a/src/designer_plugin/d3sdk/client.py b/src/designer_plugin/d3sdk/client.py index fcf8cc4..bdac3e7 100644 --- a/src/designer_plugin/d3sdk/client.py +++ b/src/designer_plugin/d3sdk/client.py @@ -13,9 +13,9 @@ from typing import Any, ParamSpec, TypeVar from designer_plugin.api import ( - d3_api_aplugin, + d3_api_aexecute, d3_api_aregister_module, - d3_api_plugin, + d3_api_execute, d3_api_register_module, ) from designer_plugin.d3sdk.ast_utils import ( @@ -85,7 +85,7 @@ def create_d3_plugin_method_wrapper( 2. Serializes the arguments using repr() 3. Builds a script string in the form: "return plugin.{method_name}({args})" 4. Creates a PluginPayload with the script and module information - 5. Sends it to Designer via d3_api_plugin or d3_api_aplugin + 5. Sends it to Designer via d3_api_execute or d3_api_aexecute 6. Returns the result from the remote execution Args: @@ -112,7 +112,7 @@ async def async_wrapper(self, *args, **kwargs): # type: ignore session_runtime_error_message(self.__class__.__name__) ) payload = build_payload(self, method_name, positional, keyword) - response: PluginResponse[T] = await d3_api_aplugin( + response: PluginResponse[T] = await d3_api_aexecute( self._hostname, self._port, payload ) return response.returnValue @@ -130,7 +130,7 @@ def sync_wrapper(self, *args, **kwargs): # type: ignore session_runtime_error_message(self.__class__.__name__) ) payload = build_payload(self, method_name, positional, keyword) - response: PluginResponse[T] = d3_api_plugin( + response: PluginResponse[T] = d3_api_execute( self._hostname, self._port, payload ) return response.returnValue diff --git a/src/designer_plugin/d3sdk/function.py b/src/designer_plugin/d3sdk/function.py index c417755..0418101 100644 --- a/src/designer_plugin/d3sdk/function.py +++ b/src/designer_plugin/d3sdk/function.py @@ -6,6 +6,7 @@ import ast import functools import inspect +import logging import textwrap from collections import defaultdict from collections.abc import Callable @@ -14,8 +15,9 @@ from pydantic import BaseModel, Field from designer_plugin.d3sdk.ast_utils import ( + PackageInfo, convert_function_to_py27, - find_packages_in_current_file, + find_imports_for_function, validate_and_bind_signature, validate_and_extract_args, ) @@ -24,6 +26,8 @@ RegisterPayload, ) +logger = logging.getLogger(__name__) + ############################################################################### # Plugin function related implementations @@ -48,6 +52,9 @@ class FunctionInfo(BaseModel): args: list[str] = Field( default=[], description="list of arguments from extracted function" ) + packages: list[PackageInfo] = Field( + default=[], description="list of packages/imports used by the function" + ) def extract_function_info(func: Callable[..., Any]) -> FunctionInfo: @@ -111,6 +118,8 @@ def extract_function_info(func: Callable[..., Any]) -> FunctionInfo: for stmt in body_nodes_py27: body_py27 += ast.unparse(stmt) + "\n" + packages = find_imports_for_function(func) + return FunctionInfo( source_code=source_code_py3, source_code_py27=source_code_py27, @@ -118,6 +127,7 @@ def extract_function_info(func: Callable[..., Any]) -> FunctionInfo: body=body.strip(), body_py27=body_py27.strip(), args=args, + packages=packages, ) @@ -253,8 +263,32 @@ def __init__(self, module_name: str, func: Callable[P, T]): super().__init__(func) + # Update the function in case the function was updated in the same session. + # For example, jupyter notebook server can be running, but function signature can + # change constantly. + is_replacement = self in D3Function._available_d3functions[module_name] + if is_replacement: + logger.debug( + "Function '%s' in module '%s' is being replaced.", + self.name, + module_name, + ) + D3Function._available_d3functions[module_name].discard(self) D3Function._available_d3functions[module_name].add(self) + if is_replacement: + # Full rebuild needed to evict stale imports from the replaced function. + D3Function._available_packages[module_name] = { + pkg.to_import_statement() + for f in D3Function._available_d3functions[module_name] + for pkg in f._function_info.packages + } + else: + # New function: incrementally add its packages. No stale imports to remove. + D3Function._available_packages[module_name].update( + pkg.to_import_statement() for pkg in self._function_info.packages + ) + def __eq__(self, other: object) -> bool: """Check equality based on function name for unique registration. @@ -298,7 +332,7 @@ def get_module_register_payload(module_name: str) -> RegisterPayload | None: return None contents_packages: str = "\n".join( - list(D3Function._available_packages[module_name]) + sorted(D3Function._available_packages[module_name]) ) contents_functions: str = "\n\n".join( [ @@ -447,35 +481,6 @@ def decorator(func: Callable[P, T]) -> D3Function[P, T]: return decorator -def add_packages_in_current_file(module_name: str) -> None: - """Add all import statements from the caller's file to a d3function module's package list. - - This function scans the calling file's import statements and registers them with - the specified module name, making those imports available when the module is - registered with Designer. This is useful for ensuring all dependencies are included - when deploying Python functions to Designer. - - Args: - module_name: The name of the d3function module to associate the packages with. - Must match the module_name used in @d3function decorator. - - Example: - ```python - import numpy as np - - @d3function("my_module") - def my_function(): - return np.array([1, 2, 3]) - - # Register all imports in the file (numpy) - add_packages_in_current_file("my_module") - ``` - """ - # caller_stack is 2, 1 for this, 1 for caller of this function. - packages: list[str] = find_packages_in_current_file(2) - D3Function._available_packages[module_name].update(packages) - - def get_register_payload(module_name: str) -> RegisterPayload | None: """Get the registration payload for a specific module. diff --git a/src/designer_plugin/d3sdk/session.py b/src/designer_plugin/d3sdk/session.py index 63992ab..b0eead5 100644 --- a/src/designer_plugin/d3sdk/session.py +++ b/src/designer_plugin/d3sdk/session.py @@ -9,10 +9,10 @@ from designer_plugin.api import ( Method, - d3_api_aplugin, + d3_api_aexecute, d3_api_aregister_module, d3_api_arequest, - d3_api_plugin, + d3_api_execute, d3_api_register_module, d3_api_request, ) @@ -29,17 +29,18 @@ class D3SessionBase: """Base class for Designer session management.""" - def __init__(self, hostname: str, port: int, context_modules: list[str]) -> None: + def __init__(self, hostname: str, port: int, context_modules: set[str]) -> None: """Initialize base session with connection details and module context. Args: hostname: The hostname of the Designer instance. port: The port number of the Designer instance. - context_modules: List of module names to register when entering session context. + context_modules: Set of module names to register when entering session context. """ self.hostname: str = hostname self.port: int = port - self.context_modules: list[str] = context_modules + self.context_modules: set[str] = context_modules + self.registered_modules: set[str] = set() class D3Session(D3SessionBase): @@ -53,16 +54,16 @@ def __init__( self, hostname: str, port: int = D3_PLUGIN_DEFAULT_PORT, - context_modules: list[str] | None = None, + context_modules: set[str] | None = None, ) -> None: """Initialize synchronous Designer session. Args: hostname: The hostname of the Designer instance. port: The port number of the Designer instance. - context_modules: Optional list of module names to register when entering session context. + context_modules: Optional set of module names to register when entering session context. """ - super().__init__(hostname, port, context_modules or []) + super().__init__(hostname, port, context_modules or set()) def __enter__(self) -> "D3Session": """Enter context manager and register all context modules. @@ -117,7 +118,10 @@ def execute( Raises: PluginException: If the plugin execution fails. """ - return d3_api_plugin(self.hostname, self.port, payload, timeout_sec) + if payload.moduleName and payload.moduleName not in self.registered_modules: + self.register_module(payload.moduleName) + + return d3_api_execute(self.hostname, self.port, payload, timeout_sec) def request(self, method: Method, url_endpoint: str, **kwargs: Any) -> Any: """Make a generic HTTP request to Designer API. @@ -152,6 +156,7 @@ def register_module( ) if payload: d3_api_register_module(self.hostname, self.port, payload, timeout_sec) + self.registered_modules.add(module_name) return True return False @@ -186,16 +191,16 @@ def __init__( self, hostname: str, port: int = D3_PLUGIN_DEFAULT_PORT, - context_modules: list[str] | None = None, + context_modules: set[str] | None = None, ) -> None: """Initialize asynchronous Designer session. Args: hostname: The hostname of the Designer instance. port: The port number of the Designer instance. - context_modules: Optional list of module names to register when entering session context. + context_modules: Optional set of module names to register when entering session context. """ - super().__init__(hostname, port, context_modules or []) + super().__init__(hostname, port, context_modules or set()) async def __aenter__(self) -> "D3AsyncSession": """Enter async context manager and register all context modules. @@ -270,7 +275,10 @@ async def execute( Raises: PluginException: If the plugin execution fails. """ - return await d3_api_aplugin(self.hostname, self.port, payload, timeout_sec) + if payload.moduleName and payload.moduleName not in self.registered_modules: + await self.register_module(payload.moduleName) + + return await d3_api_aexecute(self.hostname, self.port, payload, timeout_sec) async def register_module( self, module_name: str, timeout_sec: float | None = None @@ -294,6 +302,7 @@ async def register_module( await d3_api_aregister_module( self.hostname, self.port, payload, timeout_sec ) + self.registered_modules.add(module_name) return True return False diff --git a/tests/test_ast_utils.py b/tests/test_ast_utils.py index 599edf9..ee7306d 100644 --- a/tests/test_ast_utils.py +++ b/tests/test_ast_utils.py @@ -5,18 +5,22 @@ import ast import inspect +import logging.handlers import textwrap import types +from os.path import join as path_join import pytest from designer_plugin.d3sdk.ast_utils import ( ConvertToPython27, + ImportAlias, + PackageInfo, convert_class_to_py27, convert_function_to_py27, filter_base_classes, filter_init_args, - find_packages_in_current_file, + find_imports_for_function, get_class_node, get_source, ) @@ -890,64 +894,6 @@ def __init__(self): assert param_names == [] -class TestFindPackagesInCurrentFile: - """Tests for find_packages_in_current_file function.""" - - def test_finds_imports_from_current_file(self): - """Test that the function finds import statements from the calling file.""" - # This test file has imports at the top - they should be found - imports = find_packages_in_current_file() - - # Should find at least some of our imports - assert isinstance(imports, list) - assert len(imports) > 0 - - # Should be sorted - assert imports == sorted(imports) - - # Check for specific imports we know exist in this file - assert "import ast" in imports - assert "import pytest" in imports - assert "import textwrap" in imports - - def test_excludes_typing_imports(self): - """Test that typing module imports are excluded.""" - # Since this file doesn't import typing, we can't directly test exclusion here - # But we can verify the function doesn't crash and returns valid results - imports = find_packages_in_current_file() - - # Verify no typing imports are present - typing_imports = [imp for imp in imports if "typing" in imp] - assert len(typing_imports) == 0 - - def test_excludes_d3blobgen_imports(self): - """Test that d3blobgen package imports are excluded.""" - imports = find_packages_in_current_file() - - # Verify no d3blobgen imports are present - d3blobgen_imports = [imp for imp in imports if "d3blobgen" in imp] - assert len(d3blobgen_imports) == 0 - - def test_excludes_find_packages_function_itself(self): - """Test that the function itself is excluded from imports.""" - imports = find_packages_in_current_file() - - # Should not include import of find_packages_in_current_file itself - # even though we import it at the top of this file - function_imports = [imp for imp in imports if "find_packages_in_current_file" in imp] - assert len(function_imports) == 0 - - def test_returns_unique_sorted_imports(self): - """Test that returned imports are unique and sorted.""" - imports = find_packages_in_current_file() - - # Check uniqueness - assert len(imports) == len(set(imports)) - - # Check sorting - assert imports == sorted(imports) - - class TestDecoratorHandling: """Tests for handling decorators in AST transformations.""" @@ -1127,5 +1073,152 @@ def my_function(x, y): assert len(func.body) == 3 # Two assignments and one return +class TestPackageInfo: + """Tests for PackageInfo and ImportAlias models.""" + + def test_import_package_only(self): + """import numpy""" + pkg = PackageInfo(package="numpy") + assert pkg.to_import_statement() == "import numpy" + + def test_import_package_with_alias(self): + """import numpy as np""" + pkg = PackageInfo(package="numpy", alias="np") + assert pkg.to_import_statement() == "import numpy as np" + + def test_from_import_single_method(self): + """from pathlib import Path""" + pkg = PackageInfo( + package="pathlib", + methods=[ImportAlias(name="Path")], + ) + assert pkg.to_import_statement() == "from pathlib import Path" + + def test_from_import_multiple_methods(self): + """from os.path import join, exists""" + pkg = PackageInfo( + package="os.path", + methods=[ + ImportAlias(name="join"), + ImportAlias(name="exists"), + ], + ) + assert pkg.to_import_statement() == "from os.path import join, exists" + + def test_from_import_method_with_alias(self): + """from collections import defaultdict as dd""" + pkg = PackageInfo( + package="collections", + methods=[ImportAlias(name="defaultdict", asname="dd")], + ) + assert pkg.to_import_statement() == "from collections import defaultdict as dd" + + def test_from_import_mixed_aliases(self): + """from collections import OrderedDict, defaultdict as dd""" + pkg = PackageInfo( + package="collections", + methods=[ + ImportAlias(name="OrderedDict"), + ImportAlias(name="defaultdict", asname="dd"), + ], + ) + result = pkg.to_import_statement() + assert result == "from collections import OrderedDict, defaultdict as dd" + + +class TestFindImportsForFunction: + """Tests for find_imports_for_function.""" + + def test_finds_used_import(self): + """Function using ast should get 'import ast' extracted.""" + # This function uses ast.parse which is from 'import ast' at file top + def uses_ast(): + return ast.parse("x = 1") + + packages = find_imports_for_function(uses_ast) + statements = [p.to_import_statement() for p in packages] + assert "import ast" in statements + + def test_excludes_unused_import(self): + """Function not using a module should not include it.""" + def uses_nothing(): + return 42 + + packages = find_imports_for_function(uses_nothing) + statements = [p.to_import_statement() for p in packages] + # Should not include ast, textwrap, etc. since they're not used + assert "import types" not in statements + + def test_finds_from_import(self): + """Function using a 'from X import Y' name should include it.""" + def uses_textwrap(): + return textwrap.dedent(" hello") + + packages = find_imports_for_function(uses_textwrap) + statements = [p.to_import_statement() for p in packages] + assert "import textwrap" in statements + + def test_returns_package_info_objects(self): + """Return type should be list of PackageInfo.""" + def simple_func(): + return ast.dump(ast.parse("1")) + + packages = find_imports_for_function(simple_func) + assert all(isinstance(p, PackageInfo) for p in packages) + + def test_sorted_output(self): + """Output should be sorted by import statement.""" + def uses_multiple(): + _ = textwrap.dedent("x") + _ = ast.parse("y") + return inspect.getsource(uses_multiple) + + packages = find_imports_for_function(uses_multiple) + statements = [p.to_import_statement() for p in packages] + assert statements == sorted(statements) + + def test_excludes_typing_imports(self): + """Typing imports should be excluded.""" + # The 'Any' import from typing at the file top should never appear + def uses_nothing(): + return 1 + + packages = find_imports_for_function(uses_nothing) + statements = [p.to_import_statement() for p in packages] + typing_imports = [s for s in statements if "typing" in s] + assert len(typing_imports) == 0 + + def test_finds_submodule_import(self): + """from os.path import join (sub-module) should be detected.""" + + def uses_path_join(): + return path_join("a", "b") + + packages = find_imports_for_function(uses_path_join) + statements = [p.to_import_statement() for p in packages] + assert "from os.path import join as path_join" in statements + + def test_no_source_module_returns_empty(self): + """Function whose module source is unavailable should return empty list.""" + # Simulate a function from an unsourceable module (like Jupyter __main__) + def dummy(): + return 1 + + # Patch __module__ to a non-existent module + dummy.__module__ = "_nonexistent_module_for_test" + packages = find_imports_for_function(dummy) + assert packages == [] + + def test_dotted_import_effective_name(self): + """import logging.handlers binds 'logging' — should match usage of logging.handlers.""" + + def uses_logging_handlers(): + return logging.handlers.RotatingFileHandler("/tmp/x") + + packages = find_imports_for_function(uses_logging_handlers) + statements = [p.to_import_statement() for p in packages] + assert "import logging.handlers" in statements + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/test_client.py b/tests/test_client.py index b7c73f1..3e6442c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -71,7 +71,7 @@ def test_method_call_without_session_raises_error(self, plugin): def test_correct_arguments_sync(self, plugin, mock_response): """Test that correct arguments pass through successfully.""" - with patch('designer_plugin.d3sdk.client.d3_api_plugin', return_value=mock_response) as mock_api: + with patch('designer_plugin.d3sdk.client.d3_api_execute', return_value=mock_response) as mock_api: plugin._hostname = "localhost" plugin._port = 80 @@ -114,7 +114,7 @@ def test_unexpected_keyword_argument(self, plugin): def test_method_with_defaults_partial_args(self, plugin, mock_response): """Test method with default parameters using partial arguments.""" - with patch('designer_plugin.d3sdk.client.d3_api_plugin', return_value=mock_response): + with patch('designer_plugin.d3sdk.client.d3_api_execute', return_value=mock_response): plugin._hostname = "localhost" plugin._port = 80 @@ -124,7 +124,7 @@ def test_method_with_defaults_partial_args(self, plugin, mock_response): def test_method_with_defaults_override(self, plugin, mock_response): """Test method with default parameters overriding defaults.""" - with patch('designer_plugin.d3sdk.client.d3_api_plugin', return_value=mock_response): + with patch('designer_plugin.d3sdk.client.d3_api_execute', return_value=mock_response): plugin._hostname = "localhost" plugin._port = 80 @@ -134,7 +134,7 @@ def test_method_with_defaults_override(self, plugin, mock_response): def test_method_with_defaults_keyword(self, plugin, mock_response): """Test method with default parameters using keyword arguments.""" - with patch('designer_plugin.d3sdk.client.d3_api_plugin', return_value=mock_response): + with patch('designer_plugin.d3sdk.client.d3_api_execute', return_value=mock_response): plugin._hostname = "localhost" plugin._port = 80 @@ -144,7 +144,7 @@ def test_method_with_defaults_keyword(self, plugin, mock_response): def test_keyword_only_parameters(self, plugin, mock_response): """Test method with keyword-only parameters.""" - with patch('designer_plugin.d3sdk.client.d3_api_plugin', return_value=mock_response): + with patch('designer_plugin.d3sdk.client.d3_api_execute', return_value=mock_response): plugin._hostname = "localhost" plugin._port = 80 @@ -162,7 +162,7 @@ def test_keyword_only_parameters_as_positional_fails(self, plugin): def test_mixed_parameters(self, plugin, mock_response): """Test method with mixed parameter types.""" - with patch('designer_plugin.d3sdk.client.d3_api_plugin', return_value=mock_response): + with patch('designer_plugin.d3sdk.client.d3_api_execute', return_value=mock_response): plugin._hostname = "localhost" plugin._port = 80 diff --git a/tests/test_core.py b/tests/test_core.py index e71ea9d..2dc170f 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -3,6 +3,9 @@ Copyright (c) 2025 Disguise Technologies ltd """ +import logging +import math + import pytest from designer_plugin.d3sdk.function import ( @@ -329,6 +332,61 @@ def test_inequality_different_functions(self): +class TestD3FunctionReplacement: + """Test that re-registering a D3Function with the same name replaces the old one.""" + + def test_reregister_replaces_function(self): + """Re-registering a function with the same name should replace it in the set.""" + module = "test_replace_module" + D3Function._available_d3functions[module].clear() + + @d3function(module) + def my_func(a: int) -> int: + return a + + @d3function(module) + def my_func(a: int, b: int) -> int: # noqa: F811 + return a + b + + funcs = D3Function._available_d3functions[module] + matching = [f for f in funcs if f.name == "my_func"] + assert len(matching) == 1 + assert matching[0].function_info.args == ["a", "b"] + + def test_reregister_logs_debug(self, caplog): + """Re-registering should log a debug message.""" + module = "test_replace_warn_module" + D3Function._available_d3functions[module].clear() + + @d3function(module) + def warn_func() -> None: + pass + + with caplog.at_level(logging.DEBUG, logger="designer_plugin.d3sdk.function"): + @d3function(module) + def warn_func() -> int: # noqa: F811 + return 1 + + assert any("warn_func" in msg and "being replaced" in msg for msg in caplog.messages) + + def test_set_size_unchanged_after_replacement(self): + """The function set size should stay the same after replacement.""" + module = "test_replace_size_module" + D3Function._available_d3functions[module].clear() + + @d3function(module) + def size_func(x: int) -> int: + return x + + assert len(D3Function._available_d3functions[module]) == 1 + + @d3function(module) + def size_func(x: int, y: int) -> int: # noqa: F811 + return x + y + + assert len(D3Function._available_d3functions[module]) == 1 + + class TestD3PythonScript: def test_d3pythonscript_decorator(self): @d3pythonscript @@ -408,3 +466,92 @@ def test_func(a: int, b: int) -> int: with pytest.raises(TypeError, match="multiple values for argument"): test_func.payload(1, a=2) + + +class TestAutoPackageRegistration: + """Test that @d3function auto-registers imports used by the function.""" + + def test_extract_function_info_populates_packages(self): + """extract_function_info should populate the packages field.""" + def func_using_logging(): + return logging.getLogger("test") + + info = extract_function_info(func_using_logging) + statements = [p.to_import_statement() for p in info.packages] + assert "import logging" in statements + + def test_extract_function_info_packages_default_empty_for_no_imports(self): + """Function using no imports should have empty packages.""" + def func_no_imports(): + return 42 + + info = extract_function_info(func_no_imports) + assert info.packages == [] + + def test_d3function_auto_registers_packages(self): + """D3Function should auto-register packages.""" + module = "test_auto_pkg_module" + D3Function._available_d3functions[module].clear() + D3Function._available_packages[module].clear() + + @d3function(module) + def func_using_logging(): + return logging.getLogger("test") + + # Packages should be auto-registered + assert "import logging" in D3Function._available_packages[module] + + def test_d3function_register_payload_includes_auto_packages(self): + """get_register_payload should include auto-extracted imports.""" + module = "test_auto_payload_module" + D3Function._available_d3functions[module].clear() + D3Function._available_packages[module].clear() + + @d3function(module) + def func_using_logging(): + return logging.getLogger("test") + + payload = get_register_payload(module) + assert payload is not None + assert "import logging" in payload.contents + + def test_new_functions_accumulate_packages(self): + """Adding a second function should add its packages without losing the first's.""" + module = "test_accumulate_pkg_module" + D3Function._available_d3functions[module].clear() + D3Function._available_packages[module].clear() + + @d3function(module) + def func_a(): + return logging.getLogger("a") + + assert "import logging" in D3Function._available_packages[module] + + @d3function(module) + def func_b(): + return math.sqrt(4) + + # Both packages must be present after adding func_b + assert "import logging" in D3Function._available_packages[module] + assert "import math" in D3Function._available_packages[module] + + def test_replacement_removes_stale_packages(self): + """Replacing a function with one that uses fewer imports should evict stale packages.""" + module = "test_stale_pkg_module" + D3Function._available_d3functions[module].clear() + D3Function._available_packages[module].clear() + + @d3function(module) + def my_func(): # uses logging + return logging.getLogger("x") + + assert "import logging" in D3Function._available_packages[module] + + @d3function(module) + def my_func() -> int: # noqa: F811 # no longer uses logging + return 42 + + # Stale import from the old version must be gone + assert "import logging" not in D3Function._available_packages[module] + + diff --git a/tests/test_session.py b/tests/test_session.py new file mode 100644 index 0000000..1715571 --- /dev/null +++ b/tests/test_session.py @@ -0,0 +1,172 @@ +""" +MIT License +Copyright (c) 2025 Disguise Technologies ltd +""" + +import asyncio +from unittest.mock import AsyncMock, patch + +from designer_plugin.d3sdk.function import d3function +from designer_plugin.d3sdk.session import D3AsyncSession, D3Session +from designer_plugin.models import PluginPayload, PluginResponse, PluginStatus + + +# Register a module so D3Function._available_d3functions knows about it. +@d3function("lazy_test_module") +def _lazy_test_fn() -> str: + return "hello world" + + +def _make_response() -> PluginResponse: + return PluginResponse( + status=PluginStatus(code=0, message="OK", details=[]), + returnValue=None, + ) + + +def _module_payload() -> PluginPayload: + """Payload that references a registered @d3function module.""" + return PluginPayload(moduleName="lazy_test_module", script="return _lazy_test_fn()") + + +def _script_payload() -> PluginPayload: + """Payload with no module (equivalent to @d3pythonscript).""" + return PluginPayload(moduleName=None, script="return 42") + + +class TestD3SessionLazyRegistration: + """Lazy registration behaviour for the synchronous D3Session.""" + + def test_registered_modules_starts_empty(self): + session = D3Session("localhost", 80) + assert session.registered_modules == set() + + def test_module_registered_on_first_execute(self): + session = D3Session("localhost", 80) + with ( + patch("designer_plugin.d3sdk.session.d3_api_register_module") as mock_reg, + patch("designer_plugin.d3sdk.session.d3_api_execute", return_value=_make_response()), + ): + session.execute(_module_payload()) + mock_reg.assert_called_once() + assert "lazy_test_module" in session.registered_modules + + def test_module_not_re_registered_on_second_execute(self): + session = D3Session("localhost", 80) + with ( + patch("designer_plugin.d3sdk.session.d3_api_register_module") as mock_reg, + patch("designer_plugin.d3sdk.session.d3_api_execute", return_value=_make_response()), + ): + session.execute(_module_payload()) + session.execute(_module_payload()) + mock_reg.assert_called_once() + + def test_no_registration_for_script_payload(self): + """Payloads without a moduleName must never trigger registration.""" + session = D3Session("localhost", 80) + with ( + patch("designer_plugin.d3sdk.session.d3_api_register_module") as mock_reg, + patch("designer_plugin.d3sdk.session.d3_api_execute", return_value=_make_response()), + ): + session.execute(_script_payload()) + mock_reg.assert_not_called() + + def test_context_module_not_re_registered_lazily(self): + """A module pre-registered via context_modules must not be registered again in execute().""" + with ( + patch("designer_plugin.d3sdk.session.d3_api_register_module") as mock_reg, + patch("designer_plugin.d3sdk.session.d3_api_execute", return_value=_make_response()), + ): + with D3Session("localhost", 80, {"lazy_test_module"}) as session: + assert "lazy_test_module" in session.registered_modules + session.execute(_module_payload()) + mock_reg.assert_called_once() # only from __enter__, not from execute() + + def test_registered_modules_updated_after_execute(self): + session = D3Session("localhost", 80) + assert "lazy_test_module" not in session.registered_modules + with ( + patch("designer_plugin.d3sdk.session.d3_api_register_module"), + patch("designer_plugin.d3sdk.session.d3_api_execute", return_value=_make_response()), + ): + session.execute(_module_payload()) + assert "lazy_test_module" in session.registered_modules + + +class TestD3AsyncSessionLazyRegistration: + """Lazy registration behaviour for the asynchronous D3AsyncSession.""" + + def test_registered_modules_starts_empty(self): + session = D3AsyncSession("localhost", 80) + assert session.registered_modules == set() + + def test_module_registered_on_first_execute(self): + async def run(): + session = D3AsyncSession("localhost", 80) + with ( + patch("designer_plugin.d3sdk.session.d3_api_aregister_module", new_callable=AsyncMock) as mock_reg, + patch("designer_plugin.d3sdk.session.d3_api_aexecute", new_callable=AsyncMock) as mock_exec, + ): + mock_exec.return_value = _make_response() + await session.execute(_module_payload()) + mock_reg.assert_called_once() + assert "lazy_test_module" in session.registered_modules + + asyncio.run(run()) + + def test_module_not_re_registered_on_second_execute(self): + async def run(): + session = D3AsyncSession("localhost", 80) + with ( + patch("designer_plugin.d3sdk.session.d3_api_aregister_module", new_callable=AsyncMock) as mock_reg, + patch("designer_plugin.d3sdk.session.d3_api_aexecute", new_callable=AsyncMock) as mock_exec, + ): + mock_exec.return_value = _make_response() + await session.execute(_module_payload()) + await session.execute(_module_payload()) + mock_reg.assert_called_once() + + asyncio.run(run()) + + def test_no_registration_for_script_payload(self): + """Payloads without a moduleName must never trigger registration.""" + async def run(): + session = D3AsyncSession("localhost", 80) + with ( + patch("designer_plugin.d3sdk.session.d3_api_aregister_module", new_callable=AsyncMock) as mock_reg, + patch("designer_plugin.d3sdk.session.d3_api_aexecute", new_callable=AsyncMock) as mock_exec, + ): + mock_exec.return_value = _make_response() + await session.execute(_script_payload()) + mock_reg.assert_not_called() + + asyncio.run(run()) + + def test_context_module_not_re_registered_lazily(self): + """A module pre-registered via context_modules must not be registered again in execute().""" + async def run(): + with ( + patch("designer_plugin.d3sdk.session.d3_api_aregister_module", new_callable=AsyncMock) as mock_reg, + patch("designer_plugin.d3sdk.session.d3_api_aexecute", new_callable=AsyncMock) as mock_exec, + ): + mock_exec.return_value = _make_response() + async with D3AsyncSession("localhost", 80, {"lazy_test_module"}) as session: + assert "lazy_test_module" in session.registered_modules + await session.execute(_module_payload()) + mock_reg.assert_called_once() + + asyncio.run(run()) + + def test_registered_modules_updated_after_execute(self): + async def run(): + session = D3AsyncSession("localhost", 80) + assert "lazy_test_module" not in session.registered_modules + with ( + patch("designer_plugin.d3sdk.session.d3_api_aregister_module", new_callable=AsyncMock), + patch("designer_plugin.d3sdk.session.d3_api_aexecute", new_callable=AsyncMock) as mock_exec, + ): + mock_exec.return_value = _make_response() + await session.execute(_module_payload()) + assert "lazy_test_module" in session.registered_modules + + asyncio.run(run()) diff --git a/tests/test_supported_modules.py b/tests/test_supported_modules.py new file mode 100644 index 0000000..aa73dab --- /dev/null +++ b/tests/test_supported_modules.py @@ -0,0 +1,56 @@ +import asyncio + +import pytest + +from designer_plugin.d3sdk import D3AsyncSession, d3function +from designer_plugin.d3sdk.builtin_modules import NOT_SUPPORTED_MODULES, SUPPORTED_MODULES + + +@d3function('test_supported_modules') +def check_import(module_str) -> bool: + try: + module = __import__(module_str) + return True + except ImportError as e: + return False + + +class TestSupportedModules: + """ + Test if supported and not supported modules are handled properly on Designer side. + This is integration test so Designer must be running to pass the test. + """ + + @pytest.mark.integration + def test_supported_modules(self): + """Test if all supported modules are able to be imported on Designer side.""" + + async def run(): + failed = [] + async with D3AsyncSession("localhost", 80) as session: + for module_str in SUPPORTED_MODULES: + import_success: bool = await session.rpc( + check_import.payload(module_str) + ) + if not import_success: + failed.append(module_str) + assert not failed, f"Failed to import: {failed}" + + asyncio.run(run()) + + @pytest.mark.integration + def test_not_supported_modules(self): + """Test if all not supported modules are not importable on Designer side.""" + + async def run(): + failed = [] + async with D3AsyncSession("localhost", 80) as session: + for module_str in NOT_SUPPORTED_MODULES: + import_success: bool = await session.rpc( + check_import.payload(module_str) + ) + if import_success: + failed.append(module_str) + assert not failed, f"Unexpectedly imported: {failed}" + + asyncio.run(run()) diff --git a/uv.lock b/uv.lock index c69dd25..be4c221 100644 --- a/uv.lock +++ b/uv.lock @@ -348,7 +348,7 @@ wheels = [ [[package]] name = "designer-plugin" -version = "1.2.1" +version = "1.3.0" source = { editable = "." } dependencies = [ { name = "aiohttp" },