|
14 | 14 |
|
15 | 15 | from torchvision.models import resnet50
|
16 | 16 |
|
| 17 | +from robosat.scse import SpatialChannelSqChannelEx |
| 18 | + |
17 | 19 |
|
18 | 20 | class ConvRelu(nn.Module):
|
19 | 21 | """3x3 convolution followed by ReLU activation building block.
|
@@ -91,10 +93,23 @@ def __init__(self, num_classes, num_filters=32, pretrained=True):
|
91 | 93 |
|
92 | 94 | # Todo: make input channels configurable, not hard-coded to three channels for RGB
|
93 | 95 |
|
94 |
| - self.resnet = resnet50(pretrained=pretrained) |
95 |
| - |
96 | 96 | # Access resnet directly in forward pass; do not store refs here due to
|
97 | 97 | # https://github.com/pytorch/pytorch/issues/8392
|
| 98 | + self.resnet = resnet50(pretrained=pretrained) |
| 99 | + |
| 100 | + # seSE blocks to append to encoder and decoder as recommended by |
| 101 | + # https://arxiv.org/abs/1803.02579 |
| 102 | + self.scse0 = SpatialChannelSqChannelEx(64) |
| 103 | + self.scse1 = SpatialChannelSqChannelEx(256) |
| 104 | + self.scse2 = SpatialChannelSqChannelEx(512) |
| 105 | + self.scse3 = SpatialChannelSqChannelEx(1024) |
| 106 | + self.scse4 = SpatialChannelSqChannelEx(2048) |
| 107 | + |
| 108 | + self.scse5 = SpatialChannelSqChannelEx(num_filters * 8) |
| 109 | + self.scse6 = SpatialChannelSqChannelEx(num_filters * 8) |
| 110 | + self.scse7 = SpatialChannelSqChannelEx(num_filters * 2) |
| 111 | + self.scse8 = SpatialChannelSqChannelEx(num_filters * 2 * 2) |
| 112 | + self.scse9 = SpatialChannelSqChannelEx(num_filters) |
98 | 113 |
|
99 | 114 | self.center = DecoderBlock(2048, num_filters * 8)
|
100 | 115 |
|
@@ -122,20 +137,21 @@ def forward(self, x):
|
122 | 137 | enc0 = self.resnet.conv1(x)
|
123 | 138 | enc0 = self.resnet.bn1(enc0)
|
124 | 139 | enc0 = self.resnet.relu(enc0)
|
| 140 | + enc0 = self.scse0(enc0) |
125 | 141 | enc0 = self.resnet.maxpool(enc0)
|
126 | 142 |
|
127 |
| - enc1 = self.resnet.layer1(enc0) |
128 |
| - enc2 = self.resnet.layer2(enc1) |
129 |
| - enc3 = self.resnet.layer3(enc2) |
130 |
| - enc4 = self.resnet.layer4(enc3) |
| 143 | + enc1 = self.scse1(self.resnet.layer1(enc0)) |
| 144 | + enc2 = self.scse2(self.resnet.layer2(enc1)) |
| 145 | + enc3 = self.scse3(self.resnet.layer3(enc2)) |
| 146 | + enc4 = self.scse4(self.resnet.layer4(enc3)) |
131 | 147 |
|
132 | 148 | center = self.center(nn.functional.max_pool2d(enc4, kernel_size=2, stride=2))
|
133 | 149 |
|
134 |
| - dec0 = self.dec0(torch.cat([enc4, center], dim=1)) |
135 |
| - dec1 = self.dec1(torch.cat([enc3, dec0], dim=1)) |
136 |
| - dec2 = self.dec2(torch.cat([enc2, dec1], dim=1)) |
137 |
| - dec3 = self.dec3(torch.cat([enc1, dec2], dim=1)) |
138 |
| - dec4 = self.dec4(dec3) |
| 150 | + dec0 = self.scse5(self.dec0(torch.cat([enc4, center], dim=1))) |
| 151 | + dec1 = self.scse6(self.dec1(torch.cat([enc3, dec0], dim=1))) |
| 152 | + dec2 = self.scse7(self.dec2(torch.cat([enc2, dec1], dim=1))) |
| 153 | + dec3 = self.scse8(self.dec3(torch.cat([enc1, dec2], dim=1))) |
| 154 | + dec4 = self.scse9(self.dec4(dec3)) |
139 | 155 | dec5 = self.dec5(dec4)
|
140 | 156 |
|
141 | 157 | return self.final(dec5)
|
0 commit comments