Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ChannelGroup #250

Merged
merged 25 commits into from
Sep 14, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions mmrazor/models/architectures/dynamic_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base import DynamicOP
from .default_dynamic_ops import (DynamicBatchNorm, DynamicConv2d,
DynamicGroupNorm, DynamicInstanceNorm,
DynamicLinear)
from .default_dynamic_ops import (ChannelDynamicOP, DynamicBatchNorm,
LKJacky marked this conversation as resolved.
Show resolved Hide resolved
DynamicConv2d, DynamicGroupNorm,
DynamicInstanceNorm, DynamicLinear)
from .slimmable_dynamic_ops import SwitchableBatchNorm2d

__all__ = [
'DynamicConv2d', 'DynamicLinear', 'DynamicBatchNorm',
'DynamicInstanceNorm', 'DynamicGroupNorm', 'SwitchableBatchNorm2d',
'DynamicOP'
'DynamicOP', 'ChannelDynamicOP'
]
26 changes: 16 additions & 10 deletions mmrazor/models/architectures/dynamic_ops/bricks/dynamic_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,21 @@ def __init__(self, *args, **kwargs) -> None:
def convert_from(cls, module: nn.Conv2d) -> 'DynamicConv2d':
"""Convert an instance of nn.Conv2d to a new instance of
DynamicConv2d."""
return cls(
in_channels=module.in_channels,
out_channels=module.out_channels,
kernel_size=module.kernel_size,
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
bias=True if module.bias is not None else False,
padding_mode=module.padding_mode)
# a group-wise conv will not be converted to dynamic conv
LKJacky marked this conversation as resolved.
Show resolved Hide resolved
if module.groups > 1 and not (module.groups == module.out_channels ==
LKJacky marked this conversation as resolved.
Show resolved Hide resolved
module.in_channels):
return module
else:
return cls(
in_channels=module.in_channels,
out_channels=module.out_channels,
kernel_size=module.kernel_size,
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
bias=True if module.bias is not None else False,
padding_mode=module.padding_mode)

@property
def conv_func(self) -> Callable:
Expand Down Expand Up @@ -146,6 +151,7 @@ def __init__(self, *args, **kwargs) -> None:
def convert_from(cls, module: nn.Conv2d) -> 'OFAConv2d':
"""Convert an instance of `nn.Conv2d` to a new instance of
`OFAConv2d`."""

return cls(
in_channels=module.in_channels,
out_channels=module.out_channels,
Expand Down
16 changes: 11 additions & 5 deletions mmrazor/models/mutables/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base_mutable import BaseMutable
from .derived_mutable import DerivedMutable
from .mutable_channel import (MutableChannel, OneShotMutableChannel,
SlimmableMutableChannel)
from .mutable_channel import (BaseMutableChannel, MutableChannel,
MutableChannelContainer, OneShotMutableChannel,
SimpleMutableChannel, SlimmableMutableChannel)
from .mutable_channel.groups import (MUTABLECHANNELGROUP, MutableChannelGroup,
SequentialChannelGroup)
from .mutable_module import (DiffChoiceRoute, DiffMutableModule, DiffMutableOP,
OneShotMutableModule, OneShotMutableOP)
from .mutable_value import MutableValue, OneShotMutableValue

__all__ = [
'OneShotMutableOP', 'OneShotMutableModule', 'DiffMutableOP',
'DiffChoiceRoute', 'DiffMutableModule', 'OneShotMutableChannel',
'SlimmableMutableChannel', 'MutableChannel', 'DerivedMutable',
'MutableValue', 'OneShotMutableValue'
'DiffChoiceRoute', 'DiffMutableModule', 'DerivedMutable', 'MutableValue',
'OneShotMutableValue', 'SimpleMutableChannel', 'MutableChannelGroup',
'BaseMutableChannel', 'MutableChannelContainer', 'MUTABLECHANNELGROUP',
'BaseMutable', 'MutableChannel', 'SlimmableMutableChannel',
'OneShotMutableChannel', 'SequentialChannelGroup'
]
25 changes: 24 additions & 1 deletion mmrazor/models/mutables/mutable_channel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,31 @@
# Copyright (c) OpenMMLab. All rights reserved.
r"""This module defines MutableChannels.
LKJacky marked this conversation as resolved.
Show resolved Hide resolved

----------------------------------------------------------base_mutable_channel.py
BaseMutableChannel
| \
----------------------------------------------------------mutable_channel_container.py
MutableChannelContainer \
----------------------------------------------------------other files
\ other MutableChannels

MutableChannel are mainly used in DynamicOps. It helps DynamicOps to deal
with mutable number of channels.
"""
from .base_mutable_channel import BaseMutableChannel
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clean old mutable channels

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Old MutableChannels are also used by autoslim and slimmable. Remove after refactor?

from .groups import (MUTABLECHANNELGROUP, MutableChannelGroup,
SequentialChannelGroup)
from .mutable_channel import MutableChannel
from .mutable_channel_container import MutableChannelContainer
from .one_shot_mutable_channel import OneShotMutableChannel
from .sequential_mutable_channel import SquentialMutableChannel
from .simple_mutable_channel import SimpleMutableChannel
from .slimmable_mutable_channel import SlimmableMutableChannel

__all__ = [
'OneShotMutableChannel', 'SlimmableMutableChannel', 'MutableChannel'
'SimpleMutableChannel', 'MutableChannelGroup', 'OneShotChannelGroup',
'BaseMutableChannel', 'MutableChannelContainer', 'StackMutableChannel',
'MUTABLECHANNELGROUP', 'MutableChannel', 'OneShotMutableChannel',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MUTABLECHANNELGROUP just is a typa var, why add it here? Will it be used in other modules?

Copy link
Collaborator Author

@LKJacky LKJacky Sep 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's used in Mutator, or is it declared in Mutator?

'SlimmableMutableChannel', 'SquentialMutableChannel',
'SequentialChannelGroup'
]
83 changes: 83 additions & 0 deletions mmrazor/models/mutables/mutable_channel/base_mutable_channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright (c) OpenMMLab. All rights reserved.
""""""
from abc import abstractproperty

import torch

from ..base_mutable import BaseMutable
from ..derived_mutable import DerivedMethodMixin


class BaseMutableChannel(BaseMutable, DerivedMethodMixin):
"""BaseMutableChannel works as a channel mask for DynamicOps to select
channels.

|---------------------------------------|
|in_channel_mutable(BaseMutableChannel) |
LKJacky marked this conversation as resolved.
Show resolved Hide resolved
|---------------------------------------|
| DynamicOp |
|---------------------------------------|
|out_channel_mutable(BaseMutableChannel)|
|---------------------------------------|

Important interfaces:
current_choice: used to get/set mask.
current_mask: get mask(used in DynamicOps to get mask).
"""
LKJacky marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, num_channels: int, **kwargs):
super().__init__(**kwargs)
self.name = ''
self.num_channels = num_channels

# choice
LKJacky marked this conversation as resolved.
Show resolved Hide resolved

@abstractproperty
LKJacky marked this conversation as resolved.
Show resolved Hide resolved
def current_choice(self):
"""get current choice."""
raise NotImplementedError()

@current_choice.setter
def current_choice(self):
"""set current choice."""
raise NotImplementedError()

@abstractproperty
def current_mask(self) -> torch.Tensor:
"""Return a mask indicating the channel selection."""
raise NotImplementedError()
LKJacky marked this conversation as resolved.
Show resolved Hide resolved

@property
def activated_channels(self) -> int:
"""Number of activated channels."""
return (self.current_mask == 1).sum().item()

# implementation of abstract methods
LKJacky marked this conversation as resolved.
Show resolved Hide resolved

def fix_chosen(self, chosen=None):
"""Fix the mutable with chosen."""
if chosen is not None:
self.current_choice = chosen

if self.is_fixed:
raise AttributeError(
'The mode of current MUTABLE is `fixed`. '
'Please do not call `fix_chosen` function again.')

self.is_fixed = True

def dump_chosen(self):
"""dump current choice to a dict."""
raise NotImplementedError()

def num_choices(self) -> int:
"""Number of available choices."""
raise NotImplementedError()

# others
LKJacky marked this conversation as resolved.
Show resolved Hide resolved

def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(name={self.name}, '
repr_str += f'num_channels={self.num_channels}, '
return repr_str
21 changes: 21 additions & 0 deletions mmrazor/models/mutables/mutable_channel/groups/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""
----------------------------------------------------------channel_group.py
PruneNode && PruneGraph
|
| Graph2ChannelGroups
|
Channel && ChannelGroup
|
----------------------------------------------------------mutable_channel_group.py
MutableChannelGroup
|
----------------------------------------------------------other files
Subclasses of MutableChannelGroup
"""
from .mutable_channel_group import MUTABLECHANNELGROUP, MutableChannelGroup
from .sequential_channel_group import SequentialChannelGroup

__all__ = [
'MutableChannelGroup', 'SequentialChannelGroup', 'MUTABLECHANNELGROUP'
]
Loading