-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathmain.py
More file actions
73 lines (66 loc) · 3.11 KB
/
main.py
File metadata and controls
73 lines (66 loc) · 3.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# Copyright (c) 2019 UniMoRe, Matteo Spallanzani
import argparse
from quantlab.protocol.logbook import Logbook
from quantlab.indiv.daemon import get_topo
from quantlab.treat.daemon import get_algo, get_data
from quantlab.protocol.rooms import train, test
# Command Line Interface
parser = argparse.ArgumentParser(description='QuantLab')
parser.add_argument('--problem', help='MNIST/CIFAR-10/ImageNet/COCO')
parser.add_argument('--topology', help='Network topology')
parser.add_argument('--exp_id', help='Experiment to launch/resume', default=None)
parser.add_argument('--load', help='Checkpoint to load: best/last/i_epoch', default='best')
parser.add_argument('--mode', help='Experiment mode: train/test', default='train')
parser.add_argument('--ckpt_every', help='Frequency of checkpoints (in epochs)', default=50)
args = parser.parse_args()
# create/retrieve experiment logbook
logbook = Logbook(args.problem, args.topology, args.exp_id, args.load)
# create/retrieve network and treatment
net, net_maybe_par, device, loss_fn = get_topo(logbook)
thr, opt, lr_sched = get_algo(logbook, net)
train_l, valid_l, test_l = get_data(logbook)
# run experiment
if args.mode == 'train':
for _ in range(logbook.i_epoch + 1, logbook.config['treat']['max_epoch'] + 1):
logbook.start_epoch()
thr.step()
# train
net.train()
train_stats = train(logbook, net_maybe_par, device, loss_fn, opt, train_l)
# validate
net.eval()
valid_stats = test(logbook, net, device, loss_fn, valid_l, valid=True)
stats = {**train_stats, **valid_stats}
# update learning rate
if 'metrics' in lr_sched.step.__code__.co_varnames:
lr_sched_metric = stats[logbook.config['treat']['lr_scheduler']['step_metric']]
lr_sched.step(lr_sched_metric)
else:
lr_sched.step()
# save model if update metric has improved...
if logbook.is_better(stats):
ckpt = {'indiv': {'net': net.state_dict()},
'treat': {
'thermostat': thr.state_dict(),
'optimizer': opt.state_dict(),
'lr_scheduler': lr_sched.state_dict(),
'i_epoch': logbook.i_epoch
},
'protocol': {'metrics': logbook.metrics}}
logbook.store_checkpoint(ckpt, is_best=True)
# ...and/or if checkpoint epoch
is_ckpt_epoch = (logbook.i_epoch % args.ckpt_every) == 0
if is_ckpt_epoch:
ckpt = {'indiv': {'net': net.state_dict()},
'treat': {
'thermostat': thr.state_dict(),
'optimizer': opt.state_dict(),
'lr_scheduler': lr_sched.state_dict(),
'i_epoch': logbook.i_epoch
},
'protocol': {'metrics': logbook.metrics}}
logbook.store_checkpoint(ckpt)
elif args.mode == 'test':
# test
net.eval()
test_stats = test(logbook, net, device, loss_fn, test_l)