-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
153 lines (127 loc) · 5.53 KB
/
train.py
File metadata and controls
153 lines (127 loc) · 5.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import os, argparse, logging, random
from collections import defaultdict
import torch, dgl
import torch.optim as optim
import torch.nn.functional as F
import utils
from datasets import (
load_reaction_dataset,
load_molecule_dict,
build_dataloader,
check_molecule_dict)
from trainers.retrosynthesis import (
create_retrosynthesis_trainer,
create_retrosynthesis_evaluator)
from trainers.utils import collect_embeddings
from models import load_module
from models.similarity import CosineSimilarity
from models.loss import SimCLR
torch.backends.cudnn.benchmark = True
device = torch.device('cuda:0')
def main(args):
if not args.resume:
os.makedirs(args.logdir)
utils.set_logging_options(args.logdir)
logger = logging.getLogger('train')
logger.info(' '.join(os.sys.argv))
logger.info(args)
### DATASETS
datasets = load_reaction_dataset(args.datadir)
mol_dict = load_molecule_dict(args.mol_dict)
check_molecule_dict(mol_dict, datasets)
### DATALOADERS
trainloader = build_dataloader(datasets['train'], batch_size=args.batch_size, num_iterations=args.num_iterations)
### MODELS
module = load_module(args).to(device)
### LOSS
sim_fn = CosineSimilarity()
loss_fn = SimCLR(sim_fn, args.tau).to(device)
### OPTIMIZER
params = list(module.parameters())
optimizer = optim.SGD(params, lr=args.lr, weight_decay=args.wd, momentum=0.9)
### TRAINER
nearest_neighbors = defaultdict(list)
train_step = create_retrosynthesis_trainer(module,
loss_fn,
forward=not args.backward_only,
nearest_neighbors=nearest_neighbors,
num_neighbors=args.num_neighbors,
device=device)
evaluate = create_retrosynthesis_evaluator(module, sim_fn, device=device, beam=1)
### TRAINING
if args.resume:
ckpt = torch.load(os.path.join(args.logdir, 'last.pth'), map_location='cpu')
module.load_state_dict(ckpt['module'])
optimizer.load_state_dict(ckpt['optim'])
iteration = ckpt['iteration']
best_acc = ckpt['best_acc']
else:
iteration = 0
best_acc = 0
def save(name='last.pth'):
torch.save({
'module': module.state_dict(),
'optim': optimizer.state_dict(),
'iteration': iteration,
'best_acc': best_acc,
'args': vars(args),
}, os.path.join(args.logdir, name))
embeddings = collect_embeddings(module, mol_dict, device=device)
for reactions in trainloader:
iteration += 1
if iteration > args.num_iterations:
break
if (iteration-1) % args.eval_freq == 0:
logger.info('Update nearest_neighbors ...')
with torch.no_grad():
keys = module.construct_keys(embeddings)
keys = F.normalize(keys).to(device)
for i in range(0, keys.shape[0], 512):
_, indices = torch.einsum('ik, jk -> ij', keys[i:i+512], keys).topk(args.num_neighbors+1, dim=1)
for j, neighbors in enumerate(indices.tolist()):
nearest_neighbors[mol_dict[i+j].smiles] = [mol_dict[k] for k in neighbors[1:]]
logger.info('Update nearest_neighbors ... done')
# TRAINING
optimizer.zero_grad()
outputs = train_step(reactions)
outputs['loss'].backward()
if args.clip is not None:
torch.nn.utils.clip_grad_norm_(params, args.clip)
optimizer.step()
# LOGGING
logger.info('[Iter {}] [BatchSize {}] [Loss {:.4f}] [BatchAcc {:.4f}]'.format(
iteration, len(reactions), outputs['loss'].item(), outputs['acc'].item()))
if iteration % args.eval_freq == 0:
embeddings = collect_embeddings(module, mol_dict, device=device)
acc, _ = evaluate(mol_dict, datasets['val'], embeddings)
acc = acc[0]
if best_acc < acc:
logger.info(f'[Iter {iteration}] [BEST {acc:.4f}]')
best_acc = acc
save('best.pth')
save()
logger.info(f'[Iter {iteration}] [Val Acc {acc:.4f}]')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Model arguments
parser.add_argument('--num-layers', type=int, default=5)
parser.add_argument('--dropout', type=float, default=0)
parser.add_argument('--use-label', action='store_true')
parser.add_argument('--use-sum', action='store_true')
# Optimization arguments
parser.add_argument('--num-iterations', type=int, default=200000)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--lr', type=float, default=1e-2)
parser.add_argument('--wd', type=float, default=1e-5)
parser.add_argument('--clip', type=float, default=5.0)
parser.add_argument('--tau', type=float, default=0.1)
parser.add_argument('--eval-freq', type=int, default=1000)
parser.add_argument('--num-neighbors', type=int, default=0)
# Training arguments
parser.add_argument('--logdir', type=str, required=True)
parser.add_argument('--resume', action='store_true')
parser.add_argument('--datadir', type=str, default='data/uspto_50k')
parser.add_argument('--mol-dict', type=str, default='data/uspto_candidates')
parser.add_argument('--backward-only', action='store_true')
args = parser.parse_args()
main(args)