diff --git a/mmrazor/models/architectures/dynamic_ops/__init__.py b/mmrazor/models/architectures/dynamic_ops/__init__.py index 6b5796688..620c9e4c8 100644 --- a/mmrazor/models/architectures/dynamic_ops/__init__.py +++ b/mmrazor/models/architectures/dynamic_ops/__init__.py @@ -1,12 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .base import DynamicOP -from .default_dynamic_ops import (DynamicBatchNorm, DynamicConv2d, - DynamicGroupNorm, DynamicInstanceNorm, - DynamicLinear) -from .slimmable_dynamic_ops import SwitchableBatchNorm2d +from .bricks.dynamic_conv import BigNasConv2d, DynamicConv2d, OFAConv2d +from .bricks.dynamic_linear import DynamicLinear +from .bricks.dynamic_norm import (DynamicBatchNorm1d, DynamicBatchNorm2d, + DynamicBatchNorm3d, SwitchableBatchNorm2d) +from .mixins.dynamic_conv_mixins import DynamicConvMixin +from .mixins.dynamic_mixins import (DynamicBatchNormMixin, DynamicChannelMixin, + DynamicLinearMixin, DynamicMixin) __all__ = [ - 'DynamicConv2d', 'DynamicLinear', 'DynamicBatchNorm', - 'DynamicInstanceNorm', 'DynamicGroupNorm', 'SwitchableBatchNorm2d', - 'DynamicOP' + 'BigNasConv2d', 'DynamicConv2d', 'OFAConv2d', 'DynamicLinear', + 'DynamicBatchNorm1d', 'DynamicBatchNorm2d', 'DynamicBatchNorm3d', + 'DynamicMixin', 'DynamicChannelMixin', 'DynamicBatchNormMixin', + 'DynamicLinearMixin', 'SwitchableBatchNorm2d', 'DynamicConvMixin' ] diff --git a/mmrazor/models/architectures/dynamic_ops/base.py b/mmrazor/models/architectures/dynamic_ops/base.py deleted file mode 100644 index 2a1720ea2..000000000 --- a/mmrazor/models/architectures/dynamic_ops/base.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from abc import ABC, abstractmethod -from typing import Any, Optional, Set - -from torch import nn - -from mmrazor.models.mutables.base_mutable import BaseMutable - - -class DynamicOP(ABC): - """Base class for dynamic OP. A dynamic OP usually consists of a normal - static OP and mutables, where mutables are used to control the searchable - (mutable) part of the dynamic OP. - - Note: - When the dynamic OP has just been initialized, its forward propagation - logic should be the same as the corresponding static OP. Only after - the searchable part accepts the specific mutable through the - corresponding interface does the part really become dynamic. - - Note: - All subclass should implement ``to_static_op`` API. - - Args: - accepted_mutables (set): The string set of all accepted mutables. - """ - accepted_mutables: Set[str] = set() - - @abstractmethod - def to_static_op(self) -> nn.Module: - """Convert dynamic OP to static OP. - - Note: - The forward result for the same input between dynamic OP and its - corresponding static OP must be same. - - Returns: - nn.Module: Corresponding static OP. - """ - - def check_if_mutables_fixed(self) -> None: - """Check if all mutables are fixed. - - Raises: - RuntimeError: Error if a existing mutable is not fixed. - """ - - def check_fixed(mutable: Optional[BaseMutable]) -> None: - if mutable is not None and not mutable.is_fixed: - raise RuntimeError(f'Mutable {type(mutable)} is not fixed.') - - for mutable in self.accepted_mutables: - check_fixed(getattr(self, f'{mutable}')) - - @staticmethod - def get_current_choice(mutable: BaseMutable) -> Any: - """Get current choice of given mutable. - - Args: - mutable (BaseMutable): Given mutable. - - Raises: - RuntimeError: Error if `current_choice` is None. - - Returns: - Any: Current choice of given mutable. - """ - current_choice = mutable.current_choice - if current_choice is None: - raise RuntimeError(f'current choice of mutable {type(mutable)} ' - 'can not be None at runtime') - - return current_choice - - -class ChannelDynamicOP(DynamicOP): - """Base class for dynamic OP with mutable channels. - - Note: - All subclass should implement ``mutable_in`` and ``mutable_out`` APIs. - """ - - @property - @abstractmethod - def mutable_in(self) -> Optional[BaseMutable]: - """Mutable related to input.""" - - @property - @abstractmethod - def mutable_out(self) -> Optional[BaseMutable]: - """Mutable related to output.""" - - @staticmethod - def check_mutable_channels(mutable_channels: BaseMutable) -> None: - """Check if mutable has `currnet_mask` attribute. - - Args: - mutable_channels (BaseMutable): Mutable to be checked. - - Raises: - ValueError: Error if mutable does not have `current_mask` - attribute. - """ - if not hasattr(mutable_channels, 'current_mask'): - raise ValueError( - 'channel mutable must have attribute `current_mask`') diff --git a/mmrazor/models/architectures/dynamic_ops/bricks/__init__.py b/mmrazor/models/architectures/dynamic_ops/bricks/__init__.py index 1e528d3bf..ef101fec6 100644 --- a/mmrazor/models/architectures/dynamic_ops/bricks/__init__.py +++ b/mmrazor/models/architectures/dynamic_ops/bricks/__init__.py @@ -1,14 +1 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .dynamic_conv import BigNasConv2d, DynamicConv2d, OFAConv2d -from .dynamic_linear import DynamicLinear -from .dynamic_mixins import (DynamicBatchNormMixin, DynamicChannelMixin, - DynamicLinearMixin, DynamicMixin) -from .dynamic_norm import (DynamicBatchNorm1d, DynamicBatchNorm2d, - DynamicBatchNorm3d, SwitchableBatchNorm2d) - -__all__ = [ - 'BigNasConv2d', 'DynamicConv2d', 'OFAConv2d', 'DynamicLinear', - 'DynamicBatchNorm1d', 'DynamicBatchNorm2d', 'DynamicBatchNorm3d', - 'DynamicMixin', 'DynamicChannelMixin', 'DynamicBatchNormMixin', - 'DynamicLinearMixin', 'SwitchableBatchNorm2d' -] diff --git a/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_conv.py b/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_conv.py index 41ec11704..71fc7ab98 100644 --- a/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_conv.py +++ b/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_conv.py @@ -7,8 +7,8 @@ from mmrazor.models.mutables.base_mutable import BaseMutable from mmrazor.registry import MODELS -from .dynamic_conv_mixins import (BigNasConvMixin, DynamicConvMixin, - OFAConvMixin) +from ..mixins.dynamic_conv_mixins import (BigNasConvMixin, DynamicConvMixin, + OFAConvMixin) GroupWiseConvWarned = False diff --git a/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_linear.py b/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_linear.py index aa7bcbccc..4faa0c8b7 100644 --- a/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_linear.py +++ b/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_linear.py @@ -6,7 +6,7 @@ from torch import Tensor from mmrazor.models.mutables.base_mutable import BaseMutable -from .dynamic_mixins import DynamicLinearMixin +from ..mixins import DynamicLinearMixin class DynamicLinear(nn.Linear, DynamicLinearMixin): diff --git a/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_norm.py b/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_norm.py index 4490b237e..3b88e6a82 100644 --- a/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_norm.py +++ b/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_norm.py @@ -8,7 +8,7 @@ from mmrazor.models.mutables.base_mutable import BaseMutable from mmrazor.registry import MODELS -from .dynamic_mixins import DynamicBatchNormMixin +from ..mixins.dynamic_mixins import DynamicBatchNormMixin class _DynamicBatchNorm(_BatchNorm, DynamicBatchNormMixin): diff --git a/mmrazor/models/architectures/dynamic_ops/default_dynamic_ops.py b/mmrazor/models/architectures/dynamic_ops/default_dynamic_ops.py deleted file mode 100644 index 2488a49eb..000000000 --- a/mmrazor/models/architectures/dynamic_ops/default_dynamic_ops.py +++ /dev/null @@ -1,333 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy -from typing import Optional, Tuple - -import torch.nn as nn -import torch.nn.functional as F -from torch import Tensor -from torch.nn.modules import GroupNorm -from torch.nn.modules.batchnorm import _BatchNorm -from torch.nn.modules.instancenorm import _InstanceNorm - -from mmrazor.models.mutables.mutable_channel import MutableChannel -from mmrazor.registry import MODELS -from .base import ChannelDynamicOP - - -class DynamicConv2d(nn.Conv2d, ChannelDynamicOP): - """Applies a 2D convolution over an input signal composed of several input - planes according to the `mutable_in_channels` and `mutable_out_channels` - dynamically. - - Args: - in_channels_cfg (Dict): Config related to `in_channels`. - out_channels_cfg (Dict): Config related to `out_channels`. - """ - accepted_mutables = {'mutable_in_channels', 'mutable_out_channels'} - - def __init__(self, in_channels_cfg, out_channels_cfg, *args, **kwargs): - super(DynamicConv2d, self).__init__(*args, **kwargs) - - in_channels_cfg_ = copy.deepcopy(in_channels_cfg) - in_channels_cfg_.update(dict(num_channels=self.in_channels)) - self.mutable_in_channels = MODELS.build(in_channels_cfg_) - - out_channels_cfg_ = copy.deepcopy(out_channels_cfg) - out_channels_cfg_.update(dict(num_channels=self.out_channels)) - self.mutable_out_channels = MODELS.build(out_channels_cfg_) - - assert isinstance(self.mutable_in_channels, MutableChannel) - assert isinstance(self.mutable_out_channels, MutableChannel) - # TODO - # https://pytorch.org/docs/stable/_modules/torch/nn/modules/conv.html#Conv2d - assert self.padding_mode == 'zeros' - - @property - def mutable_in(self) -> MutableChannel: - """Mutable `in_channels`.""" - return self.mutable_in_channels - - @property - def mutable_out(self) -> MutableChannel: - """Mutable `out_channels`.""" - return self.mutable_out_channels - - def forward(self, input: Tensor) -> Tensor: - """Slice the parameters according to `mutable_in_channels` and - `mutable_out_channels`, and forward.""" - groups = self.groups - if self.groups == self.in_channels == self.out_channels: - groups = input.size(1) - weight, bias = self._get_dynamic_params() - - return F.conv2d(input, weight, bias, self.stride, self.padding, - self.dilation, groups) - - def _get_dynamic_params(self) -> Tuple[Tensor, Optional[Tensor]]: - in_mask = self.mutable_in_channels.current_mask.to(self.weight.device) - out_mask = self.mutable_out_channels.current_mask.to( - self.weight.device) - - if self.groups == 1: - weight = self.weight[out_mask][:, in_mask] - elif self.groups == self.in_channels == self.out_channels: - # depth-wise conv - weight = self.weight[out_mask] - else: - raise NotImplementedError( - 'Current `ChannelMutator` only support pruning the depth-wise ' - '`nn.Conv2d` or `nn.Conv2d` module whose group number equals ' - f'to one, but got {self.groups}.') - - bias = self.bias[out_mask] if self.bias is not None else None - - return weight, bias - - def to_static_op(self) -> nn.Conv2d: - assert self.mutable_in.is_fixed and self.mutable_out.is_fixed - - weight, bias, = self._get_dynamic_params() - groups = self.groups - if groups == self.in_channels == self.out_channels: - groups = self.mutable_in.current_mask.sum().item() - out_channels = weight.size(0) - in_channels = weight.size(1) * groups - - static_conv2d = nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=self.kernel_size, - stride=self.stride, - padding=self.padding, - padding_mode=self.padding_mode, - dilation=self.dilation, - groups=groups, - bias=True if bias is not None else False) - - static_conv2d.weight = nn.Parameter(weight) - if bias is not None: - static_conv2d.bias = nn.Parameter(bias) - - return static_conv2d - - -class DynamicLinear(nn.Linear, ChannelDynamicOP): - """Applies a linear transformation to the incoming data according to the - `mutable_in_features` and `mutable_out_features` dynamically. - - Args: - in_features_cfg (Dict): Config related to `in_features`. - out_features_cfg (Dict): Config related to `out_features`. - """ - accepted_mutables = {'mutable_in_features', 'mutable_out_features'} - - def __init__(self, in_features_cfg, out_features_cfg, *args, **kwargs): - super(DynamicLinear, self).__init__(*args, **kwargs) - - in_features_cfg_ = copy.deepcopy(in_features_cfg) - in_features_cfg_.update(dict(num_channels=self.in_features)) - self.mutable_in_features = MODELS.build(in_features_cfg_) - - out_features_cfg_ = copy.deepcopy(out_features_cfg) - out_features_cfg_.update(dict(num_channels=self.out_features)) - self.mutable_out_features = MODELS.build(out_features_cfg_) - - @property - def mutable_in(self): - """Mutable `in_features`.""" - return self.mutable_in_features - - @property - def mutable_out(self): - """Mutable `out_features`.""" - return self.mutable_out_features - - def forward(self, input: Tensor) -> Tensor: - """Slice the parameters according to `mutable_in_features` and - `mutable_out_features`, and forward.""" - in_mask = self.mutable_in_features.current_mask.to(self.weight.device) - out_mask = self.mutable_out_features.current_mask.to( - self.weight.device) - - weight = self.weight[out_mask][:, in_mask] - bias = self.bias[out_mask] if self.bias is not None else None - - return F.linear(input, weight, bias) - - # TODO - def to_static_op(self) -> nn.Module: - return self - - -class DynamicBatchNorm(_BatchNorm, ChannelDynamicOP): - """Applies Batch Normalization over an input according to the - `mutable_num_features` dynamically. - - Args: - num_features_cfg (Dict): Config related to `num_features`. - """ - accepted_mutables = {'mutable_num_features'} - - def __init__(self, num_features_cfg, *args, **kwargs): - super(DynamicBatchNorm, self).__init__(*args, **kwargs) - - num_features_cfg_ = copy.deepcopy(num_features_cfg) - num_features_cfg_.update(dict(num_channels=self.num_features)) - self.mutable_num_features = MODELS.build(num_features_cfg_) - - @property - def mutable_in(self): - """Mutable `num_features`.""" - return self.mutable_num_features - - @property - def mutable_out(self): - """Mutable `num_features`.""" - return self.mutable_num_features - - def forward(self, input: Tensor) -> Tensor: - """Slice the parameters according to `mutable_num_features`, and - forward.""" - if self.momentum is None: - exponential_average_factor = 0.0 - else: - exponential_average_factor = self.momentum - - if self.training and self.track_running_stats: - if self.num_batches_tracked is not None: # type: ignore - self.num_batches_tracked = \ - self.num_batches_tracked + 1 # type: ignore - if self.momentum is None: # use cumulative moving average - exponential_average_factor = 1.0 / float( - self.num_batches_tracked) - else: # use exponential moving average - exponential_average_factor = self.momentum - - if self.training: - bn_training = True - else: - bn_training = (self.running_mean is None) and (self.running_var is - None) - - if self.affine: - out_mask = self.mutable_num_features.current_mask.to( - self.weight.device) - weight = self.weight[out_mask] - bias = self.bias[out_mask] - else: - weight, bias = self.weight, self.bias - - if self.track_running_stats: - out_mask = self.mutable_num_features.current_mask.to( - self.running_mean.device) - running_mean = self.running_mean[out_mask] \ - if not self.training or self.track_running_stats else None - running_var = self.running_var[out_mask] \ - if not self.training or self.track_running_stats else None - else: - running_mean, running_var = self.running_mean, self.running_var - - return F.batch_norm(input, running_mean, running_var, weight, bias, - bn_training, exponential_average_factor, self.eps) - - # TODO - def to_static_op(self) -> nn.Module: - return self - - -class DynamicInstanceNorm(_InstanceNorm, ChannelDynamicOP): - """Applies Instance Normalization over an input according to the - `mutable_num_features` dynamically. - - Args: - num_features_cfg (Dict): Config related to `num_features`. - """ - accepted_mutables = {'mutable_num_features'} - - def __init__(self, num_features_cfg, *args, **kwargs): - super(DynamicInstanceNorm, self).__init__(*args, **kwargs) - - num_features_cfg_ = copy.deepcopy(num_features_cfg) - num_features_cfg_.update(dict(num_channels=self.num_features)) - self.mutable_num_features = MODELS.build(num_features_cfg_) - - @property - def mutable_in(self): - """Mutable `num_features`.""" - return self.mutable_num_features - - @property - def mutable_out(self): - """Mutable `num_features`.""" - return self.mutable_num_features - - def forward(self, input: Tensor) -> Tensor: - """Slice the parameters according to `mutable_num_features`, and - forward.""" - if self.affine: - out_mask = self.mutable_num_features.current_mask.to( - self.weight.device) - weight = self.weight[out_mask] - bias = self.bias[out_mask] - else: - weight, bias = self.weight, self.bias - - if self.track_running_stats: - out_mask = self.mutable_num_features.current_mask.to( - self.running_mean.device) - running_mean = self.running_mean[out_mask] - running_var = self.running_var[out_mask] - else: - running_mean, running_var = self.running_mean, self.running_var - - return F.instance_norm(input, running_mean, running_var, weight, bias, - self.training or not self.track_running_stats, - self.momentum, self.eps) - - # TODO - def to_static_op(self) -> nn.Module: - return self - - -class DynamicGroupNorm(GroupNorm, ChannelDynamicOP): - """Applies Group Normalization over a mini-batch of inputs according to the - `mutable_num_channels` dynamically. - - Args: - num_channels_cfg (Dict): Config related to `num_channels`. - """ - accepted_mutables = {'mutable_num_features'} - - def __init__(self, num_channels_cfg, *args, **kwargs): - super(DynamicGroupNorm, self).__init__(*args, **kwargs) - - num_channels_cfg_ = copy.deepcopy(num_channels_cfg) - num_channels_cfg_.update(dict(num_channels=self.num_channels)) - self.mutable_num_channels = MODELS.build(num_channels_cfg_) - - @property - def mutable_in(self): - """Mutable `num_channels`.""" - return self.mutable_num_channels - - @property - def mutable_out(self): - """Mutable `num_channels`.""" - return self.mutable_num_channels - - def forward(self, input: Tensor) -> Tensor: - """Slice the parameters according to `mutable_num_channels`, and - forward.""" - if self.affine: - out_mask = self.mutable_num_channels.current_mask.to( - self.weight.device) - weight = self.weight[out_mask] - bias = self.bias[out_mask] - else: - weight, bias = self.weight, self.bias - - return F.group_norm(input, self.num_groups, weight, bias, self.eps) - - # TODO - def to_static_op(self) -> nn.Module: - return self diff --git a/mmrazor/models/architectures/dynamic_ops/head/__init__.py b/mmrazor/models/architectures/dynamic_ops/head/__init__.py deleted file mode 100644 index ef101fec6..000000000 --- a/mmrazor/models/architectures/dynamic_ops/head/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. diff --git a/mmrazor/models/architectures/dynamic_ops/mixins/__init__.py b/mmrazor/models/architectures/dynamic_ops/mixins/__init__.py new file mode 100644 index 000000000..7a5097bc5 --- /dev/null +++ b/mmrazor/models/architectures/dynamic_ops/mixins/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .dynamic_conv_mixins import DynamicConvMixin +from .dynamic_mixins import (DynamicBatchNormMixin, DynamicChannelMixin, + DynamicLinearMixin, DynamicMixin) + +__all__ = [ + 'DynamicChannelMixin', 'DynamicBatchNormMixin', 'DynamicLinearMixin', + 'DynamicMixin', 'DynamicConvMixin' +] diff --git a/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_conv_mixins.py b/mmrazor/models/architectures/dynamic_ops/mixins/dynamic_conv_mixins.py similarity index 100% rename from mmrazor/models/architectures/dynamic_ops/bricks/dynamic_conv_mixins.py rename to mmrazor/models/architectures/dynamic_ops/mixins/dynamic_conv_mixins.py diff --git a/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_mixins.py b/mmrazor/models/architectures/dynamic_ops/mixins/dynamic_mixins.py similarity index 100% rename from mmrazor/models/architectures/dynamic_ops/bricks/dynamic_mixins.py rename to mmrazor/models/architectures/dynamic_ops/mixins/dynamic_mixins.py diff --git a/mmrazor/models/architectures/dynamic_ops/slimmable_dynamic_ops.py b/mmrazor/models/architectures/dynamic_ops/slimmable_dynamic_ops.py deleted file mode 100644 index a85e39af3..000000000 --- a/mmrazor/models/architectures/dynamic_ops/slimmable_dynamic_ops.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy -from typing import Dict - -import torch.nn as nn - -from mmrazor.models.mutables.mutable_channel import MutableChannel -from mmrazor.registry import MODELS -from .base import DynamicOP - - -class SwitchableBatchNorm2d(nn.Module, DynamicOP): - """Employs independent batch normalization for different switches in a - slimmable network. - - To train slimmable networks, ``SwitchableBatchNorm2d`` privatizes all - batch normalization layers for each switch in a slimmable network. - Compared with the naive training approach, it solves the problem of feature - aggregation inconsistency between different switches by independently - normalizing the feature mean and variance during testing. - - Args: - module_name (str): Name of this `SwitchableBatchNorm2d`. - num_features_cfg (Dict): Config related to `num_features`. - eps (float): A value added to the denominator for numerical stability. - Same as that in :obj:`torch.nn._BatchNorm`. Default: 1e-5 - momentum (float): The value used for the running_mean and running_var - computation. Can be set to None for cumulative moving average - (i.e. simple average). Same as that in :obj:`torch.nn._BatchNorm`. - Default: 0.1 - affine (bool): A boolean value that when set to True, this module has - learnable affine parameters. Same as that in - :obj:`torch.nn._BatchNorm`. Default: True - track_running_stats (bool): A boolean value that when set to True, this - module tracks the running mean and variance, and when set to False, - this module does not track such statistics, and initializes - statistics buffers running_mean and running_var as None. When these - buffers are None, this module always uses batch statistics. in both - training and eval modes. Same as that in - :obj:`torch.nn._BatchNorm`. Default: True - """ - - def __init__(self, - num_features_cfg: Dict, - eps: float = 1e-5, - momentum: float = 0.1, - affine: bool = True, - track_running_stats: bool = True): - super(SwitchableBatchNorm2d, self).__init__() - - num_features_cfg = copy.deepcopy(num_features_cfg) - candidate_choices = num_features_cfg.pop('candidate_choices') - num_features_cfg.update(dict(num_channels=max(candidate_choices))) - - bns = [ - nn.BatchNorm2d(num_features, eps, momentum, affine, - track_running_stats) - for num_features in candidate_choices - ] - self.bns = nn.ModuleList(bns) - - self.mutable_num_features = MODELS.build(num_features_cfg) - - @property - def mutable_in(self) -> MutableChannel: - """Mutable `num_features`.""" - return self.mutable_num_features - - @property - def mutable_out(self) -> MutableChannel: - """Mutable `num_features`.""" - return self.mutable_num_features - - def forward(self, input): - """Forward computation according to the current switch of the slimmable - networks.""" - idx = self.mutable_num_features.current_choice - return self.bns[idx](input) - - def to_static_op(self) -> nn.Module: - bn_idx = self.mutable_num_features.current_choice - - return self.bns[bn_idx] diff --git a/mmrazor/models/mutables/__init__.py b/mmrazor/models/mutables/__init__.py index 8274a194a..44796ec35 100644 --- a/mmrazor/models/mutables/__init__.py +++ b/mmrazor/models/mutables/__init__.py @@ -2,8 +2,7 @@ from .base_mutable import BaseMutable from .derived_mutable import DerivedMutable from .mutable_channel import (BaseMutableChannel, MutableChannelContainer, - OneShotMutableChannel, SimpleMutableChannel, - SquentialMutableChannel) + SimpleMutableChannel, SquentialMutableChannel) from .mutable_channel.groups import (ChannelGroupType, L1MutableChannelGroup, MutableChannelGroup, OneShotMutableChannelGroup, @@ -20,5 +19,5 @@ 'L1MutableChannelGroup', 'OneShotMutableChannelGroup', 'SimpleMutableChannel', 'MutableChannelGroup', 'SlimmableChannelGroup', 'BaseMutableChannel', 'MutableChannelContainer', 'ChannelGroupType', - 'SquentialMutableChannel', 'BaseMutable', 'OneShotMutableChannel' + 'SquentialMutableChannel', 'BaseMutable' ] diff --git a/mmrazor/models/mutables/derived_mutable.py b/mmrazor/models/mutables/derived_mutable.py index 5e991e9fe..98f680ee9 100644 --- a/mmrazor/models/mutables/derived_mutable.py +++ b/mmrazor/models/mutables/derived_mutable.py @@ -41,25 +41,27 @@ def current_mask(self) -> Tensor: """Current mask.""" -def _expand_choice_fn(mutable: MutableProtocol, expand_ratio: int) -> Callable: +def _expand_choice_fn(mutable: MutableProtocol, + expand_ratio: Union[int, float]) -> Callable: """Helper function to build `choice_fn` for expand derived mutable.""" def fn(): - return mutable.current_choice * expand_ratio + return int(mutable.current_choice * expand_ratio) return fn -def _expand_mask_fn(mutable: MutableProtocol, - expand_ratio: int) -> Callable: # pragma: no cover +def _expand_mask_fn( + mutable: MutableProtocol, + expand_ratio: Union[int, float]) -> Callable: # pragma: no cover """Helper function to build `mask_fn` for expand derived mutable.""" if not hasattr(mutable, 'current_mask'): raise ValueError('mutable must have attribute `currnet_mask`') def fn(): mask = mutable.current_mask - expand_num_channels = mask.size(0) * expand_ratio - expand_choice = mutable.current_choice * expand_ratio + expand_num_channels = int(mask.size(0) * expand_ratio) + expand_choice = int(mutable.current_choice * expand_ratio) expand_mask = torch.zeros(expand_num_channels).bool() expand_mask[:expand_choice] = True @@ -131,8 +133,9 @@ def derive_same_mutable(self: MutableProtocol) -> 'DerivedMutable': """Derive same mutable as the source.""" return self.derive_expand_mutable(expand_ratio=1) - def derive_expand_mutable(self: MutableProtocol, - expand_ratio: int) -> 'DerivedMutable': + def derive_expand_mutable( + self: MutableProtocol, + expand_ratio: Union[int, float]) -> 'DerivedMutable': """Derive expand mutable, usually used with `expand_ratio`.""" choice_fn = _expand_choice_fn(self, expand_ratio=expand_ratio) @@ -198,21 +201,18 @@ class DerivedMutable(BaseMutable[CHOICE_TYPE, CHOICE_TYPE], and `Pretrained`. Defaults to None. Examples: - >>> from mmrazor.models.mutables import OneShotMutableChannel - >>> mutable_channel = OneShotMutableChannel( - ... num_channels=3, - ... candidate_choices=[1, 2, 3], - ... candidate_mode='number') + >>> from mmrazor.models.mutables import SquentialMutableChannel + >>> mutable_channel = SquentialMutableChannel(num_channels=3) >>> # derive expand mutable >>> derived_mutable_channel = mutable_channel * 2 >>> # source mutables will be traced automatically >>> derived_mutable_channel.source_mutables - {OneShotMutableChannel(name=unbind, num_channels=3, current_choice=3, choices=[1, 2, 3], activated_channels=3, concat_mutable_name=[])} # noqa: E501 + {SquentialMutableChannel(name=unbind, num_channels=3, current_choice=3)} # noqa: E501 >>> # modify `current_choice` of `mutable_channel` >>> mutable_channel.current_choice = 2 >>> # `current_choice` and `current_mask` of derived mutable will be modified automatically # noqa: E501 >>> derived_mutable_channel - DerivedMutable(current_choice=4, activated_channels=4, source_mutables={OneShotMutableChannel(name=unbind, num_channels=3, current_choice=2, choices=[1, 2, 3], activated_channels=2, concat_mutable_name=[])}, is_fixed=False) # noqa: E501 + DerivedMutable(current_choice=4, activated_channels=4, source_mutables={SquentialMutableChannel(name=unbind, num_channels=3, current_choice=2)}, is_fixed=False) # noqa: E501 """ def __init__(self, diff --git a/mmrazor/models/mutables/mutable_channel/__init__.py b/mmrazor/models/mutables/mutable_channel/__init__.py index 37b7e5ff1..10086bbe2 100644 --- a/mmrazor/models/mutables/mutable_channel/__init__.py +++ b/mmrazor/models/mutables/mutable_channel/__init__.py @@ -3,18 +3,14 @@ from .groups import (ChannelGroupType, L1MutableChannelGroup, MutableChannelGroup, OneShotMutableChannelGroup, SequentialMutableChannelGroup, SlimmableChannelGroup) -from .mutable_channel import MutableChannel from .mutable_channel_container import MutableChannelContainer -from .one_shot_mutable_channel import OneShotMutableChannel from .sequential_mutable_channel import SquentialMutableChannel from .simple_mutable_channel import SimpleMutableChannel -from .slimmable_mutable_channel import SlimmableMutableChannel __all__ = [ 'SimpleMutableChannel', 'L1MutableChannelGroup', 'SequentialMutableChannelGroup', 'MutableChannelGroup', 'OneShotMutableChannelGroup', 'SlimmableChannelGroup', 'BaseMutableChannel', 'MutableChannelContainer', 'SquentialMutableChannel', - 'ChannelGroupType', 'MutableChannel', 'OneShotMutableChannel', - 'SlimmableMutableChannel' + 'ChannelGroupType' ] diff --git a/mmrazor/models/mutables/mutable_channel/groups/channel_group.py b/mmrazor/models/mutables/mutable_channel/groups/channel_group.py index aa6a8bf91..29c9b07ab 100644 --- a/mmrazor/models/mutables/mutable_channel/groups/channel_group.py +++ b/mmrazor/models/mutables/mutable_channel/groups/channel_group.py @@ -17,8 +17,7 @@ import torch.nn as nn from torch.nn import Module -from mmrazor.models.architectures.dynamic_ops.bricks.dynamic_mixins import \ - DynamicChannelMixin +from mmrazor.models.architectures.dynamic_ops.mixins import DynamicChannelMixin from mmrazor.structures.graph import ModuleGraph, ModuleNode from mmrazor.utils import IndexDict from ..base_mutable_channel import BaseMutableChannel diff --git a/mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py b/mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py index a40b87187..583c462d8 100644 --- a/mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py +++ b/mmrazor/models/mutables/mutable_channel/groups/mutable_channel_group.py @@ -1,12 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. """This module defines MutableChannelGroup.""" import abc -from typing import Dict, List, Type, TypeVar, Union +from typing import Any, Dict, List, Type, TypeVar, Union import torch.nn as nn from mmengine.model import BaseModule -from mmrazor.models.architectures.dynamic_ops.bricks import DynamicChannelMixin +import mmrazor.models.architectures.dynamic_ops as dynamic_ops from mmrazor.models.mutables import DerivedMutable from mmrazor.models.mutables.mutable_channel.base_mutable_channel import \ BaseMutableChannel @@ -69,7 +69,7 @@ def traverse(channels: List[Channel]): if channel.is_mutable is False: all_channel_prunable = False break - if isinstance(channel.module, DynamicChannelMixin): + if isinstance(channel.module, dynamic_ops.DynamicChannelMixin): has_dynamic_op = True return has_dynamic_op, all_channel_prunable @@ -144,9 +144,8 @@ def _get_int_choice(self, choice: Union[int, float]) -> int: assert 0 < choice <= self.num_channels, f'{choice}' return choice - def _replace_with_dynamic_ops( - self, model: nn.Module, - dynamicop_map: Dict[Type[nn.Module], Type[DynamicChannelMixin]]): + def _replace_with_dynamic_ops(self, model: nn.Module, + dynamicop_map: Dict[Type[nn.Module], Any]): """Replace torch modules with dynamic-ops.""" def replace_op(model: nn.Module, name: str, module: nn.Module): @@ -178,7 +177,7 @@ def _register_channel_container( model: nn.Module, container_class: Type[MutableChannelContainer]): """register channel container for dynamic ops.""" for module in model.modules(): - if isinstance(module, DynamicChannelMixin): + if isinstance(module, dynamic_ops.DynamicChannelMixin): if module.get_mutable_attr('in_channels') is None: in_channels = 0 if isinstance(module, nn.Conv2d): @@ -209,7 +208,7 @@ def _register_mutable_channel(self, mutable_channel: BaseMutableChannel): # register mutable_channel for channel in self.input_related + self.output_related: module = channel.module - if isinstance(module, DynamicChannelMixin): + if isinstance(module, dynamic_ops.DynamicChannelMixin): container: MutableChannelContainer if channel.output_related and module.get_mutable_attr( 'out_channels') is not None: diff --git a/mmrazor/models/mutables/mutable_channel/groups/sequential_mutable_channel_group.py b/mmrazor/models/mutables/mutable_channel/groups/sequential_mutable_channel_group.py index a19fd9452..e8045e242 100644 --- a/mmrazor/models/mutables/mutable_channel/groups/sequential_mutable_channel_group.py +++ b/mmrazor/models/mutables/mutable_channel/groups/sequential_mutable_channel_group.py @@ -6,8 +6,7 @@ import torch.nn as nn from mmengine import MMLogger -from mmrazor.models.architectures.dynamic_ops.bricks import ( - DynamicBatchNorm2d, DynamicConv2d, DynamicLinear) +import mmrazor.models.architectures.dynamic_ops as dynamic_ops from mmrazor.models.utils import make_divisible from mmrazor.registry import MODELS from ..mutable_channel_container import MutableChannelContainer @@ -53,9 +52,9 @@ def prepare_for_pruning(self, model: nn.Module): # register MutableMask self._replace_with_dynamic_ops( model, { - nn.Conv2d: DynamicConv2d, - nn.BatchNorm2d: DynamicBatchNorm2d, - nn.Linear: DynamicLinear + nn.Conv2d: dynamic_ops.DynamicConv2d, + nn.BatchNorm2d: dynamic_ops.DynamicBatchNorm2d, + nn.Linear: dynamic_ops.DynamicLinear }) self._register_channel_container(model, MutableChannelContainer) self._register_mutable_channel(self.mutable_channel) diff --git a/mmrazor/models/mutables/mutable_channel/groups/slimmable_channel_group.py b/mmrazor/models/mutables/mutable_channel/groups/slimmable_channel_group.py index 50525743c..976770193 100644 --- a/mmrazor/models/mutables/mutable_channel/groups/slimmable_channel_group.py +++ b/mmrazor/models/mutables/mutable_channel/groups/slimmable_channel_group.py @@ -4,8 +4,7 @@ import torch.nn as nn -from mmrazor.models.architectures.dynamic_ops.bricks import ( - DynamicConv2d, DynamicLinear, SwitchableBatchNorm2d) +import mmrazor.models.architectures.dynamic_ops as dynamic_ops from mmrazor.registry import MODELS from ..mutable_channel_container import MutableChannelContainer from .one_shot_mutable_channel_group import OneShotMutableChannelGroup @@ -43,9 +42,9 @@ def prepare_for_pruning(self, model: nn.Module): """Prepare for pruning.""" self._replace_with_dynamic_ops( model, { - nn.Conv2d: DynamicConv2d, - nn.BatchNorm2d: SwitchableBatchNorm2d, - nn.Linear: DynamicLinear + nn.Conv2d: dynamic_ops.DynamicConv2d, + nn.BatchNorm2d: dynamic_ops.SwitchableBatchNorm2d, + nn.Linear: dynamic_ops.DynamicLinear }) self.alter_candidates_of_switchbn(self.candidate_choices) self._register_channel_container(model, MutableChannelContainer) @@ -54,7 +53,7 @@ def prepare_for_pruning(self, model: nn.Module): def alter_candidates_of_switchbn(self, candidates): """Change candidates of SwitchableBatchNorm2d.""" for channel in self.output_related + self.input_related: - if isinstance(channel.module, SwitchableBatchNorm2d) and \ - len(channel.module.candidate_bn) == 0: + if isinstance(channel.module, dynamic_ops.SwitchableBatchNorm2d) \ + and len(channel.module.candidate_bn) == 0: channel.module.init_candidates(candidates) self.current_choice = self.max_choice diff --git a/mmrazor/models/mutables/mutable_channel/mutable_channel.py b/mmrazor/models/mutables/mutable_channel/mutable_channel.py deleted file mode 100644 index af2bf2188..000000000 --- a/mmrazor/models/mutables/mutable_channel/mutable_channel.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from abc import abstractmethod -from typing import List - -import torch - -from ..base_mutable import CHOICE_TYPE, CHOSEN_TYPE, BaseMutable -from ..derived_mutable import DerivedMethodMixin - - -class MutableChannel(BaseMutable[CHOICE_TYPE, CHOSEN_TYPE], - DerivedMethodMixin): - """A type of ``MUTABLES`` for single path supernet such as AutoSlim. In - single path supernet, each module only has one choice invoked at the same - time. A path is obtained by sampling all the available choices. It is the - base class for one shot channel mutables. - - Args: - num_channels (int): The raw number of channels. - init_cfg (dict, optional): initialization configuration dict for - ``BaseModule``. OpenMMLab has implement 5 initializer including - `Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`, - and `Pretrained`. - """ - - def __init__(self, num_channels: int, **kwargs): - super().__init__(**kwargs) - - self.num_channels = num_channels - self._same_mutables: List[MutableChannel] = list() - - # If the input of a module is a concatenation of several modules' - # outputs, we add the mutable out of these modules to the - # `concat_parent_mutables` of this module. - self.concat_parent_mutables: List[MutableChannel] = list() - self.name = 'unbind' - - @property - def same_mutables(self): - """Mutables in `same_mutables` and the current mutable should change - Synchronously.""" - return self._same_mutables - - def register_same_mutable(self, mutable): - """Register the input mutable in `same_mutables`.""" - if isinstance(mutable, list): - # Add a concatenation of mutables to `concat_parent_mutables`. - self.concat_parent_mutables = mutable - return - - if self == mutable: - return - if mutable in self._same_mutables: - return - - self._same_mutables.append(mutable) - for s_mutable in self._same_mutables: - s_mutable.register_same_mutable(mutable) - mutable.register_same_mutable(s_mutable) - - @abstractmethod - def convert_choice_to_mask(self, choice: CHOICE_TYPE) -> torch.Tensor: - """Get the mask according to the input choice.""" - pass - - @property - def current_mask(self): - """The current mask. - - We slice the registered parameters and buffers of a ``nn.Module`` - according to the mask of the corresponding channel mutable. - """ - if len(self.concat_parent_mutables) > 0: - # If the input of a module is a concatenation of several modules' - # outputs, the in_mask of this module is the concatenation of - # these modules' out_mask. - return torch.cat([ - mutable.current_mask for mutable in self.concat_parent_mutables - ]) - else: - return self.convert_choice_to_mask(self.current_choice) - - def bind_mutable_name(self, name: str) -> None: - """Bind a MutableChannel to its name. - - Args: - name (str): Name of this `MutableChannel`. - """ - self.name = name - - def fix_chosen(self, chosen: CHOSEN_TYPE) -> None: - """Fix mutable with subnet config. This operation would convert - `unfixed` mode to `fixed` mode. The :attr:`is_fixed` will be set to - True and only the selected operations can be retained. - - Args: - chosen (str): The chosen key in ``MUTABLE``. Defaults to None. - """ - if self.is_fixed: - raise AttributeError( - 'The mode of current MUTABLE is `fixed`. ' - 'Please do not call `fix_chosen` function again.') - - self.is_fixed = True - - def __repr__(self): - concat_mutable_name = [ - mutable.name for mutable in self.concat_parent_mutables - ] - repr_str = self.__class__.__name__ - repr_str += f'(name={self.name}, ' - repr_str += f'num_channels={self.num_channels}, ' - repr_str += f'concat_mutable_name={concat_mutable_name})' - return repr_str diff --git a/mmrazor/models/mutables/mutable_channel/one_shot_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/one_shot_mutable_channel.py deleted file mode 100644 index 7f6eea3ad..000000000 --- a/mmrazor/models/mutables/mutable_channel/one_shot_mutable_channel.py +++ /dev/null @@ -1,214 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Callable, Dict, List, Optional, Union - -import numpy as np -import torch - -from mmrazor.registry import MODELS -from ..derived_mutable import DerivedMutable -from .mutable_channel import MutableChannel - - -@MODELS.register_module() -class OneShotMutableChannel(MutableChannel[int, Dict]): - """A type of ``MUTABLES`` for single path supernet such as AutoSlim. In - single path supernet, each module only has one choice invoked at the same - time. A path is obtained by sampling all the available choices. It is the - base class for one shot mutable channel. - - Args: - num_channels (int): The raw number of channels. - candidate_choices (List): If `candidate_mode` is "ratio", - candidate_choices is a list of candidate width ratios. If - `candidate_mode` is "number", candidate_choices is a list of - candidate channel number. We note that the width ratio is the ratio - between the number of reserved channels and that of all channels in - a layer. - For example, if `ratios` is [0.25, 0.5], there are 2 cases - for us to choose from when we sample from a layer with 12 channels. - One is sampling the very first 3 channels in this layer, another is - sampling the very first 6 channels in this layer. - candidate_mode (str): One of "ratio" or "number". - init_cfg (dict, optional): initialization configuration dict for - ``BaseModule``. OpenMMLab has implement 5 initializer including - `Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`, - and `Pretrained`. - """ - - def __init__(self, - num_channels: int, - candidate_choices: List[Union[int, float]], - candidate_mode: str = 'ratio', - init_cfg: Optional[Dict] = None): - super(OneShotMutableChannel, self).__init__( - num_channels=num_channels, init_cfg=init_cfg) - - self._current_choice = num_channels - assert len(candidate_choices) > 0, \ - f'Number of candidate choices must be greater than 0, ' \ - f'but got: {len(candidate_choices)}' - self._candidate_choices = candidate_choices - assert candidate_mode in ['ratio', 'number'] - self._candidate_mode = candidate_mode - - self._check_candidate_choices() - - def _check_candidate_choices(self): - """Check if the input `candidate_choices` is valid.""" - if self._candidate_mode == 'number': - assert all([num > 0 and num <= self.num_channels - for num in self._candidate_choices]), \ - f'The candidate channel numbers should be in ' \ - f'range(0, {self.num_channels}].' - assert all([isinstance(num, int) - for num in self._candidate_choices]), \ - 'Type of `candidate_choices` should be int.' - else: - assert all([ - ratio > 0 and ratio <= 1 for ratio in self._candidate_choices - ]), 'The candidate ratio should be in range(0, 1].' - - def sample_choice(self) -> int: - """Sample an arbitrary selection from candidate choices. - - Returns: - int: The chosen number of channels. - """ - assert len(self.concat_parent_mutables) == 0 - num_channels = np.random.choice(self.choices) - assert num_channels > 0, \ - f'Sampled number of channels in `Mutable` {self.name}' \ - f' should be a positive integer.' - return num_channels - - @property - def min_choice(self) -> int: - """Minimum number of channels.""" - assert len(self.concat_parent_mutables) == 0 - min_channels = min(self.choices) - assert min_channels > 0, \ - f'Minimum number of channels in `Mutable` {self.name}' \ - f' should be a positive integer.' - return min_channels - - @property - def max_choice(self) -> int: - """Maximum number of channels.""" - return max(self.choices) - - @property - def current_choice(self): - """The current choice of the mutable.""" - assert len(self.concat_parent_mutables) == 0 - return self._current_choice - - @current_choice.setter - def current_choice(self, choice: int): - """Set the current choice of the mutable.""" - assert choice in self.choices - self._current_choice = choice - - @property - def choices(self) -> List: - """list: all choices. """ - if self._candidate_mode == 'number': - return self._candidate_choices - candidate_choices = [ - round(ratio * self.num_channels) - for ratio in self._candidate_choices - ] - return candidate_choices - - @property - def num_choices(self) -> int: - return len(self.choices) - - def convert_choice_to_mask(self, choice: int) -> torch.Tensor: - """Get the mask according to the input choice.""" - num_channels = choice - mask = torch.zeros(self.num_channels).bool() - mask[:num_channels] = True - return mask - - def dump_chosen(self) -> Dict: - assert self.current_choice is not None - - return dict( - current_choice=self.current_choice, - origin_channels=self.num_channels) - - def fix_chosen(self, dumped_chosen: Dict) -> None: - if self.is_fixed: - raise RuntimeError('OneShotMutableChannel can not be fixed twice') - - current_choice = dumped_chosen['current_choice'] - origin_channels = dumped_chosen['origin_channels'] - - assert current_choice <= origin_channels - assert origin_channels == self.num_channels - - self.current_choice = current_choice - self.is_fixed = True - - def __repr__(self): - concat_mutable_name = [ - mutable.name for mutable in self.concat_parent_mutables - ] - repr_str = self.__class__.__name__ - repr_str += f'(name={self.name}, ' - repr_str += f'num_channels={self.num_channels}, ' - repr_str += f'current_choice={self.current_choice}, ' - repr_str += f'choices={self.choices}, ' - repr_str += f'activated_channels={self.current_mask.sum().item()}, ' - repr_str += f'concat_mutable_name={concat_mutable_name})' - return repr_str - - def __rmul__(self, other) -> DerivedMutable: - return self * other - - def __mul__(self, other) -> DerivedMutable: - if isinstance(other, int): - return self.derive_expand_mutable(other) - - from ..mutable_value import OneShotMutableValue - - def expand_choice_fn(mutable1: 'OneShotMutableChannel', - mutable2: OneShotMutableValue) -> Callable: - - def fn(): - return mutable1.current_choice * mutable2.current_choice - - return fn - - def expand_mask_fn(mutable1: 'OneShotMutableChannel', - mutable2: OneShotMutableValue) -> Callable: - - def fn(): - mask = mutable1.current_mask - max_expand_ratio = mutable2.max_choice - current_expand_ratio = mutable2.current_choice - expand_num_channels = mask.size(0) * max_expand_ratio - - expand_choice = mutable1.current_choice * current_expand_ratio - expand_mask = torch.zeros(expand_num_channels).bool() - expand_mask[:expand_choice] = True - - return expand_mask - - return fn - - if isinstance(other, OneShotMutableValue): - return DerivedMutable( - choice_fn=expand_choice_fn(self, other), - mask_fn=expand_mask_fn(self, other)) - - raise TypeError(f'Unsupported type {type(other)} for mul!') - - def __floordiv__(self, other) -> DerivedMutable: - if isinstance(other, int): - return self.derive_divide_mutable(other) - if isinstance(other, tuple): - assert len(other) == 2 - return self.derive_divide_mutable(*other) - - raise TypeError(f'Unsupported type {type(other)} for div!') diff --git a/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py index 3f9ea8cb6..1bcd00df3 100644 --- a/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py +++ b/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py @@ -1,7 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Callable + import torch from mmrazor.registry import MODELS +from ..derived_mutable import DerivedMutable from .base_mutable_channel import BaseMutableChannel # TODO discuss later @@ -54,3 +57,67 @@ def fix_chosen(self, chosen=...): def dump_chosen(self): """Dump chosen.""" return self.current_choice + + # def __mul__(self, other): + # """multiplication.""" + # if isinstance(other, int): + # return self.derive_expand_mutable(other) + # else: + # return None + + # def __floordiv__(self, other): + # """division.""" + # if isinstance(other, int): + # return self.derive_divide_mutable(other) + # else: + # return None + + def __rmul__(self, other) -> DerivedMutable: + return self * other + + def __mul__(self, other) -> DerivedMutable: + if isinstance(other, int) or isinstance(other, float): + return self.derive_expand_mutable(other) + + from ..mutable_value import OneShotMutableValue + + def expand_choice_fn(mutable1: 'SquentialMutableChannel', + mutable2: OneShotMutableValue) -> Callable: + + def fn(): + return mutable1.current_choice * mutable2.current_choice + + return fn + + def expand_mask_fn(mutable1: 'SquentialMutableChannel', + mutable2: OneShotMutableValue) -> Callable: + + def fn(): + mask = mutable1.current_mask + max_expand_ratio = mutable2.max_choice + current_expand_ratio = mutable2.current_choice + expand_num_channels = mask.size(0) * max_expand_ratio + + expand_choice = mutable1.current_choice * current_expand_ratio + expand_mask = torch.zeros(expand_num_channels).bool() + expand_mask[:expand_choice] = True + + return expand_mask + + return fn + + if isinstance(other, OneShotMutableValue): + return DerivedMutable( + choice_fn=expand_choice_fn(self, other), + mask_fn=expand_mask_fn(self, other)) + + raise TypeError(f'Unsupported type {type(other)} for mul!') + + def __floordiv__(self, other) -> DerivedMutable: + if isinstance(other, int): + return self.derive_divide_mutable(other) + if isinstance(other, tuple): + assert len(other) == 2 + return self.derive_divide_mutable(*other) + + raise TypeError(f'Unsupported type {type(other)} for div!') diff --git a/mmrazor/models/mutables/mutable_channel/slimmable_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/slimmable_mutable_channel.py deleted file mode 100644 index 5d4dec0e7..000000000 --- a/mmrazor/models/mutables/mutable_channel/slimmable_mutable_channel.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List, Optional - -import torch - -from mmrazor.registry import MODELS -from .mutable_channel import MutableChannel - - -@MODELS.register_module() -class SlimmableMutableChannel(MutableChannel[int, Dict[str, int]]): - """A type of ``MUTABLES`` to train several subnet together, such as the - retraining stage in AutoSlim. - - Notes: - We need to set `candidate_choices` after the instantiation of a - `SlimmableMutableChannel` by ourselves. - - Args: - num_channels (int): The raw number of channels. - init_cfg (dict, optional): initialization configuration dict for - ``BaseModule``. OpenMMLab has implement 5 initializer including - `Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`, - and `Pretrained`. - """ - - def __init__(self, num_channels: int, init_cfg: Optional[Dict] = None): - super(SlimmableMutableChannel, self).__init__( - num_channels=num_channels, init_cfg=init_cfg) - - self.num_channels = num_channels - - @property - def candidate_choices(self) -> List: - """A list of candidate channel numbers.""" - return self._candidate_choices - - @candidate_choices.setter - def candidate_choices(self, choices): - """Set the candidate channel numbers.""" - assert getattr(self, '_candidate_choices', None) is None, \ - f'candidate_choices can be set only when candidate_choices is ' \ - f'None, got: candidate_choices = {self._candidate_choices}' - - assert all([num > 0 and num <= self.num_channels - for num in choices]), \ - f'The candidate channel numbers should be in ' \ - f'range(0, {self.num_channels}].' - assert all([isinstance(num, int) for num in choices]), \ - 'Type of `candidate_choices` should be int.' - - self._candidate_choices = list(choices) - - @property - def choices(self) -> List[int]: - """Return all subnet indexes.""" - assert self._candidate_choices is not None - return list(range(len(self.candidate_choices))) - - def dump_chosen(self) -> Dict: - assert self.current_choice is not None - - return dict( - current_choice=self._candidate_choices[self.current_choice], - origin_channels=self.num_channels) - - def fix_chosen(self, dumped_chosen: Dict) -> None: - chosen = dumped_chosen['current_choice'] - origin_channels = dumped_chosen['origin_channels'] - - assert chosen <= origin_channels - - # TODO - # remove after remove `current_choice` - self.current_choice = self.candidate_choices.index(chosen) - self._chosen = chosen - - super().fix_chosen(chosen) - - @property - def num_choices(self) -> int: - return len(self.choices) - - def convert_choice_to_mask(self, choice: int) -> torch.Tensor: - """Get the mask according to the input choice.""" - if self.is_fixed: - num_channels = self._chosen - elif not hasattr(self, '_candidate_choices'): - # todo: we trace the supernet before set_candidate_choices. - # It's hacky - num_channels = self.num_channels - else: - num_channels = self.candidate_choices[choice] - mask = torch.zeros(self.num_channels).bool() - mask[:num_channels] = True - return mask diff --git a/mmrazor/models/mutables/mutable_value/mutable_value.py b/mmrazor/models/mutables/mutable_value/mutable_value.py index 748d83e78..49a0c870f 100644 --- a/mmrazor/models/mutables/mutable_value/mutable_value.py +++ b/mmrazor/models/mutables/mutable_value/mutable_value.py @@ -222,15 +222,15 @@ def __mul__(self, other) -> DerivedMutable: """Overload `*` operator. Args: - other (int, OneShotMutableChannel): Expand ratio or - OneShotMutableChannel. + other (int, SquentialMutableChannel): Expand ratio or + SquentialMutableChannel. Returns: DerivedMutable: Derived expand mutable. """ - from ..mutable_channel import OneShotMutableChannel + from ..mutable_channel import SquentialMutableChannel - if isinstance(other, OneShotMutableChannel): + if isinstance(other, SquentialMutableChannel): return other * self return super().__mul__(other) diff --git a/mmrazor/models/mutators/channel_mutator/base_channel_mutator.py b/mmrazor/models/mutators/channel_mutator/base_channel_mutator.py index 2a76d7e85..42466d13f 100644 --- a/mmrazor/models/mutators/channel_mutator/base_channel_mutator.py +++ b/mmrazor/models/mutators/channel_mutator/base_channel_mutator.py @@ -5,7 +5,7 @@ from mmengine import fileio from torch.nn import Module -from mmrazor.models.architectures.dynamic_ops.bricks import DynamicChannelMixin +from mmrazor.models.architectures.dynamic_ops import DynamicChannelMixin from mmrazor.models.mutables import (BaseMutableChannel, ChannelGroupType, DerivedMutable, MutableChannelContainer, MutableChannelGroup, diff --git a/mmrazor/models/mutators/utils/__init__.py b/mmrazor/models/mutators/utils/__init__.py deleted file mode 100644 index 33f94c667..000000000 --- a/mmrazor/models/mutators/utils/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -# yapf: disable -from .default_module_converters import (DEFAULT_MODULE_CONVERTERS, - dynamic_bn_converter, - dynamic_conv2d_converter, - dynamic_gn_converter, - dynamic_in_converter, - dynamic_linear_converter) -# yapf: enable -from .slimmable_bn_converter import switchable_bn_converter - -__all__ = [ - 'dynamic_conv2d_converter', 'dynamic_linear_converter', - 'dynamic_bn_converter', 'dynamic_in_converter', 'dynamic_gn_converter', - 'DEFAULT_MODULE_CONVERTERS', 'switchable_bn_converter' -] diff --git a/mmrazor/models/mutators/utils/default_module_converters.py b/mmrazor/models/mutators/utils/default_module_converters.py deleted file mode 100644 index fdfa5d266..000000000 --- a/mmrazor/models/mutators/utils/default_module_converters.py +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Callable, Dict, Optional - -from torch import nn -from torch.nn.modules import GroupNorm -from torch.nn.modules.batchnorm import _BatchNorm -from torch.nn.modules.instancenorm import _InstanceNorm - -from ...architectures import (DynamicBatchNorm, DynamicConv2d, - DynamicGroupNorm, DynamicInstanceNorm, - DynamicLinear) - - -def dynamic_conv2d_converter(module: nn.Conv2d, in_channels_cfg: Dict, - out_channels_cfg: Dict) -> DynamicConv2d: - """Convert a nn.Conv2d module to a DynamicConv2d. - - Args: - module (:obj:`torch.nn.Conv2d`): The original Conv2d module. - in_channels_cfg (Dict): Config related to `in_channels`. - out_channels_cfg (Dict): Config related to `out_channels`. - """ - dynamic_conv = DynamicConv2d( - in_channels_cfg=in_channels_cfg, - out_channels_cfg=out_channels_cfg, - in_channels=module.in_channels, - out_channels=module.out_channels, - kernel_size=module.kernel_size, - stride=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups, - bias=True if module.bias is not None else False, - padding_mode=module.padding_mode) - return dynamic_conv - - -def dynamic_linear_converter(module: nn.Linear, in_channels_cfg: Dict, - out_channels_cfg: Dict) -> DynamicLinear: - """Convert a nn.Linear module to a DynamicLinear. - - Args: - module (:obj:`torch.nn.Linear`): The original Linear module. - in_features_cfg (Dict): Config related to `in_features`. - out_features_cfg (Dict): Config related to `out_features`. - """ - dynamic_linear = DynamicLinear( - in_features_cfg=in_channels_cfg, - out_features_cfg=out_channels_cfg, - in_features=module.in_features, - out_features=module.out_features, - bias=True if module.bias is not None else False) - return dynamic_linear - - -def dynamic_bn_converter( - module: _BatchNorm, - in_channels_cfg: Dict, - out_channels_cfg: Optional[Dict] = None) -> DynamicBatchNorm: - """Convert a _BatchNorm module to a DynamicBatchNorm. - - Args: - module (:obj:`torch.nn._BatchNorm`): The original BatchNorm module. - num_features_cfg (Dict): Config related to `num_features`. - """ - dynamic_bn = DynamicBatchNorm( - num_features_cfg=in_channels_cfg, - num_features=module.num_features, - eps=module.eps, - momentum=module.momentum, - affine=module.affine, - track_running_stats=module.track_running_stats) - return dynamic_bn - - -def dynamic_in_converter( - module: _InstanceNorm, - in_channels_cfg: Dict, - out_channels_cfg: Optional[Dict] = None) -> DynamicInstanceNorm: - """Convert a _InstanceNorm module to a DynamicInstanceNorm. - - Args: - module (:obj:`torch.nn._InstanceNorm`): The original InstanceNorm - module. - num_features_cfg (Dict): Config related to `num_features`. - """ - dynamic_in = DynamicInstanceNorm( - num_features_cfg=in_channels_cfg, - num_features=module.num_features, - eps=module.eps, - momentum=module.momentum, - affine=module.affine, - track_running_stats=module.track_running_stats) - return dynamic_in - - -def dynamic_gn_converter( - module: GroupNorm, - in_channels_cfg: Dict, - out_channels_cfg: Optional[Dict] = None) -> DynamicGroupNorm: - """Convert a GroupNorm module to a DynamicGroupNorm. - - Args: - module (:obj:`torch.nn.GroupNorm`): The original GroupNorm module. - num_channels_cfg (Dict): Config related to `num_channels`. - """ - dynamic_gn = DynamicGroupNorm( - num_channels_cfg=in_channels_cfg, - num_channels=module.num_channels, - num_groups=module.num_groups, - eps=module.eps, - affine=module.affine) - return dynamic_gn - - -DEFAULT_MODULE_CONVERTERS: Dict[Callable, Callable] = { - nn.Conv2d: dynamic_conv2d_converter, - nn.Linear: dynamic_linear_converter, - nn.BatchNorm1d: dynamic_bn_converter, - nn.BatchNorm2d: dynamic_bn_converter, - nn.BatchNorm3d: dynamic_bn_converter, - nn.InstanceNorm1d: dynamic_in_converter, - nn.InstanceNorm2d: dynamic_in_converter, - nn.InstanceNorm3d: dynamic_in_converter, - nn.GroupNorm: dynamic_gn_converter -} diff --git a/mmrazor/models/mutators/utils/slimmable_bn_converter.py b/mmrazor/models/mutators/utils/slimmable_bn_converter.py deleted file mode 100644 index bef3077c1..000000000 --- a/mmrazor/models/mutators/utils/slimmable_bn_converter.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict - -from torch.nn.modules.batchnorm import _BatchNorm - -from mmrazor.models.architectures import SwitchableBatchNorm2d - - -def switchable_bn_converter(module: _BatchNorm, in_channels_cfg: Dict, - out_channels_cfg: Dict) -> SwitchableBatchNorm2d: - """Convert a _BatchNorm module to a SwitchableBatchNorm2d. - - Args: - module (:obj:`torch.nn.GroupNorm`): The original BatchNorm module. - num_channels_cfg (Dict): Config related to `num_features`. - """ - switchable_bn = SwitchableBatchNorm2d( - num_features_cfg=in_channels_cfg, - eps=module.eps, - momentum=module.momentum, - affine=module.affine, - track_running_stats=module.track_running_stats) - return switchable_bn diff --git a/mmrazor/structures/subnet/fix_subnet.py b/mmrazor/structures/subnet/fix_subnet.py index c29a0b181..9a0485592 100644 --- a/mmrazor/structures/subnet/fix_subnet.py +++ b/mmrazor/structures/subnet/fix_subnet.py @@ -10,7 +10,7 @@ def _dynamic_to_static(model: nn.Module) -> None: # Avoid circular import - from mmrazor.models.architectures.dynamic_ops.bricks import DynamicMixin + from mmrazor.models.architectures.dynamic_ops import DynamicMixin def traverse_children(module: nn.Module) -> None: # TODO @@ -37,7 +37,7 @@ def load_fix_subnet(model: nn.Module, raise TypeError('fix_mutable should be a `str` or `dict`' f'but got {type(fix_mutable)}') - from mmrazor.models.architectures.dynamic_ops.bricks import DynamicMixin + from mmrazor.models.architectures.dynamic_ops import DynamicMixin if isinstance(model, DynamicMixin): raise RuntimeError('Root model can not be dynamic op.') diff --git a/tests/data/models.py b/tests/data/models.py index ee2dd9479..b9fe04085 100644 --- a/tests/data/models.py +++ b/tests/data/models.py @@ -3,12 +3,12 @@ from torch import Tensor import torch.nn as nn import torch -from mmrazor.models.architectures.dynamic_ops.bricks import DynamicBatchNorm2d, DynamicConv2d, DynamicLinear, DynamicChannelMixin +from mmrazor.models.architectures.dynamic_ops import DynamicBatchNorm2d, DynamicConv2d, DynamicLinear, DynamicChannelMixin from mmrazor.models.mutables.mutable_channel import MutableChannelContainer from mmrazor.models.mutables import MutableChannelGroup from mmrazor.models.mutables import DerivedMutable from mmrazor.models.mutables import BaseMutable -from mmrazor.models.mutables import OneShotMutableChannelGroup, SquentialMutableChannel, SimpleMutableChannel +from mmrazor.models.mutables import OneShotMutableChannelGroup, SimpleMutableChannel from mmrazor.registry import MODELS from mmengine.model import BaseModel # this file includes models for tesing. diff --git a/tests/test_core/test_graph/test_graph.py b/tests/test_core/test_graph/test_graph.py index 35b4549d4..1383dccd8 100644 --- a/tests/test_core/test_graph/test_graph.py +++ b/tests/test_core/test_graph/test_graph.py @@ -5,7 +5,7 @@ import torch -from mmrazor.models.architectures.dynamic_ops.bricks import DynamicChannelMixin +from mmrazor.models.architectures.dynamic_ops.mixins import DynamicChannelMixin from mmrazor.structures.graph import ModuleGraph from ...data.models import Icep # noqa from ...data.models import MultipleUseModel # noqa diff --git a/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_conv.py b/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_conv.py index 6082817ea..8eab78af8 100644 --- a/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_conv.py +++ b/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_conv.py @@ -8,10 +8,10 @@ import torch from torch import nn -from mmrazor.models.architectures.dynamic_ops.bricks import (BigNasConv2d, - DynamicConv2d, - OFAConv2d) -from mmrazor.models.mutables import OneShotMutableChannel, OneShotMutableValue +from mmrazor.models.architectures.dynamic_ops import (BigNasConv2d, + DynamicConv2d, OFAConv2d) +from mmrazor.models.mutables import (OneShotMutableValue, + SquentialMutableChannel) from mmrazor.structures.subnet import export_fix_subnet, load_fix_subnet from ..utils import fix_dynamic_op @@ -39,10 +39,8 @@ def test_dynamic_conv2d_depthwise(self) -> None: with pytest.raises(ValueError): d_conv2d.register_mutable_attr('out_channels', mock_mutable) - mutable_in_channels = OneShotMutableChannel( - 10, candidate_choices=[4, 8, 10], candidate_mode='number') - mutable_out_channels = OneShotMutableChannel( - 10, candidate_choices=[4, 8, 10], candidate_mode='number') + mutable_in_channels = SquentialMutableChannel(10) + mutable_out_channels = SquentialMutableChannel(10) d_conv2d.register_mutable_attr('in_channels', mutable_in_channels) d_conv2d.register_mutable_attr('out_channels', mutable_out_channels) @@ -82,10 +80,8 @@ def test_dynamic_conv2d(bias: bool) -> None: x_max = torch.rand(10, 4, 224, 224) out_before_mutate = d_conv2d(x_max) - mutable_in_channels = OneShotMutableChannel( - 4, candidate_choices=[2, 3, 4], candidate_mode='number') - mutable_out_channels = OneShotMutableChannel( - 10, candidate_choices=[4, 8, 10], candidate_mode='number') + mutable_in_channels = SquentialMutableChannel(4) + mutable_out_channels = SquentialMutableChannel(10) d_conv2d.register_mutable_attr('in_channels', mutable_in_channels) d_conv2d.register_mutable_attr('out_channels', mutable_out_channels) @@ -128,8 +124,7 @@ def test_dynamic_conv2d_mutable_single_channels(is_mutate_in_channels: bool, out_channels: int) -> None: d_conv2d = DynamicConv2d( in_channels=10, out_channels=10, kernel_size=3, stride=1, bias=True) - mutable_channels = OneShotMutableChannel( - 10, candidate_choices=[4, 6, 10], candidate_mode='number') + mutable_channels = SquentialMutableChannel(10) if is_mutate_in_channels: d_conv2d.register_mutable_attr('in_channels', mutable_channels) @@ -172,10 +167,8 @@ def test_dynamic_conv2d_mutable_single_channels(is_mutate_in_channels: bool, def test_kernel_dynamic_conv2d(dynamic_class: Type[nn.Conv2d], kernel_size_list: bool) -> None: - mutable_in_channels = OneShotMutableChannel( - 10, candidate_choices=[4, 8, 10], candidate_mode='number') - mutable_out_channels = OneShotMutableChannel( - 10, candidate_choices=[4, 8, 10], candidate_mode='number') + mutable_in_channels = SquentialMutableChannel(10) + mutable_out_channels = SquentialMutableChannel(10) mutable_kernel_size = OneShotMutableValue(value_list=kernel_size_list) @@ -233,7 +226,7 @@ def test_kernel_dynamic_conv2d(dynamic_class: Type[nn.Conv2d], @pytest.mark.parametrize('dynamic_class', [OFAConv2d, BigNasConv2d]) def test_mutable_kernel_dynamic_conv2d_grad( dynamic_class: Type[nn.Conv2d]) -> None: - from mmrazor.models.architectures.dynamic_ops.bricks import \ + from mmrazor.models.architectures.dynamic_ops.mixins import \ dynamic_conv_mixins kernel_size_list = [3, 5, 7] diff --git a/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_linear.py b/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_linear.py index 0cdaa20b6..ece69ddc0 100644 --- a/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_linear.py +++ b/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_linear.py @@ -6,19 +6,18 @@ import torch from torch import nn -from mmrazor.models.architectures.dynamic_ops.bricks import ( # noqa - DynamicLinear, DynamicLinearMixin) -from mmrazor.models.mutables import OneShotMutableChannel +from mmrazor.models.mutables import SquentialMutableChannel from mmrazor.structures.subnet import export_fix_subnet, load_fix_subnet from ..utils import fix_dynamic_op +from mmrazor.models.architectures.dynamic_ops import ( # isort:skip + DynamicLinear, DynamicLinearMixin) + @pytest.mark.parametrize('bias', [True, False]) def test_dynamic_linear(bias) -> None: - mutable_in_features = OneShotMutableChannel( - 10, candidate_choices=[4, 8, 10], candidate_mode='number') - mutable_out_features = OneShotMutableChannel( - 10, candidate_choices=[4, 8, 10], candidate_mode='number') + mutable_in_features = SquentialMutableChannel(10) + mutable_out_features = SquentialMutableChannel(10) d_linear = DynamicLinear(in_features=10, out_features=10, bias=bias) @@ -77,8 +76,7 @@ def test_dynamic_linear_mutable_single_features( is_mutate_in_features: Optional[bool], in_features: int, out_features: int) -> None: d_linear = DynamicLinear(in_features=10, out_features=10, bias=True) - mutable_channels = OneShotMutableChannel( - 10, candidate_choices=[4, 6, 10], candidate_mode='number') + mutable_channels = SquentialMutableChannel(10) if is_mutate_in_features is not None: if is_mutate_in_features: diff --git a/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_norm.py b/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_norm.py index e6cfe103f..ce6ae7b36 100644 --- a/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_norm.py +++ b/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_norm.py @@ -6,9 +6,11 @@ import torch from torch import nn -from mmrazor.models.architectures.dynamic_ops.bricks import ( - DynamicBatchNorm1d, DynamicBatchNorm2d, DynamicBatchNorm3d, DynamicMixin) -from mmrazor.models.mutables import OneShotMutableChannel +from mmrazor.models.architectures.dynamic_ops import (DynamicBatchNorm1d, + DynamicBatchNorm2d, + DynamicBatchNorm3d, + DynamicMixin) +from mmrazor.models.mutables import SquentialMutableChannel from mmrazor.structures.subnet import export_fix_subnet, load_fix_subnet from ..utils import fix_dynamic_op @@ -22,8 +24,7 @@ def test_dynamic_bn(dynamic_class: Type[nn.modules.batchnorm._BatchNorm], input_shape: Tuple[int], affine: bool, track_running_stats: bool) -> None: - mutable_num_features = OneShotMutableChannel( - 10, candidate_choices=[4, 8, 10], candidate_mode='number') + mutable_num_features = SquentialMutableChannel(10) d_bn = dynamic_class( num_features=10, @@ -87,8 +88,7 @@ def test_bn_track_running_stats( dynamic_class: Type[nn.modules.batchnorm._BatchNorm], input_shape: Tuple[int], ) -> None: - mutable_num_features = OneShotMutableChannel( - 10, candidate_choices=[4, 8, 10], candidate_mode='number') + mutable_num_features = SquentialMutableChannel(10) mutable_num_features.current_choice = 8 d_bn = dynamic_class( num_features=10, track_running_stats=True, affine=False) diff --git a/tests/test_models/test_architectures/test_dynamic_op/test_default_dynamic_op.py b/tests/test_models/test_architectures/test_dynamic_op/test_default_dynamic_op.py deleted file mode 100644 index 97277d03a..000000000 --- a/tests/test_models/test_architectures/test_dynamic_op/test_default_dynamic_op.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -"""from unittest import TestCase. - -import pytest -import torch - -from mmrazor.models.architectures import DynamicConv2d -from mmrazor.structures import export_fix_subnet, load_fix_subnet -from .utils import fix_dynamic_op - -class TestDefaultDynamicOP(TestCase): - - def test_dynamic_conv2d(self) -> None: - in_channels_cfg = dict(type='SlimmableMutableChannel', num_channels=4) - out_channels_cfg = dict( - type='SlimmableMutableChannel', num_channels=10) - - d_conv2d = DynamicConv2d( - in_channels_cfg, - out_channels_cfg, - in_channels=4, - out_channels=10, - kernel_size=3, - stride=1, - bias=True) - - d_conv2d.mutable_in.candidate_choices = [2, 3, 4] - d_conv2d.mutable_out.candidate_choices = [4, 8, 10] - - with pytest.raises(AssertionError): - d_conv2d.to_static_op() - - d_conv2d.mutable_in.current_choice = 1 - d_conv2d.mutable_out.current_choice = 0 - - x = torch.rand(10, 3, 224, 224) - out1 = d_conv2d(x) - self.assertEqual(out1.size(1), 4) - - fix_mutables = export_fix_subnet(d_conv2d) - with pytest.raises(RuntimeError): - load_fix_subnet(d_conv2d, fix_mutables) - fix_dynamic_op(d_conv2d, fix_mutables) - - out2 = d_conv2d(x) - self.assertTrue(torch.equal(out1, out2)) - - s_conv2d = d_conv2d.to_static_op() - out3 = s_conv2d(x) - - self.assertTrue(torch.equal(out1, out3)) - - def test_dynamic_conv2d_depthwise(self) -> None: - in_channels_cfg = dict(type='SlimmableMutableChannel', num_channels=10) - out_channels_cfg = dict( - type='SlimmableMutableChannel', num_channels=10) - - d_conv2d = DynamicConv2d( - in_channels_cfg, - out_channels_cfg, - in_channels=10, - out_channels=10, - groups=10, - kernel_size=3, - stride=1, - bias=True) - - d_conv2d.mutable_in.candidate_choices = [4, 8, 10] - d_conv2d.mutable_out.candidate_choices = [4, 8, 10] - - with pytest.raises(AssertionError): - d_conv2d.to_static_op() - - d_conv2d.mutable_in.current_choice = 1 - d_conv2d.mutable_out.current_choice = 1 - - x = torch.rand(10, 8, 224, 224) - out1 = d_conv2d(x) - self.assertEqual(out1.size(1), 8) - - fix_mutables = export_fix_subnet(d_conv2d) - with pytest.raises(RuntimeError): - load_fix_subnet(d_conv2d, fix_mutables) - fix_dynamic_op(d_conv2d, fix_mutables) - - out2 = d_conv2d(x) - self.assertTrue(torch.equal(out1, out2)) - - s_conv2d = d_conv2d.to_static_op() - out3 = s_conv2d(x) - - self.assertTrue(torch.equal(out1, out3)) -""" diff --git a/tests/test_models/test_architectures/test_dynamic_op/utils.py b/tests/test_models/test_architectures/test_dynamic_op/utils.py index 506fe1b5a..ceb2a5d4f 100644 --- a/tests/test_models/test_architectures/test_dynamic_op/utils.py +++ b/tests/test_models/test_architectures/test_dynamic_op/utils.py @@ -1,10 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import Dict, Optional -from mmrazor.models.architectures.dynamic_ops import DynamicOP +from mmrazor.models.architectures.dynamic_ops import DynamicMixin -def fix_dynamic_op(op: DynamicOP, fix_mutables: Optional[Dict] = None) -> None: +def fix_dynamic_op(op: DynamicMixin, + fix_mutables: Optional[Dict] = None) -> None: for name, mutable in op.mutable_attrs.items(): if fix_mutables is not None: diff --git a/tests/test_models/test_mutables/group/test_mutable_channel_groups.py b/tests/test_models/test_mutables/group/test_mutable_channel_groups.py index fbc61cddc..234b98d91 100644 --- a/tests/test_models/test_mutables/group/test_mutable_channel_groups.py +++ b/tests/test_models/test_mutables/group/test_mutable_channel_groups.py @@ -5,8 +5,7 @@ import torch import torch.nn as nn -from mmrazor.models.architectures.dynamic_ops.bricks.dynamic_mixins import \ - DynamicChannelMixin +from mmrazor.models.architectures.dynamic_ops.mixins import DynamicChannelMixin from mmrazor.models.mutables.mutable_channel import ( L1MutableChannelGroup, MutableChannelGroup, SequentialMutableChannelGroup) from mmrazor.models.mutables.mutable_channel.groups.channel_group import ( # noqa diff --git a/tests/test_models/test_mutables/test_channel_mutable.py b/tests/test_models/test_mutables/test_channel_mutable.py deleted file mode 100644 index fd808351c..000000000 --- a/tests/test_models/test_mutables/test_channel_mutable.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy -from unittest import TestCase - -import pytest -import torch - -from mmrazor.models import OneShotMutableChannel - - -class TestChannelMutables(TestCase): - - def test_mutable_channel_ratio(self): - with pytest.raises(AssertionError): - # Test invalid `candidate_mode` - OneShotMutableChannel( - num_channels=8, - candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0], - candidate_mode='xxx') - - with pytest.raises(AssertionError): - # Number of candidate choices must be greater than 0 - OneShotMutableChannel( - num_channels=8, - candidate_choices=list(), - candidate_mode='ratio') - - with pytest.raises(AssertionError): - # The candidate ratio should be in range(0, 1]. - OneShotMutableChannel( - num_channels=8, - candidate_choices=[0., 1 / 4, 2 / 4, 3 / 4, 1.0], - candidate_mode='ratio') - - with pytest.raises(AssertionError): - # Minimum number of channels should be a positive integer. - out_mutable = OneShotMutableChannel( - num_channels=8, - candidate_choices=[0.01, 1 / 4, 2 / 4, 3 / 4, 1.0], - candidate_mode='ratio') - out_mutable.bind_mutable_name('op') - _ = out_mutable.min_choice - - # Test mutable out - out_mutable = OneShotMutableChannel( - num_channels=8, - candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0], - candidate_mode='ratio') - - random_choice = out_mutable.sample_choice() - assert random_choice in [2, 4, 6, 8] - - max_choice = out_mutable.max_choice - assert max_choice == 8 - out_mutable.current_choice = max_choice - assert torch.equal(out_mutable.current_mask, - torch.ones_like(out_mutable.current_mask).bool()) - - min_choice = out_mutable.min_choice - assert min_choice == 2 - out_mutable.current_choice = min_choice - min_mask = torch.zeros_like(out_mutable.current_mask).bool() - min_mask[:2] = True - assert torch.equal(out_mutable.current_mask, min_mask) - - # Test mutable in with concat_mutable - in_mutable = OneShotMutableChannel( - num_channels=16, - candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0], - candidate_mode='ratio') - out_mutable1 = copy.deepcopy(out_mutable) - out_mutable2 = copy.deepcopy(out_mutable) - in_mutable.register_same_mutable([out_mutable1, out_mutable2]) - choice1 = out_mutable1.sample_choice() - out_mutable1.current_choice = choice1 - choice2 = out_mutable2.sample_choice() - out_mutable2.current_choice = choice2 - assert torch.equal( - in_mutable.current_mask, - torch.cat([out_mutable1.current_mask, out_mutable2.current_mask])) - - with pytest.raises(AssertionError): - # The mask of this in_mutable depends on the out mask of its - # `concat_mutables`, so the `sample_choice` method should not - # be called - in_mutable.sample_choice() - - with pytest.raises(AssertionError): - # The mask of this in_mutable depends on the out mask of its - # `concat_mutables`, so the `min_choice` property should not - # be called - _ = in_mutable.min_choice - - def test_mutable_channel_number(self): - with pytest.raises(AssertionError): - # The candidate ratio should be in range(0, `num_channels`]. - OneShotMutableChannel( - num_channels=8, - candidate_choices=[0, 2, 4, 6, 8], - candidate_mode='number') - - with pytest.raises(AssertionError): - # Type of `candidate_choices` should be int. - OneShotMutableChannel( - num_channels=8, - candidate_choices=[0., 2, 4, 6, 8], - candidate_mode='number') - - # Test mutable out - out_mutable = OneShotMutableChannel( - num_channels=8, - candidate_choices=[2, 4, 6, 8], - candidate_mode='number') - - random_choice = out_mutable.sample_choice() - assert random_choice in [2, 4, 6, 8] - - max_choice = out_mutable.max_choice - assert max_choice == 8 - out_mutable.current_choice = max_choice - assert torch.equal(out_mutable.current_mask, - torch.ones_like(out_mutable.current_mask).bool()) - - min_choice = out_mutable.min_choice - assert min_choice == 2 - out_mutable.current_choice = min_choice - min_mask = torch.zeros_like(out_mutable.current_mask).bool() - min_mask[:2] = True - assert torch.equal(out_mutable.current_mask, min_mask) diff --git a/tests/test_models/test_mutables/test_derived_mutable.py b/tests/test_models/test_mutables/test_derived_mutable.py index 99da8dc71..3e87b0654 100644 --- a/tests/test_models/test_mutables/test_derived_mutable.py +++ b/tests/test_models/test_mutables/test_derived_mutable.py @@ -4,18 +4,15 @@ import pytest import torch -from mmrazor.models.mutables import (DerivedMutable, OneShotMutableChannel, - OneShotMutableValue) +from mmrazor.models.mutables import (DerivedMutable, OneShotMutableValue, + SquentialMutableChannel) from mmrazor.models.mutables.base_mutable import BaseMutable class TestDerivedMutable(TestCase): def test_is_fixed(self) -> None: - mc = OneShotMutableChannel( - num_channels=10, - candidate_choices=[2, 8, 10], - candidate_mode='number') + mc = SquentialMutableChannel(num_channels=10) mc.current_choice = 2 mv = OneShotMutableValue(value_list=[2, 3, 4]) @@ -46,10 +43,7 @@ def test_fix_dump_chosen(self) -> None: derived_mutable.fix_chosen(derived_mutable.dump_chosen()) def test_derived_same_mutable(self) -> None: - mc = OneShotMutableChannel( - num_channels=3, - candidate_choices=[1, 2, 3], - candidate_mode='number') + mc = SquentialMutableChannel(num_channels=3) mc_derived = mc.derive_same_mutable() assert mc_derived.source_mutables == {mc} @@ -59,10 +53,8 @@ def test_derived_same_mutable(self) -> None: torch.tensor([1, 1, 0], dtype=torch.bool)) def test_mutable_concat_derived(self) -> None: - mc1 = OneShotMutableChannel( - num_channels=3, candidate_choices=[1, 3], candidate_mode='number') - mc2 = OneShotMutableChannel( - num_channels=4, candidate_choices=[1, 4], candidate_mode='number') + mc1 = SquentialMutableChannel(num_channels=3) + mc2 = SquentialMutableChannel(num_channels=4) ms = [mc1, mc2] mc_derived = DerivedMutable.derive_concat_mutable(ms) @@ -88,10 +80,7 @@ def test_mutable_concat_derived(self) -> None: _ = DerivedMutable.derive_concat_mutable(ms) def test_mutable_channel_derived(self) -> None: - mc = OneShotMutableChannel( - num_channels=3, - candidate_choices=[1, 2, 3], - candidate_mode='number') + mc = SquentialMutableChannel(num_channels=3) mc_derived = mc * 3 assert mc_derived.source_mutables == {mc} @@ -112,10 +101,7 @@ def test_mutable_channel_derived(self) -> None: mc_derived.current_mask.size()) def test_mutable_divide(self) -> None: - mc = OneShotMutableChannel( - num_channels=128, - candidate_choices=[112, 120, 128], - candidate_mode='number') + mc = SquentialMutableChannel(num_channels=128) mc_derived = mc // 8 assert mc_derived.source_mutables == {mc} @@ -138,14 +124,15 @@ def test_mutable_divide(self) -> None: assert mv_derived.current_choice == 16 def test_source_mutables(self) -> None: - useless_fn = lambda x: x # noqa: E731 + + def useless_fn(x): + return x # noqa: E731 + with pytest.raises(RuntimeError): _ = DerivedMutable(choice_fn=useless_fn) - mc1 = OneShotMutableChannel( - num_channels=3, candidate_choices=[1, 3], candidate_mode='number') - mc2 = OneShotMutableChannel( - num_channels=4, candidate_choices=[1, 4], candidate_mode='number') + mc1 = SquentialMutableChannel(num_channels=3) + mc2 = SquentialMutableChannel(num_channels=4) ms = [mc1, mc2] mc_derived1 = DerivedMutable.derive_concat_mutable(ms) @@ -180,8 +167,7 @@ def fn(): mask_fn=dict_closure_fn({2: [mc1, mc2]}, {3: dd_mutable})) assert ddd_mutable.source_mutables == mc_derived1.source_mutables - mc3 = OneShotMutableChannel( - num_channels=4, candidate_choices=[2, 4], candidate_mode='number') + mc3 = SquentialMutableChannel(num_channels=4) dddd_mutable = DerivedMutable( choice_fn=dict_closure_fn({ mc1: [2, 3], @@ -191,10 +177,8 @@ def fn(): assert dddd_mutable.source_mutables == {mc1, mc2, mc3} def test_nested_mutables(self) -> None: - source_a = OneShotMutableChannel( - num_channels=2, candidate_choices=[1, 2], candidate_mode='number') - source_b = OneShotMutableChannel( - num_channels=3, candidate_choices=[2, 3], candidate_mode='number') + source_a = SquentialMutableChannel(num_channels=2) + source_b = SquentialMutableChannel(num_channels=3) # derive from derived_c = source_a * 1 diff --git a/tests/test_models/test_mutables/test_dynamic_layer.py b/tests/test_models/test_mutables/test_dynamic_layer.py deleted file mode 100644 index 6864e5edd..000000000 --- a/tests/test_models/test_mutables/test_dynamic_layer.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from unittest import TestCase - -import torch -from torch import nn - -from mmrazor.models.mutators.utils import (dynamic_bn_converter, - dynamic_conv2d_converter, - dynamic_gn_converter, - dynamic_in_converter, - dynamic_linear_converter) - - -class TestDynamicLayer(TestCase): - - def test_dynamic_conv(self): - imgs = torch.rand(2, 8, 16, 16) - - in_channels_cfg = dict( - type='OneShotMutableChannel', - candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0], - candidate_mode='ratio') - - out_channels_cfg = dict( - type='OneShotMutableChannel', - candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0], - candidate_mode='ratio') - - conv = nn.Conv2d(8, 8, 1) - dynamic_conv = dynamic_conv2d_converter(conv, in_channels_cfg, - out_channels_cfg) - # test forward - dynamic_conv(imgs) - - conv = nn.Conv2d(8, 8, 1, groups=8) - dynamic_conv = dynamic_conv2d_converter(conv, in_channels_cfg, - out_channels_cfg) - # test forward - dynamic_conv(imgs) - - conv = nn.Conv2d(8, 8, 1, groups=4) - dynamic_conv = dynamic_conv2d_converter(conv, in_channels_cfg, - out_channels_cfg) - # test forward - with self.assertRaisesRegex(NotImplementedError, - 'only support pruning the depth-wise'): - dynamic_conv(imgs) - - def test_dynamic_linear(self): - imgs = torch.rand(2, 8) - - in_features_cfg = dict( - type='OneShotMutableChannel', - candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0], - candidate_mode='ratio') - - out_features_cfg = dict( - type='OneShotMutableChannel', - candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0], - candidate_mode='ratio') - - linear = nn.Linear(8, 8) - dynamic_linear = dynamic_linear_converter(linear, in_features_cfg, - out_features_cfg) - # test forward - dynamic_linear(imgs) - - def test_dynamic_batchnorm(self): - imgs = torch.rand(2, 8, 16, 16) - - num_features_cfg = dict( - type='OneShotMutableChannel', - candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0], - candidate_mode='ratio') - - bn = nn.BatchNorm2d(8) - dynamic_bn = dynamic_bn_converter(bn, num_features_cfg) - # test forward - dynamic_bn(imgs) - - bn = nn.BatchNorm2d(8, momentum=0) - dynamic_bn = dynamic_bn_converter(bn, num_features_cfg) - # test forward - dynamic_bn(imgs) - - bn = nn.BatchNorm2d(8) - bn.train() - dynamic_bn = dynamic_bn_converter(bn, num_features_cfg) - # test forward - dynamic_bn(imgs) - # test num_batches_tracked is not None - dynamic_bn(imgs) - - bn = nn.BatchNorm2d(8, affine=False) - dynamic_bn = dynamic_bn_converter(bn, num_features_cfg) - # test forward - dynamic_bn(imgs) - - bn = nn.BatchNorm2d(8, track_running_stats=False) - dynamic_bn = dynamic_bn_converter(bn, num_features_cfg) - # test forward - dynamic_bn(imgs) - - def test_dynamic_instancenorm(self): - imgs = torch.rand(2, 8, 16, 16) - - num_features_cfg = dict( - type='OneShotMutableChannel', - candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0], - candidate_mode='ratio') - - instance_norm = nn.InstanceNorm2d(8) - dynamic_in = dynamic_in_converter(instance_norm, num_features_cfg) - # test forward - dynamic_in(imgs) - - instance_norm = nn.InstanceNorm2d(8, affine=False) - dynamic_in = dynamic_in_converter(instance_norm, num_features_cfg) - # test forward - dynamic_in(imgs) - - instance_norm = nn.InstanceNorm2d(8, track_running_stats=False) - dynamic_in = dynamic_in_converter(instance_norm, num_features_cfg) - # test forward - dynamic_in(imgs) - - def test_dynamic_groupnorm(self): - imgs = torch.rand(2, 8, 16, 16) - - num_channels_cfg = dict( - type='OneShotMutableChannel', - candidate_choices=[1 / 4, 2 / 4, 3 / 4, 1.0], - candidate_mode='ratio') - - gn = nn.GroupNorm(num_groups=4, num_channels=8) - dynamic_gn = dynamic_gn_converter(gn, num_channels_cfg) - # test forward - dynamic_gn(imgs) - - gn = nn.GroupNorm(num_groups=4, num_channels=8, affine=False) - dynamic_gn = dynamic_gn_converter(gn, num_channels_cfg) - # test forward - dynamic_gn(imgs) diff --git a/tests/test_models/test_mutables/test_mutable_value.py b/tests/test_models/test_mutables/test_mutable_value.py index 0b5ed7947..d7d05b1d5 100644 --- a/tests/test_models/test_mutables/test_mutable_value.py +++ b/tests/test_models/test_mutables/test_mutable_value.py @@ -5,8 +5,8 @@ import pytest import torch -from mmrazor.models.mutables import (MutableValue, OneShotMutableChannel, - OneShotMutableValue) +from mmrazor.models.mutables import (MutableValue, OneShotMutableValue, + SquentialMutableChannel) class TestMutableValue(TestCase): @@ -87,8 +87,7 @@ def test_mul(self) -> None: _ = mv * 1.2 mv = MutableValue(value_list=[1, 2, 3], default_value=3) - mc = OneShotMutableChannel( - num_channels=4, candidate_choices=[2, 4], candidate_mode='number') + mc = SquentialMutableChannel(num_channels=4) with pytest.raises(TypeError): _ = mc * mv diff --git a/tests/test_models/test_mutables/test_sequential_mutable_channel.py b/tests/test_models/test_mutables/test_sequential_mutable_channel.py new file mode 100644 index 000000000..f7f4bb91e --- /dev/null +++ b/tests/test_models/test_mutables/test_sequential_mutable_channel.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +from mmrazor.models.mutables import SquentialMutableChannel + + +class TestSquentialMutableChannel(TestCase): + + def test_mul_float(self): + channel = SquentialMutableChannel(10) + new_channel = channel * 0.5 + self.assertEqual(new_channel.current_choice, 5) + channel.current_choice = 5 + self.assertEqual(new_channel.current_choice, 2) diff --git a/tests/test_models/test_mutators/test_classical_models/_test_mbv2_channel_mutator.py b/tests/test_models/test_mutators/test_classical_models/_test_mbv2_channel_mutator.py deleted file mode 100644 index 4f05b19a6..000000000 --- a/tests/test_models/test_mutators/test_classical_models/_test_mbv2_channel_mutator.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import os -import unittest -from os.path import dirname - -import torch -from mmcls.models import * # noqa: F401,F403 -from mmcls.structures import ClsDataSample - -from mmrazor import digit_version -from mmrazor.models.mutables import SlimmableMutableChannel -from mmrazor.models.mutators import (OneShotChannelMutator, - SlimmableChannelMutator) -from mmrazor.registry import MODELS -from ..utils import load_and_merge_channel_cfgs - -MODEL_CFG = dict( - type='mmcls.ImageClassifier', - backbone=dict(type='MobileNetV2', widen_factor=1.5), - neck=dict(type='GlobalAveragePooling'), - head=dict( - type='LinearClsHead', - num_classes=1000, - in_channels=1920, - loss=dict(type='CrossEntropyLoss', loss_weight=1.0), - topk=(1, 5))) - -ONESHOT_MUTATOR_CFG = dict( - type='OneShotChannelMutator', - skip_prefixes=['head.fc'], - parse_cfg=dict( - type='BackwardTracer', - loss_calculator=dict(type='ImageClassifierPseudoLoss')), - mutable_cfg=dict( - type='OneShotMutableChannel', - candidate_choices=[ - 1 / 8, 2 / 8, 3 / 8, 4 / 8, 5 / 8, 6 / 8, 7 / 8, 1.0 - ], - candidate_mode='ratio')) - - -@unittest.skipIf( - digit_version(torch.__version__) == digit_version('1.8.1'), - 'PyTorch version 1.8.1 is not supported by the Backward Tracer.') -def test_oneshot_channel_mutator() -> None: - imgs = torch.randn(16, 3, 224, 224) - data_samples = [ - ClsDataSample().set_gt_label(torch.randint(0, 1000, (16, ))) - ] - - model = MODELS.build(MODEL_CFG) - mutator: OneShotChannelMutator = MODELS.build(ONESHOT_MUTATOR_CFG) - - mutator.prepare_from_supernet(model) - assert hasattr(mutator, 'name2module') - - # test set_min_choices - mutator.set_min_choices() - for mutables in mutator.search_groups.values(): - for mutable in mutables: - # 1 / 8 is the minimum candidate ratio - assert mutable.current_choice == round(1 / 8 * - mutable.num_channels) - - # test set_max_channel - mutator.set_max_choices() - for mutables in mutator.search_groups.values(): - for mutable in mutables: - # 1.0 is the maximum candidate ratio - assert mutable.current_choice == round(1. * mutable.num_channels) - - # test making groups logic - choice_dict = mutator.sample_choices() - assert isinstance(choice_dict, dict) - mutator.set_choices(choice_dict) - model(imgs, data_samples=data_samples, mode='loss') - - -def test_slimmable_channel_mutator() -> None: - imgs = torch.randn(16, 3, 224, 224) - data_samples = [ - ClsDataSample().set_gt_label(torch.randint(0, 1000, (16, ))) - ] - - root_path = dirname(dirname(dirname(dirname(__file__)))) - channel_cfg_paths = [ - os.path.join(root_path, 'data/MBV2_320M.yaml'), - os.path.join(root_path, 'data/MBV2_220M.yaml') - ] - - mutator = SlimmableChannelMutator( - mutable_cfg=dict(type='SlimmableMutableChannel'), - channel_cfgs=load_and_merge_channel_cfgs(channel_cfg_paths), - parse_cfg=dict( - type='BackwardTracer', - loss_calculator=dict(type='ImageClassifierPseudoLoss'))) - - model = MODELS.build(MODEL_CFG) - mutator.prepare_from_supernet(model) - mutator.switch_choices(0) - for name, module in model.named_modules(): - if isinstance(module, SlimmableMutableChannel): - assert module.current_choice == 0 - model(imgs, data_samples=data_samples, mode='loss') - - mutator.switch_choices(1) - for name, module in model.named_modules(): - if isinstance(module, SlimmableMutableChannel): - assert module.current_choice == 1 - model(imgs, data_samples=data_samples, mode='loss') diff --git a/tests/test_models/test_mutators/test_one_shot_mutator.py b/tests/test_models/test_mutators/test_one_shot_mutator.py deleted file mode 100644 index 41921a2be..000000000 --- a/tests/test_models/test_mutators/test_one_shot_mutator.py +++ /dev/null @@ -1,121 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy - -import pytest -from mmcls.models import * # noqa: F401,F403 -from torch import Tensor -from torch.nn import Module - -from mmrazor.models import OneShotModuleMutator, OneShotMutableModule -from mmrazor.registry import MODELS - -MODEL_CFG = dict( - type='mmcls.ImageClassifier', - backbone=dict( - type='mmcls.ResNet', - depth=50, - num_stages=4, - out_indices=(3, ), - style='pytorch'), - neck=dict(type='mmcls.GlobalAveragePooling'), - head=dict( - type='mmcls.LinearClsHead', - num_classes=1000, - in_channels=2048, - loss=dict(type='mmcls.CrossEntropyLoss', loss_weight=1.0), - topk=(1, 5), - )) - -MUTATOR_CFG = dict(type='OneShotModuleMutator') - -MUTABLE_CFG = dict( - type='OneShotMutableOP', - candidates=dict( - choice1=dict( - type='MBBlock', - in_channels=3, - out_channels=3, - expand_ratio=1, - kernel_size=3), - choice2=dict( - type='MBBlock', - in_channels=3, - out_channels=3, - expand_ratio=1, - kernel_size=5), - choice3=dict( - type='MBBlock', - in_channels=3, - out_channels=3, - expand_ratio=1, - kernel_size=7))) - - -def test_one_shot_mutator_normal_model() -> None: - model = MODELS.build(MODEL_CFG) - mutator: OneShotModuleMutator = MODELS.build(MUTATOR_CFG) - - assert mutator.mutable_class_type == OneShotMutableModule - - with pytest.raises(RuntimeError): - _ = mutator.search_groups - - mutator.prepare_from_supernet(model) - assert len(mutator.search_groups) == 0 - assert len(mutator.sample_choices()) == 0 - - -class _SearchableModel(Module): - - def __init__(self) -> None: - super().__init__() - - self.op1 = MODELS.build(MUTABLE_CFG) - self.op2 = MODELS.build(MUTABLE_CFG) - self.op3 = MODELS.build(MUTABLE_CFG) - - def forward(self, x: Tensor) -> Tensor: - x = self.op1(x) - x = self.op2(x) - x = self.op3(x) - - return x - - -def test_one_shot_mutator_mutable_model() -> None: - model = _SearchableModel() - mutator: OneShotModuleMutator = MODELS.build(MUTATOR_CFG) - - mutator.prepare_from_supernet(model) - assert list(mutator.search_groups.keys()) == [0, 1, 2] - - random_choices = mutator.sample_choices() - assert list(random_choices.keys()) == [0, 1, 2] - for choice in random_choices.values(): - assert choice in ['choice1', 'choice2', 'choice3'] - - custom_group = [['op1', 'op2'], ['op3']] - mutator_cfg = copy.deepcopy(MUTATOR_CFG) - mutator_cfg.update({'custom_group': custom_group}) - mutator = MODELS.build(mutator_cfg) - - mutator.prepare_from_supernet(model) - assert list(mutator.search_groups.keys()) == [0, 1] - - random_choices = mutator.sample_choices() - assert list(random_choices.keys()) == [0, 1] - for choice in random_choices.values(): - assert choice in ['choice1', 'choice2', 'choice3'] - - mutator.set_choices(random_choices) - - custom_group.append(['op4']) - mutator_cfg = copy.deepcopy(MUTATOR_CFG) - mutator_cfg.update({'custom_group': custom_group}) - mutator = MODELS.build(mutator_cfg) - with pytest.raises(AssertionError): - mutator.prepare_from_supernet(model) - - -if __name__ == '__main__': - pytest.main() diff --git a/tests/test_models/test_mutators/utils.py b/tests/test_models/test_mutators/utils.py deleted file mode 100644 index 7ddede648..000000000 --- a/tests/test_models/test_mutators/utils.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List - -from mmengine import fileio - -from mmrazor.models.algorithms import SlimmableNetwork - - -def load_and_merge_channel_cfgs(channel_cfg_paths: List[str]) -> Dict: - channel_cfgs = list() - for channel_cfg_path in channel_cfg_paths: - channel_cfg = fileio.load(channel_cfg_path) - channel_cfgs.append(channel_cfg) - - return SlimmableNetwork.merge_channel_cfgs(channel_cfgs)