-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
53 lines (49 loc) · 1.84 KB
/
train.py
File metadata and controls
53 lines (49 loc) · 1.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from tokenizers import Tokenizer
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from model.LLM import SimpleGPT, visualize_tokens_line
from model.atten import selfAttention
if __name__ == '__main__':
from tokenizers import Tokenizer
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from Data.LLMData import LLMDataset
max_length = 256
step = 1
vocab_size = 15099
output_dim = 256
tokenizer = Tokenizer.from_file("./Tokenizer/tokenizer.json")
torch.manual_seed(1234)
with open('./Tokenizer/input.txt', 'r', encoding='utf-8') as f:
txt = f.read()
dataset = LLMDataset(txt, tokenizer, max_length, step)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
GPT = SimpleGPT(vocab_size, max_length, output_dim).cuda()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(GPT.parameters(), lr=3e-4)
best_loss = 1000000000
PATH = "./GPT.pth"
for epoch in range(100):
total_loss = 0
for data in tqdm(dataloader, desc=f"Epoch {epoch+1}"):
inputs, targets = data
inputs = inputs.cuda()
targets = targets.cuda()
outputs = GPT(inputs)
outputs = outputs.reshape(-1, vocab_size)
targets = targets.reshape(-1)
loss = criterion(outputs, targets)
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(GPT.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
if avg_loss < best_loss:
best_loss = avg_loss
torch.save(GPT.state_dict(), PATH)
print("epoch: ", epoch, "loss: ", avg_loss)
print("best_loss: ", best_loss)