forked from amanchadha/iSeeBetter
-
Notifications
You must be signed in to change notification settings - Fork 0
/
rbpn.py
138 lines (117 loc) · 4.97 KB
/
rbpn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
import torch.nn as nn
import torch.optim as optim
from base_networks import *
from torchvision.transforms import *
import torch.nn.functional as F
from dbpns import Net as DBPNS
from torchvision.models import vgg16
class Net(nn.Module):
def __init__(self, num_channels, base_filter, feat, num_stages, n_resblock, nFrames, scale_factor):
super(Net, self).__init__()
#base_filter=256
#feat=64
self.nFrames = nFrames
if scale_factor == 2:
kernel = 6
stride = 2
padding = 2
elif scale_factor == 4:
kernel = 8
stride = 4
padding = 2
elif scale_factor == 8:
kernel = 12
stride = 8
padding = 2
#Initial Feature Extraction
self.feat0 = ConvBlock(num_channels, base_filter, 3, 1, 1, activation='prelu', norm=None)
self.feat1 = ConvBlock(8, base_filter, 3, 1, 1, activation='prelu', norm=None)
###DBPNS
self.DBPN = DBPNS(base_filter, feat, num_stages, scale_factor)
#Res-Block1
modules_body1 = [
ResnetBlock(base_filter, kernel_size=3, stride=1, padding=1, bias=True, activation='prelu', norm=None) \
for _ in range(n_resblock)]
modules_body1.append(DeconvBlock(base_filter, feat, kernel, stride, padding, activation='prelu', norm=None))
self.res_feat1 = nn.Sequential(*modules_body1)
#Res-Block2
modules_body2 = [
ResnetBlock(feat, kernel_size=3, stride=1, padding=1, bias=True, activation='prelu', norm=None) \
for _ in range(n_resblock)]
modules_body2.append(ConvBlock(feat, feat, 3, 1, 1, activation='prelu', norm=None))
self.res_feat2 = nn.Sequential(*modules_body2)
#Res-Block3
modules_body3 = [
ResnetBlock(feat, kernel_size=3, stride=1, padding=1, bias=True, activation='prelu', norm=None) \
for _ in range(n_resblock)]
modules_body3.append(ConvBlock(feat, base_filter, kernel, stride, padding, activation='prelu', norm=None))
self.res_feat3 = nn.Sequential(*modules_body3)
#Reconstruction
self.output = ConvBlock((nFrames-1)*feat, num_channels, 3, 1, 1, activation=None, norm=None)
for m in self.modules():
classname = m.__class__.__name__
if classname.find('Conv2d') != -1:
torch.nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('ConvTranspose2d') != -1:
torch.nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, neigbor, flow):
### initial feature extraction
feat_input = self.feat0(x)
feat_frame=[]
for j in range(len(neigbor)):
feat_frame.append(self.feat1(torch.cat((x, neigbor[j], flow[j]),1)))
####Projection
Ht = []
for j in range(len(neigbor)):
h0 = self.DBPN(feat_input)
h1 = self.res_feat1(feat_frame[j])
e = h0-h1
e = self.res_feat2(e)
h = h0+e
Ht.append(h)
feat_input = self.res_feat3(h)
####Reconstruction
out = torch.cat(Ht,1)
output = self.output(out)
return output
class GeneratorLoss(nn.Module):
def __init__(self):
super(GeneratorLoss, self).__init__()
vgg = vgg16(pretrained=True)
loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()
for param in loss_network.parameters():
param.requires_grad = False
self.loss_network = loss_network
self.mse_loss = nn.MSELoss()
self.tv_loss = TVLoss()
def forward(self, out_labels, hr_est, hr_img, idx):
# Adversarial Loss
adversarial_loss = -torch.mean(out_labels)
# Perception Loss
perception_loss = self.mse_loss(self.loss_network(hr_est), self.loss_network(hr_img))
# Image Loss
image_loss = self.mse_loss(hr_est, hr_img)
# TV Loss
tv_loss = self.tv_loss(hr_est)
return image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss
class TVLoss(nn.Module):
def __init__(self, tv_loss_weight=1):
super(TVLoss, self).__init__()
self.tv_loss_weight = tv_loss_weight
def forward(self, x):
batch_size = x.size()[0]
h_x = x.size()[2]
w_x = x.size()[3]
count_h = self.tensor_size(x[:, :, 1:, :])
count_w = self.tensor_size(x[:, :, :, 1:])
h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size
@staticmethod
def tensor_size(t):
return t.size()[1] * t.size()[2] * t.size()[3]