Skip to content

refactor: separate optimizer algorithms from training strategies#50

Draft
roomote-v0[bot] wants to merge 1 commit intomainfrom
feature/separate-optimizer-trainer
Draft

refactor: separate optimizer algorithms from training strategies#50
roomote-v0[bot] wants to merge 1 commit intomainfrom
feature/separate-optimizer-trainer

Conversation

@roomote-v0
Copy link

@roomote-v0 roomote-v0 bot commented Feb 26, 2026

Opened by @roomote on behalf of Facundo Quiroga

Summary

This PR refactors the coupled optimizer+training-loop design into two distinct concerns, making it possible to train RNN models (and in the future, GANs) with gradient accumulation, while keeping full backward compatibility with the existing API.

Problem

The existing design combines optimization algorithms (SGD, Adam, etc.) with the training loop (batching, epochs) into a single BatchedGradientOptimizer class hierarchy. Gradients are computed once per batch and immediately applied -- there is no way to accumulate gradients across batches, which is needed for stable RNN training.

Solution

1. Pure Optimizer algorithms (optimizer.py)

New classes that only know how to update parameters given gradients:

  • SGD - simple gradient descent
  • MomentumSGD - momentum-based SGD
  • NesterovMomentumSGD - Nesterov accelerated gradient
  • RMSpropOptimizer - RMSprop adaptive learning rate
  • AdamOptimizer - Adam (adaptive moments)
  • SignSGDOptimizer - sign-based SGD

Each implements a simple step(parameters, gradients, epoch, iteration) interface.

2. Trainer strategies (trainers.py)

New classes that handle the training loop and delegate parameter updates to an Optimizer:

  • SupervisedTrainer - standard batch training for feedforward models (MLPs, CNNs)
  • RecurrentTrainer - supports gradient accumulation across multiple batches before updating, enabling stable RNN training

3. Backward compatibility

Wrapper classes (GradientDescent, Adam, RMSprop, etc.) preserve the existing API:

# Old API (still works)
optimizer = nn.GradientDescent(batch_size=32, epochs=100, lr=0.1)
history = optimizer.optimize(model, x, y, error)

# New API (separated concerns)
optimizer = nn.SGD(lr=0.1)
trainer = nn.SupervisedTrainer(optimizer, batch_size=32, epochs=100)
history = trainer.train(model, x, y, error)

# RNN training with gradient accumulation
optimizer = nn.AdamOptimizer(lr=0.001)
trainer = nn.RecurrentTrainer(optimizer, batch_size=16, epochs=50, gradient_accumulation_steps=4)
history = trainer.train(rnn_model, x_seq, y_seq, error)

Testing

  • All existing passing tests continue to pass (test_gradients, test_linear_regression)
  • Pre-existing test failures (classification tests, regression network threshold) are unrelated to this change
  • Manually verified new API (SupervisedTrainer, RecurrentTrainer) and backward-compat wrappers

View task on Roo Code Cloud

Split the coupled optimizer+training-loop design into two distinct
concerns:

1. Optimizer (optimizer.py): Pure parameter update algorithms (SGD,
   Adam, RMSprop, etc.) that only know how to update parameters given
   gradients. These can now be reused across different training
   strategies.

2. Trainer (trainers.py): Training strategies that handle the training
   loop (batching, epochs, forward/backward passes) and delegate
   parameter updates to an Optimizer.

New classes:
- SGD, MomentumSGD, NesterovMomentumSGD, RMSpropOptimizer,
  AdamOptimizer, SignSGDOptimizer (pure optimizers)
- SupervisedTrainer: standard batch training for feedforward models
- RecurrentTrainer: supports gradient accumulation across batches,
  enabling stable RNN training

Backward-compatible wrappers (GradientDescent, Adam, RMSprop, etc.)
preserve the existing API so all guides and tests continue working.
@roomote-v0
Copy link
Author

roomote-v0 bot commented Feb 26, 2026

Rooviewer Clock   See task

Found 2 issues to address:

  • Dead optimize_batch() methods on backward-compat wrappers -- SupervisedTrainer.train() calls self.optimizer.step() directly and never dispatches through optimize_batch(), making these overrides unreachable dead code (comment)
  • optimizer.initialize() unconditionally resets state on every train() call -- stateful optimizers (Momentum, Adam, RMSprop) lose accumulated buffers when train()/optimize() is called more than once, unlike the old code which preserved state (comment)

Mention @roomote in a comment to request specific changes to this pull request or fix all unresolved issues.

Comment on lines +309 to +310
def optimize_batch(self, model: Model, δEδps: ParameterSet, epoch: int, iteration: int):
self.optimizer.step(model.get_parameters(), δEδps, epoch, iteration)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optimize_batch() is dead code on all backward-compat wrappers. SupervisedTrainer.train() calls self.optimizer.step() directly and never dispatches through optimize_batch(). These methods are unreachable -- anyone subclassing BatchedGradientOptimizer and overriding optimize_batch() (which was the old extension point) would silently have their override ignored.

Consider either routing the training loop through optimize_batch() so the old override contract is honoured, or removing these methods entirely to avoid confusion.

Fix it with Roo Code or mention @roomote and request a fix.

batches = n // self.batch_size
history = []
model.set_phase(Phase.Training)
self.optimizer.initialize(model.get_parameters())
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.optimizer.initialize() is called unconditionally at the start of every train() invocation. For stateful optimizers (MomentumSGD, AdamOptimizer, RMSpropOptimizer), this zeroes out momentum buffers and moment estimates each time. The old code used a self.first flag set once in __init__, so calling optimize() multiple times on the same optimizer instance preserved accumulated state. With this change, any code that calls train()/optimize() more than once (e.g., warm-starting, curriculum learning, or resuming training) silently loses all optimizer state.

A straightforward fix: track an initialized flag on the optimizer and skip re-initialization if it has already been called, or let the caller decide via a reset_state parameter.

Fix it with Roo Code or mention @roomote and request a fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant