Implementation of Model-Agnostic Meta-Learning (MAML) for Reinforcement Learning using PyTorch and TorchRL.
This codebase implements MAML for continuous control tasks with:
- Inner Loop: Vanilla Policy Gradient (VPG) with differentiable updates
- Outer Loop: PPO or TRPO for stable meta-training
- Functional Architecture:
torch.funcfor stateless execution andvmapfor task parallelization
uv sync# Train with TRPO (default)
uv run python run.py --cfg configs/config.yaml
# Train with PPO
uv run python run.py --cfg configs/config.yaml --cfg.algorithm ppo
# Evaluate
uv run python run.py --mode eval --checkpoint checkpoints/model.pt --cfg configs/nav_eval.yamlOverride any config parameter via CLI:
uv run python run.py --cfg configs/config.yaml --cfg.outer.lr 0.0001 --cfg.inner.num_steps 3# 1. run MAML
uv run python run.py --cfg configs/nav.yaml --cfg.wandb.name maml_nav
# 2. train pretrained baseline
uv run python run.py --cfg configs/nav.yaml --cfg.inner.num_steps 0 --cfg.wandb.name pretrained_nav
# 3. train oracle
uv run python run.py --cfg configs/nav.yaml --cfg.oracle true --cfg.inner.num_steps 0 --cfg.wandb.name oracle_nav
# 4. run evaluation
uv run python run.py --mode eval --checkpoint checkpoints/maml_nav/model.pt --pretrained_checkpoint checkpoints/pretrained_nav/model.pt --oracle_checkpoint checkpoints/oracle_nav/model.pt --cfg configs/nav_eval.yaml --cfg.wandb.name eval_nav| Environment | Task | Obs Dim | Act Dim |
|---|---|---|---|
navigation |
Reach 2D goal position | 2 | 2 |
ant |
Match target velocity | 27 | 8 |
src/maml_rl/
├── maml.py # Inner loop (VPG), outer loop (PPO/TRPO), FunctionalPolicy
├── training.py # Meta-training loop
├── evaluation.py # Evaluation against baselines
├── policies.py # Actor-critic architecture
├── envs/
│ ├── base.py # MetaEnv protocol
│ ├── factory.py # Environment registry
│ ├── navigation.py
│ └── ant.py
└── utils/
├── returns.py # GAE computation
├── device.py # Device/WandB setup
└── optimization.py # TRPO conjugate gradient
The evaluation compares MAML against:
| Baseline | Description |
|---|---|
| Random Init | Randomly initialized policy adapted on test tasks |
| Pretrained | Policy trained without meta-learning, then adapted |
| Oracle | Policy trained with task parameters in observation (upper bound) |
Implement the MetaEnv protocol (see src/maml_rl/envs/base.py) and register it. Reference navigation.py or ant.py for examples.
Your environment class must implement:
| Method | Purpose |
|---|---|
set_task(task) |
Configure env for a specific task |
sample_tasks(num_tasks, low, high) |
Sample task specifications |
get_task_obs_dim() |
Dimension of task params for oracle |
make_vec_env(tasks, ...) |
Create parallel env |
make_oracle_vec_env(tasks, ...) |
Create parallel env with task in obs |
get_oracle(tasks, device, checkpoint) |
Load oracle policy |
- Create env class in
src/maml_rl/envs/my_env.pyextendinggymnasium.Env - Create oracle wrapper that appends task params to observations
- Register in
src/maml_rl/envs/factory.py:from maml_rl.envs.my_env import MyEnv ENV_REGISTRY["my_env"] = MyEnv
- Create config
configs/my_env.yamlwith env name and task bounds - Train:
uv run python run.py --cfg configs/my_env.yaml
Following the MAML paper, an oracle is a policy that receives task parameters (goal position, velocity) directly in its observation. Since the oracle knows the task, it provides an upper-bound baseline for evaluation.
Oracle environments extend the observation space by appending task parameters:
| Environment | Standard Obs | Oracle Obs |
|---|---|---|
navigation |
[x, y] (dim 2) |
[x, y, goal_x, goal_y] (dim 4) |
ant |
state (dim 27) | [state, goal_velocity] (dim 28) |
uv run python run.py --cfg configs/nav.yaml --cfg.oracle true --cfg.inner.num_steps 0This trains a policy on Oracle observations (task params included). Use the checkpoint during evaluation with --oracle_checkpoint.
Key config options in configs/base.py:
| Parameter | Default | Description |
|---|---|---|
algorithm |
"trpo" |
Outer loop algorithm (trpo or ppo) |
inner.num_steps |
1 |
Inner loop gradient steps |
inner.lr |
0.1 |
Inner loop learning rate |
outer.lr |
0.0003 |
Outer loop learning rate (PPO only) |
env.num_tasks |
8 |
Number of parallel tasks |
env.max_steps |
200 |
Maximum steps per episode |
rollout_steps |
200 |
Steps collected per task for adaptation |
gamma |
0.99 |
Discount factor |
lam |
0.95 |
GAE lambda |
| Parameter | Default | Description |
|---|---|---|
trpo.max_kl |
0.01 |
Maximum KL divergence constraint |
trpo.damping |
0.1 |
Conjugate gradient damping factor |
trpo.cg_iters |
10 |
Conjugate gradient iterations |
| Parameter | Default | Description |
|---|---|---|
outer.clip_eps |
0.2 |
PPO clipping epsilon |
outer.ppo_epochs |
5 |
PPO epochs per outer step |
outer.entropy_coef |
0.0 |
Entropy bonus coefficient |