diff --git a/mmrazor/models/pruners/ratio_pruning.py b/mmrazor/models/pruners/ratio_pruning.py index 151677d99..e4a54de0d 100644 --- a/mmrazor/models/pruners/ratio_pruning.py +++ b/mmrazor/models/pruners/ratio_pruning.py @@ -32,20 +32,23 @@ def __init__(self, ratios, **kwargs): self.ratios = ratios self.min_ratio = ratios[0] - def _check_pruner(self, supernet): + def _check_pruner_ratios(self, supernet): + """Check whether the ``ratios`` is correct.""" for module in supernet.model.modules(): if isinstance(module, GroupNorm): num_channels = module.num_channels num_groups = module.num_groups for ratio in self.ratios: new_channels = int(round(num_channels * ratio)) - assert (num_channels * ratio) % num_groups == 0, \ + assert new_channels % num_groups == 0, \ f'Expected number of channels in input of GroupNorm ' \ f'to be divisible by num_groups, but number of ' \ f'channels may be {new_channels} according to ' \ f'ratio {ratio} and num_groups={num_groups}' def prepare_from_supernet(self, supernet): + """Prepare for pruning.""" + self._check_pruner_ratios(supernet) super(RatioPruner, self).prepare_from_supernet(supernet) def get_channel_mask(self, out_mask): diff --git a/tests/test_models/test_pruner.py b/tests/test_models/test_pruner.py index 12323022a..ff4b59dd9 100644 --- a/tests/test_models/test_pruner.py +++ b/tests/test_models/test_pruner.py @@ -212,9 +212,7 @@ def test_ratio_pruner(): frozen_stages=1, norm_cfg=dict(type='BN', requires_grad=True), norm_eval=True, - style='pytorch', - init_cfg=dict( - type='Pretrained', checkpoint='torchvision://resnet50')), + style='pytorch'), neck=dict( type='FPN', in_channels=[256, 512, 1024, 2048], @@ -280,6 +278,15 @@ def test_ratio_pruner(): pruner.deploy_subnet(architecture, subnet_dict) architecture.forward_dummy(imgs) + # test invalid ``ratios`` + # Expected number of channels in input of GroupNorm to be divisible + # by ``num_groups`` + pruner_cfg = dict(type='RatioPruner', ratios=[1 / 10]) + architecture = ARCHITECTURES.build(architecture_cfg) + pruner = PRUNERS.build(pruner_cfg) + with pytest.raises(AssertionError): + pruner.prepare_from_supernet(architecture) + def _test_reset_bn_running_stats(architecture_cfg, pruner_cfg, should_fail): import os