Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 43 additions & 19 deletions judo/app/dora/controller.py → judo/app/dora/controller_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import time
from threading import Lock
from typing import Callable

import pyarrow as pa
from dora_utils.dataclasses import from_event, to_arrow
Expand All @@ -23,38 +24,60 @@ def __init__(
max_workers: int | None = None,
task_registration_cfg: DictConfig | None = None,
optimizer_registration_cfg: DictConfig | None = None,
make_controller_fn: Callable | None = None,
) -> None:
"""Initialize the controller node."""
"""Initialize the controller node.

Args:
init_task: Name of the task to initialize.
init_optimizer: Name of the optimizer to initialize (e.g., "cem", "ps", "mppi").
node_id: Identifier for this dora node.
max_workers: Maximum number of worker threads for dora (None = auto).
task_registration_cfg: Optional config for task registration overrides.
optimizer_registration_cfg: Optional config for optimizer registration overrides.
make_controller_fn: Optional factory function to create Controller instances.
Defaults to judo.controller.make_controller. Allows custom controller creation.
"""
super().__init__(node_id=node_id, max_workers=max_workers)
self.controller = make_controller(
init_task=init_task,
init_optimizer=init_optimizer,
task_registration_cfg=task_registration_cfg,
optimizer_registration_cfg=optimizer_registration_cfg,
)
self._make_controller_fn = make_controller_fn or make_controller
self._task_registration_cfg = task_registration_cfg
self._optimizer_registration_cfg = optimizer_registration_cfg
self.controller = self._build_controller(init_task, init_optimizer)
self._paused = False
self.write_controls()
self.lock = Lock()

def _build_controller(self, task_name: str, optimizer_name: str) -> Controller:
"""Build controller using the task's registered rollout backend."""
return self._make_controller_fn(
init_task=task_name,
init_optimizer=optimizer_name,
task_registration_cfg=self._task_registration_cfg,
optimizer_registration_cfg=self._optimizer_registration_cfg,
)

def _current_optimizer_name(self) -> str:
"""Look up the name of the current optimizer from the registry.

Returns "cem" as a safe default if no registry entry matches the active optimizer instance.
"""
for name, (cls, _) in self.controller.available_optimizers.items():
if isinstance(self.controller.optimizer, cls):
return name
return "cem"
Comment thread
dta-bdai marked this conversation as resolved.

@on_event("INPUT", "task")
def update_task(self, event: dict) -> None:
"""Updates the task type."""
new_task = event["value"].to_numpy(zero_copy_only=False)[0]
task_entry = self.controller.available_tasks.get(new_task)
if task_entry is not None:
task_cls, _ = task_entry
with self.lock:
task = task_cls()
optimizer = self.controller.optimizer_cls(self.controller.optimizer_config_cls(), task.nu)
self.controller = Controller(
controller_config=self.controller.controller_cfg,
task=task,
optimizer=optimizer,
)
self.write_controls()
else:
if task_entry is None:
raise ValueError(f"Task {new_task} not found in task registry.")

with self.lock:
self.controller = self._build_controller(new_task, self._current_optimizer_name())
self.write_controls()

@on_event("INPUT", "task_reset")
def reset_task(self, event: dict) -> None:
"""Resets the task."""
Expand All @@ -75,6 +98,7 @@ def update_optimizer(self, event: dict) -> None:
if optimizer_entry is not None:
optimizer_cls, optimizer_config_cls = optimizer_entry
optimizer_config = optimizer_config_cls()
optimizer_config.set_override(self.controller.task.name)
optimizer = optimizer_cls(optimizer_config, self.controller.task.nu)
with self.lock:
self.controller.optimizer = optimizer
Expand Down
42 changes: 30 additions & 12 deletions judo/app/dora/simulation.py → judo/app/dora/simulation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from omegaconf import DictConfig

from judo.app.structs import SplineData
from judo.simulation import get_simulation_backend
from judo.simulation import DEFAULT_SIMULATION_BACKEND_REGISTRY
from judo.simulation.base import Simulation
from judo.tasks import get_registered_tasks


class SimulationNode(DoraNode):
Expand All @@ -21,26 +23,40 @@ def __init__(
init_task: str = "cylinder_push",
max_workers: int | None = None,
task_registration_cfg: DictConfig | None = None,
simulation_backend: str = "mujoco",
backend_registry: dict[str, type[Simulation]] | None = None,
) -> None:
"""Initialize the simulation node."""
"""Initialize the simulation node.

Args:
node_id: Identifier for this dora node.
init_task: Name of the task to initialize.
max_workers: Maximum number of worker threads for dora (None = auto).
task_registration_cfg: Optional config for task registration overrides.
backend_registry: Optional mapping of backend names → Simulation classes. Checked first before built-in registry.
"""
super().__init__(node_id=node_id, max_workers=max_workers)
self._simulation_backend = simulation_backend
self._task_registration_cfg = task_registration_cfg
self._backend_registry = dict(DEFAULT_SIMULATION_BACKEND_REGISTRY)
self._backend_registry.update(backend_registry or {})
self._init_sim(init_task)
self.control_spline: Callable | None = None
self.write_states()

def _resolve_backend(self, backend_name: str) -> type[Simulation]:
"""Resolve a simulation backend class by name from merged registry."""
backend_cls = self._backend_registry.get(backend_name)
if backend_cls is None:
raise KeyError(f"Unknown simulation backend: {backend_name!r}")
return backend_cls

def _init_sim(self, task_name: str) -> None:
"""Initialize simulation, auto-upgrading to policy backend if needed."""
backend = self._simulation_backend
_sim_backend = get_simulation_backend(backend)
self.sim = _sim_backend(init_task=task_name, task_registration_cfg=self._task_registration_cfg)
"""Initialize simulation using the task's registered simulation backend."""
task_entry = get_registered_tasks().get(task_name)
if task_entry is None:
raise ValueError(f"Task {task_name} not found in task registry.")

# Auto-upgrade to policy backend if task requires locomotion policy
if backend == "mujoco" and self.sim.task.uses_locomotion_policy:
_sim_backend = get_simulation_backend("mujoco_policy")
self.sim = _sim_backend(init_task=task_name, task_registration_cfg=self._task_registration_cfg)
sim_backend_cls = self._resolve_backend(task_entry.simulation_backend)
self.sim = sim_backend_cls(init_task=task_name, task_registration_cfg=self._task_registration_cfg)

@on_event("INPUT", "task")
def update_task(self, event: dict) -> None:
Expand Down Expand Up @@ -85,6 +101,8 @@ def write_states(self) -> None:
"""Reads data from simulation and writes to output topic."""
arr, metadata = to_arrow(self.sim.sim_state)
self.node.send_output("states", arr, metadata)
arr, metadata = to_arrow(self.sim.render_pose)
self.node.send_output("render_pose", arr, metadata)

@on_event("INPUT", "sim_pause")
def set_paused_status(self, event: dict) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from omegaconf import DictConfig
from viser import GuiFolderHandle, GuiImageHandle, GuiInputHandle, IcosphereHandle, MeshHandle

from judo.app.structs import MujocoState
from judo.app.structs import RenderPose
from judo.tasks import TaskRegistration
from judo.visualizers.visualizer import Visualizer

ElementType = GuiImageHandle | GuiInputHandle | GuiFolderHandle | MeshHandle | IcosphereHandle
Expand All @@ -29,8 +30,24 @@ def __init__(
optimizer_override_cfg: DictConfig | None = None,
sim_pause_button: bool = True,
geom_exclude_substring: str = "collision",
available_tasks: dict[str, TaskRegistration] | None = None,
) -> None:
"""Initialize the visualization node."""
"""Initialize the visualization node (Viser web GUI for task/optimizer control).

Args:
node_id: Identifier for this dora node.
max_workers: Maximum number of worker threads for dora (None = auto).
init_task: Name of the task to initialize.
init_optimizer: Name of the optimizer to initialize (e.g., "cem", "ps").
task_registration_cfg: Optional config for task registration overrides.
optimizer_registration_cfg: Optional config for optimizer registration overrides.
controller_override_cfg: Optional config overrides for the controller.
optimizer_override_cfg: Optional config overrides for the optimizer.
sim_pause_button: Whether to display a simulation pause button in the GUI.
geom_exclude_substring: Geometry name substring to exclude from visualization (default "collision" hides collision shapes).
available_tasks: Optional pre-computed mapping of task names to TaskRegistration entries
for the task selector. If None, tasks are inferred from the task registry.
"""
super().__init__(node_id=node_id, max_workers=max_workers)
self.visualizer = Visualizer(
init_task=init_task,
Expand All @@ -41,6 +58,7 @@ def __init__(
optimizer_override_cfg=optimizer_override_cfg,
sim_pause_button=sim_pause_button,
geom_exclude_substring=geom_exclude_substring,
available_tasks=available_tasks,
)

def write_sim_pause(self) -> None:
Expand Down Expand Up @@ -85,7 +103,7 @@ def write_task_config(self) -> None:
self.node.send_output("task_config", *to_arrow(self.visualizer.task_config))
self.visualizer.task_config_updated.clear()

@on_event("INPUT", "states")
@on_event("INPUT", "render_pose")
def update_states(self, event: dict) -> None:
"""Callback to update states on receiving a new state measurement."""
if self.visualizer.controller_config.spline_order == "cubic" and self.visualizer.optimizer_config.num_nodes < 4:
Expand All @@ -96,11 +114,11 @@ def update_states(self, event: dict) -> None:
break
self.visualizer.optimizer_config_updated.set()

state_msg = from_arrow(event["value"], event["metadata"], MujocoState)
render_pose_msg = from_arrow(event["value"], event["metadata"], RenderPose)
try:
with self.visualizer.task_lock:
self.visualizer.data.xpos[:] = state_msg.xpos
self.visualizer.data.xquat[:] = state_msg.xquat
self.visualizer.data.xpos[:] = render_pose_msg.xpos
self.visualizer.data.xquat[:] = render_pose_msg.xquat
self.visualizer.viser_model.set_data(self.visualizer.data)
except ValueError:
# we're switching tasks and the new task has a different number of xpos/xquat
Expand Down
10 changes: 8 additions & 2 deletions judo/app/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ class MujocoState:
time: float
qpos: np.ndarray
qvel: np.ndarray
xpos: np.ndarray
xquat: np.ndarray
mocap_pos: np.ndarray
mocap_quat: np.ndarray
sim_metadata: dict[str, Any]
Expand Down Expand Up @@ -82,3 +80,11 @@ def spline(self) -> interp1d:
fill_value=fill_value, # type: ignore
bounds_error=not self.extrapolate,
)


@dataclass
class RenderPose:
"""Struct for visualization poses used by the renderer."""

xpos: np.ndarray
xquat: np.ndarray
62 changes: 57 additions & 5 deletions judo/app/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,73 @@ def get_class_from_string(class_path: str) -> type:


def register_tasks_from_cfg(task_registration_cfg: DictConfig) -> None:
"""Register custom tasks."""
"""Register custom tasks.

Args:
task_registration_cfg: Mapping keyed by task name. Each value must contain:
`task`: import path to the Task class
`config`: import path to the TaskConfig class

Optional keys:
`rollout_backend`: rollout backend registry key for controllers
`simulation_backend`: simulation backend registry key for the simulation node
`locomotion_policy_path`: path to a low-level policy used by hierarchical tasks

Example schema:
{
"cylinder_push": {
"task": "judo.tasks.cylinder_push.CylinderPush",
"config": "judo.tasks.cylinder_push.CylinderPushConfig",
"rollout_backend": "mujoco",
"simulation_backend": "mujoco",
}
}
"""
for task_name in task_registration_cfg.keys():
task_dict = task_registration_cfg.get(task_name, {})
assert set(task_dict.keys()) == {"task", "config"}, (
"Task registration must be a dict with keys 'task' and 'config'."
allowed_keys = {"task", "config", "rollout_backend", "simulation_backend", "locomotion_policy_path"}
assert set(task_dict.keys()).issubset(allowed_keys) and {"task", "config"}.issubset(task_dict.keys()), (
"Task registration must include 'task' and 'config', and may optionally include "
"'rollout_backend', 'simulation_backend', and 'locomotion_policy_path'."
)
assert isinstance(task_dict["task"], str), "Task must be a string path to the task class."
assert isinstance(task_dict["config"], str), "Task config must be a string path to the config class."
task_cls = get_class_from_string(task_dict["task"])
task_config_cls = get_class_from_string(task_dict["config"])
register_task(str(task_name), task_cls, task_config_cls)
rollout_backend = task_dict.get("rollout_backend", "mujoco")
simulation_backend = task_dict.get("simulation_backend", "mujoco")
locomotion_policy_path = task_dict.get("locomotion_policy_path", None)
assert isinstance(rollout_backend, str), "rollout_backend must be a string."
assert isinstance(simulation_backend, str), "simulation_backend must be a string."
assert locomotion_policy_path is None or isinstance(locomotion_policy_path, str), (
"locomotion_policy_path must be a string if provided."
)
register_task(
str(task_name),
task_cls,
task_config_cls,
rollout_backend=rollout_backend,
simulation_backend=simulation_backend,
locomotion_policy_path=locomotion_policy_path,
)


def register_optimizers_from_cfg(optimizer_registration_cfg: DictConfig) -> None:
"""Register custom optimizers."""
"""Register custom optimizers.

Args:
optimizer_registration_cfg: Mapping keyed by optimizer name. Each value must contain:
`optimizer`: import path to the Optimizer class
`config`: import path to the OptimizerConfig class

Example schema:
{
"cem": {
"optimizer": "judo.optimizers.cem.CrossEntropyMethod",
"config": "judo.optimizers.cem.CrossEntropyMethodConfig",
}
}
"""
for optimizer_name in optimizer_registration_cfg.keys():
optimizer_dict = optimizer_registration_cfg.get(optimizer_name, {})
assert set(optimizer_dict.keys()) == {"optimizer", "config"}, (
Expand Down
12 changes: 12 additions & 0 deletions judo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,19 @@ def _warm_caches() -> None:
pass # non-Spot tasks don't need this


def _require_mujoco_extensions() -> None:
"""Fail fast if mujoco_extensions is unavailable in the current environment."""
try:
import mujoco_extensions # noqa: F401, PLC0415
except Exception as e: # pragma: no cover - environment dependent
raise RuntimeError(
"mujoco_extensions is required but could not be imported. Build it with: pixi run build"
) from e


def app() -> None:
"""Entry point for the judo CLI."""
_require_mujoco_extensions()
_warm_caches()
# we store judo_dora_default in the config store so that custom dora configs outside of judo can inherit from it
cs = ConfigStore.instance()
Expand All @@ -168,6 +179,7 @@ def main_benchmark(cfg: DictConfig) -> None:

def benchmark() -> None:
"""Entry point for benchmarking."""
_require_mujoco_extensions()
# we store benchmark_default in the config store so that custom configs located outside of judo can inherit from it
cs = ConfigStore.instance()
with initialize_config_dir(config_dir=str(CONFIG_PATH), version_base="1.3"):
Expand Down
Loading