diff --git a/learned_optimization/optimizers/optax_opts.py b/learned_optimization/optimizers/optax_opts.py index 39424f9..6578b90 100644 --- a/learned_optimization/optimizers/optax_opts.py +++ b/learned_optimization/optimizers/optax_opts.py @@ -186,8 +186,9 @@ def __init__(self, epsilon_root=1e-8): opt = optax.chain( optax.scale_by_adam( - b1=beta1, b2=beta2, eps=epsilon, eps_root=epsilon_root), - optax.scale_by_schedule(piecewise_linear(times, vals=lrs)), + b1=beta1, b2=beta2, eps=epsilon, eps_root=epsilon_root + ), + optax.scale_by_schedule(piecewise_linear(times, vals=lrs)), # pytype: disable=wrong-arg-types # jax-arraylike optax.scale(-1), ) super().__init__(opt)