Skip to content

Commit

Permalink
fix bug (#407)
Browse files Browse the repository at this point in the history
* add resnet50

* fix bug

* fix bug

* fix bug

* refine

* fix bug

* add choice and mask of units to checkpoint (#397)

* add choice and mask of units to checkpoint

* update

* fix bug

* remove device operation

* fix bug

* fix circle ci error

* fix error in numpy for circle ci

* fix bug in requirements

* restore

* add a note

* a new solution

* save mutable_channel.mask as float for dist training

* refine

* mv meta file test

Co-authored-by: liukai <your_email@abc.example>
Co-authored-by: jacky <jacky@xx.com>

* fix bug

* add assert

* fix bug

* change iter to epoch

* bn_imp use abs

Co-authored-by: jacky <jacky@xx.com>
Co-authored-by: liukai <your_email@abc.example>
  • Loading branch information
3 people authored Dec 22, 2022
1 parent a91e2c7 commit 122ee38
Show file tree
Hide file tree
Showing 13 changed files with 183 additions and 59 deletions.
7 changes: 6 additions & 1 deletion tests/test_metafiles.py → .dev_scripts/meta_files_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import unittest
from pathlib import Path

import requests
Expand All @@ -8,7 +9,7 @@
MMRAZOR_ROOT = Path(__file__).absolute().parents[1]


class TestMetafiles:
class TestMetafiles(unittest.TestCase):

def get_metafiles(self, code_path):
"""
Expand Down Expand Up @@ -51,3 +52,7 @@ def test_metafiles(self):
assert model['Name'] == correct_name, \
f'name error in {metafile}, correct name should ' \
f'be {correct_name}'


if __name__ == '__main__':
unittest.main()
29 changes: 29 additions & 0 deletions configs/chex/resnet50/chex_resnet50_7.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
_base_ = ['mmcls::resnet/resnet50_8xb32_in1k.py']

data_preprocessor = {'type': 'mmcls.ClsDataPreprocessor'}
architecture = _base_.model
architecture.update({
'init_cfg': {
'type':
'Pretrained',
'checkpoint':
'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth' # noqa
}
})

model = dict(
_delete_=True,
_scope_='mmrazor',
type='ChexAlgorithm',
architecture=architecture,
mutator_cfg=dict(
type='ChexMutator',
channel_unit_cfg=dict(
type='ChexUnit', default_args=dict(choice_mode='number', )),
channel_ratio=0.7,
),
delta_t=2,
total_steps=60,
init_growth_rate=0.3,
)
custom_hooks = [{'type': 'mmrazor.ChexHook'}]
27 changes: 14 additions & 13 deletions mmrazor/models/algorithms/pruning/ite_prune_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,23 +204,24 @@ def forward(self,
data_samples: Optional[List[BaseDataElement]] = None,
mode: str = 'tensor') -> ForwardResults:
"""Forward."""
if not hasattr(self, 'prune_config_manager'):
# self._iters_per_epoch() only available after initiation
self.prune_config_manager = self._init_prune_config_manager()

if self.prune_config_manager.is_prune_time(self._iter):
if self.training:
if not hasattr(self, 'prune_config_manager'):
# self._iters_per_epoch() only available after initiation
self.prune_config_manager = self._init_prune_config_manager()
if self.prune_config_manager.is_prune_time(self._iter):

config = self.prune_config_manager.prune_at(self._iter)
config = self.prune_config_manager.prune_at(self._iter)

self.mutator.set_choices(config)
self.mutator.set_choices(config)

logger = MMLogger.get_current_instance()
if (self.by_epoch):
logger.info(
f'The model is pruned at {self._epoch}th epoch once.')
else:
logger.info(
f'The model is pruned at {self._iter}th iter once.')
logger = MMLogger.get_current_instance()
if (self.by_epoch):
logger.info(
f'The model is pruned at {self._epoch}th epoch once.')
else:
logger.info(
f'The model is pruned at {self._iter}th iter once.')

return super().forward(inputs, data_samples, mode)

Expand Down
3 changes: 2 additions & 1 deletion mmrazor/models/chex/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .chex_algorithm import ChexAlgorithm
from .chex_hook import ChexHook
from .chex_mutator import ChexMutator
from .chex_ops import ChexConv2d, ChexLinear, ChexMixin
from .chex_unit import ChexUnit

__all__ = [
'ChexAlgorithm', 'ChexMutator', 'ChexUnit', 'ChexConv2d', 'ChexLinear',
'ChexMixin'
'ChexMixin', 'ChexHook'
]
32 changes: 27 additions & 5 deletions mmrazor/models/chex/chex_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
# Copyright (c) OpenMMLab. All rights reserved.
import json
import math
from typing import Dict, Optional, Union

import torch
import torch.nn as nn
from mmengine import dist
from mmengine.model import BaseModel
from mmengine.model.utils import convert_sync_batchnorm

from mmrazor.models.algorithms import BaseAlgorithm
from mmrazor.registry import MODELS
from mmrazor.utils import print_log
from .chex_mutator import ChexMutator
from .utils import RuntimeInfo

Expand All @@ -26,6 +31,9 @@ def __init__(self,
init_cfg: Optional[Dict] = None):
super().__init__(architecture, data_preprocessor, init_cfg)

if dist.is_distributed():
self.architecture = convert_sync_batchnorm(self.architecture)

self.delta_t = delta_t
self.total_steps = total_steps
self.init_growth_rate = init_growth_rate
Expand All @@ -35,17 +43,31 @@ def __init__(self,

def forward(self, inputs, data_samples=None, mode: str = 'tensor'):
if self.training: #
if RuntimeInfo.iter() % self.delta_t == 0 and \
RuntimeInfo.iter() // self.delta_t < self.total_steps:
self.mutator.prune()
self.mutator.grow(self.growth_ratio)
if RuntimeInfo.epoch() % self.delta_t == 0 and \
RuntimeInfo.epoch() < self.total_steps and \
RuntimeInfo.iter_by_epoch() == 0:
with torch.no_grad():
self.mutator.prune()
print_log(f'prune model with {self.mutator.channel_ratio}')
self.log_choices()

self.mutator.grow(self.growth_ratio)
print_log(f'grow model with {self.growth_ratio}')
self.log_choices()
return super().forward(inputs, data_samples, mode)

@property
def growth_ratio(self):
# return growth ratio in current epoch
def cos():
a = math.pi * RuntimeInfo.epoch() / RuntimeInfo.max_epochs()
a = math.pi * RuntimeInfo.epoch() / self.total_steps
return (math.cos(a) + 1) / 2

return self.init_growth_rate * cos()

def log_choices(self):
if dist.get_rank() == 0:
config = {}
for unit in self.mutator.mutable_units:
config[unit.name] = unit.current_choice
print_log(json.dumps(config, indent=4))
24 changes: 24 additions & 0 deletions mmrazor/models/chex/chex_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.hooks import Hook

from mmrazor.registry import HOOKS


@HOOKS.register_module()
class ChexHook(Hook):
pass
# @classmethod
# def algorithm(cls, runner):
# if dist.is_distributed():
# return runner.model.module
# else:
# return runner.model

# def before_val(self, runner) -> None:
# algorithm = self.algorithm(runner)
# if dist.get_rank() == 0:
# config = {}
# for unit in algorithm.mutator.mutable_units:
# config[unit.name] = unit.current_choice
# print_log(json.dumps(config, indent=4))
# print_log(f'growth_ratio: {algorithm.growth_ratio}')
22 changes: 15 additions & 7 deletions mmrazor/models/chex/chex_mutator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ def prune(self):
step1: get pruning structure
step2: prune based on ChexMixin.prune_imp
"""
choices = self._get_prune_choices()
for unit in self.mutable_units:
unit.prune(choices[unit.name])
with torch.no_grad():
choices = self._get_prune_choices()
for unit in self.mutable_units:
unit.prune(choices[unit.name])

def grow(self, growth_ratio=0.0):
"""Make the model grow.
Expand Down Expand Up @@ -60,9 +61,16 @@ def _get_prune_choices(self):
unit: ChexUnit
bn_imps[unit.name] = unit.bn_imp
bn_imp: torch.Tensor = torch.cat(list(bn_imps.values()), dim=0)
num_remain = int(self.channel_ratio * len(bn_imp))
threshold = bn_imp.topk(num_remain)[0][-1]

num_total_channel = len(bn_imp)
num_min_remained = int(self.channel_ratio * num_total_channel)
threshold = bn_imp.topk(num_min_remained)[0][-1]

num_remained = 0
for unit in self.mutable_units:
num = (bn_imps[unit.name] >= threshold).float().sum().long().item()
choices[unit.name] = num
num = (bn_imps[unit.name] >= threshold).long().sum().item()
choices[unit.name] = max(num, 1)
num_remained += choices[unit.name]
assert num_remained >= num_min_remained, \
f'{num_remained},{num_min_remained}'
return choices
44 changes: 19 additions & 25 deletions mmrazor/models/chex/chex_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,33 +32,28 @@ def prepare_for_pruning(self, model: nn.Module):
def prune(self, num_remaining):
# prune the channels to num_remaining
def get_prune_imp():
prune_imp: torch.Tensor = torch.zeros([self.num_channels])
prune_imp = 0
for channel in self.chex_channels:
module = channel.module
prune_imp = prune_imp.to(
module.prune_imp(num_remaining).device)
prune_imp = prune_imp + module.prune_imp(
num_remaining)[channel.start:channel.end]
return prune_imp

prune_imp = get_prune_imp()
index = prune_imp.topk(num_remaining)[1]
mask: torch.Tensor = torch.zeros([self.num_channels],
device=prune_imp.device)
mask.scatter_(-1, index, 1.0)
mask = mask.bool()
self.mutable_channel.current_choice.data = mask
with torch.no_grad():
prune_imp = get_prune_imp()
index = prune_imp.topk(num_remaining)[1]
self.mutable_channel.mask.fill_(0.0)
self.mutable_channel.mask.data.scatter_(-1, index, 1.0)

def grow(self, num):
assert num >= 0
if num == 0:
return

def get_growth_imp():
growth_imp: torch.Tensor = torch.zeros([self.num_channels])
growth_imp = 0
for channel in self.chex_channels:
module = channel.module
growth_imp = growth_imp.to(module.growth_imp.device)
growth_imp = growth_imp + module.growth_imp[channel.
start:channel.end]
return growth_imp
Expand All @@ -73,23 +68,22 @@ def get_growth_imp():
select_index = index_free[select_index]
else:
select_index = index_free
mask.index_fill_(-1, select_index, 1.0)

self.mutable_channel.current_choice.data = mask
self.mutable_channel.mask.index_fill_(-1, select_index, 1.0)

@property
def bn_imp(self):
imp = torch.zeros([self.num_channels])
num_layers = 0
for channel in self.output_related:
module = channel.module
if isinstance(module, nn.modules.batchnorm._BatchNorm):
imp = imp.to(module.weight.device)
imp = imp + module.weight[channel.start:channel.end]
num_layers += 1
assert num_layers > 0
imp = imp / num_layers
return imp
with torch.no_grad():
imp = 0
num_layers = 0
for channel in self.output_related:
module = channel.module
if isinstance(module, nn.modules.batchnorm._BatchNorm):
imp = imp + module.weight[channel.start:channel.end].abs()
num_layers += 1
assert num_layers > 0
imp = imp / num_layers
return imp

@property
def chex_channels(self):
Expand Down
11 changes: 11 additions & 0 deletions mmrazor/models/chex/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math

from mmengine.logging import MessageHub


Expand All @@ -23,3 +25,12 @@ def max_epochs(cls):
@classmethod
def iter(cls):
return cls.get_info('iter')

@classmethod
def max_iters(cls):
return cls.get_info('max_iters')

@classmethod
def iter_by_epoch(cls):
iter_per_epoch = math.ceil(cls.max_iters() / cls.max_epochs())
return cls.iter() % iter_per_epoch
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def __init__(self, num_channels: int, choice_mode='number', **kwargs):
super().__init__(num_channels, **kwargs)
assert choice_mode in ['ratio', 'number']
self.choice_mode = choice_mode
self.mask = torch.ones([self.num_channels]).bool()

@property
def is_num_mode(self):
Expand All @@ -50,14 +49,13 @@ def current_choice(self, choice: Union[int, float]):
int_choice = self._ratio2num(choice)
else:
int_choice = choice
mask = torch.zeros([self.num_channels], device=self.mask.device)
mask[0:int_choice] = 1
self.mask = mask.bool()
self.mask.fill_(0.0)
self.mask[0:int_choice] = 1.0

@property
def current_mask(self) -> torch.Tensor:
"""Return current mask."""
return self.mask
return self.mask.bool()

# methods for

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ class SimpleMutableChannel(BaseMutableChannel):

def __init__(self, num_channels: int, **kwargs) -> None:
super().__init__(num_channels, **kwargs)
self.mask = torch.ones(num_channels).bool()
mask = torch.ones([self.num_channels
]) # save bool as float for dist training
self.register_buffer('mask', mask)
self.mask: torch.Tensor

# choice

Expand All @@ -32,7 +35,7 @@ def current_choice(self) -> torch.Tensor:
@current_choice.setter
def current_choice(self, choice: torch.Tensor):
"""Set current choice."""
self.mask = choice.to(self.mask.device).bool()
self.mask = choice.to(self.mask.device).float()

@property
def current_mask(self) -> torch.Tensor:
Expand Down
1 change: 1 addition & 0 deletions requirements/tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ interrogate
isort==4.3.21
nbconvert
nbformat
numpy < 1.24.0 # A temporary solution for tests with mmdet.
pytest
xdoctest >= 0.10.0
yapf
Loading

0 comments on commit 122ee38

Please sign in to comment.