Skip to content

async collect buffer for VLA RL#122

Open
yhnsu wants to merge 11 commits intomainfrom
yhn/rl_vla
Open

async collect buffer for VLA RL#122
yhnsu wants to merge 11 commits intomainfrom
yhn/rl_vla

Conversation

@yhnsu
Copy link
Collaborator

@yhnsu yhnsu commented Feb 6, 2026

RL Training Framework Guide

TensorDict-based RL framework supporting standard PPO and asynchronous VLA training.


Quick Start

Configuration

{
  "trainer": {
    "buffer_size": 2048,
    "model_type": "standard"  // or "vla"
  },
  "policy": {"name": "actor_critic"},
  "algorithm": {
    "name": "ppo",
    "cfg": {
      "learning_rate": 3e-4,
      "gamma": 0.99,
      "n_epochs": 10,
      "batch_size": 64
    }
  }
}

Run Training

python embodichain/agents/rl/train.py --config configs/agents/rl/my_config.json

Architecture

Trainer → Collector (sync/async) → Buffer (standard/vla) → Algorithm (PPO)

Components:

  • Collector: Gather data from environment (SyncCollector / AsyncCollector)
  • Buffer: Store transitions (RolloutBuffer / VLABuffer)
  • Algorithm: Update policy (PPO)
  • Trainer: Coordinate training loop

Training Modes

Standard Mode (Default)

For: Normal models (<100ms inference/step)

SyncCollector → Collect 2048 steps → Train → Clear buffer → Repeat

Config: {"trainer": {"model_type": "standard"}}

Pros: Simple, stable, low memory, no staleness

VLA Async Mode

For: Large models (>1 sec inference/step)

Background: AsyncCollector → Continuously collect → VLABuffer
Main:       Wait for buffer full → Train → Repeat

Config: {"trainer": {"model_type": "vla"}}

Pros: 2-3x speedup via parallel collection
Cons: Data staleness, higher memory


Collectors

SyncCollector

Collects complete rollout synchronously:

from embodichain.agents.rl.collector import SyncCollector

collector = SyncCollector(env, policy, device, callback)
rollout = collector.collect(num_steps=2048)  # [T, N, ...]

AsyncCollector

Runs in background thread:

from embodichain.agents.rl.collector import AsyncCollector

collector = AsyncCollector(env, policy, buffer, device, callback)
collector.start()   # Begin background collection
# ... buffer fills automatically ...
collector.stop()    # Stop collection

Buffers

RolloutBuffer (Standard)

Single-use buffer:

from embodichain.agents.rl.buffer import RolloutBuffer

buffer = RolloutBuffer(buffer_size=2048, device=device)
buffer.add(rollout)  # [T, N, ...]
data = buffer.get(flatten=True)  # [T*N, ...], auto-clears

VLABuffer (Async)

Circular FIFO buffer:

from embodichain.agents.rl.buffer import VLABuffer

buffer = VLABuffer(buffer_size=4096, device=device)
buffer.add(transition)  # Single step
data = buffer.get(flatten=True)  # [buffer_size, ...] when full

Circular behavior: [T0,T1,T2,T3] → add T4 → [T4,T1,T2,T3] (T0 overwritten)


VLA Integration

1. Implement Model

class MyVLAModel(nn.Module):
    def forward(self, obs: TensorDict) -> TensorDict:
        # Add 'action', 'sample_log_prob', 'value'
        ...
    def get_value(self, obs: TensorDict) -> TensorDict:
        # Add 'value'
        ...
    def evaluate_actions(self, obs: TensorDict) -> TensorDict:
        # Add 'sample_log_prob', 'entropy', 'value'
        ...

2. Implement Loading

Edit embodichain/agents/rl/models/vla_policy.py:

def load_vla_model(model_path, model_class, model_config, device):
    model = MyVLAModel(**model_config)
    model.load_state_dict(torch.load(model_path))
    return model.to(device)

3. Configure

{
  "trainer": {"model_type": "vla"},
  "policy": {
    "name": "vla",
    "vla_config": {
      "model_path": "checkpoints/vla.pt",
      "model_class": "MyVLAModel",
      "model_config": {}
    }
  }
}

Common APIs

Trainer

from embodichain.agents.rl.utils import Trainer

trainer = Trainer(
    policy, env, algorithm,
    buffer_size=2048,
    model_type="standard",  # or "vla"
    ...
)
trainer.train(total_timesteps=1000000)

Buffer Methods

buffer.add(data)            # Add data
data = buffer.get(flatten=True)  # Retrieve data
buffer.is_full()            # Check ready status
buffer.clear()              # Clear buffer
buffer.get_stats()          # Statistics

Algorithm

from embodichain.agents.rl.algo import PPO, PPOCfg

algorithm = PPO(PPOCfg(...), policy)
losses = algorithm.update(rollout)  # Returns loss dict

FAQ

Q: When use VLA mode?
A: Inference >100ms/step AND GPU training fast

Q: Buffer size?
A: Standard: 2048-4096 (rollout size). VLA: 2048-4096 (buffer capacity)

Q: Data staleness impact?
A: Minor. PPO robust to staleness. 2-3x speedup >> small penalty

Q: Debug data flow?
A: buffer.get_stats() or _print_tensordict_tree(rollout) in ppo.py


Workflows

Standard

collector = SyncCollector(env, policy, device, callback)
while step < total:
    rollout = collector.collect(num_steps=2048)
    buffer.add(rollout)
    data = buffer.get(flatten=True)
    losses = algorithm.update(data)

VLA

collector = AsyncCollector(env, policy, buffer, device, callback)
collector.start()
while step < total:
    while not buffer.is_full():
        time.sleep(0.1)
    data = buffer.get(flatten=True)
    losses = algorithm.update(data)
collector.stop()

File Structure

embodichain/agents/rl/
├── train.py              # Entry point
├── algo/ppo.py          # PPO algorithm
├── buffer/
│   ├── standard_buffer.py  # RolloutBuffer
│   └── vla_buffer.py       # VLABuffer
├── collector/
│   ├── base.py             # BaseCollector
│   ├── sync_collector.py   # SyncCollector
│   └── async_collector.py  # AsyncCollector
├── models/
│   ├── actor_critic.py     # Standard policy
│   └── vla_policy.py       # VLA wrapper
└── utils/trainer.py     # Training coordinator

References

Copilot AI review requested due to automatic review settings February 6, 2026 04:22
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request introduces a comprehensive refactoring of the RL training framework to use TensorDict-based data flow, replacing the previous tensor-based approach. The PR adds support for two training modes: standard synchronous PPO and asynchronous VLA training designed for scenarios with slow model inference.

Changes:

  • Migrated entire RL pipeline to TensorDict-based architecture for structured, extensible data flow
  • Introduced dual buffer system: RolloutBuffer (standard) and VLABuffer (async with FIFO)
  • Added AsyncCollector for background data collection in VLA mode with thread-based parallelism
  • Refactored Policy interface to use TensorDict inputs/outputs with in-place modifications
  • Updated PPO algorithm to work with TensorDict rollouts and removed dependency on gym spaces
  • Modified configuration to use buffer_size instead of rollout_steps and added action_dim requirement

Reviewed changes

Copilot reviewed 18 out of 18 changed files in this pull request and generated 15 comments.

Show a summary per file
File Description
embodichain/agents/rl/utils/trainer.py Refactored to support dual training modes (sync/async) with TensorDict
embodichain/agents/rl/utils/helper.py Added dict_to_tensordict, compute_gae, and logging utilities
embodichain/agents/rl/utils/async_collector.py New async data collector for VLA mode with background thread
embodichain/agents/rl/buffer/rollout_buffer.py Renamed/refactored to VLABuffer with circular indexing
embodichain/agents/rl/buffer/standard_buffer.py New RolloutBuffer for standard PPO mode
embodichain/agents/rl/buffer/init.py Updated exports for dual buffer system
embodichain/agents/rl/algo/ppo.py Refactored to use TensorDict data flow throughout
embodichain/agents/rl/algo/base.py Updated base algorithm interface for TensorDict
embodichain/agents/rl/models/policy.py Changed interface to TensorDict-based methods
embodichain/agents/rl/models/actor_critic.py Implemented TensorDict-based policy with in-place modifications
embodichain/agents/rl/models/init.py Removed gymnasium dependency, added action_dim parameter
embodichain/agents/rl/train.py Added action_dim requirement, removed gym space dependency
tests/agents/test_rl.py Updated test to use buffer_size parameter
configs/agents/rl/push_cube/train_config.json Updated config with buffer_size, action_dim, and eval_freq
configs/agents/rl/basic/cart_pole/train_config.json Updated config with buffer_size
docs/source/tutorial/rl.rst Updated documentation to reference buffer_size
pyproject.toml Added tensordict>=0.5.0 dependency
Comments suppressed due to low confidence (1)

embodichain/agents/rl/train.py:289

  • The buffer_type parameter is not read from the trainer config and not passed to the Trainer constructor (line 273-289). This means the VLA async mode introduced in this PR cannot be used, as it will always default to "standard" mode. Add buffer_type = trainer_cfg.get("buffer_type", "standard") before the Trainer initialization and pass it as buffer_type=buffer_type to the Trainer constructor.
    trainer = Trainer(
        policy=policy,
        env=env,
        algorithm=algo,
        buffer_size=buffer_size,
        batch_size=algo_cfg["batch_size"],
        writer=writer,
        eval_freq=eval_freq if enable_eval else 0,  # Disable eval if not enabled
        save_freq=save_freq,
        checkpoint_dir=checkpoint_dir,
        exp_name=exp_name,
        use_wandb=use_wandb,
        eval_env=eval_env,  # None if enable_eval=False
        event_cfg=train_event_cfg,
        eval_event_cfg=eval_event_cfg if enable_eval else {},
        num_eval_episodes=num_eval_episodes,
    )

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.


# Update global step
num_envs = tensordict.batch_size[0]
self.global_step += num_envs
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The self.global_step variable is updated from the async collector thread (line 182 via callback) and potentially read from the main thread (lines 214, 244, 255). This creates a race condition. Consider using a thread-safe counter (e.g., threading.Lock protection or multiprocessing.Value) or tracking steps only in one thread.

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings February 6, 2026 07:51
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 26 out of 26 changed files in this pull request and generated 14 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +113 to +158
def get(self, flatten: bool = True) -> TensorDict:
"""Get valid data from buffer.

Args:
flatten: If True, return flattened [size, ...]. Currently only supports True.

Returns:
TensorDict with batch_size=[size, ...] containing valid data
"""
if not self._initialized or self.size == 0:
raise ValueError("Buffer is empty")

if not flatten:
raise NotImplementedError("Only flatten=True is supported for VLABuffer")

# Return first 'size' elements (valid data)
# Note: Data is in insertion order up to write_pos, then wraps
if self.size < self.buffer_size:
# Buffer not yet full, data is [0:size]
return self.buffer[: self.size]
else:
# Buffer full, need to rearrange to maintain temporal order
# Oldest data is at write_pos, newest at write_pos-1
indices = (
torch.arange(
self.write_pos,
self.write_pos + self.buffer_size,
device=self.device,
)
% self.buffer_size
)
return self.buffer[indices]

def clear(self) -> None:
"""Clear buffer (reset pointers, keep pre-allocated memory)."""
self.write_pos = 0
self.size = 0
# Keep buffer allocated for reuse

def __len__(self) -> int:
"""Return current number of valid transitions."""
return self.size

def is_full(self) -> bool:
"""Check if buffer is at full buffer_size."""
return self.size >= self.buffer_size
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The VLABuffer.get() and is_full() methods are called from the main thread while AsyncCollector writes to the buffer from a background thread, but these methods lack thread safety. The read of self.size and self.write_pos could return inconsistent values if a write is in progress. Additionally, buffer.get() performs complex operations (checking size, slicing buffer) that should be atomic with respect to concurrent writes. Consider adding thread synchronization or document that external locking is required.

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'v added lock for get/is_full for thread safety

if deterministic:
action = mean
else:
dist = Normal(mean, std)
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The distribution is created twice when deterministic=False. Line 130 creates dist = Normal(mean, std), then lines 136-137 create it again. This is wasteful. Consider refactoring to create the distribution once and use either dist.mean or dist.sample() based on the deterministic flag.

Suggested change
dist = Normal(mean, std)

Copilot uses AI. Check for mistakes.
Comment on lines +195 to +201
next_value_td = TensorDict(
{"observation": next_obs_for_td},
batch_size=next_td.batch_size,
device=self.device,
)
self.policy.get_value(next_value_td)
next_td["value"] = next_value_td["value"]
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The policy is accessed from both the background collector thread (lines 145-146, 200) and potentially from the main training thread during algorithm.update(). PyTorch tensors and models are not thread-safe by default. Concurrent access to the policy parameters during forward passes and gradient updates can lead to race conditions and corrupted gradients. Consider using locks to synchronize policy access, or ensure the policy is not being updated while the collector is running (e.g., by stopping collection during training).

Suggested change
next_value_td = TensorDict(
{"observation": next_obs_for_td},
batch_size=next_td.batch_size,
device=self.device,
)
self.policy.get_value(next_value_td)
next_td["value"] = next_value_td["value"]
# Protect policy access with lock to avoid races with training thread
with self._lock:
next_value_td = TensorDict(
{"observation": next_obs_for_td},
batch_size=next_td.batch_size,
device=self.device,
)
self.policy.get_value(next_value_td)
next_td["value"] = next_value_td["value"]

Copilot uses AI. Check for mistakes.

losses = self.algorithm.update(data)
self._log_train(losses)

Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After buffer.get() is called on line 238, the VLABuffer is not cleared (unlike RolloutBuffer which auto-clears). Since the buffer is full (size == buffer_size), the is_full() check on line 232 will immediately return True in the next iteration, causing the training loop to repeatedly train on the same data without waiting for new transitions. The buffer should be cleared after get(), or the is_full() logic should be modified to track whether data has been consumed.

Suggested change
# Clear async buffer after consumption to avoid retraining on stale data
if hasattr(self.buffer, "clear"):
self.buffer.clear()

Copilot uses AI. Check for mistakes.
Comment on lines +121 to +122
# Prepare next iteration - use the converted TensorDict
current_td = next_obs_td
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The collector does not handle episode resets when done=True. After an episode terminates (done flag is set), the environment should be reset to get a fresh initial observation for the next episode. Currently, the collector continues using next_obs even after termination, which could contain stale data. Most RL environments auto-reset on episode end, but this should be made explicit or documented as a requirement.

Suggested change
# Prepare next iteration - use the converted TensorDict
current_td = next_obs_td
# Prepare next iteration:
# - if episode is done, reset env to get a fresh initial observation
# - otherwise, continue from next_obs_td
if done.any():
reset_result = self.env.reset()
# Support both Gym/Gymnasium-style (obs, info) and plain-obs resets
if isinstance(reset_result, tuple):
reset_obs = reset_result[0]
else:
reset_obs = reset_result
current_td = dict_to_tensordict(reset_obs, self.device)
else:
current_td = next_obs_td

Copilot uses AI. Check for mistakes.
Comment on lines +114 to +115
# Store complete transition
rollout_list.append(current_td.clone())
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling .clone() on every transition creates a full copy of the TensorDict including all nested tensors, which can be memory-intensive for large rollouts. Since current_td is reassigned to next_obs_td on line 122 (which is a fresh TensorDict), the clone may be unnecessary. Consider whether a shallow copy or reference would suffice, or document why deep cloning is required here.

Suggested change
# Store complete transition
rollout_list.append(current_td.clone())
# Store complete transition (no clone needed: current_td is not mutated afterwards)
rollout_list.append(current_td)

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

@yangchen73 yangchen73 Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. Every loop does current_td["next"] = next_td and then current_td = next_obs_td. If we don't use clone(), every appended element is the same TensorDict reference. Then the next loop overwrites its contents. As a result, every entry in rollout_list points to the same modified data.

Comment on lines +240 to +242
# Update global step based on collected data (main thread only)
batch_size = data.batch_size[0] if len(data.batch_size) > 0 else 0
self.global_step += batch_size
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The global_step update in async mode only counts batch_size from the returned data (line 242), not the actual number of environment steps taken. Since VLABuffer is continuously being written to by AsyncCollector (which tracks steps in _step_count), the global_step will not accurately reflect the total number of environment interactions. Consider synchronizing global_step with the collector's _step_count, or documenting this discrepancy.

Suggested change
# Update global step based on collected data (main thread only)
batch_size = data.batch_size[0] if len(data.batch_size) > 0 else 0
self.global_step += batch_size
# Update global step.
# Prefer the collector's step count (actual env interactions) if available,
# otherwise fall back to counting processed batch size.
batch_size = data.batch_size[0] if len(data.batch_size) > 0 else 0
steps_from_collector = getattr(collector, "_step_count", None)
if isinstance(steps_from_collector, int) and steps_from_collector > self.global_step:
self.global_step = steps_from_collector
else:
self.global_step += batch_size

Copilot uses AI. Check for mistakes.
Comment on lines 146 to 167
@@ -166,30 +154,113 @@ def on_step(obs, actions, reward, done, info, next_obs):
self.curr_ret[done_idx] = 0
self.curr_len[done_idx] = 0

# Update global step and observation
# next_obs is already flattened in algorithm's collect_rollout
self.obs = next_obs
self.global_step += next_obs.shape[0]

if isinstance(info, dict):
rewards_dict = info.get("rewards")
metrics_dict = info.get("metrics")
# Log environment metrics
if isinstance(env_info, dict):
rewards_dict = env_info.get("rewards")
metrics_dict = env_info.get("metrics")
self._log_scalar_dict("rewards", rewards_dict)
self._log_scalar_dict("metrics", metrics_dict)
log_dict = {}
log_dict.update(self._pack_log_dict("rewards", rewards_dict))
log_dict.update(self._pack_log_dict("metrics", metrics_dict))
log_dict.update(pack_log_dict("rewards", rewards_dict))
log_dict.update(pack_log_dict("metrics", metrics_dict))
if log_dict and self.use_wandb:
wandb.log(log_dict, step=self.global_step)
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The on_step callback modifies shared state (self.curr_ret, self.curr_len, self.ret_window, self.len_window, self.global_step) without thread synchronization. In async mode, this callback runs in the AsyncCollector background thread while the main thread could be accessing these same variables (e.g., in _log_train). This can cause race conditions and data corruption. Use threading.Lock to protect access to these shared variables, or ensure they're only accessed from one thread.

Copilot uses AI. Check for mistakes.
Comment on lines +58 to +60
def collect(self, **kwargs) -> TensorDict:
"""Collect data from environment.

Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overridden method signature does not match call, where it is passed too many arguments. Overriding method method SyncCollector.collect matches the call.
Overridden method signature does not match call, where it is passed an argument named 'num_steps'. Overriding method method SyncCollector.collect matches the call.

Suggested change
def collect(self, **kwargs) -> TensorDict:
"""Collect data from environment.
def collect(self, num_steps: int, **kwargs) -> TensorDict:
"""Collect data from environment.
Args:
num_steps: Number of steps to collect.

Copilot uses AI. Check for mistakes.
Comment on lines +38 to +46
def collect(self, num_steps: int) -> TensorDict:
"""Collect a synchronous rollout.

Args:
num_steps: Number of steps to collect

Returns:
TensorDict with batch_size=[T, N] containing full rollout
"""
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method requires 2 positional arguments, whereas overridden BaseCollector.collect requires 1.

Suggested change
def collect(self, num_steps: int) -> TensorDict:
"""Collect a synchronous rollout.
Args:
num_steps: Number of steps to collect
Returns:
TensorDict with batch_size=[T, N] containing full rollout
"""
def collect(self, num_steps: int | None = None) -> TensorDict:
"""Collect a synchronous rollout.
Args:
num_steps: Number of steps to collect.
Returns:
TensorDict with batch_size=[T, N] containing full rollout
"""
if num_steps is None:
raise TypeError("num_steps must be provided for SyncCollector.collect()")

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

@yangchen73 yangchen73 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • need to use clone
  • add locks

Copilot AI review requested due to automatic review settings March 2, 2026 03:01
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 26 out of 26 changed files in this pull request and generated 14 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +250 to +253
self.global_step += data.batch_size[0] if data.batch_size else 0

losses = self.algorithm.update(data)
self._log_train(losses)
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In async mode, the collector thread calls policy.forward/get_value while the main thread simultaneously runs algorithm.update() (backprop + optimizer.step) on the same policy. PyTorch modules/optimizers are not thread-safe, and concurrent CUDA kernels/parameter updates can lead to nondeterminism or crashes. Consider pausing collection during update, protecting policy access with a lock, or using a separate inference copy of the policy for the collector (periodically synced).

Copilot uses AI. Check for mistakes.
Comment on lines +124 to +128
### VLABuffer (Async)

Circular FIFO buffer:

```python
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The guide describes VLABuffer as a circular FIFO buffer with step-level add(transition) and shows buffer.get(flatten=True), but the implementation only accepts full rollouts via add_rollout() and PPO’s GAE requires the unflattened [N, T] layout. Please update this section to match the actual VLABuffer API and required shapes, otherwise users will hit AttributeErrors or compute incorrect advantages.

Copilot uses AI. Check for mistakes.
log_dict.update(pack_log_dict("rewards", rewards_dict))
log_dict.update(pack_log_dict("metrics", metrics_dict))
if log_dict and self.use_wandb:
wandb.log(log_dict, step=self.global_step)
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on_step_callback logs to W&B using self.global_step, but global_step is only incremented once per rollout in _train_sync/_train_async. Because on_step_callback is invoked for every env step, these logs will repeatedly use the same step value (overwriting or producing a flat x-axis). Consider incrementing global_step inside the callback (e.g., by num_envs per env.step) or passing an explicit per-step counter into the callback for logging.

Suggested change
wandb.log(log_dict, step=self.global_step)
# Use a dedicated per-environment-step counter for W&B logging.
# Lazily initialize it so we don't depend on __init__ details.
env_log_step = getattr(self, "_env_log_step", 0)
# Increment by the number of parallel environments (reward batch size).
if isinstance(reward, torch.Tensor):
env_log_step += reward.shape[0]
else:
env_log_step += 1
self._env_log_step = env_log_step
wandb.log(log_dict, step=env_log_step)

Copilot uses AI. Check for mistakes.
Comment on lines 51 to +53
@abstractmethod
def get_action(
self, obs: torch.Tensor, deterministic: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Sample an action from the policy.
def forward(self, tensordict: TensorDict) -> TensorDict:
"""Forward pass that adds action to the input tensordict (in-place).
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Policy.forward() in the abstract base class does not accept a deterministic argument, but Trainer._eval_once calls self.policy.forward(..., deterministic=True) and concrete implementations (ActorCritic/VLAPolicy) already support it. This makes the interface inconsistent and can cause TypeErrors for any other Policy implementation that follows the base signature. Consider updating the abstract method signature to include deterministic: bool = False (and documenting the expected behavior).

Copilot uses AI. Check for mistakes.
Comment on lines +245 to +249
while step < total:
rollout = collector.collect(num_steps=2048)
buffer.add(rollout)
data = buffer.get(flatten=True)
losses = algorithm.update(data)
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This workflow calls buffer.get(flatten=True) and passes the flattened result to algorithm.update(). With the current PPO implementation, flattened input is treated as a [size, 1] rollout, making GAE effectively run with T=1 and producing incorrect advantages/targets. Update the example to pass the unflattened [N, T] rollout into update (flatten only inside PPO for minibatching).

Copilot uses AI. Check for mistakes.
Comment on lines 10 to 14
"num_envs": 64,
"iterations": 1000,
"rollout_steps": 1024,
"buffer_size": 1024,
"eval_freq": 2,
"save_freq": 200,
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This config was updated to use trainer.buffer_size, but train.py now requires policy.action_dim to be present. As-is, running this config will raise "Missing 'action_dim' in policy config". Consider adding an explicit action_dim to the policy block (and optionally trainer.model_type) so the example remains runnable.

Copilot uses AI. Check for mistakes.
Comment on lines +79 to +82
# Initialize observation and get num_envs (needed for VLA buffer)
obs, _ = env.reset()
self.obs_tensordict = dict_to_tensordict(obs, device)
num_envs = self.obs_tensordict.batch_size[0]
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trainer.init calls env.reset() to infer num_envs / seed obs, but both SyncCollector/AsyncCollector (via BaseCollector) also call env.reset(). This extra reset can be expensive and can change the initial state/episode accounting. Since the env exposes num_envs (used elsewhere in train.py), consider using env.num_envs (and/or env.device) instead of resetting here, and avoid storing an unused obs_tensordict.

Suggested change
# Initialize observation and get num_envs (needed for VLA buffer)
obs, _ = env.reset()
self.obs_tensordict = dict_to_tensordict(obs, device)
num_envs = self.obs_tensordict.batch_size[0]
# Initialize num_envs without forcing a reset when possible
if hasattr(env, "num_envs"):
num_envs = env.num_envs
# No need to create an initial obs_tensordict here; collectors will reset the env.
self.obs_tensordict = None
else:
# Fallback for environments that do not expose num_envs
obs, _ = env.reset()
self.obs_tensordict = dict_to_tensordict(obs, device)
num_envs = self.obs_tensordict.batch_size[0]

Copilot uses AI. Check for mistakes.
Comment on lines +17 to 21
"""Helper utilities for RL training.

This module provides utility functions for RL algorithms.
"""

Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flatten_dict_observation() was removed from this module, but embodichain/lab/gym/envs/base_env.py still imports it to build flattened_observation_space. This will raise ImportError at runtime. Either restore a compatible flatten_dict_observation helper here (for backward compatibility) or update base_env.py in this PR to use the new TensorDict-based utilities.

Copilot uses AI. Check for mistakes.
__all__ = ["RolloutBuffer"]
Provides two buffer implementations:
- RolloutBuffer: Standard PPO buffer (single rollout, use and discard)
- VLABuffer: VLA buffer (FIFO multi-rollout accumulation for slow inference)
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The buffer module docstring says VLABuffer is “FIFO multi-rollout accumulation”, but VLABuffer currently stores exactly one rollout (present/None). This mismatch is likely to confuse users; update the docstring to reflect the current behavior, or adjust VLABuffer to match the documented semantics.

Suggested change
- VLABuffer: VLA buffer (FIFO multi-rollout accumulation for slow inference)
- VLABuffer: VLA buffer (single-rollout accumulation optimized for slow inference)

Copilot uses AI. Check for mistakes.
Comment on lines +96 to 99
# Ensure 2D format [T, N] for GAE computation
if len(rollout.batch_size) == 1:
rollout = rollout.unsqueeze(1) # [size] -> [size, 1]

Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PPO.update() claims to support receiving a flattened rollout (batch_size=[size]) by doing rollout.unsqueeze(1), but that turns it into [size, 1] and makes GAE run with T=1 (incorrect unless the original rollout length was 1). Either require callers to pass an unflattened [N, T] / [T, N] rollout, or reshape using known (N, T) metadata (e.g., carry rollout_length/num_envs) before computing GAE.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants