-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_test.py
More file actions
24 lines (18 loc) · 867 Bytes
/
train_test.py
File metadata and controls
24 lines (18 loc) · 867 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import capsnet as cn
import datasets
import os
if __name__ == "__main__":
working_directory = '/tmp/models/mnistcaps/'
batch_size = 100
data = datasets.MNIST(batch_size=batch_size)
capsnet = cn.CapsNet(input_shape=[784], n_class=10, reshape=[28, 28, 1],
save_dir=working_directory)
capsnet.train_model.summary()
if not os.path.exists(working_directory + 'weights'):
os.makedirs(working_directory + 'weights/')
else: # if there are weights load them
capsnet.load_weights(working_directory + 'weights/trained_model.h5')
model = capsnet.train(data, batch_size, 55000 // batch_size, epochs=50)
model.save_weights(working_directory + 'weights/trained_model.h5', overwrite=True)
print('Trained model saved to \'%sweights/trained_model.h5\'' % working_directory)
capsnet.test(data)