-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathmain.py
More file actions
179 lines (155 loc) · 8 KB
/
main.py
File metadata and controls
179 lines (155 loc) · 8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import argparse
import torch
import torch.optim as optim
from dataloader import get_data
from transformer.models import Transformer
from transformer.labelsmooth import LabelSmoothing
from transformer.warmupoptim import WarmUpOptim
from trainer import Trainer
def argument_parsing(preparse=False):
parser = argparse.ArgumentParser(description="Transformer Argparser",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# load data
parser.add_argument("-rt", "--root_dir", required=True,
help="Root Dir")
parser.add_argument("-dt", "--data_type", type=str, default="multi30k",
help="Dataset type: wmt14, multi30k, iwslt")
parser.add_argument("-maxlen", "--max_length", type=int, default=None,
help="Max length of Sentences")
parser.add_argument("-minfreq", "--min_freq", type=int, default=1,
help="Minmum frequence of vocabulary")
# model
parser.add_argument("-nl", "--n_layer", type=int, default=6,
help="Number of layers in Encoder / Decoder")
parser.add_argument("-nh", "--n_head", type=int, default=8,
help="Number of heads in Multi-head Attention sublayer")
parser.add_argument("-dm", "--d_model", type=int, default=512,
help="Dimension of model")
parser.add_argument("-dk", "--d_k", type=int, default=64,
help="Dimension of key")
parser.add_argument("-dv", "--d_v", type=int, default=64,
help="Dimension of value")
parser.add_argument("-df", "--d_f", type=int, default=2048,
help="Dimension of value")
parser.add_argument("-pad", "--pad_idx", type=int,
help="Pad index of vocabulary")
parser.add_argument("-pospad", "--pos_pad_idx", type=int,
help="Position pad index")
parser.add_argument("-drop", "--drop_rate", type=float, default=0.1,
help="Drop Rate")
parser.add_argument("-lws","--linear_weight_share", action="store_true",
help="Share the same weight matrix between the decoder embedding layer and the pre-softmax linear transformation")
parser.add_argument("-ews","--embed_weight_share", action="store_true",
help="Share the same weight matrix between the decoder embedding layer and the encoder embedding layer")
parser.add_argument("-conv","--use_conv", action="store_true",
help="Use Convolution operation in PositionWiseFFN layer")
# loss function
parser.add_argument("-eps", "--smooth_eps", type=float, default=0.1,
help="Label smoothing epsilon value")
# optimizer
parser.add_argument("-warm", "--warmup_steps", type=int, default=4000,
help="Warmup steps for learning rate schedule")
parser.add_argument("-b1", "--beta1", type=float, default=0.9,
help="Beta1 value for Adam optimizer")
parser.add_argument("-b2", "--beta2", type=float, default=0.98,
help="Beta2 value for Adam optimizer")
# training: parse to 'trainer.py'
parser.add_argument("-encsos", "--enc_sos_idx", type=int,
help="Encoder SOS index of vocabulary")
parser.add_argument("-enceos", "--enc_eos_idx", type=int,
help="Encoder EOS index of vocabulary")
parser.add_argument("-decsos", "--dec_sos_idx", type=int,
help="Decoder SOS index of vocabulary")
parser.add_argument("-deceos", "--dec_eos_idx", type=int,
help="Decoder EOS index of vocabulary")
parser.add_argument("-step","--n_step", type=int, default=30,
help="Total Training Step")
# others
parser.add_argument("-bt","--batch", type=int, default=64,
help="Mini batch size")
parser.add_argument("-cuda","--use_cuda", action="store_true",
help="Use Cuda")
parser.add_argument("-svp","--save_path", type=str, default="./saved_model/model.pt",
help="Path to save model")
parser.add_argument("-load","--load_path", type=str,
help="load previous model to transfer learning")
parser.add_argument("-vb","--verbose", type=int, default=0,
help="verbose")
parser.add_argument("-met","--metrics_method", type=str, default="acc",
help="metrics method to use: can choose 'acc' or 'loss'")
if preparse:
return parser
args = parser.parse_args()
return args
def main(args):
# configs path to load data & save model
from pathlib import Path
if not Path(args.root_dir).exists():
Path(args.root_dir).mkdir()
p = Path(args.save_path).parent
if not p.exists():
p.mkdir()
device = "cuda" if (torch.cuda.is_available() and args.use_cuda) else "cpu"
import sys
print(sys.version)
print(f"Using {device}")
print("Loading Data...")
(src, trg), (train, valid, _), (train_loader, valid_loader, _) = get_data(args)
src_vocab_len = len(src.vocab.stoi)
trg_vocab_len = len(trg.vocab.stoi)
# check vocab size
print(f"SRC vocab {src_vocab_len}, TRG vocab {trg_vocab_len}")
enc_max_seq_len = args.max_length
dec_max_seq_len = args.max_length
pad_idx = src.vocab.stoi.get("<pad>") if args.pad_idx is None else args.pad_idx
enc_sos_idx = src.vocab.stoi.get("<s>") if args.enc_sos_idx is None else args.enc_sos_idx
enc_eos_idx = src.vocab.stoi.get("</s>") if args.enc_eos_idx is None else args.enc_eos_idx
dec_sos_idx = trg.vocab.stoi.get("<s>") if args.dec_sos_idx is None else args.dec_sos_idx
dec_eos_idx = trg.vocab.stoi.get("</s>") if args.dec_eos_idx is None else args.dec_eos_idx
pos_pad_idx = 0 if args.pos_pad_idx is None else args.pos_pad_idx
print("Building Model...")
model = Transformer(enc_vocab_len=src_vocab_len,
enc_max_seq_len=enc_max_seq_len,
dec_vocab_len=trg_vocab_len,
dec_max_seq_len=dec_max_seq_len,
n_layer=args.n_layer,
n_head=args.n_head,
d_model=args.d_model,
d_k=args.d_k,
d_v=args.d_v,
d_f=args.d_f,
pad_idx=pad_idx,
pos_pad_idx=pos_pad_idx,
drop_rate=args.drop_rate,
use_conv=args.use_conv,
linear_weight_share=args.linear_weight_share,
embed_weight_share=args.embed_weight_share).to(device)
if args.load_path is not None:
print(f"Load Model {args.load_path}")
model.load_state_dict(torch.load(args.load_path))
# build loss function using LabelSmoothing
loss_function = LabelSmoothing(trg_vocab_size=trg_vocab_len,
pad_idx=args.pad_idx,
eps=args.smooth_eps)
optimizer = WarmUpOptim(warmup_steps=args.warmup_steps,
d_model=args.d_model,
optimizer=optim.Adam(model.parameters(),
betas=(args.beta1, args.beta2),
eps=10e-9))
trainer = Trainer(optimizer=optimizer,
train_loader=train_loader,
test_loader=valid_loader,
n_step=args.n_step,
device=device,
save_path=args.save_path,
enc_sos_idx=enc_sos_idx,
enc_eos_idx=enc_eos_idx,
dec_sos_idx=dec_sos_idx,
dec_eos_idx=dec_eos_idx,
metrics_method=args.metrics_method,
verbose=args.verbose)
print("Start Training...")
trainer.main(model=model, loss_function=loss_function)
if __name__ == "__main__":
args = argument_parsing()
main(args)