-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathutils.py
More file actions
98 lines (71 loc) · 2.67 KB
/
utils.py
File metadata and controls
98 lines (71 loc) · 2.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# Keep tensorflow quiet!
import os
import tensorflow as tf
tf.get_logger().setLevel('ERROR')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import cv2 as cv
import numpy as np
from keras.models import load_model
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from mlhub.utils import DataResourceNotFoundException
MODEL_FILE = 'cache/models/model.06-2.5489.hdf5'
IMG_PATH = 'images'
def get_predict_api():
try:
model = load_model(MODEL_FILE)
except OSError:
raise DataResourceNotFoundException(MODEL_FILE)
def predict(gray):
h_in, w_in = 256, 256
h_out, w_out = h_in // 4, w_in // 4
epsilon = 1e-6
T = 0.38
img_rows, img_cols = gray.shape[:2]
q_ab = np.load("data/pts_in_hull.npy")
nb_q = q_ab.shape[0]
L = gray
gray = cv.resize(gray, (h_in, w_in), cv.INTER_CUBIC)
x_test = np.empty((1, h_in, w_in, 1), dtype=np.float32)
x_test[0, :, :, 0] = gray / 255.
X_colorized = model.predict(x_test)
X_colorized = X_colorized.reshape((h_out * w_out, nb_q))
X_colorized = np.exp(np.log(X_colorized + epsilon) / T)
X_colorized = X_colorized / np.sum(X_colorized, 1)[:, np.newaxis]
q_a = q_ab[:, 0].reshape((1, 313))
q_b = q_ab[:, 1].reshape((1, 313))
X_a = np.sum(X_colorized * q_a, 1).reshape((h_out, w_out))
X_b = np.sum(X_colorized * q_b, 1).reshape((h_out, w_out))
X_a = cv.resize(X_a, (img_cols, img_rows), cv.INTER_CUBIC)
X_b = cv.resize(X_b, (img_cols, img_rows), cv.INTER_CUBIC)
X_a = X_a + 128
X_b = X_b + 128
out_lab = np.zeros((img_rows, img_cols, 3), dtype=np.int32)
out_lab[:, :, 0] = L
out_lab[:, :, 1] = X_a
out_lab[:, :, 2] = X_b
out_lab = out_lab.astype(np.uint8)
out_bgr = cv.cvtColor(out_lab, cv.COLOR_LAB2BGR)
out_bgr = out_bgr.astype(np.uint8)
return out_bgr
return predict, model
def _plot_image(ax, img, cmap=None, label=''):
ax.imshow(img, cmap)
ax.tick_params(axis='both',
which='both',
bottom='off',
top='off',
left='off',
right='off',
labelleft='off',
labelbottom='off')
ax.set_xlabel(label)
def plot_bw_color_comparison(bw, color):
gs = gridspec.GridSpec(6, 13)
gs.update(hspace=0.1, wspace=0.001)
fig = plt.figure(figsize=(7, 3))
ax = fig.add_subplot(gs[:, 0:6])
_plot_image(ax, bw, cmap='gray', label='original image')
ax = fig.add_subplot(gs[:, 7:13])
_plot_image(ax, color, label='colorized result')
plt.show()