-
Notifications
You must be signed in to change notification settings - Fork 675
/
Copy pathrdn.py
105 lines (85 loc) · 3.13 KB
/
rdn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# Residual Dense Network for Image Super-Resolution
# https://arxiv.org/abs/1802.08797
from model import common
import torch
import torch.nn as nn
def make_model(args, parent=False):
return RDN(args)
class RDB_Conv(nn.Module):
def __init__(self, inChannels, growRate, kSize=3):
super(RDB_Conv, self).__init__()
Cin = inChannels
G = growRate
self.conv = nn.Sequential(*[
nn.Conv2d(Cin, G, kSize, padding=(kSize-1)//2, stride=1),
nn.ReLU()
])
def forward(self, x):
out = self.conv(x)
return torch.cat((x, out), 1)
class RDB(nn.Module):
def __init__(self, growRate0, growRate, nConvLayers, kSize=3):
super(RDB, self).__init__()
G0 = growRate0
G = growRate
C = nConvLayers
convs = []
for c in range(C):
convs.append(RDB_Conv(G0 + c*G, G))
self.convs = nn.Sequential(*convs)
# Local Feature Fusion
self.LFF = nn.Conv2d(G0 + C*G, G0, 1, padding=0, stride=1)
def forward(self, x):
return self.LFF(self.convs(x)) + x
class RDN(nn.Module):
def __init__(self, args):
super(RDN, self).__init__()
r = args.scale[0]
G0 = args.G0
kSize = args.RDNkSize
# number of RDB blocks, conv layers, out channels
self.D, C, G = {
'A': (20, 6, 32),
'B': (16, 8, 64),
}[args.RDNconfig]
# Shallow feature extraction net
self.SFENet1 = nn.Conv2d(args.n_colors, G0, kSize, padding=(kSize-1)//2, stride=1)
self.SFENet2 = nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)
# Redidual dense blocks and dense feature fusion
self.RDBs = nn.ModuleList()
for i in range(self.D):
self.RDBs.append(
RDB(growRate0 = G0, growRate = G, nConvLayers = C)
)
# Global Feature Fusion
self.GFF = nn.Sequential(*[
nn.Conv2d(self.D * G0, G0, 1, padding=0, stride=1),
nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)
])
# Up-sampling net
if r == 2 or r == 3:
self.UPNet = nn.Sequential(*[
nn.Conv2d(G0, G * r * r, kSize, padding=(kSize-1)//2, stride=1),
nn.PixelShuffle(r),
nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1)
])
elif r == 4:
self.UPNet = nn.Sequential(*[
nn.Conv2d(G0, G * 4, kSize, padding=(kSize-1)//2, stride=1),
nn.PixelShuffle(2),
nn.Conv2d(G, G * 4, kSize, padding=(kSize-1)//2, stride=1),
nn.PixelShuffle(2),
nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1)
])
else:
raise ValueError("scale must be 2 or 3 or 4.")
def forward(self, x):
f__1 = self.SFENet1(x)
x = self.SFENet2(f__1)
RDBs_out = []
for i in range(self.D):
x = self.RDBs[i](x)
RDBs_out.append(x)
x = self.GFF(torch.cat(RDBs_out,1))
x += f__1
return self.UPNet(x)