forked from mankasto/iLRG
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cal_metrics.py
61 lines (52 loc) · 2.12 KB
/
cal_metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
import cv2
import lpips
import torch
import numpy as np
def get_images(images_list, img_size=28):
images = []
for image_path in images_list:
image = cv2.imread(image_path)
image = cv2.resize(image, (img_size, img_size))
image = image.astype(np.float32) / 255
image = torch.from_numpy(image)
image = image.permute(2, 0, 1)
images.append(image)
images = torch.stack(images).to(device)
return images
def psnr(img_batch, ref_batch, batched=False, factor=1.0):
"""Standard PSNR."""
def get_psnr(img_in, img_ref):
mse = ((img_in - img_ref) ** 2).mean()
if mse > 0 and torch.isfinite(mse):
return 10 * torch.log10(factor ** 2 / mse)
elif not torch.isfinite(mse):
return img_batch.new_tensor(float('nan'))
else:
return img_batch.new_tensor(float('inf'))
if batched:
psnr = get_psnr(img_batch.detach(), ref_batch)
else:
[B, C, m, n] = img_batch.shape
psnrs = []
for sample in range(B):
psnrs.append(get_psnr(img_batch.detach()[sample, :, :, :], ref_batch[sample, :, :, :]))
psnr = torch.stack(psnrs, dim=0).mean()
return psnr.item()
if __name__ == '__main__':
device = 'cuda:0'
dataset = 'mnist'
lpips_loss = lpips.LPIPS(net='vgg', spatial=False).to(device)
x = [os.path.join(f'images/{dataset}_gt', _) for _ in os.listdir(f'images/{dataset}_gt')]
y = [os.path.join(f'images/{dataset}_ig', _) for _ in os.listdir(f'images/{dataset}_ig')]
z = [os.path.join(f'images/{dataset}_ours', _) for _ in os.listdir(f'images/{dataset}_ours')]
x = get_images(x, 28 if dataset == 'mnist' else 32)
y = get_images(y, 28 if dataset == 'mnist' else 32)
z = get_images(z, 28 if dataset == 'mnist' else 32)
with torch.no_grad():
lpips_score = lpips_loss(y, x).squeeze().mean()
lpips_score2 = lpips_loss(z, x).squeeze().mean()
psnr1 = psnr(y, x)
psnr2 = psnr(z, x)
print('LPIPS: IG-%.2f; Ours-%.2f' % (lpips_score.item(), lpips_score2.item()))
print('PSNR: IG-%.2f; Ours-%.2f' % (psnr1, psnr2))