-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
42 lines (30 loc) · 1.09 KB
/
main.py
File metadata and controls
42 lines (30 loc) · 1.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from Network import Cnn, NetworkStuff
from DataBalancing import Regularisation
import torchvision
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.nn as nn
regularisation = Regularisation()
# regularisation.regularise()
torch.cuda.empty_cache()
model = Cnn().to(torch.device("cuda"))
loss_func = nn.CrossEntropyLoss()
optimiser = optim.AdamW(model.parameters(), amsgrad=True)
lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimiser, factor=0.1, patience=2)
net_stuff = NetworkStuff(model,
optimiser,
loss_func,
lr_scheduler,
train_dir='data/train/simpsons_dataset',
test_dir='data/test/testset',
save_name='model8_adamW',
use_scheduler=True,
epochs=100000,
batch_size=336
)
net_stuff.load_model()
# net_stuff.train(clear_history=False)
net_stuff.submit()
# net_stuff.plotter()
net_stuff.draw_prediction(type_='val')