-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils.py
96 lines (71 loc) · 2.3 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""
NILUT: Conditional Neural Implicit 3D Lookup Tables for Image Enhancement
Utils for training and ploting
"""
import torch
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import gc
import time
from skimage import io, color
# Timing utilities
def start_timer():
global start_time
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
start_time = time.time()
def end_timer_and_print(local_msg):
torch.cuda.synchronize()
end_time = time.time()
print("\n" + local_msg)
print("Total execution time = {:.3f} sec".format(end_time - start_time))
print("Max memory used by tensors = {} bytes".format(torch.cuda.max_memory_allocated()))
def clean_mem():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# Model
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# Load/save and plot images
def load_img (filename, norm=True,):
img = np.array(Image.open(filename))
if norm:
img = img / 255.
img = img.astype(np.float32)
return img
def save_rgb (img, filename):
if np.max(img) <= 1:
img = img * 255
img = img.astype(np.float32)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
cv2.imwrite(filename, img)
def plot_all (images, figsize=(20,10), axis='off', title=None):
nplots = len(images)
fig, axs = plt.subplots(1,nplots, figsize=figsize, dpi=80,constrained_layout=True)
for i in range(nplots):
axs[i].imshow(images[i])
axs[i].axis(axis)
plt.show()
# Metrics
def np_psnr(y_true, y_pred):
mse = np.mean((y_true - y_pred) ** 2)
if(mse == 0): return np.inf
return 20 * np.log10(1 / np.sqrt(mse))
def pt_psnr (y_true, y_pred):
mse = torch.mean((y_true - y_pred) ** 2)
return 20 * torch.log10(1 / torch.sqrt(mse))
def deltae_dist (y_true, y_pred):
"""
Calcultae DeltaE discance in the LAB color space.
Images must numpy arrays.
"""
gt_lab = color.rgb2lab((y_true*255).astype('uint8'))
out_lab = color.rgb2lab((y_pred*255).astype('uint8'))
l2_lab = ((gt_lab - out_lab)**2).mean()
l2_lab = np.sqrt(((gt_lab - out_lab)**2).sum(axis=-1)).mean()
return l2_lab