forked from proteus1991/GridDehazeNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
117 lines (93 loc) · 5.29 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""
paper: GridDehazeNet: Attention-Based Multi-Scale Network for Image Dehazing
file: model.py
about: model for GridDehazeNet
author: Xiaohong Liu
date: 01/08/19
"""
# --- Imports --- #
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from residual_dense_block import RDB
# --- Downsampling block in GridDehazeNet --- #
class DownSample(nn.Module):
def __init__(self, in_channels, kernel_size=3, stride=2):
super(DownSample, self).__init__()
self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride=stride, padding=(kernel_size-1)//2)
self.conv2 = nn.Conv2d(in_channels, stride*in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2)
def forward(self, x):
out = F.relu(self.conv1(x))
out = F.relu(self.conv2(out))
return out
# --- Upsampling block in GridDehazeNet --- #
class UpSample(nn.Module):
def __init__(self, in_channels, kernel_size=3, stride=2):
super(UpSample, self).__init__()
self.deconv = nn.ConvTranspose2d(in_channels, in_channels, kernel_size, stride=stride, padding=1)
self.conv = nn.Conv2d(in_channels, in_channels // stride, kernel_size, stride=1, padding=(kernel_size - 1) // 2)
def forward(self, x, output_size):
out = F.relu(self.deconv(x, output_size=output_size))
out = F.relu(self.conv(out))
return out
# --- Main model --- #
class GridDehazeNet(nn.Module):
def __init__(self, in_channels=3, depth_rate=16, kernel_size=3, stride=2, height=3, width=6, num_dense_layer=4, growth_rate=16, attention=True):
super(GridDehazeNet, self).__init__()
self.rdb_module = nn.ModuleDict()
self.upsample_module = nn.ModuleDict()
self.downsample_module = nn.ModuleDict()
self.height = height
self.width = width
self.stride = stride
self.depth_rate = depth_rate
self.coefficient = nn.Parameter(torch.Tensor(np.ones((height, width, 2, depth_rate*stride**(height-1)))), requires_grad=attention)
self.conv_in = nn.Conv2d(in_channels, depth_rate, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)
self.conv_out = nn.Conv2d(depth_rate, in_channels, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)
self.rdb_in = RDB(depth_rate, num_dense_layer, growth_rate)
self.rdb_out = RDB(depth_rate, num_dense_layer, growth_rate)
rdb_in_channels = depth_rate
for i in range(height):
for j in range(width - 1):
self.rdb_module.update({'{}_{}'.format(i, j): RDB(rdb_in_channels, num_dense_layer, growth_rate)})
rdb_in_channels *= stride
_in_channels = depth_rate
for i in range(height - 1):
for j in range(width // 2):
self.downsample_module.update({'{}_{}'.format(i, j): DownSample(_in_channels)})
_in_channels *= stride
for i in range(height - 2, -1, -1):
for j in range(width // 2, width):
self.upsample_module.update({'{}_{}'.format(i, j): UpSample(_in_channels)})
_in_channels //= stride
def forward(self, x):
inp = self.conv_in(x)
x_index = [[0 for _ in range(self.width)] for _ in range(self.height)]
i, j = 0, 0
x_index[0][0] = self.rdb_in(inp)
for j in range(1, self.width // 2):
x_index[0][j] = self.rdb_module['{}_{}'.format(0, j-1)](x_index[0][j-1])
for i in range(1, self.height):
x_index[i][0] = self.downsample_module['{}_{}'.format(i-1, 0)](x_index[i-1][0])
for i in range(1, self.height):
for j in range(1, self.width // 2):
channel_num = int(2**(i-1)*self.stride*self.depth_rate)
x_index[i][j] = self.coefficient[i, j, 0, :channel_num][None, :, None, None] * self.rdb_module['{}_{}'.format(i, j-1)](x_index[i][j-1]) + \
self.coefficient[i, j, 1, :channel_num][None, :, None, None] * self.downsample_module['{}_{}'.format(i-1, j)](x_index[i-1][j])
x_index[i][j+1] = self.rdb_module['{}_{}'.format(i, j)](x_index[i][j])
k = j
for j in range(self.width // 2 + 1, self.width):
x_index[i][j] = self.rdb_module['{}_{}'.format(i, j-1)](x_index[i][j-1])
for i in range(self.height - 2, -1, -1):
channel_num = int(2 ** (i-1) * self.stride * self.depth_rate)
x_index[i][k+1] = self.coefficient[i, k+1, 0, :channel_num][None, :, None, None] * self.rdb_module['{}_{}'.format(i, k)](x_index[i][k]) + \
self.coefficient[i, k+1, 1, :channel_num][None, :, None, None] * self.upsample_module['{}_{}'.format(i, k+1)](x_index[i+1][k+1], x_index[i][k].size())
for i in range(self.height - 2, -1, -1):
for j in range(self.width // 2 + 1, self.width):
channel_num = int(2 ** (i - 1) * self.stride * self.depth_rate)
x_index[i][j] = self.coefficient[i, j, 0, :channel_num][None, :, None, None] * self.rdb_module['{}_{}'.format(i, j-1)](x_index[i][j-1]) + \
self.coefficient[i, j, 1, :channel_num][None, :, None, None] * self.upsample_module['{}_{}'.format(i, j)](x_index[i+1][j], x_index[i][j-1].size())
out = self.rdb_out(x_index[i][j])
out = F.relu(self.conv_out(out))
return out