diff --git a/judo/app/dora/controller.py b/judo/app/dora/controller_node.py similarity index 72% rename from judo/app/dora/controller.py rename to judo/app/dora/controller_node.py index 63ec9b7c..872b066e 100644 --- a/judo/app/dora/controller.py +++ b/judo/app/dora/controller_node.py @@ -2,6 +2,7 @@ import time from threading import Lock +from typing import Callable import pyarrow as pa from dora_utils.dataclasses import from_event, to_arrow @@ -23,38 +24,60 @@ def __init__( max_workers: int | None = None, task_registration_cfg: DictConfig | None = None, optimizer_registration_cfg: DictConfig | None = None, + make_controller_fn: Callable | None = None, ) -> None: - """Initialize the controller node.""" + """Initialize the controller node. + + Args: + init_task: Name of the task to initialize. + init_optimizer: Name of the optimizer to initialize (e.g., "cem", "ps", "mppi"). + node_id: Identifier for this dora node. + max_workers: Maximum number of worker threads for dora (None = auto). + task_registration_cfg: Optional config for task registration overrides. + optimizer_registration_cfg: Optional config for optimizer registration overrides. + make_controller_fn: Optional factory function to create Controller instances. + Defaults to judo.controller.make_controller. Allows custom controller creation. + """ super().__init__(node_id=node_id, max_workers=max_workers) - self.controller = make_controller( - init_task=init_task, - init_optimizer=init_optimizer, - task_registration_cfg=task_registration_cfg, - optimizer_registration_cfg=optimizer_registration_cfg, - ) + self._make_controller_fn = make_controller_fn or make_controller + self._task_registration_cfg = task_registration_cfg + self._optimizer_registration_cfg = optimizer_registration_cfg + self.controller = self._build_controller(init_task, init_optimizer) self._paused = False self.write_controls() self.lock = Lock() + def _build_controller(self, task_name: str, optimizer_name: str) -> Controller: + """Build controller using the task's registered rollout backend.""" + return self._make_controller_fn( + init_task=task_name, + init_optimizer=optimizer_name, + task_registration_cfg=self._task_registration_cfg, + optimizer_registration_cfg=self._optimizer_registration_cfg, + ) + + def _current_optimizer_name(self) -> str: + """Look up the name of the current optimizer from the registry. + + Returns "cem" as a safe default if no registry entry matches the active optimizer instance. + """ + for name, (cls, _) in self.controller.available_optimizers.items(): + if isinstance(self.controller.optimizer, cls): + return name + return "cem" + @on_event("INPUT", "task") def update_task(self, event: dict) -> None: """Updates the task type.""" new_task = event["value"].to_numpy(zero_copy_only=False)[0] task_entry = self.controller.available_tasks.get(new_task) - if task_entry is not None: - task_cls, _ = task_entry - with self.lock: - task = task_cls() - optimizer = self.controller.optimizer_cls(self.controller.optimizer_config_cls(), task.nu) - self.controller = Controller( - controller_config=self.controller.controller_cfg, - task=task, - optimizer=optimizer, - ) - self.write_controls() - else: + if task_entry is None: raise ValueError(f"Task {new_task} not found in task registry.") + with self.lock: + self.controller = self._build_controller(new_task, self._current_optimizer_name()) + self.write_controls() + @on_event("INPUT", "task_reset") def reset_task(self, event: dict) -> None: """Resets the task.""" @@ -75,6 +98,7 @@ def update_optimizer(self, event: dict) -> None: if optimizer_entry is not None: optimizer_cls, optimizer_config_cls = optimizer_entry optimizer_config = optimizer_config_cls() + optimizer_config.set_override(self.controller.task.name) optimizer = optimizer_cls(optimizer_config, self.controller.task.nu) with self.lock: self.controller.optimizer = optimizer diff --git a/judo/app/dora/simulation.py b/judo/app/dora/simulation_node.py similarity index 65% rename from judo/app/dora/simulation.py rename to judo/app/dora/simulation_node.py index 054fc595..d3925847 100644 --- a/judo/app/dora/simulation.py +++ b/judo/app/dora/simulation_node.py @@ -9,7 +9,9 @@ from omegaconf import DictConfig from judo.app.structs import SplineData -from judo.simulation import get_simulation_backend +from judo.simulation import DEFAULT_SIMULATION_BACKEND_REGISTRY +from judo.simulation.base import Simulation +from judo.tasks import get_registered_tasks class SimulationNode(DoraNode): @@ -21,26 +23,40 @@ def __init__( init_task: str = "cylinder_push", max_workers: int | None = None, task_registration_cfg: DictConfig | None = None, - simulation_backend: str = "mujoco", + backend_registry: dict[str, type[Simulation]] | None = None, ) -> None: - """Initialize the simulation node.""" + """Initialize the simulation node. + + Args: + node_id: Identifier for this dora node. + init_task: Name of the task to initialize. + max_workers: Maximum number of worker threads for dora (None = auto). + task_registration_cfg: Optional config for task registration overrides. + backend_registry: Optional mapping of backend names → Simulation classes. Checked first before built-in registry. + """ super().__init__(node_id=node_id, max_workers=max_workers) - self._simulation_backend = simulation_backend self._task_registration_cfg = task_registration_cfg + self._backend_registry = dict(DEFAULT_SIMULATION_BACKEND_REGISTRY) + self._backend_registry.update(backend_registry or {}) self._init_sim(init_task) self.control_spline: Callable | None = None self.write_states() + def _resolve_backend(self, backend_name: str) -> type[Simulation]: + """Resolve a simulation backend class by name from merged registry.""" + backend_cls = self._backend_registry.get(backend_name) + if backend_cls is None: + raise KeyError(f"Unknown simulation backend: {backend_name!r}") + return backend_cls + def _init_sim(self, task_name: str) -> None: - """Initialize simulation, auto-upgrading to policy backend if needed.""" - backend = self._simulation_backend - _sim_backend = get_simulation_backend(backend) - self.sim = _sim_backend(init_task=task_name, task_registration_cfg=self._task_registration_cfg) + """Initialize simulation using the task's registered simulation backend.""" + task_entry = get_registered_tasks().get(task_name) + if task_entry is None: + raise ValueError(f"Task {task_name} not found in task registry.") - # Auto-upgrade to policy backend if task requires locomotion policy - if backend == "mujoco" and self.sim.task.uses_locomotion_policy: - _sim_backend = get_simulation_backend("mujoco_policy") - self.sim = _sim_backend(init_task=task_name, task_registration_cfg=self._task_registration_cfg) + sim_backend_cls = self._resolve_backend(task_entry.simulation_backend) + self.sim = sim_backend_cls(init_task=task_name, task_registration_cfg=self._task_registration_cfg) @on_event("INPUT", "task") def update_task(self, event: dict) -> None: @@ -85,6 +101,8 @@ def write_states(self) -> None: """Reads data from simulation and writes to output topic.""" arr, metadata = to_arrow(self.sim.sim_state) self.node.send_output("states", arr, metadata) + arr, metadata = to_arrow(self.sim.render_pose) + self.node.send_output("render_pose", arr, metadata) @on_event("INPUT", "sim_pause") def set_paused_status(self, event: dict) -> None: diff --git a/judo/app/dora/visualization.py b/judo/app/dora/visualization_node.py similarity index 80% rename from judo/app/dora/visualization.py rename to judo/app/dora/visualization_node.py index 96381f9e..462155ae 100644 --- a/judo/app/dora/visualization.py +++ b/judo/app/dora/visualization_node.py @@ -8,7 +8,8 @@ from omegaconf import DictConfig from viser import GuiFolderHandle, GuiImageHandle, GuiInputHandle, IcosphereHandle, MeshHandle -from judo.app.structs import MujocoState +from judo.app.structs import RenderPose +from judo.tasks import TaskRegistration from judo.visualizers.visualizer import Visualizer ElementType = GuiImageHandle | GuiInputHandle | GuiFolderHandle | MeshHandle | IcosphereHandle @@ -29,8 +30,24 @@ def __init__( optimizer_override_cfg: DictConfig | None = None, sim_pause_button: bool = True, geom_exclude_substring: str = "collision", + available_tasks: dict[str, TaskRegistration] | None = None, ) -> None: - """Initialize the visualization node.""" + """Initialize the visualization node (Viser web GUI for task/optimizer control). + + Args: + node_id: Identifier for this dora node. + max_workers: Maximum number of worker threads for dora (None = auto). + init_task: Name of the task to initialize. + init_optimizer: Name of the optimizer to initialize (e.g., "cem", "ps"). + task_registration_cfg: Optional config for task registration overrides. + optimizer_registration_cfg: Optional config for optimizer registration overrides. + controller_override_cfg: Optional config overrides for the controller. + optimizer_override_cfg: Optional config overrides for the optimizer. + sim_pause_button: Whether to display a simulation pause button in the GUI. + geom_exclude_substring: Geometry name substring to exclude from visualization (default "collision" hides collision shapes). + available_tasks: Optional pre-computed mapping of task names to TaskRegistration entries + for the task selector. If None, tasks are inferred from the task registry. + """ super().__init__(node_id=node_id, max_workers=max_workers) self.visualizer = Visualizer( init_task=init_task, @@ -41,6 +58,7 @@ def __init__( optimizer_override_cfg=optimizer_override_cfg, sim_pause_button=sim_pause_button, geom_exclude_substring=geom_exclude_substring, + available_tasks=available_tasks, ) def write_sim_pause(self) -> None: @@ -85,7 +103,7 @@ def write_task_config(self) -> None: self.node.send_output("task_config", *to_arrow(self.visualizer.task_config)) self.visualizer.task_config_updated.clear() - @on_event("INPUT", "states") + @on_event("INPUT", "render_pose") def update_states(self, event: dict) -> None: """Callback to update states on receiving a new state measurement.""" if self.visualizer.controller_config.spline_order == "cubic" and self.visualizer.optimizer_config.num_nodes < 4: @@ -96,11 +114,11 @@ def update_states(self, event: dict) -> None: break self.visualizer.optimizer_config_updated.set() - state_msg = from_arrow(event["value"], event["metadata"], MujocoState) + render_pose_msg = from_arrow(event["value"], event["metadata"], RenderPose) try: with self.visualizer.task_lock: - self.visualizer.data.xpos[:] = state_msg.xpos - self.visualizer.data.xquat[:] = state_msg.xquat + self.visualizer.data.xpos[:] = render_pose_msg.xpos + self.visualizer.data.xquat[:] = render_pose_msg.xquat self.visualizer.viser_model.set_data(self.visualizer.data) except ValueError: # we're switching tasks and the new task has a different number of xpos/xquat diff --git a/judo/app/structs.py b/judo/app/structs.py index 1a65eb41..3769efe4 100644 --- a/judo/app/structs.py +++ b/judo/app/structs.py @@ -34,8 +34,6 @@ class MujocoState: time: float qpos: np.ndarray qvel: np.ndarray - xpos: np.ndarray - xquat: np.ndarray mocap_pos: np.ndarray mocap_quat: np.ndarray sim_metadata: dict[str, Any] @@ -82,3 +80,11 @@ def spline(self) -> interp1d: fill_value=fill_value, # type: ignore bounds_error=not self.extrapolate, ) + + +@dataclass +class RenderPose: + """Struct for visualization poses used by the renderer.""" + + xpos: np.ndarray + xquat: np.ndarray diff --git a/judo/app/utils.py b/judo/app/utils.py index 09aa632f..340609da 100644 --- a/judo/app/utils.py +++ b/judo/app/utils.py @@ -17,21 +17,73 @@ def get_class_from_string(class_path: str) -> type: def register_tasks_from_cfg(task_registration_cfg: DictConfig) -> None: - """Register custom tasks.""" + """Register custom tasks. + + Args: + task_registration_cfg: Mapping keyed by task name. Each value must contain: + `task`: import path to the Task class + `config`: import path to the TaskConfig class + + Optional keys: + `rollout_backend`: rollout backend registry key for controllers + `simulation_backend`: simulation backend registry key for the simulation node + `locomotion_policy_path`: path to a low-level policy used by hierarchical tasks + + Example schema: + { + "cylinder_push": { + "task": "judo.tasks.cylinder_push.CylinderPush", + "config": "judo.tasks.cylinder_push.CylinderPushConfig", + "rollout_backend": "mujoco", + "simulation_backend": "mujoco", + } + } + """ for task_name in task_registration_cfg.keys(): task_dict = task_registration_cfg.get(task_name, {}) - assert set(task_dict.keys()) == {"task", "config"}, ( - "Task registration must be a dict with keys 'task' and 'config'." + allowed_keys = {"task", "config", "rollout_backend", "simulation_backend", "locomotion_policy_path"} + assert set(task_dict.keys()).issubset(allowed_keys) and {"task", "config"}.issubset(task_dict.keys()), ( + "Task registration must include 'task' and 'config', and may optionally include " + "'rollout_backend', 'simulation_backend', and 'locomotion_policy_path'." ) assert isinstance(task_dict["task"], str), "Task must be a string path to the task class." assert isinstance(task_dict["config"], str), "Task config must be a string path to the config class." task_cls = get_class_from_string(task_dict["task"]) task_config_cls = get_class_from_string(task_dict["config"]) - register_task(str(task_name), task_cls, task_config_cls) + rollout_backend = task_dict.get("rollout_backend", "mujoco") + simulation_backend = task_dict.get("simulation_backend", "mujoco") + locomotion_policy_path = task_dict.get("locomotion_policy_path", None) + assert isinstance(rollout_backend, str), "rollout_backend must be a string." + assert isinstance(simulation_backend, str), "simulation_backend must be a string." + assert locomotion_policy_path is None or isinstance(locomotion_policy_path, str), ( + "locomotion_policy_path must be a string if provided." + ) + register_task( + str(task_name), + task_cls, + task_config_cls, + rollout_backend=rollout_backend, + simulation_backend=simulation_backend, + locomotion_policy_path=locomotion_policy_path, + ) def register_optimizers_from_cfg(optimizer_registration_cfg: DictConfig) -> None: - """Register custom optimizers.""" + """Register custom optimizers. + + Args: + optimizer_registration_cfg: Mapping keyed by optimizer name. Each value must contain: + `optimizer`: import path to the Optimizer class + `config`: import path to the OptimizerConfig class + + Example schema: + { + "cem": { + "optimizer": "judo.optimizers.cem.CrossEntropyMethod", + "config": "judo.optimizers.cem.CrossEntropyMethodConfig", + } + } + """ for optimizer_name in optimizer_registration_cfg.keys(): optimizer_dict = optimizer_registration_cfg.get(optimizer_name, {}) assert set(optimizer_dict.keys()) == {"optimizer", "config"}, ( diff --git a/judo/cli.py b/judo/cli.py index f5cabf69..7a9a29d4 100644 --- a/judo/cli.py +++ b/judo/cli.py @@ -141,8 +141,19 @@ def _warm_caches() -> None: pass # non-Spot tasks don't need this +def _require_mujoco_extensions() -> None: + """Fail fast if mujoco_extensions is unavailable in the current environment.""" + try: + import mujoco_extensions # noqa: F401, PLC0415 + except Exception as e: # pragma: no cover - environment dependent + raise RuntimeError( + "mujoco_extensions is required but could not be imported. Build it with: pixi run build" + ) from e + + def app() -> None: """Entry point for the judo CLI.""" + _require_mujoco_extensions() _warm_caches() # we store judo_dora_default in the config store so that custom dora configs outside of judo can inherit from it cs = ConfigStore.instance() @@ -168,6 +179,7 @@ def main_benchmark(cfg: DictConfig) -> None: def benchmark() -> None: """Entry point for benchmarking.""" + _require_mujoco_extensions() # we store benchmark_default in the config store so that custom configs located outside of judo can inherit from it cs = ConfigStore.instance() with initialize_config_dir(config_dir=str(CONFIG_PATH), version_base="1.3"): diff --git a/judo/configs/judo_dora_default.yaml b/judo/configs/judo_dora_default.yaml index f667539d..0ee7f442 100644 --- a/judo/configs/judo_dora_default.yaml +++ b/judo/configs/judo_dora_default.yaml @@ -20,11 +20,12 @@ dataflow: queue_size: 1 outputs: - states + - render_pose - id: visualization path: dynamic inputs: - states: - source: simulation/states + render_pose: + source: simulation/render_pose queue_size: 1 traces: source: controller/traces @@ -74,14 +75,13 @@ dataflow: node_definitions: simulation: - _target_: judo.app.dora.simulation.SimulationNode + _target_: judo.app.dora.simulation_node.SimulationNode node_id: simulation max_workers: null init_task: ${task} task_registration_cfg: ${custom_tasks} - simulation_backend: ${simulation_backend} visualization: - _target_: judo.app.dora.visualization.VisualizationNode + _target_: judo.app.dora.visualization_node.VisualizationNode node_id: visualization max_workers: null init_task: ${task} @@ -92,7 +92,7 @@ node_definitions: optimizer_override_cfg: ${optimizer_config_overrides} sim_pause_button: true controller: - _target_: judo.app.dora.controller.ControllerNode + _target_: judo.app.dora.controller_node.ControllerNode node_id: controller max_workers: null init_task: ${task} @@ -104,4 +104,3 @@ custom_tasks: null custom_optimizers: null controller_config_overrides: null optimizer_config_overrides: null -simulation_backend: mujoco diff --git a/judo/controller/controller.py b/judo/controller/controller.py index fc25fb94..6f2a5a6c 100644 --- a/judo/controller/controller.py +++ b/judo/controller/controller.py @@ -3,7 +3,7 @@ import copy import warnings from dataclasses import dataclass -from typing import TYPE_CHECKING, Literal +from typing import Any, Callable, Literal import numpy as np from mujoco import MjData @@ -15,7 +15,9 @@ from judo.config import OverridableConfig from judo.gui import slider from judo.optimizers import Optimizer, OptimizerConfig, get_registered_optimizers -from judo.tasks import Task, TaskConfig, get_registered_tasks +from judo.tasks import Task, TaskConfig, get_registered_tasks, get_task_registration +from judo.tasks.spot.spot_constants import POLICY_OUTPUT_DIM +from judo.utils.hierarchical_mj_rollout_backend import HierarchicalMJRolloutBackend from judo.utils.mj_rollout_backend import MJRolloutBackend from judo.utils.normalization import ( IdentityNormalizer, @@ -24,14 +26,18 @@ make_normalizer, normalizer_registry, ) -from judo.utils.policy_mj_rollout_backend import PolicyMJRolloutBackend - -if TYPE_CHECKING: - from judo.utils.mjwarp_rollout_backend import MJWarpRolloutBackend -from judo.utils.rollout_backend import RolloutBackend +from judo.utils.rollout_backend import BatchedRolloutBackend, RolloutBackend from judo.utils.timer import Timer from judo.visualizers.utils import get_trace_sensors +RolloutBackendEntry = type[RolloutBackend] | Callable[..., RolloutBackend] + + +DEFAULT_ROLLOUT_BACKEND_REGISTRY: dict[str, RolloutBackendEntry] = { + "mujoco": MJRolloutBackend, + "mujoco_hierarchical": HierarchicalMJRolloutBackend, +} + @slider("horizon", 0.1, 10.0, bounded=True) @slider("control_freq", 0.25, 50.0) @@ -55,7 +61,9 @@ def __init__( controller_config: ControllerConfig, task: Task, optimizer: Optimizer, - rollout_backend: "Literal['mujoco'] | MJWarpRolloutBackend" = "mujoco", + rollout_backend: str = "mujoco", + rollout_backend_registry: dict[str, RolloutBackendEntry] | None = None, + rollout_backend_kwargs: dict[str, Any] | None = None, ) -> None: """Initialize the controller. @@ -63,37 +71,36 @@ def __init__( controller_config: The controller configuration. task: The task to use. optimizer: The optimizer to use. - rollout_backend: The backend to use for rollouts. Either "mujoco" to create - a new CPU backend, or an existing RolloutBackend instance (e.g. GPU warp). + rollout_backend: Name of the backend to use for rollouts (e.g., "mujoco", "mujoco_hierarchical"). + rollout_backend_registry: Optional mapping of backend names to backend classes. + Overrides entries in DEFAULT_ROLLOUT_BACKEND_REGISTRY. + rollout_backend_kwargs: Optional extra kwargs for rollout backend constructor. + For "mujoco_hierarchical" backend, 'physics_substeps' and 'policy_path' cannot be + specified here—they are sourced from the task and task registry respectively. + Raises ValueError if either is provided. To use different values, create or + update a task registry entry. """ self._controller_cfg = controller_config self.task = task self.optimizer = optimizer + self._rollout_backend_registry = dict(DEFAULT_ROLLOUT_BACKEND_REGISTRY) + self._rollout_backend_registry.update(rollout_backend_registry or {}) + self._rollout_backend_kwargs = rollout_backend_kwargs or {} self.available_optimizers = get_registered_optimizers() self.available_tasks = get_registered_tasks() self.model = self.task.model - # Initialize rollout backend (auto-select policy backend if task requires it) - if not isinstance(rollout_backend, str): - # MJWarpRolloutBackend instance (avoid isinstance check to skip mujoco_warp import) - self.rollout_backend = rollout_backend - elif self.task.uses_locomotion_policy: - assert self.task.locomotion_policy_path is not None - self.rollout_backend: RolloutBackend = PolicyMJRolloutBackend( - model=self.model, - num_threads=self.optimizer_cfg.num_rollouts, - policy_path=self.task.locomotion_policy_path, - physics_substeps=self.task.physics_substeps, - ) - else: - self.rollout_backend = MJRolloutBackend( - model=self.model, - num_threads=self.optimizer_cfg.num_rollouts, - ) - - self._last_policy_output = None + self.rollout_backend: RolloutBackend = self._make_rollout_backend( + rollout_backend, + backend_kwargs=self._rollout_backend_kwargs, + ) + self._last_policy_output = ( + np.zeros((self.optimizer_cfg.num_rollouts, POLICY_OUTPUT_DIM)) + if isinstance(self.rollout_backend, HierarchicalMJRolloutBackend) + else None + ) self.action_normalizer = self._init_action_normalizer() # a container for any metadata from the system that we want to pass to the task @@ -371,9 +378,9 @@ def reset(self) -> None: self.candidate_knots = np.tile(self.nominal_knots, (self.optimizer_cfg.num_rollouts, 1, 1)) self.times = self.task.data.time + self.spline_timesteps self.update_spline(self.times, self.nominal_knots) - # Reset policy output state for locomotion policy tasks - if self.task.uses_locomotion_policy: - self._last_policy_output = None + # Reset policy output state for policy rollout backends + if isinstance(self.rollout_backend, HierarchicalMJRolloutBackend): + self._last_policy_output = np.zeros((self.optimizer_cfg.num_rollouts, POLICY_OUTPUT_DIM)) def update_traces(self) -> None: """Update traces by extracting data from sensors readings. @@ -431,11 +438,67 @@ def _init_action_normalizer(self) -> Normalizer: action_normalizer_kwargs["max"] = self.task.actuator_ctrlrange[:, 1] elif self.action_normalizer_type == "running": action_normalizer_kwargs["init_std"] = 1.0 # TODO(yunhai): make this configurable - return make_normalizer(self.action_normalizer_type, self.model.nu, **action_normalizer_kwargs) + return make_normalizer(self.action_normalizer_type, self.task.nu, **action_normalizer_kwargs) + + def _make_rollout_backend( + self, + backend_name: str, + backend_kwargs: dict[str, Any] | None = None, + ) -> RolloutBackend: + """Instantiate a rollout backend from the merged backend registry.""" + backend_factory = self._rollout_backend_registry.get(backend_name) + if backend_factory is None: + raise ValueError( + f"Unknown rollout backend '{backend_name}'. " + "Provide it via rollout_backend_registry or choose a built-in backend." + ) + + final_kwargs = { + "model": self.model, + "num_threads": self.optimizer_cfg.num_rollouts, + } + final_kwargs.update(backend_kwargs or {}) + + if backend_name == "mujoco_hierarchical": + # For the built-in hierarchical backend, physics_substeps is task-owned + # and should not be overridden at controller construction time. + if "physics_substeps" in final_kwargs: + raise ValueError( + f"Cannot specify 'physics_substeps' in rollout_backend_kwargs. " + f"It is determined by the task configuration (task.physics_substeps). " + f"Current task '{self.task.name}' has physics_substeps={self.task.physics_substeps}." + ) + final_kwargs["physics_substeps"] = self.task.physics_substeps + + # policy_path is currently registry-owned for hierarchical backend wiring. + # This keeps policy/model assumptions centralized until hierarchical ONNX + # integration is generalized on the C++ side. + if "policy_path" in final_kwargs: + raise ValueError( + f"Cannot specify 'policy_path' in rollout_backend_kwargs. " + f"It must be defined in the task registry entry for '{self.task.name}'. " + f"To use a different policy path, create or update a task registry entry with the desired path." + ) + + task_policy_path = get_task_registration(self.task.name).locomotion_policy_path + if task_policy_path is None: + raise ValueError( + f"Backend '{backend_name}' requires 'policy_path'. " + f"Task '{self.task.name}' must have a locomotion_policy_path registered in the task registry." + ) + final_kwargs["policy_path"] = task_policy_path + + backend = backend_factory(**final_kwargs) + if not isinstance(backend, RolloutBackend): + raise TypeError( + f"Rollout backend factory for '{backend_name}' must return a RolloutBackend, " + f"got {type(backend).__name__}." + ) + return backend class BatchedControllers: - """Coordinates multiple controllers sharing a single RolloutBackend. + """Coordinates multiple controllers sharing a single BatchedRolloutBackend. This class manages batched rollouts across multiple controllers, executing a single GPU rollout for all controllers at each optimization iteration. @@ -444,7 +507,7 @@ class BatchedControllers: # Create shared backend with num_threads per problem and num_problems num_rollouts = 64 # rollouts per controller num_problems = 3 # number of controllers - backend = RolloutBackend(model, num_threads=num_rollouts, num_problems=num_problems) + backend = BatchedRolloutBackend(model, num_threads=num_rollouts, num_problems=num_problems) # Create batched controller coordinator batched = BatchedControllers(config, task, optimizer, backend) @@ -458,7 +521,7 @@ def __init__( controller_config: ControllerConfig, task: Task, optimizer: Optimizer, - rollout_backend: "MJWarpRolloutBackend", + rollout_backend: BatchedRolloutBackend, ) -> None: """Initialize the batched controllers. @@ -466,7 +529,7 @@ def __init__( controller_config: Configuration for all controllers. task: Template task instance (new instances created from its class and model_path). optimizer: Template optimizer instance (deep copied for each controller). - rollout_backend: The shared WarpRolloutBackend instance. Should be initialized with + rollout_backend: Shared batched rollout backend instance. Should be initialized with num_problems equal to len(controllers). """ self.num_problems = rollout_backend.num_problems @@ -474,11 +537,19 @@ def __init__( for _ in range(self.num_problems): new_task = task.__class__(model_path=task.model_path) new_task.config = copy.deepcopy(task.config) + + # Construct controllers through normal backend-name resolution, but route + # the shared backend instance via a per-controller registry override. + shared_backend_name = "__shared_batched_backend__" + shared_backend_registry: dict[str, RolloutBackendEntry] = { + shared_backend_name: (lambda **_: rollout_backend) + } controller = Controller( controller_config=controller_config, task=new_task, optimizer=copy.deepcopy(optimizer), - rollout_backend=rollout_backend, + rollout_backend=shared_backend_name, + rollout_backend_registry=shared_backend_registry, ) self.controllers.append(controller) self.rollout_backend = rollout_backend @@ -492,9 +563,9 @@ def __init__( ) # Validate num_problems matches number of controllers - if rollout_backend.num_problems != len(self.controllers): + if self.num_problems != len(self.controllers): raise ValueError( - f"RolloutBackend num_problems ({rollout_backend.num_problems}) does not match " + f"RolloutBackend num_problems ({self.num_problems}) does not match " f"number of controllers ({len(self.controllers)}). " f"Initialize backend with num_problems={len(self.controllers)}." ) @@ -593,8 +664,9 @@ def print_timer_stats(self) -> None: self.timer_rewards.print_stats() self.timer_update_iter.print_stats() self.timer_post_opt.print_stats() - if hasattr(self.rollout_backend, "print_timer_stats"): - self.rollout_backend.print_timer_stats() + backend_print_timer_stats = getattr(self.rollout_backend, "print_timer_stats", None) + if callable(backend_print_timer_stats): + backend_print_timer_stats() def reset_timers(self) -> None: """Reset all timers.""" @@ -603,8 +675,9 @@ def reset_timers(self) -> None: self.timer_rewards.reset() self.timer_update_iter.reset() self.timer_post_opt.reset() - if hasattr(self.rollout_backend, "reset_timers"): - self.rollout_backend.reset_timers() + backend_reset_timers = getattr(self.rollout_backend, "reset_timers", None) + if callable(backend_reset_timers): + backend_reset_timers() def update_states(self, state_msgs: list) -> None: """Update states for all controllers. @@ -630,7 +703,7 @@ def set_init_previous_actions(self, previous_actions_list: list[np.ndarray | Non else: pa_np = np.stack([pa for pa in previous_actions_list if pa is not None], axis=0) pa_broadcast = np.repeat(pa_np, self.rollout_backend.num_threads, axis=0) - import warp as wp # noqa: PLC0415 + import warp as wp # pyright: ignore[reportMissingImports] # noqa: PLC0415 self._last_policy_output = wp.array(pa_broadcast, dtype=wp.float32, device=self.rollout_backend.device) @@ -662,17 +735,24 @@ def make_controller( init_optimizer: str, task_registration_cfg: DictConfig | None = None, optimizer_registration_cfg: DictConfig | None = None, - rollout_backend: "Literal['mujoco'] | MJWarpRolloutBackend" = "mujoco", + controller_cls: type[Controller] | None = None, + **controller_kwargs: Any, ) -> Controller: """Make a controller. Args: init_task: The task name to use. init_optimizer: The optimizer name to use. - task_registration_cfg: Optional task registration config. - optimizer_registration_cfg: Optional optimizer registration config. - rollout_backend: Either a backend type string ("mujoco") to create a new backend, - or an existing WarpRolloutBackend instance to share with other controllers. + task_registration_cfg: Optional task registration overrides keyed by task name. + Each entry must contain `task` and `config` import paths, and may also define + `rollout_backend`, `simulation_backend`, and `locomotion_policy_path`. + See register_tasks_from_cfg for the exact supported schema. + optimizer_registration_cfg: Optional optimizer registration overrides keyed by + optimizer name. Each entry must contain `optimizer` and `config` import paths. + See register_optimizers_from_cfg for the exact supported schema. + controller_cls: Optional controller class to instantiate instead of Controller. + **controller_kwargs: Additional keyword arguments forwarded to the controller + constructor. Returns: The created Controller instance. @@ -689,10 +769,10 @@ def make_controller( assert task_entry is not None, f"Task {init_task} not found in task registry." assert optimizer_entry is not None, f"Optimizer {init_optimizer} not found in optimizer registry." + task_registration = get_task_registration(init_task) # instantiate the task/optimizer/controller - task_cls, _ = task_entry - task = task_cls() + task = task_entry.task_type() optimizer_cls, optimizer_config_cls = optimizer_entry optimizer_cfg = optimizer_config_cls() @@ -702,9 +782,11 @@ def make_controller( controller_cfg = ControllerConfig() controller_cfg.set_override(init_task) - return Controller( + cls = controller_cls or Controller + return cls( controller_config=controller_cfg, task=task, optimizer=optimizer, - rollout_backend=rollout_backend, + rollout_backend=task_registration.rollout_backend, + **controller_kwargs, ) diff --git a/judo/simulation/__init__.py b/judo/simulation/__init__.py index f58c801d..cb16104d 100644 --- a/judo/simulation/__init__.py +++ b/judo/simulation/__init__.py @@ -1,22 +1,33 @@ # Copyright (c) 2025 Robotics and AI Institute LLC. All rights reserved. from judo.simulation.base import Simulation +from judo.simulation.hierarchical_mj_simulation import HierarchicalMJSimulation from judo.simulation.mj_simulation import MJSimulation -from judo.simulation.policy_mj_simulation import PolicyMJSimulation -simulation_registry = { +DEFAULT_SIMULATION_BACKEND_REGISTRY: dict[str, type[Simulation]] = { "mujoco": MJSimulation, - "mujoco_policy": PolicyMJSimulation, + "mujoco_hierarchical": HierarchicalMJSimulation, } def get_simulation_backend(simulation_backend: str) -> type: - """Get the simulation class for a given backend.""" - return simulation_registry[simulation_backend] + """Get the simulation class for a given backend. + + Args: + simulation_backend: Name of the simulation backend to get. + + Returns: + The simulation class for the given backend. + """ + if simulation_backend not in DEFAULT_SIMULATION_BACKEND_REGISTRY: + raise KeyError(f"Unknown simulation backend: {simulation_backend!r}") + return DEFAULT_SIMULATION_BACKEND_REGISTRY[simulation_backend] __all__ = [ "Simulation", "MJSimulation", - "PolicyMJSimulation", + "HierarchicalMJSimulation", + "DEFAULT_SIMULATION_BACKEND_REGISTRY", + "get_simulation_backend", ] diff --git a/judo/simulation/base.py b/judo/simulation/base.py index a32aa20e..bf9cbc77 100644 --- a/judo/simulation/base.py +++ b/judo/simulation/base.py @@ -5,6 +5,7 @@ import numpy as np from omegaconf import DictConfig +from judo.app.structs import MujocoState, RenderPose from judo.app.utils import register_tasks_from_cfg from judo.tasks import get_registered_tasks from judo.tasks.base import Task @@ -37,8 +38,7 @@ def set_task(self, task_name: str) -> None: if task_entry is None: raise ValueError(f"Task {task_name} not found in task registry") - task_cls, _ = task_entry - self.task: Task = task_cls() + self.task: Task = task_entry.task_type() self.task.reset() @abstractmethod @@ -54,6 +54,26 @@ def pause(self) -> None: self.paused = not self.paused @property - @abstractmethod + def sim_state(self) -> MujocoState: + """Returns the current simulation state.""" + return MujocoState( + time=self.task.data.time, # type: ignore + qpos=self.task.data.qpos, # type: ignore + qvel=self.task.data.qvel, # type: ignore + mocap_pos=self.task.data.mocap_pos, # type: ignore + mocap_quat=self.task.data.mocap_quat, # type: ignore + sim_metadata=self.task.get_sim_metadata(), + ) + + @property + def render_pose(self) -> RenderPose: + """Returns the current pose data used for visualization.""" + return RenderPose( + xpos=self.task.data.xpos, # type: ignore + xquat=self.task.data.xquat, # type: ignore + ) + + @property def timestep(self) -> float: """Timestep the simulation expects to run at.""" + return self.task.dt diff --git a/judo/simulation/policy_mj_simulation.py b/judo/simulation/hierarchical_mj_simulation.py similarity index 63% rename from judo/simulation/policy_mj_simulation.py rename to judo/simulation/hierarchical_mj_simulation.py index ad60b39f..483e8f1f 100644 --- a/judo/simulation/policy_mj_simulation.py +++ b/judo/simulation/hierarchical_mj_simulation.py @@ -1,6 +1,6 @@ # Copyright (c) 2025 Robotics and AI Institute LLC. All rights reserved. -"""MuJoCo Simulation with locomotion policy support.""" +"""MuJoCo simulation with hierarchical low-level policy support.""" from pathlib import Path @@ -9,6 +9,7 @@ from omegaconf import DictConfig from judo.simulation.mj_simulation import MJSimulation +from judo.tasks import get_task_registration from judo.tasks.spot.spot_constants import DEFAULT_SPOT_ROLLOUT_CUTOFF_TIME, POLICY_OUTPUT_DIM try: @@ -21,14 +22,15 @@ ) from e -class PolicyMJSimulation(MJSimulation): - """MuJoCo simulation with locomotion policy support. +class HierarchicalMJSimulation(MJSimulation): + """MuJoCo simulation with a hierarchical low-level policy layer. - For tasks with locomotion_policy_path set, uses C++ mujoco_extensions - threaded_rollout to run the neural network policy at 50Hz. + For tasks with uses_locomotion_policy=True, this routes control through + mujoco_extensions threaded_rollout so a lower-level policy can refine the + high-level command before physics integration. - The simulation maintains internal state for the locomotion policy - (last_policy_output) to ensure smooth transitions between timesteps. + The simulation maintains internal policy state (last_policy_output) to + ensure smooth transitions between timesteps. """ def __init__( @@ -36,7 +38,7 @@ def __init__( init_task: str = "spot_base", task_registration_cfg: DictConfig | None = None, ) -> None: - """Initialize the policy simulation. + """Initialize the hierarchical simulation. Args: init_task: Name of the task to initialize. @@ -47,15 +49,20 @@ def __init__( self._systems = None self._last_policy_output = np.zeros(POLICY_OUTPUT_DIM) - # Initialize C++ systems if task uses locomotion policy - if self.task.locomotion_policy_path is not None: - self._init_cpp_systems(self.task.locomotion_policy_path) + # Initialize C++ systems if the task uses a hierarchical policy layer. + if self.task.uses_locomotion_policy: + policy_path = get_task_registration(self.task.name).locomotion_policy_path + if policy_path is None: + raise ValueError( + f"Task '{self.task.name}' uses locomotion policy but no locomotion_policy_path is registered." + ) + self._init_cpp_systems(policy_path) def _init_cpp_systems(self, policy_path: str | Path) -> None: """Initialize the C++ systems vector for threaded rollout. Args: - policy_path: Path to the ONNX locomotion policy file. + policy_path: Path to the ONNX low-level policy file. """ self._systems = create_systems_vector( self.task.model, # Pass the MjModel directly @@ -66,26 +73,27 @@ def _init_cpp_systems(self, policy_path: str | Path) -> None: def step(self, command: np.ndarray) -> None: """Step the simulation forward. - Routes to the C++ policy rollout if systems are initialized, + Routes to the C++ hierarchical rollout if systems are initialized, otherwise falls back to direct actuator control. Args: command: Control array in task format (task.nu dimensions). - For locomotion tasks, will be converted to policy command internally. + For hierarchical tasks, this is converted to the low-level + policy command internally. """ if self._systems is not None: if self.paused: return command = self.task.task_to_sim_ctrl(command) - self._step_with_locomotion_policy(command) + self._step_with_hierarchical_policy(command) else: super().step(command) - def _step_with_locomotion_policy(self, command: np.ndarray) -> None: - """Execute a single step using the C++ rollout backend. + def _step_with_hierarchical_policy(self, command: np.ndarray) -> None: + """Execute a single step using the hierarchical rollout backend. Args: - command: Command array for the locomotion policy. + command: Command array for the low-level policy. """ # Get current state state = np.concatenate([self.task.data.qpos, self.task.data.qvel]) @@ -124,7 +132,7 @@ def _step_with_locomotion_policy(self, command: np.ndarray) -> None: # Compute derived quantities (xpos, xquat, etc.) for visualization mj_forward(self.task.model, self.task.data) - # Update last policy output for continuity + # Update last policy output for continuity. self._last_policy_output = np.array(policy_outputs[0]) def reset_policy_state(self) -> None: @@ -139,14 +147,22 @@ def set_task(self, task_name: str) -> None: """ super().set_task(task_name) - # Reinitialize systems based on new task's policy - if self.task.locomotion_policy_path is not None: - self._init_cpp_systems(self.task.locomotion_policy_path) + # Reinitialize systems based on the new task's policy layer. + if self.task.uses_locomotion_policy: + policy_path = get_task_registration(self.task.name).locomotion_policy_path + if policy_path is None: + raise ValueError( + f"Task '{self.task.name}' uses locomotion policy but no locomotion_policy_path is registered." + ) + self._init_cpp_systems(policy_path) self._last_policy_output = np.zeros(POLICY_OUTPUT_DIM) else: - self._systems = None + raise ValueError( + f"Task '{self.task.name}' does not use a locomotion policy. " + "Use MJSimulation instead of HierarchicalMJSimulation for this task." + ) @property def last_policy_output(self) -> np.ndarray: - """Returns the last policy output (12-dim leg actions).""" + """Return the last low-level policy output.""" return self._last_policy_output.copy() diff --git a/judo/simulation/mj_simulation.py b/judo/simulation/mj_simulation.py index 6abeb46d..4d7164e0 100644 --- a/judo/simulation/mj_simulation.py +++ b/judo/simulation/mj_simulation.py @@ -6,7 +6,6 @@ from mujoco import mj_step from omegaconf import DictConfig -from judo.app.structs import MujocoState from judo.simulation.base import Simulation @@ -34,7 +33,7 @@ def step(self, command: np.ndarray) -> None: """Step the simulation forward. Args: - command: Control array in task format (task.nu dimensions). + command: Control command for this timestep. """ if self.paused: return @@ -44,30 +43,3 @@ def step(self, command: np.ndarray) -> None: self.task.pre_sim_step() mj_step(self.task.sim_model, self.task.data) self.task.post_sim_step() - - def set_task(self, task_name: str) -> None: - """Set the current task. - - Args: - task_name: Name of the task to set. - """ - super().set_task(task_name) - - @property - def sim_state(self) -> MujocoState: - """Returns the current simulation state.""" - return MujocoState( - time=self.task.data.time, - qpos=self.task.data.qpos, - qvel=self.task.data.qvel, - xpos=self.task.data.xpos, - xquat=self.task.data.xquat, - mocap_pos=self.task.data.mocap_pos, - mocap_quat=self.task.data.mocap_quat, - sim_metadata=self.task.get_sim_metadata(), - ) - - @property - def timestep(self) -> float: - """Returns the effective simulation timestep (accounting for substeps).""" - return self.task.dt diff --git a/judo/tasks/__init__.py b/judo/tasks/__init__.py index 38a077e2..30be02ad 100644 --- a/judo/tasks/__init__.py +++ b/judo/tasks/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) 2025 Robotics and AI Institute LLC. All rights reserved. -from typing import Dict, Tuple, Type +from dataclasses import dataclass +from typing import Dict, Type from judo.tasks.base import Task, TaskConfig from judo.tasks.caltech_leap_cube import CaltechLeapCube, CaltechLeapCubeConfig @@ -21,35 +22,101 @@ SpotTireUpright, SpotTireUprightConfig, ) +from judo.tasks.spot.spot_constants import SPOT_LOCOMOTION_POLICY_PATH -_registered_tasks: Dict[str, Tuple[Type[Task], Type[TaskConfig]]] = { - CylinderPush.name: (CylinderPush, CylinderPushConfig), - Cartpole.name: (Cartpole, CartpoleConfig), - FR3Pick.name: (FR3Pick, FR3PickConfig), - LeapCube.name: (LeapCube, LeapCubeConfig), - LeapCubeDown.name: (LeapCubeDown, LeapCubeDownConfig), - CaltechLeapCube.name: (CaltechLeapCube, CaltechLeapCubeConfig), - SpotBase.name: (SpotBase, SpotBaseConfig), - SpotBoxPush.name: (SpotBoxPush, SpotBoxPushConfig), - SpotNavigate.name: (SpotNavigate, SpotNavigateConfig), - SpotTireRoll.name: (SpotTireRoll, SpotTireRollConfig), - SpotTireUpright.name: (SpotTireUpright, SpotTireUprightConfig), + +@dataclass(frozen=True) +class TaskRegistration: + """Complete registration metadata for a task.""" + + task_type: Type[Task] + task_config_type: Type[TaskConfig] + rollout_backend: str = "mujoco" + simulation_backend: str = "mujoco" + locomotion_policy_path: str | None = None + + +_registered_tasks: Dict[str, TaskRegistration] = { + CylinderPush.name: TaskRegistration(CylinderPush, CylinderPushConfig), + Cartpole.name: TaskRegistration(Cartpole, CartpoleConfig), + FR3Pick.name: TaskRegistration(FR3Pick, FR3PickConfig), + LeapCube.name: TaskRegistration(LeapCube, LeapCubeConfig), + LeapCubeDown.name: TaskRegistration(LeapCubeDown, LeapCubeDownConfig), + CaltechLeapCube.name: TaskRegistration(CaltechLeapCube, CaltechLeapCubeConfig), + SpotBase.name: TaskRegistration( + SpotBase, + SpotBaseConfig, + rollout_backend="mujoco_hierarchical", + simulation_backend="mujoco_hierarchical", + locomotion_policy_path=str(SPOT_LOCOMOTION_POLICY_PATH), + ), + SpotBoxPush.name: TaskRegistration( + SpotBoxPush, + SpotBoxPushConfig, + rollout_backend="mujoco_hierarchical", + simulation_backend="mujoco_hierarchical", + locomotion_policy_path=str(SPOT_LOCOMOTION_POLICY_PATH), + ), + SpotNavigate.name: TaskRegistration( + SpotNavigate, + SpotNavigateConfig, + rollout_backend="mujoco_hierarchical", + simulation_backend="mujoco_hierarchical", + locomotion_policy_path=str(SPOT_LOCOMOTION_POLICY_PATH), + ), + SpotTireRoll.name: TaskRegistration( + SpotTireRoll, + SpotTireRollConfig, + rollout_backend="mujoco_hierarchical", + simulation_backend="mujoco_hierarchical", + locomotion_policy_path=str(SPOT_LOCOMOTION_POLICY_PATH), + ), + SpotTireUpright.name: TaskRegistration( + SpotTireUpright, + SpotTireUprightConfig, + rollout_backend="mujoco_hierarchical", + simulation_backend="mujoco_hierarchical", + locomotion_policy_path=str(SPOT_LOCOMOTION_POLICY_PATH), + ), } -def get_registered_tasks() -> Dict[str, Tuple[Type[Task], Type[TaskConfig]]]: +def get_registered_tasks() -> Dict[str, TaskRegistration]: """Returns a dictionary of registered tasks.""" return _registered_tasks -def register_task(name: str, task_type: Type[Task], task_config_type: Type[TaskConfig]) -> None: - """Registers a new task.""" - _registered_tasks[name] = (task_type, task_config_type) +def get_task_registration(task_name: str) -> TaskRegistration: + """Return full registration metadata for a task.""" + task_entry = _registered_tasks.get(task_name) + if task_entry is None: + raise ValueError(f"Task {task_name} not found in task registry.") + return task_entry + + +def register_task( + name: str, + task_type: Type[Task], + task_config_type: Type[TaskConfig], + rollout_backend: str = "mujoco", + simulation_backend: str = "mujoco", + locomotion_policy_path: str | None = None, +) -> None: + """Registers a new task and its default controller/simulation backends.""" + _registered_tasks[name] = TaskRegistration( + task_type=task_type, + task_config_type=task_config_type, + rollout_backend=rollout_backend, + simulation_backend=simulation_backend, + locomotion_policy_path=locomotion_policy_path, + ) __all__ = [ "get_registered_tasks", + "get_task_registration", "register_task", + "TaskRegistration", "Task", "TaskConfig", "CaltechLeapCube", diff --git a/judo/tasks/base.py b/judo/tasks/base.py index 016b5ebb..3ab5cf31 100644 --- a/judo/tasks/base.py +++ b/judo/tasks/base.py @@ -25,6 +25,7 @@ class Task(ABC, Generic[ConfigT]): """Task definition.""" config_t: type[ConfigT] + name: str def __init__(self, model_path: Path | str = "", sim_model_path: Path | str | None = None) -> None: """Initialize the Mujoco task.""" @@ -80,19 +81,10 @@ def nu(self) -> int: """Number of control inputs. The same as the MjModel for this task.""" return self.model.nu - @property - def locomotion_policy_path(self) -> str | None: - """Path to locomotion policy for this task, or None if not used. - - Override in tasks that use a learned locomotion policy - (e.g., Spot tasks that run an ONNX policy at 50Hz). - """ - return None - @property def uses_locomotion_policy(self) -> bool: """Whether this task uses a locomotion policy for simulation.""" - return self.locomotion_policy_path is not None + return False @property def actuator_ctrlrange(self) -> np.ndarray: diff --git a/judo/tasks/spot/spot_base.py b/judo/tasks/spot/spot_base.py index 1a4e83c3..c92fe5d8 100644 --- a/judo/tasks/spot/spot_base.py +++ b/judo/tasks/spot/spot_base.py @@ -6,7 +6,7 @@ but adapted for judo's standalone simulation framework. """ -from dataclasses import dataclass, field +from dataclasses import dataclass from pathlib import Path from typing import Any, Generic, TypeVar @@ -33,7 +33,6 @@ LEG_SOFT_UPPER_JOINT_LIMITS, LEGS_STANDING_POS, LEGS_STANDING_POS_RL, - SPOT_LOCOMOTION_POLICY_PATH, STANDING_HEIGHT, STANDING_HEIGHT_CMD, TORSO_CMD_INDS, @@ -44,15 +43,6 @@ XML_PATH = str(MODEL_PATH / "xml" / "spot_primitive" / "robot.xml") -@dataclass -class GoalPositions: - """Goal positions for Spot tasks.""" - - origin: np.ndarray = field(default_factory=lambda: np.array([0, 0, 0.0])) - blue_cross: np.ndarray = field(default_factory=lambda: np.array([2.77, 0.71, 0.3])) - black_cross: np.ndarray = field(default_factory=lambda: np.array([1.5, -1.5, 0.275])) - - @dataclass class SpotBaseConfig(TaskConfig): """Base configuration for Spot tasks. @@ -118,9 +108,9 @@ def physics_substeps(self) -> int: # type: ignore[override] return 2 @property - def locomotion_policy_path(self) -> str: - """Path to Spot locomotion policy.""" - return str(SPOT_LOCOMOTION_POLICY_PATH) + def uses_locomotion_policy(self) -> bool: # type: ignore[override] + """Spot tasks always use a locomotion policy backend.""" + return True def __init__( self, @@ -454,7 +444,7 @@ def reset(self) -> None: def get_action_components(self) -> list[str]: """Get names of each component in the action command vector. - Matches starfish/dexterity/tasks/spot_base.py. + Matches judo/tasks/spot/spot_base.py. """ action_components = ["spot/base.vx", "spot/base.vy", "spot/base.vtheta"] if self.use_arm: diff --git a/judo/utils/policy_mj_rollout_backend.py b/judo/utils/hierarchical_mj_rollout_backend.py similarity index 86% rename from judo/utils/policy_mj_rollout_backend.py rename to judo/utils/hierarchical_mj_rollout_backend.py index a0939743..fc216d72 100644 --- a/judo/utils/policy_mj_rollout_backend.py +++ b/judo/utils/hierarchical_mj_rollout_backend.py @@ -1,6 +1,6 @@ # Copyright (c) 2025 Robotics and AI Institute LLC. All rights reserved. -"""MuJoCo rollout backend with locomotion policy support.""" +"""MuJoCo rollout backend with hierarchical low-level policy support.""" from pathlib import Path @@ -11,8 +11,8 @@ from judo.utils.rollout_backend import RolloutBackend -class PolicyMJRolloutBackend(RolloutBackend): - """Rollout backend with C++ mujoco_extensions and ONNX locomotion policy inference. +class HierarchicalMJRolloutBackend(RolloutBackend): + """Rollout backend with C++ mujoco_extensions and ONNX low-level policy inference. For Spot tasks, the command format is a 25-dim vector: [base_vel(3), arm(7), legs(12), torso(3)] @@ -25,12 +25,12 @@ def __init__( policy_path: str | Path, physics_substeps: int = 2, ) -> None: - """Initialize the policy rollout backend. + """Initialize the hierarchical rollout backend. Args: model: MuJoCo model for the scene. num_threads: Number of parallel rollout threads. - policy_path: Path to ONNX locomotion policy. + policy_path: Path to the ONNX low-level policy. physics_substeps: Physics steps per control step. """ self.num_threads = num_threads @@ -41,7 +41,7 @@ def __init__( self._setup_mujoco_extensions(model, policy_path, num_threads) def _setup_mujoco_extensions(self, model: MjModel, policy_path: str | Path, num_threads: int) -> None: - """Setup the mujoco_extensions C++ rollout backend with ONNX policy.""" + """Setup the mujoco_extensions C++ rollout backend with ONNX low-level policy.""" try: from mujoco_extensions.policy_rollout import create_systems_vector, threaded_rollout # type: ignore # noqa: PLC0415, I001 except ImportError as e: @@ -60,7 +60,7 @@ def rollout( controls: np.ndarray, last_policy_output: np.ndarray | None = None, ) -> tuple[np.ndarray, np.ndarray, np.ndarray | None]: - """Conduct parallel rollouts with policy inference. + """Conduct parallel rollouts with hierarchical policy inference. Args: x0: Initial state, shape (nq+nv,). Will be tiled to num_threads internally. @@ -78,7 +78,7 @@ def rollout( x0 = np.tile(x0, (self.num_threads, 1)) if last_policy_output is None: - last_policy_output = np.zeros((x0.shape[0], 12)) + raise ValueError("last_policy_output is required for HierarchicalMJRolloutBackend") x0 = np.asarray(x0, dtype=np.float64) controls = np.asarray(controls, dtype=np.float64) diff --git a/judo/utils/mj_rollout_backend.py b/judo/utils/mj_rollout_backend.py index a4f42c49..e7b1ea8b 100644 --- a/judo/utils/mj_rollout_backend.py +++ b/judo/utils/mj_rollout_backend.py @@ -12,6 +12,13 @@ from judo.utils.rollout_backend import RolloutBackend +def make_model_data_pairs(model: MjModel, num_pairs: int) -> tuple[list[MjModel], list[MjData]]: + """Create model/data pairs for mujoco threaded rollout.""" + models = [deepcopy(model) for _ in range(num_pairs)] + datas = [MjData(m) for m in models] + return models, datas + + class MJRolloutBackend(RolloutBackend): """Backend for conducting multithreaded rollouts using standard MuJoCo. @@ -32,16 +39,9 @@ def __init__( self.num_threads = num_threads self.model = model - self._model_data_pairs = self._make_model_data_pairs(model, num_threads) + self._models, self._datas = make_model_data_pairs(model, num_threads) self._rollout_obj = Rollout(nthread=num_threads) - @staticmethod - def _make_model_data_pairs(model: MjModel, num_pairs: int) -> list[tuple[MjModel, MjData]]: - """Create model/data pairs for mujoco threaded rollout.""" - models = [deepcopy(model) for _ in range(num_pairs)] - datas = [MjData(m) for m in models] - return list(zip(models, datas, strict=True)) - def rollout( self, x0: np.ndarray, @@ -64,16 +64,12 @@ def rollout( if x0.ndim == 1: x0 = np.tile(x0, (self.num_threads, 1)) - ms, ds = zip(*self._model_data_pairs, strict=True) - ms = list(ms) - ds = list(ds) - - nq = ms[0].nq - nv = ms[0].nv - nu = ms[0].nu + nq = self._models[0].nq + nv = self._models[0].nv + nu = self._models[0].nu # Prepend time to batched x0 - full_states = np.concatenate([time.time() * np.ones((len(ms), 1)), x0], axis=-1) + full_states = np.concatenate([time.time() * np.ones((len(self._models), 1)), x0], axis=-1) assert full_states.shape[-1] == nq + nv + 1 assert full_states.ndim == 2 @@ -81,7 +77,7 @@ def rollout( assert controls.shape[-1] == nu assert controls.shape[0] == full_states.shape[0] - _states, _sensors = self._rollout_obj.rollout(ms, ds, full_states, controls) + _states, _sensors = self._rollout_obj.rollout(self._models, self._datas, full_states, controls) out_states = np.array(_states)[..., 1:] # Remove time from state out_sensors = np.array(_sensors) @@ -97,5 +93,5 @@ def update(self, num_threads: int) -> None: """ self.num_threads = num_threads self._rollout_obj.close() - self._model_data_pairs = self._make_model_data_pairs(self.model, num_threads) + self._models, self._datas = make_model_data_pairs(self.model, num_threads) self._rollout_obj = Rollout(nthread=num_threads) diff --git a/judo/utils/mjwarp_rollout_backend.py b/judo/utils/mjwarp_rollout_backend.py index 9642ff26..085f4bac 100644 --- a/judo/utils/mjwarp_rollout_backend.py +++ b/judo/utils/mjwarp_rollout_backend.py @@ -10,7 +10,7 @@ from mujoco import MjData, MjModel from judo.controller.batched_spot_locomotion import BatchedSpotLocomotion -from judo.utils.rollout_backend import RolloutBackend +from judo.utils.rollout_backend import BatchedRolloutBackend from judo.utils.timer import Timer logger = logging.getLogger(__name__) @@ -56,7 +56,7 @@ def target_frequency(self) -> float: return float("inf") -class MJWarpRolloutBackend(RolloutBackend): +class MJWarpRolloutBackend(BatchedRolloutBackend): """GPU-accelerated rollout backend using mujoco_warp. Supports two modes: diff --git a/judo/utils/rollout_backend.py b/judo/utils/rollout_backend.py index a3a1f144..4a59df2d 100644 --- a/judo/utils/rollout_backend.py +++ b/judo/utils/rollout_backend.py @@ -45,3 +45,10 @@ def update(self, num_threads: int) -> None: Args: num_threads: New number of parallel threads. """ + + +class BatchedRolloutBackend(RolloutBackend, ABC): + """Rollout backend base class for multi-problem batched execution.""" + + num_problems: int + device: str diff --git a/judo/visualizers/visualizer.py b/judo/visualizers/visualizer.py index 4562a9bf..d69e1c3a 100644 --- a/judo/visualizers/visualizer.py +++ b/judo/visualizers/visualizer.py @@ -15,7 +15,7 @@ from judo.controller import ControllerConfig from judo.gui import create_gui_elements from judo.optimizers import get_registered_optimizers -from judo.tasks import get_registered_tasks +from judo.tasks import TaskRegistration, get_registered_tasks from judo.visualizers.model import ViserMjModel ElementType = GuiImageHandle | GuiInputHandle | GuiFolderHandle | MeshHandle | IcosphereHandle @@ -39,6 +39,7 @@ def __init__( optimizer_override_cfg: DictConfig | None = None, sim_pause_button: bool = True, geom_exclude_substring: str = "collision", + available_tasks: dict[str, TaskRegistration] | None = None, ) -> None: """Initialize the visualization node.""" # handling custom task and optimizer registration @@ -49,7 +50,7 @@ def __init__( # starting the server self.server = viser.ViserServer() - self.available_tasks = get_registered_tasks() + self.available_tasks = available_tasks or get_registered_tasks() self.available_optimizers = get_registered_optimizers() self.geom_exclude_substring = geom_exclude_substring @@ -105,8 +106,7 @@ def set_task(self, task_name: str, optimizer_name: str) -> None: if task_entry is None: raise ValueError(f"Task {task_name} not found in the task registry.") - task_cls, _ = task_entry - self.task = task_cls() + self.task = task_entry.task_type() self.task_config = self.task.config self.data = mujoco.MjData(self.task.model) self.viser_model = ViserMjModel( diff --git a/run_mpc/mpc_batch.py b/run_mpc/mpc_batch.py index 1c9f367e..5cca2856 100644 --- a/run_mpc/mpc_batch.py +++ b/run_mpc/mpc_batch.py @@ -9,24 +9,24 @@ import numpy as np from tqdm import tqdm -from judo.app.structs import MujocoState +from judo.app.structs import RenderPose from judo.controller import BatchedControllers as JudoBatchedController from judo.controller import Controller as JudoController +from judo.simulation.hierarchical_mj_simulation import HierarchicalMJSimulation from judo.simulation.mj_simulation import MJSimulation -from judo.simulation.policy_mj_simulation import PolicyMJSimulation from judo.visualizers.visualizer import Visualizer from run_mpc.mpc_config import MPCTimers, PublicMPCConfig, SizeData def _get_previous_actions(sims: list[MJSimulation]) -> list[np.ndarray | None]: """Get previous actions from sims for hierarchical control sync.""" - return [sim.last_policy_output if isinstance(sim, PolicyMJSimulation) else None for sim in sims] + return [sim.last_policy_output if isinstance(sim, HierarchicalMJSimulation) else None for sim in sims] -def update_visualization(visualizer: Visualizer, sim_state: MujocoState, traces: np.ndarray) -> None: - """Update the viser visualization with current sim state and traces.""" - visualizer.data.xpos[:] = sim_state.xpos - visualizer.data.xquat[:] = sim_state.xquat +def update_visualization(visualizer: Visualizer, render_pose: RenderPose, traces: np.ndarray) -> None: + """Update the viser visualization with current render pose and traces.""" + visualizer.data.xpos[:] = render_pose.xpos + visualizer.data.xquat[:] = render_pose.xquat visualizer.viser_model.set_data(visualizer.data) sensor_rollout_size = traces.shape[1] num_trace_sensors = traces.shape[2] @@ -154,7 +154,7 @@ def run_mpc_batch( # Reset all simulations and controllers, sample new initial conditions + goals for sim, ctrl in zip(sims, controllers, strict=True): sim.task.reset() - if isinstance(sim, PolicyMJSimulation): + if isinstance(sim, HierarchicalMJSimulation): sim.reset_policy_state() ctrl.reset() ctrl.update_states(sim.sim_state) @@ -202,7 +202,7 @@ def run_mpc_batch( timers.sim_step.toc() if vis is not None and num_parallel == 1: - update_visualization(vis, sims[0].sim_state, controllers[0].traces) + update_visualization(vis, sims[0].render_pose, controllers[0].traces) time.sleep(sims[0].timestep) results = storage.package_results(config.max_num_task_steps) diff --git a/run_mpc/mpc_config.py b/run_mpc/mpc_config.py index e0f1649f..e06b57a1 100644 --- a/run_mpc/mpc_config.py +++ b/run_mpc/mpc_config.py @@ -119,7 +119,8 @@ def load_configs_from_json_data(json_data: Any) -> tuple[JudoTask, Optimizer, Co optimizer_entry = available_optimizers.get(json_data["optimizer"]) assert optimizer_entry is not None, f"Optimizer {json_data['optimizer']} is not registered!" - task_cls, task_config_cls = task_entry + task_cls = task_entry.task_type + task_config_cls = task_entry.task_config_type task_config: JudoTaskConfig = dacite.from_dict(task_config_cls, json_data["task_config"]) task: JudoTask = task_cls() task.config = task_config diff --git a/run_mpc/mpc_setup.py b/run_mpc/mpc_setup.py index b161ad8c..f3b5b856 100644 --- a/run_mpc/mpc_setup.py +++ b/run_mpc/mpc_setup.py @@ -14,8 +14,8 @@ from judo.controller import ControllerConfig from judo.controller.batched_spot_locomotion import BatchedSpotLocomotion from judo.optimizers import Optimizer +from judo.simulation.hierarchical_mj_simulation import HierarchicalMJSimulation from judo.simulation.mj_simulation import MJSimulation -from judo.simulation.policy_mj_simulation import PolicyMJSimulation from judo.tasks import Task as JudoTask from judo.utils.mjwarp_rollout_backend import MJWarpRolloutBackend from run_mpc.mpc_config import PublicMPCConfig, SizeData, make_size_data @@ -58,7 +58,7 @@ def setup_mpc( for _ in range(num_parallel): if use_spot: - sim = PolicyMJSimulation(init_task=json_configs["task"]) + sim = HierarchicalMJSimulation(init_task=json_configs["task"]) else: sim = MJSimulation(init_task=json_configs["task"]) sim.task.config = copy.deepcopy(task.config) diff --git a/run_mpc/visualize_trajectories.py b/run_mpc/visualize_trajectories.py index ac29971e..e1c933a0 100644 --- a/run_mpc/visualize_trajectories.py +++ b/run_mpc/visualize_trajectories.py @@ -56,8 +56,7 @@ def visualize_trajectory_batch( registered_tasks = get_registered_tasks() task_entry = registered_tasks.get(task) assert task_entry is not None, f"Task {task} is not registered!" - task_cls, _ = task_entry - task_instance = task_cls() + task_instance = task_entry.task_type() # Increase constraint buffers for contact-heavy scenes (e.g. Spot + tire). task_instance.spec.nconmax = 512 task_instance.spec.njmax = 2048 diff --git a/tests/test_controller/test_action_normalization.py b/tests/test_controller/test_action_normalization.py index e9ba2d03..4c662d7f 100644 --- a/tests/test_controller/test_action_normalization.py +++ b/tests/test_controller/test_action_normalization.py @@ -122,7 +122,6 @@ def test_normalizer_type_change() -> None: controller = make_controller( init_task="cylinder_push", init_optimizer="cem", - rollout_backend="mujoco", ) # Initially should be IdentityNormalizer @@ -147,7 +146,6 @@ def test_normalizer_in_update_action_loop() -> None: controller = make_controller( init_task="cylinder_push", init_optimizer="cem", - rollout_backend="mujoco", ) controller.controller_cfg = ControllerConfig(action_normalizer=normalizer_type) @@ -164,7 +162,6 @@ def test_min_max_normalizer_with_task_control_ranges() -> None: controller = make_controller( init_task="cylinder_push", init_optimizer="cem", - rollout_backend="mujoco", ) controller.controller_cfg = ControllerConfig(action_normalizer="min_max", max_opt_iters=1) @@ -196,7 +193,6 @@ def test_running_normalizer_updates_with_optimizer_data() -> None: controller = make_controller( init_task="cylinder_push", init_optimizer="cem", - rollout_backend="mujoco", ) controller.controller_cfg = ControllerConfig(action_normalizer="running", max_opt_iters=1) diff --git a/tests/test_controller/test_controller.py b/tests/test_controller/test_controller.py index 5bf0d7c6..b43765e6 100644 --- a/tests/test_controller/test_controller.py +++ b/tests/test_controller/test_controller.py @@ -49,7 +49,6 @@ def _setup_controller(max_opt_iters: int) -> tuple[MockOptimizerTrackNominalKnot controller = make_controller( init_task="cylinder_push", init_optimizer="cem", - rollout_backend="mujoco", ) controller.controller_cfg = ControllerConfig(max_opt_iters=max_opt_iters) controller.optimizer = opt @@ -87,7 +86,6 @@ def _setup_controller(opt_cls: type[Optimizer], opt_cfg: OptimizerConfig) -> Con controller = make_controller( init_task="cylinder_push", init_optimizer="cem", - rollout_backend="mujoco", ) controller.optimizer = opt return controller diff --git a/tests/test_simulation/test_simulation.py b/tests/test_simulation/test_simulation.py index eb9d4a25..5254de27 100644 --- a/tests/test_simulation/test_simulation.py +++ b/tests/test_simulation/test_simulation.py @@ -4,7 +4,8 @@ import numpy as np -from judo.simulation import MJSimulation, PolicyMJSimulation +from judo.simulation import HierarchicalMJSimulation, MJSimulation +from judo.tasks import get_task_registration from judo.tasks.cartpole import Cartpole, CartpoleConfig from judo.tasks.cylinder_push import CylinderPush, CylinderPushConfig @@ -41,15 +42,16 @@ def test_simulation_data_step(temp_np_seed: Callable) -> None: def test_spot_simulation_init() -> None: - """Test PolicyMJSimulation initializes with a Spot task and C++ systems.""" - sim = PolicyMJSimulation(init_task="spot_base") + """Test HierarchicalMJSimulation initializes with a Spot task and C++ systems.""" + sim = HierarchicalMJSimulation(init_task="spot_base") assert sim._systems is not None - assert sim.task.locomotion_policy_path is not None + assert sim.task.uses_locomotion_policy + assert get_task_registration(sim.task.name).locomotion_policy_path is not None def test_spot_simulation_step() -> None: - """Test PolicyMJSimulation steps correctly with Spot locomotion policy.""" - sim = PolicyMJSimulation(init_task="spot_base") + """Test HierarchicalMJSimulation steps correctly with Spot locomotion policy.""" + sim = HierarchicalMJSimulation(init_task="spot_base") qpos_before = sim.task.data.qpos.copy() command = np.zeros(sim.task.nu) sim.step(command) diff --git a/tests/test_spot_tasks.py b/tests/test_spot_tasks.py index 5ee3784a..90744b90 100644 --- a/tests/test_spot_tasks.py +++ b/tests/test_spot_tasks.py @@ -8,12 +8,12 @@ import pytest from judo.tasks.spot import ( - SpotBase, SpotBoxPush, SpotNavigate, SpotTireRoll, SpotTireUpright, ) +from judo.tasks.spot.spot_base import SpotBase # (TaskClass, expected_nu) # nu depends on use_arm, use_gripper, use_legs, use_torso: diff --git a/tests/test_tasks/test_spot.py b/tests/test_tasks/test_spot.py index daccac71..530357b9 100644 --- a/tests/test_tasks/test_spot.py +++ b/tests/test_tasks/test_spot.py @@ -4,7 +4,9 @@ import numpy as np -from judo.tasks.spot import SpotBase, SpotTireUpright +from judo.tasks import get_task_registration +from judo.tasks.spot import SpotTireUpright +from judo.tasks.spot.spot_base import SpotBase def test_spot_base_init() -> None: @@ -12,7 +14,8 @@ def test_spot_base_init() -> None: task = SpotBase() assert task.name == "spot_base" assert task.physics_substeps == 2 - assert task.locomotion_policy_path is not None + assert task.uses_locomotion_policy + assert get_task_registration(task.name).locomotion_policy_path is not None # Base + arm = 3 + 7 = 10 assert task.nu == 10 @@ -22,7 +25,8 @@ def test_spot_tire_upright_init() -> None: task = SpotTireUpright() assert task.name == "spot_tire_upright" assert task.physics_substeps == 2 - assert task.locomotion_policy_path is not None + assert task.uses_locomotion_policy + assert get_task_registration(task.name).locomotion_policy_path is not None # Base + arm + legs + leg_selection = 3 + 7 + 6 + 1 = 17 assert task.nu == 17