Skip to content

Commit fa4eeb3

Browse files
authored
fix: shallow plugin discovery (#462)
1 parent 8ff112d commit fa4eeb3

2 files changed

Lines changed: 537 additions & 80 deletions

File tree

src/aignostics/utils/_di.py

Lines changed: 90 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import importlib
44
import pkgutil
5+
from collections.abc import Callable
56
from functools import lru_cache
67
from importlib.metadata import entry_points
78
from inspect import isclass
@@ -41,11 +42,81 @@ def load_modules() -> None:
4142
importlib.import_module(f"{__project_name__}.{name}")
4243

4344

45+
def _scan_packages_deep(package_name: str, predicate: Callable[[object], bool]) -> list[Any]:
46+
"""
47+
Deep-scan a package by walking all top-level submodules via pkgutil.iter_modules.
48+
49+
Discovers objects by importing the package, iterating through all submodules,
50+
and checking each module's members against the predicate. Used for the main
51+
aignostics package to ensure all registered implementations are found.
52+
53+
Example:
54+
>>> from inspect import isclass
55+
>>> _scan_packages_deep("aignostics", lambda m: isclass(m))
56+
57+
Args:
58+
package_name (str): Name of the package to deep-scan.
59+
predicate (Callable[[object], bool]): Function to filter members.
60+
61+
Returns:
62+
list[Any]: List of members matching the predicate.
63+
"""
64+
results: list[Any] = []
65+
try:
66+
package = importlib.import_module(package_name)
67+
except ImportError:
68+
return results
69+
70+
for _, name, _ in pkgutil.iter_modules(package.__path__):
71+
try:
72+
module = importlib.import_module(f"{package_name}.{name}")
73+
for member_name in dir(module):
74+
member = getattr(module, member_name)
75+
if predicate(member):
76+
results.append(member)
77+
except ImportError:
78+
continue
79+
80+
return results
81+
82+
83+
def _scan_packages_shallow(package_names: tuple[str, ...], predicate: Callable[[object], bool]) -> list[Any]:
84+
"""
85+
Shallow-scan plugin packages by checking only top-level exports.
86+
87+
Discovers objects by importing each package and checking its top-level members
88+
(i.e. what is exported from __init__.py via dir(package)) against the predicate.
89+
Does NOT walk submodules via pkgutil.iter_modules. This prevents over-discovering
90+
objects from plugin submodules that happen to be imported internally.
91+
92+
Args:
93+
package_names (tuple[str, ...]): Names of the plugin packages to shallow-scan.
94+
predicate (Callable[[object], bool]): Function to filter members.
95+
96+
Returns:
97+
list[Any]: List of members matching the predicate.
98+
"""
99+
results: list[Any] = []
100+
for package_name in package_names:
101+
try:
102+
package = importlib.import_module(package_name)
103+
except ImportError:
104+
continue
105+
106+
for member_name in dir(package):
107+
member = getattr(package, member_name)
108+
if predicate(member):
109+
results.append(member)
110+
111+
return results
112+
113+
44114
def locate_implementations(_class: type[Any]) -> list[Any]:
45115
"""
46116
Dynamically discover all instances of some class.
47117
48-
Searches in the main project and all plugins registered via entry points.
118+
Searches plugin packages using a shallow scan and the main project package using
119+
a deep scan.
49120
50121
Args:
51122
_class (type[Any]): Class to search for.
@@ -56,29 +127,24 @@ def locate_implementations(_class: type[Any]) -> list[Any]:
56127
if _class in _implementation_cache:
57128
return _implementation_cache[_class]
58129

59-
plugin_packages = discover_plugin_packages()
130+
def predicate(member: object) -> bool:
131+
return isinstance(member, _class)
60132

61-
implementations = []
62-
for package_name in [*plugin_packages, __project_name__]:
63-
package = importlib.import_module(package_name)
133+
results = [
134+
*_scan_packages_shallow(discover_plugin_packages(), predicate),
135+
*_scan_packages_deep(__project_name__, predicate),
136+
]
64137

65-
for _, name, _ in pkgutil.iter_modules(package.__path__):
66-
module = importlib.import_module(f"{package_name}.{name}")
67-
# Check all members of the module
68-
for member_name in dir(module):
69-
member = getattr(module, member_name)
70-
if isinstance(member, _class):
71-
implementations.append(member)
72-
73-
_implementation_cache[_class] = implementations
74-
return implementations
138+
_implementation_cache[_class] = results
139+
return results
75140

76141

77142
def locate_subclasses(_class: type[Any]) -> list[Any]:
78143
"""
79144
Dynamically discover all classes that are subclasses of some type.
80145
81-
Searches in the main project and all plugins registered via entry points.
146+
Searches plugin packages using a shallow scan and the main project package using
147+
a deep scan.
82148
83149
Args:
84150
_class (type[Any]): Parent class of subclasses to search for.
@@ -89,25 +155,13 @@ def locate_subclasses(_class: type[Any]) -> list[Any]:
89155
if _class in _subclass_cache:
90156
return _subclass_cache[_class]
91157

92-
plugin_packages = discover_plugin_packages()
158+
def predicate(member: object) -> bool:
159+
return isclass(member) and issubclass(member, _class) and member != _class
93160

94-
subclasses = []
95-
for package_name in [*plugin_packages, __project_name__]:
96-
try:
97-
package = importlib.import_module(package_name)
98-
except ImportError:
99-
continue
161+
results = [
162+
*_scan_packages_shallow(discover_plugin_packages(), predicate),
163+
*_scan_packages_deep(__project_name__, predicate),
164+
]
100165

101-
for _, name, _ in pkgutil.iter_modules(package.__path__):
102-
try:
103-
module = importlib.import_module(f"{package_name}.{name}")
104-
# Check all members of the module
105-
for member_name in dir(module):
106-
member = getattr(module, member_name)
107-
if isclass(member) and issubclass(member, _class) and member != _class:
108-
subclasses.append(member)
109-
except ImportError:
110-
continue
111-
112-
_subclass_cache[_class] = subclasses
113-
return subclasses
166+
_subclass_cache[_class] = results
167+
return results

0 commit comments

Comments
 (0)