-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
74 lines (59 loc) · 2.85 KB
/
train.py
File metadata and controls
74 lines (59 loc) · 2.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import argparse
import wandb
import lightning as L
from lightning.pytorch.loggers import WandbLogger
from src import GLUEDataModule, GLUETransformer
EPOCHS = 3 # do not change this
def train_experiment(wandb_project: str, checkpoint_dir: str, **kwargs: dict):
if "batch_size" in kwargs:
kwargs["train_batch_size"] = kwargs["batch_size"]
kwargs["eval_batch_size"] = kwargs["batch_size"]
wandb.init(
project=wandb_project,
name=f"distilbert-base-uncased-{'-'.join('{}_{}'.format(key, val) for key, val in kwargs.items())}",
config=kwargs,
reinit="finish_previous", # allows multiple runs in same script
)
logger = WandbLogger(project=wandb_project, save_dir=checkpoint_dir) # use your experiment tracking tool's logger
L.seed_everything(42)
dm = GLUEDataModule(
model_name_or_path="distilbert-base-uncased",
task_name="mrpc",
**kwargs
)
dm.setup("fit")
model = GLUETransformer(
model_name_or_path="distilbert-base-uncased",
num_labels=dm.num_labels,
eval_splits=dm.eval_splits,
task_name=dm.task_name,
**kwargs
)
trainer = L.Trainer(
max_epochs=EPOCHS,
accelerator="auto",
devices=1,
logger=logger
)
trainer.fit(model, datamodule=dm)
wandb.finish()
if __name__ == "__main__":
parser = argparse.ArgumentParser("python train.py")
# Command line arguments for wandb
parser.add_argument("--wandb-project", type=str, required=True, help="Name of Weights & Biases project")
parser.add_argument("--checkpoint-dir", type=str, default='models', help="Directory to store checkpoints in")
# Command line arguments for all supported hyperparameters
parser.add_argument("-bs", "--batch-size", type=int, help="Training & evaluation batch size")
parser.add_argument("-lr", "--learning-rate", type=float, help="Optimizer learning rate")
parser.add_argument("-ws", "--warmup-steps", type=int, help="Number of warmup steps")
parser.add_argument("-wd", "--weight-decay", type=float, help="Weight decay (L2 regularization)")
parser.add_argument("-o", "--optimizer", choices=["AdamW", "Adam", "NAdam", "SGD"], help="Optimizer to use")
parser.add_argument("-adm-b", "--adam-betas", type=lambda s: tuple(map(float, s.split(","))), help="Adam betas as 'beta1,beta2' (comma-separated)")
parser.add_argument("-adm-e", "--adam-eps", type=float, help="Adam epsilon")
parser.add_argument("-sgd-m", "--sgd-momentum", type=float, help="SGD momentum")
parser.add_argument("-sgd-d", "--sgd-dampening", type=float, help="SGD dampening")
parser.add_argument("-sgd-n", "--sgd-nesterov", action='store_true', help="Enable Nesterov momentum")
args = parser.parse_args()
args = {k: v for k, v in vars(args).items() if v}
# Start the experiment
train_experiment(**args)