Skip to content

Add patience (early stopping) based on validation loss#1408

Open
BitcrushedHeart wants to merge 3 commits intoNerogar:masterfrom
BitcrushedHeart:patience
Open

Add patience (early stopping) based on validation loss#1408
BitcrushedHeart wants to merge 3 commits intoNerogar:masterfrom
BitcrushedHeart:patience

Conversation

@BitcrushedHeart
Copy link
Copy Markdown
Contributor

Description

Adds early stopping to standard training based on validation loss. When enabled, the trainer monitors validation loss after each validation run and saves a lightweight checkpoint whenever a new minimum is reached. If loss doesn't improve for a configurable number of consecutive validation checks, training stops and the best checkpoint is restored as the final output model.

This mirrors the same pattern already used for DPO early stopping (accuracy-based), but operates on the general validation loss path in the General tab.

How It Works

Each validation run is a "tick." The trainer keeps a running minimum of the total average validation loss across all concepts. When loss improves, the counter resets and the current parameter state is saved to {workspace}/backup/patience-best.pt. When loss doesn't improve, the counter increments. Once the counter reaches the configured threshold (patience_epochs, default 5), training stops via commands.stop().

At the end of training, if a best checkpoint exists, the trainer restores those weights before saving the final model. The user gets the best-performing weights, not the last (potentially overfit) ones. The log output makes this explicit:

Patience triggered at step 280. Best checkpoint from step 180 (val_loss: 0.043200)
Restoring patience best checkpoint from step 180 (val_loss: 0.043200)

The best checkpoint is a separate file from regular interval saves and backups. Users who already have "save every N steps" configured will see both their regular checkpoints and the best checkpoint appearing independently. No existing backup/checkpoint behaviour is changed.

UI

Two new fields in the General tab, filling previously empty cells next to Dataloader Threads and Train Device:

Left Right
Validation Validate after
Dataloader Threads Patience (toggle)
Train Device Early Stop After (input, default 5)
Multi-GPU Device Indexes

Enabling the Patience toggle auto-enables Validation if it's off. There's also a training-time guard that does the same thing with a console warning, in case someone edits the config JSON directly.

Changes

  • modules/util/config/TrainConfig.py: Two new fields (patience, patience_epochs), config version bump (10→11), migration
  • modules/ui/TrainUI.py: Patience toggle + Early Stop After input in General tab, auto-enable callback
  • modules/trainer/GenericTrainer.py: Patience check after validation, lightweight checkpoint save/restore, training-time guard
  • tests/test_patience.py: 8 tests covering counter increment, reset, stop trigger, checkpoint save/restore, auto-enable, and TensorBoard logging

Testing Notes

All 8 unit tests pass. Tests use dummy tensors and mock objects to verify the patience logic independently of the full training pipeline — counter behaviour, stop triggering, and parameter save/restore are all exercised.

Tested on Windows 11, Python 3.11.9.

When enabled, tracks validation loss across runs and saves the best
checkpoint. If loss doesn't improve for N consecutive validation
checks, stops training and restores the best weights as the final
output model.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant