-
Notifications
You must be signed in to change notification settings - Fork 2
/
test.py
51 lines (41 loc) · 1.67 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import os
from tqdm import tqdm
import torch
from models.models import DecompModelInference
from tools.dataloader import get_loader
from tools.visualizer import Visualizer
from tools.args import BaseOptions
from tools.helper_functions import TrainTools
opt = BaseOptions().parse(save=False)
visualizer = Visualizer(opt)
tt = TrainTools(opt)
eval_data = get_loader(opt, 'test')
dataset_size = len(eval_data)
print(f'# evaluation images = {dataset_size}')
ckpt_path = os.path.join(opt.checkpoints_dir, opt.name)
# define D & G, losses and optim
my_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DecompModelInference().to(my_device)
model.load_model(model.netG, model.optimizerG, 'G', opt.which)
model.eval()
print('currently running: ', opt.name)
total_loss = dict.fromkeys(['SSIM', 'PSNR', 'Perc_loss'], 0)
with torch.no_grad():
for data in tqdm(eval_data):
out, loss_dict = model(*tt.process_inputs(data, mode='test'))
total_loss['Perc_loss'] += loss_dict['G_perc']
total_loss['SSIM'] += loss_dict['SSIM']
total_loss['PSNR'] += loss_dict['PSNR']
for k, v in total_loss.items():
total_loss[k] = v/dataset_size
# print out errors
errors = {k: round(v.data.item(),4) if isinstance(v, torch.Tensor) else round(v,4) for k, v in total_loss.items()}
visualizer.print_test_errors(errors)
print(errors)
# save images
visualizer.better_save(data['reference'], 'test', 'ref')
visualizer.better_save(data['input'], 'test', 'in')
visualizer.better_save(out['fake'], 'test', 'out')
#visualizer.better_save(out['r_map'], 'test', 'r')
#visualizer.better_save(out['s_map'], 'test', 's')
#visualizer.better_save(out['new_s_map'], 'test', 's_new')