-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
62 lines (46 loc) · 1.63 KB
/
train.py
File metadata and controls
62 lines (46 loc) · 1.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
from torch.utils.data import DataLoader
from data.dataset import TinyStoriesDataset
from model.bigram import Bigram
from model.transformer import Transformer, TINY_CONFIG, FULL_CONFIG
if __name__ == "__main__":
device = torch.device("cuda")
training_data = TinyStoriesDataset("data/train.bin", context_len=512)
valid_data = TinyStoriesDataset("data/val.bin", context_len=512)
model = Transformer(**FULL_CONFIG) #GPT2 encoder vocab size is 50257
model.to(device)
n_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {n_params:,}")
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
training_loader = DataLoader(
training_data,
batch_size=4,
shuffle=True,
num_workers=0,
)
valid_loader = DataLoader(
valid_data,
batch_size=4,
shuffle=True,
num_workers=0,
)
epoch = 30
for e in range(epoch):
for i, data in enumerate(training_loader):
x, y = data
x = x.to(device)
y = y.to(device)
outputs = model(x)
optimizer.zero_grad()
loss = loss_function(outputs.view(-1, 50257), y.view(-1))
loss.backward()
optimizer.step()
if i % 100 == 0:
print(f"Epoch {e}, Step {i}, Loss: {loss.item():.4f}")
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss.item(),
}, 'checkpoints/transformer_full.pt')
print("Model saved!")