-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate.py
More file actions
123 lines (105 loc) · 4.3 KB
/
generate.py
File metadata and controls
123 lines (105 loc) · 4.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from __future__ import print_function
import sys
sys.path.append('.')
from utils import get_config
from trainer import IPES_Trainer, to_gray
import argparse
from torch.autograd import Variable
import sys
import torch
import os
import numpy as np
from torchvision import datasets, transforms
from PIL import Image
import pdb
name = 'best'
if not os.path.isdir('./output_new1/outputs/%s'%name):
assert 0, "please change the name to your model name"
parser = argparse.ArgumentParser()
parser.add_argument('--output_folder', type=str, default="./", help="output image path")
parser.add_argument('--input_folder', type=str, default="./data/dataset_zheng/", help="input image path")
parser.add_argument('--config', type=str, default='./output_new1/outputs/%s/config.yaml'%name, help="net configuration")
parser.add_argument('--checkpoint_gen', type=str, default="./output_new1/outputs/%s/checkpoints/gen_00100000.pt"%name, help="checkpoint of autoencoders")
parser.add_argument('--checkpoint_id', type=str, default="./output_new1/outputs/%s/checkpoints/id_00100000.pt"%name, help="checkpoint of autoencoders")
parser.add_argument('--batchsize', default=1, type=int, help='batchsize')
parser.add_argument('--a2b', type=int, default=1, help="1 for a2b and others for b2a")
parser.add_argument('--seed', type=int, default=10, help="random seed")
parser.add_argument('--synchronized', action='store_true', help="whether use synchronized style code or not")
parser.add_argument('--output_only', action='store_true', help="whether use synchronized style code or not")
parser.add_argument('--trainer', type=str, default='IPES', help="IPES")
opts = parser.parse_args()
torch.manual_seed(opts.seed)
torch.cuda.manual_seed(opts.seed)
if not os.path.exists(opts.output_folder):
os.makedirs(opts.output_folder)
# Load experiment setting
config = get_config(opts.config)
opts.num_style = 1
# Setup model and data loader
if opts.trainer == 'IPES':
trainer = IPES_Trainer(config)
else:
sys.exit("Only support IPES")
state_dict_gen = torch.load(opts.checkpoint_gen)
trainer.gen_a.load_state_dict(state_dict_gen['a'], strict=False)
trainer.gen_b = trainer.gen_a
state_dict_id = torch.load(opts.checkpoint_id)
trainer.id_a.load_state_dict(state_dict_id['a'])
trainer.id_b = trainer.id_a
trainer.cuda()
trainer.eval()
encode = trainer.gen_a.encode # encode function
style_encode = trainer.gen_a.encode # encode function
id_encode = trainer.id_a # encode function
decode = trainer.gen_a.decode # decode function
data_transforms = transforms.Compose([
transforms.Resize((256,128), interpolation=3),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image_datasets = datasets.ImageFolder(opts.input_folder, data_transforms)
dataloader_content = torch.utils.data.DataLoader(image_datasets, batch_size=1, shuffle=False, num_workers=1)
dataloader_structure = torch.utils.data.DataLoader(image_datasets, batch_size=16, shuffle=False, num_workers=1)
image_paths = image_datasets.imgs
######################################################################
# recover image
# -----------------
def recover(inp):
"""Imshow for Tensor."""
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = inp * 255.0
inp = np.clip(inp, 0, 255)
return inp
save_path = './data/gendata/'
if not os.path.isdir(save_path):
os.mkdir(save_path)
im = {}
count = 0
data2 = next(iter(dataloader_structure))
bg_img, _ = data2
gray = to_gray(False)
bg_img = gray(bg_img)
bg_img = Variable(bg_img.cuda())
with torch.no_grad():
for data in dataloader_content:
id_img, _ = data
id_img = Variable(id_img.cuda())
n, c, h, w = id_img.size()
# Start testing
s = encode(bg_img)
f, _ = id_encode(id_img)
input1 = recover(data[0].squeeze())
im[count] = input1
for i in range(s.size(0)):
s_tmp = s[i,:,:,:]
outputs = decode(s_tmp.unsqueeze(0), f)
tmp = recover(outputs[0].data.cpu())
pic = Image.fromarray(tmp.astype('uint8'))
n = image_paths[i][0]
n1 = n.split('/')[-1]
n2 = n1.split('_')[0]
pic.save('%s/rainbow_%d_%d.jpg'%(save_path,i,count))
count +=1