diff --git a/mmseg/models/decode_heads/fcn_head.py b/mmseg/models/decode_heads/fcn_head.py index fb79a0d7c7..0ba09a010f 100644 --- a/mmseg/models/decode_heads/fcn_head.py +++ b/mmseg/models/decode_heads/fcn_head.py @@ -37,28 +37,21 @@ def __init__(self, conv_padding = (kernel_size // 2) * dilation convs = [] - convs.append( - ConvModule( - self.in_channels, - self.channels, - kernel_size=kernel_size, - padding=conv_padding, - dilation=dilation, - conv_cfg=self.conv_cfg, - norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg)) - for i in range(num_convs - 1): + for i in range(num_convs): + _in_channels = self.in_channels if i == 0 else self.channels convs.append( ConvModule( - self.channels, + _in_channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg)) - if num_convs == 0: + act_cfg=self.act_cfg + )) + + if len(convs) == 0: self.convs = nn.Identity() else: self.convs = nn.Sequential(*convs)