diff --git a/mmcv/cnn/utils/flops_counter.py b/mmcv/cnn/utils/flops_counter.py index 2c89ebebeab..27aec347a5f 100644 --- a/mmcv/cnn/utils/flops_counter.py +++ b/mmcv/cnn/utils/flops_counter.py @@ -56,7 +56,8 @@ def get_model_complexity_info(model, ``nn.AdaptiveMaxPool3d``, ``nn.AdaptiveAvgPool1d``, ``nn.AdaptiveAvgPool2d``, ``nn.AdaptiveAvgPool3d``. - BatchNorms: ``nn.BatchNorm1d``, ``nn.BatchNorm2d``, - ``nn.BatchNorm3d``. + ``nn.BatchNorm3d``, ``nn.GroupNorm``, ``nn.InstanceNorm1d``, + ``InstanceNorm2d``, ``InstanceNorm3d``, ``nn.LayerNorm``. - Linear: ``nn.Linear``. - Deconvolution: ``nn.ConvTranspose2d``. - Upsample: ``nn.Upsample``. @@ -426,11 +427,12 @@ def pool_flops_counter_hook(module, input, output): module.__flops__ += int(np.prod(input.shape)) -def bn_flops_counter_hook(module, input, output): +def norm_flops_counter_hook(module, input, output): input = input[0] batch_flops = np.prod(input.shape) - if module.affine: + if (getattr(module, 'affine', False) + or getattr(module, 'elementwise_affine', False)): batch_flops *= 2 module.__flops__ += int(batch_flops) @@ -577,10 +579,15 @@ def get_modules_mapping(): nn.AdaptiveAvgPool2d: pool_flops_counter_hook, nn.AdaptiveMaxPool3d: pool_flops_counter_hook, nn.AdaptiveAvgPool3d: pool_flops_counter_hook, - # BNs - nn.BatchNorm1d: bn_flops_counter_hook, - nn.BatchNorm2d: bn_flops_counter_hook, - nn.BatchNorm3d: bn_flops_counter_hook, + # normalizations + nn.BatchNorm1d: norm_flops_counter_hook, + nn.BatchNorm2d: norm_flops_counter_hook, + nn.BatchNorm3d: norm_flops_counter_hook, + nn.GroupNorm: norm_flops_counter_hook, + nn.InstanceNorm1d: norm_flops_counter_hook, + nn.InstanceNorm2d: norm_flops_counter_hook, + nn.InstanceNorm3d: norm_flops_counter_hook, + nn.LayerNorm: norm_flops_counter_hook, # FC nn.Linear: linear_flops_counter_hook, mmcv.cnn.bricks.Linear: linear_flops_counter_hook, diff --git a/tests/test_cnn/test_flops_counter.py b/tests/test_cnn/test_flops_counter.py index cc8b8f1658a..99a53b75714 100644 --- a/tests/test_cnn/test_flops_counter.py +++ b/tests/test_cnn/test_flops_counter.py @@ -32,9 +32,15 @@ {'model': nn.AdaptiveAvgPool1d(2), 'input': (3, 16), 'flops': 48.0, 'params': 0}, # noqa: E501 {'model': nn.AdaptiveAvgPool2d(2), 'input': (3, 16, 16), 'flops': 768.0, 'params': 0}, # noqa: E501 {'model': nn.AdaptiveAvgPool3d(2), 'input': (3, 3, 16, 16), 'flops': 2304.0, 'params': 0}, # noqa: E501 - {'model': nn.BatchNorm1d(3, 8), 'input': (3, 16), 'flops': 96.0, 'params': 6.0}, # noqa: E501 - {'model': nn.BatchNorm2d(3, 8), 'input': (3, 16, 16), 'flops': 1536.0, 'params': 6.0}, # noqa: E501 - {'model': nn.BatchNorm3d(3, 8), 'input': (3, 3, 16, 16), 'flops': 4608.0, 'params': 6.0}, # noqa: E501 + {'model': nn.BatchNorm1d(3), 'input': (3, 16), 'flops': 96.0, 'params': 6.0}, # noqa: E501 + {'model': nn.BatchNorm2d(3), 'input': (3, 16, 16), 'flops': 1536.0, 'params': 6.0}, # noqa: E501 + {'model': nn.BatchNorm3d(3), 'input': (3, 3, 16, 16), 'flops': 4608.0, 'params': 6.0}, # noqa: E501 + {'model': nn.GroupNorm(2, 6), 'input': (6, 16, 16), 'flops': 3072.0, 'params': 12.0}, # noqa: E501 + {'model': nn.InstanceNorm1d(3, affine=True), 'input': (3, 16), 'flops': 96.0, 'params': 6.0}, # noqa: E501 + {'model': nn.InstanceNorm2d(3, affine=True), 'input': (3, 16, 16), 'flops': 1536.0, 'params': 6.0}, # noqa: E501 + {'model': nn.InstanceNorm3d(3, affine=True), 'input': (3, 3, 16, 16), 'flops': 4608.0, 'params': 6.0}, # noqa: E501 + {'model': nn.LayerNorm((3, 16, 16)), 'input': (3, 16, 16), 'flops': 1536.0, 'params': 1536.0}, # noqa: E501 + {'model': nn.LayerNorm((3, 16, 16), elementwise_affine=False), 'input': (3, 16, 16), 'flops': 768.0, 'params': 0}, # noqa: E501 {'model': nn.Linear(1024, 2), 'input': (1024, ), 'flops': 2048.0, 'params': 2050.0}, # noqa: E501 {'model': nn.ConvTranspose2d(3, 8, 3), 'input': (3, 16, 16), 'flops': 57888, 'params': 224.0}, # noqa: E501 {'model': nn.Upsample((32, 32)), 'input': (3, 16, 16), 'flops': 3072.0, 'params': 0} # noqa: E501