From 3b33fdf4e311862afcb93936547e03131001c22d Mon Sep 17 00:00:00 2001 From: xiexinch Date: Thu, 22 Apr 2021 20:20:17 +0800 Subject: [PATCH 1/5] init --- mmseg/models/necks/upsample_neck.py | 35 +++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 mmseg/models/necks/upsample_neck.py diff --git a/mmseg/models/necks/upsample_neck.py b/mmseg/models/necks/upsample_neck.py new file mode 100644 index 0000000000..4d7c6ac392 --- /dev/null +++ b/mmseg/models/necks/upsample_neck.py @@ -0,0 +1,35 @@ +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import Conv2d + +from ..builder import NECKS + + +@NECKS.register_module() +class UpsampleNeck(nn.Module): + """Upsample Network.""" + + def __init__(self, + in_channels, + out_channels, + scales=[0.5, 1, 2, 4], + num_outs=4): + super(UpsampleNeck, self).__init__() + assert len(scales) == num_outs + self.scales = scales + self.num_outs = num_outs + self.convs = [ + Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1) + for _ in range(num_outs) + ] + + def forward(self, x): + outs = [] + print(len(self.convs)) + for i in range(self.num_outs): + scale = self.scales[i] + x = self.convs[i](x) + outs.append( + F.interpolate(x, size=x.shape[:2] * scale, mode='bilinear')) + return tuple(outs) From 6c00f926a9e0bf5d75125164554b34301ae8e975 Mon Sep 17 00:00:00 2001 From: xiexinch Date: Fri, 23 Apr 2021 18:54:16 +0800 Subject: [PATCH 2/5] upsample v1.0 --- mmseg/models/necks/__init__.py | 3 ++- mmseg/models/necks/upsample_neck.py | 15 +++++++++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/mmseg/models/necks/__init__.py b/mmseg/models/necks/__init__.py index 0093021eba..3eb794b3e8 100644 --- a/mmseg/models/necks/__init__.py +++ b/mmseg/models/necks/__init__.py @@ -1,3 +1,4 @@ from .fpn import FPN +from .upsample_neck import UpsampleNeck -__all__ = ['FPN'] +__all__ = ['FPN', 'UpsampleNeck'] diff --git a/mmseg/models/necks/upsample_neck.py b/mmseg/models/necks/upsample_neck.py index 4d7c6ac392..31de5fab2f 100644 --- a/mmseg/models/necks/upsample_neck.py +++ b/mmseg/models/necks/upsample_neck.py @@ -1,3 +1,4 @@ +import numpy as np import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import Conv2d @@ -18,18 +19,24 @@ def __init__(self, assert len(scales) == num_outs self.scales = scales self.num_outs = num_outs + self.conv1 = Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.convs = [ Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, padding=1) + out_channels, out_channels, kernel_size=3, stride=1, padding=1) for _ in range(num_outs) ] def forward(self, x): + x = self.conv1(x) + outs = [] - print(len(self.convs)) for i in range(self.num_outs): - scale = self.scales[i] x = self.convs[i](x) outs.append( - F.interpolate(x, size=x.shape[:2] * scale, mode='bilinear')) + F.interpolate( + x, + size=list( + (np.array(x.shape[2:]) * self.scales[i]).astype(int)), + mode='bilinear')) return tuple(outs) From de771f43a995b6a95be395a021e2f43e30e1cb44 Mon Sep 17 00:00:00 2001 From: xiexinch Date: Fri, 23 Apr 2021 19:10:03 +0800 Subject: [PATCH 3/5] fix errors --- mmseg/models/necks/upsample_neck.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/mmseg/models/necks/upsample_neck.py b/mmseg/models/necks/upsample_neck.py index 31de5fab2f..57683ce0d7 100644 --- a/mmseg/models/necks/upsample_neck.py +++ b/mmseg/models/necks/upsample_neck.py @@ -19,8 +19,7 @@ def __init__(self, assert len(scales) == num_outs self.scales = scales self.num_outs = num_outs - self.conv1 = Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.conv1 = Conv2d(in_channels, out_channels, 3, 1, 1) self.convs = [ Conv2d( out_channels, out_channels, kernel_size=3, stride=1, padding=1) @@ -29,14 +28,12 @@ def __init__(self, def forward(self, x): x = self.conv1(x) - outs = [] for i in range(self.num_outs): - x = self.convs[i](x) - outs.append( - F.interpolate( - x, - size=list( - (np.array(x.shape[2:]) * self.scales[i]).astype(int)), - mode='bilinear')) + x_resize = F.interpolate( + x, + size=list( + (np.array(x.shape[2:]) * self.scales[i]).astype(int)), + mode='bilinear') + outs.append(self.convs[i](x_resize)) return tuple(outs) From 8cd7604e1ab452eabf220c0e6856353ae0cf5713 Mon Sep 17 00:00:00 2001 From: xiexinch Date: Sat, 24 Apr 2021 02:27:40 +0800 Subject: [PATCH 4/5] change to in_channels list --- mmseg/models/necks/upsample_neck.py | 41 ++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/mmseg/models/necks/upsample_neck.py b/mmseg/models/necks/upsample_neck.py index 57683ce0d7..c77aad45e4 100644 --- a/mmseg/models/necks/upsample_neck.py +++ b/mmseg/models/necks/upsample_neck.py @@ -1,7 +1,7 @@ import numpy as np import torch.nn as nn import torch.nn.functional as F -from mmcv.cnn import Conv2d +from mmcv.cnn import ConvModule from ..builder import NECKS @@ -16,24 +16,41 @@ def __init__(self, scales=[0.5, 1, 2, 4], num_outs=4): super(UpsampleNeck, self).__init__() + assert isinstance(in_channels, list) assert len(scales) == num_outs + self.in_channels = in_channels + self.out_channels = out_channels self.scales = scales self.num_outs = num_outs - self.conv1 = Conv2d(in_channels, out_channels, 3, 1, 1) - self.convs = [ - Conv2d( - out_channels, out_channels, kernel_size=3, stride=1, padding=1) - for _ in range(num_outs) - ] + self.lateral_convs = nn.ModuleList() + self.convs = nn.ModuleList() + for in_channel in in_channels: + self.lateral_convs.append( + ConvModule(in_channel, out_channels, kernel_size=1)) + for _ in range(num_outs): + self.convs.append( + ConvModule( + out_channels, + out_channels, + kernel_size=3, + padding=1, + stride=1)) - def forward(self, x): - x = self.conv1(x) + def forward(self, inputs): + assert len(inputs) == len(self.in_channels) + inputs = [ + lateral_conv(inputs[i]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + # support num_outs == 1 or num_outs == 4 + if len(inputs) == 1: + inputs = inputs * self.num_outs outs = [] for i in range(self.num_outs): x_resize = F.interpolate( - x, - size=list( - (np.array(x.shape[2:]) * self.scales[i]).astype(int)), + inputs[i], + size=list((np.array(inputs[i].shape[2:]) * + self.scales[i]).astype(int)), mode='bilinear') outs.append(self.convs[i](x_resize)) return tuple(outs) From a6ae77b35c10cb58e3bf60c649e377fcb1d6f0de Mon Sep 17 00:00:00 2001 From: xiexinch Date: Sun, 25 Apr 2021 12:03:19 +0800 Subject: [PATCH 5/5] add unittest, docstring, norm/act config and rename --- mmseg/models/necks/__init__.py | 4 +- mmseg/models/necks/multilevel_neck.py | 70 +++++++++++++++++++ mmseg/models/necks/upsample_neck.py | 56 --------------- .../test_necks/test_multilevel_neck.py | 28 ++++++++ 4 files changed, 100 insertions(+), 58 deletions(-) create mode 100644 mmseg/models/necks/multilevel_neck.py delete mode 100644 mmseg/models/necks/upsample_neck.py create mode 100644 tests/test_models/test_necks/test_multilevel_neck.py diff --git a/mmseg/models/necks/__init__.py b/mmseg/models/necks/__init__.py index 3eb794b3e8..9b9d3d5b3f 100644 --- a/mmseg/models/necks/__init__.py +++ b/mmseg/models/necks/__init__.py @@ -1,4 +1,4 @@ from .fpn import FPN -from .upsample_neck import UpsampleNeck +from .multilevel_neck import MultiLevelNeck -__all__ = ['FPN', 'UpsampleNeck'] +__all__ = ['FPN', 'MultiLevelNeck'] diff --git a/mmseg/models/necks/multilevel_neck.py b/mmseg/models/necks/multilevel_neck.py new file mode 100644 index 0000000000..7e13813b16 --- /dev/null +++ b/mmseg/models/necks/multilevel_neck.py @@ -0,0 +1,70 @@ +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from ..builder import NECKS + + +@NECKS.register_module() +class MultiLevelNeck(nn.Module): + """MultiLevelNeck. + + A neck structure connect vit backbone and decoder_heads. + Args: + in_channels (List[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale). + scales (List[int]): Scale factors for each input feature map. + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (dict): Config dict for activation layer in ConvModule. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + scales=[0.5, 1, 2, 4], + norm_cfg=None, + act_cfg=None): + super(MultiLevelNeck, self).__init__() + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + self.scales = scales + self.num_outs = len(scales) + self.lateral_convs = nn.ModuleList() + self.convs = nn.ModuleList() + for in_channel in in_channels: + self.lateral_convs.append( + ConvModule( + in_channel, + out_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + for _ in range(self.num_outs): + self.convs.append( + ConvModule( + out_channels, + out_channels, + kernel_size=3, + padding=1, + stride=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def forward(self, inputs): + assert len(inputs) == len(self.in_channels) + print(inputs[0].shape) + inputs = [ + lateral_conv(inputs[i]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + # for len(inputs) not equal to self.num_outs + if len(inputs) == 1: + inputs = [inputs[0] for _ in range(self.num_outs)] + outs = [] + for i in range(self.num_outs): + x_resize = F.interpolate( + inputs[i], scale_factor=self.scales[i], mode='bilinear') + outs.append(self.convs[i](x_resize)) + return tuple(outs) diff --git a/mmseg/models/necks/upsample_neck.py b/mmseg/models/necks/upsample_neck.py deleted file mode 100644 index c77aad45e4..0000000000 --- a/mmseg/models/necks/upsample_neck.py +++ /dev/null @@ -1,56 +0,0 @@ -import numpy as np -import torch.nn as nn -import torch.nn.functional as F -from mmcv.cnn import ConvModule - -from ..builder import NECKS - - -@NECKS.register_module() -class UpsampleNeck(nn.Module): - """Upsample Network.""" - - def __init__(self, - in_channels, - out_channels, - scales=[0.5, 1, 2, 4], - num_outs=4): - super(UpsampleNeck, self).__init__() - assert isinstance(in_channels, list) - assert len(scales) == num_outs - self.in_channels = in_channels - self.out_channels = out_channels - self.scales = scales - self.num_outs = num_outs - self.lateral_convs = nn.ModuleList() - self.convs = nn.ModuleList() - for in_channel in in_channels: - self.lateral_convs.append( - ConvModule(in_channel, out_channels, kernel_size=1)) - for _ in range(num_outs): - self.convs.append( - ConvModule( - out_channels, - out_channels, - kernel_size=3, - padding=1, - stride=1)) - - def forward(self, inputs): - assert len(inputs) == len(self.in_channels) - inputs = [ - lateral_conv(inputs[i]) - for i, lateral_conv in enumerate(self.lateral_convs) - ] - # support num_outs == 1 or num_outs == 4 - if len(inputs) == 1: - inputs = inputs * self.num_outs - outs = [] - for i in range(self.num_outs): - x_resize = F.interpolate( - inputs[i], - size=list((np.array(inputs[i].shape[2:]) * - self.scales[i]).astype(int)), - mode='bilinear') - outs.append(self.convs[i](x_resize)) - return tuple(outs) diff --git a/tests/test_models/test_necks/test_multilevel_neck.py b/tests/test_models/test_necks/test_multilevel_neck.py new file mode 100644 index 0000000000..8fb2fc9280 --- /dev/null +++ b/tests/test_models/test_necks/test_multilevel_neck.py @@ -0,0 +1,28 @@ +import torch + +from mmseg.models import MultiLevelNeck + + +def test_multilevel_neck(): + + # Test multi feature maps + in_channels = [256, 512, 1024, 2048] + inputs = [torch.randn(1, c, 14, 14) for i, c in enumerate(in_channels)] + + neck = MultiLevelNeck(in_channels, 256) + outputs = neck(inputs) + assert outputs[0].shape == torch.Size([1, 256, 7, 7]) + assert outputs[1].shape == torch.Size([1, 256, 14, 14]) + assert outputs[2].shape == torch.Size([1, 256, 28, 28]) + assert outputs[3].shape == torch.Size([1, 256, 56, 56]) + + # Test one feature map + in_channels = [768] + inputs = [torch.randn(1, 768, 14, 14)] + + neck = MultiLevelNeck(in_channels, 256) + outputs = neck(inputs) + assert outputs[0].shape == torch.Size([1, 256, 7, 7]) + assert outputs[1].shape == torch.Size([1, 256, 14, 14]) + assert outputs[2].shape == torch.Size([1, 256, 28, 28]) + assert outputs[3].shape == torch.Size([1, 256, 56, 56])