From d8f42a91f3091ddbc2584d0784dcb2e750dfe69a Mon Sep 17 00:00:00 2001 From: Samet Date: Tue, 19 May 2026 15:42:32 +0200 Subject: [PATCH 1/3] refactor(config): move shared primitives into runtime --- src/physicalai/config/__init__.py | 11 ++ src/physicalai/config/base.py | 84 +++++++++++ src/physicalai/config/component.py | 101 +++++++++++++ src/physicalai/config/instantiate.py | 134 ++++++++++++++++++ src/physicalai/config/mixin.py | 91 ++++++++++++ src/physicalai/config/serializable.py | 132 +++++++++++++++++ src/physicalai/inference/component_factory.py | 4 +- src/physicalai/inference/manifest.py | 107 +------------- src/physicalai/inference/model.py | 3 +- src/physicalai/inference/runners/factory.py | 2 +- tests/unit/inference/test_manifest.py | 7 +- 11 files changed, 567 insertions(+), 109 deletions(-) create mode 100644 src/physicalai/config/__init__.py create mode 100644 src/physicalai/config/base.py create mode 100644 src/physicalai/config/component.py create mode 100644 src/physicalai/config/instantiate.py create mode 100644 src/physicalai/config/mixin.py create mode 100644 src/physicalai/config/serializable.py diff --git a/src/physicalai/config/__init__.py b/src/physicalai/config/__init__.py new file mode 100644 index 0000000..e3a6831 --- /dev/null +++ b/src/physicalai/config/__init__.py @@ -0,0 +1,11 @@ +# Copyright (C) 2025-2026 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Configuration primitives shared by runtime and training packages.""" + +from physicalai.config.base import Config +from physicalai.config.component import ComponentSpec +from physicalai.config.instantiate import instantiate_obj +from physicalai.config.mixin import FromConfig, from_config + +__all__ = ["ComponentSpec", "Config", "FromConfig", "from_config", "instantiate_obj"] diff --git a/src/physicalai/config/base.py b/src/physicalai/config/base.py new file mode 100644 index 0000000..9a4a693 --- /dev/null +++ b/src/physicalai/config/base.py @@ -0,0 +1,84 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# ruff: noqa: DOC201, DOC501 + +"""Base configuration class for typed constructor configs.""" + +import dataclasses +from collections.abc import Mapping +from pathlib import Path +from typing import Any, Literal, Self + +from physicalai.config.serializable import dataclass_to_dict, dict_to_dataclass + +__all__ = ["Config"] + + +class Config: + """Base class for dataclass-backed configuration objects.""" + + def to_dict(self) -> dict[str, Any]: + """Convert this config to a plain dict for serialization.""" + if not dataclasses.is_dataclass(self): + msg = f"{self.__class__.__name__} must be a dataclass to use Config" + raise TypeError(msg) + + result = dataclass_to_dict(self) + if not isinstance(result, dict): + msg = f"Expected dict from dataclass_to_dict, got {type(result)}" + raise TypeError(msg) + return result + + @classmethod + def from_dict(cls, data: Mapping[str, Any]) -> Self: + """Reconstruct this config from a dict.""" + if not dataclasses.is_dataclass(cls): + msg = f"{cls.__name__} must be a dataclass to use Config" + raise TypeError(msg) + return dict_to_dataclass(cls, data) + + def to_jsonargparse(self) -> dict[str, Any]: + """Convert config to ``class_path``/``init_args`` format.""" + return { + "class_path": f"{self.__class__.__module__}.{self.__class__.__qualname__}", + "init_args": self.to_dict(), + } + + def save( + self, + path: str | Path, + *, + format: Literal["jsonargparse", "dict"] = "jsonargparse", # noqa: A002 + ) -> None: + """Save config to a YAML file.""" + path = Path(path) + data = self.to_dict() if format == "dict" else self.to_jsonargparse() + + if path.suffix not in {".yaml", ".yml"}: + msg = f"Unsupported file extension: {path.suffix}. Use .yaml or .yml" + raise ValueError(msg) + + import yaml # noqa: PLC0415 + + with path.open("w") as f: + yaml.safe_dump(data, f, default_flow_style=False, sort_keys=False) + + @classmethod + def load(cls, path: str | Path) -> Self: + """Load config from a YAML file.""" + path = Path(path) + + if path.suffix not in {".yaml", ".yml"}: + msg = f"Unsupported file extension: {path.suffix}. Use .yaml or .yml" + raise ValueError(msg) + + import yaml # noqa: PLC0415 + + with path.open() as f: + data = yaml.safe_load(f) + + if "init_args" in data: + data = data["init_args"] + + return cls.from_dict(data) diff --git a/src/physicalai/config/component.py b/src/physicalai/config/component.py new file mode 100644 index 0000000..d16d4a0 --- /dev/null +++ b/src/physicalai/config/component.py @@ -0,0 +1,101 @@ +# Copyright (C) 2026 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# ruff: noqa: DOC201, DOC501 + +"""Generic component specifications for dynamic instantiation.""" + +from __future__ import annotations + +import inspect +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +# Alias builtin ``type`` so it remains accessible inside classes that define a +# Pydantic field with the same name (e.g. ``ComponentSpec.type``). +_type = type + + +class ComponentSpec(BaseModel): + """Dual-resolution component descriptor for dynamic instantiation. + + Supports two resolution modes: + + 1. **type + flat params** (LeRobot-compatible):: + + {"type": "single_pass"} + + 2. **class_path + init_args** (full-power PhysicalAI):: + + {"class_path": "physicalai.inference.runners.SinglePass", + "init_args": {}} + + When ``class_path`` is present it takes precedence. When only ``type`` is + present, a component registry can resolve it. + + Attributes: + type: Registered short name (e.g. ``"single_pass"``). + class_path: Fully-qualified class path for direct import. + init_args: Keyword arguments forwarded to the constructor + (used with ``class_path`` mode). + """ + + model_config = ConfigDict(frozen=True, extra="allow") + type: str = "" + class_path: str = "" + init_args: dict[str, Any] = Field(default_factory=dict) + + @model_validator(mode="after") + def _must_have_type_or_class_path(self) -> ComponentSpec: + if not self.type and not self.class_path: + msg = "ComponentSpec requires either 'type' or 'class_path'" + raise ValueError(msg) + return self + + @property + def flat_params(self) -> dict[str, Any]: + """Return extra fields as flat params for type-based resolution.""" + return dict(self.model_extra) if self.model_extra else {} + + @classmethod + def from_class(cls, target: _type, **overrides: Any) -> ComponentSpec: # noqa: ANN401 + """Build a spec by introspecting a class constructor. + + Parameters not present in *overrides* use their default values. Required + parameters without defaults must be provided in *overrides* or a + TypeError is raised. + """ + sig = inspect.signature(target) + init_args: dict[str, Any] = {} + missing: list[str] = [] + + for name, param in sig.parameters.items(): + if name == "self": + continue + if name in overrides: + value = overrides[name] + elif param.default is not param.empty: + value = param.default + else: + missing.append(name) + continue + + if isinstance(value, ComponentSpec): + value = value.model_dump() + init_args[name] = value + + if missing: + msg = ( + f"Missing required parameters for {target.__qualname__}: " + f"{', '.join(missing)}. Pass them as keyword arguments." + ) + raise TypeError(msg) + + return cls( + class_path=f"{target.__module__}.{target.__qualname__}", + init_args=init_args, + ) + + +__all__ = ["ComponentSpec"] diff --git a/src/physicalai/config/instantiate.py b/src/physicalai/config/instantiate.py new file mode 100644 index 0000000..17c106e --- /dev/null +++ b/src/physicalai/config/instantiate.py @@ -0,0 +1,134 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# ruff: noqa: DOC201, DOC501 + +"""Configuration instantiation helpers.""" + +import dataclasses +import importlib +from pathlib import Path +from typing import TYPE_CHECKING + +import yaml +from pydantic import BaseModel + +if TYPE_CHECKING: + from typing import Any + + +def _import_class(class_path: str) -> type: + """Import a class from a module path.""" + try: + module_path, class_name = class_path.rsplit(".", 1) + module = importlib.import_module(module_path) # nosemgrep + return getattr(module, class_name) + except (ValueError, ImportError, AttributeError) as e: + msg = f"Cannot import '{class_path}': {e}" + raise ImportError(msg) from e + + +def _instantiate_recursive(value: "Any") -> "Any": # noqa: ANN401 + """Walk a value and instantiate nested ``{class_path, init_args}`` dicts.""" + if isinstance(value, dict): + if "class_path" in value: + return instantiate_obj_from_dict(value) + return {k: _instantiate_recursive(v) for k, v in value.items()} + if isinstance(value, list): + return [_instantiate_recursive(item) for item in value] + if isinstance(value, tuple): + return tuple(_instantiate_recursive(item) for item in value) + return value + + +def instantiate_obj_from_dict( + config: dict[str, "Any"], + *, + key: str | None = None, + target_cls: type | None = None, +) -> object: + """Instantiate an object from a configuration dictionary.""" + if key is not None: + if key not in config: + msg = f"Configuration must contain '{key}' key. Got keys: {list(config.keys())}" + raise ValueError(msg) + config = config[key] + + if "class_path" in config: + cls = _import_class(config["class_path"]) + init_args = config.get("init_args", {}) + elif target_cls is not None: + cls = target_cls + init_args = config + else: + msg = ( + "Configuration must contain 'class_path' for instantiation, " + f"or pass target_cls explicitly. Got keys: {list(config.keys())}" + ) + raise ValueError(msg) + + if not isinstance(init_args, dict): + return cls(init_args) + + instantiated_args = {k: _instantiate_recursive(v) for k, v in init_args.items()} + + if "args" in instantiated_args: + args = instantiated_args.pop("args") + return cls(*args, **instantiated_args) + return cls(**instantiated_args) + + +def instantiate_obj_from_pydantic( + config: BaseModel, + *, + key: str | None = None, + target_cls: type | None = None, +) -> object: + """Instantiate an object from a Pydantic model.""" + return instantiate_obj_from_dict(config.model_dump(), key=key, target_cls=target_cls) + + +def instantiate_obj_from_dataclass( + config: object, + *, + key: str | None = None, + target_cls: type | None = None, +) -> object: + """Instantiate an object from a dataclass instance.""" + if not dataclasses.is_dataclass(config) or isinstance(config, type): + msg = f"Expected dataclass instance, got {type(config)}" + raise TypeError(msg) + + return instantiate_obj_from_dict(dataclasses.asdict(config), key=key, target_cls=target_cls) + + +def instantiate_obj_from_file( + file_path: str | Path, + *, + key: str | None = None, + target_cls: type | None = None, +) -> object: + """Instantiate an object from a YAML/JSON configuration file.""" + with Path(file_path).open("r", encoding="utf-8") as f: + config = yaml.safe_load(f) + return instantiate_obj_from_dict(config, key=key, target_cls=target_cls) + + +def instantiate_obj( + config: dict[str, "Any"] | BaseModel | object | str | Path, + *, + key: str | None = None, + target_cls: type | None = None, +) -> object: + """Instantiate an object from dict, Pydantic, dataclass, or file config.""" + if isinstance(config, (str, Path)): + return instantiate_obj_from_file(config, key=key, target_cls=target_cls) + if isinstance(config, BaseModel): + return instantiate_obj_from_pydantic(config, key=key, target_cls=target_cls) + if dataclasses.is_dataclass(config) and not isinstance(config, type): + return instantiate_obj_from_dataclass(config, key=key, target_cls=target_cls) + if isinstance(config, dict): + return instantiate_obj_from_dict(config, key=key, target_cls=target_cls) + + msg = f"Unsupported configuration type: {type(config)}. Expected dict, file path, Pydantic model, or dataclass." + raise TypeError(msg) diff --git a/src/physicalai/config/mixin.py b/src/physicalai/config/mixin.py new file mode 100644 index 0000000..ed1d330 --- /dev/null +++ b/src/physicalai/config/mixin.py @@ -0,0 +1,91 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# ruff: noqa: DOC201, DOC501 + +"""Configuration mixins for adding from_config functionality.""" + +import dataclasses +from pathlib import Path +from typing import Any, Self, cast + +import yaml +from pydantic import BaseModel + +from physicalai.config.instantiate import instantiate_obj_from_dict +from physicalai.config.serializable import dataclass_to_dict + + +class FromConfig: + """Mixin that adds configuration-based construction helpers.""" + + @classmethod + def from_yaml(cls, file_path: str | Path, *, key: str | None = None) -> Self: + """Load configuration from a YAML file and instantiate the class.""" + with Path(file_path).open("r", encoding="utf-8") as f: + config = yaml.safe_load(f) + return cls.from_dict(config, key=key) + + @classmethod + def from_dict(cls, config: dict[str, Any], *, key: str | None = None) -> Self: + """Instantiate the class from a configuration dictionary.""" + return cast("Self", instantiate_obj_from_dict(config, key=key, target_cls=cls)) + + @classmethod + def from_pydantic( + cls, + config: BaseModel, + *, + key: str | None = None, + recursive: bool = False, + ) -> Self: + """Instantiate the class from a Pydantic model.""" + if recursive: + config_dict = config.model_dump() + else: + config_dict = {name: getattr(config, name) for name in config.__class__.model_fields} + return cls.from_dict(config_dict, key=key) + + @classmethod + def from_dataclass( + cls, + config: object, + *, + key: str | None = None, + recursive: bool = False, + ) -> Self: + """Instantiate the class from a dataclass instance.""" + if not dataclasses.is_dataclass(config) or isinstance(config, type): + msg = f"Expected dataclass instance, got {type(config)}" + raise TypeError(msg) + + config_dict = cast("dict[str, Any]", dataclass_to_dict(config, recursive=recursive)) + return cls.from_dict(config_dict, key=key) + + @classmethod + def from_config( + cls, + config: dict[str, Any] | BaseModel | object | str | Path, + *, + key: str | None = None, + recursive: bool = False, + ) -> Self: + """Generic entry point that dispatches on the type of ``config``.""" + if isinstance(config, (str, Path)): + return cls.from_yaml(config, key=key) + if isinstance(config, BaseModel): + return cls.from_pydantic(config, key=key, recursive=recursive) + if dataclasses.is_dataclass(config) and not isinstance(config, type): + return cls.from_dataclass(config, key=key, recursive=recursive) + if isinstance(config, dict): + return cls.from_dict(config, key=key) + + msg = f"Unsupported configuration type: {type(config)}. Expected dict, file path, Pydantic model, or dataclass." + raise TypeError(msg) + + +def from_config[T: type](cls: T) -> T: + """Decorate a class with the same config constructors as ``FromConfig``.""" + for name in ("from_yaml", "from_dict", "from_pydantic", "from_dataclass", "from_config"): + setattr(cls, name, FromConfig.__dict__[name]) + return cls diff --git a/src/physicalai/config/serializable.py b/src/physicalai/config/serializable.py new file mode 100644 index 0000000..2a8ea96 --- /dev/null +++ b/src/physicalai/config/serializable.py @@ -0,0 +1,132 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# ruff: noqa: DOC201, DOC501 + +"""Serialization utilities for dataclasses.""" + +from __future__ import annotations + +import dataclasses +import operator +import types +from enum import Enum +from functools import reduce +from itertools import starmap +from typing import TYPE_CHECKING, Union, get_args, get_origin, get_type_hints + +if TYPE_CHECKING: + from collections.abc import Mapping + +_MIN_DICT_TYPE_ARGS = 2 +_VAR_TUPLE_ARG_COUNT = 2 + +__all__ = ["dataclass_to_dict", "dict_to_dataclass"] + + +def dataclass_to_dict(obj: object, *, recursive: bool = True) -> object: # noqa: PLR0911 + """Convert a dataclass or nested structure to plain Python data.""" + if dataclasses.is_dataclass(obj) and not isinstance(obj, type): + if not recursive: + return {field.name: getattr(obj, field.name) for field in dataclasses.fields(obj)} + return {field.name: dataclass_to_dict(getattr(obj, field.name)) for field in dataclasses.fields(obj)} + + if not recursive: + return obj + + if isinstance(obj, dict): + return {(k.value if isinstance(k, Enum) else k): dataclass_to_dict(v) for k, v in obj.items()} + + if isinstance(obj, (list, tuple)): + return [dataclass_to_dict(item) for item in obj] + + if isinstance(obj, Enum): + return obj.value + + if hasattr(obj, "tolist") and hasattr(obj, "ndim"): + return obj.tolist() # type: ignore[union-attr] + + return obj + + +def dict_to_dataclass[T](cls: type[T], data: Mapping[str, object]) -> T: + """Reconstruct a dataclass from a dict using type hints.""" + if not dataclasses.is_dataclass(cls): + msg = f"Expected dataclass, got {cls}" + raise TypeError(msg) + + try: + hints = get_type_hints(cls) + except Exception: # noqa: BLE001 + hints = {} + + kwargs = {} + for field in dataclasses.fields(cls): + if field.name not in data: + continue + value = data[field.name] + field_type = hints.get(field.name, field.type) + kwargs[field.name] = _reconstruct_value(value, field_type) + + return cls(**kwargs) # type: ignore[return-value] + + +def _reconstruct_value(value: object, field_type: object) -> object: # noqa: PLR0911 + """Reconstruct a value based on its expected type.""" + if value is None: + return None + + origin = get_origin(field_type) + args = get_args(field_type) + + if origin is type(None): + return None + + if _is_optional_type(field_type): + return _reconstruct_value(value, _get_optional_inner_type(field_type)) + + if origin is dict and isinstance(value, dict): + if len(args) >= _MIN_DICT_TYPE_ARGS: + return {k: _reconstruct_value(v, args[1]) for k, v in value.items()} + return value + + if origin is list and isinstance(value, list): + if args: + return [_reconstruct_value(item, args[0]) for item in value] + return value + + if origin is tuple and isinstance(value, list): + if args: + if len(args) == _VAR_TUPLE_ARG_COUNT and args[1] is ...: + return tuple(_reconstruct_value(item, args[0]) for item in value) + return tuple(starmap(_reconstruct_value, zip(value, args, strict=False))) + return tuple(value) + + actual_type = origin or field_type + if isinstance(actual_type, type) and dataclasses.is_dataclass(actual_type) and isinstance(value, dict): + return dict_to_dataclass(actual_type, value) + + if isinstance(actual_type, type) and issubclass(actual_type, Enum) and not isinstance(value, Enum): + return actual_type(value) + + return value + + +def _is_optional_type(field_type: object) -> bool: + """Check if a type is Optional[X].""" + origin = get_origin(field_type) + if origin is None: + return False + if origin is types.UnionType: + return type(None) in get_args(field_type) + if origin is Union: + return type(None) in get_args(field_type) + return False + + +def _get_optional_inner_type(field_type: object) -> object: + """Get the non-None inner type from Optional[X].""" + non_none_args = [arg for arg in get_args(field_type) if arg is not type(None)] + if len(non_none_args) == 1: + return non_none_args[0] + return reduce(operator.or_, non_none_args) diff --git a/src/physicalai/inference/component_factory.py b/src/physicalai/inference/component_factory.py index 71136a3..9875ece 100644 --- a/src/physicalai/inference/component_factory.py +++ b/src/physicalai/inference/component_factory.py @@ -6,7 +6,7 @@ The :class:`ComponentRegistry` maps short names (e.g. ``"single_pass"``) to fully-qualified class paths so that manifests can use concise identifiers instead of full dotted paths. The :func:`instantiate_component` -factory resolves a :class:`~physicalai.inference.manifest.ComponentSpec` +factory resolves a :class:`~physicalai.config.ComponentSpec` to an object instance, supporting both ``type`` + flat params and ``class_path`` + ``init_args`` resolution modes. """ @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from physicalai.inference.manifest import ComponentSpec + from physicalai.config import ComponentSpec class ComponentRegistry: diff --git a/src/physicalai/inference/manifest.py b/src/physicalai/inference/manifest.py index 1e10785..b2c3dab 100644 --- a/src/physicalai/inference/manifest.py +++ b/src/physicalai/inference/manifest.py @@ -18,20 +18,17 @@ from __future__ import annotations -import inspect import json from pathlib import Path from typing import Any -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator + +from physicalai.config import ComponentSpec MANIFEST_VERSION = "1.0" MANIFEST_FORMAT = "policy_package" -# Alias builtin ``type`` so it remains accessible inside classes that -# define a Pydantic field with the same name (e.g. ``ComponentSpec.type``). -_type = type - class TensorSpec(BaseModel): """Shape and dtype descriptor for one tensor. @@ -147,104 +144,6 @@ class PolicySpec(BaseModel): source: PolicySource = Field(default_factory=PolicySource) -class ComponentSpec(BaseModel): - """Dual-resolution component descriptor for dynamic instantiation. - - Supports two resolution modes: - - 1. **type + flat params** (LeRobot-compatible):: - - {"type": "single_pass"} - - 2. **class_path + init_args** (full-power PhysicalAI):: - - {"class_path": "physicalai.inference.runners.SinglePass", - "init_args": {}} - - When ``class_path`` is present it takes precedence. When only - ``type`` is present, the :class:`ComponentRegistry` resolves it. - - Attributes: - type: Registered short name (e.g. ``"single_pass"``). - class_path: Fully-qualified class path for direct import. - init_args: Keyword arguments forwarded to the constructor - (used with ``class_path`` mode). - """ - - model_config = ConfigDict(frozen=True, extra="allow") - type: str = "" - class_path: str = "" - init_args: dict[str, Any] = Field(default_factory=dict) - - @model_validator(mode="after") - def _must_have_type_or_class_path(self) -> ComponentSpec: - if not self.type and not self.class_path: - msg = "ComponentSpec requires either 'type' or 'class_path'" - raise ValueError(msg) - return self - - @property - def flat_params(self) -> dict[str, Any]: - """Return extra fields as flat params (type-based resolution). - - Returns all fields stored in ``model_extra`` — these are the - flat kwargs passed alongside ``type`` in LeRobot-style specs. - """ - return dict(self.model_extra) if self.model_extra else {} - - @classmethod - def from_class(cls, target: _type, **overrides: Any) -> ComponentSpec: # noqa: ANN401 - """Build a spec by introspecting a class constructor. - - Parameters not present in *overrides* use their default values. - Required parameters without defaults must be provided in *overrides* - or a TypeError is raised. - - For nested components, pass a ``ComponentSpec`` instance (e.g. from - another ``from_class`` call) — it will be serialized automatically. - - Args: - target: The class to build a spec for. - **overrides: Values that override or supply constructor args. - - Returns: - A ``ComponentSpec`` ready for serialisation or instantiation. - - Raises: - TypeError: If required parameters are missing from overrides. - """ - sig = inspect.signature(target) - init_args: dict[str, Any] = {} - missing: list[str] = [] - - for name, param in sig.parameters.items(): - if name == "self": - continue - if name in overrides: - value = overrides[name] - elif param.default is not param.empty: - value = param.default - else: - missing.append(name) - continue - - if isinstance(value, ComponentSpec): - value = value.model_dump() - init_args[name] = value - - if missing: - msg = ( - f"Missing required parameters for {target.__qualname__}: " - f"{', '.join(missing)}. Pass them as keyword arguments." - ) - raise TypeError(msg) - - return cls( - class_path=f"{target.__module__}.{target.__qualname__}", - init_args=init_args, - ) - - class ModelSpec(BaseModel): """Model inference specification. diff --git a/src/physicalai/inference/model.py b/src/physicalai/inference/model.py index 39f4fc4..582515a 100644 --- a/src/physicalai/inference/model.py +++ b/src/physicalai/inference/model.py @@ -15,13 +15,14 @@ from physicalai.inference.adapters import adapter_registry, get_adapter from physicalai.inference.component_factory import instantiate_component, resolve_artifact from physicalai.inference.constants import ACTION -from physicalai.inference.manifest import ComponentSpec, Manifest +from physicalai.inference.manifest import Manifest from physicalai.inference.runners import get_runner from physicalai.inference.utils import ActionCursor if TYPE_CHECKING: import numpy as np + from physicalai.config import ComponentSpec from physicalai.inference.adapters.base import RuntimeAdapter from physicalai.inference.callbacks.base import Callback from physicalai.inference.postprocessors.base import Postprocessor diff --git a/src/physicalai/inference/runners/factory.py b/src/physicalai/inference/runners/factory.py index fd479ae..9355606 100644 --- a/src/physicalai/inference/runners/factory.py +++ b/src/physicalai/inference/runners/factory.py @@ -49,8 +49,8 @@ def get_runner(source: Manifest | dict[str, Any]) -> InferenceRunner: runner_spec = _extract_runner_spec(source) if runner_spec is not None: + from physicalai.config import ComponentSpec # noqa: PLC0415 from physicalai.inference.component_factory import instantiate_component # noqa: PLC0415 - from physicalai.inference.manifest import ComponentSpec # noqa: PLC0415 runner = instantiate_component(ComponentSpec.model_validate(runner_spec)) if not isinstance(runner, InferenceRunner): diff --git a/tests/unit/inference/test_manifest.py b/tests/unit/inference/test_manifest.py index f228bc6..80626ed 100644 --- a/tests/unit/inference/test_manifest.py +++ b/tests/unit/inference/test_manifest.py @@ -10,6 +10,7 @@ import pytest from pydantic import ValidationError +from physicalai.config import ComponentSpec from physicalai.inference.component_factory import ( ComponentRegistry, component_registry, @@ -18,7 +19,6 @@ ) from physicalai.inference.manifest import ( CameraSpec, - ComponentSpec, HardwareSpec, Manifest, MetadataSpec, @@ -30,9 +30,14 @@ TensorSpec, _policy_name_from_class_path, ) +from physicalai.inference.manifest import ComponentSpec as ManifestComponentSpec from physicalai.inference.runners import SinglePass +def test_manifest_reexports_component_spec() -> None: + assert ManifestComponentSpec is ComponentSpec + + class TestTensorSpec: def test_from_dict_defaults(self) -> None: spec = TensorSpec.model_validate({"shape": [14]}) From a45ecb0d605ed27828f2a4aa15a3059009624b2a Mon Sep 17 00:00:00 2001 From: Samet Date: Tue, 19 May 2026 15:59:01 +0200 Subject: [PATCH 2/3] fix(config): validate yaml roots in shared loaders --- src/physicalai/config/base.py | 11 +++ src/physicalai/config/instantiate.py | 13 +++ src/physicalai/config/mixin.py | 6 ++ tests/unit/config/test_config.py | 118 +++++++++++++++++++++++++++ 4 files changed, 148 insertions(+) create mode 100644 tests/unit/config/test_config.py diff --git a/src/physicalai/config/base.py b/src/physicalai/config/base.py index 9a4a693..fb69b6e 100644 --- a/src/physicalai/config/base.py +++ b/src/physicalai/config/base.py @@ -78,7 +78,18 @@ def load(cls, path: str | Path) -> Self: with path.open() as f: data = yaml.safe_load(f) + if data is None: + data = {} + if not isinstance(data, Mapping): + msg = f"Expected YAML root to be a mapping, got {type(data).__name__}" + raise TypeError(msg) + if "init_args" in data: data = data["init_args"] + if data is None: + data = {} + if not isinstance(data, Mapping): + msg = f"Expected 'init_args' to be a mapping, got {type(data).__name__}" + raise TypeError(msg) return cls.from_dict(data) diff --git a/src/physicalai/config/instantiate.py b/src/physicalai/config/instantiate.py index 17c106e..fbcfe34 100644 --- a/src/physicalai/config/instantiate.py +++ b/src/physicalai/config/instantiate.py @@ -7,6 +7,7 @@ import dataclasses import importlib +from collections.abc import Mapping from pathlib import Path from typing import TYPE_CHECKING @@ -48,11 +49,18 @@ def instantiate_obj_from_dict( target_cls: type | None = None, ) -> object: """Instantiate an object from a configuration dictionary.""" + if not isinstance(config, Mapping): + msg = f"Expected configuration to be a mapping, got {type(config).__name__}" + raise TypeError(msg) + if key is not None: if key not in config: msg = f"Configuration must contain '{key}' key. Got keys: {list(config.keys())}" raise ValueError(msg) config = config[key] + if not isinstance(config, Mapping): + msg = f"Configuration at key '{key}' must be a mapping, got {type(config).__name__}" + raise TypeError(msg) if "class_path" in config: cls = _import_class(config["class_path"]) @@ -111,6 +119,11 @@ def instantiate_obj_from_file( """Instantiate an object from a YAML/JSON configuration file.""" with Path(file_path).open("r", encoding="utf-8") as f: config = yaml.safe_load(f) + if config is None: + config = {} + if not isinstance(config, Mapping): + msg = f"Expected YAML root to be a mapping, got {type(config).__name__}" + raise TypeError(msg) return instantiate_obj_from_dict(config, key=key, target_cls=target_cls) diff --git a/src/physicalai/config/mixin.py b/src/physicalai/config/mixin.py index ed1d330..349da8c 100644 --- a/src/physicalai/config/mixin.py +++ b/src/physicalai/config/mixin.py @@ -6,6 +6,7 @@ """Configuration mixins for adding from_config functionality.""" import dataclasses +from collections.abc import Mapping from pathlib import Path from typing import Any, Self, cast @@ -24,6 +25,11 @@ def from_yaml(cls, file_path: str | Path, *, key: str | None = None) -> Self: """Load configuration from a YAML file and instantiate the class.""" with Path(file_path).open("r", encoding="utf-8") as f: config = yaml.safe_load(f) + if config is None: + config = {} + if not isinstance(config, Mapping): + msg = f"Expected YAML root to be a mapping, got {type(config).__name__}" + raise TypeError(msg) return cls.from_dict(config, key=key) @classmethod diff --git a/tests/unit/config/test_config.py b/tests/unit/config/test_config.py new file mode 100644 index 0000000..21cb4f3 --- /dev/null +++ b/tests/unit/config/test_config.py @@ -0,0 +1,118 @@ +# Copyright (C) 2025-2026 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# ruff: noqa: S101 + +from __future__ import annotations + +from dataclasses import dataclass + +import pytest + +from physicalai.config import Config, instantiate_obj +from physicalai.config.mixin import FromConfig + + +class NestedComponent: + def __init__(self, value: int) -> None: + self.value = value + + +class SampleModel(FromConfig): + def __init__(self, hidden_size: int, component: NestedComponent | None = None) -> None: + self.hidden_size = hidden_size + self.component = component + + +@dataclass +class SampleConfig(Config): + hidden_size: int = 128 + + +class TestInstantiateObj: + def test_instantiates_nested_config(self) -> None: + model = instantiate_obj({ + "class_path": f"{SampleModel.__module__}.SampleModel", + "init_args": { + "hidden_size": 256, + "component": { + "class_path": f"{NestedComponent.__module__}.NestedComponent", + "init_args": {"value": 7}, + }, + }, + }) + + assert isinstance(model, SampleModel) + assert model.hidden_size == 256 + assert isinstance(model.component, NestedComponent) + assert model.component.value == 7 + + @pytest.mark.parametrize( + ("contents", "message"), + [ + ("", "class_path"), + ("- not\n- a\n- mapping\n", "Expected YAML root to be a mapping"), + ], + ) + def test_file_validation(self, tmp_path, contents: str, message: str) -> None: + path = tmp_path / "config.yaml" + path.write_text(contents) + + with pytest.raises((TypeError, ValueError), match=message): + instantiate_obj(path) + + def test_key_requires_mapping_value(self) -> None: + with pytest.raises(TypeError, match="Configuration at key 'model' must be a mapping"): + instantiate_obj({"model": 3}, key="model") + + +class TestFromConfig: + def test_from_yaml_loads_mapping(self, tmp_path) -> None: + path = tmp_path / "model.yaml" + path.write_text("hidden_size: 512\n") + + model = SampleModel.from_yaml(path) + + assert model.hidden_size == 512 + + def test_from_yaml_rejects_non_mapping_root(self, tmp_path) -> None: + path = tmp_path / "model.yaml" + path.write_text("- hidden_size\n") + + with pytest.raises(TypeError, match="Expected YAML root to be a mapping"): + SampleModel.from_yaml(path) + + +class TestConfigLoad: + def test_load_supports_dict_and_jsonargparse_formats(self, tmp_path) -> None: + dict_path = tmp_path / "dict.yaml" + dict_path.write_text("hidden_size: 256\n") + + jsonargparse_path = tmp_path / "jsonargparse.yaml" + jsonargparse_path.write_text( + "class_path: builtins.dict\n" + "init_args:\n" + " hidden_size: 512\n", + ) + + assert SampleConfig.load(dict_path).hidden_size == 256 + assert SampleConfig.load(jsonargparse_path).hidden_size == 512 + + def test_load_empty_yaml_uses_defaults(self, tmp_path) -> None: + path = tmp_path / "config.yaml" + path.write_text("") + + assert SampleConfig.load(path).hidden_size == 128 + + @pytest.mark.parametrize( + ("contents", "message"), + [ + ("- not\n- a\n- mapping\n", "Expected YAML root to be a mapping"), + ("class_path: builtins.dict\ninit_args: 3\n", "Expected 'init_args' to be a mapping"), + ], + ) + def test_load_validates_yaml_shape(self, tmp_path, contents: str, message: str) -> None: + path = tmp_path / "config.yaml" + path.write_text(contents) + + with pytest.raises(TypeError, match=message): + SampleConfig.load(path) From 8f037515f5d295b2147ebbc1342357eb5e95f723 Mon Sep 17 00:00:00 2001 From: Samet Date: Tue, 19 May 2026 16:01:37 +0200 Subject: [PATCH 3/3] fix(config): align mapping types for pyrefly --- src/physicalai/config/instantiate.py | 11 +++++------ src/physicalai/config/mixin.py | 4 ++-- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/physicalai/config/instantiate.py b/src/physicalai/config/instantiate.py index fbcfe34..1216367 100644 --- a/src/physicalai/config/instantiate.py +++ b/src/physicalai/config/instantiate.py @@ -18,6 +18,9 @@ from typing import Any +ConfigMapping = Mapping[str, "Any"] + + def _import_class(class_path: str) -> type: """Import a class from a module path.""" try: @@ -43,16 +46,12 @@ def _instantiate_recursive(value: "Any") -> "Any": # noqa: ANN401 def instantiate_obj_from_dict( - config: dict[str, "Any"], + config: ConfigMapping, *, key: str | None = None, target_cls: type | None = None, ) -> object: """Instantiate an object from a configuration dictionary.""" - if not isinstance(config, Mapping): - msg = f"Expected configuration to be a mapping, got {type(config).__name__}" - raise TypeError(msg) - if key is not None: if key not in config: msg = f"Configuration must contain '{key}' key. Got keys: {list(config.keys())}" @@ -128,7 +127,7 @@ def instantiate_obj_from_file( def instantiate_obj( - config: dict[str, "Any"] | BaseModel | object | str | Path, + config: ConfigMapping | BaseModel | object | str | Path, *, key: str | None = None, target_cls: type | None = None, diff --git a/src/physicalai/config/mixin.py b/src/physicalai/config/mixin.py index 349da8c..93fc244 100644 --- a/src/physicalai/config/mixin.py +++ b/src/physicalai/config/mixin.py @@ -33,7 +33,7 @@ def from_yaml(cls, file_path: str | Path, *, key: str | None = None) -> Self: return cls.from_dict(config, key=key) @classmethod - def from_dict(cls, config: dict[str, Any], *, key: str | None = None) -> Self: + def from_dict(cls, config: Mapping[str, Any], *, key: str | None = None) -> Self: """Instantiate the class from a configuration dictionary.""" return cast("Self", instantiate_obj_from_dict(config, key=key, target_cls=cls)) @@ -71,7 +71,7 @@ def from_dataclass( @classmethod def from_config( cls, - config: dict[str, Any] | BaseModel | object | str | Path, + config: Mapping[str, Any] | BaseModel | object | str | Path, *, key: str | None = None, recursive: bool = False,