-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdemo_ILVR.py
127 lines (90 loc) · 3.47 KB
/
demo_ILVR.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
###############################################################################
# Code adapted and modified from
# https://github.com/jychoi118/ilvr_adm
###############################################################################
import os
import tqdm
from options.test_options import TestOptions
from data.VGGface2HQ import VGGFace2HQDataset
from utils.guided_diffusion import logger
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from models.models import create_model
from utils.plot import plot_batch
class Transform:
def __int__(self):
super(Transform, self).__int__()
def __call__(self, x):
return x * 2 - 1
class DeTransform:
def __int__(self):
super(DeTransform, self).__int__()
def __call__(self, x):
return (x + 1) / 2
if __name__ == '__main__':
# args = create_argparser().parse_args()
opt = TestOptions().parse()
opt.model = 'ILVR'
# th.manual_seed(0)
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
# logger.log("loading data...")
# # data = load_reference(
# # opt.base_samples,
# # args.batch_size,
# # image_size=args.image_size,
# # class_cond=args.class_cond,
# # )
print("Initiating model...")
model = create_model(opt)
print("Model initiated.")
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((opt.image_size, opt.image_size)),
Transform()
])
detransform = DeTransform()
print("Generating data loaders...")
test_data = VGGFace2HQDataset(opt, isTrain=False, transform=transform, is_same_ID=True, auto_same_ID=False)
test_loader = DataLoader(dataset=test_data, batch_size=opt.batchSize, shuffle=True, num_workers=opt.nThreads,
worker_init_fn=test_data.set_worker)
print("Dataloaders ready.")
print("creating samples...")
count = 0
model.eval()
for (img_source, _), _, is_same_ID in tqdm.tqdm(test_loader):
if count >= opt.ntest:
break
count += img_source.shape[0]
img_source = img_source.to(device)
# display images
sample_size = min(8, opt.batchSize)
output_pth = os.path.join(opt.output_path, opt.name)
if not os.path.exists(output_pth):
os.mkdir(output_pth)
sample_path = os.path.join(output_pth, 'samples')
if not os.path.exists(sample_path):
os.mkdir(sample_path)
with torch.no_grad():
img_source = img_source[:sample_size]
imgs = []
zero_img = (torch.zeros_like(img_source[0, ...]))
imgs.append(zero_img.cpu().numpy())
save_img = (detransform(img_source.cpu())).numpy()
for r in range(sample_size):
imgs.append(save_img[r, ...])
for i in range(sample_size):
imgs.append(save_img[i, ...])
image_infer = img_source[i, ...].repeat(sample_size, 1, 1, 1)
img_fake = model.swap(img_source, image_infer)
img_fake = (detransform(img_fake.cpu())).numpy()
for j in range(sample_size):
imgs.append(img_fake[j, ...])
imgs = np.stack(imgs, axis=0).transpose(0, 2, 3, 1)
plot_batch(imgs, os.path.join(sample_path, 'sample_' + str(count) + '.jpg'))
print(f"created {count} samples")
print("sampling complete")