-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* init * upsample v1.0 * fix errors * change to in_channels list * add unittest, docstring, norm/act config and rename Co-authored-by: xiexinch <test767803@foxmail.com>
- Loading branch information
谢昕辰
and
xiexinch
authored
Apr 25, 2021
1 parent
84fb600
commit 98ef5ac
Showing
3 changed files
with
100 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .fpn import FPN | ||
from .multilevel_neck import MultiLevelNeck | ||
|
||
__all__ = ['FPN'] | ||
__all__ = ['FPN', 'MultiLevelNeck'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from mmcv.cnn import ConvModule | ||
|
||
from ..builder import NECKS | ||
|
||
|
||
@NECKS.register_module() | ||
class MultiLevelNeck(nn.Module): | ||
"""MultiLevelNeck. | ||
A neck structure connect vit backbone and decoder_heads. | ||
Args: | ||
in_channels (List[int]): Number of input channels per scale. | ||
out_channels (int): Number of output channels (used at each scale). | ||
scales (List[int]): Scale factors for each input feature map. | ||
norm_cfg (dict): Config dict for normalization layer. Default: None. | ||
act_cfg (dict): Config dict for activation layer in ConvModule. | ||
Default: None. | ||
""" | ||
|
||
def __init__(self, | ||
in_channels, | ||
out_channels, | ||
scales=[0.5, 1, 2, 4], | ||
norm_cfg=None, | ||
act_cfg=None): | ||
super(MultiLevelNeck, self).__init__() | ||
assert isinstance(in_channels, list) | ||
self.in_channels = in_channels | ||
self.out_channels = out_channels | ||
self.scales = scales | ||
self.num_outs = len(scales) | ||
self.lateral_convs = nn.ModuleList() | ||
self.convs = nn.ModuleList() | ||
for in_channel in in_channels: | ||
self.lateral_convs.append( | ||
ConvModule( | ||
in_channel, | ||
out_channels, | ||
kernel_size=1, | ||
norm_cfg=norm_cfg, | ||
act_cfg=act_cfg)) | ||
for _ in range(self.num_outs): | ||
self.convs.append( | ||
ConvModule( | ||
out_channels, | ||
out_channels, | ||
kernel_size=3, | ||
padding=1, | ||
stride=1, | ||
norm_cfg=norm_cfg, | ||
act_cfg=act_cfg)) | ||
|
||
def forward(self, inputs): | ||
assert len(inputs) == len(self.in_channels) | ||
print(inputs[0].shape) | ||
inputs = [ | ||
lateral_conv(inputs[i]) | ||
for i, lateral_conv in enumerate(self.lateral_convs) | ||
] | ||
# for len(inputs) not equal to self.num_outs | ||
if len(inputs) == 1: | ||
inputs = [inputs[0] for _ in range(self.num_outs)] | ||
outs = [] | ||
for i in range(self.num_outs): | ||
x_resize = F.interpolate( | ||
inputs[i], scale_factor=self.scales[i], mode='bilinear') | ||
outs.append(self.convs[i](x_resize)) | ||
return tuple(outs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import torch | ||
|
||
from mmseg.models import MultiLevelNeck | ||
|
||
|
||
def test_multilevel_neck(): | ||
|
||
# Test multi feature maps | ||
in_channels = [256, 512, 1024, 2048] | ||
inputs = [torch.randn(1, c, 14, 14) for i, c in enumerate(in_channels)] | ||
|
||
neck = MultiLevelNeck(in_channels, 256) | ||
outputs = neck(inputs) | ||
assert outputs[0].shape == torch.Size([1, 256, 7, 7]) | ||
assert outputs[1].shape == torch.Size([1, 256, 14, 14]) | ||
assert outputs[2].shape == torch.Size([1, 256, 28, 28]) | ||
assert outputs[3].shape == torch.Size([1, 256, 56, 56]) | ||
|
||
# Test one feature map | ||
in_channels = [768] | ||
inputs = [torch.randn(1, 768, 14, 14)] | ||
|
||
neck = MultiLevelNeck(in_channels, 256) | ||
outputs = neck(inputs) | ||
assert outputs[0].shape == torch.Size([1, 256, 7, 7]) | ||
assert outputs[1].shape == torch.Size([1, 256, 14, 14]) | ||
assert outputs[2].shape == torch.Size([1, 256, 28, 28]) | ||
assert outputs[3].shape == torch.Size([1, 256, 56, 56]) |