-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain-multiclass.py
More file actions
98 lines (82 loc) · 2.44 KB
/
train-multiclass.py
File metadata and controls
98 lines (82 loc) · 2.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import sys
import os
from keras.preprocessing.image import ImageDataGenerator
from keras import optimizers
from keras.models import Sequential
from keras.layers import Dropout, Flatten, Dense, Activation
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras import callbacks
DEV = False
argvs = sys.argv
argc = len(argvs)
if argc > 1 and (argvs[1] == "--development" or argvs[1] == "-d"):
DEV = True
if DEV:
epochs = 2
else:
epochs = 20
train_data_path = './data/train'
validation_data_path = './data/validation'
"""
Parameters
"""
img_width, img_height = 150, 150
batch_size = 8
samples_per_epoch = 50
validation_steps = 2
nb_filters1 = 32
nb_filters2 = 64
conv1_size = 3
conv2_size = 2
pool_size = 2
classes_num = 3
lr = 0.0004
model = Sequential()
model.add(Convolution2D(nb_filters1, conv1_size, conv1_size, border_mode ="same", input_shape=(img_width, img_height, 3)))
model.add(Activation("relu"))
model.add(MaxPooling2D(pool_size=(pool_size, pool_size)))
model.add(Convolution2D(nb_filters2, conv2_size, conv2_size, border_mode ="same"))
model.add(Activation("relu"))
model.add(MaxPooling2D(pool_size=(pool_size, pool_size), dim_ordering='th'))
model.add(Flatten())
model.add(Dense(256))
model.add(Activation("relu"))
model.add(Dropout(0.5))
model.add(Dense(classes_num, activation='softmax'))
model.compile(loss='categorical_crossentropy',
optimizer=optimizers.RMSprop(lr=lr),
metrics=['accuracy'])
train_datagen = ImageDataGenerator(
rescale=1. / 255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
test_datagen = ImageDataGenerator(rescale=1. / 255)
train_generator = train_datagen.flow_from_directory(
train_data_path,
target_size=(img_height, img_width),
batch_size=batch_size,
class_mode='categorical')
validation_generator = test_datagen.flow_from_directory(
validation_data_path,
target_size=(img_height, img_width),
batch_size=batch_size,
class_mode='categorical')
"""
Tensorboard log
"""
log_dir = './tf-log/'
tb_cb = callbacks.TensorBoard(log_dir=log_dir, histogram_freq=0)
cbks = [tb_cb]
model.fit_generator(
train_generator,
samples_per_epoch=samples_per_epoch,
epochs=epochs,
validation_data=validation_generator,
callbacks=cbks,
validation_steps=validation_steps)
target_dir = './models/'
if not os.path.exists(target_dir):
os.mkdir(target_dir)
model.save('./models/model.h5')
model.save_weights('./models/weights.h5')