Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,22 @@ pixi run build
If you prefer named environments, `pixi install -e dev` is equivalent for this repo,
and the corresponding task form is `pixi run -e dev ...`.

### Cleaning the build

If native extensions misbehave (e.g. after a mujoco / judo / pybind11 version
bump, or if you see odd shape mismatches between Python and the C++ rollout
backend), wipe the build artifacts before rebuilding:

```bash
# Remove C++ build dirs, deployed .so files, and __pycache__ directories
pixi run clean-build
pixi run build

# Full reset: also removes .judo-src/ and .pixi/ (forces fresh clone + reinstall)
pixi run clean-all
pixi run build # auto-runs `pixi install` first
```

## Run

```bash
Expand Down
70 changes: 38 additions & 32 deletions pixi.lock

Large diffs are not rendered by default.

50 changes: 47 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ allow-direct-references = true
[project]
name = "sumo"
version = "0.0.1"
requires-python = ">=3.10"
dependencies = [
"judo-rai @ git+https://github.com/bdaiinstitute/judo.git@dta/fix_for_sumo",
"judo-rai @ git+https://github.com/rai-opensource/judo.git@dta/fix_for_sumo",
"numpy",
"mujoco>=3.5.0, <3.6.0",
"mujoco>=3.6.0,<3.7",
"h5py",
Comment thread
dta-bdai marked this conversation as resolved.
"tyro",
"tqdm",
Expand Down Expand Up @@ -74,7 +75,7 @@ sumo = "sumo"
cmd = """
sh -c '
if [ ! -d .judo-src ]; then
git clone --depth 1 --branch dta/fix_for_sumo https://github.com/bdaiinstitute/judo.git .judo-src
git clone --depth 1 --branch dta/fix_for_sumo https://github.com/rai-opensource/judo.git .judo-src
fi &&
sed -i.bak "s|libonnxruntime.so\\*|libonnxruntime.*|" .judo-src/mujoco_extensions/CMakeLists.txt &&
mkdir -p .judo-src/build &&
Expand Down Expand Up @@ -127,6 +128,48 @@ depends-on = ["build-g1-ext", "build-judo-ext"]
cwd = "g1_extensions"
cmd = "rm -rf build/"

# Re-resolve the judo-rai git dependency to the latest commit on its tracked
# branch, then rebuild C++ extensions. Use this when the judo branch has new
# commits and `pixi install` keeps installing the old pinned commit.
[tool.pixi.tasks.update-judo]
cmd = "pixi update judo-rai && pixi run clean-build && pixi run build"

# Remove all C++ build artifacts (g1_extensions and judo mujoco_extensions),
# the installed extension .so files in site-packages, and Python caches.
# Use this when mujoco / judo / pybind11 versions change to avoid stale objects.
[tool.pixi.tasks.clean-build]
cmd = """
sh -c '
set -e
rm -rf g1_extensions/build/
rm -rf .judo-src/build/
rm -f .judo-src/mujoco_extensions/policy_rollout/policy_rollout_pybind*.so
SITE=$(python -c "import site; print(site.getsitepackages()[0])" 2>/dev/null || true)
if [ -n "$SITE" ]; then
rm -f "$SITE"/mujoco_extensions/policy_rollout/policy_rollout_pybind*.so
rm -f "$SITE"/g1_extensions/_g1_extensions*.so
fi
find . -path ./.pixi -prune -o -type d -name __pycache__ -print -exec rm -rf {} + 2>/dev/null || true
echo "Cleaned C++ build artifacts and Python caches."
'
"""

# Full nuke: clean-build + remove the cloned judo source AND the entire pixi env.
# Forces fresh judo clone and full reinstall on next `pixi install` / `pixi run build`.
# Use when judo git revision in pixi.lock has been updated or when site-packages
# may have stale judo files.
[tool.pixi.tasks.clean-all]
depends-on = ["clean-build"]
cmd = """
sh -c '
set -e
rm -rf .judo-src
rm -rf .pixi
rm -rf outputs/ run_mpc/results/ out/
echo "Full clean done. Run: pixi install && pixi run build"
'
"""

[tool.pytest.ini_options]
markers = [
"g1_extensions: marks tests that require the optional g1_extensions extension",
Expand Down Expand Up @@ -200,4 +243,5 @@ exclude = [
"**/__pycache__",
]
stubPath = "typings"
useLibraryCodeForTypes = true
reportAttributeAccessIssue = "warning"
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) 2025-2026 Robotics and AI Institute LLC dba RAI Institute. All rights reserved.

from judo.app.dora.controller import ControllerNode as JudoControllerNode
from judo.app.dora.controller_node import ControllerNode as JudoControllerNode

import sumo.controller # noqa: F401 -- register controller/optimizer overrides
import sumo.tasks # noqa: F401 -- register all sumo tasks
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) 2025-2026 Robotics and AI Institute LLC dba RAI Institute. All rights reserved.

from judo.app.dora.simulation import SimulationNode as JudoSimulationNode
from judo.app.dora.simulation_node import SimulationNode as JudoSimulationNode

import sumo.tasks # noqa: F401 -- register all sumo tasks
from sumo.app.dora.g1_simulation import G1Simulation
Expand All @@ -10,7 +10,7 @@ class SimulationNode(JudoSimulationNode):
"""Simulation node with G1 backend support."""

def __init__(self, init_task: str = "spot_box_push", **kwargs) -> None:
kwargs.setdefault("custom_backends", {"mujoco_g1": G1Simulation})
kwargs.setdefault("backend_registry", {"mujoco_g1": G1Simulation})
super().__init__(init_task=init_task, **kwargs)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) 2025-2026 Robotics and AI Institute LLC dba RAI Institute. All rights reserved.

from judo.app.dora.visualization import VisualizationNode as JudoVisualizationNode
from judo.app.dora.visualization_node import VisualizationNode as JudoVisualizationNode

import sumo.controller # noqa: F401 -- register controller/optimizer overrides
import sumo.tasks # noqa: F401 -- register all sumo tasks
Expand Down
6 changes: 3 additions & 3 deletions sumo/configs/sumo_default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ task: spot_box_push

node_definitions:
simulation:
_target_: sumo.app.dora.simulation.SimulationNode
_target_: sumo.app.dora.simulation_node.SimulationNode
visualization:
_target_: sumo.app.dora.visualization.VisualizationNode
_target_: sumo.app.dora.visualization_node.VisualizationNode
controller:
_target_: sumo.app.dora.controller.ControllerNode
_target_: sumo.app.dora.controller_node.ControllerNode
6 changes: 1 addition & 5 deletions sumo/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

from __future__ import annotations

from typing import Literal

from judo.controller.controller import Controller, ControllerConfig, make_spline
from judo.controller.controller import make_controller as _judo_make_controller
from omegaconf import DictConfig
Expand All @@ -18,16 +16,14 @@ def make_controller(
init_optimizer: str,
task_registration_cfg: DictConfig | None = None,
optimizer_registration_cfg: DictConfig | None = None,
rollout_backend: Literal["mujoco"] = "mujoco",
) -> Controller:
"""Make a controller with G1 backend support."""
return _judo_make_controller(
init_task=init_task,
init_optimizer=init_optimizer,
task_registration_cfg=task_registration_cfg,
optimizer_registration_cfg=optimizer_registration_cfg,
rollout_backend=rollout_backend,
custom_rollout_backends={"mujoco_g1": G1RolloutBackend},
rollout_backend_registry={"mujoco_g1": G1RolloutBackend},
)


Expand Down
16 changes: 9 additions & 7 deletions sumo/run_mpc/run_mpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,22 @@ class RunMPCConfig:

def _create_sim(task_name: str):
"""Create the right simulation backend for a task."""
task_cls, _ = get_registered_tasks()[task_name]
default_backend = getattr(task_cls, "default_backend", None)
task_entry = get_registered_tasks()[task_name]
simulation_backend = task_entry.simulation_backend

if default_backend == "mujoco_g1":
if simulation_backend == "mujoco_g1":
require_g1_extensions()
return G1Simulation(init_task=task_name)

if task_cls().uses_locomotion_policy:
if simulation_backend == "mujoco_hierarchical":
require_mujoco_extensions()
from judo.simulation import get_simulation_backend

return get_simulation_backend("mujoco_policy")(init_task=task_name)
return get_simulation_backend("mujoco_hierarchical")(init_task=task_name)

from judo.simulation import get_simulation_backend

return get_simulation_backend("mujoco")(init_task=task_name)
return get_simulation_backend(simulation_backend)(init_task=task_name)


def _make_condition_checker(method):
Expand Down Expand Up @@ -246,7 +246,9 @@ def run_mpc(config: RunMPCConfig) -> list[dict]:
# Create controller
controller_config = ControllerConfig()
controller_config.set_override(config.init_task)
controller = Controller(controller_config, task, optimizer, custom_rollout_backends={"mujoco_g1": G1RolloutBackend})
controller = Controller(
controller_config, task, optimizer, rollout_backend_registry={"mujoco_g1": G1RolloutBackend}
)

# Set up visualization
viser_model = None
Expand Down
69 changes: 44 additions & 25 deletions sumo/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2025-2026 Robotics and AI Institute LLC dba RAI Institute. All rights reserved.

from judo.tasks import get_registered_tasks, register_task
from judo.tasks import TaskRegistration, get_registered_tasks, register_task
from judo.tasks.spot.spot_constants import SPOT_LOCOMOTION_POLICY_PATH

G1_TASK_NAMES = (
"g1_base",
Expand All @@ -16,11 +17,23 @@
from sumo.tasks.g1.g1_door import G1Door, G1DoorConfig
from sumo.tasks.g1.g1_table_push import G1TablePush, G1TablePushConfig

register_task("g1_base", G1Base, G1BaseConfig)
register_task("g1_box", G1Box, G1BoxConfig)
register_task("g1_chair_push", G1ChairPush, G1ChairPushConfig)
register_task("g1_door", G1Door, G1DoorConfig)
register_task("g1_table_push", G1TablePush, G1TablePushConfig)
register_task("g1_base", G1Base, G1BaseConfig, rollout_backend="mujoco_g1", simulation_backend="mujoco_g1")
register_task("g1_box", G1Box, G1BoxConfig, rollout_backend="mujoco_g1", simulation_backend="mujoco_g1")
register_task(
"g1_chair_push",
G1ChairPush,
G1ChairPushConfig,
rollout_backend="mujoco_g1",
simulation_backend="mujoco_g1",
)
register_task("g1_door", G1Door, G1DoorConfig, rollout_backend="mujoco_g1", simulation_backend="mujoco_g1")
register_task(
"g1_table_push",
G1TablePush,
G1TablePushConfig,
rollout_backend="mujoco_g1",
simulation_backend="mujoco_g1",
)

SPOT_TASK_NAMES = (
"spot_base",
Expand Down Expand Up @@ -63,29 +76,35 @@
from sumo.tasks.spot.spot_tire_stack import SpotTireStack, SpotTireStackConfig
from sumo.tasks.spot.spot_tire_upright import SpotTireUpright, SpotTireUprightConfig

register_task("spot_base", SpotBase, SpotBaseConfig)
register_task("spot_box_push", SpotBoxPush, SpotBoxPushConfig)
register_task("spot_chair_push", SpotChairPush, SpotChairPushConfig)
register_task("spot_cone_push", SpotConePush, SpotConePushConfig)
register_task("spot_rack_push", SpotRackPush, SpotRackPushConfig)
register_task("spot_tire_push", SpotTirePush, SpotTirePushConfig)
register_task("spot_box_upright", SpotBoxUpright, SpotBoxUprightConfig)
register_task("spot_chair_upright", SpotChairUpright, SpotChairUprightConfig)
register_task("spot_cone_upright", SpotConeUpright, SpotConeUprightConfig)
register_task("spot_rack_upright", SpotRackUpright, SpotRackUprightConfig)
register_task("spot_tire_upright", SpotTireUpright, SpotTireUprightConfig)
register_task("spot_chair_ramp", SpotChairRamp, SpotChairRampConfig)
register_task("spot_barrier_upright", SpotBarrierUpright, SpotBarrierUprightConfig)
register_task("spot_barrier_drag", SpotBarrierDrag, SpotBarrierDragConfig)
register_task("spot_tire_roll", SpotTireRoll, SpotTireRollConfig)
register_task("spot_tire_stack", SpotTireStack, SpotTireStackConfig)
register_task("spot_tire_rack_drag", SpotTireRackDrag, SpotTireRackDragConfig)
register_task("spot_rugged_box_push", SpotRuggedBoxPush, SpotRuggedBoxPushConfig)
_SPOT_REGISTRATION_KWARGS = {
"rollout_backend": "mujoco_hierarchical",
"simulation_backend": "mujoco_hierarchical",
"locomotion_policy_path": str(SPOT_LOCOMOTION_POLICY_PATH),
}

register_task("spot_base", SpotBase, SpotBaseConfig, **_SPOT_REGISTRATION_KWARGS)
register_task("spot_box_push", SpotBoxPush, SpotBoxPushConfig, **_SPOT_REGISTRATION_KWARGS)
register_task("spot_chair_push", SpotChairPush, SpotChairPushConfig, **_SPOT_REGISTRATION_KWARGS)
register_task("spot_cone_push", SpotConePush, SpotConePushConfig, **_SPOT_REGISTRATION_KWARGS)
register_task("spot_rack_push", SpotRackPush, SpotRackPushConfig, **_SPOT_REGISTRATION_KWARGS)
register_task("spot_tire_push", SpotTirePush, SpotTirePushConfig, **_SPOT_REGISTRATION_KWARGS)
register_task("spot_box_upright", SpotBoxUpright, SpotBoxUprightConfig, **_SPOT_REGISTRATION_KWARGS)
register_task("spot_chair_upright", SpotChairUpright, SpotChairUprightConfig, **_SPOT_REGISTRATION_KWARGS)
register_task("spot_cone_upright", SpotConeUpright, SpotConeUprightConfig, **_SPOT_REGISTRATION_KWARGS)
register_task("spot_rack_upright", SpotRackUpright, SpotRackUprightConfig, **_SPOT_REGISTRATION_KWARGS)
register_task("spot_tire_upright", SpotTireUpright, SpotTireUprightConfig, **_SPOT_REGISTRATION_KWARGS)
register_task("spot_chair_ramp", SpotChairRamp, SpotChairRampConfig, **_SPOT_REGISTRATION_KWARGS)
register_task("spot_barrier_upright", SpotBarrierUpright, SpotBarrierUprightConfig, **_SPOT_REGISTRATION_KWARGS)
register_task("spot_barrier_drag", SpotBarrierDrag, SpotBarrierDragConfig, **_SPOT_REGISTRATION_KWARGS)
register_task("spot_tire_roll", SpotTireRoll, SpotTireRollConfig, **_SPOT_REGISTRATION_KWARGS)
register_task("spot_tire_stack", SpotTireStack, SpotTireStackConfig, **_SPOT_REGISTRATION_KWARGS)
register_task("spot_tire_rack_drag", SpotTireRackDrag, SpotTireRackDragConfig, **_SPOT_REGISTRATION_KWARGS)
register_task("spot_rugged_box_push", SpotRuggedBoxPush, SpotRuggedBoxPushConfig, **_SPOT_REGISTRATION_KWARGS)

SUMO_TASK_NAMES = G1_TASK_NAMES + SPOT_TASK_NAMES


def get_sumo_registered_tasks() -> dict[str, tuple[type, type]]:
def get_sumo_registered_tasks() -> dict[str, TaskRegistration]:
"""Return only the task registrations owned by sumo."""
registered_tasks = get_registered_tasks()
return {task_name: registered_tasks[task_name] for task_name in SUMO_TASK_NAMES if task_name in registered_tasks}
1 change: 0 additions & 1 deletion sumo/tasks/g1/g1_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ class G1Base(Task[ConfigT], Generic[ConfigT]):
"""Base class for G1 tasks."""

config_t: type[G1BaseConfig] = G1BaseConfig # type: ignore[assignment]
default_backend = "mujoco_g1" # Use G1-specific backend

def _process_spec(self) -> None:
"""No-op for G1 tasks (meshes are local)."""
Expand Down
12 changes: 4 additions & 8 deletions sumo/tasks/g1/g1_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,20 +190,16 @@ def reset_pose(self) -> np.ndarray:
]
)

def success(
self, model: MjModel, data: MjData, config: G1BoxConfig, metadata: dict[str, Any] | None = None
) -> bool:
def success(self, model: MjModel, data: MjData, metadata: dict[str, Any] | None = None) -> bool:
"""Check if the box is in the goal position."""
object_pos = data.qpos[..., self.object_pose_idx[0:2]] # XY only
object_vel = data.qvel[..., self.object_vel_idx[0:3]]
goal_pos = np.array(config.goal_position[:2]) # XY only
goal_pos = np.array(self.config.goal_position[:2]) # XY only
Comment thread
dta-bdai marked this conversation as resolved.
position_check = np.linalg.norm(object_pos - goal_pos, axis=-1, ord=np.inf) < POSITION_TOLERANCE
velocity_check = np.linalg.norm(object_vel, axis=-1) < VELOCITY_TOLERANCE
return position_check and velocity_check

def failure(
self, model: MjModel, data: MjData, config: G1BoxConfig, metadata: dict[str, Any] | None = None
) -> bool:
def failure(self, model: MjModel, data: MjData, metadata: dict[str, Any] | None = None) -> bool:
"""Check if G1 has fallen."""
body_height = data.qpos[..., self.body_pose_idx[2]]
return bool(body_height <= config.fall_threshold)
return bool(body_height <= self.config.fall_threshold)
12 changes: 4 additions & 8 deletions sumo/tasks/g1/g1_chair_push.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,20 +184,16 @@ def reset_pose(self) -> np.ndarray:
]
)

def success(
self, model: MjModel, data: MjData, config: G1ChairPushConfig, metadata: dict[str, Any] | None = None
) -> bool:
def success(self, model: MjModel, data: MjData, metadata: dict[str, Any] | None = None) -> bool:
"""Check if the chair is in the goal position."""
object_pos = data.qpos[..., self.object_pose_idx[0:2]] # XY only
object_vel = data.qvel[..., self.object_vel_idx[0:3]]
goal_pos = np.array(config.goal_position[:2]) # XY only
goal_pos = np.array(self.config.goal_position[:2]) # XY only
Comment thread
dta-bdai marked this conversation as resolved.
position_check = np.linalg.norm(object_pos - goal_pos, axis=-1, ord=np.inf) < POSITION_TOLERANCE
velocity_check = np.linalg.norm(object_vel, axis=-1) < VELOCITY_TOLERANCE
return position_check and velocity_check

def failure(
self, model: MjModel, data: MjData, config: G1ChairPushConfig, metadata: dict[str, Any] | None = None
) -> bool:
def failure(self, model: MjModel, data: MjData, metadata: dict[str, Any] | None = None) -> bool:
"""Check if G1 has fallen."""
body_height = data.qpos[..., self.body_pose_idx[2]]
return bool(body_height <= config.fall_threshold)
return bool(body_height <= self.config.fall_threshold)
12 changes: 4 additions & 8 deletions sumo/tasks/g1/g1_door.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,18 +208,14 @@ def reset_pose(self) -> np.ndarray:
]
)

def success(
self, model: MjModel, data: MjData, config: G1DoorConfig, metadata: dict[str, Any] | None = None
) -> bool:
def success(self, model: MjModel, data: MjData, metadata: dict[str, Any] | None = None) -> bool:
"""Check if the robot has reached the target XY position."""
body_pos = data.qpos[..., self.body_pose_idx[0:2]] # Get XY position only
goal_pos = np.array(config.goal_position[:2]) # Get XY goal only
goal_pos = np.array(self.config.goal_position[:2]) # Get XY goal only
position_check = np.linalg.norm(body_pos - goal_pos, axis=-1) < POSITION_TOLERANCE
Comment thread
dta-bdai marked this conversation as resolved.
return bool(position_check)

def failure(
self, model: MjModel, data: MjData, config: G1DoorConfig, metadata: dict[str, Any] | None = None
) -> bool:
def failure(self, model: MjModel, data: MjData, metadata: dict[str, Any] | None = None) -> bool:
"""Check if G1 has fallen."""
body_height = data.qpos[..., self.body_pose_idx[2]]
return bool(body_height <= config.fall_threshold)
return bool(body_height <= self.config.fall_threshold)
Loading