-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathearlystop.py
More file actions
161 lines (140 loc) · 5.94 KB
/
earlystop.py
File metadata and controls
161 lines (140 loc) · 5.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import numpy as np
import pandas as pd
import os
import torch
import joblib
class MonitorBestModelEarlyStopping:
"""
Early stops the training if validation loss doesn't improve after a given patience and saves the best model.
"""
def __init__(
self,
patience=15,
min_epochs=20,
saving_checkpoint=True,
hpset=None,
output_dir=None,
):
"""
Initializes the early stopping monitor.
Args:
patience (int): How long to wait after the last time validation loss improved. Default: 15
min_epochs (int): Minimum number of epochs to wait before considering early stopping. Default: 20
saving_checkpoint (bool): If True, saves the model checkpoint when validation loss improves. Default: True
hpset (int): Hyperparameter set identifier.
output_dir (str): Directory to save checkpoints and predictions.
"""
self.patience = patience
self.min_epochs = min_epochs
self.counter = 0
self.early_stop = False
self.eval_loss_min = np.inf
self.best_loss_score = None
self.best_epoch_loss = None
self.best_opt_metric_score = 0.0 # model will be optimized on this metric
self.best_epoch = None
self.saving_checkpoint = saving_checkpoint
self.hpset = hpset
self.output_dir = os.path.abspath(output_dir) if output_dir else None
if self.saving_checkpoint and self.output_dir is None:
raise ValueError(
"output_dir must be provided when saving_checkpoint is enabled."
)
def __call__(
self, epoch, eval_loss, eval_opt_metric, model, platt_model, preds, preds_train
):
"""
Checks if training should be stopped based on validation loss and saves the best model.
"""
loss_score = -eval_loss
opt_score = eval_opt_metric # AUC
if self.best_loss_score is None:
self._update_loss_scores(loss_score, eval_loss, epoch)
self._update_metrics_scores(opt_score, epoch)
# If validation loss starts increasing, begin counting for early stopping
elif loss_score < self.best_loss_score:
self.counter += 1
print(
f"Evaluation loss does not decrease : Starting Early stopping counter {self.counter} out of {self.patience}"
)
if self.counter >= self.patience and epoch > self.min_epochs:
self.early_stop = True
# If validation loss is still decreasing, reset the counter and save the model
else:
print(
f"Epoch {epoch} validation loss decreased ({self.eval_loss_min:.6f} --> {eval_loss:.6f})"
)
self._update_loss_scores(loss_score, eval_loss, epoch)
self._update_metrics_scores(opt_score, epoch)
self.save_checkpoint_predictions(model, platt_model, preds, preds_train)
self.counter = 0
# If the optimization metric (e.g., AUC) improves, update scores and save the model
if opt_score > self.best_opt_metric_score:
print(
f"Epoch {epoch}: AUC improved ({self.best_opt_metric_score:.4f} --> {opt_score:.4f})"
)
self._update_metrics_scores(opt_score, epoch)
self._update_loss_scores(loss_score, eval_loss, epoch)
self.save_checkpoint_predictions(model, platt_model, preds, preds_train)
def save_checkpoint_predictions(self, model, platt_model, preds, preds_train):
"""
Saves the model checkpoint, Platt scaler, and predictions.
"""
if not self.saving_checkpoint or self.output_dir is None:
return
print(f"Saving checkpoint and predictions")
base_name = "polarix"
# Save model checkpoint
checkpoints_dir = os.path.join(
self.output_dir, "checkpoints", "final", base_name
)
os.makedirs(checkpoints_dir, exist_ok=True)
filepath_check = os.path.join(
checkpoints_dir,
f"{base_name}_hp{self.hpset}_checkpoint.pt",
)
torch.save(model.state_dict(), filepath_check)
# Save Platt scaler model and coefficients
platt_dir = checkpoints_dir
filepath_platt = os.path.join(
platt_dir,
f"{base_name}_hp{self.hpset}_PlattScaler.pkl",
)
joblib.dump(platt_model, filepath_platt)
alpha = platt_model.coef_[0][0]
beta = platt_model.intercept_[0]
coefficients_df = pd.DataFrame({"alpha": [alpha], "beta": [beta]})
filepath_platt_coef = os.path.join(
platt_dir,
f"{base_name}_hp{self.hpset}_PlattScalerCOEF.csv",
)
coefficients_df.to_csv(filepath_platt_coef, index=False)
# Save evaluation predictions
predictions_dir = os.path.join(
self.output_dir, "predictions", "final", base_name
)
os.makedirs(predictions_dir, exist_ok=True)
filepath_pred = os.path.join(
predictions_dir,
f"{base_name}_hp{self.hpset}_predictions.csv",
)
preds.to_csv(filepath_pred, index=False)
# Save training predictions
filepath_pred_train = os.path.join(
predictions_dir,
f"{base_name}_hp{self.hpset}_predictions_TRAIN.csv",
)
preds_train.to_csv(filepath_pred_train, index=False)
def _update_loss_scores(self, loss_score, eval_loss, epoch):
self.eval_loss_min = eval_loss
self.best_loss_score = loss_score
self.best_epoch_loss = epoch
print(
f"Updating loss at epoch {self.best_epoch_loss} -> {self.eval_loss_min:.6f}"
)
def _update_metrics_scores(self, opt_score, epoch):
self.best_opt_metric_score = opt_score
self.best_epoch = epoch
print(
f"Updating Opt metric at epoch {self.best_epoch} -> {self.best_opt_metric_score:.6f}"
)