From 99088c81a8db77ec4bb3fb16f8703a82e2433210 Mon Sep 17 00:00:00 2001 From: young <45293661+yyz561@users.noreply.github.com> Date: Tue, 7 Sep 2021 19:45:14 +0800 Subject: [PATCH] [Fix] Restrict the warning message (#1267) * restrict the warning message * and an important keyword in warning description * a more elegant way of condition * link format code too long * fix the stupid spelling mistake * Use issubclass to restrict warning message. * maybe this version is more elegant. * conv + bias + norm warning pytest * 'created' a warning, hahaha * isort and yapf format revision * isort and yapf format revision * flake8 fail issue * I have to right this way in order to solve the conflicts between yapf and flake8, sigh... * fixed test bug * Add ruby windows installer source. * Simplified the code and remove ruby source from CONTRIBUTING.md * use _BatchNorm to simplify the code * bug fix and add instanceNorm case into warning * change the warning message to make it more clear * fix unit test --- mmcv/cnn/bricks/conv_module.py | 8 +++++--- tests/test_cnn/test_conv_module.py | 20 ++++++++++++++++++-- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/mmcv/cnn/bricks/conv_module.py b/mmcv/cnn/bricks/conv_module.py index e71d243e15..4f19f1d0cf 100644 --- a/mmcv/cnn/bricks/conv_module.py +++ b/mmcv/cnn/bricks/conv_module.py @@ -3,6 +3,7 @@ import torch.nn as nn +from mmcv.utils import _BatchNorm, _InstanceNorm from ..utils import constant_init, kaiming_init from .activation import build_activation_layer from .conv import build_conv_layer @@ -104,9 +105,6 @@ def __init__(self, bias = not self.with_norm self.with_bias = bias - if self.with_norm and self.with_bias: - warnings.warn('ConvModule has norm and bias at the same time') - if self.with_explicit_padding: pad_cfg = dict(type=padding_mode) self.padding_layer = build_padding_layer(pad_cfg, padding) @@ -147,6 +145,10 @@ def __init__(self, norm_channels = in_channels self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels) self.add_module(self.norm_name, norm) + if self.with_bias: + if isinstance(norm, (_BatchNorm, _InstanceNorm)): + warnings.warn( + 'Unnecessary conv bias before batch/instance norm') else: self.norm_name = None diff --git a/tests/test_cnn/test_conv_module.py b/tests/test_cnn/test_conv_module.py index 63dc61ab87..e231ef3ae3 100644 --- a/tests/test_cnn/test_conv_module.py +++ b/tests/test_cnn/test_conv_module.py @@ -1,3 +1,4 @@ +import warnings from unittest.mock import patch import pytest @@ -161,12 +162,27 @@ def test_bias(): conv = ConvModule(3, 8, 2, bias=False) assert conv.conv.bias is None - # bias: True, with norm + # bias: True, with batch norm with pytest.warns(UserWarning) as record: ConvModule(3, 8, 2, bias=True, norm_cfg=dict(type='BN')) assert len(record) == 1 assert record[0].message.args[ - 0] == 'ConvModule has norm and bias at the same time' + 0] == 'Unnecessary conv bias before batch/instance norm' + + # bias: True, with instance norm + with pytest.warns(UserWarning) as record: + ConvModule(3, 8, 2, bias=True, norm_cfg=dict(type='IN')) + assert len(record) == 1 + assert record[0].message.args[ + 0] == 'Unnecessary conv bias before batch/instance norm' + + # bias: True, with other norm + with pytest.warns(UserWarning) as record: + norm_cfg = dict(type='GN', num_groups=1) + ConvModule(3, 8, 2, bias=True, norm_cfg=norm_cfg) + warnings.warn('No warnings') + assert len(record) == 1 + assert record[0].message.args[0] == 'No warnings' def conv_forward(self, x):