-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtest.py
More file actions
executable file
·45 lines (38 loc) · 1.21 KB
/
test.py
File metadata and controls
executable file
·45 lines (38 loc) · 1.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#!/usr/bin/env python
import argparse
import torch
import utils
from PIL import Image
from tqdm import tqdm
import torch.nn as nn
from torch.autograd import Variable
from torchvision.transforms import ToTensor, ToPILImage, Normalize, Resize
from model.rpnet import Net
parser = argparse.ArgumentParser(description="PyTorch DeepDehazing")
parser.add_argument("--rb", type=int, default=13, help="number of residual blocks")
parser.add_argument("--checkpoint", type=str, help="path to load model checkpoint")
parser.add_argument("--test", type=str, help="path to load test images")
opt = parser.parse_args()
print(opt)
net = Net(opt.rb)
net.load_state_dict(torch.load(opt.checkpoint)['state_dict'])
net.eval()
net = nn.DataParallel(net, device_ids=[0, 1, 2, 3]).cuda()
print(net)
images = utils.load_all_image(opt.test)
for im_path in tqdm(images):
filename = im_path.split('/')[-1]
print(filename)
im = Image.open(im_path)
h, w = im.size
print(h, w)
im = ToTensor()(im)
im = Variable(im).view(1, -1, w, h)
im = im.cuda()
with torch.no_grad():
im = net(im)
im = torch.clamp(im, 0., 1.)
im = im.cpu()
im = im.data[0]
im = ToPILImage()(im)
im.save('output/%s' % filename)