diff --git a/mmcv/cnn/bricks/conv_module.py b/mmcv/cnn/bricks/conv_module.py index e71d243e15..4f19f1d0cf 100644 --- a/mmcv/cnn/bricks/conv_module.py +++ b/mmcv/cnn/bricks/conv_module.py @@ -3,6 +3,7 @@ import torch.nn as nn +from mmcv.utils import _BatchNorm, _InstanceNorm from ..utils import constant_init, kaiming_init from .activation import build_activation_layer from .conv import build_conv_layer @@ -104,9 +105,6 @@ def __init__(self, bias = not self.with_norm self.with_bias = bias - if self.with_norm and self.with_bias: - warnings.warn('ConvModule has norm and bias at the same time') - if self.with_explicit_padding: pad_cfg = dict(type=padding_mode) self.padding_layer = build_padding_layer(pad_cfg, padding) @@ -147,6 +145,10 @@ def __init__(self, norm_channels = in_channels self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels) self.add_module(self.norm_name, norm) + if self.with_bias: + if isinstance(norm, (_BatchNorm, _InstanceNorm)): + warnings.warn( + 'Unnecessary conv bias before batch/instance norm') else: self.norm_name = None diff --git a/tests/test_cnn/test_conv_module.py b/tests/test_cnn/test_conv_module.py index 63dc61ab87..e231ef3ae3 100644 --- a/tests/test_cnn/test_conv_module.py +++ b/tests/test_cnn/test_conv_module.py @@ -1,3 +1,4 @@ +import warnings from unittest.mock import patch import pytest @@ -161,12 +162,27 @@ def test_bias(): conv = ConvModule(3, 8, 2, bias=False) assert conv.conv.bias is None - # bias: True, with norm + # bias: True, with batch norm with pytest.warns(UserWarning) as record: ConvModule(3, 8, 2, bias=True, norm_cfg=dict(type='BN')) assert len(record) == 1 assert record[0].message.args[ - 0] == 'ConvModule has norm and bias at the same time' + 0] == 'Unnecessary conv bias before batch/instance norm' + + # bias: True, with instance norm + with pytest.warns(UserWarning) as record: + ConvModule(3, 8, 2, bias=True, norm_cfg=dict(type='IN')) + assert len(record) == 1 + assert record[0].message.args[ + 0] == 'Unnecessary conv bias before batch/instance norm' + + # bias: True, with other norm + with pytest.warns(UserWarning) as record: + norm_cfg = dict(type='GN', num_groups=1) + ConvModule(3, 8, 2, bias=True, norm_cfg=norm_cfg) + warnings.warn('No warnings') + assert len(record) == 1 + assert record[0].message.args[0] == 'No warnings' def conv_forward(self, x):