From 4fba4e12b3a5f2c357cea847d0a6875e465a93b8 Mon Sep 17 00:00:00 2001 From: liukai Date: Mon, 5 Sep 2022 10:44:33 +0800 Subject: [PATCH 01/25] rebase new dev-1.x --- .../architectures/dynamic_ops/__init__.py | 8 +- .../dynamic_ops/bricks/dynamic_conv.py | 26 +- mmrazor/models/mutables/__init__.py | 16 +- .../mutables/mutable_channel/__init__.py | 25 +- .../mutable_channel/base_mutable_channel.py | 83 +++ .../mutable_channel/groups/__init__.py | 21 + .../mutable_channel/groups/channel_group.py | 612 ++++++++++++++++++ .../groups/mutable_channel_group.py | 245 +++++++ .../groups/sequential_channel_group.py | 73 +++ .../mutable_channel_container.py | 106 +++ .../sequential_mutable_channel.py | 64 ++ .../mutable_channel/simple_mutable_channel.py | 61 ++ mmrazor/structures/graph/module_graph.py | 45 +- mmrazor/utils/__init__.py | 4 +- mmrazor/utils/index_dict.py | 53 ++ tests/data/models.py | 177 ++++- tests/test_core/__init__.py | 1 + tests/test_core/test_graph/__init__.py | 1 + tests/test_core/test_graph/test_graph.py | 132 ++-- tests/test_models/__init__.py | 1 + tests/test_models/test_mutables/__init__.py | 1 + .../test_mutables/group/__init__.py | 1 + .../group/test_mutable_channel_groups.py | 103 +++ tests/test_utils/test_index_dict.py | 16 + 24 files changed, 1779 insertions(+), 96 deletions(-) create mode 100644 mmrazor/models/mutables/mutable_channel/base_mutable_channel.py create mode 100644 mmrazor/models/mutables/mutable_channel/groups/__init__.py create mode 100644 mmrazor/models/mutables/mutable_channel/groups/channel_group.py create mode 100644 mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py create mode 100644 mmrazor/models/mutables/mutable_channel/groups/sequential_channel_group.py create mode 100644 mmrazor/models/mutables/mutable_channel/mutable_channel_container.py create mode 100644 mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py create mode 100644 mmrazor/models/mutables/mutable_channel/simple_mutable_channel.py create mode 100644 mmrazor/utils/index_dict.py create mode 100644 tests/test_core/__init__.py create mode 100644 tests/test_core/test_graph/__init__.py create mode 100644 tests/test_models/__init__.py create mode 100644 tests/test_models/test_mutables/__init__.py create mode 100644 tests/test_models/test_mutables/group/__init__.py create mode 100644 tests/test_models/test_mutables/group/test_mutable_channel_groups.py create mode 100644 tests/test_utils/test_index_dict.py diff --git a/mmrazor/models/architectures/dynamic_ops/__init__.py b/mmrazor/models/architectures/dynamic_ops/__init__.py index 6b5796688..a7c259e38 100644 --- a/mmrazor/models/architectures/dynamic_ops/__init__.py +++ b/mmrazor/models/architectures/dynamic_ops/__init__.py @@ -1,12 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. from .base import DynamicOP -from .default_dynamic_ops import (DynamicBatchNorm, DynamicConv2d, - DynamicGroupNorm, DynamicInstanceNorm, - DynamicLinear) +from .default_dynamic_ops import (ChannelDynamicOP, DynamicBatchNorm, + DynamicConv2d, DynamicGroupNorm, + DynamicInstanceNorm, DynamicLinear) from .slimmable_dynamic_ops import SwitchableBatchNorm2d __all__ = [ 'DynamicConv2d', 'DynamicLinear', 'DynamicBatchNorm', 'DynamicInstanceNorm', 'DynamicGroupNorm', 'SwitchableBatchNorm2d', - 'DynamicOP' + 'DynamicOP', 'ChannelDynamicOP' ] diff --git a/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_conv.py b/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_conv.py index 8b031e4b1..009b3131b 100644 --- a/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_conv.py +++ b/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_conv.py @@ -37,16 +37,21 @@ def __init__(self, *args, **kwargs) -> None: def convert_from(cls, module: nn.Conv2d) -> 'DynamicConv2d': """Convert an instance of nn.Conv2d to a new instance of DynamicConv2d.""" - return cls( - in_channels=module.in_channels, - out_channels=module.out_channels, - kernel_size=module.kernel_size, - stride=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups, - bias=True if module.bias is not None else False, - padding_mode=module.padding_mode) + # a group-wise conv will not be converted to dynamic conv + if module.groups > 1 and not (module.groups == module.out_channels == + module.in_channels): + return module + else: + return cls( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size, + stride=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + bias=True if module.bias is not None else False, + padding_mode=module.padding_mode) @property def conv_func(self) -> Callable: @@ -146,6 +151,7 @@ def __init__(self, *args, **kwargs) -> None: def convert_from(cls, module: nn.Conv2d) -> 'OFAConv2d': """Convert an instance of `nn.Conv2d` to a new instance of `OFAConv2d`.""" + return cls( in_channels=module.in_channels, out_channels=module.out_channels, diff --git a/mmrazor/models/mutables/__init__.py b/mmrazor/models/mutables/__init__.py index 917364607..659bcd57c 100644 --- a/mmrazor/models/mutables/__init__.py +++ b/mmrazor/models/mutables/__init__.py @@ -1,14 +1,20 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .base_mutable import BaseMutable from .derived_mutable import DerivedMutable -from .mutable_channel import (MutableChannel, OneShotMutableChannel, - SlimmableMutableChannel) +from .mutable_channel import (BaseMutableChannel, MutableChannel, + MutableChannelContainer, OneShotMutableChannel, + SimpleMutableChannel, SlimmableMutableChannel) +from .mutable_channel.groups import (MUTABLECHANNELGROUP, MutableChannelGroup, + SequentialChannelGroup) from .mutable_module import (DiffChoiceRoute, DiffMutableModule, DiffMutableOP, OneShotMutableModule, OneShotMutableOP) from .mutable_value import MutableValue, OneShotMutableValue __all__ = [ 'OneShotMutableOP', 'OneShotMutableModule', 'DiffMutableOP', - 'DiffChoiceRoute', 'DiffMutableModule', 'OneShotMutableChannel', - 'SlimmableMutableChannel', 'MutableChannel', 'DerivedMutable', - 'MutableValue', 'OneShotMutableValue' + 'DiffChoiceRoute', 'DiffMutableModule', 'DerivedMutable', 'MutableValue', + 'OneShotMutableValue', 'SimpleMutableChannel', 'MutableChannelGroup', + 'BaseMutableChannel', 'MutableChannelContainer', 'MUTABLECHANNELGROUP', + 'BaseMutable', 'MutableChannel', 'SlimmableMutableChannel', + 'OneShotMutableChannel', 'SequentialChannelGroup' ] diff --git a/mmrazor/models/mutables/mutable_channel/__init__.py b/mmrazor/models/mutables/mutable_channel/__init__.py index b3bbd3ab3..72d3f7276 100644 --- a/mmrazor/models/mutables/mutable_channel/__init__.py +++ b/mmrazor/models/mutables/mutable_channel/__init__.py @@ -1,8 +1,31 @@ # Copyright (c) OpenMMLab. All rights reserved. +r"""This module defines MutableChannels. + +----------------------------------------------------------base_mutable_channel.py +BaseMutableChannel +| \ +----------------------------------------------------------mutable_channel_container.py +MutableChannelContainer \ +----------------------------------------------------------other files + \ other MutableChannels + +MutableChannel are mainly used in DynamicOps. It helps DynamicOps to deal +with mutable number of channels. +""" +from .base_mutable_channel import BaseMutableChannel +from .groups import (MUTABLECHANNELGROUP, MutableChannelGroup, + SequentialChannelGroup) from .mutable_channel import MutableChannel +from .mutable_channel_container import MutableChannelContainer from .one_shot_mutable_channel import OneShotMutableChannel +from .sequential_mutable_channel import SquentialMutableChannel +from .simple_mutable_channel import SimpleMutableChannel from .slimmable_mutable_channel import SlimmableMutableChannel __all__ = [ - 'OneShotMutableChannel', 'SlimmableMutableChannel', 'MutableChannel' + 'SimpleMutableChannel', 'MutableChannelGroup', 'OneShotChannelGroup', + 'BaseMutableChannel', 'MutableChannelContainer', 'StackMutableChannel', + 'MUTABLECHANNELGROUP', 'MutableChannel', 'OneShotMutableChannel', + 'SlimmableMutableChannel', 'SquentialMutableChannel', + 'SequentialChannelGroup' ] diff --git a/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py new file mode 100644 index 000000000..54ce6568e --- /dev/null +++ b/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py @@ -0,0 +1,83 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""""" +from abc import abstractproperty + +import torch + +from ..base_mutable import BaseMutable +from ..derived_mutable import DerivedMethodMixin + + +class BaseMutableChannel(BaseMutable, DerivedMethodMixin): + """BaseMutableChannel works as a channel mask for DynamicOps to select + channels. + + |---------------------------------------| + |in_channel_mutable(BaseMutableChannel) | + |---------------------------------------| + | DynamicOp | + |---------------------------------------| + |out_channel_mutable(BaseMutableChannel)| + |---------------------------------------| + + Important interfaces: + current_choice: used to get/set mask. + current_mask: get mask(used in DynamicOps to get mask). + """ + + def __init__(self, num_channels: int, **kwargs): + super().__init__(**kwargs) + self.name = '' + self.num_channels = num_channels + + # choice + + @abstractproperty + def current_choice(self): + """get current choice.""" + raise NotImplementedError() + + @current_choice.setter + def current_choice(self): + """set current choice.""" + raise NotImplementedError() + + @abstractproperty + def current_mask(self) -> torch.Tensor: + """Return a mask indicating the channel selection.""" + raise NotImplementedError() + + @property + def activated_channels(self) -> int: + """Number of activated channels.""" + return (self.current_mask == 1).sum().item() + + # implementation of abstract methods + + def fix_chosen(self, chosen=None): + """Fix the mutable with chosen.""" + if chosen is not None: + self.current_choice = chosen + + if self.is_fixed: + raise AttributeError( + 'The mode of current MUTABLE is `fixed`. ' + 'Please do not call `fix_chosen` function again.') + + self.is_fixed = True + + def dump_chosen(self): + """dump current choice to a dict.""" + raise NotImplementedError() + + def num_choices(self) -> int: + """Number of available choices.""" + raise NotImplementedError() + + # others + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(name={self.name}, ' + repr_str += f'num_channels={self.num_channels}, ' + return repr_str diff --git a/mmrazor/models/mutables/mutable_channel/groups/__init__.py b/mmrazor/models/mutables/mutable_channel/groups/__init__.py new file mode 100644 index 000000000..f73fa24cb --- /dev/null +++ b/mmrazor/models/mutables/mutable_channel/groups/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) OpenMMLab. All rights reserved. +""" +----------------------------------------------------------channel_group.py +PruneNode && PruneGraph +| +| Graph2ChannelGroups +| +Channel && ChannelGroup +| +----------------------------------------------------------mutable_channel_group.py +MutableChannelGroup +| +----------------------------------------------------------other files +Subclasses of MutableChannelGroup +""" +from .mutable_channel_group import MUTABLECHANNELGROUP, MutableChannelGroup +from .sequential_channel_group import SequentialChannelGroup + +__all__ = [ + 'MutableChannelGroup', 'SequentialChannelGroup', 'MUTABLECHANNELGROUP' +] diff --git a/mmrazor/models/mutables/mutable_channel/groups/channel_group.py b/mmrazor/models/mutables/mutable_channel/groups/channel_group.py new file mode 100644 index 000000000..7a3968029 --- /dev/null +++ b/mmrazor/models/mutables/mutable_channel/groups/channel_group.py @@ -0,0 +1,612 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""This module defines ChannelGroup with related modules. + +PruneNode Channel + -------------------> +PruneGraph Graph2ChannelGroups ChannelGroup + +PruneNode and PruneGraph are used to record the computation graph of a model. +A Channel records a slice of the input or output channels of a module. +A ChannelGroup collects all Channels with channel-dependency. +Graph2ChannelGroups is used to parse a PruneGraph and get ChannelGroups +""" + +import copy +from typing import Any, Dict, List, Tuple, Type, TypeVar, Union + +import torch.nn as nn +from torch.nn import Module + +from mmrazor.models.architectures.dynamic_ops.bricks.dynamic_mixins import \ + DynamicChannelMixin +from mmrazor.registry import MODELS +from mmrazor.structures.graph import ModuleGraph, ModuleNode +from mmrazor.utils import IndexDict +from ..base_mutable_channel import BaseMutableChannel + +# PruneNode && PruneGraph + + +class PruneNode(ModuleNode): + """Node class for pruning.""" + + def __init__(self, name: str, obj: Module, module_name='') -> None: + super().__init__(name, obj, module_name=module_name) + self.input_related_groups: IndexDict[ChannelGroup] = IndexDict() + self.output_related_groups: IndexDict[ChannelGroup] = IndexDict() + + # groups operation + + def get_channels(self, + index: Union[None, Tuple[int, int]] = None, + out_related=True, + expand_ratio: int = 1) -> 'Channel': + """PruneChannels: get the channels in the node between a range + + Args: + index (Union[None, Tuple[int, int]]): the channel range for pruning + out_related (Bool): represents if the channels are output channels, + otherwise input channels. + expand_ratio (Bool): expand_ratio of the number of channels + compared with pruning mask. + """ + if index is None: + index = (0, self.out_channels + if out_related is True else self.in_channels) + channel = Channel( + self, + index, + out_related=out_related, + expand_ratio=expand_ratio, + module_name=self.module_name) + return channel + + def output_related_groups_of_prev_nodes( + self) -> List[IndexDict['ChannelGroup']]: + """IndexDict['ChannelGroup']: the output-related + ChannelGroups of previous nodes.""" + groups = [] + for node in self.prev_nodes: + groups.append(node.output_related_groups) + return groups + + # channel + + @property + def act_in_channels(self) -> int: + """Int: activated input channel number""" + if isinstance(self.val, nn.Module): + if isinstance(self.val, DynamicChannelMixin): + mutable: BaseMutableChannel = self.val.get_mutable_attr( + 'in_channels') + return mutable.activated_channels + else: + if isinstance(self.val, nn.Conv2d): + return self.val.in_channels + elif isinstance(self.val, nn.modules.batchnorm._BatchNorm): + return self.val.num_features + elif isinstance(self.val, nn.Linear): + return self.val.in_features + else: + raise NotImplementedError() + elif self.is_bind_node(): + assert len(self.prev_nodes) > 1, '{name} is bind node' + return self.prev_nodes[0].act_in_channels + elif self.is_cat_node(): + return sum([ + node.act_in_channels if node.act_in_channels is not None else 0 + for node in self.prev_nodes + ]) + else: + raise NotImplementedError() + + @property + def act_out_channels(self) -> int: + """Int: activated output channel number""" + if isinstance(self.val, nn.Module): + if isinstance(self.val, DynamicChannelMixin): + mutable: BaseMutableChannel = self.val.get_mutable_attr( + 'out_channels') + return mutable.activated_channels + else: + return self.out_channels + elif self.is_bind_node(): + assert len(self.prev_nodes) > 1, '{name} is bind node' + return self.prev_nodes[0].act_out_channels + elif self.is_cat_node(): + return sum([ + node.act_out_channels + if node.act_out_channels is not None else 0 + for node in self.prev_nodes + ]) + else: + raise NotImplementedError() + + @property + def is_parsed(self): + return len(self.input_related_groups) > 0 or len( + self.output_related_groups) > 0 + + # others + def __repr__(self) -> str: + return (f'{self.name}_{self.act_in_channels}/{self.in_channels}' + f'_{self.act_out_channels}/{self.out_channels}') + + @property + def is_prunable(self) -> bool: + """Bool: if the node prunable""" + return self.basic_type not in ['gwconv2d'] + + @classmethod + def copy_from(cls, node): + if isinstance(node, ModuleNode): + return cls(node.name, node.val, node.module_name) + else: + raise NotImplementedError() + + +PRUNENODE = TypeVar('PRUNENODE', bound=PruneNode) + + +class PruneGraph(ModuleGraph[PRUNENODE]): + """Graph class for pruning.""" + + def __init__(self) -> None: + super().__init__() + + # groups_operation + def colloct_groups(self) -> List['ChannelGroup']: + """Set['ChannelGroup']: collect all ChannelGroups in the graph""" + groups = [] + for node in self.topo_traverse(): + for group in node.input_related_groups.values(): + if group not in groups: + groups.append(group) + for group in node.output_related_groups.values(): + if group not in groups: + groups.append(group) + return groups + + @classmethod + def copy_from(cls, graph, node_converter=PruneNode.copy_from): + graph = super().copy_from(graph, node_converter) + graph.merge_same_module() + return graph + + def merge_same_module(self): + module2node: Dict[Any, List[PruneNode]] = dict() + for node in self: + if isinstance(node.val, Module): + if node.val not in module2node: + module2node[node.val] = [] + if node not in module2node[node.val]: + module2node[node.val].append(node) + for module in module2node: + if len(module2node[module]) > 1: + input_group = IndexDict() + output_group = IndexDict() + for node in module2node[module]: + node.input_related_groups = input_group + node.output_related_groups = output_group + + +# Channel && ChannelGroup + + +class Channel: + """Channel records information about channels for pruning.""" + + def __init__(self, + node: PruneNode, + index, + out_related=True, + expand_ratio=1, + module_name='') -> None: + """ + Args: + node: (PruneNode): prune-node to be recorded + index (Union[None, Tuple[int, int]]): the channel range for pruning + out_related (Bool): represents if the channels are output channels, + otherwise input channels + expand_ratio (Bool): expand_ratio of the number of channels + compared with pruning mask + """ + self.node = node + self.index = index + self.start = index[0] + self.end = index[1] + self.output_related = out_related + self.expand_ratio = expand_ratio + + self.name = node.name + self.module: DynamicChannelMixin = node.val + self.module_name = module_name + + @property + def num_channels(self) -> int: + """Int: number of channels in the Channels""" + return self.index[1] - self.index[0] + + # group related operations + + def slice(self, start: int, end: int) -> 'Channel': + """Channel: a new Channel who manage a slice of the current Channel.""" + channel = Channel( + self.node, + index=(self.start + start, self.start + end), + out_related=self.output_related, + expand_ratio=self.expand_ratio, + module_name=self.module_name) + return channel + + # others + + def __repr__(self) -> str: + return f'{self.name}\t{self.index}\t \ + {"out" if self.output_related else "in"}\t\ + expand:{self.expand_ratio}' + + +class ChannelGroup: + """A manager for Channels.""" + + def __init__(self, num_channels: int) -> None: + """ + Args: + num_channels (int): the dimension of Channels. + """ + + self.num_channels = num_channels + self.output_related: List[Channel] = [] + self.input_related: List[Channel] = [] + self.init_args: Dict = { + } # is used to generate new channel group with same args + + # node operations + + def add_ouptut_related(self, channel: Channel): + """None: add a Channel which is output related""" + assert channel.output_related + assert self.num_channels == channel.num_channels + if channel not in self.output_related: + self.output_related.append(channel) + + def add_input_related(self, channel: Channel): + """None: add a Channel which is input related""" + assert channel.output_related is False + assert self.num_channels == channel.num_channels + if channel not in self.input_related: + self.input_related.append(channel) + + def remove_from_node(self): + """Remove recorded information in all nodes about this group.""" + for channel in self.output_related: + assert channel.index in channel.node.output_related_groups, \ + f'{channel.name}.{channel.index} not exist in node.out_related' + channel.node.output_related_groups.pop(channel.index) + for channel in self.input_related: + assert channel.index in channel.node.input_related_groups, \ + f'{channel.name}.{channel.index} \ + not exist in node.input_related' + + channel.node.input_related_groups.pop(channel.index) + + def apply_for_node(self): + """Register the information about this group for all nodes.""" + for node in self.output_related: + node.node.output_related_groups[node.index] = self + for node in self.input_related: + node.node.input_related_groups[node.index] = self + + # group operations + + @classmethod + def union(cls, groups: List['ChannelGroup']) -> 'ChannelGroup': + """ChannelGroup: Union ChannelGroups and return.""" + group = cls(groups[0].num_channels, + **groups[0].init_args) # type: ignore + for old_group in groups: + for group_module in old_group.input_related: + group.add_input_related(group_module) + for group_module in old_group.output_related: + group.add_ouptut_related(group_module) + return group + + def split(self, nums: List[int]) -> List['ChannelGroup']: + """Split the ChannelGroup and return.""" + assert sum(nums) == self.num_channels + + if len(nums) == 1: + return [self] + else: + groups = [] + start = 0 + for num in nums: + groups.append(self.slice(start, start + num)) + start += num + return groups + + def slice(self, start: int, end: int) -> 'ChannelGroup': + """Get a slice of the ChannelGroup.""" + assert start >= 0 and end <= self.num_channels + group = self.__class__(end - start, **self.init_args) # type: ignore + for module in self.input_related: + group.add_input_related(module.slice(start, end)) + for module in self.output_related: + group.add_ouptut_related(module.slice(start, end)) + return group + + # init + + @classmethod + def parse_channel_groups(cls, + graph: ModuleGraph, + group_args={}) -> List['ChannelGroup']: + """Parse a module-graph and get ChannelGroups.""" + group_graph = PruneGraph.copy_from(graph, PruneNode.copy_from) + + cfg = dict(type=cls.__name__, **group_args) + groups = Graph2ChannelGroups(group_graph, cfg).groups + for group in groups: + group._model = graph._model + return groups + + # to string + + def __repr__(self): + + def add_prefix(string: str, prefix=' '): + str_list = string.split('\n') + str_list = [ + prefix + line if line != '' else line for line in str_list + ] + return '\n'.join(str_list) + + def list_repr(lit: List): + s = '[\n' + for item in lit: + s += add_prefix(item.__repr__(), ' ') + '\n' + s += ']\n' + return s + + s = (f'{self.name}_' + f'\t{len(self.output_related)},{len(self.input_related)}' + f'\t{self.is_prunable}\n') + s += ' output_related:\n' + s += add_prefix(list_repr(self.output_related), ' ' * 4) + s += ' input_related\n' + s += add_prefix(list_repr(self.input_related), ' ' * 4) + return s + + +# Group to ChannelGroup Converter + + +class Graph2ChannelGroups: + """A converter which converts a Graph to a list of ChannelGroups.""" + + def __init__( + self, + graph: PruneGraph, + channel_group_cfg: Union[Dict, + Type[ChannelGroup]] = ChannelGroup) -> None: + """ + Args: + graph (PruneGraph): input prune-graph + channel_group_cfg: the config for generating groups + """ + self.graph = graph + if isinstance(channel_group_cfg, dict): + self.channel_group_class = MODELS.module_dict[ + channel_group_cfg['type']] + self.channel_group_args = copy.copy(channel_group_cfg) + self.channel_group_args.pop('type') + else: + self.channel_group_class = channel_group_cfg + self.channel_group_args = {} + self.groups = self.parse(self.graph) + + # group operations + + def new_channel_group(self, num_channels) -> ChannelGroup: + """Initialize a ChannelGroup.""" + return self.channel_group_class(num_channels, + **self.channel_group_args) + + def union_node_groups( + self, + node_groups_list=List[IndexDict[ChannelGroup]] + ) -> List[ChannelGroup]: + union_groups = [] + for index in copy.copy(node_groups_list[0]): + groups = [node_groups[index] for node_groups in node_groups_list] + group = self.union_groups(groups) + union_groups.append(group) + return union_groups + + def union_groups(self, groups: List[ChannelGroup]) -> ChannelGroup: + """List[ChannelGroup]: union a list of ChannelGroups""" + group = self.channel_group_class.union(groups) + # avoid removing multiple times + groups_set = set(groups) + for old_group in groups_set: + old_group.remove_from_node() + group.apply_for_node() + return group + + def align_node_groups(self, nodes_groups: List[IndexDict[ChannelGroup]]): + """Align the ChannelGroups in the prev nodes. + + Example(pseudocode): + >>> node1 + (0,4):group1, (4,8):group2 + >>> node2 + (0,2):group3, (2,8):group4 + >>> prev_nodes=[node1,node2] + >>> align_prev_output_groups(prev_nodes) + node1: (0,2):group5, (2,4):group6, (4,8):group7 + node2: (0,2):group8, (2,4):group9, (4,8):group10 + """ + + def points2nums(points): + nums = [points[i + 1] - points[i] for i in range(len(points) - 1)] + return nums + + # get points + points = set() + for node_groups in nodes_groups: + start = 0 + for group in node_groups.values(): + points.add(start) + points.add(start + group.num_channels) + start += group.num_channels + points_list = list(points) + points_list.sort() + + # split group + new_groups: List[ChannelGroup] = [] + old_groups: List[ChannelGroup] = [] + for node_groups in nodes_groups: + start = 0 + for group in node_groups.values(): + end = start + group.num_channels + in_points = [ + point for point in points_list if start <= point <= end + ] + in_nums = points2nums(in_points) + if len(in_nums) == 1: + pass + else: + split_groups = group.split(in_nums) + new_groups.extend(split_groups) + old_groups.append(group) + start = end + + # apply + for group in old_groups: + group.remove_from_node() + for group in new_groups: + group.apply_for_node() + + # operations + + def add_input_related(self, + group: ChannelGroup, + node: PruneNode, + index: Tuple[int, int] = None, + expand_ratio: int = 1): + """Add a Channel of a PruneNode to a the input-related channels of a + ChannelGroup.""" + if index is None: + index = (0, node.in_channels) + group.add_input_related( + node.get_channels( + index, out_related=False, expand_ratio=expand_ratio)) + node.input_related_groups[index] = group + + def add_output_related(self, + group: ChannelGroup, + node: PruneNode, + index: Tuple[int, int] = None, + expand_ratio=1): + """Add a Channel of a PruneNode to a the output-related channels of a + ChannelGroup.""" + if index is None: + index = (0, node.out_channels) + group.add_ouptut_related( + node.get_channels( + index, out_related=True, expand_ratio=expand_ratio)) + node.output_related_groups[index] = group + + # parse + + def parse_node(self, node: PruneNode): + """Parse the channels of a node, and create or update ChannelGroups.""" + prev_node_groups = node.output_related_groups_of_prev_nodes() + + if node.is_parsed: + + # align + self.align_node_groups(prev_node_groups + + [node.input_related_groups]) + + # union + prev_node_groups = node.output_related_groups_of_prev_nodes() + self.union_node_groups(prev_node_groups + + [node.input_related_groups]) + + elif node.is_mix_node(): + assert len(prev_node_groups) <= 1 + input_channel = node.prev_nodes[0].out_channels if len( + node.prev_nodes) == 1 else 0 + assert input_channel == 0 or \ + node.in_channels % input_channel == 0 + + # new group and add output-related + current_group = self.new_channel_group(node.out_channels) + self.add_output_related(current_group, node) + + # add input-related + for node_groups in prev_node_groups: + start = 0 + for group in node_groups.values(): + self.add_input_related( + group, + node, + index=(start, start + group.num_channels), + expand_ratio=node.in_channels // + input_channel if input_channel != 0 else 1) + start += group.num_channels + + elif node.is_pass_node(): + assert len(prev_node_groups) <= 1, \ + (f'{node} is a pass node which should' + 'not has more than one pre node') + + # add input-related and output-related + for node_groups in prev_node_groups: + start = 0 + for group in node_groups.values(): + self.add_output_related( + group, node, index=(start, start + group.num_channels)) + self.add_input_related( + group, node, index=(start, start + group.num_channels)) + start += group.num_channels + + elif node.is_bind_node(): + assert len(prev_node_groups) > 1 + for node_groups in prev_node_groups: + assert len(node_groups) > 0, \ + f'{node},{prev_node_groups} is a bind node which \ + should have more than one pre nodes' + + # align + self.align_node_groups(prev_node_groups) + + # union + unoin_groups = self.union_node_groups(prev_node_groups) + + # add output-related + start = 0 + for group in unoin_groups: + self.add_output_related(group, node, + (start, start + group.num_channels)) + start += group.num_channels + + elif node.is_cat_node(): + # add output-related + start = 0 + for node_groups in prev_node_groups: + for group in node_groups.values(): + self.add_output_related( + group, node, (start, start + group.num_channels)) + start += group.num_channels + + else: + raise NotImplementedError(f'{node.basic_type}') + + def parse(self, graph: PruneGraph): + """Parse a module-graph and get ChannelGroups.""" + for node in graph.topo_traverse(): + self.parse_node(node) + return graph.colloct_groups() diff --git a/mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py b/mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py new file mode 100644 index 000000000..55e7a2e16 --- /dev/null +++ b/mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py @@ -0,0 +1,245 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""This module defines MutableChannelGroup.""" +import abc +from typing import Dict, List, Type, TypeVar, Union + +import torch.nn as nn +from mmengine.model import BaseModule + +from mmrazor.models.architectures.dynamic_ops.bricks import DynamicChannelMixin +from mmrazor.models.mutables.mutable_channel.base_mutable_channel import \ + BaseMutableChannel +from ..mutable_channel_container import MutableChannelContainer +from .channel_group import Channel, ChannelGroup + + +class MutableChannelGroup(ChannelGroup, BaseModule): + """MutableChannelGroup inherits from ChannelGroup, which manages channels + with channel-dependency. + + Compared with ChannelGroup, MutableChannelGroup defines the core + interfaces for pruning. By inheriting MutableChannelGroup, we can implement + a variant pruning algorithm. + + Basic Property + + name + is_prunable + + Important interfaces during different stages: + + # Before pruning + prepare_model + prepare_for_pruning + + # Pruning stage + current_choice + sample_choice + + # After pruning + fix_chosen + """ + + def __init__(self, num_channels: int) -> None: + """ + Args: + num_channels (int): dimension of the channels that this + MutableChannelGroup manages. + """ + super().__init__(num_channels) + BaseModule.__init__(self) + + # basic property + + @property + def name(self) -> str: + """str: name of the group""" + first_module = self.output_related[0] if len( + self.output_related) > 0 else self.input_related[0] + name = f'{first_module.name}_{first_module.index}_' + name += f'out_{len(self.output_related)}_in_{len(self.input_related)}' + return name + + @property + def is_prunable(self) -> bool: + """If the channel-group is prunable.""" + + def traverse(channels: List[Channel]): + has_dynamic_op = False + all_channel_prunable = True + for channel in channels: + if channel.node.is_prunable is False: + all_channel_prunable = False + break + if isinstance(channel.module, DynamicChannelMixin): + has_dynamic_op = True + return has_dynamic_op, all_channel_prunable + + input_has_dynamic_op, input_all_prunable = traverse(self.input_related) + output_has_dynamic_op, output_all_prunable = traverse( + self.output_related) + + return len(self.output_related) > 0 \ + and len(self.input_related) > 0 \ + and input_has_dynamic_op \ + and input_all_prunable \ + and output_has_dynamic_op \ + and output_all_prunable + + # before pruning: prepare a model + + @abc.abstractmethod + def prepare_for_pruning(self, model): + """Post process after parse groups. + + For example, we need to register mutables to dynamic-ops. + """ + raise not NotImplementedError + + # pruning: choice-related + + @property + def current_choice(self): + """Choice of this group.""" + raise NotImplementedError() + + @current_choice.setter + def current_choice(self, choice) -> None: + """setter of current_choice.""" + raise NotImplementedError() + + @abc.abstractmethod + def sample_choice(self): + """Randomly sample a valid choice and return.""" + raise NotImplementedError() + + def config_template(self, with_info=False): + if with_info: + return {'info': self._info_dict} + else: + return {} + + # after pruning + + def fix_chosen(self, choice=None): + """Make the channels in this group fixed.""" + if choice is not None: + self.current_choice = choice + + # tools + + def _info_dict(self): + info = { + 'num_channels': self.num_channels, + 'choice': self.current_choice, + 'prunable': self.is_prunable, + 'input_layers': self.input_related, + 'out_related': self.output_related + } + return info + + def _get_int_choice(self, choice: Union[int, float]) -> int: + """Convert ratio of channels to number of channels.""" + if isinstance(choice, float): + choice = max(1, int(self.num_channels * choice)) + assert 0 < choice <= self.num_channels, f'{choice}' + return choice + + def _replace_with_dynamic_ops( + self, model: nn.Module, + dynamicop_map: Dict[Type[nn.Module], Type[DynamicChannelMixin]]): + """Replace torch modules with dynamic-ops.""" + + def replace_op(model: nn.Module, name: str, module: nn.Module): + names = name.split('.') + for sub_name in names[:-1]: + model = getattr(model, sub_name) + + setattr(model, names[-1], module) + + def get_module(model, name): + names = name.split('.') + for sub_name in names: + model = getattr(model, sub_name) + return model + + for channel in self.input_related + self.output_related: + if isinstance(channel.module, nn.Module): + module = get_module(model, channel.module_name) + if type(module) in dynamicop_map: + new_module = dynamicop_map[type(module)].convert_from( + module) + replace_op(model, channel.module_name, new_module) + channel.module = new_module + else: + channel.module = module + + @staticmethod + def _register_mask_container( + model: nn.Module, container_class: Type[MutableChannelContainer]): + """register channel container for dynamic ops.""" + for module in model.modules(): + if isinstance(module, DynamicChannelMixin): + if module.get_mutable_attr('in_channels') is None: + in_channels = 0 + if isinstance(module, nn.Conv2d): + in_channels = module.in_channels + elif isinstance(module, nn.modules.batchnorm._BatchNorm): + in_channels = module.num_features + elif isinstance(module, nn.Linear): + in_channels = module.in_features + else: + raise NotImplementedError() + module.register_mutable_attr('in_channels', + container_class(in_channels)) + if module.get_mutable_attr('out_channels') is None: + out_channels = 0 + if isinstance(module, nn.Conv2d): + out_channels = module.out_channels + elif isinstance(module, nn.modules.batchnorm._BatchNorm): + out_channels = module.num_features + elif isinstance(module, nn.Linear): + out_channels = module.out_features + else: + raise NotImplementedError() + module.register_mutable_attr('out_channels', + container_class(out_channels)) + + def _register_mask(self, mutable_channel: BaseMutableChannel): + + # register mutable_channel + for channel in self.input_related + self.output_related: + module = channel.module + if isinstance(module, DynamicChannelMixin): + container: MutableChannelContainer + if channel.output_related and module.get_mutable_attr( + 'out_channels') is not None: + container = module.get_mutable_attr('out_channels') + elif channel.output_related is False \ + and module.get_mutable_attr('in_channels') is not None: + container = module.get_mutable_attr('in_channels') + else: + raise NotImplementedError() + + if channel.expand_ratio == 1: + mutable_channel_ = mutable_channel + start = channel.start + end = channel.end + else: + mutable_channel_ = mutable_channel.expand_mutable_channel( + channel.expand_ratio) + start = channel.start + end = channel.start + ( + channel.end - channel.start) * channel.expand_ratio + if (start, end) in container.mutable_channels: + # TODO refine assert + existed = container.mutable_channels[(start, end)] + + assert mutable_channel is existed \ + or mutable_channel_ is list( + existed._trace_source_mutables)[0] + else: + container.register_mutable(mutable_channel_, start, end) + + +MUTABLECHANNELGROUP = TypeVar('MUTABLECHANNELGROUP', bound=MutableChannelGroup) diff --git a/mmrazor/models/mutables/mutable_channel/groups/sequential_channel_group.py b/mmrazor/models/mutables/mutable_channel/groups/sequential_channel_group.py new file mode 100644 index 000000000..658228958 --- /dev/null +++ b/mmrazor/models/mutables/mutable_channel/groups/sequential_channel_group.py @@ -0,0 +1,73 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random + +import torch +import torch.nn as nn + +from mmrazor.models.architectures.dynamic_ops.bricks import ( + DynamicBatchNorm2d, DynamicConv2d, DynamicLinear) +from mmrazor.registry import MODELS +from ..mutable_channel_container import MutableChannelContainer +from ..simple_mutable_channel import SimpleMutableChannel +from .mutable_channel_group import MutableChannelGroup + + +@MODELS.register_module() +class SequentialChannelGroup(MutableChannelGroup): + """SimpleChannelGroup defines a simple pruning algorithhm. + + The type of choice of SimpleChannelGroup is float. It indicates what ratio + of channels are remained from left to right. + """ + + def __init__(self, num_channels: int) -> None: + super().__init__(num_channels) + self.mutable_channel: SimpleMutableChannel = SimpleMutableChannel( + self.num_channels) + + # prepare model + + def prepare_for_pruning(self, model: nn.Module): + """Prepare for pruning, including register mutable channels.""" + # register MutableMask + self._replace_with_dynamic_ops( + model, { + nn.Conv2d: DynamicConv2d, + nn.BatchNorm2d: DynamicBatchNorm2d, + nn.Linear: DynamicLinear + }) + self._register_mask_container(model, MutableChannelContainer) + self._register_mask(self.mutable_channel) + + # choice + + @property + def current_choice(self) -> float: + """return current choice.""" + return self.mutable_channel.activated_channels / self.num_channels + + @current_choice.setter + def current_choice(self, choice: float): + """set choice.""" + int_choice = self._get_int_choice(choice) + mask = self._generate_mask(int_choice) + self.mutable_channel.current_choice = mask + + def sample_choice(self) -> float: + """Sample a choice in (0,1]""" + return max(1, int( + random.random() * self.num_channels)) / self.num_channels + + # private methods + + def _generate_mask(self, choice: int) -> torch.Tensor: + """torch.Tesnor: generate mask for pruning""" + mask = torch.zeros([self.num_channels]) + mask[0:choice] = 1 + return mask + + # interface + def fix_chosen(self, choice=None): + """fix chosen.""" + super().fix_chosen(choice) + self.mutable_channel.fix_chosen() diff --git a/mmrazor/models/mutables/mutable_channel/mutable_channel_container.py b/mmrazor/models/mutables/mutable_channel/mutable_channel_container.py new file mode 100644 index 000000000..c7c7111df --- /dev/null +++ b/mmrazor/models/mutables/mutable_channel/mutable_channel_container.py @@ -0,0 +1,106 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import torch + +from mmrazor.registry import MODELS +from mmrazor.utils import IndexDict +from .base_mutable_channel import BaseMutableChannel +from .simple_mutable_channel import SimpleMutableChannel + + +@MODELS.register_module() +class MutableChannelContainer(BaseMutableChannel): + """MutableChannelContainer inherits from BaseMutableChannel. However, + it's not a single BaseMutableChannel, but a container for + BaseMutableChannel. The mask of MutableChannelContainer consists of + all masks of stored MutableChannels. + + ----------------------------------------------------------- + | MutableChannelContainer | + ----------------------------------------------------------- + |MutableChannel1| MutableChannel2 |MutableChannel3| + ----------------------------------------------------------- + + Important interfaces: + register_mutable: register/store BaseMutableChannel in the + MutableChannelContainer + """ + + def __init__(self, num_channels: int, **kwargs): + super().__init__(num_channels, **kwargs) + self.mutable_channels: IndexDict[BaseMutableChannel] = IndexDict() + + # choice + + @property + def current_choice(self): + """Get current choices.""" + if len(self.mutable_channels) == 0: + return torch.ones([self.num_channels]).bool() + else: + self._full_empty_range() + self._assert_mutables_valid() + mutable_channels = list(self.mutable_channels.values()) + masks = [mutable.current_mask for mutable in mutable_channels] + mask = torch.cat(masks) + return mask.bool() + + @current_choice.setter + def current_choice(self, choice): + """Set current choices. + + However, MutableChannelContainer doesn't support directly set mask. You + can change the mask of MutableChannelContainer by changing its stored + BaseMutableChannel. + """ + raise NotImplementedError() + + @property + def current_mask(self) -> torch.Tensor: + """Return current mask.""" + return self.current_choice.bool() + + # basic extension + + def register_mutable(self, mutable_channel: BaseMutableChannel, start: int, + end: int): + """Register/Store BaseMutableChannel in the MutableChannelContainer in + the range [start,end)""" + + self.mutable_channels[(start, end)] = mutable_channel + + # private methods + + def _assert_mutables_valid(self): + """Assert the current stored BaseMutableChannels are valid to generate + mask.""" + assert len(self.mutable_channels) > 0 + last_end = 0 + for start, end in self.mutable_channels: + assert start == last_end + last_end = end + assert last_end == self.num_channels + + def _full_empty_range(self): + """Add SimpleMutableChannels in the range without any stored + BaseMutableChannel.""" + last_end = 0 + for start, end in copy.copy(self.mutable_channels): + if last_end < start: + self.register_mutable( + SimpleMutableChannel(last_end - start), last_end, start) + last_end = end + if last_end < self.num_channels: + self.register_mutable( + SimpleMutableChannel(self.num_channels - last_end), last_end, + self.num_channels) + + # others + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(name={self.name}, ' + repr_str += f'num_channels={self.num_channels}, ' + repr_str += f'activated_channels: {self.activated_channels}' + return repr_str diff --git a/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py new file mode 100644 index 000000000..13a39f0a1 --- /dev/null +++ b/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py @@ -0,0 +1,64 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmrazor.registry import MODELS +from .base_mutable_channel import BaseMutableChannel + + +@MODELS.register_module() +class SquentialMutableChannel(BaseMutableChannel): + """SquentialMutableChannel defines a BaseMutableChannel which switch off + channel mask from right to left sequentially, like '11111000'. + + A choice of SquentialMutableChannel is an integer, which indicates how many + channel are activated from left to right. + """ + + def __init__(self, num_channels: int, **kwargs): + super().__init__(num_channels, **kwargs) + self.mask = torch.ones([self.num_channels]).bool() + + @property + def current_choice(self) -> int: + """Get current choice.""" + return (self.mask == 1).sum().item() + + @current_choice.setter + def current_choice(self, choice: int): + """Set choice.""" + mask = torch.zeros([self.num_channels], device=self.mask.device) + mask[0:choice] = 1 + self.mask = mask.bool() + + @property + def current_mask(self) -> torch.Tensor: + """Return current mask.""" + return self.mask + + # methods for + + def fix_chosen(self, chosen=...): + """Fix chosen.""" + if chosen is ...: + chosen = self.current_choice + assert self.is_fixed is False + self.current_choice = chosen + self.is_fixed = True + + def dump_chosen(self): + """Dump chosen.""" + return self.current_choice + + def __mul__(self, other): + """multiplication.""" + if isinstance(other, int): + return self.derive_expand_mutable(other) + else: + return None + + def __floordiv__(self, other): + """division.""" + if isinstance(other, int): + return self.derive_divide_mutable(other) + else: + return None diff --git a/mmrazor/models/mutables/mutable_channel/simple_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/simple_mutable_channel.py new file mode 100644 index 000000000..b209cc05c --- /dev/null +++ b/mmrazor/models/mutables/mutable_channel/simple_mutable_channel.py @@ -0,0 +1,61 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from functools import partial + +import torch + +from mmrazor.registry import MODELS +from ..derived_mutable import DerivedMutable +from .base_mutable_channel import BaseMutableChannel + + +@MODELS.register_module() +class SimpleMutableChannel(BaseMutableChannel): + """SimpleMutableChannel is a simple BaseMutableChannel, it directly take a + mask as a choice.""" + + def __init__(self, num_channels, **kwargs) -> None: + super().__init__(num_channels, **kwargs) + self.num_channels = num_channels + self.mask = torch.ones(num_channels).bool() + + # choice + + @property + def current_choice(self) -> torch.Tensor: + """Get current choice.""" + return self.mask.bool() + + @current_choice.setter + def current_choice(self, choice: torch.Tensor): + """Set current choice.""" + self.mask = choice.to(self.mask.device).bool() + + @property + def current_mask(self) -> torch.Tensor: + """Get current mask.""" + return self.current_choice.bool() + + # basic extension + + def expand_mutable_channel(self, expand_ratio) -> DerivedMutable: + """Get a derived SimpleMutableChannel with expanded mask.""" + + def _expand_mask(mutable_channel, expand_ratio): + mask = mutable_channel.current_mask + mask = torch.unsqueeze( + mask, -1).expand(list(mask.shape) + [expand_ratio]).flatten(-2) + return mask + + derive_fun = partial( + _expand_mask, mutable_channel=self, expand_ratio=expand_ratio) + return DerivedMutable(derive_fun, derive_fun, [self]) + + # others + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += '(' + repr_str += f'num_channels={self.num_channels}, ' + repr_str += f'activated_channels: {self.activated_channels}' + repr_str += ')' + return repr_str diff --git a/mmrazor/structures/graph/module_graph.py b/mmrazor/structures/graph/module_graph.py index 8b7f9c920..658224cea 100644 --- a/mmrazor/structures/graph/module_graph.py +++ b/mmrazor/structures/graph/module_graph.py @@ -10,11 +10,16 @@ import torch.nn as nn from torch.nn import Module -from mmrazor.models.task_modules import (BackwardTracer, Path, PathConcatNode, - PathList, PathNode) -from mmrazor.models.task_modules.tracer import ImageClassifierPseudoLoss +from mmrazor.models.task_modules.tracer.backward_tracer import BackwardTracer +from mmrazor.models.task_modules.tracer.loss_calculator import \ + ImageClassifierPseudoLoss +from mmrazor.models.task_modules.tracer.path import (Path, PathConcatNode, + PathList, PathNode) +from mmrazor.registry import TASK_UTILS from .base_graph import BaseGraph, BaseNode +# ModuleNode && ModuleGraph + class ModuleNode(BaseNode): """A node in a computation graph. @@ -30,7 +35,8 @@ class ModuleNode(BaseNode): def __init__(self, name: str, val: Union[Module, str], - expand_ratio: int = 1) -> None: + expand_ratio: int = 1, + module_name='') -> None: """ Args: name (str): the name of the node @@ -56,6 +62,7 @@ def forward(x): 'expand != 1 is only valid when val=="pass"' super().__init__(name, val) self.expand_ratio = expand_ratio + self.module_name = module_name # channel @@ -204,9 +211,9 @@ def check_type(self): class ModuleGraph(BaseGraph[MODULENODE]): """Computatation Graph.""" - def __init__(self) -> None: + def __init__(self, model=None) -> None: super().__init__() - self._model = None + self._model: nn.Module = model # functions to generate module graph. @@ -217,12 +224,16 @@ def init_using_backward_tracer( loss_calculator=ImageClassifierPseudoLoss()), ): """init module graph using backward tracer.""" + if isinstance(backward_tracer, dict): + backward_tracer = TASK_UTILS.build(backward_tracer) path_lists = backward_tracer.trace(model) converter = PathToGraphConverter(path_lists, model) + converter.graph.refresh_module_name() return converter.graph @staticmethod - def init_using_fx_tracer(model: Module, is_extra_leaf_module=None): + def init_using_fx_tracer(model: Module, + fx_tracer={'type': 'RazorFxTracer'}): """init module graph using torch fx tracer.""" pass @@ -260,6 +271,16 @@ def connect_module(pre: Module, next: Module): next._pre = set() next._pre.add(pre) + # others + def refresh_module_name(self): + module2name = {} + for name, module in self._model.named_modules(): + module2name[module] = name + + for node in self: + if isinstance(node.val, nn.Module): + node.module_name = module2name[node.val] + # Converter @@ -267,8 +288,8 @@ def connect_module(pre: Module, next: Module): class GraphConverter: """Base class for converters for ModuleGraph.""" - def __init__(self) -> None: - self.graph = ModuleGraph[ModuleNode]() + def __init__(self, model) -> None: + self.graph = ModuleGraph[ModuleNode](model) self.cat_placeholder_num = 0 self.bind_placeholder_num = 0 self.pass_placeholder_num = 0 @@ -388,15 +409,15 @@ def __init__(self, path_list: PathList, model: Module) -> None: path_list (PathList): path_list generated by backward tracer. model (Module): the model corresponding to the path_list """ - super().__init__() + super().__init__(model) self.path_list = path_list self.cat_dict: Dict[str, str] = {} self.name2module = dict(model.named_modules()) - self._pass(self.path_list) + self._parse(self.path_list) self._post_process() - def _pass(self, path_list: PathList): + def _parse(self, path_list: PathList): """Parse path list.""" self._parse_helper(path_list, []) diff --git a/mmrazor/utils/__init__.py b/mmrazor/utils/__init__.py index 5fb8dc209..8490e8eef 100644 --- a/mmrazor/utils/__init__.py +++ b/mmrazor/utils/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .index_dict import IndexDict from .misc import find_latest_checkpoint from .placeholder import get_placeholder from .setup_env import register_all_modules, setup_multi_processes @@ -9,5 +10,6 @@ __all__ = [ 'find_latest_checkpoint', 'setup_multi_processes', 'register_all_modules', 'FixMutable', 'ValidFixMutable', 'SingleMutatorRandomSubnet', - 'MultiMutatorsRandomSubnet', 'SupportRandomSubnet', 'get_placeholder' + 'MultiMutatorsRandomSubnet', 'SupportRandomSubnet', 'get_placeholder', + 'IndexDict' ] diff --git a/mmrazor/utils/index_dict.py b/mmrazor/utils/index_dict.py new file mode 100644 index 000000000..de7d6ae3a --- /dev/null +++ b/mmrazor/utils/index_dict.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import OrderedDict, Tuple, TypeVar + +VT = TypeVar('VT') + + +class IndexDict(OrderedDict[Tuple[int, int], VT]): + """IndexDict inherents from OrderedDict[Tuple[int, int], VT]. + + The type of the key is a Tuple[a: int,b: int]. It indicates a index range + [a,b). IndexDict can sort index and checkout if indexes overlap + """ + + def __setitem__(self, __k: Tuple[int, int], __v): + """set item.""" + start, end = __k + assert start < end + self._assert_no_over_lap(start, end) + super().__setitem__(__k, __v) + self._sort() + + def _sort(self): + """sort the dict accorrding to index.""" + items = sorted(self.items()) + self.clear() + for k, v in items: + super().__setitem__(k, v) + + def _assert_no_over_lap(self, start, end): + """Assert the index [start,end) has no over lav with existed + indexes.""" + assert (start, end) not in self, 'index overlap' + + def __contains__(self, __o) -> bool: + """Bool: if the index has any overlap with existed indexes""" + if super().__contains__(__o): + return True + else: + self._assert_is_index(__o) + start, end = __o + existed = False + for s, e in self.keys(): + existed = (s <= start < e or s < end < e or + (s < start and end < e)) or existed + + return existed + + def _assert_is_index(self, index): + """Assert the index is an instance of Tuple[int,int]""" + assert isinstance(index, Tuple) \ + and len(index) == 2 \ + and isinstance(index[0], int) \ + and isinstance(index[1], int) diff --git a/tests/data/models.py b/tests/data/models.py index dd328b516..537dba7be 100644 --- a/tests/data/models.py +++ b/tests/data/models.py @@ -7,6 +7,18 @@ # this file includes models for tesing. +class LinearHead(Module): + + def __init__(self, in_channel, num_class=1000) -> None: + super().__init__() + self.pool = nn.AdaptiveAvgPool2d(1) + self.linear = nn.Linear(in_channel, num_class) + + def forward(self, x): + pool = self.pool(x).flatten(1) + return self.linear(pool) + + class MultiConcatModel(Module): """ x---------------- @@ -127,7 +139,7 @@ def forward(self, x: Tensor) -> Tensor: output = self.fc(x_pool) return output - + class ResBlock(Module): """ @@ -233,6 +245,20 @@ def forward(self, x): class GroupWiseConvModel(nn.Module): + """ + x + |op1,bn1 + x1 + |op2,bn2 + x2 + |op3 + x3 + |avg_pool + x_pool + |fc + y + """ + def __init__(self) -> None: super().__init__() self.op1 = nn.Conv2d(3, 8, 3, 1, 1) @@ -240,6 +266,8 @@ def __init__(self) -> None: self.op2 = nn.Conv2d(8, 16, 3, 1, 1, groups=2) self.bn2 = nn.BatchNorm2d(16) self.op3 = nn.Conv2d(16, 32, 3, 1, 1) + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(32, 1000) def forward(self, x): x1 = self.op1(x) @@ -291,6 +319,23 @@ def forward(self, x): class MultipleUseModel(nn.Module): + """ + x------------------------ + |conv0 |conv1 |conv2 |conv3 + xs.0 xs.1 xs.2 xs.3 + |convm |convm |convm |convm + xs_.0 xs_.1 xs_.2 xs_.3 + | | | | + +------------------------ + | + x_sum + |conv_last + feature + |avg_pool + pool + |linear + output + """ def __init__(self) -> None: super().__init__() @@ -299,7 +344,7 @@ def __init__(self) -> None: self.conv2 = nn.Conv2d(3, 8, 3, 1, 1) self.conv3 = nn.Conv2d(3, 8, 3, 1, 1) self.conv_multiple_use = nn.Conv2d(8, 16, 3, 1, 1) - self.conv_last = nn.Conv2d(16, 32, 3, 1, 1) + self.conv_last = nn.Conv2d(16*4, 32, 3, 1, 1) self.avg_pool = nn.AdaptiveAvgPool2d(1) self.linear = nn.Linear(32, 1000) @@ -309,17 +354,133 @@ def forward(self, x): for conv in [self.conv0, self.conv1, self.conv2, self.conv3] ] xs_ = [self.conv_multiple_use(x_) for x_ in xs] - x_sum = 0 - for x_ in xs_: - x_sum = x_sum + x_ - feature = self.conv_last(x_sum) + x_cat = torch.cat(xs_, dim=1) + feature = self.conv_last(x_cat) pool = self.avg_pool(feature).flatten(1) return self.linear(pool) +class IcepBlock(nn.Module): + """ + x------------------------ + |op1 |op2 |op3 |op4 + x1 x2 x3 x4 + | | | | + cat---------------------- + | + y_ + """ + + def __init__(self, in_c=3, out_c=32) -> None: + super().__init__() + self.op1 = nn.Conv2d(in_c, out_c, 3, 1, 1) + self.op2 = nn.Conv2d(in_c, out_c, 3, 1, 1) + self.op3 = nn.Conv2d(in_c, out_c, 3, 1, 1) + self.op4 = nn.Conv2d(in_c, out_c, 3, 1, 1) + # self.op5 = nn.Conv2d(out_c*4, out_c, 3) + + def forward(self, x): + x1 = self.op1(x) + x2 = self.op2(x) + x3 = self.op3(x) + x4 = self.op4(x) + y_ = [x1, x2, x3, x4] + y_ = torch.cat(y_, 1) + return y_ + + +class Icep(nn.Module): + + def __init__(self, num_icep_blocks=2) -> None: + super().__init__() + self.icps = nn.Sequential(*[ + IcepBlock(32 * 4 if i != 0 else 3, 32) + for i in range(num_icep_blocks) + ]) + self.op = nn.Conv2d(32 * 4, 32, 1) + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(32, 1000) + + def forward(self, x): + y_ = self.icps(x) + y = self.op(y_) + pool = self.avg_pool(y).flatten(1) + return self.fc(pool) + + +class ExpandLineModel(Module): + """ + x + |net0,net1 + |net2 + |net3 + x1 + |fc + output + """ + + def __init__(self) -> None: + super().__init__() + self.net = nn.Sequential( + nn.Conv2d(3, 8, 3, 1, 1), nn.BatchNorm2d(8), nn.ReLU(), + nn.Conv2d(8, 16, 3, 1, 1), nn.BatchNorm2d(16), + nn.AdaptiveAvgPool2d(2)) + self.linear = nn.Linear(64, 1000) + + def forward(self, x): + x1 = self.net(x) + x1 = x1.reshape([x1.shape[0], -1]) + return self.linear(x1) + + +class MultiBindModel(Module): + + def __init__(self) -> None: + super().__init__() + self.conv1 = nn.Conv2d(3, 8, 3, 1, 1) + self.conv2 = nn.Conv2d(3, 8, 3, 1, 1) + self.conv3 = nn.Conv2d(8, 8, 3, 1, 1) + self.head = LinearHead(8, 1000) + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.conv2(x) + x12 = x1 + x2 + x3 = self.conv3(x12) + x123 = x12 + x3 + return self.head(x123) + + +class DwConvModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.net = nn.Sequential( + nn.Conv2d(3, 48, 3, 1, 1), + nn.BatchNorm2d(48), + nn.ReLU(), + nn.Conv2d(48, 48, 3, 1, 1, groups=48), + nn.BatchNorm2d(48), + nn.ReLU() + ) + self.head = LinearHead(48, 1000) + + def forward(self, x): + return self.head(self.net(x)) + + default_models = [ - LineModel, ResBlock, AddCatModel, ConcatModel, MultiConcatModel, - MultiConcatModel2, GroupWiseConvModel, Xmodel, MultipleUseModel + LineModel, + ResBlock, + AddCatModel, + ConcatModel, + MultiConcatModel, + MultiConcatModel2, + GroupWiseConvModel, + Xmodel, + MultipleUseModel, + Icep, + ExpandLineModel, + DwConvModel, ] diff --git a/tests/test_core/__init__.py b/tests/test_core/__init__.py new file mode 100644 index 000000000..ef101fec6 --- /dev/null +++ b/tests/test_core/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_core/test_graph/__init__.py b/tests/test_core/test_graph/__init__.py new file mode 100644 index 000000000..ef101fec6 --- /dev/null +++ b/tests/test_core/test_graph/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_core/test_graph/test_graph.py b/tests/test_core/test_graph/test_graph.py index 2b7995a09..8278988a1 100644 --- a/tests/test_core/test_graph/test_graph.py +++ b/tests/test_core/test_graph/test_graph.py @@ -1,16 +1,29 @@ # Copyright (c) OpenMMLab. All rights reserved. +import os import sys from unittest import TestCase import torch +from mmrazor.models.architectures.dynamic_ops.bricks import DynamicChannelMixin from mmrazor.structures.graph import ModuleGraph -from tests.data.models import (AddCatModel, ConcatModel, LineModel, - MultiConcatModel, MultiConcatModel2, ResBlock) +from ...data.models import Icep # noqa +from ...data.models import MultipleUseModel # noqa +from ...data.models import Xmodel # noqa +from ...data.models import (AddCatModel, ConcatModel, DwConvModel, + ExpandLineModel, GroupWiseConvModel, LineModel, + ModelLibrary, MultiBindModel, MultiConcatModel, + MultiConcatModel2, ResBlock) + +FULL_TEST = os.getenv('FULL_TEST') == 'true' sys.setrecursionlimit(int(1e8)) +def is_dynamic_op_fx(module, name): + return isinstance(module, DynamicChannelMixin) + + class ToyCNNPseudoLoss: def __call__(self, model): @@ -19,59 +32,68 @@ def __call__(self, model): return pseudo_output.sum() -DATA = [ - { - 'model': LineModel, - 'num_nodes': 5, - }, - { - 'model': ResBlock, - 'num_nodes': 7, - }, - { - 'model': ConcatModel, - 'num_nodes': 7, - }, - { - 'model': MultiConcatModel2, - 'num_nodes': 7, - }, - { - 'model': MultiConcatModel, - 'num_nodes': 7, - }, - { - 'model': AddCatModel - }, -] - - class TestGraph(TestCase): - def test_graph_init(self) -> None: - - for data in DATA: + @classmethod + def backward_tracer_passed_models(cls): + '''MultipleUseModel: backward tracer can't distinguish multiple use and + first bind then use.''' + default_models = [ + LineModel, + ResBlock, + AddCatModel, + ConcatModel, + MultiConcatModel, + MultiConcatModel2, + GroupWiseConvModel, + Xmodel, + # MultipleUseModel, # bug + # Icep, bug + ExpandLineModel, + MultiBindModel, + DwConvModel + ] + """ + googlenet return a tuple when training, so it + should trace in eval mode + """ + + torch_models_includes = [ + 'alexnet', + 'densenet', + 'efficientnet', + 'googlenet', + # 'inception', bug + 'mnasnet', + 'mobilenet', + 'regnet', + 'resnet', + 'resnext', + # 'shufflenet', # bug + 'squeezenet', + 'vgg', + 'wide_resnet', + ] + model_library = ModelLibrary(torch_models_includes) + + models = default_models + model_library.export_models( + ) if FULL_TEST else default_models + return models + + def test_init_using_backward_tracer(self) -> None: + TestData = self.backward_tracer_passed_models() + + for data in TestData: with self.subTest(data=data): - model = data['model']() - # print(model) - graphs = [ - ModuleGraph.init_using_backward_tracer(model), - ] - - unit_num = len(graphs[0].nodes) - - for graph in graphs: - - # check channels - try: - graph.check() - except Exception as e: - self.fail(str(e) + '\n' + str(graph)) - - # check number of nodes - self.assertEqual(unit_num, len(graph.nodes)) - if 'num_nodes' in data: - self.assertEqual( - len(graph), - data['num_nodes'], - msg=f'{graph.nodes}') + model = data() + model.eval() + graph = ModuleGraph.init_using_backward_tracer(model) + + # check channels + self._valid_graph(graph) + + def _valid_graph(self, graph: ModuleGraph): + try: + graph.check() + except Exception as e: + self.fail(str(e) + '\n' + str(graph)) diff --git a/tests/test_models/__init__.py b/tests/test_models/__init__.py new file mode 100644 index 000000000..ef101fec6 --- /dev/null +++ b/tests/test_models/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_models/test_mutables/__init__.py b/tests/test_models/test_mutables/__init__.py new file mode 100644 index 000000000..ef101fec6 --- /dev/null +++ b/tests/test_models/test_mutables/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_models/test_mutables/group/__init__.py b/tests/test_models/test_mutables/group/__init__.py new file mode 100644 index 000000000..ef101fec6 --- /dev/null +++ b/tests/test_models/test_mutables/group/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_models/test_mutables/group/test_mutable_channel_groups.py b/tests/test_models/test_mutables/group/test_mutable_channel_groups.py new file mode 100644 index 000000000..f6c17ab97 --- /dev/null +++ b/tests/test_models/test_mutables/group/test_mutable_channel_groups.py @@ -0,0 +1,103 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List +from unittest import TestCase + +import torch +import torch.nn as nn + +from mmrazor.models.architectures.dynamic_ops.bricks.dynamic_mixins import \ + DynamicChannelMixin +from mmrazor.models.mutables.mutable_channel import (MutableChannelGroup, + SequentialChannelGroup) +from mmrazor.models.mutables.mutable_channel.groups.channel_group import ( # noqa + Channel, PruneNode) +from mmrazor.structures.graph import ModuleGraph as ModuleGraph +from ....test_core.test_graph.test_graph import TestGraph + +MUTABLE_CFG = dict(type='SimpleMutableChannl') +TRACER_CFG = dict( + type='BackwardTracer', + loss_calculator=dict(type='ImageClassifierPseudoLoss')) + +# DEVICE = torch.device('cuda:0') if torch.cuda.is_available() \ +# else torch.device('cpu') +DEVICE = torch.device('cpu') +GROUPS: List[MutableChannelGroup] = [SequentialChannelGroup] + +DefaultChannelGroup = SequentialChannelGroup + + +class TestMutableChannelGroup(TestCase): + + def _test_a_graph(self, model, graph): + try: + groups = DefaultChannelGroup.parse_channel_groups(graph) + for group in groups: + group.prepare_for_pruning(model) + prunable_groups = [group for group in groups if group.is_prunable] + + for group in prunable_groups: + choice = group.sample_choice() + group.current_choice = choice + self.assertAlmostEqual(group.current_choice, choice, delta=0.1) + x = torch.rand([2, 3, 224, 224]).to(DEVICE) + y = model(x) + self.assertSequenceEqual(y.shape, [2, 1000]) + + except Exception as e: + self.fail(f'{e}') + + def _test_a_model_using_backward_tracer(self, model): + model.eval() + model = model.to(DEVICE) + graph = ModuleGraph.init_using_backward_tracer(model) + self._test_a_graph(model, graph) + + def test_with_backward_tracer(self): + test_models = TestGraph.backward_tracer_passed_models() + for model_data in test_models: + with self.subTest(model=model_data): + model = model_data() + self._test_a_model_using_backward_tracer(model) + + def test_group_split(self): + layer = nn.Conv2d(3, 16, 3) + node = PruneNode('layer', layer) + channel1 = Channel(node, (8, 16), True) + channel2 = Channel(node, (0, 8), True) + group = DefaultChannelGroup(8) + group.add_ouptut_related(channel1) + group.add_ouptut_related(channel2) + + groups = group.split([2, 6]) + self.assertEqual(groups[0].output_related[0].index, (8, 10)) + self.assertEqual(groups[0].output_related[1].index, (0, 2)) + self.assertEqual(groups[1].output_related[0].index, (10, 16)) + self.assertEqual(groups[1].output_related[1].index, (2, 8)) + + def test_replace_with_dynamic_ops(self): + model_datas = TestGraph.backward_tracer_passed_models() + for model_data in model_datas: + for group_type in GROUPS: + with self.subTest(model=model_data, group=group_type): + model: nn.Module = model_data() + graph = ModuleGraph.init_using_backward_tracer(model) + groups: List[ + MutableChannelGroup] = group_type.parse_channel_groups( + graph) + + for group in groups: + group.prepare_for_pruning(model) + + for module in model.modules(): + if isinstance(module, nn.Conv2d)\ + and module.groups == module.in_channels\ + and module.groups == 1: + self.assertTrue( + isinstance(module, DynamicChannelMixin)) + if isinstance(module, nn.Linear): + self.assertTrue( + isinstance(module, DynamicChannelMixin)) + if isinstance(module, nn.BatchNorm2d): + self.assertTrue( + isinstance(module, DynamicChannelMixin)) diff --git a/tests/test_utils/test_index_dict.py b/tests/test_utils/test_index_dict.py new file mode 100644 index 000000000..767dd806c --- /dev/null +++ b/tests/test_utils/test_index_dict.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import unittest + +from mmrazor.utils.index_dict import IndexDict + + +class TestIndexDict(unittest.TestCase): + + def test_dict(self): + dict = IndexDict() + dict[(4, 5)] = 2 + dict[(1, 3)] = 1 + + self.assertSequenceEqual(list(dict.keys()), [(1, 3), (4, 5)]) + with self.assertRaisesRegex(AssertionError, 'overlap'): + dict[2, 3] = 3 From 74c1e8bad7e718f871d36a2fa88a5d50ca602e34 Mon Sep 17 00:00:00 2001 From: liukai Date: Tue, 6 Sep 2022 16:49:24 +0800 Subject: [PATCH 02/25] modification for adding config_template --- .../mutable_channel/groups/channel_group.py | 137 +++++++++++++++--- .../groups/mutable_channel_group.py | 54 +++---- .../group/test_mutable_channel_groups.py | 4 +- 3 files changed, 147 insertions(+), 48 deletions(-) diff --git a/mmrazor/models/mutables/mutable_channel/groups/channel_group.py b/mmrazor/models/mutables/mutable_channel/groups/channel_group.py index 7a3968029..782f1c7a5 100644 --- a/mmrazor/models/mutables/mutable_channel/groups/channel_group.py +++ b/mmrazor/models/mutables/mutable_channel/groups/channel_group.py @@ -53,12 +53,15 @@ def get_channels(self, if index is None: index = (0, self.out_channels if out_related is True else self.in_channels) + name = self.module_name if isinstance(self.val, + nn.Module) else self.name channel = Channel( - self, + name, + self.val, index, + self, out_related=out_related, - expand_ratio=expand_ratio, - module_name=self.module_name) + expand_ratio=expand_ratio) return channel def output_related_groups_of_prev_nodes( @@ -197,11 +200,12 @@ class Channel: """Channel records information about channels for pruning.""" def __init__(self, - node: PruneNode, + name, + module, index, + node: PruneNode = None, out_related=True, - expand_ratio=1, - module_name='') -> None: + expand_ratio=1) -> None: """ Args: node: (PruneNode): prune-node to be recorded @@ -211,32 +215,73 @@ def __init__(self, expand_ratio (Bool): expand_ratio of the number of channels compared with pruning mask """ - self.node = node + self.name = name + self.module: DynamicChannelMixin = module self.index = index self.start = index[0] self.end = index[1] + + self.node = node + self.output_related = out_related self.expand_ratio = expand_ratio - self.name = node.name - self.module: DynamicChannelMixin = node.val - self.module_name = module_name + @classmethod + def init_using_cfg(cls, model: nn.Module, config: Dict): + """init a Channel using a config which can be generated by + self.config_template()""" + name = config['name'] + start = config['start'] + end = config['end'] + expand_ratio = config['expand_ratio'] + is_output = config['is_output_related'] + + name2module = dict(model.named_modules()) + name2module.pop('') + module = name2module[name] if name in name2module else None + return Channel( + name, + module, (start, end), + out_related=is_output, + expand_ratio=expand_ratio) + + # config + def config_template(self): + """Generate a config template which can be used to initialize a Channel + by cls.init_using_cfg(**kwargs)""" + return { + 'name': self.name, + 'start': self.start, + 'end': self.end, + 'expand_ratio': self.expand_ratio, + 'is_output_related': self.output_related + } @property def num_channels(self) -> int: """Int: number of channels in the Channels""" return self.index[1] - self.index[0] + @property + def is_prunable(self) -> bool: + if isinstance(self.module, nn.Conv2d): + if self.module.groups != 1 and not (self.module.groups == + self.module.in_channels == + self.module.out_channels): + return False + return True + # group related operations def slice(self, start: int, end: int) -> 'Channel': """Channel: a new Channel who manage a slice of the current Channel.""" channel = Channel( - self.node, + name=self.name, + module=self.module, index=(self.start + start, self.start + end), + node=self.node, out_related=self.output_related, - expand_ratio=self.expand_ratio, - module_name=self.module_name) + expand_ratio=self.expand_ratio) return channel # others @@ -247,6 +292,7 @@ def __repr__(self) -> str: expand:{self.expand_ratio}' +@MODELS.register_module() class ChannelGroup: """A manager for Channels.""" @@ -262,6 +308,34 @@ def __init__(self, num_channels: int) -> None: self.init_args: Dict = { } # is used to generate new channel group with same args + @classmethod + def init_using_cfg(cls, model: nn.Module, config: Dict): + """init a ChannelGroup using a config which can be generated by + self.config_template()""" + config = copy.deepcopy(config) + if 'channels' in config: + channels = config.pop('channels') + else: + channels = None + group = cls(**(config['init_args'])) + if channels is not None: + for channel_config in channels['input_related']: + group.add_input_related( + Channel.init_using_cfg(model, channel_config)) + for channel_config in channels['output_related']: + group.add_ouptut_related( + Channel.init_using_cfg(model, channel_config)) + return group + + @property + def name(self) -> str: + """str: name of the group""" + first_module = self.output_related[0] if len( + self.output_related) > 0 else self.input_related[0] + name = f'{first_module.name}_{first_module.index}_' + name += f'out_{len(self.output_related)}_in_{len(self.input_related)}' + return name + # node operations def add_ouptut_related(self, channel: Channel): @@ -281,11 +355,13 @@ def add_input_related(self, channel: Channel): def remove_from_node(self): """Remove recorded information in all nodes about this group.""" for channel in self.output_related: - assert channel.index in channel.node.output_related_groups, \ + assert channel.node is not None \ + and channel.index in channel.node.output_related_groups, \ f'{channel.name}.{channel.index} not exist in node.out_related' channel.node.output_related_groups.pop(channel.index) for channel in self.input_related: - assert channel.index in channel.node.input_related_groups, \ + assert channel.node is not None \ + and channel.index in channel.node.input_related_groups, \ f'{channel.name}.{channel.index} \ not exist in node.input_related' @@ -293,10 +369,12 @@ def remove_from_node(self): def apply_for_node(self): """Register the information about this group for all nodes.""" - for node in self.output_related: - node.node.output_related_groups[node.index] = self - for node in self.input_related: - node.node.input_related_groups[node.index] = self + for channel in self.output_related: + assert channel.node is not None + channel.node.output_related_groups[channel.index] = self + for channel in self.input_related: + assert channel.node is not None + channel.node.input_related_groups[channel.index] = self # group operations @@ -351,6 +429,27 @@ def parse_channel_groups(cls, group._model = graph._model return groups + # config + def config_template(self, with_init_args=False, with_channels=False): + """Generate a config template which can be used to initialize a + ChannelGroup by cls.init_using_cfg(**kwargs)""" + config = {} + if with_init_args: + config['init_args'] = {'num_channels': self.num_channels} + if with_channels: + config['channels'] = self._channel_dict() + return config + + # tools + def _channel_dict(self): + info = { + 'input_related': + [channel.config_template() for channel in self.input_related], + 'output_related': + [channel.config_template() for channel in self.output_related], + } + return info + # to string def __repr__(self): diff --git a/mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py b/mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py index 55e7a2e16..a498a47ec 100644 --- a/mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py +++ b/mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py @@ -49,16 +49,15 @@ def __init__(self, num_channels: int) -> None: super().__init__(num_channels) BaseModule.__init__(self) - # basic property + @classmethod + def init_from_channel_group(cls, group: ChannelGroup, args: Dict): + args['num_channels'] = group.num_channels + mutable_group = cls(**args) + mutable_group.input_related = group.input_related + mutable_group.output_related = group.output_related + return mutable_group - @property - def name(self) -> str: - """str: name of the group""" - first_module = self.output_related[0] if len( - self.output_related) > 0 else self.input_related[0] - name = f'{first_module.name}_{first_module.index}_' - name += f'out_{len(self.output_related)}_in_{len(self.input_related)}' - return name + # basic property @property def is_prunable(self) -> bool: @@ -68,7 +67,7 @@ def traverse(channels: List[Channel]): has_dynamic_op = False all_channel_prunable = True for channel in channels: - if channel.node.is_prunable is False: + if channel.is_prunable is False: all_channel_prunable = False break if isinstance(channel.module, DynamicChannelMixin): @@ -113,12 +112,6 @@ def sample_choice(self): """Randomly sample a valid choice and return.""" raise NotImplementedError() - def config_template(self, with_info=False): - if with_info: - return {'info': self._info_dict} - else: - return {} - # after pruning def fix_chosen(self, choice=None): @@ -128,15 +121,22 @@ def fix_chosen(self, choice=None): # tools - def _info_dict(self): - info = { - 'num_channels': self.num_channels, - 'choice': self.current_choice, - 'prunable': self.is_prunable, - 'input_layers': self.input_related, - 'out_related': self.output_related - } - return info + def config_template(self, + with_init_args=False, + with_channels=False) -> Dict: + """Return the config template of this group. By default, the config + template only includes a key 'choice'. + + Args: + with_init_args (bool): if the config includes args for + initialization. + with_channels (bool): if the config includes info about + channels. the config with info about channels can used to + parse channel groups without tracer. + """ + config = super().config_template(with_init_args, with_channels) + config['choice'] = self.current_choice + return config def _get_int_choice(self, choice: Union[int, float]) -> int: """Convert ratio of channels to number of channels.""" @@ -165,11 +165,11 @@ def get_module(model, name): for channel in self.input_related + self.output_related: if isinstance(channel.module, nn.Module): - module = get_module(model, channel.module_name) + module = get_module(model, channel.name) if type(module) in dynamicop_map: new_module = dynamicop_map[type(module)].convert_from( module) - replace_op(model, channel.module_name, new_module) + replace_op(model, channel.name, new_module) channel.module = new_module else: channel.module = module diff --git a/tests/test_models/test_mutables/group/test_mutable_channel_groups.py b/tests/test_models/test_mutables/group/test_mutable_channel_groups.py index f6c17ab97..fa67e2c15 100644 --- a/tests/test_models/test_mutables/group/test_mutable_channel_groups.py +++ b/tests/test_models/test_mutables/group/test_mutable_channel_groups.py @@ -63,8 +63,8 @@ def test_with_backward_tracer(self): def test_group_split(self): layer = nn.Conv2d(3, 16, 3) node = PruneNode('layer', layer) - channel1 = Channel(node, (8, 16), True) - channel2 = Channel(node, (0, 8), True) + channel1 = Channel(node.name, layer, (8, 16), node, True) + channel2 = Channel(node.name, layer, (0, 8), node, True) group = DefaultChannelGroup(8) group.add_ouptut_related(channel1) group.add_ouptut_related(channel2) From 7e2532a0eb058c6f0df755d08a21d470abc5a0c2 Mon Sep 17 00:00:00 2001 From: liukai Date: Tue, 6 Sep 2022 17:18:26 +0800 Subject: [PATCH 03/25] add docstring to channel_group.py --- .../mutable_channel/groups/channel_group.py | 149 +++++++++++------- 1 file changed, 88 insertions(+), 61 deletions(-) diff --git a/mmrazor/models/mutables/mutable_channel/groups/channel_group.py b/mmrazor/models/mutables/mutable_channel/groups/channel_group.py index 782f1c7a5..61308fbad 100644 --- a/mmrazor/models/mutables/mutable_channel/groups/channel_group.py +++ b/mmrazor/models/mutables/mutable_channel/groups/channel_group.py @@ -30,11 +30,27 @@ class PruneNode(ModuleNode): """Node class for pruning.""" + # init + def __init__(self, name: str, obj: Module, module_name='') -> None: + """ + Args: + name (str): node name. + obj (Module): Module + module_name: the name of the module in the model. + """ super().__init__(name, obj, module_name=module_name) self.input_related_groups: IndexDict[ChannelGroup] = IndexDict() self.output_related_groups: IndexDict[ChannelGroup] = IndexDict() + @classmethod + def copy_from(cls, node): + """Copy from a ModuleNode.""" + if isinstance(node, ModuleNode): + return cls(node.name, node.val, node.module_name) + else: + raise NotImplementedError() + # groups operation def get_channels(self, @@ -73,7 +89,7 @@ def output_related_groups_of_prev_nodes( groups.append(node.output_related_groups) return groups - # channel + # channel properties @property def act_in_channels(self) -> int: @@ -127,25 +143,20 @@ def act_out_channels(self) -> int: @property def is_parsed(self): + """If this node have been parsed.""" return len(self.input_related_groups) > 0 or len( self.output_related_groups) > 0 - # others - def __repr__(self) -> str: - return (f'{self.name}_{self.act_in_channels}/{self.in_channels}' - f'_{self.act_out_channels}/{self.out_channels}') - @property def is_prunable(self) -> bool: """Bool: if the node prunable""" return self.basic_type not in ['gwconv2d'] - @classmethod - def copy_from(cls, node): - if isinstance(node, ModuleNode): - return cls(node.name, node.val, node.module_name) - else: - raise NotImplementedError() + # others + + def __repr__(self) -> str: + return (f'{self.name}_{self.act_in_channels}/{self.in_channels}' + f'_{self.act_out_channels}/{self.out_channels}') PRUNENODE = TypeVar('PRUNENODE', bound=PruneNode) @@ -154,12 +165,20 @@ def copy_from(cls, node): class PruneGraph(ModuleGraph[PRUNENODE]): """Graph class for pruning.""" - def __init__(self) -> None: - super().__init__() + # init + + @classmethod + def copy_from(cls, graph, node_converter=PruneNode.copy_from): + """Copy from a module graph.""" + assert isinstance(graph, ModuleGraph) + graph = super().copy_from(graph, node_converter) + graph._merge_same_module() + return graph # groups_operation + def colloct_groups(self) -> List['ChannelGroup']: - """Set['ChannelGroup']: collect all ChannelGroups in the graph""" + """List['ChannelGroup']: collect all ChannelGroups in the graph.""" groups = [] for node in self.topo_traverse(): for group in node.input_related_groups.values(): @@ -170,13 +189,11 @@ def colloct_groups(self) -> List['ChannelGroup']: groups.append(group) return groups - @classmethod - def copy_from(cls, graph, node_converter=PruneNode.copy_from): - graph = super().copy_from(graph, node_converter) - graph.merge_same_module() - return graph + # private methods - def merge_same_module(self): + def _merge_same_module(self): + """Let all nodes that refer to the same module use the same + input_related_groups and output_related_groups.""" module2node: Dict[Any, List[PruneNode]] = dict() for node in self: if isinstance(node.val, Module): @@ -199,6 +216,8 @@ def merge_same_module(self): class Channel: """Channel records information about channels for pruning.""" + # init + def __init__(self, name, module, @@ -245,7 +264,7 @@ def init_using_cfg(cls, model: nn.Module, config: Dict): out_related=is_output, expand_ratio=expand_ratio) - # config + # config template def config_template(self): """Generate a config template which can be used to initialize a Channel by cls.init_using_cfg(**kwargs)""" @@ -257,6 +276,8 @@ def config_template(self): 'is_output_related': self.output_related } + # basic properties + @property def num_channels(self) -> int: """Int: number of channels in the Channels""" @@ -264,14 +285,16 @@ def num_channels(self) -> int: @property def is_prunable(self) -> bool: + """If the channel is prunable.""" if isinstance(self.module, nn.Conv2d): + # group-wise conv if self.module.groups != 1 and not (self.module.groups == self.module.in_channels == self.module.out_channels): return False return True - # group related operations + # node operations def slice(self, start: int, end: int) -> 'Channel': """Channel: a new Channel who manage a slice of the current Channel.""" @@ -327,6 +350,21 @@ def init_using_cfg(cls, model: nn.Module, config: Dict): Channel.init_using_cfg(model, channel_config)) return group + @classmethod + def parse_channel_groups(cls, + graph: ModuleGraph, + group_args={}) -> List['ChannelGroup']: + """Parse a module-graph and get ChannelGroups.""" + group_graph = PruneGraph.copy_from(graph, PruneNode.copy_from) + + cfg = dict(type=cls.__name__, **group_args) + groups = Graph2ChannelGroups(group_graph, cfg).groups + for group in groups: + group._model = graph._model + return groups + + # basic property + @property def name(self) -> str: """str: name of the group""" @@ -336,6 +374,18 @@ def name(self) -> str: name += f'out_{len(self.output_related)}_in_{len(self.input_related)}' return name + # config template + + def config_template(self, with_init_args=False, with_channels=False): + """Generate a config template which can be used to initialize a + ChannelGroup by cls.init_using_cfg(**kwargs)""" + config = {} + if with_init_args: + config['init_args'] = {'num_channels': self.num_channels} + if with_channels: + config['channels'] = self._channel_dict() + return config + # node operations def add_ouptut_related(self, channel: Channel): @@ -414,43 +464,7 @@ def slice(self, start: int, end: int) -> 'ChannelGroup': group.add_ouptut_related(module.slice(start, end)) return group - # init - - @classmethod - def parse_channel_groups(cls, - graph: ModuleGraph, - group_args={}) -> List['ChannelGroup']: - """Parse a module-graph and get ChannelGroups.""" - group_graph = PruneGraph.copy_from(graph, PruneNode.copy_from) - - cfg = dict(type=cls.__name__, **group_args) - groups = Graph2ChannelGroups(group_graph, cfg).groups - for group in groups: - group._model = graph._model - return groups - - # config - def config_template(self, with_init_args=False, with_channels=False): - """Generate a config template which can be used to initialize a - ChannelGroup by cls.init_using_cfg(**kwargs)""" - config = {} - if with_init_args: - config['init_args'] = {'num_channels': self.num_channels} - if with_channels: - config['channels'] = self._channel_dict() - return config - - # tools - def _channel_dict(self): - info = { - 'input_related': - [channel.config_template() for channel in self.input_related], - 'output_related': - [channel.config_template() for channel in self.output_related], - } - return info - - # to string + # others def __repr__(self): @@ -477,6 +491,18 @@ def list_repr(lit: List): s += add_prefix(list_repr(self.input_related), ' ' * 4) return s + # private methods + + def _channel_dict(self) -> Dict: + """Return channel config.""" + info = { + 'input_related': + [channel.config_template() for channel in self.input_related], + 'output_related': + [channel.config_template() for channel in self.output_related], + } + return info + # Group to ChannelGroup Converter @@ -516,6 +542,7 @@ def union_node_groups( self, node_groups_list=List[IndexDict[ChannelGroup]] ) -> List[ChannelGroup]: + """Union groups of nodes.""" union_groups = [] for index in copy.copy(node_groups_list[0]): groups = [node_groups[index] for node_groups in node_groups_list] @@ -587,7 +614,7 @@ def points2nums(points): for group in new_groups: group.apply_for_node() - # operations + # node operations def add_input_related(self, group: ChannelGroup, From 8e11bce16ae9d65b062cbe078f4946693449d4ee Mon Sep 17 00:00:00 2001 From: liukai Date: Tue, 6 Sep 2022 17:25:49 +0800 Subject: [PATCH 04/25] add docstring to mutable_channel_group.py --- .../groups/mutable_channel_group.py | 41 ++++++++++--------- .../groups/sequential_channel_group.py | 4 +- 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py b/mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py index a498a47ec..e7976235f 100644 --- a/mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py +++ b/mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py @@ -51,6 +51,7 @@ def __init__(self, num_channels: int) -> None: @classmethod def init_from_channel_group(cls, group: ChannelGroup, args: Dict): + """Initialize a MutalbeChannelGroup from a ChannelGroup.""" args['num_channels'] = group.num_channels mutable_group = cls(**args) mutable_group.input_related = group.input_related @@ -85,6 +86,25 @@ def traverse(channels: List[Channel]): and output_has_dynamic_op \ and output_all_prunable + # config template + + def config_template(self, + with_init_args=False, + with_channels=False) -> Dict: + """Return the config template of this group. By default, the config + template only includes a key 'choice'. + + Args: + with_init_args (bool): if the config includes args for + initialization. + with_channels (bool): if the config includes info about + channels. the config with info about channels can used to + parse channel groups without tracer. + """ + config = super().config_template(with_init_args, with_channels) + config['choice'] = self.current_choice + return config + # before pruning: prepare a model @abc.abstractmethod @@ -121,23 +141,6 @@ def fix_chosen(self, choice=None): # tools - def config_template(self, - with_init_args=False, - with_channels=False) -> Dict: - """Return the config template of this group. By default, the config - template only includes a key 'choice'. - - Args: - with_init_args (bool): if the config includes args for - initialization. - with_channels (bool): if the config includes info about - channels. the config with info about channels can used to - parse channel groups without tracer. - """ - config = super().config_template(with_init_args, with_channels) - config['choice'] = self.current_choice - return config - def _get_int_choice(self, choice: Union[int, float]) -> int: """Convert ratio of channels to number of channels.""" if isinstance(choice, float): @@ -175,7 +178,7 @@ def get_module(model, name): channel.module = module @staticmethod - def _register_mask_container( + def _register_channel_container( model: nn.Module, container_class: Type[MutableChannelContainer]): """register channel container for dynamic ops.""" for module in model.modules(): @@ -205,7 +208,7 @@ def _register_mask_container( module.register_mutable_attr('out_channels', container_class(out_channels)) - def _register_mask(self, mutable_channel: BaseMutableChannel): + def _register_mutable_channel(self, mutable_channel: BaseMutableChannel): # register mutable_channel for channel in self.input_related + self.output_related: diff --git a/mmrazor/models/mutables/mutable_channel/groups/sequential_channel_group.py b/mmrazor/models/mutables/mutable_channel/groups/sequential_channel_group.py index 658228958..9ee5d81de 100644 --- a/mmrazor/models/mutables/mutable_channel/groups/sequential_channel_group.py +++ b/mmrazor/models/mutables/mutable_channel/groups/sequential_channel_group.py @@ -36,8 +36,8 @@ def prepare_for_pruning(self, model: nn.Module): nn.BatchNorm2d: DynamicBatchNorm2d, nn.Linear: DynamicLinear }) - self._register_mask_container(model, MutableChannelContainer) - self._register_mask(self.mutable_channel) + self._register_channel_container(model, MutableChannelContainer) + self._register_mutable_channel(self.mutable_channel) # choice From 474b6940d0ccef22a8deace85a8aa7bc08f7fc57 Mon Sep 17 00:00:00 2001 From: liukai Date: Tue, 6 Sep 2022 17:50:27 +0800 Subject: [PATCH 05/25] rm channel_group_cfg from Graph2ChannelGroups --- .../mutable_channel/groups/channel_group.py | 32 +++++++------------ 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/mmrazor/models/mutables/mutable_channel/groups/channel_group.py b/mmrazor/models/mutables/mutable_channel/groups/channel_group.py index 61308fbad..17500f40e 100644 --- a/mmrazor/models/mutables/mutable_channel/groups/channel_group.py +++ b/mmrazor/models/mutables/mutable_channel/groups/channel_group.py @@ -12,7 +12,7 @@ """ import copy -from typing import Any, Dict, List, Tuple, Type, TypeVar, Union +from typing import Any, Dict, List, Tuple, TypeVar, Union import torch.nn as nn from torch.nn import Module @@ -350,15 +350,20 @@ def init_using_cfg(cls, model: nn.Module, config: Dict): Channel.init_using_cfg(model, channel_config)) return group + @classmethod + def init_from_channel_group(cls, group: 'ChannelGroup', args: Dict): + return group + @classmethod def parse_channel_groups(cls, graph: ModuleGraph, group_args={}) -> List['ChannelGroup']: """Parse a module-graph and get ChannelGroups.""" group_graph = PruneGraph.copy_from(graph, PruneNode.copy_from) - - cfg = dict(type=cls.__name__, **group_args) - groups = Graph2ChannelGroups(group_graph, cfg).groups + groups = Graph2ChannelGroups(group_graph).groups + groups = [ + cls.init_from_channel_group(group, group_args) for group in groups + ] for group in groups: group._model = graph._model return groups @@ -510,33 +515,20 @@ def _channel_dict(self) -> Dict: class Graph2ChannelGroups: """A converter which converts a Graph to a list of ChannelGroups.""" - def __init__( - self, - graph: PruneGraph, - channel_group_cfg: Union[Dict, - Type[ChannelGroup]] = ChannelGroup) -> None: + def __init__(self, graph: PruneGraph) -> None: """ Args: graph (PruneGraph): input prune-graph channel_group_cfg: the config for generating groups """ self.graph = graph - if isinstance(channel_group_cfg, dict): - self.channel_group_class = MODELS.module_dict[ - channel_group_cfg['type']] - self.channel_group_args = copy.copy(channel_group_cfg) - self.channel_group_args.pop('type') - else: - self.channel_group_class = channel_group_cfg - self.channel_group_args = {} self.groups = self.parse(self.graph) # group operations def new_channel_group(self, num_channels) -> ChannelGroup: """Initialize a ChannelGroup.""" - return self.channel_group_class(num_channels, - **self.channel_group_args) + return ChannelGroup(num_channels) def union_node_groups( self, @@ -552,7 +544,7 @@ def union_node_groups( def union_groups(self, groups: List[ChannelGroup]) -> ChannelGroup: """List[ChannelGroup]: union a list of ChannelGroups""" - group = self.channel_group_class.union(groups) + group = ChannelGroup.union(groups) # avoid removing multiple times groups_set = set(groups) for old_group in groups_set: From 65205331e7d400dd2a801c9714ccfb2e1c8a3e12 Mon Sep 17 00:00:00 2001 From: liukai Date: Wed, 7 Sep 2022 11:27:13 +0800 Subject: [PATCH 06/25] change choice type of SequentialChannelGroup from float to int --- .../groups/sequential_channel_group.py | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/mmrazor/models/mutables/mutable_channel/groups/sequential_channel_group.py b/mmrazor/models/mutables/mutable_channel/groups/sequential_channel_group.py index 9ee5d81de..29da40e60 100644 --- a/mmrazor/models/mutables/mutable_channel/groups/sequential_channel_group.py +++ b/mmrazor/models/mutables/mutable_channel/groups/sequential_channel_group.py @@ -16,8 +16,8 @@ class SequentialChannelGroup(MutableChannelGroup): """SimpleChannelGroup defines a simple pruning algorithhm. - The type of choice of SimpleChannelGroup is float. It indicates what ratio - of channels are remained from left to right. + The type of choice of SimpleChannelGroup is int. It indicates what ratio of + channels are remained from left to right. """ def __init__(self, num_channels: int) -> None: @@ -42,21 +42,20 @@ def prepare_for_pruning(self, model: nn.Module): # choice @property - def current_choice(self) -> float: + def current_choice(self) -> int: """return current choice.""" - return self.mutable_channel.activated_channels / self.num_channels + return self.mutable_channel.activated_channels @current_choice.setter - def current_choice(self, choice: float): + def current_choice(self, choice: int): """set choice.""" - int_choice = self._get_int_choice(choice) - mask = self._generate_mask(int_choice) + assert 0 < choice <= self.num_channels + mask = self._generate_mask(choice) self.mutable_channel.current_choice = mask - def sample_choice(self) -> float: + def sample_choice(self) -> int: """Sample a choice in (0,1]""" - return max(1, int( - random.random() * self.num_channels)) / self.num_channels + return random.randint(1, self.num_channels) # private methods @@ -66,7 +65,6 @@ def _generate_mask(self, choice: int) -> torch.Tensor: mask[0:choice] = 1 return mask - # interface def fix_chosen(self, choice=None): """fix chosen.""" super().fix_chosen(choice) From 3139b225f0a59cb7ab7244d46e9611a7064c2576 Mon Sep 17 00:00:00 2001 From: liukai Date: Thu, 8 Sep 2022 09:49:36 +0800 Subject: [PATCH 07/25] add a warning about group-wise conv --- .../dynamic_ops/bricks/dynamic_conv.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_conv.py b/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_conv.py index 009b3131b..b045f6dd7 100644 --- a/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_conv.py +++ b/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_conv.py @@ -10,6 +10,8 @@ from .dynamic_conv_mixins import (BigNasConvMixin, DynamicConvMixin, OFAConvMixin) +GroupWiseConvWarned = False + @MODELS.register_module() class DynamicConv2d(nn.Conv2d, DynamicConvMixin): @@ -40,6 +42,17 @@ def convert_from(cls, module: nn.Conv2d) -> 'DynamicConv2d': # a group-wise conv will not be converted to dynamic conv if module.groups > 1 and not (module.groups == module.out_channels == module.in_channels): + global GroupWiseConvWarned + if GroupWiseConvWarned is False: + from mmengine import MMLogger + logger = MMLogger.get_instance( + 'mmrazor', logger_name='mmrazor') + logger.warning( + ('Group-wise convolutional layers are not supported to be' + 'pruned now, so they are not converted to new' + 'DynamicConvs.')) + GroupWiseConvWarned = True + return module else: return cls( From 52e1a1623158869edf6021881003142a828fd320 Mon Sep 17 00:00:00 2001 From: liukai Date: Thu, 8 Sep 2022 09:54:13 +0800 Subject: [PATCH 08/25] restore __init__ of dynamic op --- mmrazor/models/architectures/dynamic_ops/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mmrazor/models/architectures/dynamic_ops/__init__.py b/mmrazor/models/architectures/dynamic_ops/__init__.py index a7c259e38..6b5796688 100644 --- a/mmrazor/models/architectures/dynamic_ops/__init__.py +++ b/mmrazor/models/architectures/dynamic_ops/__init__.py @@ -1,12 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. from .base import DynamicOP -from .default_dynamic_ops import (ChannelDynamicOP, DynamicBatchNorm, - DynamicConv2d, DynamicGroupNorm, - DynamicInstanceNorm, DynamicLinear) +from .default_dynamic_ops import (DynamicBatchNorm, DynamicConv2d, + DynamicGroupNorm, DynamicInstanceNorm, + DynamicLinear) from .slimmable_dynamic_ops import SwitchableBatchNorm2d __all__ = [ 'DynamicConv2d', 'DynamicLinear', 'DynamicBatchNorm', 'DynamicInstanceNorm', 'DynamicGroupNorm', 'SwitchableBatchNorm2d', - 'DynamicOP', 'ChannelDynamicOP' + 'DynamicOP' ] From 0dd4dafc1927a6826bc14af16965f7e78f42348c Mon Sep 17 00:00:00 2001 From: liukai Date: Thu, 8 Sep 2022 09:55:44 +0800 Subject: [PATCH 09/25] in_channel_mutable -> mutable_in_channel --- .../models/mutables/mutable_channel/base_mutable_channel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py index 54ce6568e..f48a49b3b 100644 --- a/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py +++ b/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py @@ -13,11 +13,11 @@ class BaseMutableChannel(BaseMutable, DerivedMethodMixin): channels. |---------------------------------------| - |in_channel_mutable(BaseMutableChannel) | + |mutable_in_channel(BaseMutableChannel) | |---------------------------------------| | DynamicOp | |---------------------------------------| - |out_channel_mutable(BaseMutableChannel)| + |mutable_out_channel(BaseMutableChannel)| |---------------------------------------| Important interfaces: From ef7239036d04c9bdd7ce43f81ec696864138818e Mon Sep 17 00:00:00 2001 From: liukai Date: Thu, 8 Sep 2022 10:08:55 +0800 Subject: [PATCH 10/25] rm abstractproperty --- .../mutables/mutable_channel/base_mutable_channel.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py index f48a49b3b..5cc8c9d74 100644 --- a/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py +++ b/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. """""" -from abc import abstractproperty +from abc import abstractmethod import torch @@ -32,17 +32,20 @@ def __init__(self, num_channels: int, **kwargs): # choice - @abstractproperty + @property # type: ignore + @abstractmethod def current_choice(self): """get current choice.""" raise NotImplementedError() - @current_choice.setter + @current_choice.setter # type: ignore + @abstractmethod def current_choice(self): """set current choice.""" raise NotImplementedError() - @abstractproperty + @property # type: ignore + @abstractmethod def current_mask(self) -> torch.Tensor: """Return a mask indicating the channel selection.""" raise NotImplementedError() From f6daf66a44822a054a4b8f3f7dffd6a1221c5419 Mon Sep 17 00:00:00 2001 From: liukai Date: Thu, 8 Sep 2022 10:11:20 +0800 Subject: [PATCH 11/25] add a comment about VT --- mmrazor/utils/index_dict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmrazor/utils/index_dict.py b/mmrazor/utils/index_dict.py index de7d6ae3a..8fc163c77 100644 --- a/mmrazor/utils/index_dict.py +++ b/mmrazor/utils/index_dict.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import OrderedDict, Tuple, TypeVar -VT = TypeVar('VT') +VT = TypeVar('VT') # Value type class IndexDict(OrderedDict[Tuple[int, int], VT]): From d110291b40bb34a41932d4c88e132b95cac3a3b3 Mon Sep 17 00:00:00 2001 From: liukai Date: Thu, 8 Sep 2022 10:31:10 +0800 Subject: [PATCH 12/25] rm registry for ChannelGroup --- mmrazor/models/mutables/mutable_channel/groups/channel_group.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mmrazor/models/mutables/mutable_channel/groups/channel_group.py b/mmrazor/models/mutables/mutable_channel/groups/channel_group.py index 17500f40e..fb4898d8a 100644 --- a/mmrazor/models/mutables/mutable_channel/groups/channel_group.py +++ b/mmrazor/models/mutables/mutable_channel/groups/channel_group.py @@ -19,7 +19,6 @@ from mmrazor.models.architectures.dynamic_ops.bricks.dynamic_mixins import \ DynamicChannelMixin -from mmrazor.registry import MODELS from mmrazor.structures.graph import ModuleGraph, ModuleNode from mmrazor.utils import IndexDict from ..base_mutable_channel import BaseMutableChannel @@ -315,7 +314,6 @@ def __repr__(self) -> str: expand:{self.expand_ratio}' -@MODELS.register_module() class ChannelGroup: """A manager for Channels.""" From 475674fea5706badc5857e0088be439e6061f773 Mon Sep 17 00:00:00 2001 From: liukai Date: Thu, 8 Sep 2022 10:35:21 +0800 Subject: [PATCH 13/25] MUTABLECHANNELGROUP -> ChannelGroupType --- mmrazor/models/mutables/__init__.py | 4 ++-- mmrazor/models/mutables/mutable_channel/__init__.py | 4 ++-- mmrazor/models/mutables/mutable_channel/groups/__init__.py | 6 ++---- .../mutable_channel/groups/mutable_channel_group.py | 2 +- 4 files changed, 7 insertions(+), 9 deletions(-) diff --git a/mmrazor/models/mutables/__init__.py b/mmrazor/models/mutables/__init__.py index 659bcd57c..82b029450 100644 --- a/mmrazor/models/mutables/__init__.py +++ b/mmrazor/models/mutables/__init__.py @@ -4,7 +4,7 @@ from .mutable_channel import (BaseMutableChannel, MutableChannel, MutableChannelContainer, OneShotMutableChannel, SimpleMutableChannel, SlimmableMutableChannel) -from .mutable_channel.groups import (MUTABLECHANNELGROUP, MutableChannelGroup, +from .mutable_channel.groups import (ChannelGroupType, MutableChannelGroup, SequentialChannelGroup) from .mutable_module import (DiffChoiceRoute, DiffMutableModule, DiffMutableOP, OneShotMutableModule, OneShotMutableOP) @@ -14,7 +14,7 @@ 'OneShotMutableOP', 'OneShotMutableModule', 'DiffMutableOP', 'DiffChoiceRoute', 'DiffMutableModule', 'DerivedMutable', 'MutableValue', 'OneShotMutableValue', 'SimpleMutableChannel', 'MutableChannelGroup', - 'BaseMutableChannel', 'MutableChannelContainer', 'MUTABLECHANNELGROUP', + 'BaseMutableChannel', 'MutableChannelContainer', 'ChannelGroupType', 'BaseMutable', 'MutableChannel', 'SlimmableMutableChannel', 'OneShotMutableChannel', 'SequentialChannelGroup' ] diff --git a/mmrazor/models/mutables/mutable_channel/__init__.py b/mmrazor/models/mutables/mutable_channel/__init__.py index 72d3f7276..da35489c2 100644 --- a/mmrazor/models/mutables/mutable_channel/__init__.py +++ b/mmrazor/models/mutables/mutable_channel/__init__.py @@ -13,7 +13,7 @@ with mutable number of channels. """ from .base_mutable_channel import BaseMutableChannel -from .groups import (MUTABLECHANNELGROUP, MutableChannelGroup, +from .groups import (ChannelGroupType, MutableChannelGroup, SequentialChannelGroup) from .mutable_channel import MutableChannel from .mutable_channel_container import MutableChannelContainer @@ -25,7 +25,7 @@ __all__ = [ 'SimpleMutableChannel', 'MutableChannelGroup', 'OneShotChannelGroup', 'BaseMutableChannel', 'MutableChannelContainer', 'StackMutableChannel', - 'MUTABLECHANNELGROUP', 'MutableChannel', 'OneShotMutableChannel', + 'ChannelGroupType', 'MutableChannel', 'OneShotMutableChannel', 'SlimmableMutableChannel', 'SquentialMutableChannel', 'SequentialChannelGroup' ] diff --git a/mmrazor/models/mutables/mutable_channel/groups/__init__.py b/mmrazor/models/mutables/mutable_channel/groups/__init__.py index f73fa24cb..fbfc37d5c 100644 --- a/mmrazor/models/mutables/mutable_channel/groups/__init__.py +++ b/mmrazor/models/mutables/mutable_channel/groups/__init__.py @@ -13,9 +13,7 @@ ----------------------------------------------------------other files Subclasses of MutableChannelGroup """ -from .mutable_channel_group import MUTABLECHANNELGROUP, MutableChannelGroup +from .mutable_channel_group import ChannelGroupType, MutableChannelGroup from .sequential_channel_group import SequentialChannelGroup -__all__ = [ - 'MutableChannelGroup', 'SequentialChannelGroup', 'MUTABLECHANNELGROUP' -] +__all__ = ['MutableChannelGroup', 'SequentialChannelGroup', 'ChannelGroupType'] diff --git a/mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py b/mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py index e7976235f..0ed83f712 100644 --- a/mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py +++ b/mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py @@ -245,4 +245,4 @@ def _register_mutable_channel(self, mutable_channel: BaseMutableChannel): container.register_mutable(mutable_channel_, start, end) -MUTABLECHANNELGROUP = TypeVar('MUTABLECHANNELGROUP', bound=MutableChannelGroup) +ChannelGroupType = TypeVar('ChannelGroupType', bound=MutableChannelGroup) From 27c55efb9a5ddbc61e9b3b0d4ffd0575150662c4 Mon Sep 17 00:00:00 2001 From: liukai Date: Thu, 8 Sep 2022 10:42:17 +0800 Subject: [PATCH 14/25] refine docstring of IndexDict --- mmrazor/utils/index_dict.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/mmrazor/utils/index_dict.py b/mmrazor/utils/index_dict.py index 8fc163c77..ef809466d 100644 --- a/mmrazor/utils/index_dict.py +++ b/mmrazor/utils/index_dict.py @@ -5,10 +5,17 @@ class IndexDict(OrderedDict[Tuple[int, int], VT]): - """IndexDict inherents from OrderedDict[Tuple[int, int], VT]. + """IndexDict inherents from OrderedDict[Tuple[int, int], VT]. Each + IndexDict object is a OrderDict object which using index(Tuple[int,int]) as + key and Any as value. - The type of the key is a Tuple[a: int,b: int]. It indicates a index range - [a,b). IndexDict can sort index and checkout if indexes overlap + The key type is Tuple[a: int,b: int]. It indicates a range in + the [a,b). + + IndexDict has three features: + 1. ensure a key always is a index(Tuple[int,int]). + 1. ensure the the indexes are sorted by ascending order. + 2. ensure there is no overlap among indexes. """ def __setitem__(self, __k: Tuple[int, int], __v): From a18919d578ab3363bbd03de809623338ed542cae Mon Sep 17 00:00:00 2001 From: liukai Date: Thu, 8 Sep 2022 11:44:44 +0800 Subject: [PATCH 15/25] update docstring --- mmrazor/models/mutables/mutable_channel/__init__.py | 13 ------------- .../mutable_channel/base_mutable_channel.py | 10 +++++++--- .../mutable_channel/simple_mutable_channel.py | 8 ++++++-- 3 files changed, 13 insertions(+), 18 deletions(-) diff --git a/mmrazor/models/mutables/mutable_channel/__init__.py b/mmrazor/models/mutables/mutable_channel/__init__.py index da35489c2..0b11531f4 100644 --- a/mmrazor/models/mutables/mutable_channel/__init__.py +++ b/mmrazor/models/mutables/mutable_channel/__init__.py @@ -1,17 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -r"""This module defines MutableChannels. - -----------------------------------------------------------base_mutable_channel.py -BaseMutableChannel -| \ -----------------------------------------------------------mutable_channel_container.py -MutableChannelContainer \ -----------------------------------------------------------other files - \ other MutableChannels - -MutableChannel are mainly used in DynamicOps. It helps DynamicOps to deal -with mutable number of channels. -""" from .base_mutable_channel import BaseMutableChannel from .groups import (ChannelGroupType, MutableChannelGroup, SequentialChannelGroup) diff --git a/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py index 5cc8c9d74..b60b98184 100644 --- a/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py +++ b/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py @@ -20,9 +20,13 @@ class BaseMutableChannel(BaseMutable, DerivedMethodMixin): |mutable_out_channel(BaseMutableChannel)| |---------------------------------------| - Important interfaces: - current_choice: used to get/set mask. - current_mask: get mask(used in DynamicOps to get mask). + All subclasses should implement the following APIs: + + - ``current_choice`` + - ``current_mask`` + + Args: + num_channels (int): number(dimension) of channels(mask). """ def __init__(self, num_channels: int, **kwargs): diff --git a/mmrazor/models/mutables/mutable_channel/simple_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/simple_mutable_channel.py index b209cc05c..13e3ceb7b 100644 --- a/mmrazor/models/mutables/mutable_channel/simple_mutable_channel.py +++ b/mmrazor/models/mutables/mutable_channel/simple_mutable_channel.py @@ -11,9 +11,13 @@ @MODELS.register_module() class SimpleMutableChannel(BaseMutableChannel): """SimpleMutableChannel is a simple BaseMutableChannel, it directly take a - mask as a choice.""" + mask as a choice. - def __init__(self, num_channels, **kwargs) -> None: + Args: + num_channels (int): number of channels. + """ + + def __init__(self, num_channels: int, **kwargs) -> None: super().__init__(num_channels, **kwargs) self.num_channels = num_channels self.mask = torch.ones(num_channels).bool() From 1bd66a462c6a37a668a9c877d01abc0742586282 Mon Sep 17 00:00:00 2001 From: liukai Date: Thu, 8 Sep 2022 15:10:41 +0800 Subject: [PATCH 16/25] update docstring --- .../mutable_channel/groups/channel_group.py | 69 +++++++++++++------ .../sequential_mutable_channel.py | 4 ++ 2 files changed, 52 insertions(+), 21 deletions(-) diff --git a/mmrazor/models/mutables/mutable_channel/groups/channel_group.py b/mmrazor/models/mutables/mutable_channel/groups/channel_group.py index fb4898d8a..c2ce9670e 100644 --- a/mmrazor/models/mutables/mutable_channel/groups/channel_group.py +++ b/mmrazor/models/mutables/mutable_channel/groups/channel_group.py @@ -27,24 +27,35 @@ class PruneNode(ModuleNode): - """Node class for pruning.""" + """Node class for pruning. + + Args: + name (str): node name. + obj (Module): module. + module_name (str, optional): the name of the module in the model. + Defaults to ''. + """ # init def __init__(self, name: str, obj: Module, module_name='') -> None: - """ - Args: - name (str): node name. - obj (Module): Module - module_name: the name of the module in the model. - """ super().__init__(name, obj, module_name=module_name) self.input_related_groups: IndexDict[ChannelGroup] = IndexDict() self.output_related_groups: IndexDict[ChannelGroup] = IndexDict() @classmethod def copy_from(cls, node): - """Copy from a ModuleNode.""" + """Copy from a ModuleNode. + + Args: + node (ModuleNode): node to be copied. + + Raises: + NotImplementedError: _description_ + + Returns: + PruneNode : Prunenode copied from the ModuleNode. + """ if isinstance(node, ModuleNode): return cls(node.name, node.val, node.module_name) else: @@ -56,7 +67,7 @@ def get_channels(self, index: Union[None, Tuple[int, int]] = None, out_related=True, expand_ratio: int = 1) -> 'Channel': - """PruneChannels: get the channels in the node between a range + """Get the channels in the module of the node between a range. Args: index (Union[None, Tuple[int, int]]): the channel range for pruning @@ -64,6 +75,8 @@ def get_channels(self, otherwise input channels. expand_ratio (Bool): expand_ratio of the number of channels compared with pruning mask. + Returns: + Channel """ if index is None: index = (0, self.out_channels @@ -174,7 +187,7 @@ def copy_from(cls, graph, node_converter=PruneNode.copy_from): graph._merge_same_module() return graph - # groups_operation + # group operations def colloct_groups(self) -> List['ChannelGroup']: """List['ChannelGroup']: collect all ChannelGroups in the graph.""" @@ -264,6 +277,7 @@ def init_using_cfg(cls, model: nn.Module, config: Dict): expand_ratio=expand_ratio) # config template + def config_template(self): """Generate a config template which can be used to initialize a Channel by cls.init_using_cfg(**kwargs)""" @@ -315,13 +329,17 @@ def __repr__(self) -> str: class ChannelGroup: - """A manager for Channels.""" + """A group of Channels. + + A ChannelGroup has two list, input_related and output_related, to store + the Channels. These Channels are dependent on each other, and have to + have the same number of activated number of channels. + + Args: + num_channels (int): the number of channels of Channel object. + """ def __init__(self, num_channels: int) -> None: - """ - Args: - num_channels (int): the dimension of Channels. - """ self.num_channels = num_channels self.output_related: List[Channel] = [] @@ -350,6 +368,15 @@ def init_using_cfg(cls, model: nn.Module, config: Dict): @classmethod def init_from_channel_group(cls, group: 'ChannelGroup', args: Dict): + """Initial a object of current class from a ChannelGroup object. + + Args: + group (ChannelGroup) + args (Dict): arguments to initial the object of current class. + + Returns: + Type(cls) + """ return group @classmethod @@ -511,14 +538,14 @@ def _channel_dict(self) -> Dict: class Graph2ChannelGroups: - """A converter which converts a Graph to a list of ChannelGroups.""" + """Graph2ChannelGroups parses a PruneGraph and return ChannelGroups. + + Args: + graph (PruneGraph): input prune-graph. + """ def __init__(self, graph: PruneGraph) -> None: - """ - Args: - graph (PruneGraph): input prune-graph - channel_group_cfg: the config for generating groups - """ + self.graph = graph self.groups = self.parse(self.graph) diff --git a/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py index 13a39f0a1..8150802be 100644 --- a/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py +++ b/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py @@ -12,9 +12,13 @@ class SquentialMutableChannel(BaseMutableChannel): A choice of SquentialMutableChannel is an integer, which indicates how many channel are activated from left to right. + + Args: + num_channels (int): number of channels. """ def __init__(self, num_channels: int, **kwargs): + super().__init__(num_channels, **kwargs) self.mask = torch.ones([self.num_channels]).bool() From a229f30913d49bba59fd3ea16bcec8202efe6f18 Mon Sep 17 00:00:00 2001 From: liukai Date: Thu, 8 Sep 2022 15:15:54 +0800 Subject: [PATCH 17/25] is_prunable -> is_mutable --- .../mutables/mutable_channel/groups/channel_group.py | 11 +++++------ .../mutable_channel/groups/mutable_channel_group.py | 6 +++--- .../group/test_mutable_channel_groups.py | 2 +- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/mmrazor/models/mutables/mutable_channel/groups/channel_group.py b/mmrazor/models/mutables/mutable_channel/groups/channel_group.py index c2ce9670e..2d10113ab 100644 --- a/mmrazor/models/mutables/mutable_channel/groups/channel_group.py +++ b/mmrazor/models/mutables/mutable_channel/groups/channel_group.py @@ -160,8 +160,8 @@ def is_parsed(self): self.output_related_groups) > 0 @property - def is_prunable(self) -> bool: - """Bool: if the node prunable""" + def is_mutable(self) -> bool: + """Bool: if the the channels of the node is mutable(prunable)""" return self.basic_type not in ['gwconv2d'] # others @@ -175,8 +175,7 @@ def __repr__(self) -> str: class PruneGraph(ModuleGraph[PRUNENODE]): - """Graph class for pruning.""" - + """Subclass of ModuleGraph for pruning.""" # init @classmethod @@ -297,7 +296,7 @@ def num_channels(self) -> int: return self.index[1] - self.index[0] @property - def is_prunable(self) -> bool: + def is_mutable(self) -> bool: """If the channel is prunable.""" if isinstance(self.module, nn.Conv2d): # group-wise conv @@ -514,7 +513,7 @@ def list_repr(lit: List): s = (f'{self.name}_' f'\t{len(self.output_related)},{len(self.input_related)}' - f'\t{self.is_prunable}\n') + f'\t{self.is_mutable}\n') s += ' output_related:\n' s += add_prefix(list_repr(self.output_related), ' ' * 4) s += ' input_related\n' diff --git a/mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py b/mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py index 0ed83f712..1ce6446a7 100644 --- a/mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py +++ b/mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py @@ -24,7 +24,7 @@ class MutableChannelGroup(ChannelGroup, BaseModule): Basic Property name - is_prunable + is_mutable Important interfaces during different stages: @@ -61,14 +61,14 @@ def init_from_channel_group(cls, group: ChannelGroup, args: Dict): # basic property @property - def is_prunable(self) -> bool: + def is_mutable(self) -> bool: """If the channel-group is prunable.""" def traverse(channels: List[Channel]): has_dynamic_op = False all_channel_prunable = True for channel in channels: - if channel.is_prunable is False: + if channel.is_mutable is False: all_channel_prunable = False break if isinstance(channel.module, DynamicChannelMixin): diff --git a/tests/test_models/test_mutables/group/test_mutable_channel_groups.py b/tests/test_models/test_mutables/group/test_mutable_channel_groups.py index fa67e2c15..50e6212e9 100644 --- a/tests/test_models/test_mutables/group/test_mutable_channel_groups.py +++ b/tests/test_models/test_mutables/group/test_mutable_channel_groups.py @@ -34,7 +34,7 @@ def _test_a_graph(self, model, graph): groups = DefaultChannelGroup.parse_channel_groups(graph) for group in groups: group.prepare_for_pruning(model) - prunable_groups = [group for group in groups if group.is_prunable] + prunable_groups = [group for group in groups if group.is_mutable] for group in prunable_groups: choice = group.sample_choice() From a3de0dbb17e13d1ba13acb24a03f84b674d5adbd Mon Sep 17 00:00:00 2001 From: liukai Date: Thu, 8 Sep 2022 15:57:05 +0800 Subject: [PATCH 18/25] update docstring --- .../mutable_channel/groups/channel_group.py | 26 ++++----- .../groups/mutable_channel_group.py | 57 ++++++++----------- .../groups/sequential_channel_group.py | 5 +- 3 files changed, 41 insertions(+), 47 deletions(-) diff --git a/mmrazor/models/mutables/mutable_channel/groups/channel_group.py b/mmrazor/models/mutables/mutable_channel/groups/channel_group.py index 2d10113ab..bc9eab0a1 100644 --- a/mmrazor/models/mutables/mutable_channel/groups/channel_group.py +++ b/mmrazor/models/mutables/mutable_channel/groups/channel_group.py @@ -225,7 +225,16 @@ def _merge_same_module(self): class Channel: - """Channel records information about channels for pruning.""" + """Channel records information about channels for pruning. + + Args: + node: (PruneNode): prune-node to be recorded + index (Union[None, Tuple[int, int]]): the channel range for pruning + out_related (Bool): represents if the channels are output channels, + otherwise input channels + expand_ratio (Bool): expand_ratio of the number of channels + compared with pruning mask + """ # init @@ -236,15 +245,6 @@ def __init__(self, node: PruneNode = None, out_related=True, expand_ratio=1) -> None: - """ - Args: - node: (PruneNode): prune-node to be recorded - index (Union[None, Tuple[int, int]]): the channel range for pruning - out_related (Bool): represents if the channels are output channels, - otherwise input channels - expand_ratio (Bool): expand_ratio of the number of channels - compared with pruning mask - """ self.name = name self.module: DynamicChannelMixin = module self.index = index @@ -338,6 +338,8 @@ class ChannelGroup: num_channels (int): the number of channels of Channel object. """ + # init methods + def __init__(self, num_channels: int) -> None: self.num_channels = num_channels @@ -392,7 +394,7 @@ def parse_channel_groups(cls, group._model = graph._model return groups - # basic property + # tools @property def name(self) -> str: @@ -403,8 +405,6 @@ def name(self) -> str: name += f'out_{len(self.output_related)}_in_{len(self.input_related)}' return name - # config template - def config_template(self, with_init_args=False, with_channels=False): """Generate a config template which can be used to initialize a ChannelGroup by cls.init_using_cfg(**kwargs)""" diff --git a/mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py b/mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py index 1ce6446a7..8af3bef56 100644 --- a/mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py +++ b/mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py @@ -14,38 +14,33 @@ class MutableChannelGroup(ChannelGroup, BaseModule): - """MutableChannelGroup inherits from ChannelGroup, which manages channels - with channel-dependency. - - Compared with ChannelGroup, MutableChannelGroup defines the core - interfaces for pruning. By inheriting MutableChannelGroup, we can implement - a variant pruning algorithm. - - Basic Property - - name - is_mutable - - Important interfaces during different stages: - - # Before pruning - prepare_model - prepare_for_pruning - - # Pruning stage - current_choice - sample_choice - - # After pruning - fix_chosen - """ + # init methods def __init__(self, num_channels: int) -> None: - """ + """MutableChannelGroup inherits from ChannelGroup, which manages + channels with channel-dependency. + + Compared with ChannelGroup, MutableChannelGroup defines the core + interfaces for pruning. By inheriting MutableChannelGroup, + we can implement a variant pruning and nas algorithm. + + These apis includes + - basic property + - name + - is_mutable + - before pruning + - prepare_for_pruning + - pruning stage + - current_choice + - sample_choice + - after pruning + - fix_chosen + Args: - num_channels (int): dimension of the channels that this - MutableChannelGroup manages. + num_channels (int): dimension of the channels of the Channel + objects in the group. """ + super().__init__(num_channels) BaseModule.__init__(self) @@ -58,7 +53,7 @@ def init_from_channel_group(cls, group: ChannelGroup, args: Dict): mutable_group.output_related = group.output_related return mutable_group - # basic property + # properties @property def is_mutable(self) -> bool: @@ -86,8 +81,6 @@ def traverse(channels: List[Channel]): and output_has_dynamic_op \ and output_all_prunable - # config template - def config_template(self, with_init_args=False, with_channels=False) -> Dict: @@ -139,7 +132,7 @@ def fix_chosen(self, choice=None): if choice is not None: self.current_choice = choice - # tools + # private methods def _get_int_choice(self, choice: Union[int, float]) -> int: """Convert ratio of channels to number of channels.""" diff --git a/mmrazor/models/mutables/mutable_channel/groups/sequential_channel_group.py b/mmrazor/models/mutables/mutable_channel/groups/sequential_channel_group.py index 29da40e60..093b3f91e 100644 --- a/mmrazor/models/mutables/mutable_channel/groups/sequential_channel_group.py +++ b/mmrazor/models/mutables/mutable_channel/groups/sequential_channel_group.py @@ -18,6 +18,9 @@ class SequentialChannelGroup(MutableChannelGroup): The type of choice of SimpleChannelGroup is int. It indicates what ratio of channels are remained from left to right. + + Args: + num_channels (int): number of channels. """ def __init__(self, num_channels: int) -> None: @@ -25,8 +28,6 @@ def __init__(self, num_channels: int) -> None: self.mutable_channel: SimpleMutableChannel = SimpleMutableChannel( self.num_channels) - # prepare model - def prepare_for_pruning(self, model: nn.Module): """Prepare for pruning, including register mutable channels.""" # register MutableMask From 8d7cbc3b6e80fb952f7b1ad4b1b11ea9ff0c679a Mon Sep 17 00:00:00 2001 From: liukai Date: Fri, 9 Sep 2022 14:15:29 +0800 Subject: [PATCH 19/25] fix error in pre-commit --- .../models/mutables/mutable_channel/groups/channel_group.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mmrazor/models/mutables/mutable_channel/groups/channel_group.py b/mmrazor/models/mutables/mutable_channel/groups/channel_group.py index bc9eab0a1..ef004db22 100644 --- a/mmrazor/models/mutables/mutable_channel/groups/channel_group.py +++ b/mmrazor/models/mutables/mutable_channel/groups/channel_group.py @@ -460,8 +460,7 @@ def apply_for_node(self): @classmethod def union(cls, groups: List['ChannelGroup']) -> 'ChannelGroup': """ChannelGroup: Union ChannelGroups and return.""" - group = cls(groups[0].num_channels, - **groups[0].init_args) # type: ignore + group = cls(groups[0].num_channels) # type: ignore for old_group in groups: for group_module in old_group.input_related: group.add_input_related(group_module) From 3e2791a326c2261dc7a0cb0a6006ea2772437676 Mon Sep 17 00:00:00 2001 From: liukai Date: Fri, 9 Sep 2022 15:32:00 +0800 Subject: [PATCH 20/25] update unittest --- mmrazor/models/mutables/__init__.py | 6 ++-- .../sequential_mutable_channel.py | 14 -------- .../group/test_mutable_channels.py | 34 +++++++++++++++++++ 3 files changed, 38 insertions(+), 16 deletions(-) create mode 100644 tests/test_models/test_mutables/group/test_mutable_channels.py diff --git a/mmrazor/models/mutables/__init__.py b/mmrazor/models/mutables/__init__.py index 82b029450..14a60d066 100644 --- a/mmrazor/models/mutables/__init__.py +++ b/mmrazor/models/mutables/__init__.py @@ -3,7 +3,8 @@ from .derived_mutable import DerivedMutable from .mutable_channel import (BaseMutableChannel, MutableChannel, MutableChannelContainer, OneShotMutableChannel, - SimpleMutableChannel, SlimmableMutableChannel) + SimpleMutableChannel, SlimmableMutableChannel, + SquentialMutableChannel) from .mutable_channel.groups import (ChannelGroupType, MutableChannelGroup, SequentialChannelGroup) from .mutable_module import (DiffChoiceRoute, DiffMutableModule, DiffMutableOP, @@ -16,5 +17,6 @@ 'OneShotMutableValue', 'SimpleMutableChannel', 'MutableChannelGroup', 'BaseMutableChannel', 'MutableChannelContainer', 'ChannelGroupType', 'BaseMutable', 'MutableChannel', 'SlimmableMutableChannel', - 'OneShotMutableChannel', 'SequentialChannelGroup' + 'OneShotMutableChannel', 'SequentialChannelGroup', + 'SquentialMutableChannel' ] diff --git a/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py index 8150802be..f3436f60e 100644 --- a/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py +++ b/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py @@ -52,17 +52,3 @@ def fix_chosen(self, chosen=...): def dump_chosen(self): """Dump chosen.""" return self.current_choice - - def __mul__(self, other): - """multiplication.""" - if isinstance(other, int): - return self.derive_expand_mutable(other) - else: - return None - - def __floordiv__(self, other): - """division.""" - if isinstance(other, int): - return self.derive_divide_mutable(other) - else: - return None diff --git a/tests/test_models/test_mutables/group/test_mutable_channels.py b/tests/test_models/test_mutables/group/test_mutable_channels.py new file mode 100644 index 000000000..f8a19c923 --- /dev/null +++ b/tests/test_models/test_mutables/group/test_mutable_channels.py @@ -0,0 +1,34 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import unittest + +import pytest +import torch + +from mmrazor.models.mutables import (SimpleMutableChannel, + SquentialMutableChannel) + + +class TestMutableChannels(unittest.TestCase): + + def test_SquentialMutableChannel(self): + mutable_channel = SquentialMutableChannel(4) + mutable_channel.current_choice = 3 + self.assertEqual(mutable_channel.activated_channels, + mutable_channel.current_choice) + self.assertTrue( + (mutable_channel.current_mask == torch.tensor([1, 1, 1, + 0]).bool()).all()) + channel_str = mutable_channel.__repr__() + self.assertEqual(channel_str, + 'SquentialMutableChannel(name=, num_channels=4, ') + + mutable_channel.fix_chosen() + mutable_channel.dump_chosen() + + def test_SimpleMutableChannel(self): + channel = SimpleMutableChannel(4) + channel.current_choice = torch.tensor([1, 0, 0, 0]).bool() + self.assertEqual(channel.activated_channels, 1) + channel.fix_chosen() + with pytest.raises(NotImplementedError): + channel.dump_chosen() From 46d5ca67e402537ad09d4a31abdbcf3a99b526f0 Mon Sep 17 00:00:00 2001 From: liukai Date: Fri, 9 Sep 2022 17:29:12 +0800 Subject: [PATCH 21/25] add return type --- .../mutable_channel/groups/channel_group.py | 36 ++++++++++--------- .../groups/mutable_channel_group.py | 3 +- .../mutable_channel_container.py | 2 +- 3 files changed, 22 insertions(+), 19 deletions(-) diff --git a/mmrazor/models/mutables/mutable_channel/groups/channel_group.py b/mmrazor/models/mutables/mutable_channel/groups/channel_group.py index ef004db22..10acff8b9 100644 --- a/mmrazor/models/mutables/mutable_channel/groups/channel_group.py +++ b/mmrazor/models/mutables/mutable_channel/groups/channel_group.py @@ -44,7 +44,7 @@ def __init__(self, name: str, obj: Module, module_name='') -> None: self.output_related_groups: IndexDict[ChannelGroup] = IndexDict() @classmethod - def copy_from(cls, node): + def copy_from(cls, node) -> 'PruneNode': """Copy from a ModuleNode. Args: @@ -154,7 +154,7 @@ def act_out_channels(self) -> int: raise NotImplementedError() @property - def is_parsed(self): + def is_parsed(self) -> bool: """If this node have been parsed.""" return len(self.input_related_groups) > 0 or len( self.output_related_groups) > 0 @@ -179,7 +179,9 @@ class PruneGraph(ModuleGraph[PRUNENODE]): # init @classmethod - def copy_from(cls, graph, node_converter=PruneNode.copy_from): + def copy_from(cls, + graph, + node_converter=PruneNode.copy_from) -> 'PruneGraph': """Copy from a module graph.""" assert isinstance(graph, ModuleGraph) graph = super().copy_from(graph, node_converter) @@ -257,7 +259,7 @@ def __init__(self, self.expand_ratio = expand_ratio @classmethod - def init_using_cfg(cls, model: nn.Module, config: Dict): + def init_using_cfg(cls, model: nn.Module, config: Dict) -> 'Channel': """init a Channel using a config which can be generated by self.config_template()""" name = config['name'] @@ -277,7 +279,7 @@ def init_using_cfg(cls, model: nn.Module, config: Dict): # config template - def config_template(self): + def config_template(self) -> Dict: """Generate a config template which can be used to initialize a Channel by cls.init_using_cfg(**kwargs)""" return { @@ -340,7 +342,7 @@ class ChannelGroup: # init methods - def __init__(self, num_channels: int) -> None: + def __init__(self, num_channels: int): self.num_channels = num_channels self.output_related: List[Channel] = [] @@ -349,7 +351,7 @@ def __init__(self, num_channels: int) -> None: } # is used to generate new channel group with same args @classmethod - def init_using_cfg(cls, model: nn.Module, config: Dict): + def init_using_cfg(cls, model: nn.Module, config: Dict) -> 'ChannelGroup': """init a ChannelGroup using a config which can be generated by self.config_template()""" config = copy.deepcopy(config) @@ -368,7 +370,8 @@ def init_using_cfg(cls, model: nn.Module, config: Dict): return group @classmethod - def init_from_channel_group(cls, group: 'ChannelGroup', args: Dict): + def init_from_channel_group(cls, group: 'ChannelGroup', + args: Dict) -> 'ChannelGroup': """Initial a object of current class from a ChannelGroup object. Args: @@ -390,8 +393,6 @@ def parse_channel_groups(cls, groups = [ cls.init_from_channel_group(group, group_args) for group in groups ] - for group in groups: - group._model = graph._model return groups # tools @@ -405,7 +406,9 @@ def name(self) -> str: name += f'out_{len(self.output_related)}_in_{len(self.input_related)}' return name - def config_template(self, with_init_args=False, with_channels=False): + def config_template(self, + with_init_args=False, + with_channels=False) -> Dict: """Generate a config template which can be used to initialize a ChannelGroup by cls.init_using_cfg(**kwargs)""" config = {} @@ -418,14 +421,14 @@ def config_template(self, with_init_args=False, with_channels=False): # node operations def add_ouptut_related(self, channel: Channel): - """None: add a Channel which is output related""" + """Add a Channel which is output related.""" assert channel.output_related assert self.num_channels == channel.num_channels if channel not in self.output_related: self.output_related.append(channel) def add_input_related(self, channel: Channel): - """None: add a Channel which is input related""" + """Add a Channel which is input related.""" assert channel.output_related is False assert self.num_channels == channel.num_channels if channel not in self.input_related: @@ -494,7 +497,7 @@ def slice(self, start: int, end: int) -> 'ChannelGroup': # others - def __repr__(self): + def __repr__(self) -> str: def add_prefix(string: str, prefix=' '): str_list = string.split('\n') @@ -511,8 +514,7 @@ def list_repr(lit: List): return s s = (f'{self.name}_' - f'\t{len(self.output_related)},{len(self.input_related)}' - f'\t{self.is_mutable}\n') + f'\t{len(self.output_related)},{len(self.input_related)}\n') s += ' output_related:\n' s += add_prefix(list_repr(self.output_related), ' ' * 4) s += ' input_related\n' @@ -746,7 +748,7 @@ def parse_node(self, node: PruneNode): else: raise NotImplementedError(f'{node.basic_type}') - def parse(self, graph: PruneGraph): + def parse(self, graph: PruneGraph) -> List[ChannelGroup]: """Parse a module-graph and get ChannelGroups.""" for node in graph.topo_traverse(): self.parse_node(node) diff --git a/mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py b/mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py index 8af3bef56..3cd92f7fd 100644 --- a/mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py +++ b/mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py @@ -45,7 +45,8 @@ def __init__(self, num_channels: int) -> None: BaseModule.__init__(self) @classmethod - def init_from_channel_group(cls, group: ChannelGroup, args: Dict): + def init_from_channel_group(cls, group: ChannelGroup, + args: Dict) -> 'MutableChannelGroup': """Initialize a MutalbeChannelGroup from a ChannelGroup.""" args['num_channels'] = group.num_channels mutable_group = cls(**args) diff --git a/mmrazor/models/mutables/mutable_channel/mutable_channel_container.py b/mmrazor/models/mutables/mutable_channel/mutable_channel_container.py index c7c7111df..107ef35f4 100644 --- a/mmrazor/models/mutables/mutable_channel/mutable_channel_container.py +++ b/mmrazor/models/mutables/mutable_channel/mutable_channel_container.py @@ -34,7 +34,7 @@ def __init__(self, num_channels: int, **kwargs): # choice @property - def current_choice(self): + def current_choice(self) -> torch.Tensor: """Get current choices.""" if len(self.mutable_channels) == 0: return torch.ones([self.num_channels]).bool() From 0a53daf701d1876ada4103a28646424f959b9f79 Mon Sep 17 00:00:00 2001 From: liukai Date: Fri, 9 Sep 2022 17:43:14 +0800 Subject: [PATCH 22/25] unify init_xxx apit --- .../mutable_channel/groups/channel_group.py | 18 +++++++++--------- .../group/test_mutable_channel_groups.py | 4 ++-- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/mmrazor/models/mutables/mutable_channel/groups/channel_group.py b/mmrazor/models/mutables/mutable_channel/groups/channel_group.py index 10acff8b9..11a54a50e 100644 --- a/mmrazor/models/mutables/mutable_channel/groups/channel_group.py +++ b/mmrazor/models/mutables/mutable_channel/groups/channel_group.py @@ -259,7 +259,7 @@ def __init__(self, self.expand_ratio = expand_ratio @classmethod - def init_using_cfg(cls, model: nn.Module, config: Dict) -> 'Channel': + def init_from_cfg(cls, model: nn.Module, config: Dict) -> 'Channel': """init a Channel using a config which can be generated by self.config_template()""" name = config['name'] @@ -281,7 +281,7 @@ def init_using_cfg(cls, model: nn.Module, config: Dict) -> 'Channel': def config_template(self) -> Dict: """Generate a config template which can be used to initialize a Channel - by cls.init_using_cfg(**kwargs)""" + by cls.init_from_cfg(**kwargs)""" return { 'name': self.name, 'start': self.start, @@ -351,7 +351,7 @@ def __init__(self, num_channels: int): } # is used to generate new channel group with same args @classmethod - def init_using_cfg(cls, model: nn.Module, config: Dict) -> 'ChannelGroup': + def init_from_cfg(cls, model: nn.Module, config: Dict) -> 'ChannelGroup': """init a ChannelGroup using a config which can be generated by self.config_template()""" config = copy.deepcopy(config) @@ -363,10 +363,10 @@ def init_using_cfg(cls, model: nn.Module, config: Dict) -> 'ChannelGroup': if channels is not None: for channel_config in channels['input_related']: group.add_input_related( - Channel.init_using_cfg(model, channel_config)) + Channel.init_from_cfg(model, channel_config)) for channel_config in channels['output_related']: group.add_ouptut_related( - Channel.init_using_cfg(model, channel_config)) + Channel.init_from_cfg(model, channel_config)) return group @classmethod @@ -384,9 +384,9 @@ def init_from_channel_group(cls, group: 'ChannelGroup', return group @classmethod - def parse_channel_groups(cls, - graph: ModuleGraph, - group_args={}) -> List['ChannelGroup']: + def init_from_graph(cls, + graph: ModuleGraph, + group_args={}) -> List['ChannelGroup']: """Parse a module-graph and get ChannelGroups.""" group_graph = PruneGraph.copy_from(graph, PruneNode.copy_from) groups = Graph2ChannelGroups(group_graph).groups @@ -410,7 +410,7 @@ def config_template(self, with_init_args=False, with_channels=False) -> Dict: """Generate a config template which can be used to initialize a - ChannelGroup by cls.init_using_cfg(**kwargs)""" + ChannelGroup by cls.init_from_cfg(**kwargs)""" config = {} if with_init_args: config['init_args'] = {'num_channels': self.num_channels} diff --git a/tests/test_models/test_mutables/group/test_mutable_channel_groups.py b/tests/test_models/test_mutables/group/test_mutable_channel_groups.py index 50e6212e9..e54e0f6a1 100644 --- a/tests/test_models/test_mutables/group/test_mutable_channel_groups.py +++ b/tests/test_models/test_mutables/group/test_mutable_channel_groups.py @@ -31,7 +31,7 @@ class TestMutableChannelGroup(TestCase): def _test_a_graph(self, model, graph): try: - groups = DefaultChannelGroup.parse_channel_groups(graph) + groups = DefaultChannelGroup.init_from_graph(graph) for group in groups: group.prepare_for_pruning(model) prunable_groups = [group for group in groups if group.is_mutable] @@ -83,7 +83,7 @@ def test_replace_with_dynamic_ops(self): model: nn.Module = model_data() graph = ModuleGraph.init_using_backward_tracer(model) groups: List[ - MutableChannelGroup] = group_type.parse_channel_groups( + MutableChannelGroup] = group_type.init_from_graph( graph) for group in groups: From dd8fb69975d177bbc5fee33a6531b6b56da284ae Mon Sep 17 00:00:00 2001 From: liukai Date: Fri, 9 Sep 2022 18:01:46 +0800 Subject: [PATCH 23/25] add unitest about init of MutableChannelGroup --- .../groups/mutable_channel_group.py | 5 +- .../group/test_mutable_channel_groups.py | 97 +++++++++++++++---- 2 files changed, 82 insertions(+), 20 deletions(-) diff --git a/mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py b/mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py index 3cd92f7fd..eaf08df8e 100644 --- a/mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py +++ b/mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py @@ -45,8 +45,9 @@ def __init__(self, num_channels: int) -> None: BaseModule.__init__(self) @classmethod - def init_from_channel_group(cls, group: ChannelGroup, - args: Dict) -> 'MutableChannelGroup': + def init_from_channel_group(cls, + group: ChannelGroup, + args: Dict = {}) -> 'MutableChannelGroup': """Initialize a MutalbeChannelGroup from a ChannelGroup.""" args['num_channels'] = group.num_channels mutable_group = cls(**args) diff --git a/tests/test_models/test_mutables/group/test_mutable_channel_groups.py b/tests/test_models/test_mutables/group/test_mutable_channel_groups.py index e54e0f6a1..255d898e0 100644 --- a/tests/test_models/test_mutables/group/test_mutable_channel_groups.py +++ b/tests/test_models/test_mutables/group/test_mutable_channel_groups.py @@ -10,8 +10,9 @@ from mmrazor.models.mutables.mutable_channel import (MutableChannelGroup, SequentialChannelGroup) from mmrazor.models.mutables.mutable_channel.groups.channel_group import ( # noqa - Channel, PruneNode) + Channel, ChannelGroup, PruneNode) from mmrazor.structures.graph import ModuleGraph as ModuleGraph +from ....data.models import LineModel from ....test_core.test_graph.test_graph import TestGraph MUTABLE_CFG = dict(type='SimpleMutableChannl') @@ -29,23 +30,74 @@ class TestMutableChannelGroup(TestCase): - def _test_a_graph(self, model, graph): - try: - groups = DefaultChannelGroup.init_from_graph(graph) - for group in groups: - group.prepare_for_pruning(model) - prunable_groups = [group for group in groups if group.is_mutable] - - for group in prunable_groups: - choice = group.sample_choice() - group.current_choice = choice - self.assertAlmostEqual(group.current_choice, choice, delta=0.1) - x = torch.rand([2, 3, 224, 224]).to(DEVICE) - y = model(x) - self.assertSequenceEqual(y.shape, [2, 1000]) - - except Exception as e: - self.fail(f'{e}') + def test_init_from_graph(self): + model = LineModel() + # init using tracer + graph = ModuleGraph.init_using_backward_tracer(model) + groups = DefaultChannelGroup.init_from_graph(graph) + self._test_groups(groups, model) + + def test_init_from_cfg(self): + model = LineModel() + # init using tracer + + config = { + 'init_args': { + 'num_channels': 8 + }, + 'channels': { + 'input_related': [{ + 'name': 'net.1', + 'start': 0, + 'end': 8, + 'expand_ratio': 1, + 'is_output_related': False + }, { + 'name': 'net.3', + 'start': 0, + 'end': 8, + 'expand_ratio': 1, + 'is_output_related': False + }], + 'output_related': [{ + 'name': 'net.0', + 'start': 0, + 'end': 8, + 'expand_ratio': 1, + 'is_output_related': True + }, { + 'name': 'net.1', + 'start': 0, + 'end': 8, + 'expand_ratio': 1, + 'is_output_related': True + }] + } + } + groups = [DefaultChannelGroup.init_from_cfg(model, config)] + self._test_groups(groups, model) + + def test_init_from_channel_group(self): + model = LineModel() + # init using tracer + graph = ModuleGraph.init_using_backward_tracer(model) + groups: List[ChannelGroup] = ChannelGroup.init_from_graph(graph) + mutable_groups = [ + DefaultChannelGroup.init_from_channel_group(group) + for group in groups + ] + self._test_groups(mutable_groups, model) + + def _test_groups(self, groups: List[MutableChannelGroup], model): + prunable_groups = [group for group in groups if group.is_mutable] + + for group in prunable_groups: + choice = group.sample_choice() + group.current_choice = choice + self.assertAlmostEqual(group.current_choice, choice, delta=0.1) + x = torch.rand([2, 3, 224, 224]).to(DEVICE) + y = model(x) + self.assertSequenceEqual(y.shape, [2, 1000]) def _test_a_model_using_backward_tracer(self, model): model.eval() @@ -101,3 +153,12 @@ def test_replace_with_dynamic_ops(self): if isinstance(module, nn.BatchNorm2d): self.assertTrue( isinstance(module, DynamicChannelMixin)) + + def _test_a_graph(self, model, graph): + try: + groups = DefaultChannelGroup.init_from_graph(graph) + for group in groups: + group.prepare_for_pruning(model) + self._test_groups(groups, model) + except Exception as e: + self.fail(f'{e}') From fa41dcd0bc05cd0b6fc9b7be5e2970f467f3c780 Mon Sep 17 00:00:00 2001 From: liukai Date: Tue, 13 Sep 2022 11:43:07 +0800 Subject: [PATCH 24/25] update according to reviews --- mmrazor/models/mutables/__init__.py | 4 ++-- .../mutables/mutable_channel/__init__.py | 4 ++-- .../mutable_channel/base_mutable_channel.py | 4 +++- .../mutable_channel/groups/__init__.py | 6 +++-- .../groups/sequential_channel_group.py | 10 ++++----- .../mutable_channel_container.py | 22 ++++++++----------- .../sequential_mutable_channel.py | 2 ++ .../mutable_channel/simple_mutable_channel.py | 11 ---------- .../group/test_mutable_channel_groups.py | 8 +++---- .../group/test_mutable_channels.py | 5 +++-- 10 files changed, 34 insertions(+), 42 deletions(-) diff --git a/mmrazor/models/mutables/__init__.py b/mmrazor/models/mutables/__init__.py index 14a60d066..2dc96ee77 100644 --- a/mmrazor/models/mutables/__init__.py +++ b/mmrazor/models/mutables/__init__.py @@ -6,7 +6,7 @@ SimpleMutableChannel, SlimmableMutableChannel, SquentialMutableChannel) from .mutable_channel.groups import (ChannelGroupType, MutableChannelGroup, - SequentialChannelGroup) + SequentialMutableChannelGroup) from .mutable_module import (DiffChoiceRoute, DiffMutableModule, DiffMutableOP, OneShotMutableModule, OneShotMutableOP) from .mutable_value import MutableValue, OneShotMutableValue @@ -17,6 +17,6 @@ 'OneShotMutableValue', 'SimpleMutableChannel', 'MutableChannelGroup', 'BaseMutableChannel', 'MutableChannelContainer', 'ChannelGroupType', 'BaseMutable', 'MutableChannel', 'SlimmableMutableChannel', - 'OneShotMutableChannel', 'SequentialChannelGroup', + 'OneShotMutableChannel', 'SequentialMutableChannelGroup', 'SquentialMutableChannel' ] diff --git a/mmrazor/models/mutables/mutable_channel/__init__.py b/mmrazor/models/mutables/mutable_channel/__init__.py index 0b11531f4..c5c911ea2 100644 --- a/mmrazor/models/mutables/mutable_channel/__init__.py +++ b/mmrazor/models/mutables/mutable_channel/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .base_mutable_channel import BaseMutableChannel from .groups import (ChannelGroupType, MutableChannelGroup, - SequentialChannelGroup) + SequentialMutableChannelGroup) from .mutable_channel import MutableChannel from .mutable_channel_container import MutableChannelContainer from .one_shot_mutable_channel import OneShotMutableChannel @@ -14,5 +14,5 @@ 'BaseMutableChannel', 'MutableChannelContainer', 'StackMutableChannel', 'ChannelGroupType', 'MutableChannel', 'OneShotMutableChannel', 'SlimmableMutableChannel', 'SquentialMutableChannel', - 'SequentialChannelGroup' + 'SequentialMutableChannelGroup' ] diff --git a/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py index b60b98184..28f1e4854 100644 --- a/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py +++ b/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py @@ -85,6 +85,8 @@ def num_choices(self) -> int: def __repr__(self): repr_str = self.__class__.__name__ - repr_str += f'(name={self.name}, ' + repr_str += '(' repr_str += f'num_channels={self.num_channels}, ' + repr_str += f'activated_channels={self.activated_channels}' + repr_str += ')' return repr_str diff --git a/mmrazor/models/mutables/mutable_channel/groups/__init__.py b/mmrazor/models/mutables/mutable_channel/groups/__init__.py index fbfc37d5c..c8ba8d4d4 100644 --- a/mmrazor/models/mutables/mutable_channel/groups/__init__.py +++ b/mmrazor/models/mutables/mutable_channel/groups/__init__.py @@ -14,6 +14,8 @@ Subclasses of MutableChannelGroup """ from .mutable_channel_group import ChannelGroupType, MutableChannelGroup -from .sequential_channel_group import SequentialChannelGroup +from .sequential_channel_group import SequentialMutableChannelGroup -__all__ = ['MutableChannelGroup', 'SequentialChannelGroup', 'ChannelGroupType'] +__all__ = [ + 'MutableChannelGroup', 'SequentialMutableChannelGroup', 'ChannelGroupType' +] diff --git a/mmrazor/models/mutables/mutable_channel/groups/sequential_channel_group.py b/mmrazor/models/mutables/mutable_channel/groups/sequential_channel_group.py index 093b3f91e..e6bd4cb2d 100644 --- a/mmrazor/models/mutables/mutable_channel/groups/sequential_channel_group.py +++ b/mmrazor/models/mutables/mutable_channel/groups/sequential_channel_group.py @@ -12,12 +12,12 @@ from .mutable_channel_group import MutableChannelGroup +# TODO change the name of SequentialMutableChannelGroup @MODELS.register_module() -class SequentialChannelGroup(MutableChannelGroup): - """SimpleChannelGroup defines a simple pruning algorithhm. - - The type of choice of SimpleChannelGroup is int. It indicates what ratio of - channels are remained from left to right. +class SequentialMutableChannelGroup(MutableChannelGroup): + """SequentialMutableChannelGroup accepts a intger as the choice, which + indicates the number of the channels are remained from left to right, like + 11110000. Args: num_channels (int): number of channels. diff --git a/mmrazor/models/mutables/mutable_channel/mutable_channel_container.py b/mmrazor/models/mutables/mutable_channel/mutable_channel_container.py index 107ef35f4..5b6070821 100644 --- a/mmrazor/models/mutables/mutable_channel/mutable_channel_container.py +++ b/mmrazor/models/mutables/mutable_channel/mutable_channel_container.py @@ -39,7 +39,7 @@ def current_choice(self) -> torch.Tensor: if len(self.mutable_channels) == 0: return torch.ones([self.num_channels]).bool() else: - self._full_empty_range() + self._fill_unregistered_range() self._assert_mutables_valid() mutable_channels = list(self.mutable_channels.values()) masks = [mutable.current_mask for mutable in mutable_channels] @@ -82,9 +82,14 @@ def _assert_mutables_valid(self): last_end = end assert last_end == self.num_channels - def _full_empty_range(self): - """Add SimpleMutableChannels in the range without any stored - BaseMutableChannel.""" + def _fill_unregistered_range(self): + """Fill with SimpleMutableChannels in the range without any stored + BaseMutableChannel. + + For example, if a MutableChannelContainer has 10 channels, and only the + [0,5) is registered with BaseMutableChannels, this method will + automatically register BaseMutableChannels in the range [5,10). + """ last_end = 0 for start, end in copy.copy(self.mutable_channels): if last_end < start: @@ -95,12 +100,3 @@ def _full_empty_range(self): self.register_mutable( SimpleMutableChannel(self.num_channels - last_end), last_end, self.num_channels) - - # others - - def __repr__(self): - repr_str = self.__class__.__name__ - repr_str += f'(name={self.name}, ' - repr_str += f'num_channels={self.num_channels}, ' - repr_str += f'activated_channels: {self.activated_channels}' - return repr_str diff --git a/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py index f3436f60e..3f9ea8cb6 100644 --- a/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py +++ b/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py @@ -4,6 +4,8 @@ from mmrazor.registry import MODELS from .base_mutable_channel import BaseMutableChannel +# TODO discuss later + @MODELS.register_module() class SquentialMutableChannel(BaseMutableChannel): diff --git a/mmrazor/models/mutables/mutable_channel/simple_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/simple_mutable_channel.py index 13e3ceb7b..7ece698d6 100644 --- a/mmrazor/models/mutables/mutable_channel/simple_mutable_channel.py +++ b/mmrazor/models/mutables/mutable_channel/simple_mutable_channel.py @@ -19,7 +19,6 @@ class SimpleMutableChannel(BaseMutableChannel): def __init__(self, num_channels: int, **kwargs) -> None: super().__init__(num_channels, **kwargs) - self.num_channels = num_channels self.mask = torch.ones(num_channels).bool() # choice @@ -53,13 +52,3 @@ def _expand_mask(mutable_channel, expand_ratio): derive_fun = partial( _expand_mask, mutable_channel=self, expand_ratio=expand_ratio) return DerivedMutable(derive_fun, derive_fun, [self]) - - # others - - def __repr__(self): - repr_str = self.__class__.__name__ - repr_str += '(' - repr_str += f'num_channels={self.num_channels}, ' - repr_str += f'activated_channels: {self.activated_channels}' - repr_str += ')' - return repr_str diff --git a/tests/test_models/test_mutables/group/test_mutable_channel_groups.py b/tests/test_models/test_mutables/group/test_mutable_channel_groups.py index 255d898e0..882fdcf8f 100644 --- a/tests/test_models/test_mutables/group/test_mutable_channel_groups.py +++ b/tests/test_models/test_mutables/group/test_mutable_channel_groups.py @@ -7,8 +7,8 @@ from mmrazor.models.architectures.dynamic_ops.bricks.dynamic_mixins import \ DynamicChannelMixin -from mmrazor.models.mutables.mutable_channel import (MutableChannelGroup, - SequentialChannelGroup) +from mmrazor.models.mutables.mutable_channel import ( + MutableChannelGroup, SequentialMutableChannelGroup) from mmrazor.models.mutables.mutable_channel.groups.channel_group import ( # noqa Channel, ChannelGroup, PruneNode) from mmrazor.structures.graph import ModuleGraph as ModuleGraph @@ -23,9 +23,9 @@ # DEVICE = torch.device('cuda:0') if torch.cuda.is_available() \ # else torch.device('cpu') DEVICE = torch.device('cpu') -GROUPS: List[MutableChannelGroup] = [SequentialChannelGroup] +GROUPS: List[MutableChannelGroup] = [SequentialMutableChannelGroup] -DefaultChannelGroup = SequentialChannelGroup +DefaultChannelGroup = SequentialMutableChannelGroup class TestMutableChannelGroup(TestCase): diff --git a/tests/test_models/test_mutables/group/test_mutable_channels.py b/tests/test_models/test_mutables/group/test_mutable_channels.py index f8a19c923..c93a43842 100644 --- a/tests/test_models/test_mutables/group/test_mutable_channels.py +++ b/tests/test_models/test_mutables/group/test_mutable_channels.py @@ -19,8 +19,9 @@ def test_SquentialMutableChannel(self): (mutable_channel.current_mask == torch.tensor([1, 1, 1, 0]).bool()).all()) channel_str = mutable_channel.__repr__() - self.assertEqual(channel_str, - 'SquentialMutableChannel(name=, num_channels=4, ') + self.assertEqual( + channel_str, + 'SquentialMutableChannel(num_channels=4, activated_channels=3)') mutable_channel.fix_chosen() mutable_channel.dump_chosen() From 07caed3c2a97d73ef92ac5129cfd1e1832610173 Mon Sep 17 00:00:00 2001 From: liukai Date: Wed, 14 Sep 2022 14:07:45 +0800 Subject: [PATCH 25/25] sequential_channel_group -> sequential_mutable_channel_group --- mmrazor/models/mutables/mutable_channel/groups/__init__.py | 2 +- ...ial_channel_group.py => sequential_mutable_channel_group.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename mmrazor/models/mutables/mutable_channel/groups/{sequential_channel_group.py => sequential_mutable_channel_group.py} (100%) diff --git a/mmrazor/models/mutables/mutable_channel/groups/__init__.py b/mmrazor/models/mutables/mutable_channel/groups/__init__.py index c8ba8d4d4..c3c25edd0 100644 --- a/mmrazor/models/mutables/mutable_channel/groups/__init__.py +++ b/mmrazor/models/mutables/mutable_channel/groups/__init__.py @@ -14,7 +14,7 @@ Subclasses of MutableChannelGroup """ from .mutable_channel_group import ChannelGroupType, MutableChannelGroup -from .sequential_channel_group import SequentialMutableChannelGroup +from .sequential_mutable_channel_group import SequentialMutableChannelGroup __all__ = [ 'MutableChannelGroup', 'SequentialMutableChannelGroup', 'ChannelGroupType' diff --git a/mmrazor/models/mutables/mutable_channel/groups/sequential_channel_group.py b/mmrazor/models/mutables/mutable_channel/groups/sequential_mutable_channel_group.py similarity index 100% rename from mmrazor/models/mutables/mutable_channel/groups/sequential_channel_group.py rename to mmrazor/models/mutables/mutable_channel/groups/sequential_mutable_channel_group.py