-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcalculate_lpips.py
executable file
·71 lines (56 loc) · 2.44 KB
/
calculate_lpips.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import argparse
import cv2
import glob
import numpy as np
import os.path as osp
from torchvision.transforms.functional import normalize
# from tqdm import tqdm
from pathlib import Path
from basicsr.utils import img2tensor
# https://download.pytorch.org/models/vgg16-397923af.pth
try:
import lpips
except ImportError:
print('Please install lpips: pip install lpips')
def calculate_lpips(args):
# Configurations
loss_fn_vgg = lpips.LPIPS(net='vgg').cuda() # RGB, normalized to [-1,1]
lpips_all = []
img_list = sorted(glob.glob(osp.join(args.gt_folder, '*')))
restored_list = sorted(glob.glob(osp.join(args.restored_folder, '*')))
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]
for i, (restored_path, img_path) in enumerate(zip(restored_list, img_list)):
img_gt = cv2.imread(img_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
img_restored = cv2.imread(restored_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
img_gt, img_restored = img2tensor([img_gt, img_restored], bgr2rgb=True, float32=True)
# norm to [-1, 1]
normalize(img_gt, mean, std, inplace=True)
normalize(img_restored, mean, std, inplace=True)
# calculate lpips
lpips_val = loss_fn_vgg(img_restored.unsqueeze(0).cuda(), img_gt.unsqueeze(0).cuda())
lpips_val = lpips_val.cpu().item()
# print(f'{i+1:3d}: {basename:25}. \tLPIPS: {lpips_val:.6f}.')
lpips_all.append(lpips_val)
# print(args.restored_folder)
# print(f'Average: LPIPS: {sum(lpips_all) / len(lpips_all):.6f}')
lpips_avg = sum(lpips_all) / len(lpips_all)
output_text_file = Path(args.out_path) / 'lpips.txt'
with open(output_text_file, 'a') as f:
f.write('Identity Metric\n')
f.write(f'Average LPIPS {lpips_avg}\n')
f.write('filename | LPIPS\n')
for file_idx in range(len(restored_list)):
f.write(f'{Path(restored_list[file_idx]).stem} | {lpips_all[file_idx]}\n')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-restored_folder', type=str, help='Path to the folder.', required=True)
parser.add_argument('-gt_folder', type=str, help='Path to the folder.', required=True)
parser.add_argument(
"--out_path",
type=str,
default='metrics',
help='text file summarizing results',
)
args = parser.parse_args()
calculate_lpips(args)