From a8643b02e914d75f1d4ded3472e4f726cd77b78c Mon Sep 17 00:00:00 2001 From: gaoyang07 <1546308416@qq.com> Date: Thu, 20 Oct 2022 22:38:14 +0800 Subject: [PATCH 1/2] fix counter mapping bug --- .../counters/flops_params_counter.py | 67 ++++++++++++++++++- .../test_estimators/test_flops_params.py | 21 ++++++ 2 files changed, 86 insertions(+), 2 deletions(-) diff --git a/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py b/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py index f31208248..a091f3246 100644 --- a/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py +++ b/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py @@ -1,8 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. import sys +import warnings from functools import partial -from typing import Dict +from typing import Dict, List +import mmcv import torch import torch.nn as nn @@ -409,6 +411,12 @@ def add_flops_params_counter_hook_function(module): else: counter_type = get_counter_type(module) + if counter_type not in TASK_UTILS._module_dict.keys(): + old_counter_type = counter_type + counter_type = \ + module.__class__.__base__.__name__ + 'Counter' + warnings.warn(f'`{old_counter_type}` not in ' + f'op_counters. Using `{counter_type}`') if (disabled_counters is None or counter_type not in disabled_counters): counter = TASK_UTILS.build( @@ -503,9 +511,13 @@ def get_counter_type(module): def is_supported_instance(module): - """Judge whether the module is in TASK_UTILS registry or not.""" + """Judge whether the module can be countered or not.""" if get_counter_type(module) in TASK_UTILS._module_dict.keys(): return True + else: + for op in get_modules_list(): + if issubclass(module.__class__.__base__, op): + return True return False @@ -518,3 +530,54 @@ def remove_flops_params_counter_hook_function(module): del module.__flops__ if hasattr(module, '__params__'): del module.__params__ + + +def get_modules_list() -> List: + return [ + # convolutions + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + mmcv.cnn.bricks.Conv2d, + mmcv.cnn.bricks.Conv3d, + # activations + nn.ReLU, + nn.PReLU, + nn.ELU, + nn.LeakyReLU, + nn.ReLU6, + # poolings + nn.MaxPool1d, + nn.AvgPool1d, + nn.AvgPool2d, + nn.MaxPool2d, + nn.MaxPool3d, + nn.AvgPool3d, + mmcv.cnn.bricks.MaxPool2d, + mmcv.cnn.bricks.MaxPool3d, + nn.AdaptiveMaxPool1d, + nn.AdaptiveAvgPool1d, + nn.AdaptiveMaxPool2d, + nn.AdaptiveAvgPool2d, + nn.AdaptiveMaxPool3d, + nn.AdaptiveAvgPool3d, + # normalizations + nn.BatchNorm1d, + nn.BatchNorm2d, + nn.BatchNorm3d, + nn.GroupNorm, + nn.InstanceNorm1d, + nn.InstanceNorm2d, + nn.InstanceNorm3d, + nn.LayerNorm, + # FC + nn.Linear, + mmcv.cnn.bricks.Linear, + # Upscale + nn.Upsample, + nn.UpsamplingNearest2d, + nn.UpsamplingBilinear2d, + # Deconvolution + nn.ConvTranspose2d, + mmcv.cnn.bricks.ConvTranspose2d, + ] diff --git a/tests/test_models/test_task_modules/test_estimators/test_flops_params.py b/tests/test_models/test_task_modules/test_estimators/test_flops_params.py index 60bcef4ba..839f9ce9e 100644 --- a/tests/test_models/test_task_modules/test_estimators/test_flops_params.py +++ b/tests/test_models/test_task_modules/test_estimators/test_flops_params.py @@ -4,6 +4,7 @@ import pytest import torch +from mmcv.cnn.bricks import Conv2dAdaptivePadding from torch import Tensor from torch.nn import Conv2d, Module, Parameter @@ -127,6 +128,15 @@ def test_estimate(self) -> None: self.assertGreater(flops_count, 0) self.assertGreater(params_count, 0) + fool_conv2d = Conv2dAdaptivePadding(3, 32, 3) + results = estimator.estimate( + model=fool_conv2d, flops_params_cfg=flops_params_cfg) + flops_count = results['flops'] + params_count = results['params'] + + self.assertGreater(flops_count, 0) + self.assertGreater(params_count, 0) + def test_register_module(self) -> None: fool_add_constant = FoolConvModule() flops_params_cfg = dict(input_shape=(1, 3, 224, 224)) @@ -151,6 +161,17 @@ def test_disable_sepc_counter(self) -> None: self.assertLess(rest_flops_count, 45.158) self.assertLess(rest_params_count, 0.701) + fool_conv2d = Conv2dAdaptivePadding(3, 32, 3) + flops_params_cfg = dict( + input_shape=(1, 3, 224, 224), disabled_counters=['Conv2dCounter']) + rest_results = estimator.estimate( + model=fool_conv2d, flops_params_cfg=flops_params_cfg) + rest_flops_count = rest_results['flops'] + rest_params_count = rest_results['params'] + + self.assertEqual(rest_flops_count, 0) + self.assertEqual(rest_params_count, 0) + def test_estimate_spec_module(self) -> None: fool_add_constant = FoolConvModule() flops_params_cfg = dict( From ac78ed2bcfda8498b03a04847af72a6979d602b3 Mon Sep 17 00:00:00 2001 From: gaoyang07 <1546308416@qq.com> Date: Fri, 21 Oct 2022 19:45:28 +0800 Subject: [PATCH 2/2] move judgment into get_counter_type & update UT --- .../counters/flops_params_counter.py | 39 ++++++++++++------- .../test_estimators/test_flops_params.py | 8 ++-- 2 files changed, 28 insertions(+), 19 deletions(-) diff --git a/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py b/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py index a091f3246..df0c867c6 100644 --- a/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py +++ b/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py @@ -1,6 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. import sys -import warnings from functools import partial from typing import Dict, List @@ -411,12 +410,6 @@ def add_flops_params_counter_hook_function(module): else: counter_type = get_counter_type(module) - if counter_type not in TASK_UTILS._module_dict.keys(): - old_counter_type = counter_type - counter_type = \ - module.__class__.__base__.__name__ + 'Counter' - warnings.warn(f'`{old_counter_type}` not in ' - f'op_counters. Using `{counter_type}`') if (disabled_counters is None or counter_type not in disabled_counters): counter = TASK_UTILS.build( @@ -505,19 +498,35 @@ def add_flops_params_counter_variable_or_reset(module): module.__params__ = 0 -def get_counter_type(module): - """Get counter type of the module based on the module class name.""" - return module.__class__.__name__ + 'Counter' +def get_counter_type(module) -> str: + """Get counter type of the module based on the module class name. + + If the current module counter_type is not in TASK_UTILS._module_dict, + it will search the base classes of the module to see if it matches any + base class counter_type. + + Returns: + str: Counter type (or the base counter type) of the current module. + """ + counter_type = module.__class__.__name__ + 'Counter' + if counter_type not in TASK_UTILS._module_dict.keys(): + old_counter_type = counter_type + assert nn.Module in module.__class__.mro() + for base_cls in module.__class__.mro(): + if base_cls in get_modules_list(): + counter_type = base_cls.__name__ + 'Counter' + from mmengine import MMLogger + logger = MMLogger.get_current_instance() + logger.warning(f'`{old_counter_type}` not in op_counters. ' + f'Using `{counter_type}` instead.') + break + return counter_type def is_supported_instance(module): - """Judge whether the module can be countered or not.""" + """Judge whether the module is in TASK_UTILS registry or not.""" if get_counter_type(module) in TASK_UTILS._module_dict.keys(): return True - else: - for op in get_modules_list(): - if issubclass(module.__class__.__base__, op): - return True return False diff --git a/tests/test_models/test_task_modules/test_estimators/test_flops_params.py b/tests/test_models/test_task_modules/test_estimators/test_flops_params.py index 839f9ce9e..2acb58e95 100644 --- a/tests/test_models/test_task_modules/test_estimators/test_flops_params.py +++ b/tests/test_models/test_task_modules/test_estimators/test_flops_params.py @@ -125,8 +125,8 @@ def test_estimate(self) -> None: flops_count = results['flops'] params_count = results['params'] - self.assertGreater(flops_count, 0) - self.assertGreater(params_count, 0) + self.assertEqual(flops_count, 44.158) + self.assertEqual(params_count, 0.001) fool_conv2d = Conv2dAdaptivePadding(3, 32, 3) results = estimator.estimate( @@ -134,8 +134,8 @@ def test_estimate(self) -> None: flops_count = results['flops'] params_count = results['params'] - self.assertGreater(flops_count, 0) - self.assertGreater(params_count, 0) + self.assertEqual(flops_count, 44.958) + self.assertEqual(params_count, 0.001) def test_register_module(self) -> None: fool_add_constant = FoolConvModule()