Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 64 additions & 4 deletions modules/trainer/GenericTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,18 @@ def start(self):
self.model, self.model_setup, self.model.train_progress, is_validation=True
)

if self.config.patience and not self.config.validation:
print("Warning: Patience enabled without Validation. Auto-enabling Validation.")
self.config.validation = True
self.validation_data_loader = self.create_data_loader(
self.model, self.model_setup, self.model.train_progress, is_validation=True
)

self._patience_counter = 0
self._patience_best_loss = float('inf')
self._patience_best_step = -1
self._patience_best_backup_path: str | None = None

def __save_config_to_workspace(self):
path = path_util.canonical_join(self.config.workspace_dir, "config")
os.makedirs(Path(path).absolute(), exist_ok=True)
Expand Down Expand Up @@ -404,15 +416,51 @@ def __validate(self, train_progress: TrainProgress):
average_loss,
train_progress.global_step)

if len(concept_counts) > 1:
total_loss = sum(accumulated_loss_per_concept[key] for key in concept_counts)
total_count = sum(concept_counts[key] for key in concept_counts)
total_average_loss = total_loss / total_count
total_loss = sum(accumulated_loss_per_concept[key] for key in concept_counts)
total_count = sum(concept_counts[key] for key in concept_counts)
total_average_loss = total_loss / total_count

if len(concept_counts) > 1:
self.tensorboard.add_scalar("loss/validation_step/total_average",
total_average_loss,
train_progress.global_step)

if self.config.patience:
self.__check_patience(total_average_loss, train_progress)

def __check_patience(self, val_loss: float, train_progress: TrainProgress):
is_new_best = val_loss < self._patience_best_loss

if is_new_best:
self._patience_best_backup_path = self.__save_patience_best(train_progress)
self._patience_best_step = train_progress.global_step
self._patience_best_loss = val_loss
self._patience_counter = 0
else:
self._patience_counter += 1

self.tensorboard.add_scalar("patience/counter", self._patience_counter, train_progress.global_step)
self.tensorboard.add_scalar("patience/best_val_loss", self._patience_best_loss, train_progress.global_step)

if self._patience_counter >= self.config.patience_epochs:
print(f"Patience triggered at step {train_progress.global_step}. "
f"Best checkpoint from step {self._patience_best_step} "
f"(val_loss: {self._patience_best_loss:.6f})")
self.commands.stop()

def __save_patience_best(self, train_progress: TrainProgress) -> str:
best_path = os.path.join(self.config.workspace_dir, "backup", "patience-best.pt")
os.makedirs(os.path.dirname(best_path), exist_ok=True)
try:
state = [p.data.clone().cpu() for p in self.parameters]
torch.save(state, best_path)
print(f"Saved patience best checkpoint (val_loss improved) at step {train_progress.global_step}")
except Exception:
traceback.print_exc()
print("Could not save patience best checkpoint.")
return self._patience_best_backup_path or ""
return best_path

def __save_backup_config(self, backup_path):
config_path = os.path.join(backup_path, "onetrainer_config")
args_path = path_util.canonical_join(config_path, "args.json")
Expand Down Expand Up @@ -856,6 +904,18 @@ def end(self):

if self.model.ema:
self.model.ema.copy_ema_to(self.parameters, store_temp=False)

if (self.config.patience
and self._patience_best_backup_path
and os.path.isfile(self._patience_best_backup_path)):
print(f"Restoring patience best checkpoint from step {self._patience_best_step} "
f"(val_loss: {self._patience_best_loss:.6f})")
self.callbacks.on_update_status("Restoring best validation checkpoint")
best_state = torch.load(self._patience_best_backup_path, map_location=self.temp_device)
for param, saved in zip(self.parameters, best_state, strict=True):
param.data.copy_(saved)
del best_state

if os.path.isdir(self.config.output_model_destination) and self.config.output_model_format.is_single_file():
save_path = os.path.join(
self.config.output_model_destination,
Expand Down
21 changes: 21 additions & 0 deletions modules/ui/TrainUI.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,10 +304,26 @@ def create_general_tab(self, master):
tooltip="Number of threads used for the data loader. Increase if your GPU has room during caching, decrease if it's going out of memory during caching.")
components.entry(frame, 10, 1, self.ui_state, "dataloader_threads", required=True)

components.label(frame, 10, 2, "Patience",
tooltip="Enable early stopping based on validation loss. "
"Training stops when validation loss has not improved "
"for a set number of consecutive validation checks. "
"Automatically enables Validation when turned on.")
components.switch(frame, 10, 3, self.ui_state, "patience",
command=self._on_patience_toggle)

components.label(frame, 11, 0, "Train Device",
tooltip="The device used for training. Can be \"cuda\", \"cuda:0\", \"cuda:1\" etc. Default:\"cuda\". Must be \"cuda\" for multi-GPU training.")
components.entry(frame, 11, 1, self.ui_state, "train_device", required=True)

components.label(frame, 11, 2, "Early Stop After",
tooltip="Number of consecutive validation checks without improvement "
"before training stops. You control how often validation runs "
"with the 'Validate after' setting above. The best checkpoint "
"(lowest validation loss) is automatically saved and restored "
"as the final output model.")
components.entry(frame, 11, 3, self.ui_state, "patience_epochs")

components.label(frame, 12, 0, "Multi-GPU",
tooltip="Enable multi-GPU training")
components.switch(frame, 12, 1, self.ui_state, "multi_gpu")
Expand Down Expand Up @@ -871,6 +887,11 @@ def _on_always_on_tensorboard_toggle(self):
if not (self.training_thread and self.train_config.tensorboard):
self._stop_always_on_tensorboard()

def _on_patience_toggle(self):
if self.train_config.patience and not self.train_config.validation:
self.train_config.validation = True
self.ui_state.get_var("validation").set(True)

def _set_training_button_style(self, mode: str):
if not self.training_button:
return
Expand Down
13 changes: 12 additions & 1 deletion modules/util/config/TrainConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,8 @@ class TrainConfig(BaseConfig):
validation: bool
validate_after: float
validate_after_unit: TimeUnit
patience: bool
patience_epochs: int
continue_last_backup: bool
prevent_overwrites: bool
include_train_config: ConfigPart
Expand Down Expand Up @@ -560,7 +562,7 @@ class TrainConfig(BaseConfig):
def __init__(self, data: list[(str, Any, type, bool)]):
super().__init__(
data,
config_version=10,
config_version=11,
config_migrations={
0: self.__migration_0,
1: self.__migration_1,
Expand All @@ -572,6 +574,7 @@ def __init__(self, data: list[(str, Any, type, bool)]):
7: self.__migration_7,
8: self.__migration_8,
9: self.__migration_9,
10: self.__migration_10,
}
)

Expand Down Expand Up @@ -791,6 +794,12 @@ def replace_dtype(part: str):

return migrated_data

def __migration_10(self, data: dict) -> dict:
migrated_data = data.copy()
migrated_data.setdefault("patience", False)
migrated_data.setdefault("patience_epochs", 5)
return migrated_data

def weight_dtypes(self) -> ModelWeightDtypes:
return ModelWeightDtypes(
self.train_dtype,
Expand Down Expand Up @@ -943,6 +952,8 @@ def default_values() -> 'TrainConfig':
data.append(("validation", False, bool, False))
data.append(("validate_after", 1, int, False))
data.append(("validate_after_unit", TimeUnit.EPOCH, TimeUnit, False))
data.append(("patience", False, bool, False))
data.append(("patience_epochs", 5, int, False))
data.append(("continue_last_backup", False, bool, False))
data.append(("prevent_overwrites", False, bool, False))
data.append(("include_train_config", ConfigPart.NONE, ConfigPart, False))
Expand Down