Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 90 additions & 36 deletions src/aignostics/utils/_di.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import importlib
import pkgutil
from collections.abc import Callable
from functools import lru_cache
from importlib.metadata import entry_points
from inspect import isclass
Expand Down Expand Up @@ -41,11 +42,81 @@ def load_modules() -> None:
importlib.import_module(f"{__project_name__}.{name}")


def _scan_packages_deep(package_name: str, predicate: Callable[[object], bool]) -> list[Any]:
"""
Deep-scan a package by walking all top-level submodules via pkgutil.iter_modules.

Discovers objects by importing the package, iterating through all submodules,
and checking each module's members against the predicate. Used for the main
aignostics package to ensure all registered implementations are found.

Example:
>>> from inspect import isclass
>>> _scan_packages_deep("aignostics", lambda m: isclass(m))

Args:
package_name (str): Name of the package to deep-scan.
predicate (Callable[[object], bool]): Function to filter members.

Returns:
list[Any]: List of members matching the predicate.
"""
results: list[Any] = []
try:
package = importlib.import_module(package_name)
except ImportError:
return results

for _, name, _ in pkgutil.iter_modules(package.__path__):
try:
module = importlib.import_module(f"{package_name}.{name}")
for member_name in dir(module):
member = getattr(module, member_name)
if predicate(member):
results.append(member)
except ImportError:
continue

return results


def _scan_packages_shallow(package_names: tuple[str, ...], predicate: Callable[[object], bool]) -> list[Any]:
"""
Shallow-scan plugin packages by checking only top-level exports.

Discovers objects by importing each package and checking its top-level members
(i.e. what is exported from __init__.py via dir(package)) against the predicate.
Does NOT walk submodules via pkgutil.iter_modules. This prevents over-discovering
objects from plugin submodules that happen to be imported internally.

Args:
package_names (tuple[str, ...]): Names of the plugin packages to shallow-scan.
predicate (Callable[[object], bool]): Function to filter members.

Returns:
list[Any]: List of members matching the predicate.
"""
results: list[Any] = []
for package_name in package_names:
try:
package = importlib.import_module(package_name)
except ImportError:
continue

for member_name in dir(package):
member = getattr(package, member_name)
if predicate(member):
results.append(member)

return results


def locate_implementations(_class: type[Any]) -> list[Any]:
"""
Dynamically discover all instances of some class.

Searches in the main project and all plugins registered via entry points.
Searches plugin packages using a shallow scan and the main project package using
a deep scan.

Args:
_class (type[Any]): Class to search for.
Expand All @@ -56,29 +127,24 @@ def locate_implementations(_class: type[Any]) -> list[Any]:
if _class in _implementation_cache:
return _implementation_cache[_class]

plugin_packages = discover_plugin_packages()
def predicate(member: object) -> bool:
return isinstance(member, _class)

implementations = []
for package_name in [*plugin_packages, __project_name__]:
package = importlib.import_module(package_name)
results = [
*_scan_packages_shallow(discover_plugin_packages(), predicate),
*_scan_packages_deep(__project_name__, predicate),
]

for _, name, _ in pkgutil.iter_modules(package.__path__):
module = importlib.import_module(f"{package_name}.{name}")
# Check all members of the module
for member_name in dir(module):
member = getattr(module, member_name)
if isinstance(member, _class):
implementations.append(member)

_implementation_cache[_class] = implementations
return implementations
_implementation_cache[_class] = results
return results


def locate_subclasses(_class: type[Any]) -> list[Any]:
"""
Dynamically discover all classes that are subclasses of some type.

Searches in the main project and all plugins registered via entry points.
Searches plugin packages using a shallow scan and the main project package using
a deep scan.

Args:
_class (type[Any]): Parent class of subclasses to search for.
Expand All @@ -89,25 +155,13 @@ def locate_subclasses(_class: type[Any]) -> list[Any]:
if _class in _subclass_cache:
return _subclass_cache[_class]

plugin_packages = discover_plugin_packages()
def predicate(member: object) -> bool:
return isclass(member) and issubclass(member, _class) and member != _class

subclasses = []
for package_name in [*plugin_packages, __project_name__]:
try:
package = importlib.import_module(package_name)
except ImportError:
continue
results = [
*_scan_packages_shallow(discover_plugin_packages(), predicate),
*_scan_packages_deep(__project_name__, predicate),
]

for _, name, _ in pkgutil.iter_modules(package.__path__):
try:
module = importlib.import_module(f"{package_name}.{name}")
# Check all members of the module
for member_name in dir(module):
member = getattr(module, member_name)
if isclass(member) and issubclass(member, _class) and member != _class:
subclasses.append(member)
except ImportError:
continue

_subclass_cache[_class] = subclasses
return subclasses
_subclass_cache[_class] = results
return results
Loading
Loading