Skip to content

Latest commit

 

History

History
68 lines (55 loc) · 3.17 KB

File metadata and controls

68 lines (55 loc) · 3.17 KB

CroSTAta: Cross-State Transition Attention Transformer for Robotic Manipulation

This repository contains the implementation code for CroSTAta: Cross-State Transition Attention Transformer for Robotic Manipulation. CroSTAta introduces a novel cross-state transition attention mechanism designed to improve robotic manipulation tasks by better capturing temporal dependencies and state transitions in sequential data.

CroSTAta Banner

Structure

Source code outline

<src>
│
├───algos/ : source algorithms class
├───common/ : shared utilities and environments
│   ├───datasets/ : dataset handling and processing
│   └───envs/ : environment wrappers and utilities
├───configs/ : agent-and-environment-specific algorithm configurations
│   └───agents/ : Lerobot (training) and Maniskill (inference) configurations
├───data/ : utils for training datasets
├───docs/ : documentation files
├──(save/ : model checkpoints)
├──(wandb/ : experiment tracking logs)
├───train.py : offline training script
└───predict.py : evaluation and inference script

Installation

git clone git@github.com:iit-DLSLab/croSTAta.git
cd croSTAta/
pip install -r requirements.txt

Data

Task Name Dataset
StackCube-v1 ManiSkill_StackCube-v1_recovery
PegInsertionSide-v1 ManiSkill_PegInsertionSide-v1_recovery
TwoRobotStackCube-v1 ManiSkill_TwoRobotStackCube-v1_recovery
UnitreeG1TransportBox-v1 ManiSkill_UnitreeG1TransportBox-v1_recovery

Training

train.py --task <env name e.g. PegInsertionSide-v1> --envsim maniskill --num_envs 1 --val_episodes 100 --agent Maniskill/<cfg_file e.g. maniskill_sl_inference_cfg> --device cuda --sim_device cuda (--wandb) (--resume --checkpoint save/<model_name>.zip --wandb_run <id>)

Evaluation

predict.py --task PegInsertionSide-v1 --envsim maniskill --num_envs 1 --val_episodes 100 --agent Maniskill/<cfg_file> --device cuda --sim_device cuda --resume --checkpoint save/<model_name>

Note: sl_agent inference uses by default the policy's method predict_batch (batch prediction). This has been used for official evaluation of method and baselines. For efficient inference use the policy's method predict (prediction with cache). Results from the two implementation may slightly differ.

Checkpoints

We provide baseline and method checkpoints for all tasks at this link.

Cite

@article{minelli2025crostata,
  title={CroSTAta: Cross-State Transition Attention Transformer for Robotic Manipulation},
  author={Minelli, Giovanni and Turrisi, Giulio and Barasuol, Victor and Semini, Claudio},
  year={2025}
}