diff --git a/mmrazor/models/pruners/structure_pruning.py b/mmrazor/models/pruners/structure_pruning.py index a1aa3f7b4..af1eae2b0 100644 --- a/mmrazor/models/pruners/structure_pruning.py +++ b/mmrazor/models/pruners/structure_pruning.py @@ -203,6 +203,8 @@ def prepare_from_supernet(self, supernet): self.channel_spaces = self.build_channel_spaces(name2module) + self._reset_norm_running_stats(supernet) + @abstractmethod def sample_subnet(self): """Sample a subnet from the supernet. @@ -804,3 +806,11 @@ def concat_backward_parser(self, grad_fn, module2name, var2module, if cur_path.pop(-1) != f'{name}_item_{i}': print(f'{name}_item_{i}') cur_path.pop(-1) + + @staticmethod + def _reset_norm_running_stats(supernet): + from torch.nn.modules.batchnorm import _NormBase + + for module in supernet.modules(): + if isinstance(module, _NormBase): + module.reset_parameters() diff --git a/tests/test_models/test_pruner.py b/tests/test_models/test_pruner.py index a3c76740a..835c4894f 100644 --- a/tests/test_models/test_pruner.py +++ b/tests/test_models/test_pruner.py @@ -35,6 +35,10 @@ def test_ratio_pruner(): type='RatioPruner', ratios=[1 / 8, 2 / 8, 3 / 8, 4 / 8, 5 / 8, 6 / 8, 7 / 8, 1.0]) + _test_reset_bn_running_stats(architecture_cfg, pruner_cfg, False) + with pytest.raises(AssertionError): + _test_reset_bn_running_stats(architecture_cfg, pruner_cfg, True) + imgs = torch.randn(16, 3, 224, 224) label = torch.randint(0, 1000, (16, )) @@ -161,3 +165,57 @@ def test_ratio_pruner(): assert isinstance(subnet_dict, dict) pruner.deploy_subnet(architecture, subnet_dict) architecture.forward_dummy(imgs) + + +def _test_reset_bn_running_stats(architecture_cfg, pruner_cfg, should_fail): + import os + import random + + import numpy as np + + def set_seed(seed: int) -> None: + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + output_list = [] + + def output_hook(self, input, output) -> None: + output_list.append(output) + + set_seed(1024) + + imgs = torch.randn(16, 3, 224, 224) + + torch_rng_state = torch.get_rng_state() + np_rng_state = np.random.get_state() + random_rng_state = random.getstate() + + architecture1 = ARCHITECTURES.build(architecture_cfg) + pruner1 = PRUNERS.build(pruner_cfg) + if should_fail: + pruner1._reset_norm_running_stats = lambda *_: None + set_seed(1) + pruner1.prepare_from_supernet(architecture1) + architecture1.model.head.fc.register_forward_hook(output_hook) + architecture1.eval() + architecture1(imgs, return_loss=False) + + set_seed(1024) + torch.set_rng_state(torch_rng_state) + np.random.set_state(np_rng_state) + random.setstate(random_rng_state) + + architecture2 = ARCHITECTURES.build(architecture_cfg) + pruner2 = PRUNERS.build(pruner_cfg) + if should_fail: + pruner2._reset_norm_running_stats = lambda *_: None + set_seed(2) + pruner2.prepare_from_supernet(architecture2) + architecture2.model.head.fc.register_forward_hook(output_hook) + architecture2.eval() + architecture2(imgs, return_loss=False) + + assert torch.equal(output_list[0].norm(p='fro'), + output_list[1].norm(p='fro'))