Skip to content

Commit 76a1b91

Browse files
committed
Adds squeeze and excitation (scSE) modules, resolves #157
1 parent 27833da commit 76a1b91

File tree

2 files changed

+90
-11
lines changed

2 files changed

+90
-11
lines changed

robosat/scse.py

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""Squeeze and Excitation blocks - attention for classification and segmentation
2+
3+
See:
4+
- https://arxiv.org/abs/1709.01507 - Squeeze-and-Excitation Networks
5+
- https://arxiv.org/abs/1803.02579 - Concurrent Spatial and Channel 'Squeeze & Excitation' in Fully Convolutional Networks
6+
7+
"""
8+
9+
import torch
10+
import torch.nn as nn
11+
12+
13+
class SpatialSqChannelEx(nn.Module):
14+
"""Spatial Squeeze and Channel Excitation (cSE) block
15+
See https://arxiv.org/abs/1803.02579 Figure 1 b
16+
"""
17+
18+
def __init__(self, num_in, r):
19+
super().__init__()
20+
self.fc0 = Conv1x1(num_in, num_in // r)
21+
self.fc1 = Conv1x1(num_in // r, num_in)
22+
23+
def forward(self, x):
24+
xx = nn.functional.adaptive_avg_pool2d(x, 1)
25+
xx = self.fc0(xx)
26+
xx = nn.functional.relu(xx, inplace=True)
27+
xx = self.fc1(xx)
28+
xx = torch.sigmoid(xx)
29+
return x * xx
30+
31+
32+
class ChannelSqSpatialEx(nn.Module):
33+
"""Channel Squeeze and Spatial Excitation (sSE) block
34+
See https://arxiv.org/abs/1803.02579 Figure 1 c
35+
"""
36+
37+
def __init__(self, num_in):
38+
super().__init__()
39+
self.conv = Conv1x1(num_in, 1)
40+
41+
def forward(self, x):
42+
xx = self.conv(x)
43+
xx = torch.sigmoid(xx)
44+
return x * xx
45+
46+
47+
class SpatialChannelSqChannelEx(nn.Module):
48+
"""Concurrent Spatial and Channel Squeeze and Channel Excitation (csSE) block
49+
See https://arxiv.org/abs/1803.02579 Figure 1 d
50+
"""
51+
52+
def __init__(self, num_in, r=16):
53+
super().__init__()
54+
55+
self.cse = SpatialSqChannelEx(num_in, r)
56+
self.sse = ChannelSqSpatialEx(num_in)
57+
58+
def forward(self, x):
59+
return self.cse(x) + self.sse(x)
60+
61+
62+
def Conv1x1(num_in, num_out):
63+
return nn.Conv2d(num_in, num_out, kernel_size=1, bias=False)

robosat/unet.py

+27-11
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from torchvision.models import resnet50
1616

17+
from robosat.scse import SpatialChannelSqChannelEx
18+
1719

1820
class ConvRelu(nn.Module):
1921
"""3x3 convolution followed by ReLU activation building block.
@@ -91,10 +93,23 @@ def __init__(self, num_classes, num_filters=32, pretrained=True):
9193

9294
# Todo: make input channels configurable, not hard-coded to three channels for RGB
9395

94-
self.resnet = resnet50(pretrained=pretrained)
95-
9696
# Access resnet directly in forward pass; do not store refs here due to
9797
# https://github.com/pytorch/pytorch/issues/8392
98+
self.resnet = resnet50(pretrained=pretrained)
99+
100+
# seSE blocks to append to encoder and decoder as recommended by
101+
# https://arxiv.org/abs/1803.02579
102+
self.scse0 = SpatialChannelSqChannelEx(64)
103+
self.scse1 = SpatialChannelSqChannelEx(256)
104+
self.scse2 = SpatialChannelSqChannelEx(512)
105+
self.scse3 = SpatialChannelSqChannelEx(1024)
106+
self.scse4 = SpatialChannelSqChannelEx(2048)
107+
108+
self.scse5 = SpatialChannelSqChannelEx(num_filters * 8)
109+
self.scse6 = SpatialChannelSqChannelEx(num_filters * 8)
110+
self.scse7 = SpatialChannelSqChannelEx(num_filters * 2)
111+
self.scse8 = SpatialChannelSqChannelEx(num_filters * 2 * 2)
112+
self.scse9 = SpatialChannelSqChannelEx(num_filters)
98113

99114
self.center = DecoderBlock(2048, num_filters * 8)
100115

@@ -122,20 +137,21 @@ def forward(self, x):
122137
enc0 = self.resnet.conv1(x)
123138
enc0 = self.resnet.bn1(enc0)
124139
enc0 = self.resnet.relu(enc0)
140+
enc0 = self.scse0(enc0)
125141
enc0 = self.resnet.maxpool(enc0)
126142

127-
enc1 = self.resnet.layer1(enc0)
128-
enc2 = self.resnet.layer2(enc1)
129-
enc3 = self.resnet.layer3(enc2)
130-
enc4 = self.resnet.layer4(enc3)
143+
enc1 = self.scse1(self.resnet.layer1(enc0))
144+
enc2 = self.scse2(self.resnet.layer2(enc1))
145+
enc3 = self.scse3(self.resnet.layer3(enc2))
146+
enc4 = self.scse4(self.resnet.layer4(enc3))
131147

132148
center = self.center(nn.functional.max_pool2d(enc4, kernel_size=2, stride=2))
133149

134-
dec0 = self.dec0(torch.cat([enc4, center], dim=1))
135-
dec1 = self.dec1(torch.cat([enc3, dec0], dim=1))
136-
dec2 = self.dec2(torch.cat([enc2, dec1], dim=1))
137-
dec3 = self.dec3(torch.cat([enc1, dec2], dim=1))
138-
dec4 = self.dec4(dec3)
150+
dec0 = self.scse5(self.dec0(torch.cat([enc4, center], dim=1)))
151+
dec1 = self.scse6(self.dec1(torch.cat([enc3, dec0], dim=1)))
152+
dec2 = self.scse7(self.dec2(torch.cat([enc2, dec1], dim=1)))
153+
dec3 = self.scse8(self.dec3(torch.cat([enc1, dec2], dim=1)))
154+
dec4 = self.scse9(self.dec4(dec3))
139155
dec5 = self.dec5(dec4)
140156

141157
return self.final(dec5)

0 commit comments

Comments
 (0)