-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathinference.py
68 lines (53 loc) · 2.13 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import argparse
import time
import torch
from PIL import Image
from torch.autograd import Variable
from torchvision.transforms import ToTensor, ToPILImage
from utils.config import Config
from model import FasterRCNNVGG16
from trainer import FasterRCNNTrainer
from data.util import read_image
from utils.vis_tool import vis_bbox
from utils import array_tool as at
from SRGAN import Generator
parser = argparse.ArgumentParser(description='Test Single Image')
parser.add_argument('--upscale_factor', default=4, type=int, help='super resolution upscale factor')
parser.add_argument('--test_mode', default='GPU', type=str, choices=['GPU', 'CPU'], help='using GPU or CPU')
parser.add_argument('--model_name', default='netG_epoch_4_100.pth', type=str, help='generator model epoch name')
opt = parser.parse_args()
FasterRCNNOpt = Config()
UPSCALE_FACTOR = opt.upscale_factor
TEST_MODE = True if opt.test_mode == 'GPU' else False
MODEL_NAME = opt.model_name
gan_model = Generator(UPSCALE_FACTOR).eval()
faster_rcnn = FasterRCNNVGG16()
trainer = FasterRCNNTrainer(faster_rcnn)
if TEST_MODE:
gan_model.cuda()
trainer.cuda()
gan_model.load_state_dict(torch.load('epochs/' + MODEL_NAME))
# trainer.load('epochs/samir_fast_rcnn_epoch60.pth')
else:
gan_model.load_state_dict(torch.load('epochs/' + MODEL_NAME, map_location=lambda storage, loc: storage))
# trainer.load('epochs/samir_fast_rcnn_epoch60.pth')
image = read_image('misc/demo.jpg')
image = Variable(ToTensor()(image), volatile=True).unsqueeze(0)
if TEST_MODE:
image = image.cuda()
start = time.clock()
out = gan_model(image)
out_img = ToPILImage()(out[0].data.cpu())
out_img.save('out_srf_' + str(UPSCALE_FACTOR) + '_' + IMAGE_NAME)
# _bboxes, _labels, _scores = trainer.faster_rcnn.predict(out_img, visualize=True)
# ax = vis_bbox(at.tonumpy(img[0]),
# at.tonumpy(_bboxes[0]),
# at.tonumpy(_labels[0]).reshape(-1),
# at.tonumpy(_scores[0]).reshape(-1))
#
# plt.show()
elapsed = (time.clock() - start)
print('cost' + str(elapsed) + 's')
# out_img = ToPILImage()(out[0].data.cpu())
# out_img.save('out_srf_' + str(UPSCALE_FACTOR) + '_' + IMAGE_NAME)