Skip to content

Commit

Permalink
add docstring
Browse files Browse the repository at this point in the history
base: lk:add docstring for IndexDict
  • Loading branch information
liukai committed Aug 29, 2022
1 parent 0eefd45 commit 37d9fb4
Show file tree
Hide file tree
Showing 12 changed files with 123 additions and 68 deletions.
13 changes: 6 additions & 7 deletions mmrazor/models/mutables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,18 @@
from .derived_mutable import DerivedMutable
from .mutable_channel import (BaseMutableChannel, MutableChannel,
MutableChannelContainer, OneShotMutableChannel,
SimpleMutableChannel, SlimmableMutableChannel,
StackMutableChannel)
SimpleMutableChannel, SlimmableMutableChannel)
from .mutable_channel.groups import (MUTABLECHANNELGROUP, MutableChannelGroup,
SimpleChannelGroup)
SequentialChannelGroup)
from .mutable_module import (DiffChoiceRoute, DiffMutableModule, DiffMutableOP,
OneShotMutableModule, OneShotMutableOP)
from .mutable_value import MutableValue, OneShotMutableValue

__all__ = [
'OneShotMutableOP', 'OneShotMutableModule', 'DiffMutableOP',
'DiffChoiceRoute', 'DiffMutableModule', 'DerivedMutable', 'MutableValue',
'OneShotMutableValue', 'SimpleChannelGroup', 'SimpleMutableChannel',
'MutableChannelGroup', 'BaseMutableChannel', 'MutableChannelContainer',
'MUTABLECHANNELGROUP', 'StackMutableChannel', 'BaseMutable',
'MutableChannel', 'SlimmableMutableChannel', 'OneShotMutableChannel'
'OneShotMutableValue', 'SimpleMutableChannel', 'MutableChannelGroup',
'BaseMutableChannel', 'MutableChannelContainer', 'MUTABLECHANNELGROUP',
'BaseMutable', 'MutableChannel', 'SlimmableMutableChannel',
'OneShotMutableChannel', 'SequentialChannelGroup'
]
13 changes: 7 additions & 6 deletions mmrazor/models/mutables/mutable_channel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base_mutable_channel import BaseMutableChannel
from .groups import (MUTABLECHANNELGROUP, MutableChannelGroup,
SimpleChannelGroup)
SequentialChannelGroup)
from .mutable_channel import MutableChannel
from .mutable_channel_container import MutableChannelContainer
from .one_shot_mutable_channel import OneShotMutableChannel
from .sequential_mutable_channel import SquentialMutableChannel
from .simple_mutable_channel import SimpleMutableChannel
from .slimmable_mutable_channel import SlimmableMutableChannel
from .stack_mutable_channel import StackMutableChannel

__all__ = [
'SimpleMutableChannel', 'SimpleChannelGroup', 'MutableChannelGroup',
'OneShotChannelGroup', 'BaseMutableChannel', 'MutableChannelContainer',
'StackMutableChannel', 'MUTABLECHANNELGROUP', 'MutableChannel',
'OneShotMutableChannel', 'SlimmableMutableChannel'
'SimpleMutableChannel', 'MutableChannelGroup', 'OneShotChannelGroup',
'BaseMutableChannel', 'MutableChannelContainer', 'StackMutableChannel',
'MUTABLECHANNELGROUP', 'MutableChannel', 'OneShotMutableChannel',
'SlimmableMutableChannel', 'SquentialMutableChannel',
'SequentialChannelGroup'
]
28 changes: 9 additions & 19 deletions mmrazor/models/mutables/mutable_channel/base_mutable_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,8 @@


class BaseMutableChannel(BaseMutable, DerivedMethodMixin):
"""A type of ``MUTABLES`` for single path supernet such as AutoSlim. In
single path supernet, each module only has one choice invoked at the same
time. A path is obtained by sampling all the available choices. It is the
base class for one shot channel mutables.
Args:
num_channels (int): The raw number of channels.
init_cfg (dict, optional): initialization configuration dict for
``BaseModule``. OpenMMLab has implement 5 initializer including
`Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`,
and `Pretrained`.
"""
"""BaseMutableChannel works as a channel mask for dynamic ops to select
channels."""

def __init__(self, num_channels: int, **kwargs):
super().__init__(**kwargs)
Expand All @@ -28,30 +18,28 @@ def __init__(self, num_channels: int, **kwargs):

@abstractproperty
def current_choice(self):
"""get current choice."""
raise NotImplementedError()

@current_choice.setter
def current_choice(self):
"""set current choice."""
raise NotImplementedError()

@abstractproperty
def current_mask(self):
"""Return a mask indicating the channel selection."""
raise NotImplementedError()

@property
def activated_channels(self):
"""Number of activated channels."""
return (self.current_mask == 1).sum().item()

# implementation of abstract methods

def fix_chosen(self, chosen=None):
"""Fix mutable with subnet config. This operation would convert
`unfixed` mode to `fixed` mode. The :attr:`is_fixed` will be set to
True and only the selected operations can be retained.
Args:
chosen (str): The chosen key in ``MUTABLE``. Defaults to None.
"""
"""Fix the mutable with chosen."""
if chosen is not None:
self.current_choice = chosen

Expand All @@ -63,9 +51,11 @@ def fix_chosen(self, chosen=None):
self.is_fixed = True

def dump_chosen(self):
"""dump current choice to a dict."""
raise NotImplementedError()

def num_choices(self) -> int:
"""Number of available choices."""
raise NotImplementedError()

# others
Expand Down
6 changes: 4 additions & 2 deletions mmrazor/models/mutables/mutable_channel/groups/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .mutable_channel_group import MUTABLECHANNELGROUP, MutableChannelGroup
from .simple_channel_group import SimpleChannelGroup
from .sequential_channel_group import SequentialChannelGroup

__all__ = ['MutableChannelGroup', 'SimpleChannelGroup', 'MUTABLECHANNELGROUP']
__all__ = [
'MutableChannelGroup', 'SequentialChannelGroup', 'MUTABLECHANNELGROUP'
]
15 changes: 12 additions & 3 deletions mmrazor/models/mutables/mutable_channel/groups/channel_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def colloct_groups(self) -> List['ChannelGroup']:
return groups


# ChannelGroup
# Channel && ChannelGroup


class Channel:
Expand Down Expand Up @@ -325,7 +325,7 @@ def list_repr(lit: List):
return s


# Converter
# Group to ChannelGroup Converter


class Graph2ChannelGroups:
Expand All @@ -336,6 +336,11 @@ def __init__(
graph: PruneGraph,
channel_group_cfg: Union[Dict,
Type[ChannelGroup]] = ChannelGroup) -> None:
"""
Args:
graph (PruneGraph): input prune-graph
channel_group_cfg: the config for generating groups
"""
self.graph = graph
if isinstance(channel_group_cfg, dict):
self.channel_group_class = MODELS.module_dict[
Expand All @@ -350,6 +355,7 @@ def __init__(
# group operations

def new_channel_group(self, num_channels):
"""Initialize a ChannelGroup."""
return self.channel_group_class(num_channels,
**self.channel_group_args)

Expand Down Expand Up @@ -425,7 +431,8 @@ def add_input_related(self,
node: PruneNode,
index=None,
expand_ratio=1):
"""add some channels of a prune-node to a channel-group."""
"""Add channels of a prune-node to a the input-related channels of a
channel-group."""
if index is None:
index = (0, node.in_channels)
group.add_input_related(
Expand All @@ -438,6 +445,8 @@ def add_output_related(self,
node: PruneNode,
index=None,
expand_ratio=1):
"""Add channels of a prune-node to a the output-related channels of a
channel-group."""
if index is None:
index = (0, node.out_channels)
group.add_ouptut_related(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@


class MutableChannelGroup(ChannelGroup, BaseModule):
"""MutableChannelGroup manages channels with dependency."""
"""MutableChannelGroup inherents from ChannelGroup, which manages channels
with dependency.
MutableChannelGroup defines the core interfaces for pruning.
"""

def __init__(self, num_channels: int) -> None:
"""
Expand Down Expand Up @@ -82,7 +86,7 @@ def prepare_model(cls, model: nn.Module):
def prepare_for_pruning(self):
"""Post process after parse groups.
For example, we need to register mutable to dynamic-ops
For example, we need to register mutables to dynamic-ops.
"""
raise NotImplementedError()

Expand Down Expand Up @@ -121,7 +125,7 @@ def traverse(module):

@staticmethod
def _register_mask_container(model: nn.Module, container_class):
"""register mask container for dynamic ops."""
"""register channel container for dynamic ops."""
for module in model.modules():
if isinstance(module, DynamicChannelMixin):
if module.get_mutable_attr('in_channels') is None:
Expand Down Expand Up @@ -151,7 +155,7 @@ def _register_mask_container(model: nn.Module, container_class):

def _register_mask(self, mutable_channel: SimpleMutableChannel):

# register MutableMask
# register mutable_channel
for channel in self.input_related + self.output_related:
module = channel.module
if isinstance(module, DynamicChannelMixin):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@


@MODELS.register_module()
class SimpleChannelGroup(MutableChannelGroup):
class SequentialChannelGroup(MutableChannelGroup):
"""SimpleChannelGroup defines a simple pruning algorithhm.
The type of choice of SimpleChannelGroup is float. It indicates what ratio
of channels are remained from left to right.
"""

def __init__(self, num_channels) -> None:
super().__init__(num_channels)
Expand All @@ -25,16 +30,18 @@ def __init__(self, num_channels) -> None:

@property
def current_choice(self) -> float:
"""return current choice."""
return self.mutable_channel.activated_channels / self.num_channels

@current_choice.setter
def current_choice(self, choice: float):
"""Current choice setter will be executed in mutator."""
"""set choice."""
int_choice = self._get_int_choice(choice)
mask = self._generate_mask(int_choice)
self.mutable_channel.current_choice = mask

def sample_choice(self):
"""Sample a choice in (0,1]"""
return max(1, int(
random.random() * self.num_channels)) / self.num_channels

Expand All @@ -45,6 +52,11 @@ def prepare_model(
cls,
model: nn.Module,
):
"""Prepare a model, including two steps:
1. replace torch modules with dynamic ops.
2. register channel containers
"""
cls._replace_with_dynamic_ops(
model, {
nn.Conv2d: DynamicConv2d,
Expand All @@ -56,7 +68,7 @@ def prepare_model(
cls._register_mask_container(model, MutableChannelContainer)

def prepare_for_pruning(self):

"""Prepare for pruning, including register mutable channels."""
# register MutableMask
self._register_mask(self.mutable_channel)

Expand All @@ -70,5 +82,6 @@ def _generate_mask(self, choice: int) -> torch.Tensor:

# interface
def fix_chosen(self, choice=None):
"""fix chosen."""
super().fix_chosen(choice)
self.mutable_channel.fix_chosen()
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@

@MODELS.register_module()
class MutableChannelContainer(BaseMutableChannel):
"""MutableChannelContainer is a container for BaseMutableChannel.
The mask of MutableChannelContainer consists of all masks of stored
BaseMutableChannel.
"""

def __init__(self, num_channels: int, **kwargs):
super().__init__(num_channels, **kwargs)
Expand All @@ -20,41 +25,56 @@ def __init__(self, num_channels: int, **kwargs):

@property
def current_choice(self):
"""Get current choices."""
if len(self.mutable_channels) == 0:
return torch.ones([self.num_channels]).bool()
else:
self._full_with_empty_mask()
self._assert_mask_valid()
self._full_empty_range()
self._assert_mutables_valid()
mutable_channels = list(self.mutable_channels.values())
masks = [mutable.current_mask for mutable in mutable_channels]
mask = torch.cat(masks)
return mask.bool()

@current_choice.setter
def current_choice(self, choice):
"""Set current choices.
However, MutableChannelContainer doesn't support directly set mask. You
can change the mask of MutableChannelContainer by changing its stored
BaseMutableChannel.
"""
raise NotImplementedError()

@property
def current_mask(self):
"""Return current mask."""
return self.current_choice.bool()

# basic extension

def register_mutable(self, mutable_channel: SimpleMutableChannel,
start: int, end: int):
"""Register/Store BaseMutableChannel in the MutableChannelContainer in
the range [start,end)"""

self.mutable_channels[(start, end)] = mutable_channel

# private methods

def _assert_mask_valid(self):
def _assert_mutables_valid(self):
"""Assert the current stored BaseMutableChannels are valid to generate
mask."""
assert len(self.mutable_channels) > 0
last_end = 0
for start, end in self.mutable_channels:
assert start == last_end
last_end = end
assert last_end == self.num_channels

def _full_with_empty_mask(self):
def _full_empty_range(self):
"""Add SimpleMutableChannels in the range without any stored
BaseMutableChannel."""
last_end = 0
for start, end in copy.copy(self.mutable_channels):
if last_end < start:
Expand Down
Loading

0 comments on commit 37d9fb4

Please sign in to comment.