From 32186aa46a96f26ba0625755f93c55659adcd523 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 23 Mar 2026 13:55:45 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 888266025 --- learned_optimization/optimizers/optax_opts.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/learned_optimization/optimizers/optax_opts.py b/learned_optimization/optimizers/optax_opts.py index 39424f9..6578b90 100644 --- a/learned_optimization/optimizers/optax_opts.py +++ b/learned_optimization/optimizers/optax_opts.py @@ -186,8 +186,9 @@ def __init__(self, epsilon_root=1e-8): opt = optax.chain( optax.scale_by_adam( - b1=beta1, b2=beta2, eps=epsilon, eps_root=epsilon_root), - optax.scale_by_schedule(piecewise_linear(times, vals=lrs)), + b1=beta1, b2=beta2, eps=epsilon, eps_root=epsilon_root + ), + optax.scale_by_schedule(piecewise_linear(times, vals=lrs)), # pytype: disable=wrong-arg-types # jax-arraylike optax.scale(-1), ) super().__init__(opt)