Hi Brax team,
JummerCloth again. Thanks for adding the parametric action distribution! I’d like to suggest another small but useful feature enhancement to the ppo.train function: adding a new argument learning_rate_schedule_fn to allow more flexible control over the learning rate schedule. This would make it easier to plug in schedules like optax.cosine_decay_schedule or any other custom schedule without modifying the core training logic.
Motivation
Right now, ppo.train uses the simple adam optimizer. While theoretically adaptive, this empirically limits experimentation and fine-tuning, especially for users who want to explore different scheduling strategies (e.g., cosine decay, piecewise constant, exponential decay, etc.). We benchmarked performance between brax and rsl_rl, and found that the adaptive learning rate in rsl_rl helped a lot for our tasks and made rsl_rl outperform brax. While the adaptive learning rate may be challenging to implement, we found that a simple learning_rate_schedule_fn, such as cosine decay, is sufficient for significant improvement in our tasks' performance using brax.
Proposed Change
Add a new optional argument to ppo.train:
def train(..., learning_rate_schedule_fn: Optional[Callable[[int], float]] = None, ...):
Then modify the learning rate schedule setup as follows:
if learning_rate_schedule_fn is None:
# Default to a constant schedule if no custom schedule is provided
learning_rate_schedule_fn = optax.constant_schedule(value=learning_rate)
optimizer = optax.adam(learning_rate=learning_rate_schedule_fn)
This change would be backwards compatible and would not affect existing users.
Example Usage
learning_rate_schedule_fn = optax.cosine_decay_schedule(
init_value=0.001,
decay_steps=100_000,
alpha=0.1,
)
ppo.train(..., learning_rate_schedule_fn=learning_rate_schedule_fn)
Let me know if this would be a welcome addition — I’d be happy to contribute a PR if that’s helpful!
Thanks for the awesome work on Brax!
JummerCloth
Hi Brax team,
JummerCloth again. Thanks for adding the parametric action distribution! I’d like to suggest another small but useful feature enhancement to the
ppo.trainfunction: adding a new argumentlearning_rate_schedule_fnto allow more flexible control over the learning rate schedule. This would make it easier to plug in schedules likeoptax.cosine_decay_scheduleor any other custom schedule without modifying the core training logic.Motivation
Right now,
ppo.trainuses the simple adam optimizer. While theoretically adaptive, this empirically limits experimentation and fine-tuning, especially for users who want to explore different scheduling strategies (e.g., cosine decay, piecewise constant, exponential decay, etc.). We benchmarked performance between brax and rsl_rl, and found that the adaptive learning rate in rsl_rl helped a lot for our tasks and made rsl_rl outperform brax. While the adaptive learning rate may be challenging to implement, we found that a simplelearning_rate_schedule_fn, such as cosine decay, is sufficient for significant improvement in our tasks' performance using brax.Proposed Change
Add a new optional argument to
ppo.train:Then modify the learning rate schedule setup as follows:
This change would be backwards compatible and would not affect existing users.
Example Usage
Let me know if this would be a welcome addition — I’d be happy to contribute a PR if that’s helpful!
Thanks for the awesome work on Brax!
JummerCloth