Skip to content

Commit

Permalink
Merge e9a8f3c into 00870b9
Browse files Browse the repository at this point in the history
  • Loading branch information
zhouzaida authored Mar 18, 2021
2 parents 00870b9 + e9a8f3c commit cf0b651
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 10 deletions.
21 changes: 14 additions & 7 deletions mmcv/cnn/utils/flops_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def get_model_complexity_info(model,
``nn.AdaptiveMaxPool3d``, ``nn.AdaptiveAvgPool1d``,
``nn.AdaptiveAvgPool2d``, ``nn.AdaptiveAvgPool3d``.
- BatchNorms: ``nn.BatchNorm1d``, ``nn.BatchNorm2d``,
``nn.BatchNorm3d``.
``nn.BatchNorm3d``, ``nn.GroupNorm``, ``nn.InstanceNorm1d``,
``InstanceNorm2d``, ``InstanceNorm3d``, ``nn.LayerNorm``.
- Linear: ``nn.Linear``.
- Deconvolution: ``nn.ConvTranspose2d``.
- Upsample: ``nn.Upsample``.
Expand Down Expand Up @@ -426,11 +427,12 @@ def pool_flops_counter_hook(module, input, output):
module.__flops__ += int(np.prod(input.shape))


def bn_flops_counter_hook(module, input, output):
def norm_flops_counter_hook(module, input, output):
input = input[0]

batch_flops = np.prod(input.shape)
if module.affine:
if (getattr(module, 'affine', False)
or getattr(module, 'elementwise_affine', False)):
batch_flops *= 2
module.__flops__ += int(batch_flops)

Expand Down Expand Up @@ -577,10 +579,15 @@ def get_modules_mapping():
nn.AdaptiveAvgPool2d: pool_flops_counter_hook,
nn.AdaptiveMaxPool3d: pool_flops_counter_hook,
nn.AdaptiveAvgPool3d: pool_flops_counter_hook,
# BNs
nn.BatchNorm1d: bn_flops_counter_hook,
nn.BatchNorm2d: bn_flops_counter_hook,
nn.BatchNorm3d: bn_flops_counter_hook,
# normalizations
nn.BatchNorm1d: norm_flops_counter_hook,
nn.BatchNorm2d: norm_flops_counter_hook,
nn.BatchNorm3d: norm_flops_counter_hook,
nn.GroupNorm: norm_flops_counter_hook,
nn.InstanceNorm1d: norm_flops_counter_hook,
nn.InstanceNorm2d: norm_flops_counter_hook,
nn.InstanceNorm3d: norm_flops_counter_hook,
nn.LayerNorm: norm_flops_counter_hook,
# FC
nn.Linear: linear_flops_counter_hook,
mmcv.cnn.bricks.Linear: linear_flops_counter_hook,
Expand Down
12 changes: 9 additions & 3 deletions tests/test_cnn/test_flops_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,15 @@
{'model': nn.AdaptiveAvgPool1d(2), 'input': (3, 16), 'flops': 48.0, 'params': 0}, # noqa: E501
{'model': nn.AdaptiveAvgPool2d(2), 'input': (3, 16, 16), 'flops': 768.0, 'params': 0}, # noqa: E501
{'model': nn.AdaptiveAvgPool3d(2), 'input': (3, 3, 16, 16), 'flops': 2304.0, 'params': 0}, # noqa: E501
{'model': nn.BatchNorm1d(3, 8), 'input': (3, 16), 'flops': 96.0, 'params': 6.0}, # noqa: E501
{'model': nn.BatchNorm2d(3, 8), 'input': (3, 16, 16), 'flops': 1536.0, 'params': 6.0}, # noqa: E501
{'model': nn.BatchNorm3d(3, 8), 'input': (3, 3, 16, 16), 'flops': 4608.0, 'params': 6.0}, # noqa: E501
{'model': nn.BatchNorm1d(3), 'input': (3, 16), 'flops': 96.0, 'params': 6.0}, # noqa: E501
{'model': nn.BatchNorm2d(3), 'input': (3, 16, 16), 'flops': 1536.0, 'params': 6.0}, # noqa: E501
{'model': nn.BatchNorm3d(3), 'input': (3, 3, 16, 16), 'flops': 4608.0, 'params': 6.0}, # noqa: E501
{'model': nn.GroupNorm(2, 6), 'input': (6, 16, 16), 'flops': 3072.0, 'params': 12.0}, # noqa: E501
{'model': nn.InstanceNorm1d(3, affine=True), 'input': (3, 16), 'flops': 96.0, 'params': 6.0}, # noqa: E501
{'model': nn.InstanceNorm2d(3, affine=True), 'input': (3, 16, 16), 'flops': 1536.0, 'params': 6.0}, # noqa: E501
{'model': nn.InstanceNorm3d(3, affine=True), 'input': (3, 3, 16, 16), 'flops': 4608.0, 'params': 6.0}, # noqa: E501
{'model': nn.LayerNorm((3, 16, 16)), 'input': (3, 16, 16), 'flops': 1536.0, 'params': 1536.0}, # noqa: E501
{'model': nn.LayerNorm((3, 16, 16), elementwise_affine=False), 'input': (3, 16, 16), 'flops': 768.0, 'params': 0}, # noqa: E501
{'model': nn.Linear(1024, 2), 'input': (1024, ), 'flops': 2048.0, 'params': 2050.0}, # noqa: E501
{'model': nn.ConvTranspose2d(3, 8, 3), 'input': (3, 16, 16), 'flops': 57888, 'params': 224.0}, # noqa: E501
{'model': nn.Upsample((32, 32)), 'input': (3, 16, 16), 'flops': 3072.0, 'params': 0} # noqa: E501
Expand Down

0 comments on commit cf0b651

Please sign in to comment.