-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel_fn.py
More file actions
65 lines (53 loc) · 2.54 KB
/
model_fn.py
File metadata and controls
65 lines (53 loc) · 2.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import numpy as np
import tensorflow as tf
from vae_celeb import *
from config import *
def model_fn(features, mode):
# Instantiate model
if type(features) is dict:
x = features['x']
z = features['z']
else:
x = features
z = tf.constant(0)
y, z_mean, z_log_var = Vae(x, z, FLAGS.mode)
# Loss function
rec_loss = tf.reduce_sum(tf.squared_difference(x, y)) # / rec_norm
kl_loss = -0.5 * tf.reduce_sum(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)) # / kl_norm
total_loss = tf.reduce_mean(rec_loss + beta*kl_loss)
# Outputs
predictions = {'x': x, 'y': y, 'mu':z_mean, 'sigma':z_log_var}
# Mode selection
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
else:
if mode == tf.estimator.ModeKeys.EVAL:
return tf.estimator.EstimatorSpec(mode=mode, loss=total_loss)
elif mode == tf.estimator.ModeKeys.TRAIN:
# Set up logging hooks
tf.summary.scalar('rec_loss', rec_loss)
tf.summary.scalar('kl_loss', kl_loss)
summary_hook = tf.train.SummarySaverHook(save_steps=FLAGS.save_steps,
output_dir=FLAGS.export_path,
summary_op=tf.summary.merge_all())
# Set up optimizer
optimizer = tf.train.AdamOptimizer(learn_rate)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss=total_loss, global_step=tf.train.get_global_step())
return tf.estimator.EstimatorSpec(mode=mode,
loss=total_loss,
train_op=train_op,
training_hooks=[summary_hook])
else:
raise NotImplementedError()
def serving_input_fn():
# Export estimator as a tf serving API
x = tf.placeholder(dtype=tf.float32, shape=[None] + input_dim, name='x')
z = tf.placeholder(dtype=tf.float32, shape=[None, z_dim], name='z')
features={'x':x, 'z':z}
#return tf.estimator.export.TensorServingInputReceiver(features, features)
return tf.estimator.export.ServingInputReceiver(features, features)
def export_tf_model(export_path):
estimator = tf.estimator.Estimator(model_fn, export_path)
estimator.export_saved_model(FLAGS.export_path + '/frozen_pb', serving_input_fn)