-
Notifications
You must be signed in to change notification settings - Fork 12
/
segnet.py
115 lines (93 loc) · 4.46 KB
/
segnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch.nn as nn
import torch.nn.functional as F
class SegNet(nn.Module):
"""SegNet: A Deep Convolutional Encoder-Decoder Architecture for
Image Segmentation. https://arxiv.org/abs/1511.00561
See https://github.com/alexgkendall/SegNet-Tutorial for original models.
Args:
num_classes (int): number of classes to segment
n_init_features (int): number of input features in the fist convolution
drop_rate (float): dropout rate of each encoder/decoder module
filter_config (list of 5 ints): number of output features at each level
"""
def __init__(self, num_classes, n_init_features=1, drop_rate=0.5,
filter_config=(64, 128, 256, 512, 512)):
super(SegNet, self).__init__()
self.encoders = nn.ModuleList()
self.decoders = nn.ModuleList()
# setup number of conv-bn-relu blocks per module and number of filters
encoder_n_layers = (2, 2, 3, 3, 3)
encoder_filter_config = (n_init_features,) + filter_config
decoder_n_layers = (3, 3, 3, 2, 1)
decoder_filter_config = filter_config[::-1] + (filter_config[0],)
for i in range(0, 5):
# encoder architecture
self.encoders.append(_Encoder(encoder_filter_config[i],
encoder_filter_config[i + 1],
encoder_n_layers[i], drop_rate))
# decoder architecture
self.decoders.append(_Decoder(decoder_filter_config[i],
decoder_filter_config[i + 1],
decoder_n_layers[i], drop_rate))
# final classifier (equivalent to a fully connected layer)
self.classifier = nn.Conv2d(filter_config[0], num_classes, 3, 1, 1)
def forward(self, x):
indices = []
unpool_sizes = []
feat = x
# encoder path, keep track of pooling indices and features size
for i in range(0, 5):
(feat, ind), size = self.encoders[i](feat)
indices.append(ind)
unpool_sizes.append(size)
# decoder path, upsampling with corresponding indices and size
for i in range(0, 5):
feat = self.decoders[i](feat, indices[4 - i], unpool_sizes[4 - i])
return self.classifier(feat)
class _Encoder(nn.Module):
def __init__(self, n_in_feat, n_out_feat, n_blocks=2, drop_rate=0.5):
"""Encoder layer follows VGG rules + keeps pooling indices
Args:
n_in_feat (int): number of input features
n_out_feat (int): number of output features
n_blocks (int): number of conv-batch-relu block inside the encoder
drop_rate (float): dropout rate to use
"""
super(_Encoder, self).__init__()
layers = [nn.Conv2d(n_in_feat, n_out_feat, 3, 1, 1),
nn.BatchNorm2d(n_out_feat),
nn.ReLU(inplace=True)]
if n_blocks > 1:
layers += [nn.Conv2d(n_out_feat, n_out_feat, 3, 1, 1),
nn.BatchNorm2d(n_out_feat),
nn.ReLU(inplace=True)]
if n_blocks == 3:
layers += [nn.Dropout(drop_rate)]
self.features = nn.Sequential(*layers)
def forward(self, x):
output = self.features(x)
return F.max_pool2d(output, 2, 2, return_indices=True), output.size()
class _Decoder(nn.Module):
"""Decoder layer decodes the features by unpooling with respect to
the pooling indices of the corresponding decoder part.
Args:
n_in_feat (int): number of input features
n_out_feat (int): number of output features
n_blocks (int): number of conv-batch-relu block inside the decoder
drop_rate (float): dropout rate to use
"""
def __init__(self, n_in_feat, n_out_feat, n_blocks=2, drop_rate=0.5):
super(_Decoder, self).__init__()
layers = [nn.Conv2d(n_in_feat, n_in_feat, 3, 1, 1),
nn.BatchNorm2d(n_in_feat),
nn.ReLU(inplace=True)]
if n_blocks > 1:
layers += [nn.Conv2d(n_in_feat, n_out_feat, 3, 1, 1),
nn.BatchNorm2d(n_out_feat),
nn.ReLU(inplace=True)]
if n_blocks == 3:
layers += [nn.Dropout(drop_rate)]
self.features = nn.Sequential(*layers)
def forward(self, x, indices, size):
unpooled = F.max_unpool2d(x, indices, 2, 2, 0, size)
return self.features(unpooled)