From 98ef5ac705aa575ebcd248acc874d257a211fa67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E6=98=95=E8=BE=B0?= Date: Sun, 25 Apr 2021 12:22:09 +0800 Subject: [PATCH] add upsample neck (#512) * init * upsample v1.0 * fix errors * change to in_channels list * add unittest, docstring, norm/act config and rename Co-authored-by: xiexinch --- mmseg/models/necks/__init__.py | 3 +- mmseg/models/necks/multilevel_neck.py | 70 +++++++++++++++++++ .../test_necks/test_multilevel_neck.py | 28 ++++++++ 3 files changed, 100 insertions(+), 1 deletion(-) create mode 100644 mmseg/models/necks/multilevel_neck.py create mode 100644 tests/test_models/test_necks/test_multilevel_neck.py diff --git a/mmseg/models/necks/__init__.py b/mmseg/models/necks/__init__.py index 0093021eba..9b9d3d5b3f 100644 --- a/mmseg/models/necks/__init__.py +++ b/mmseg/models/necks/__init__.py @@ -1,3 +1,4 @@ from .fpn import FPN +from .multilevel_neck import MultiLevelNeck -__all__ = ['FPN'] +__all__ = ['FPN', 'MultiLevelNeck'] diff --git a/mmseg/models/necks/multilevel_neck.py b/mmseg/models/necks/multilevel_neck.py new file mode 100644 index 0000000000..7e13813b16 --- /dev/null +++ b/mmseg/models/necks/multilevel_neck.py @@ -0,0 +1,70 @@ +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from ..builder import NECKS + + +@NECKS.register_module() +class MultiLevelNeck(nn.Module): + """MultiLevelNeck. + + A neck structure connect vit backbone and decoder_heads. + Args: + in_channels (List[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale). + scales (List[int]): Scale factors for each input feature map. + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (dict): Config dict for activation layer in ConvModule. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + scales=[0.5, 1, 2, 4], + norm_cfg=None, + act_cfg=None): + super(MultiLevelNeck, self).__init__() + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + self.scales = scales + self.num_outs = len(scales) + self.lateral_convs = nn.ModuleList() + self.convs = nn.ModuleList() + for in_channel in in_channels: + self.lateral_convs.append( + ConvModule( + in_channel, + out_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + for _ in range(self.num_outs): + self.convs.append( + ConvModule( + out_channels, + out_channels, + kernel_size=3, + padding=1, + stride=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def forward(self, inputs): + assert len(inputs) == len(self.in_channels) + print(inputs[0].shape) + inputs = [ + lateral_conv(inputs[i]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + # for len(inputs) not equal to self.num_outs + if len(inputs) == 1: + inputs = [inputs[0] for _ in range(self.num_outs)] + outs = [] + for i in range(self.num_outs): + x_resize = F.interpolate( + inputs[i], scale_factor=self.scales[i], mode='bilinear') + outs.append(self.convs[i](x_resize)) + return tuple(outs) diff --git a/tests/test_models/test_necks/test_multilevel_neck.py b/tests/test_models/test_necks/test_multilevel_neck.py new file mode 100644 index 0000000000..8fb2fc9280 --- /dev/null +++ b/tests/test_models/test_necks/test_multilevel_neck.py @@ -0,0 +1,28 @@ +import torch + +from mmseg.models import MultiLevelNeck + + +def test_multilevel_neck(): + + # Test multi feature maps + in_channels = [256, 512, 1024, 2048] + inputs = [torch.randn(1, c, 14, 14) for i, c in enumerate(in_channels)] + + neck = MultiLevelNeck(in_channels, 256) + outputs = neck(inputs) + assert outputs[0].shape == torch.Size([1, 256, 7, 7]) + assert outputs[1].shape == torch.Size([1, 256, 14, 14]) + assert outputs[2].shape == torch.Size([1, 256, 28, 28]) + assert outputs[3].shape == torch.Size([1, 256, 56, 56]) + + # Test one feature map + in_channels = [768] + inputs = [torch.randn(1, 768, 14, 14)] + + neck = MultiLevelNeck(in_channels, 256) + outputs = neck(inputs) + assert outputs[0].shape == torch.Size([1, 256, 7, 7]) + assert outputs[1].shape == torch.Size([1, 256, 14, 14]) + assert outputs[2].shape == torch.Size([1, 256, 28, 28]) + assert outputs[3].shape == torch.Size([1, 256, 56, 56])