-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_detect.py
More file actions
61 lines (49 loc) · 1.91 KB
/
train_detect.py
File metadata and controls
61 lines (49 loc) · 1.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from __future__ import print_function
import numpy as np
from keras.callbacks import EarlyStopping
from keras.datasets import cifar10
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Flatten
from keras.layers.convolutional import Conv2D
from keras.optimizers import Adam
from keras.layers.pooling import MaxPooling2D
from keras.utils import to_categorical
from keras.callbacks import CSVLogger
# For reproducibility
np.random.seed(1000)
if __name__ == '__main__':
# Load the dataset
(X_train, Y_train), (X_test, Y_test) = cifar10.load_data()
# Create the model
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(32, 32, 3)))
model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Conv2D(128, kernel_size=(3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(128, kernel_size=(3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(1024, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax'))
csv=CSVLogger("epochs1.log")
# Compile the model
model.compile(loss='categorical_crossentropy',
optimizer=Adam(lr=0.0001, decay=1e-6),
metrics=['accuracy'])
# Train the model
model.fit(X_train / 255.0, to_categorical(Y_train),
batch_size=128,
shuffle=True,
epochs=63,
validation_data=(X_test / 255.0, to_categorical(Y_test)),
callbacks=[csv])
# Evaluate the model
scores = model.evaluate(X_test / 255.0, to_categorical(Y_test))
print('Loss: %.3f' % scores[0])
print('Accuracy: %.3f' % scores[1])
import h5py
model.save('cifar-10.h5')