forked from wzx0826/LBNet
-
Notifications
You must be signed in to change notification settings - Fork 2
/
utility.py
104 lines (77 loc) · 2.32 KB
/
utility.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import math
import time
import random
import numpy as np
import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lrs
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.device_count() == 1:
torch.cuda.manual_seed(seed)
else:
torch.cuda.manual_seed_all(seed)
class timer():
def __init__(self):
self.acc = 0
self.tic()
def tic(self):
self.t0 = time.time()
def toc(self):
return time.time() - self.t0
def hold(self):
self.acc += self.toc()
def release(self):
ret = self.acc
self.acc = 0
return ret
def reset(self):
self.acc = 0
def quantize(img, rgb_range):
pixel_range = 255 / rgb_range
return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
def calc_psnr(sr, hr, scale, rgb_range, benchmark=False):
if sr.size(-2) > hr.size(-2) or sr.size(-1) > hr.size(-1):
print("the dimention of sr image is not equal to hr's! ")
sr = sr[:,:,:hr.size(-2),:hr.size(-1)]
diff = (sr - hr).data.div(rgb_range)
if benchmark:
shave = scale
if diff.size(1) > 1:
convert = diff.new(1, 3, 1, 1)
convert[0, 0, 0, 0] = 65.738
convert[0, 1, 0, 0] = 129.057
convert[0, 2, 0, 0] = 25.064
diff.mul_(convert).div_(256)
diff = diff.sum(dim=1, keepdim=True)
else:
shave = scale + 6
valid = diff[:, :, shave:-shave, shave:-shave]
mse = valid.pow(2).mean()
return -10 * math.log10(mse)
def make_optimizer(opt, my_model):
trainable = filter(lambda x: x.requires_grad, my_model.parameters())
optimizer_function = optim.Adam
kwargs = {
'betas': (opt.beta1, opt.beta2),
'eps': opt.epsilon
}
kwargs['lr'] = opt.lr
kwargs['weight_decay'] = opt.weight_decay
return optimizer_function(trainable, **kwargs)
def make_scheduler(opt, my_optimizer):
scheduler = lrs.CosineAnnealingLR(
my_optimizer,
float(opt.epochs),
eta_min=opt.eta_min
)
return scheduler
def init_model(args):
if args.model == 'LBNet':
args.n_feats = 32
args.num_heads = 8
elif args.model == 'LBNet-T':
args.n_feats = 18
args.num_heads = 6