-
Notifications
You must be signed in to change notification settings - Fork 22
/
fid_evaluation.py
90 lines (72 loc) · 3.4 KB
/
fid_evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""
Contains code for logging approximate FID scores during training.
If you want to output ground-truth images from the training dataset, you can
run this file as a script.
"""
import os
import torch
import copy
import argparse
from torchvision.utils import save_image
from pytorch_fid import fid_score
from tqdm import tqdm
import datasets
def output_real_images(dataloader, num_imgs, real_dir):
if len(dataloader.dataset) < num_imgs:
num_imgs = len(dataloader)
img_counter = 0
batch_size = dataloader.batch_size
dataloader = iter(dataloader)
for i in range(num_imgs//batch_size):
real_imgs, _ = next(dataloader)
for img in real_imgs:
save_image(img, os.path.join(real_dir, f'{img_counter:0>5}.jpg'), normalize=True, range=(-1, 1))
img_counter += 1
def setup_evaluation(dataset_name, dataset_path, generated_dir, target_size=128, num_imgs=8000):
# Only make real images if they haven't been made yet
real_dir = os.path.join('EvalImages', dataset_name + '_real_images_' + str(target_size))
if not os.path.exists(real_dir):
os.makedirs(real_dir)
dataloader, CHANNELS = datasets.get_dataset(dataset_name, img_size=target_size, dataset_path=dataset_path)
print('outputting real images...')
output_real_images(dataloader, num_imgs, real_dir)
print('...done')
if generated_dir is not None:
os.makedirs(generated_dir, exist_ok=True)
return real_dir
def output_images(generator, input_metadata, rank, world_size, output_dir, num_imgs=2048):
metadata = copy.deepcopy(input_metadata)
metadata['img_size'] = 128
metadata['batch_size'] = 4
metadata['h_stddev'] = metadata.get('h_stddev_eval', metadata['h_stddev'])
metadata['v_stddev'] = metadata.get('v_stddev_eval', metadata['v_stddev'])
metadata['sample_dist'] = metadata.get('sample_dist_eval', metadata['sample_dist'])
metadata['psi'] = 1
img_counter = rank
generator.eval()
img_counter = rank
ldist = generator.module.ldist
if rank == 0: pbar = tqdm("generating images", total = num_imgs)
with torch.no_grad():
while img_counter < num_imgs:
z = torch.randn((metadata['batch_size'], generator.module.z_dim), device=generator.module.device)
l = ldist.sample(metadata['batch_size'])
generated_imgs = generator.module.staged_forward(z, l, **metadata)['rgb']
for img in generated_imgs:
save_image(img, os.path.join(output_dir, f'{img_counter:0>5}.jpg'), normalize=True, range=(-1, 1))
img_counter += world_size
if rank == 0: pbar.update(world_size)
if rank == 0: pbar.close()
def calculate_fid(dataset_name, generated_dir, target_size=256):
real_dir = os.path.join('EvalImages', dataset_name + '_real_images_' + str(target_size))
fid = fid_score.calculate_fid_given_paths([real_dir, generated_dir], 128, 'cuda', 2048)
torch.cuda.empty_cache()
return fid
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='CelebA')
parser.add_argument('--dataset_path', type=str)
parser.add_argument('--img_size', type=int, default=128)
parser.add_argument('--num_imgs', type=int, default=8000)
opt = parser.parse_args()
real_images_dir = setup_evaluation(opt.dataset, opt.dataset_path, None, target_size=opt.img_size, num_imgs=opt.num_imgs)