|
4 | 4 | """ |
5 | 5 |
|
6 | 6 | import ast |
| 7 | +import functools |
7 | 8 | import inspect |
| 9 | +import logging |
8 | 10 | import textwrap |
9 | 11 | import types |
| 12 | +from collections.abc import Callable |
10 | 13 | from typing import Any |
11 | 14 |
|
| 15 | +from pydantic import BaseModel, Field |
| 16 | + |
| 17 | +from designer_plugin.d3sdk.builtin_modules import SUPPORTED_MODULES |
| 18 | + |
| 19 | +logger = logging.getLogger(__name__) |
| 20 | + |
| 21 | + |
| 22 | +############################################################################### |
| 23 | +# Package info models |
| 24 | +class ImportAlias(BaseModel): |
| 25 | + """Represents a single imported name with an optional alias. |
| 26 | +
|
| 27 | + Mirrors the structure of ast.alias for Pydantic compatibility. |
| 28 | + """ |
| 29 | + |
| 30 | + name: str = Field( |
| 31 | + description="The imported name (e.g., 'Path' in 'from pathlib import Path')" |
| 32 | + ) |
| 33 | + asname: str | None = Field( |
| 34 | + default=None, |
| 35 | + description="The alias (e.g., 'np' in 'import numpy as np')", |
| 36 | + ) |
| 37 | + |
| 38 | + |
| 39 | +class PackageInfo(BaseModel): |
| 40 | + """Structured representation of a Python import statement. |
| 41 | +
|
| 42 | + Rendering rules (via to_import_statement using ast.unparse): |
| 43 | + - package only → import package |
| 44 | + - package + alias → import package as alias |
| 45 | + - package + methods → from package import method1, method2 |
| 46 | + - package + methods w/alias → from package import method1 as alias1 |
| 47 | + """ |
| 48 | + |
| 49 | + package: str = Field(description="The module/package name to import") |
| 50 | + alias: str | None = Field( |
| 51 | + default=None, |
| 52 | + description="Alias for the package (e.g., 'np' in 'import numpy as np')", |
| 53 | + ) |
| 54 | + methods: list[ImportAlias] = Field( |
| 55 | + default_factory=list, |
| 56 | + description="Imported names for 'from X import ...' style imports", |
| 57 | + ) |
| 58 | + |
| 59 | + def to_import_statement(self) -> str: |
| 60 | + """Render back to a Python import statement using ast.unparse.""" |
| 61 | + node: ast.stmt |
| 62 | + if self.methods: |
| 63 | + node = ast.ImportFrom( |
| 64 | + module=self.package, |
| 65 | + names=[ast.alias(name=m.name, asname=m.asname) for m in self.methods], |
| 66 | + level=0, |
| 67 | + ) |
| 68 | + else: |
| 69 | + node = ast.Import(names=[ast.alias(name=self.package, asname=self.alias)]) |
| 70 | + return ast.unparse(node) |
| 71 | + |
12 | 72 |
|
13 | 73 | ############################################################################### |
14 | 74 | # Source code extraction utilities |
@@ -369,94 +429,157 @@ def validate_and_extract_args( |
369 | 429 |
|
370 | 430 |
|
371 | 431 | ############################################################################### |
372 | | -# Python package finder utility |
373 | | -def find_packages_in_current_file(caller_stack: int = 1) -> list[str]: |
374 | | - """Find all import statements in the caller's file by inspecting the call stack. |
| 432 | +# Function-scoped import extraction utility |
| 433 | +def _collect_used_names(func_node: ast.FunctionDef | ast.AsyncFunctionDef) -> set[str]: |
| 434 | + """Collect all identifier names used inside a function body. |
| 435 | +
|
| 436 | + Walks the function's AST body and extracts: |
| 437 | + - Simple names (ast.Name nodes, e.g., ``foo`` in ``foo()``) |
| 438 | + - Root names of attribute chains (e.g., ``np`` in ``np.array()``) |
| 439 | +
|
| 440 | + Args: |
| 441 | + func_node: The function AST node to analyse. |
| 442 | +
|
| 443 | + Returns: |
| 444 | + Set of identifier strings used in the function body. |
| 445 | + """ |
| 446 | + names: set[str] = set() |
| 447 | + for node in ast.walk(func_node): |
| 448 | + if isinstance(node, ast.Name): |
| 449 | + names.add(node.id) |
| 450 | + elif isinstance(node, ast.Attribute): |
| 451 | + # Walk down the attribute chain to find the root name |
| 452 | + root: ast.expr = node |
| 453 | + while isinstance(root, ast.Attribute): |
| 454 | + root = root.value |
| 455 | + if isinstance(root, ast.Name): |
| 456 | + names.add(root.id) |
| 457 | + return names |
| 458 | + |
| 459 | + |
| 460 | +def _is_type_checking_block(node: ast.If) -> bool: |
| 461 | + """Check if an if statement is ``if TYPE_CHECKING:``.""" |
| 462 | + if isinstance(node.test, ast.Name) and node.test.id == "TYPE_CHECKING": |
| 463 | + return True |
| 464 | + # Also match `if typing.TYPE_CHECKING:` |
| 465 | + if isinstance(node.test, ast.Attribute): |
| 466 | + return ( |
| 467 | + node.test.attr == "TYPE_CHECKING" |
| 468 | + and isinstance(node.test.value, ast.Name) |
| 469 | + and node.test.value.id == "typing" |
| 470 | + ) |
| 471 | + return False |
| 472 | + |
| 473 | + |
| 474 | +def _is_supported_module(module_name: str) -> bool: |
| 475 | + """Check if a module (or its top-level parent) is Designer-supported.""" |
| 476 | + top_level = module_name.split(".")[0] |
| 477 | + return top_level in SUPPORTED_MODULES |
| 478 | + |
375 | 479 |
|
376 | | - This function walks up the call stack to find the module where it was called from, |
377 | | - then parses that module's source code to extract all import statements that are |
378 | | - compatible with Python 2.7 and safe to send to Designer. |
| 480 | +@functools.lru_cache(maxsize=128) |
| 481 | +def _get_module_ast(module: types.ModuleType) -> ast.Module | None: |
| 482 | + """Return the parsed AST for *module*, cached by module identity.""" |
| 483 | + try: |
| 484 | + return ast.parse(inspect.getsource(module)) |
| 485 | + except (OSError, TypeError): |
| 486 | + return None |
| 487 | + |
| 488 | + |
| 489 | +def find_imports_for_function(func: Callable[..., Any]) -> list[PackageInfo]: |
| 490 | + """Extract import statements used by a function from its source file. |
| 491 | +
|
| 492 | + Inspects the module containing *func*, parses all top-level imports, then |
| 493 | + filters them down to only those whose imported names are actually referenced |
| 494 | + inside the function body. |
379 | 495 |
|
380 | 496 | Args: |
381 | | - caller_stack: Number of frames to go up the call stack. Default is 1 (immediate caller). |
382 | | - Use higher values to inspect files further up the call chain. |
| 497 | + func: The callable to analyse. |
383 | 498 |
|
384 | 499 | Returns: |
385 | | - Sorted list of unique import statement strings (e.g., "import ast", "from pathlib import Path"). |
| 500 | + Sorted list of :class:`PackageInfo` objects representing the imports |
| 501 | + used by *func*. |
386 | 502 |
|
387 | 503 | Filters applied: |
388 | | - - Excludes imports inside `if TYPE_CHECKING:` blocks (type checking only) |
389 | | - - Excludes imports from the 'd3blobgen' package (client-side only) |
390 | | - - Excludes imports from the 'typing' module (not supported in Python 2.7) |
391 | | - - Excludes imports of this function itself to avoid circular references |
| 504 | + - Excludes imports inside ``if TYPE_CHECKING:`` blocks |
| 505 | + - Only includes imports from Designer-supported builtin modules |
| 506 | + (see ``SUPPORTED_MODULES`` in ``builtin_modules.py``) |
| 507 | + - Only includes imports whose names are actually used in the function body |
392 | 508 | """ |
393 | | - # Get the this file frame |
394 | | - current_frame: types.FrameType | None = inspect.currentframe() |
395 | | - if not current_frame: |
| 509 | + # --- 1. Get the function's module source --- |
| 510 | + module = inspect.getmodule(func) |
| 511 | + if not module: |
396 | 512 | return [] |
397 | 513 |
|
398 | | - # Get the caller's frame (file where this function is called) |
399 | | - caller_frame: types.FrameType | None = current_frame |
400 | | - for _ in range(caller_stack): |
401 | | - if not caller_frame or not caller_frame.f_back: |
402 | | - return [] |
403 | | - caller_frame = caller_frame.f_back |
404 | | - |
405 | | - if not caller_frame: |
| 514 | + module_tree = _get_module_ast(module) |
| 515 | + if module_tree is None: |
| 516 | + logger.warning( |
| 517 | + "Cannot detect file-level imports for '%s': module source unavailable " |
| 518 | + "(e.g. Jupyter notebook). Place imports inside the function body instead.", |
| 519 | + func.__qualname__, |
| 520 | + ) |
406 | 521 | return [] |
407 | 522 |
|
408 | | - modules: types.ModuleType | None = inspect.getmodule(caller_frame) |
409 | | - if not modules: |
| 523 | + # --- 2. Collect names used inside the function body --- |
| 524 | + func_source = textwrap.dedent(inspect.getsource(func)) |
| 525 | + func_tree = ast.parse(func_source) |
| 526 | + if not func_tree.body: |
410 | 527 | return [] |
411 | 528 |
|
412 | | - source: str = inspect.getsource(modules) |
413 | | - |
414 | | - # Parse the source code |
415 | | - tree = ast.parse(source) |
416 | | - |
417 | | - # Get the name of this function to filter it out |
418 | | - # For example, we don't want `from core import find_packages_in_current_file` |
419 | | - function_name: str = current_frame.f_code.co_name |
420 | | - # Skip any package from d3blobgen |
421 | | - d3blobgen_package_name: str = "d3blobgen" |
422 | | - # typing not supported in python2.7 |
423 | | - typing_package_name: str = "typing" |
| 529 | + func_node = func_tree.body[0] |
| 530 | + if not isinstance(func_node, (ast.FunctionDef, ast.AsyncFunctionDef)): |
| 531 | + return [] |
424 | 532 |
|
425 | | - def is_type_checking_block(node: ast.If) -> bool: |
426 | | - """Check if an if statement is 'if TYPE_CHECKING:'""" |
427 | | - return isinstance(node.test, ast.Name) and node.test.id == "TYPE_CHECKING" |
| 533 | + used_names = _collect_used_names(func_node) |
428 | 534 |
|
429 | | - imports: list[str] = [] |
430 | | - for node in tree.body: |
431 | | - # Skip TYPE_CHECKING blocks entirely |
432 | | - if isinstance(node, ast.If) and is_type_checking_block(node): |
| 535 | + # --- 3. Parse file-level imports and filter to used ones --- |
| 536 | + packages: list[PackageInfo] = [] |
| 537 | + for node in module_tree.body: |
| 538 | + # Skip TYPE_CHECKING blocks |
| 539 | + if isinstance(node, ast.If) and _is_type_checking_block(node): |
433 | 540 | continue |
434 | 541 |
|
435 | 542 | if isinstance(node, ast.Import): |
436 | | - imported_modules: list[str] = [alias.name for alias in node.names] |
437 | | - # Skip imports that include d3blobgen |
438 | | - if any(d3blobgen_package_name in module for module in imported_modules): |
439 | | - continue |
440 | | - if any(typing_package_name in module for module in imported_modules): |
441 | | - continue |
442 | | - import_text: str = f"import {', '.join(imported_modules)}" |
443 | | - imports.append(import_text) |
| 543 | + for alias in node.names: |
| 544 | + if not _is_supported_module(alias.name): |
| 545 | + continue |
| 546 | + |
| 547 | + # The name used in code is the alias if present, otherwise the top-level |
| 548 | + # package name (e.g. "import logging.handlers" binds "logging", not |
| 549 | + # "logging.handlers"). |
| 550 | + effective_name = ( |
| 551 | + alias.asname if alias.asname else alias.name.split(".")[0] |
| 552 | + ) |
| 553 | + if effective_name in used_names: |
| 554 | + packages.append( |
| 555 | + PackageInfo( |
| 556 | + package=alias.name, |
| 557 | + alias=alias.asname, |
| 558 | + ) |
| 559 | + ) |
444 | 560 |
|
445 | 561 | elif isinstance(node, ast.ImportFrom): |
446 | | - imported_module: str | None = node.module |
447 | | - imported_names: list[str] = [alias.name for alias in node.names] |
448 | | - if not imported_module: |
449 | | - continue |
450 | | - # Skip imports that include d3blobgen |
451 | | - if d3blobgen_package_name in imported_module: |
| 562 | + if not node.module: |
452 | 563 | continue |
453 | | - elif typing_package_name in imported_module: |
454 | | - continue |
455 | | - # Skip imports that include this function itself |
456 | | - if function_name in imported_names: |
| 564 | + if not _is_supported_module(node.module): |
457 | 565 | continue |
458 | 566 |
|
459 | | - line_text = f"from {imported_module} import {', '.join(imported_names)}" |
460 | | - imports.append(line_text) |
| 567 | + # Filter to only methods actually used by the function |
| 568 | + matched_methods: list[ImportAlias] = [] |
| 569 | + for alias in node.names: |
| 570 | + effective_name = alias.asname if alias.asname else alias.name |
| 571 | + if effective_name in used_names: |
| 572 | + matched_methods.append( |
| 573 | + ImportAlias(name=alias.name, asname=alias.asname) |
| 574 | + ) |
| 575 | + |
| 576 | + if matched_methods: |
| 577 | + packages.append( |
| 578 | + PackageInfo( |
| 579 | + package=node.module, |
| 580 | + methods=matched_methods, |
| 581 | + ) |
| 582 | + ) |
461 | 583 |
|
462 | | - return sorted(set(imports)) |
| 584 | + # Sort by import statement string for deterministic output |
| 585 | + return sorted(packages, key=lambda p: p.to_import_statement()) |
0 commit comments