Skip to content

Commit

Permalink
reset norm running status after prepare_from_supernet (#81)
Browse files Browse the repository at this point in the history
  • Loading branch information
wutongshenqiu authored Apr 2, 2022
1 parent 9147394 commit 45b8ce7
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 0 deletions.
10 changes: 10 additions & 0 deletions mmrazor/models/pruners/structure_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ def prepare_from_supernet(self, supernet):

self.channel_spaces = self.build_channel_spaces(name2module)

self._reset_norm_running_stats(supernet)

@abstractmethod
def sample_subnet(self):
"""Sample a subnet from the supernet.
Expand Down Expand Up @@ -804,3 +806,11 @@ def concat_backward_parser(self, grad_fn, module2name, var2module,
if cur_path.pop(-1) != f'{name}_item_{i}':
print(f'{name}_item_{i}')
cur_path.pop(-1)

@staticmethod
def _reset_norm_running_stats(supernet):
from torch.nn.modules.batchnorm import _NormBase

for module in supernet.modules():
if isinstance(module, _NormBase):
module.reset_parameters()
58 changes: 58 additions & 0 deletions tests/test_models/test_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def test_ratio_pruner():
type='RatioPruner',
ratios=[1 / 8, 2 / 8, 3 / 8, 4 / 8, 5 / 8, 6 / 8, 7 / 8, 1.0])

_test_reset_bn_running_stats(architecture_cfg, pruner_cfg, False)
with pytest.raises(AssertionError):
_test_reset_bn_running_stats(architecture_cfg, pruner_cfg, True)

imgs = torch.randn(16, 3, 224, 224)
label = torch.randint(0, 1000, (16, ))

Expand Down Expand Up @@ -161,3 +165,57 @@ def test_ratio_pruner():
assert isinstance(subnet_dict, dict)
pruner.deploy_subnet(architecture, subnet_dict)
architecture.forward_dummy(imgs)


def _test_reset_bn_running_stats(architecture_cfg, pruner_cfg, should_fail):
import os
import random

import numpy as np

def set_seed(seed: int) -> None:
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)

output_list = []

def output_hook(self, input, output) -> None:
output_list.append(output)

set_seed(1024)

imgs = torch.randn(16, 3, 224, 224)

torch_rng_state = torch.get_rng_state()
np_rng_state = np.random.get_state()
random_rng_state = random.getstate()

architecture1 = ARCHITECTURES.build(architecture_cfg)
pruner1 = PRUNERS.build(pruner_cfg)
if should_fail:
pruner1._reset_norm_running_stats = lambda *_: None
set_seed(1)
pruner1.prepare_from_supernet(architecture1)
architecture1.model.head.fc.register_forward_hook(output_hook)
architecture1.eval()
architecture1(imgs, return_loss=False)

set_seed(1024)
torch.set_rng_state(torch_rng_state)
np.random.set_state(np_rng_state)
random.setstate(random_rng_state)

architecture2 = ARCHITECTURES.build(architecture_cfg)
pruner2 = PRUNERS.build(pruner_cfg)
if should_fail:
pruner2._reset_norm_running_stats = lambda *_: None
set_seed(2)
pruner2.prepare_from_supernet(architecture2)
architecture2.model.head.fc.register_forward_hook(output_hook)
architecture2.eval()
architecture2(imgs, return_loss=False)

assert torch.equal(output_list[0].norm(p='fro'),
output_list[1].norm(p='fro'))

0 comments on commit 45b8ce7

Please sign in to comment.