Skip to content

Feature Request: Add learning_rate_schedule_fn argument to ppo.train #621

@JummerCloth

Description

@JummerCloth

Hi Brax team,

JummerCloth again. Thanks for adding the parametric action distribution! I’d like to suggest another small but useful feature enhancement to the ppo.train function: adding a new argument learning_rate_schedule_fn to allow more flexible control over the learning rate schedule. This would make it easier to plug in schedules like optax.cosine_decay_schedule or any other custom schedule without modifying the core training logic.

Motivation

Right now, ppo.train uses the simple adam optimizer. While theoretically adaptive, this empirically limits experimentation and fine-tuning, especially for users who want to explore different scheduling strategies (e.g., cosine decay, piecewise constant, exponential decay, etc.). We benchmarked performance between brax and rsl_rl, and found that the adaptive learning rate in rsl_rl helped a lot for our tasks and made rsl_rl outperform brax. While the adaptive learning rate may be challenging to implement, we found that a simple learning_rate_schedule_fn, such as cosine decay, is sufficient for significant improvement in our tasks' performance using brax.

Proposed Change

Add a new optional argument to ppo.train:

def train(..., learning_rate_schedule_fn: Optional[Callable[[int], float]] = None, ...):

Then modify the learning rate schedule setup as follows:

if learning_rate_schedule_fn is None:
    # Default to a constant schedule if no custom schedule is provided
    learning_rate_schedule_fn = optax.constant_schedule(value=learning_rate)

optimizer = optax.adam(learning_rate=learning_rate_schedule_fn)

This change would be backwards compatible and would not affect existing users.

Example Usage

learning_rate_schedule_fn = optax.cosine_decay_schedule(
    init_value=0.001,
    decay_steps=100_000,
    alpha=0.1,
)

ppo.train(..., learning_rate_schedule_fn=learning_rate_schedule_fn)

Let me know if this would be a welcome addition — I’d be happy to contribute a PR if that’s helpful!

Thanks for the awesome work on Brax!

JummerCloth

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions