-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
38 lines (25 loc) · 954 Bytes
/
train.py
File metadata and controls
38 lines (25 loc) · 954 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from ColourNet import ColourNet
import torch
import os
import time
from Plotter import Plotter
from DataHandler import DataHandler
def train_model(batch_size=8, reset=False, device='cpu', iter=20000):
data_handler = DataHandler(batch_size, split='train', device=device)
colourNet = ColourNet(device)
plotter = Plotter(mode='train')
if 'ColourNet_G.pth' in os.listdir(os.curdir) and not reset:
colourNet.load()
best_loss = float('inf')
for i in range(iter):
grayscale, color = data_handler.get_batch()
# Training the Generator
loss = colourNet.train_G(grayscale, color)
# Save the model if loss is decresed
if loss < best_loss:
colourNet.save()
plotter.add_loss(loss)
print(i, (i*data_handler.batch_size*100)/len(data_handler.files), "% ", "Total Loss: ", loss)
print('Training finished')
if __name__ == '__main__':
train_model()