forked from dabasajay/Image-Caption-Generator
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_val.py
More file actions
40 lines (33 loc) · 1.44 KB
/
train_val.py
File metadata and controls
40 lines (33 loc) · 1.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from pickle import load
from utils.model import *
from utils.load_data import *
# Load Data
# X1 : image features
# X2 : text features
X1train, X2train, max_length = loadTrainData(path = 'train_val_data/Flickr_8k.trainImages.txt',preprocessDataReady=False)
X1val, X2val = loadValData(path = 'train_val_data/Flickr_8k.devImages.txt')
# load the tokenizer
tokenizer_path = 'model_data/tokenizer.pkl'
tokenizer = load(open(tokenizer_path, 'rb'))
vocab_size = len(tokenizer.word_index) + 1
# prints 34
print('Max Length : ',max_length)
# We already have the image features from CNN model so we only need to define the RNN model now.
# define the RNN model
model = defineRNNmodel(vocab_size, max_length)
# train the model, run epochs manually and save after each epoch
epochs = 20
steps_train = len(X2train)
steps_val = len(X2val)
for i in range(epochs):
# create the train data generator
generator_train = data_generator(X1train, X2train, tokenizer, max_length)
# create the val data generator
generator_val = data_generator(X1val, X2val, tokenizer, max_length)
# fit for one epoch
model.fit_generator(generator_train, epochs=1, steps_per_epoch=steps_train,
verbose=1, validation_data=generator_val, validation_steps=steps_val)
# save model
model.save('model_data/model_' + str(i) + '.h5')
# Evaluate the model on validation data and ouput BLEU score
# evaluate_model(model, X1val, X2val, tokenizer, max_length)