diff --git a/src/aignostics/utils/_di.py b/src/aignostics/utils/_di.py index fd73fdece..000066dee 100644 --- a/src/aignostics/utils/_di.py +++ b/src/aignostics/utils/_di.py @@ -2,6 +2,7 @@ import importlib import pkgutil +from collections.abc import Callable from functools import lru_cache from importlib.metadata import entry_points from inspect import isclass @@ -41,11 +42,81 @@ def load_modules() -> None: importlib.import_module(f"{__project_name__}.{name}") +def _scan_packages_deep(package_name: str, predicate: Callable[[object], bool]) -> list[Any]: + """ + Deep-scan a package by walking all top-level submodules via pkgutil.iter_modules. + + Discovers objects by importing the package, iterating through all submodules, + and checking each module's members against the predicate. Used for the main + aignostics package to ensure all registered implementations are found. + + Example: + >>> from inspect import isclass + >>> _scan_packages_deep("aignostics", lambda m: isclass(m)) + + Args: + package_name (str): Name of the package to deep-scan. + predicate (Callable[[object], bool]): Function to filter members. + + Returns: + list[Any]: List of members matching the predicate. + """ + results: list[Any] = [] + try: + package = importlib.import_module(package_name) + except ImportError: + return results + + for _, name, _ in pkgutil.iter_modules(package.__path__): + try: + module = importlib.import_module(f"{package_name}.{name}") + for member_name in dir(module): + member = getattr(module, member_name) + if predicate(member): + results.append(member) + except ImportError: + continue + + return results + + +def _scan_packages_shallow(package_names: tuple[str, ...], predicate: Callable[[object], bool]) -> list[Any]: + """ + Shallow-scan plugin packages by checking only top-level exports. + + Discovers objects by importing each package and checking its top-level members + (i.e. what is exported from __init__.py via dir(package)) against the predicate. + Does NOT walk submodules via pkgutil.iter_modules. This prevents over-discovering + objects from plugin submodules that happen to be imported internally. + + Args: + package_names (tuple[str, ...]): Names of the plugin packages to shallow-scan. + predicate (Callable[[object], bool]): Function to filter members. + + Returns: + list[Any]: List of members matching the predicate. + """ + results: list[Any] = [] + for package_name in package_names: + try: + package = importlib.import_module(package_name) + except ImportError: + continue + + for member_name in dir(package): + member = getattr(package, member_name) + if predicate(member): + results.append(member) + + return results + + def locate_implementations(_class: type[Any]) -> list[Any]: """ Dynamically discover all instances of some class. - Searches in the main project and all plugins registered via entry points. + Searches plugin packages using a shallow scan and the main project package using + a deep scan. Args: _class (type[Any]): Class to search for. @@ -56,29 +127,24 @@ def locate_implementations(_class: type[Any]) -> list[Any]: if _class in _implementation_cache: return _implementation_cache[_class] - plugin_packages = discover_plugin_packages() + def predicate(member: object) -> bool: + return isinstance(member, _class) - implementations = [] - for package_name in [*plugin_packages, __project_name__]: - package = importlib.import_module(package_name) + results = [ + *_scan_packages_shallow(discover_plugin_packages(), predicate), + *_scan_packages_deep(__project_name__, predicate), + ] - for _, name, _ in pkgutil.iter_modules(package.__path__): - module = importlib.import_module(f"{package_name}.{name}") - # Check all members of the module - for member_name in dir(module): - member = getattr(module, member_name) - if isinstance(member, _class): - implementations.append(member) - - _implementation_cache[_class] = implementations - return implementations + _implementation_cache[_class] = results + return results def locate_subclasses(_class: type[Any]) -> list[Any]: """ Dynamically discover all classes that are subclasses of some type. - Searches in the main project and all plugins registered via entry points. + Searches plugin packages using a shallow scan and the main project package using + a deep scan. Args: _class (type[Any]): Parent class of subclasses to search for. @@ -89,25 +155,13 @@ def locate_subclasses(_class: type[Any]) -> list[Any]: if _class in _subclass_cache: return _subclass_cache[_class] - plugin_packages = discover_plugin_packages() + def predicate(member: object) -> bool: + return isclass(member) and issubclass(member, _class) and member != _class - subclasses = [] - for package_name in [*plugin_packages, __project_name__]: - try: - package = importlib.import_module(package_name) - except ImportError: - continue + results = [ + *_scan_packages_shallow(discover_plugin_packages(), predicate), + *_scan_packages_deep(__project_name__, predicate), + ] - for _, name, _ in pkgutil.iter_modules(package.__path__): - try: - module = importlib.import_module(f"{package_name}.{name}") - # Check all members of the module - for member_name in dir(module): - member = getattr(module, member_name) - if isclass(member) and issubclass(member, _class) and member != _class: - subclasses.append(member) - except ImportError: - continue - - _subclass_cache[_class] = subclasses - return subclasses + _subclass_cache[_class] = results + return results diff --git a/tests/aignostics/utils/di_test.py b/tests/aignostics/utils/di_test.py index 37fa32ec5..f7f8d3c7a 100644 --- a/tests/aignostics/utils/di_test.py +++ b/tests/aignostics/utils/di_test.py @@ -1,18 +1,21 @@ """Tests for the CLI utilities and dependency injection.""" import sys -from collections.abc import Generator +from collections.abc import Callable, Generator +from contextlib import contextmanager from types import ModuleType from unittest.mock import MagicMock, Mock, patch import pytest import typer +import aignostics.utils._di as di_module from aignostics.utils._cli import ( _add_epilog_recursively, _no_args_is_help_recursively, prepare_cli, ) +from aignostics.utils._constants import __project_name__ from aignostics.utils._di import ( PLUGIN_ENTRY_POINT_GROUP, _implementation_cache, @@ -25,6 +28,8 @@ # Constants to avoid duplication TEST_EPILOG = "Test epilog" SCRIPT_FILENAME = "script.py" +PLUGIN = "plugin" +MYMODULE = "mymodule" @pytest.mark.unit @@ -202,6 +207,11 @@ def test_no_args_is_help_recursively_calls_itself_on_nested_typers(record_proper assert subgroup.no_args_is_help is True +# --------------------------------------------------------------------------- +# Plugin discovery helpers +# --------------------------------------------------------------------------- + + class DummyBaseClass: """Base class for testing locate_subclasses.""" @@ -221,6 +231,141 @@ class AnotherDummySub(AnotherDummyBase): another_dummy_instance = AnotherDummyBase() +def _mock_package() -> MagicMock: + """Return a MagicMock that looks like an importable package (has __path__).""" + pkg = MagicMock() + pkg.__path__ = ["/fake/path"] + return pkg + + +def _make_import_side_effect( + mapping: dict[str, ModuleType | Exception], + default: MagicMock | None = None, +) -> Callable[[str], ModuleType]: + """Return an import side-effect callable driven by *mapping*. + + Args: + mapping: Maps module name to the module to return or an exception to raise. + default: Returned for any name not in *mapping*. Defaults to a package + with an empty ``__path__``. + + Returns: + A callable suitable for use as ``importlib.import_module``'s side effect. + """ + if default is None: + default = _mock_package() + default.__path__ = [] + + def _side_effect(name: str) -> ModuleType: + if name in mapping: + result = mapping[name] + if isinstance(result, BaseException): + raise result + return result # type: ignore[return-value] + return default # type: ignore[return-value] + + return _side_effect + + +@contextmanager +def _broken_plugin_package_patches( + main_pkg: MagicMock, + main_mod: ModuleType, +) -> Generator[None, None, None]: + """Yield patches where a plugin package itself raises ImportError. + + The plugin package raises ``ImportError`` on import. The main project + package and its ``MYMODULE`` submodule import normally. + + Args: + main_pkg: Mock main package (has ``__path__``). + main_mod: Module to return for the main ``MYMODULE`` import. + """ + with ( + patch.object(di_module, "discover_plugin_packages", return_value=(PLUGIN,)), + patch.object( + di_module.importlib, + "import_module", + side_effect=_make_import_side_effect({ + PLUGIN: ImportError("broken"), + __project_name__: main_pkg, + f"{__project_name__}.{MYMODULE}": main_mod, + }), + ), + patch.object(di_module.pkgutil, "iter_modules", return_value=[("", MYMODULE, False)]), + ): + yield + + +@contextmanager +def _no_match_plugin_patches( + plugin_pkg: MagicMock, + main_pkg: MagicMock, + main_mod: ModuleType, +) -> Generator[None, None, None]: + """Yield patches where a plugin imports successfully but has no matching top-level members. + + The plugin package is importable but its top-level namespace contains no + members that satisfy the discovery predicate. The main project package and + its ``MYMODULE`` submodule import normally and contain the expected member. + + Args: + plugin_pkg: Mock plugin package (importable, no matching members). + main_pkg: Mock main package (has ``__path__``). + main_mod: Module to return for the main ``MYMODULE`` import. + """ + with ( + patch.object(di_module, "discover_plugin_packages", return_value=(PLUGIN,)), + patch.object( + di_module.importlib, + "import_module", + side_effect=_make_import_side_effect({ + PLUGIN: plugin_pkg, + __project_name__: main_pkg, + f"{__project_name__}.{MYMODULE}": main_mod, + }), + ), + patch.object(di_module.pkgutil, "iter_modules", return_value=[("", MYMODULE, False)]), + ): + yield + + +@contextmanager +def _no_plugins_patches( + main_pkg: MagicMock, + main_mod: ModuleType, +) -> Generator[list[str], None, None]: + """Yield a tracking list of searched module names with no-plugin patches active. + + Patches ``discover_plugin_packages`` to return an empty tuple, + ``importlib.import_module`` with a call-tracking side-effect, and + ``pkgutil.iter_modules`` with a single-module result. + + Args: + main_pkg: Mock main package (has ``__path__``). + main_mod: Module to return for the main ``MYMODULE`` import. + + Yields: + A list of module names that were imported during the patched scope. + """ + searched: list[str] = [] + base_side_effect = _make_import_side_effect({ + __project_name__: main_pkg, + f"{__project_name__}.{MYMODULE}": main_mod, + }) + + def tracking_import(name: str) -> ModuleType: + searched.append(name) + return base_side_effect(name) + + with ( + patch.object(di_module, "discover_plugin_packages", return_value=()), + patch.object(di_module.importlib, "import_module", side_effect=tracking_import), + patch.object(di_module.pkgutil, "iter_modules", return_value=[("", MYMODULE, False)]), + ): + yield searched + + @pytest.fixture def clear_di_caches() -> Generator[None, None, None]: """Clear DI caches before and after each test. @@ -237,10 +382,15 @@ def clear_di_caches() -> Generator[None, None, None]: discover_plugin_packages.cache_clear() +# --------------------------------------------------------------------------- +# discover_plugin_packages +# --------------------------------------------------------------------------- + + @pytest.mark.unit def test_discover_plugin_packages_returns_tuple(clear_di_caches, record_property) -> None: """Test that discover_plugin_packages returns a tuple.""" - record_property("tested-item-id", "SPEC-UTILS-DI") + record_property("tested-item-id", "SPEC-UTILS-SERVICE") result = discover_plugin_packages() assert isinstance(result, tuple) @@ -248,7 +398,7 @@ def test_discover_plugin_packages_returns_tuple(clear_di_caches, record_property @pytest.mark.unit def test_discover_plugin_packages_uses_correct_entry_point_group(clear_di_caches, record_property) -> None: """Test that discover_plugin_packages uses the correct entry point group.""" - record_property("tested-item-id", "SPEC-UTILS-DI") + record_property("tested-item-id", "SPEC-UTILS-SERVICE") assert PLUGIN_ENTRY_POINT_GROUP == "aignostics.plugins" @@ -258,7 +408,7 @@ def test_discover_plugin_packages_extracts_values_from_entry_points( mock_entry_points: Mock, clear_di_caches, record_property ) -> None: """Test that discover_plugin_packages extracts values from entry points.""" - record_property("tested-item-id", "SPEC-UTILS-DI") + record_property("tested-item-id", "SPEC-UTILS-SERVICE") # Setup mock entry points mock_ep1 = MagicMock() mock_ep1.value = "plugin_one" @@ -280,7 +430,7 @@ def test_discover_plugin_packages_returns_empty_tuple_when_no_plugins( mock_entry_points: Mock, clear_di_caches, record_property ) -> None: """Test that discover_plugin_packages returns empty tuple when no plugins registered.""" - record_property("tested-item-id", "SPEC-UTILS-DI") + record_property("tested-item-id", "SPEC-UTILS-SERVICE") mock_entry_points.return_value = [] result = discover_plugin_packages() @@ -292,7 +442,7 @@ def test_discover_plugin_packages_returns_empty_tuple_when_no_plugins( @patch("aignostics.utils._di.entry_points") def test_discover_plugin_packages_is_cached(mock_entry_points: Mock, clear_di_caches, record_property) -> None: """Test that discover_plugin_packages caches results.""" - record_property("tested-item-id", "SPEC-UTILS-DI") + record_property("tested-item-id", "SPEC-UTILS-SERVICE") mock_ep = MagicMock() mock_ep.value = "cached_plugin" mock_entry_points.return_value = [mock_ep] @@ -306,23 +456,23 @@ def test_discover_plugin_packages_is_cached(mock_entry_points: Mock, clear_di_ca assert result1 == result2 == ("cached_plugin",) +# --------------------------------------------------------------------------- +# locate_implementations — plugin discovery +# --------------------------------------------------------------------------- + + @pytest.mark.unit def test_locate_implementations_searches_plugins(clear_di_caches, record_property) -> None: - """Test that locate_implementations searches plugin packages.""" - record_property("tested-item-id", "SPEC-UTILS-DI") - import aignostics.utils._di as di_module + """Test that locate_implementations shallow-scans plugin packages for top-level exports.""" + record_property("tested-item-id", "SPEC-UTILS-SERVICE") plugin_instance = AnotherDummyBase() - mock_plugin_package = MagicMock() - mock_plugin_package.__path__ = ["/fake/path"] - mock_plugin_module = ModuleType("test_plugin.submodule") - mock_plugin_module.plugin_instance = plugin_instance # type: ignore[attr-defined] + mock_plugin_package = ModuleType("test_plugin") + mock_plugin_package.plugin_instance = plugin_instance # type: ignore[attr-defined] def import_side_effect(name: str) -> ModuleType: if name == "test_plugin": return mock_plugin_package - if name == "test_plugin.submodule": - return mock_plugin_module mock_aig = MagicMock() mock_aig.__path__ = [] return mock_aig @@ -330,17 +480,122 @@ def import_side_effect(name: str) -> ModuleType: with ( patch.object(di_module, "discover_plugin_packages", return_value=("test_plugin",)), patch.object(di_module.importlib, "import_module", side_effect=import_side_effect), - patch.object(di_module.pkgutil, "iter_modules", side_effect=[[("", "submodule", False)], []]), + patch.object(di_module.pkgutil, "iter_modules", return_value=[]), ): result = locate_implementations(AnotherDummyBase) assert plugin_instance in result +@pytest.mark.unit +def test_locate_implementations_only_finds_plugin_top_level_exports(clear_di_caches, record_property) -> None: + """Plugin submodule instances are not discovered; only top-level __init__.py exports are found.""" + record_property("tested-item-id", "SPEC-UTILS-SERVICE") + + class _Base: + pass + + top_instance = _Base() + sub_instance = _Base() + + plugin_pkg = _mock_package() + plugin_pkg.top_instance = top_instance # type: ignore[attr-defined] + + plugin_submod = ModuleType(f"{PLUGIN}.submod") + plugin_submod.sub_instance = sub_instance # type: ignore[attr-defined] + + with ( + patch.object(di_module, "discover_plugin_packages", return_value=(PLUGIN,)), + patch.object( + di_module.importlib, + "import_module", + side_effect=_make_import_side_effect({ + PLUGIN: plugin_pkg, + f"{PLUGIN}.submod": plugin_submod, + }), + ), + patch.object(di_module.pkgutil, "iter_modules", return_value=[]), + ): + result = locate_implementations(_Base) + + assert top_instance in result + assert sub_instance not in result + + +@pytest.mark.unit +def test_locate_implementations_handles_broken_plugin_package(clear_di_caches, record_property) -> None: + """A plugin package raising ImportError on import is skipped; main package still searched.""" + record_property("tested-item-id", "SPEC-UTILS-SERVICE") + + class _Base: + pass + + main_instance = _Base() + main_pkg = _mock_package() + main_mod = ModuleType(f"{__project_name__}.{MYMODULE}") + main_mod.main_instance = main_instance # type: ignore[attr-defined] + + with _broken_plugin_package_patches(main_pkg, main_mod): + result = locate_implementations(_Base) + + assert main_instance in result + + +@pytest.mark.unit +def test_locate_implementations_handles_plugin_with_no_matching_top_level_members( + clear_di_caches, record_property +) -> None: + """A plugin with no matching top-level exports is skipped; main package still searched.""" + record_property("tested-item-id", "SPEC-UTILS-SERVICE") + + class _Base: + pass + + main_instance = _Base() + plugin_pkg = _mock_package() + main_pkg = _mock_package() + main_mod = ModuleType(f"{__project_name__}.{MYMODULE}") + main_mod.main_instance = main_instance # type: ignore[attr-defined] + + with _no_match_plugin_patches(plugin_pkg, main_pkg, main_mod): + result = locate_implementations(_Base) + + assert main_instance in result + + +@pytest.mark.unit +def test_locate_implementations_deep_scans_main_package(clear_di_caches, record_property) -> None: + """Main package submodule instances are found via deep scan even when a plugin is present.""" + record_property("tested-item-id", "SPEC-UTILS-SERVICE") + + class _Base: + pass + + main_instance = _Base() + main_pkg = _mock_package() + main_mod = ModuleType(f"{__project_name__}.{MYMODULE}") + main_mod.main_instance = main_instance # type: ignore[attr-defined] + + with ( + patch.object(di_module, "discover_plugin_packages", return_value=()), + patch.object( + di_module.importlib, + "import_module", + side_effect=_make_import_side_effect({ + __project_name__: main_pkg, + f"{__project_name__}.{MYMODULE}": main_mod, + }), + ), + patch.object(di_module.pkgutil, "iter_modules", return_value=[("", MYMODULE, False)]), + ): + result = locate_implementations(_Base) + + assert main_instance in result + + @pytest.mark.unit def test_locate_implementations_caches_results(clear_di_caches, record_property) -> None: """Test that locate_implementations caches results.""" - record_property("tested-item-id", "SPEC-UTILS-DI") - import aignostics.utils._di as di_module + record_property("tested-item-id", "SPEC-UTILS-SERVICE") mock_package = MagicMock() mock_package.__path__ = [] @@ -356,25 +611,45 @@ def test_locate_implementations_caches_results(clear_di_caches, record_property) assert AnotherDummyBase in _implementation_cache +@pytest.mark.unit +def test_locate_implementations_no_plugins_detects_main_package(clear_di_caches, record_property) -> None: + """With no plugins, locate_implementations only searches the main package.""" + record_property("tested-item-id", "SPEC-UTILS-SERVICE") + + class _Base: + pass + + instance = _Base() + main_pkg = _mock_package() + main_mod = ModuleType(f"{__project_name__}.{MYMODULE}") + main_mod.instance = instance # type: ignore[attr-defined] + + with _no_plugins_patches(main_pkg, main_mod) as searched: + result = locate_implementations(_Base) + + assert instance in result + assert not any(p != __project_name__ and not p.startswith(f"{__project_name__}.") for p in searched) + + +# --------------------------------------------------------------------------- +# locate_subclasses — plugin discovery +# --------------------------------------------------------------------------- + + @pytest.mark.unit def test_locate_subclasses_searches_plugins(clear_di_caches, record_property) -> None: - """Test that locate_subclasses searches plugin packages.""" - record_property("tested-item-id", "SPEC-UTILS-DI") - import aignostics.utils._di as di_module + """Test that locate_subclasses shallow-scans plugin packages for top-level exports.""" + record_property("tested-item-id", "SPEC-UTILS-SERVICE") class PluginSubClass(AnotherDummyBase): pass - mock_plugin_package = MagicMock() - mock_plugin_package.__path__ = ["/fake/path"] - mock_plugin_module = ModuleType("test_plugin.submodule") - mock_plugin_module.PluginSubClass = PluginSubClass # type: ignore[attr-defined] + mock_plugin_package = ModuleType("test_plugin") + mock_plugin_package.PluginSubClass = PluginSubClass # type: ignore[attr-defined] def import_side_effect(name: str) -> ModuleType: if name == "test_plugin": return mock_plugin_package - if name == "test_plugin.submodule": - return mock_plugin_module mock_aig = MagicMock() mock_aig.__path__ = [] return mock_aig @@ -382,20 +657,131 @@ def import_side_effect(name: str) -> ModuleType: with ( patch.object(di_module, "discover_plugin_packages", return_value=("test_plugin",)), patch.object(di_module.importlib, "import_module", side_effect=import_side_effect), - patch.object(di_module.pkgutil, "iter_modules", side_effect=[[("", "submodule", False)], []]), + patch.object(di_module.pkgutil, "iter_modules", return_value=[]), ): result = locate_subclasses(AnotherDummyBase) assert PluginSubClass in result +@pytest.mark.unit +def test_locate_subclasses_only_finds_plugin_top_level_exports(clear_di_caches, record_property) -> None: + """Plugin subclasses only in submodules are not discovered; only top-level exports are found.""" + record_property("tested-item-id", "SPEC-UTILS-SERVICE") + + class _Base: + pass + + class TopSub(_Base): + pass + + class SubSub(_Base): + pass + + plugin_pkg = _mock_package() + plugin_pkg.TopSub = TopSub # type: ignore[attr-defined] + + plugin_submod = ModuleType(f"{PLUGIN}.submod") + plugin_submod.SubSub = SubSub # type: ignore[attr-defined] + + with ( + patch.object(di_module, "discover_plugin_packages", return_value=(PLUGIN,)), + patch.object( + di_module.importlib, + "import_module", + side_effect=_make_import_side_effect({ + PLUGIN: plugin_pkg, + f"{PLUGIN}.submod": plugin_submod, + }), + ), + patch.object(di_module.pkgutil, "iter_modules", return_value=[]), + ): + result = locate_subclasses(_Base) + + assert TopSub in result + assert SubSub not in result + + +@pytest.mark.unit +def test_locate_subclasses_handles_broken_plugin_package(clear_di_caches, record_property) -> None: + """A plugin package raising ImportError on import is skipped; main package still searched.""" + record_property("tested-item-id", "SPEC-UTILS-SERVICE") + + class _Base: + pass + + class MainSub(_Base): + pass + + main_pkg = _mock_package() + main_mod = ModuleType(f"{__project_name__}.{MYMODULE}") + main_mod.MainSub = MainSub # type: ignore[attr-defined] + + with _broken_plugin_package_patches(main_pkg, main_mod): + result = locate_subclasses(_Base) + + assert MainSub in result + + +@pytest.mark.unit +def test_locate_subclasses_handles_plugin_with_no_matching_top_level_members(clear_di_caches, record_property) -> None: + """A plugin with no matching top-level exports is skipped; main package still searched.""" + record_property("tested-item-id", "SPEC-UTILS-SERVICE") + + class _Base: + pass + + class MainSub(_Base): + pass + + plugin_pkg = _mock_package() + main_pkg = _mock_package() + main_mod = ModuleType(f"{__project_name__}.{MYMODULE}") + main_mod.MainSub = MainSub # type: ignore[attr-defined] + + with _no_match_plugin_patches(plugin_pkg, main_pkg, main_mod): + result = locate_subclasses(_Base) + + assert MainSub in result + + +@pytest.mark.unit +def test_locate_subclasses_deep_scans_main_package(clear_di_caches, record_property) -> None: + """Main package subclasses in submodules are found via deep scan.""" + record_property("tested-item-id", "SPEC-UTILS-SERVICE") + + class _Base: + pass + + class MainSub(_Base): + pass + + main_pkg = _mock_package() + main_mod = ModuleType(f"{__project_name__}.{MYMODULE}") + main_mod.MainSub = MainSub # type: ignore[attr-defined] + + with ( + patch.object(di_module, "discover_plugin_packages", return_value=()), + patch.object( + di_module.importlib, + "import_module", + side_effect=_make_import_side_effect({ + __project_name__: main_pkg, + f"{__project_name__}.{MYMODULE}": main_mod, + }), + ), + patch.object(di_module.pkgutil, "iter_modules", return_value=[("", MYMODULE, False)]), + ): + result = locate_subclasses(_Base) + + assert MainSub in result + + @pytest.mark.unit def test_locate_subclasses_excludes_base_class(clear_di_caches, record_property) -> None: """Test that locate_subclasses excludes the base class itself.""" - record_property("tested-item-id", "SPEC-UTILS-DI") - import aignostics.utils._di as di_module + record_property("tested-item-id", "SPEC-UTILS-SERVICE") - mock_package = MagicMock() - mock_package.__path__ = ["/fake/path"] + mock_package = _mock_package() mock_module = ModuleType("aignostics.testmodule") mock_module.AnotherDummyBase = AnotherDummyBase # type: ignore[attr-defined] @@ -411,8 +797,7 @@ def test_locate_subclasses_excludes_base_class(clear_di_caches, record_property) @pytest.mark.unit def test_locate_subclasses_caches_results(clear_di_caches, record_property) -> None: """Test that locate_subclasses caches results.""" - record_property("tested-item-id", "SPEC-UTILS-DI") - import aignostics.utils._di as di_module + record_property("tested-item-id", "SPEC-UTILS-SERVICE") mock_package = MagicMock() mock_package.__path__ = [] @@ -431,8 +816,7 @@ def test_locate_subclasses_caches_results(clear_di_caches, record_property) -> N @pytest.mark.unit def test_locate_subclasses_handles_plugin_import_error(clear_di_caches, record_property) -> None: """Test that locate_subclasses handles ImportError for plugin packages gracefully.""" - record_property("tested-item-id", "SPEC-UTILS-DI") - import aignostics.utils._di as di_module + record_property("tested-item-id", "SPEC-UTILS-SERVICE") mock_package = MagicMock() mock_package.__path__ = [] @@ -453,11 +837,9 @@ def import_side_effect(name: str) -> ModuleType: @pytest.mark.unit def test_locate_subclasses_handles_module_import_error(clear_di_caches, record_property) -> None: """Test that locate_subclasses handles ImportError for individual modules gracefully.""" - record_property("tested-item-id", "SPEC-UTILS-DI") - import aignostics.utils._di as di_module + record_property("tested-item-id", "SPEC-UTILS-SERVICE") - mock_package = MagicMock() - mock_package.__path__ = ["/fake/path"] + mock_package = _mock_package() call_count = 0 def import_side_effect(name: str) -> ModuleType: @@ -476,14 +858,35 @@ def import_side_effect(name: str) -> ModuleType: assert isinstance(result, list) +@pytest.mark.unit +def test_locate_subclasses_no_plugins_detects_main_package(clear_di_caches, record_property) -> None: + """With no plugins, locate_subclasses only searches the main package.""" + record_property("tested-item-id", "SPEC-UTILS-SERVICE") + + class _Base: + pass + + class LocalSub(_Base): + pass + + main_pkg = _mock_package() + main_mod = ModuleType(f"{__project_name__}.{MYMODULE}") + main_mod.LocalSub = LocalSub # type: ignore[attr-defined] + + with _no_plugins_patches(main_pkg, main_mod) as searched: + result = locate_subclasses(_Base) + + assert LocalSub in result + assert not any(p != __project_name__ and not p.startswith(f"{__project_name__}.") for p in searched) + + @pytest.mark.unit def test_locate_implementations_and_subclasses_search_both_plugins_and_main_package( clear_di_caches, record_property, ) -> None: """Test that both functions search plugins first, then main package.""" - record_property("tested-item-id", "SPEC-UTILS-DI") - import aignostics.utils._di as di_module + record_property("tested-item-id", "SPEC-UTILS-SERVICE") import_order: list[str] = [] @@ -499,10 +902,10 @@ def track_imports(name: str) -> MagicMock: patch.object(di_module.pkgutil, "iter_modules", return_value=[]), ): locate_implementations(AnotherDummyBase) - assert import_order == ["plugin_a", "plugin_b", "aignostics"] + assert import_order == ["plugin_a", "plugin_b", __project_name__] _implementation_cache.clear() import_order.clear() locate_subclasses(AnotherDummySub) - assert import_order == ["plugin_a", "plugin_b", "aignostics"] + assert import_order == ["plugin_a", "plugin_b", __project_name__]