Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support FastFCN #885

Merged
merged 25 commits into from
Sep 30, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion configs/_base_/models/fastfcn_r50-d32.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
neck=dict(
type='JPU',
in_channels=(256, 512, 1024, 2048),
out_channels=512,
start_level=1,
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
end_level=-1,
dilations=(1, 2, 4, 8),
align_corners=False,
out_indices=(0, 1, 2, 3),
Expand Down
91 changes: 48 additions & 43 deletions mmseg/models/necks/jpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@ class JPU(BaseModule):
in_channels (Tuple[int], optional): The number of input channels
for each convolution operations before upsampling.
Default: (256, 512, 1024, 2048).
out_channels (int): The number of output channels. Default: 512.
dilations (tuple[int]): Dilation rate of each layer.
start_level (int): Index of the start input backbone level used to
build the feature pyramid. Default: 1.
end_level (int): Index of the end input backbone level (exclusive) to
build the feature pyramid. Default: -1, which means the last level.
dilations (tuple[int]): Dilation rate of each Depthwise
Separable ConvModule.
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
out_indices (Tuple[int] | int, optional): Output from which stages.
Default: (0, 1, 2, 3).
align_corners (bool, optional): The align_corners argument of
Expand All @@ -38,7 +42,8 @@ class JPU(BaseModule):

def __init__(self,
in_channels=(256, 512, 1024, 2048),
out_channels=512,
start_level=1,
end_level=-1,
dilations=(1, 2, 4, 8),
out_indices=(0, 1, 2, 3),
align_corners=False,
Expand All @@ -48,39 +53,39 @@ def __init__(self,
init_cfg=None):
super(JPU, self).__init__(init_cfg=init_cfg)
assert isinstance(in_channels, tuple)
assert len(in_channels) == 4, 'Length of input channels \
must be 4!'

assert len(dilations) == 4, 'Length of dilations \
must be 4!'

assert out_channels == in_channels[1], 'Output channels must \
be the same with in_channels[1]!'

assert isinstance(dilations, tuple)
self.in_channels = in_channels
self.out_channels = out_channels
self.out_channels = in_channels[start_level]
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
self.start_level = start_level
self.num_ins = len(in_channels)
if end_level == -1:
self.backbone_end_level = self.num_ins
else:
self.backbone_end_level = end_level
assert end_level <= len(in_channels)

self.dilations = dilations
self.align_corners = align_corners
self.out_indices = out_indices

MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
# Note: Names of operations below are referenced from original paper.
for i in range(3):
conv_name = f'conv{i+3}'
self.conv_layers = nn.ModuleList()
self.dilation_layers = nn.ModuleList()
for i in range(self.start_level, self.backbone_end_level):
conv_layer = nn.Sequential(
ConvModule(
self.in_channels[i - 3],
self.in_channels[i],
self.out_channels,
kernel_size=3,
padding=1,
bias=False,
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
self.add_module(conv_name, conv_layer)
for i in range(4):
dilation_name = f'dilation{i+1}'
self.conv_layers.append(conv_layer)
for i in range(len(dilations)):
dilation_layer = nn.Sequential(
DepthwiseSeparableConvModule(
in_channels=3 * self.out_channels,
in_channels=(len(in_channels) - 1) * self.out_channels,
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
out_channels=self.out_channels,
kernel_size=3,
stride=1,
Expand All @@ -90,35 +95,35 @@ def __init__(self,
dw_act_cfg=None,
pw_norm_cfg=norm_cfg,
pw_act_cfg=act_cfg))
self.add_module(dilation_name, dilation_layer)
self.dilation_layers.append(dilation_layer)

def forward(self, inputs):
"""Forward function."""
x_8 = inputs[1]
x_16 = inputs[2]
x_32 = inputs[3]
feats = [self.conv5(x_32), self.conv4(x_16), self.conv3(x_8)]
assert len(inputs) == len(self.in_channels), 'Length of inputs must \
be the same with self.in_channels!'

feats = [
self.conv_layers[i](inputs[i + 1])
for i in range(len(self.in_channels) - 1)
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
]

_, _, h, w = feats[0].size()
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
for i in range(len(feats) - 1):
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
feats[i + 1] = resize(
feats[i + 1],
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
size=(h, w),
mode='bilinear',
align_corners=self.align_corners)

_, _, h, w = feats[-1].size()
feats[-2] = resize(
feats[-2],
size=(h, w),
mode='bilinear',
align_corners=self.align_corners)
feats[-3] = resize(
feats[-3],
size=(h, w),
mode='bilinear',
align_corners=self.align_corners)
feat = torch.cat(feats, dim=1)
feat = torch.cat(feats[::-1], dim=1)
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
feat = torch.cat([
self.dilation1(feat),
self.dilation2(feat),
self.dilation3(feat),
self.dilation4(feat)
self.dilation_layers[i](feat) for i in range(len(self.dilations))
],
dim=1)

outs = [inputs[0], inputs[1], inputs[2], feat]
outs = []
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
for i in range(len(inputs) - 1):
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
outs.append(inputs[i])
outs.append(feat)
outs = [outs[i] for i in self.out_indices]
return tuple(outs)
5 changes: 5 additions & 0 deletions tests/test_models/test_necks/test_jpu.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch

from mmseg.models.necks import JPU
Expand All @@ -23,3 +24,7 @@ def test_fastfcn_neck():
assert feat[1].shape == torch.Size([batch_size, 512, 64, 128])
assert feat[2].shape == torch.Size([batch_size, 1024, 32, 64])
assert feat[3].shape == torch.Size([batch_size, 2048, 64, 128])

with pytest.raises(AssertionError):
# FastFCN input and in_channels constraints.
JPU(in_channels=(128, 256, 512, 1024), start_level=1, end_level=5)