-
Notifications
You must be signed in to change notification settings - Fork 1
/
infer.py
102 lines (81 loc) · 3.54 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import numpy as np
import torch
import os
import argparse
import time
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from PIL import Image
import utils
from network import Generator, Discriminator
from dataset import ImgPairDataset
# Global Variables
BATCH_SIZE = 1
IMAGE_SIZE = 256
def infer(args):
# GPU enabling
if (args.gpu != None):
use_cuda = True
dtype = torch.cuda.FloatTensor
torch.cuda.set_device(args.gpu)
print("Current device: %s" %torch.cuda.get_device_name(args.gpu))
# define networks
g_AtoB = Generator().type(dtype)
g_BtoA = Generator().type(dtype)
# load pretrained model parameters
g_AtoB.load_state_dict(torch.load(args.modelAtoB))
g_BtoA.load_state_dict(torch.load(args.modelBtoA))
# set to evaluation mode
g_AtoB.eval()
g_BtoA.eval()
# get training data
dataset_transform = transforms.Compose([
transforms.Resize(int(IMAGE_SIZE * 1.2), Image.BICUBIC), # scale shortest side to image_size
transforms.RandomCrop((IMAGE_SIZE, IMAGE_SIZE)), # random center image_size out
transforms.ToTensor(), # turn image from [0-255] to [0-1]
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # normalize
])
dataloader = DataLoader(ImgPairDataset(args.dataroot, dataset_transform, 'test'),
batch_size = BATCH_SIZE,
shuffle=True)
# make folders to save data
A_inStyleOfB_folder = args.dataroot + '/testA_inStyleOfB'
A_folder = args.dataroot + '/testA_before'
B_inStyleOfA_folder = args.dataroot + '/testB_inStyleOfA'
B_folder = args.dataroot + '/testB_before'
if not os.path.exists(A_inStyleOfB_folder):
os.makedirs(A_inStyleOfB_folder)
if not os.path.exists(A_folder):
os.makedirs(A_folder)
if not os.path.exists(B_inStyleOfA_folder):
os.makedirs(B_inStyleOfA_folder)
if not os.path.exists(B_folder):
os.makedirs(B_folder)
# iterate through folder
for idx, batch in enumerate(dataloader):
real_A = batch['A'].type(dtype)
real_B = batch['B'].type(dtype)
start = time.time()
A_inStyleOfB = g_AtoB(real_A).cpu()
end = time.time()
B_inStyleOfA = g_BtoA(real_B).cpu()
time_array.append(end - start)
A_after_imgPath = A_inStyleOfB_folder + '/%03d.png' % (idx)
B_after_imgPath = B_inStyleOfA_folder + '/%03d.png' % (idx)
utils.save_image(A_after_imgPath, A_inStyleOfB.data[0])
utils.save_image(B_after_imgPath, B_inStyleOfA.data[0])
A_before_imgPath = A_folder + '/%03d.png' % (idx)
B_before_imgPath = B_folder + '/%03d.png' % (idx)
utils.save_image(A_before_imgPath, real_A.cpu().data[0])
utils.save_image(B_before_imgPath, real_B.cpu().data[0])
def main():
parser = argparse.ArgumentParser(description='Apply CycleGAN with trained models onto a folder of images')
parser.add_argument("--dataroot", type=str, required=True, help="path to dataset in defined file hierarchy")
parser.add_argument("--gpu", type=int, default=None, help="ID of GPU to be used")
parser.add_argument("--modelAtoB", type=str, required=True, help="path to folder with models for A to B generator")
parser.add_argument("--modelBtoA", type=str, required=True, help="path to folder with models for B to A generator")
args = parser.parse_args()
infer(args)
if __name__ == '__main__':
main()