diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 53a184a3d..e00ed24c8 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -31,6 +31,44 @@ jobs: python-version: [3.7] torch: [1.6.0, 1.7.0, 1.8.0, 1.9.0, 1.10.0, 1.11.0, 1.12.0, 1.13.0] include: + - torch: 1.6.0 + torch_version: 1.6 + torchvision: 0.7.0 + - torch: 1.7.0 + torch_version: 1.7 + torchvision: 0.8.1 + - torch: 1.7.0 + torch_version: 1.7 + torchvision: 0.8.1 + python-version: 3.8 + - torch: 1.8.0 + torch_version: 1.8 + torchvision: 0.9.0 + - torch: 1.8.0 + torch_version: 1.8 + torchvision: 0.9.0 + python-version: 3.8 + - torch: 1.9.0 + torch_version: 1.9 + torchvision: 0.10.0 + - torch: 1.9.0 + torch_version: 1.9 + torchvision: 0.10.0 + python-version: 3.8 + - torch: 1.10.0 + torch_version: 1.10 + torchvision: 0.11.0 + - torch: 1.10.0 + torch_version: 1.10 + torchvision: 0.11.0 + python-version: 3.8 + - torch: 1.11.0 + torch_version: 1.11 + torchvision: 0.12.0 + - torch: 1.11.0 + torch_version: 1.11 + torchvision: 0.12.0 + python-version: 3.8 - torch: 1.12.0 torch_version: 1.12 torchvision: 0.13.0 diff --git a/configs/pruning/mmpose/dcff/fix_subnet.json b/configs/pruning/mmpose/dcff/fix_subnet.json index dfdcea758..f7b40f41d 100644 --- a/configs/pruning/mmpose/dcff/fix_subnet.json +++ b/configs/pruning/mmpose/dcff/fix_subnet.json @@ -54,7 +54,11 @@ "min_value":1, "min_ratio":0.9 }, +<<<<<<< HEAD "choice":0.59375 +======= + "choice":0.59374 +>>>>>>> 985a611e (Merge dev-1.x into quantize (#430)) }, "backbone.layer2.1.conv1_(0, 128)_128":{ "init_args":{ diff --git a/configs/pruning/mmseg/dcff/dcff_compact_pointrend_resnet50_8xb2_cityscapes.py b/configs/pruning/mmseg/dcff/dcff_compact_pointrend_resnet50_8xb2_cityscapes.py index e6c1eb031..a0d0d044a 100644 --- a/configs/pruning/mmseg/dcff/dcff_compact_pointrend_resnet50_8xb2_cityscapes.py +++ b/configs/pruning/mmseg/dcff/dcff_compact_pointrend_resnet50_8xb2_cityscapes.py @@ -1,7 +1,11 @@ _base_ = ['dcff_pointrend_resnet50_8xb2_cityscapes.py'] # model settings +<<<<<<< HEAD _base_.model = dict( +======= +model_cfg = dict( +>>>>>>> 985a611e (Merge dev-1.x into quantize (#430)) _scope_='mmrazor', type='sub_model', cfg=_base_.architecture, diff --git a/mmrazor/engine/__init__.py b/mmrazor/engine/__init__.py index da6cec34d..603aa3d77 100644 --- a/mmrazor/engine/__init__.py +++ b/mmrazor/engine/__init__.py @@ -4,15 +4,14 @@ from .optimizers import SeparateOptimWrapperConstructor from .runner import (AutoSlimGreedySearchLoop, DartsEpochBasedTrainLoop, DartsIterBasedTrainLoop, EvolutionSearchLoop, - GreedySamplerTrainLoop, SelfDistillValLoop, - SingleTeacherDistillValLoop, SlimmableValLoop, - SubnetValLoop) + GreedySamplerTrainLoop, PTQLoop, QATEpochBasedLoop, + SelfDistillValLoop, SingleTeacherDistillValLoop, + SlimmableValLoop, SubnetValLoop) __all__ = [ 'SeparateOptimWrapperConstructor', 'DumpSubnetHook', 'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop', 'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop', 'GreedySamplerTrainLoop', 'EstimateResourcesHook', 'SelfDistillValLoop', - 'AutoSlimGreedySearchLoop', 'SubnetValLoop', 'StopDistillHook', - 'DMCPSubnetHook' + 'AutoSlimGreedySearchLoop', 'SubnetValLoop', 'PTQLoop', 'QATEpochBasedLoop' ] diff --git a/mmrazor/engine/runner/__init__.py b/mmrazor/engine/runner/__init__.py index 647d8b410..2ca6c0dbb 100644 --- a/mmrazor/engine/runner/__init__.py +++ b/mmrazor/engine/runner/__init__.py @@ -13,6 +13,6 @@ 'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop', 'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop', 'GreedySamplerTrainLoop', 'SubnetValLoop', 'SelfDistillValLoop', - 'ItePruneValLoop', 'AutoSlimGreedySearchLoop', 'PTQLoop', - 'QATEpochBasedLoop' + 'ItePruneValLoop', 'AutoSlimGreedySearchLoop', 'QATEpochBasedLoop', + 'PTQLoop' ] diff --git a/mmrazor/engine/runner/iteprune_val_loop.py b/mmrazor/engine/runner/iteprune_val_loop.py index bbca5d53a..2a627f398 100644 --- a/mmrazor/engine/runner/iteprune_val_loop.py +++ b/mmrazor/engine/runner/iteprune_val_loop.py @@ -52,7 +52,6 @@ def _save_fix_subnet(self): file.write(fix_subnet) torch.save({'state_dict': static_model.state_dict()}, osp.join(self.runner.work_dir, weight_name)) - self.runner.logger.info( 'export finished and ' f'{subnet_name}, ' diff --git a/mmrazor/engine/runner/quantization_loops.py b/mmrazor/engine/runner/quantization_loops.py index 2a0aa812f..e90715910 100644 --- a/mmrazor/engine/runner/quantization_loops.py +++ b/mmrazor/engine/runner/quantization_loops.py @@ -4,9 +4,18 @@ import torch from mmengine.evaluator import Evaluator from mmengine.runner import EpochBasedTrainLoop, TestLoop, ValLoop -from torch.ao.quantization import (disable_observer, enable_fake_quant, - enable_observer) -from torch.nn.intrinsic.qat import freeze_bn_stats + +try: + from torch.ao.quantization import (disable_observer, enable_fake_quant, + enable_observer) + from torch.nn.intrinsic.qat import freeze_bn_stats +except ImportError: + from mmrazor.utils import get_placeholder + disable_observer = get_placeholder('torch>=1.13') + enable_fake_quant = get_placeholder('torch>=1.13') + enable_observer = get_placeholder('torch>=1.13') + freeze_bn_stats = get_placeholder('torch>=1.13') + from torch.utils.data import DataLoader from mmrazor.registry import LOOPS diff --git a/mmrazor/models/algorithms/nas/autoslim.py b/mmrazor/models/algorithms/nas/autoslim.py index dc8d54c0e..77bb6cacc 100644 --- a/mmrazor/models/algorithms/nas/autoslim.py +++ b/mmrazor/models/algorithms/nas/autoslim.py @@ -75,6 +75,8 @@ def __init__(self, self._optim_wrapper_count_status_reinitialized = False self.norm_training = norm_training + self.bn_training_mode = bn_training_mode + def _build_mutator(self, mutator: VALID_MUTATOR_TYPE = None) -> ChannelMutator: """Build mutator.""" diff --git a/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py b/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py index 937aaa156..f510acd76 100644 --- a/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py +++ b/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py @@ -10,6 +10,7 @@ from mmrazor.models.mutables import MutableChannelUnit from mmrazor.models.mutators import ChannelMutator from mmrazor.registry import MODELS +from mmrazor.utils import ValidFixMutable from ..base import BaseAlgorithm LossResults = Dict[str, torch.Tensor] @@ -97,6 +98,8 @@ class ItePruneAlgorithm(BaseAlgorithm): mutator_cfg (Union[Dict, ChannelMutator], optional): The config of a mutator. Defaults to dict( type='ChannelMutator', channel_unit_cfg=dict( type='SequentialMutableChannelUnit')). + fix_subnet (str | dict | :obj:`FixSubnet`): The path of yaml file or + loaded dict or built :obj:`FixSubnet`. Defaults to None. data_preprocessor (Optional[Union[Dict, nn.Module]], optional): Defaults to None. target_pruning_ratio (dict, optional): The prune-target. The template @@ -118,6 +121,7 @@ def __init__(self, type='ChannelMutator', channel_unit_cfg=dict( type='SequentialMutableChannelUnit')), + fix_subnet: Optional[ValidFixMutable] = None, data_preprocessor: Optional[Union[Dict, nn.Module]] = None, target_pruning_ratio: Optional[Dict[str, float]] = None, step_freq=1, diff --git a/mmrazor/models/algorithms/quantization/mm_architecture.py b/mmrazor/models/algorithms/quantization/mm_architecture.py index c14aae08c..f5cf30f10 100644 --- a/mmrazor/models/algorithms/quantization/mm_architecture.py +++ b/mmrazor/models/algorithms/quantization/mm_architecture.py @@ -7,12 +7,17 @@ from mmengine.runner import load_checkpoint from mmengine.structures import BaseDataElement from torch import nn -from torch.ao.quantization import FakeQuantizeBase -from mmrazor.models.task_modules import build_graphmodule +from mmrazor.models.task_modules.tracer import build_graphmodule from mmrazor.registry import MODEL_WRAPPERS, MODELS from ..base import BaseAlgorithm +try: + from torch.ao.quantization import FakeQuantizeBase +except ImportError: + from mmrazor.utils import get_placeholder + FakeQuantizeBase = get_placeholder('torch>=1.13') + LossResults = Dict[str, torch.Tensor] TensorResults = Union[Tuple[torch.Tensor], torch.Tensor] PredictResults = List[BaseDataElement] diff --git a/mmrazor/models/fake_quants/base.py b/mmrazor/models/fake_quants/base.py index 1d4c6dfe0..45aed7421 100644 --- a/mmrazor/models/fake_quants/base.py +++ b/mmrazor/models/fake_quants/base.py @@ -1,4 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from torch.ao.quantization import FakeQuantize +try: + from torch.ao.quantization import FakeQuantize +except ImportError: + from mmrazor.utils import get_placeholder + FakeQuantize = get_placeholder('torch>=1.13') BaseFakeQuantize = FakeQuantize diff --git a/mmrazor/models/fake_quants/torch_fake_quants.py b/mmrazor/models/fake_quants/torch_fake_quants.py index ad1a0d966..b477929ad 100644 --- a/mmrazor/models/fake_quants/torch_fake_quants.py +++ b/mmrazor/models/fake_quants/torch_fake_quants.py @@ -2,10 +2,14 @@ import inspect from typing import List -import torch.ao.quantization.fake_quantize as torch_fake_quant_src - from mmrazor.registry import MODELS +try: + import torch.ao.quantization.fake_quantize as torch_fake_quant_src +except ImportError: + from mmrazor.utils import get_package_placeholder + torch_fake_quant_src = get_package_placeholder('torch>=1.13') + def register_torch_fake_quants() -> List[str]: """Register fake_quants in ``torch.ao.quantization.fake_quantize`` to the diff --git a/mmrazor/models/losses/__init__.py b/mmrazor/models/losses/__init__.py index 3509acd5c..65e2108fd 100644 --- a/mmrazor/models/losses/__init__.py +++ b/mmrazor/models/losses/__init__.py @@ -1,6 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from .ab_loss import ABLoss -from .adaround_loss import AdaRoundLoss from .at_loss import ATLoss from .crd_loss import CRDLoss from .cross_entropy_loss import CrossEntropyLoss diff --git a/mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py b/mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py index cc008b0b8..3aca98c95 100644 --- a/mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py +++ b/mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py @@ -4,11 +4,13 @@ from mmrazor.models.mutables import OneShotMutableChannelUnit from mmrazor.registry import MODELS +from ..group_mixin import DynamicSampleMixin from .channel_mutator import ChannelMutator, ChannelUnitType @MODELS.register_module() -class OneShotChannelMutator(ChannelMutator[OneShotMutableChannelUnit]): +class OneShotChannelMutator(ChannelMutator[OneShotMutableChannelUnit], + DynamicSampleMixin): """OneShotChannelMutator based on ChannelMutator. It use OneShotMutableChannelUnit by default. diff --git a/mmrazor/models/mutators/group_mixin.py b/mmrazor/models/mutators/group_mixin.py index 569f01ebc..3ecd44b74 100644 --- a/mmrazor/models/mutators/group_mixin.py +++ b/mmrazor/models/mutators/group_mixin.py @@ -8,6 +8,11 @@ from mmrazor.models.mutables.mutable_module import MutableModule from .base_mutator import MUTABLE_TYPE +if sys.version_info < (3, 8): + from typing_extensions import Protocol +else: + from typing import Protocol + class GroupMixin(): """A mixin for :class:`BaseMutator`, which can group mutables by @@ -259,3 +264,66 @@ def _check_valid_groups(self, alias2mutable_names: Dict[str, List[str]], f'When a mutable is set alias attribute :{alias_key},' f'the corresponding module name {mutable_name} should ' f'not be used in `custom_group` {custom_group}.') + + +class MutatorProtocol(Protocol): # pragma: no cover + + @property + def mutable_class_type(self) -> Type[BaseMutable]: + ... + + @property + def search_groups(self) -> Dict: + ... + + +class OneShotSampleMixin: + """Sample mixin for one-shot mutators.""" + + def sample_choices(self: MutatorProtocol) -> Dict: + """Sample choices for each group in search_groups.""" + random_choices = dict() + for group_id, modules in self.search_groups.items(): + random_choices[group_id] = modules[0].sample_choice() + + return random_choices + + def set_choices(self: MutatorProtocol, choices: Dict) -> None: + """Set choices for each group in search_groups.""" + for group_id, modules in self.search_groups.items(): + choice = choices[group_id] + for module in modules: + module.current_choice = choice + + +class DynamicSampleMixin(OneShotSampleMixin): + + def sample_choices(self: MutatorProtocol, kind: str = 'random') -> Dict: + """Sample choices for each group in search_groups.""" + random_choices = dict() + for group_id, modules in self.search_groups.items(): + if kind == 'max': + random_choices[group_id] = modules[0].max_choice + elif kind == 'min': + random_choices[group_id] = modules[0].min_choice + else: + random_choices[group_id] = modules[0].sample_choice() + return random_choices + + @property + def max_choice(self: MutatorProtocol) -> Dict: + """Get max choices for each group in search_groups.""" + max_choice = dict() + for group_id, modules in self.search_groups.items(): + max_choice[group_id] = modules[0].max_choice + + return max_choice + + @property + def min_choice(self: MutatorProtocol) -> Dict: + """Get min choices for each group in search_groups.""" + min_choice = dict() + for group_id, modules in self.search_groups.items(): + min_choice[group_id] = modules[0].min_choice + + return min_choice diff --git a/mmrazor/models/mutators/value_mutator/__init__.py b/mmrazor/models/mutators/value_mutator/__init__.py new file mode 100644 index 000000000..a29577bb1 --- /dev/null +++ b/mmrazor/models/mutators/value_mutator/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .dynamic_value_mutator import DynamicValueMutator +from .value_mutator import ValueMutator + +__all__ = ['ValueMutator', 'DynamicValueMutator'] diff --git a/mmrazor/models/mutators/value_mutator/dynamic_value_mutator.py b/mmrazor/models/mutators/value_mutator/dynamic_value_mutator.py new file mode 100644 index 000000000..d8d081343 --- /dev/null +++ b/mmrazor/models/mutators/value_mutator/dynamic_value_mutator.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmrazor.models.mutables import OneShotMutableValue +from mmrazor.registry import MODELS +from ..group_mixin import DynamicSampleMixin +from .value_mutator import ValueMutator + + +@MODELS.register_module() +class DynamicValueMutator(ValueMutator, DynamicSampleMixin): + """Dynamic value mutator with type as `OneShotMutableValue`.""" + + @property + def mutable_class_type(self): + return OneShotMutableValue diff --git a/mmrazor/models/mutators/value_mutator/value_mutator.py b/mmrazor/models/mutators/value_mutator/value_mutator.py new file mode 100644 index 000000000..5127cbe37 --- /dev/null +++ b/mmrazor/models/mutators/value_mutator/value_mutator.py @@ -0,0 +1,73 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Type + +from torch.nn import Module + +from mmrazor.models.mutables import MutableValue +from mmrazor.registry import MODELS +from ..base_mutator import BaseMutator +from ..group_mixin import GroupMixin + + +@MODELS.register_module() +class ValueMutator(BaseMutator[MutableValue], GroupMixin): + """The base class for mutable based mutator. All subclass should implement + the following APIS: + + - ``mutable_class_type`` + Args: + custom_group (list[list[str]], optional): User-defined search groups. + All searchable modules that are not in ``custom_group`` will be + grouped separately. + """ + + def __init__(self, + custom_group: Optional[List[List[str]]] = None, + init_cfg: Optional[Dict] = None) -> None: + super().__init__(init_cfg) + + if custom_group is None: + custom_group = [] + self._custom_group = custom_group + self._search_groups: Optional[Dict[int, List[MutableValue]]] = None + + # TODO + # should be a class property + @property + def mutable_class_type(self) -> Type[MutableValue]: + """Corresponding mutable class type. + + Returns: + Type[MUTABLE_TYPE]: Mutable class type. + """ + return MutableValue + + def prepare_from_supernet(self, supernet: Module) -> None: + """Do some necessary preparations with supernet. + + Note: + For mutable based mutator, we need to build search group first. + Args: + supernet (:obj:`torch.nn.Module`): The supernet to be searched + in your algorithm. + """ + self._search_groups = self.build_search_groups(supernet, + self.mutable_class_type, + self._custom_group) + + @property + def search_groups(self) -> Dict[int, List[MutableValue]]: + """Search group of supernet. + + Note: + For mutable based mutator, the search group is composed of + corresponding mutables. + Raises: + RuntimeError: Called before search group has been built. + Returns: + Dict[int, List[MUTABLE_TYPE]]: Search group. + """ + if self._search_groups is None: + raise RuntimeError( + 'Call `prepare_from_supernet` before access search group!') + return self._search_groups diff --git a/mmrazor/models/observers/base.py b/mmrazor/models/observers/base.py index a68410eb0..ce226cb48 100644 --- a/mmrazor/models/observers/base.py +++ b/mmrazor/models/observers/base.py @@ -1,4 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from torch.ao.quantization.observer import UniformQuantizationObserverBase +try: + from torch.ao.quantization.observer import UniformQuantizationObserverBase +except ImportError: + from mmrazor.utils import get_placeholder + UniformQuantizationObserverBase = get_placeholder('torch>=1.13') BaseObserver = UniformQuantizationObserverBase diff --git a/mmrazor/models/observers/torch_observers.py b/mmrazor/models/observers/torch_observers.py index 8e0e81d58..5dc24609f 100644 --- a/mmrazor/models/observers/torch_observers.py +++ b/mmrazor/models/observers/torch_observers.py @@ -2,10 +2,14 @@ import inspect from typing import List -import torch.ao.quantization.observer as torch_observer_src - from mmrazor.registry import MODELS +try: + import torch.ao.quantization.observer as torch_observer_src +except ImportError: + from mmrazor.utils import get_package_placeholder + torch_observer_src = get_package_placeholder('torch>=1.13') + def register_torch_observers() -> List[str]: """Register observers in ``torch.ao.quantization.observer`` to the diff --git a/mmrazor/models/quantizers/academic_quantizer.py b/mmrazor/models/quantizers/academic_quantizer.py index 6a6500791..768f51c53 100644 --- a/mmrazor/models/quantizers/academic_quantizer.py +++ b/mmrazor/models/quantizers/academic_quantizer.py @@ -1,16 +1,26 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from torch.ao.quantization.fx import prepare -from torch.ao.quantization.fx.custom_config import (FuseCustomConfig, - PrepareCustomConfig) -from torch.ao.quantization.qconfig_mapping import QConfigMapping -from torch.ao.quantization.quant_type import _quant_type_from_str -from torch.ao.quantization.quantize_fx import _fuse_fx from mmrazor.registry import MODELS from mmrazor.structures.quantization import BackendConfigs, QConfigHander from .base import BaseQuantizer +try: + from torch.ao.quantization.fx import prepare + from torch.ao.quantization.fx.custom_config import (FuseCustomConfig, + PrepareCustomConfig) + from torch.ao.quantization.qconfig_mapping import QConfigMapping + from torch.ao.quantization.quant_type import _quant_type_from_str + from torch.ao.quantization.quantize_fx import _fuse_fx +except ImportError: + from mmrazor.utils import get_placeholder + prepare = get_placeholder('torch>=1.13') + FuseCustomConfig = get_placeholder('torch>=1.13') + PrepareCustomConfig = get_placeholder('torch>=1.13') + QConfigMapping = get_placeholder('torch>=1.13') + _quant_type_from_str = get_placeholder('torch>=1.13') + _fuse_fx = get_placeholder('torch>=1.13') + GLOBAL_DICT_KEY = '_global_' OBJECT_TYPE_DICT_KEY = 'object_type' MODULE_NAME_REGEX_DICT_KEY = 'module_name_regex' @@ -23,6 +33,7 @@ @MODELS.register_module() class AcademicQuantizer(BaseQuantizer): + """tmp.""" def __init__(self, qconfig_mapping, @@ -37,6 +48,7 @@ def __init__(self, self.example_inputs = (torch.randn(1, 3, 224, 224), ) def prepare(self, model, graph_module): + """tmp.""" preserved_attributes = self.prepare_custom_config.preserved_attributes for attr_name in preserved_attributes: setattr(graph_module, attr_name, getattr(model, attr_name)) @@ -60,6 +72,7 @@ def prepare(self, model, graph_module): return prepared def gen_qconfig_mapping(self, qconfig_mapping): + """tmp.""" conf = QConfigMapping() if GLOBAL_DICT_KEY in qconfig_mapping: qconfig = QConfigHander(qconfig_mapping[GLOBAL_DICT_KEY]).convert() @@ -86,6 +99,7 @@ def gen_qconfig_mapping(self, qconfig_mapping): return conf def gen_prepare_custom_config(self, prepare_custom_config): + """tmp.""" conf = PrepareCustomConfig() if prepare_custom_config is None: return conf diff --git a/mmrazor/models/quantizers/base.py b/mmrazor/models/quantizers/base.py index d98fbd786..0f14917ac 100644 --- a/mmrazor/models/quantizers/base.py +++ b/mmrazor/models/quantizers/base.py @@ -8,6 +8,7 @@ class BaseQuantizer(BaseModule): + """tmp.""" def __init__(self, tracer): super().__init__() @@ -15,11 +16,11 @@ def __init__(self, tracer): @abstractmethod def prepare(self, model, graph_module): + """tmp.""" pass def swap_ff_with_fxff(self, model): - r""" Swap FloatFunctional with FXFloatFunctional - """ + """Swap FloatFunctional with FXFloatFunctional.""" modules_to_swap = [] for name, module in model.named_children(): if isinstance(module, torch.ao.nn.quantized.FloatFunctional): diff --git a/mmrazor/models/quantizers/native_quantizer.py b/mmrazor/models/quantizers/native_quantizer.py index 84be1edfb..b3f2002e5 100644 --- a/mmrazor/models/quantizers/native_quantizer.py +++ b/mmrazor/models/quantizers/native_quantizer.py @@ -1,45 +1,62 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Tuple + import torch -from torch.ao.quantization import enable_fake_quant -from torch.ao.quantization.fx import prepare -from torch.ao.quantization.qconfig_mapping import QConfigMapping -from torch.ao.quantization.quantize_fx import _fuse_fx -from torch.nn.intrinsic.qat import modules as qat_fused_modules -from torch.nn.qat import modules as qat_modules +try: + from torch.ao.quantization import enable_fake_quant + from torch.ao.quantization.fx import prepare + from torch.ao.quantization.qconfig_mapping import QConfigMapping + from torch.ao.quantization.quantize_fx import _fuse_fx + from torch.nn.intrinsic.qat import modules as qat_fused_modules + from torch.nn.qat import modules as qat_modules +except ImportError: + from mmrazor.utils import get_package_placeholder, get_placeholder + enable_fake_quant = get_placeholder('torch>=1.13') + prepare = get_placeholder('torch>=1.13') + QConfigMapping = get_placeholder('torch>=1.13') + _fuse_fx = get_placeholder('torch>=1.13') + qat_fused_modules = get_package_placeholder('torch>=1.13') + qat_modules = get_package_placeholder('torch>=1.13') + +from mmrazor import digit_version from mmrazor.models.task_modules.tracer.fx import ( del_fakequant_after_function, del_fakequant_after_method, del_fakequant_after_module, del_fakequant_after_op, del_fakequant_before_function, del_fakequant_before_method, del_fakequant_before_module, del_fakequant_before_op) - from mmrazor.models.utils import str2class from mmrazor.registry import MODELS from mmrazor.structures.quantization import BackendConfigs, QConfigHander from .base import BaseQuantizer -SUPPORT_QAT_MODULES = ( - qat_fused_modules.ConvBn1d, qat_fused_modules.ConvBn2d, - qat_fused_modules.ConvBn3d, qat_fused_modules.ConvBnReLU1d, - qat_fused_modules.ConvBnReLU2d, qat_fused_modules.ConvBnReLU3d, - qat_fused_modules.ConvReLU1d, qat_fused_modules.ConvReLU2d, - qat_fused_modules.ConvReLU3d, qat_fused_modules.LinearBn1d, - qat_fused_modules.LinearReLU, qat_modules.Conv1d, qat_modules.Conv2d, - qat_modules.Conv3d, qat_modules.Linear) - -MERGE_BN_MAPPINGS = { - qat_fused_modules.ConvBn1d: qat_modules.Conv1d, - qat_fused_modules.ConvBn2d: qat_modules.Conv2d, - qat_fused_modules.ConvBn3d: qat_modules.Conv3d, - qat_fused_modules.ConvBnReLU1d: qat_fused_modules.ConvReLU1d, - qat_fused_modules.ConvBnReLU2d: qat_fused_modules.ConvReLU2d, - qat_fused_modules.ConvBnReLU3d: qat_fused_modules.ConvReLU3d, - qat_fused_modules.LinearBn1d: qat_modules.Linear -} +if digit_version(torch.__version__) >= digit_version('1.13.0'): + SUPPORT_QAT_MODULES: Tuple = ( + qat_fused_modules.ConvBn1d, qat_fused_modules.ConvBn2d, + qat_fused_modules.ConvBn3d, qat_fused_modules.ConvBnReLU1d, + qat_fused_modules.ConvBnReLU2d, qat_fused_modules.ConvBnReLU3d, + qat_fused_modules.ConvReLU1d, qat_fused_modules.ConvReLU2d, + qat_fused_modules.ConvReLU3d, qat_fused_modules.LinearBn1d, + qat_fused_modules.LinearReLU, qat_modules.Conv1d, qat_modules.Conv2d, + qat_modules.Conv3d, qat_modules.Linear) + + MERGE_BN_MAPPINGS: Dict = { + qat_fused_modules.ConvBn1d: qat_modules.Conv1d, + qat_fused_modules.ConvBn2d: qat_modules.Conv2d, + qat_fused_modules.ConvBn3d: qat_modules.Conv3d, + qat_fused_modules.ConvBnReLU1d: qat_fused_modules.ConvReLU1d, + qat_fused_modules.ConvBnReLU2d: qat_fused_modules.ConvReLU2d, + qat_fused_modules.ConvBnReLU3d: qat_fused_modules.ConvReLU3d, + qat_fused_modules.LinearBn1d: qat_modules.Linear + } +else: + SUPPORT_QAT_MODULES = () + MERGE_BN_MAPPINGS = {} @MODELS.register_module() class NativeQuantizer(BaseQuantizer): + """tmp.""" # backend: 'native' # support_w_modes = ['per_tensor', 'per_channel'] @@ -52,12 +69,12 @@ def __init__(self, extra_redundant_fakequants=dict( extra_module_prev_wo_fakequant=tuple(), extra_module_next_wo_fakequant=tuple(), - extra_function_prev_wo_fakequant = tuple(), - extra_function_next_wo_fakequant = tuple(), - extra_method_prev_wo_fakequant = tuple(), - extra_method_next_wo_fakequant = tuple(), - extra_op_prev_wo_fakequant = tuple(), - extra_op_next_wo_fakequant = tuple())): + extra_function_prev_wo_fakequant=tuple(), + extra_function_next_wo_fakequant=tuple(), + extra_method_prev_wo_fakequant=tuple(), + extra_method_next_wo_fakequant=tuple(), + extra_op_prev_wo_fakequant=tuple(), + extra_op_next_wo_fakequant=tuple())): super().__init__(tracer) self.qconfig = QConfigHander(global_qconfig) if self.qconfig.w_qscheme.is_per_channel: @@ -86,17 +103,21 @@ def __init__(self, @property def backend(self): + """tmp.""" return 'native' @property def support_w_modes(self): + """tmp.""" return ['per_tensor', 'per_channel'] @property def support_a_modes(self): + """tmp.""" return ['per_tensor'] def prepare(self, model, graph_module): + """tmp.""" graph_module = _fuse_fx( graph_module=graph_module, is_qat=True, @@ -115,6 +136,7 @@ def prepare(self, model, graph_module): def post_process_weight_fakequant(self, observed_module, keep_fake_quant=False): + """tmp.""" def traverse(module): for name, child in module.named_children(): @@ -145,70 +167,104 @@ def traverse(module): traverse(observed_module) def prepare_for_mmdeploy(self, model, dummy_input, checkpoint): + """tmp.""" raise NotImplementedError def del_redundant_fakequant(self, prepared): - extra_module_prev_wo_fakequant = self.extra_redundant_fakequants.get('extra_module_prev_wo_fakequant', tuple()) + """tmp.""" + extra_module_prev_wo_fakequant = self.extra_redundant_fakequants.get( + 'extra_module_prev_wo_fakequant', tuple()) prepared = del_fakequant_before_module( - prepared, self.module_prev_wo_fakequant + extra_module_prev_wo_fakequant, inplace=True) + prepared, + self.module_prev_wo_fakequant + extra_module_prev_wo_fakequant, + inplace=True) - extra_module_next_wo_fakequant = self.extra_redundant_fakequants.get('extra_module_next_wo_fakequant', tuple()) + extra_module_next_wo_fakequant = self.extra_redundant_fakequants.get( + 'extra_module_next_wo_fakequant', tuple()) prepared = del_fakequant_after_module( - prepared, self.module_next_wo_fakequant + extra_module_next_wo_fakequant, inplace=True) + prepared, + self.module_next_wo_fakequant + extra_module_next_wo_fakequant, + inplace=True) - extra_function_prev_wo_fakequant = self.extra_redundant_fakequants.get('extra_function_prev_wo_fakequant', tuple()) + extra_function_prev_wo_fakequant = self.extra_redundant_fakequants.get( + 'extra_function_prev_wo_fakequant', tuple()) prepared = del_fakequant_before_method( - prepared, self.function_prev_wo_fakequant + extra_function_prev_wo_fakequant, inplace=True) + prepared, + self.function_prev_wo_fakequant + extra_function_prev_wo_fakequant, + inplace=True) - extra_function_next_wo_fakequant = self.extra_redundant_fakequants.get('extra_function_next_wo_fakequant', tuple()) + extra_function_next_wo_fakequant = self.extra_redundant_fakequants.get( + 'extra_function_next_wo_fakequant', tuple()) prepared = del_fakequant_after_method( - prepared, self.function_next_wo_fakequant + extra_function_next_wo_fakequant, inplace=True) + prepared, + self.function_next_wo_fakequant + extra_function_next_wo_fakequant, + inplace=True) - extra_method_prev_wo_fakequant = self.extra_redundant_fakequants.get('extra_method_prev_wo_fakequant', tuple()) + extra_method_prev_wo_fakequant = self.extra_redundant_fakequants.get( + 'extra_method_prev_wo_fakequant', tuple()) prepared = del_fakequant_before_function( - prepared, self.method_prev_wo_fakequant + extra_method_prev_wo_fakequant, inplace=True) + prepared, + self.method_prev_wo_fakequant + extra_method_prev_wo_fakequant, + inplace=True) - extra_method_next_wo_fakequant = self.extra_redundant_fakequants.get('extra_method_next_wo_fakequant', tuple()) + extra_method_next_wo_fakequant = self.extra_redundant_fakequants.get( + 'extra_method_next_wo_fakequant', tuple()) prepared = del_fakequant_after_function( - prepared, self.method_next_wo_fakequant + extra_method_next_wo_fakequant, inplace=True) + prepared, + self.method_next_wo_fakequant + extra_method_next_wo_fakequant, + inplace=True) - extra_op_prev_wo_fakequant = self.extra_redundant_fakequants.get('extra_op_prev_wo_fakequant', tuple()) + extra_op_prev_wo_fakequant = self.extra_redundant_fakequants.get( + 'extra_op_prev_wo_fakequant', tuple()) prepared = del_fakequant_before_op( - prepared, self.op_prev_wo_fakequant + extra_op_prev_wo_fakequant, inplace=True) + prepared, + self.op_prev_wo_fakequant + extra_op_prev_wo_fakequant, + inplace=True) - extra_op_next_wo_fakequant = self.extra_redundant_fakequants.get('extra_op_next_wo_fakequant', tuple()) + extra_op_next_wo_fakequant = self.extra_redundant_fakequants.get( + 'extra_op_next_wo_fakequant', tuple()) prepared = del_fakequant_after_op( - prepared, self.op_next_wo_fakequant + extra_op_next_wo_fakequant, inplace=True) + prepared, + self.op_next_wo_fakequant + extra_op_next_wo_fakequant, + inplace=True) return prepared @property def module_prev_wo_fakequant(self): + """tmp.""" return tuple() @property def module_next_wo_fakequant(self): + """tmp.""" return tuple() @property def function_prev_wo_fakequant(self): + """tmp.""" return tuple() @property def function_next_wo_fakequant(self): + """tmp.""" return tuple() @property def method_prev_wo_fakequant(self): + """tmp.""" return tuple() @property def method_next_wo_fakequant(self): + """tmp.""" return tuple() @property def op_prev_wo_fakequant(self): + """tmp.""" return tuple() @property def op_next_wo_fakequant(self): + """tmp.""" return tuple() diff --git a/mmrazor/models/quantizers/openvino_quantizer.py b/mmrazor/models/quantizers/openvino_quantizer.py index 0b13b23f9..23abf40da 100644 --- a/mmrazor/models/quantizers/openvino_quantizer.py +++ b/mmrazor/models/quantizers/openvino_quantizer.py @@ -1,8 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Tuple import torch -from torch.ao.quantization import disable_observer + +try: + from torch.ao.quantization import disable_observer +except ImportError: + from mmrazor.utils import get_placeholder + disable_observer = get_placeholder('torch>=1.13') from mmrazor.models.task_modules.tracer.fx import build_graphmodule from mmrazor.registry import MODELS @@ -19,21 +23,24 @@ class OpenVINOQuantizer(NativeQuantizer): @property def backend(self): + """tmp.""" return 'openvino' @property def support_w_modes(self): + """tmp.""" return ['per_tensor', 'per_channel'] @property def support_a_modes(self): + """tmp.""" return ['per_tensor'] def prepare_for_mmdeploy(self, model, dummy_input=(1, 3, 224, 224), checkpoint=None): - + """tmp.""" self.swap_ff_with_fxff(model) graph = self.tracer.trace(model) graph_module = build_graphmodule(model, graph) @@ -52,16 +59,20 @@ def prepare_for_mmdeploy(self, @property def module_prev_wo_fakequant(self): + """tmp.""" return (torch.nn.ReLU6, torch.nn.Identity) @property def module_next_wo_fakequant(self): + """tmp.""" return (torch.nn.MaxPool2d, ) @property def method_next_wo_fakequant(self): + """tmp.""" return ('flatten', ) @property def op_prev_wo_fakequant(self): + """tmp.""" return ('output', ) diff --git a/mmrazor/models/quantizers/tensorrt_quantizer.py b/mmrazor/models/quantizers/tensorrt_quantizer.py index 4d9868c4f..36e3f2be7 100644 --- a/mmrazor/models/quantizers/tensorrt_quantizer.py +++ b/mmrazor/models/quantizers/tensorrt_quantizer.py @@ -1,6 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from torch.ao.quantization import disable_observer + +try: + from torch.ao.quantization import disable_observer +except ImportError: + from mmrazor.utils import get_placeholder + disable_observer = get_placeholder('torch>=1.13') from mmrazor.models.task_modules.tracer.fx.custom_tracer import \ build_graphmodule @@ -24,21 +29,24 @@ def __init__(self, @property def backend(self): + """tmp.""" return 'tensorrt' @property def support_w_modes(self): + """tmp.""" return ['per_tensor', 'per_channel'] @property def support_a_modes(self): + """tmp.""" return ['per_tensor'] def prepare_for_mmdeploy(self, model, dummy_input=(1, 3, 224, 224), checkpoint=None): - + """tmp.""" self.swap_ff_with_fxff(model) graph = self.tracer.trace(model) graph_module = build_graphmodule(model, graph) diff --git a/mmrazor/models/task_modules/tracer/fx/custom_tracer.py b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py index 0e118290e..2d33e9875 100644 --- a/mmrazor/models/task_modules/tracer/fx/custom_tracer.py +++ b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py @@ -5,18 +5,32 @@ import torch import torch.nn as nn + +try: + from torch._C import ScriptObject # type: ignore[attr-defined] + from torch.ao.quantization.quantize_fx import QuantizationTracer + from torch.fx import Graph, GraphModule, Tracer + from torch.fx._symbolic_trace import (_autowrap_check, + _patch_wrapped_functions, _Patcher) + from torch.fx.proxy import Proxy +except ImportError: + from mmrazor.utils import get_placeholder + ScriptObject = get_placeholder('torch>=1.13') + QuantizationTracer = get_placeholder('torch>=1.13') + GraphModule = get_placeholder('torch>=1.13') + Tracer = get_placeholder('torch>=1.13') + Graph = get_placeholder('torch>=1.13') + _autowrap_check = get_placeholder('torch>=1.13') + _patch_wrapped_functions = get_placeholder('torch>=1.13') + _Patcher = get_placeholder('torch>=1.13') + Proxy = get_placeholder('torch>=1.13') + from mmengine.utils import import_modules_from_strings -from torch._C import ScriptObject # type: ignore[attr-defined] -from torch.ao.quantization.quantize_fx import QuantizationTracer -from torch.fx import GraphModule, Tracer -from torch.fx._symbolic_trace import (Graph, _autowrap_check, - _patch_wrapped_functions, _Patcher) -from torch.fx.proxy import Proxy from mmrazor.registry import TASK_UTILS -_orig_module_call: Callable = torch.nn.Module.__call__ -_orig_module_getattr: Callable = torch.nn.Module.__getattr__ +_orig_module_call: Callable = nn.Module.__call__ +_orig_module_getattr: Callable = nn.Module.__getattr__ class UntracedMethodRegistry: @@ -59,13 +73,12 @@ def method(*args, **kwargs): return wrapped_method -def custom_symbolic_trace( - root: Union[torch.nn.Module, Callable[..., Any]], - concrete_args: Optional[Dict[str, Any]] = None) -> GraphModule: +def custom_symbolic_trace(root: Union[nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None): """Modified `symbolic_trace` function. Args: - root (Union[torch.nn.Module, Callable]): Module or function to be + root (Union[nn.Module, Callable]): Module or function to be traced and converted into a Graph representation. concrete_args (Optional[Dict[str, any]]): Inputs to be partially specialized. @@ -75,12 +88,12 @@ def custom_symbolic_trace( """ tracer = CustomTracer() graph = tracer.trace(root, concrete_args) - name = root.__class__.__name__ if isinstance( - root, torch.nn.Module) else root.__name__ + name = root.__class__.__name__ if isinstance(root, + nn.Module) else root.__name__ return GraphModule(tracer.root, graph, name) -def _prepare_module_dict(model: nn.Module, fx_graph: torch.fx.Graph): +def _prepare_module_dict(model: nn.Module, fx_graph): """If there is a class method that can not be traced by the symbolic tracer, a ``call_method`` ``Node`` will be inserted into the ``Graph`` in ``CustomTracer``. @@ -128,7 +141,7 @@ def _prepare_module_dict(model: nn.Module, fx_graph: torch.fx.Graph): Args: model (nn.Module): The original model. - fx_graph (torch.fx.Graph): The fx Graph traced by fx tracer. + fx_graph (Graph): The fx Graph traced by fx tracer. """ def _get_attrs(target, attrs): @@ -157,9 +170,7 @@ def _get_attrs(target, attrs): return module_dict -def build_graphmodule(model: nn.Module, - fx_graph: torch.fx.Graph, - name: str = 'GraphModule'): +def build_graphmodule(model: nn.Module, fx_graph, name: str = 'GraphModule'): modules = dict(model.named_modules()) module_dict = _prepare_module_dict(model, fx_graph) modules.update(module_dict) @@ -228,7 +239,7 @@ def register_skipped_methods(self): method_registry = UntracedMethodRegistry(method) method_registry.__set_name__(imported_cls, method_str) - def call_method(self, m: torch.nn.Module, name, method, args, kwargs): + def call_method(self, m: nn.Module, name, method, args, kwargs): """Method that specifies the behavior of this ``Tracer`` when it encounters a call to an ``nn.Module`` instance. @@ -266,7 +277,7 @@ def call_method(self, m: torch.nn.Module, name, method, args, kwargs): return self.create_proxy('call_method', name, args, kwargs) def trace(self, root, concrete_args=None): - if isinstance(root, torch.nn.Module): + if isinstance(root, nn.Module): self.root = root fn = type(root).forward self.submodule_paths = { @@ -274,7 +285,7 @@ def trace(self, root, concrete_args=None): for name, mod in root.named_modules() } else: - self.root = torch.nn.Module() + self.root = nn.Module() fn = root tracer_cls: Optional[Type['Tracer']] = getattr(self, '__class__', None) @@ -286,7 +297,7 @@ def trace(self, root, concrete_args=None): # used downstream in create_arg self.tensor_attrs: Dict[Union[torch.Tensor, ScriptObject], str] = {} - def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]): + def collect_tensor_attrs(m: nn.Module, prefix_atoms: List[str]): for k, v in m.__dict__.items(): if isinstance(v, (torch.Tensor, ScriptObject)): self.tensor_attrs[v] = '.'.join(prefix_atoms + [k]) @@ -298,8 +309,7 @@ def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]): assert isinstance(fn, FunctionType) fn_globals = fn.__globals__ # run before it gets patched - fn, args = self.create_args_for_root(fn, - isinstance(root, torch.nn.Module), + fn, args = self.create_args_for_root(fn, isinstance(root, nn.Module), concrete_args) # Reduce number of get_attr calls @@ -328,15 +338,12 @@ def forward(*args, **kwargs): with _Patcher() as patcher: # allow duplicate patches to support the case of nested calls patcher.patch_method( - torch.nn.Module, + nn.Module, '__getattr__', module_getattr_wrapper, deduplicate=False) patcher.patch_method( - torch.nn.Module, - '__call__', - module_call_wrapper, - deduplicate=False) + nn.Module, '__call__', module_call_wrapper, deduplicate=False) for name, value in UntracedMethodRegistry.method_dict.items(): wrapped = value['wrapped'] @@ -363,8 +370,7 @@ def is_skipped_method(self, m): custom = isinstance(m, mods) return custom - def is_leaf_module(self, m: torch.nn.Module, - module_qualified_name: str) -> bool: + def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool: # return super().is_leaf_module(m, module_qualified_name) leaf = super().is_leaf_module(m, module_qualified_name) return leaf diff --git a/mmrazor/models/task_modules/tracer/fx/graph_utils.py b/mmrazor/models/task_modules/tracer/fx/graph_utils.py index fe8d620c2..5e3ddc2f4 100644 --- a/mmrazor/models/task_modules/tracer/fx/graph_utils.py +++ b/mmrazor/models/task_modules/tracer/fx/graph_utils.py @@ -2,8 +2,13 @@ import copy from typing import Any, List, Tuple -import torch.fx -from torch.ao.quantization.fake_quantize import FakeQuantizeBase +import torch + +try: + from torch.ao.quantization.fake_quantize import FakeQuantizeBase +except ImportError: + from mmrazor.utils import get_placeholder + FakeQuantizeBase = get_placeholder('torch>=1.13') def _get_attrs(target: torch.nn.Module, attr: str) -> Any: @@ -67,9 +72,9 @@ def recursive_find_erased_nodes(node, prepared_model): return nodes_to_erase -def del_fakequant_before_op(prepared_model: torch.fx.GraphModule, +def del_fakequant_before_op(prepared_model, target_ops: Tuple, - inplace: bool = True) -> torch.fx.GraphModule: + inplace: bool = True): """Delete useless fakequant before nodes whose ``op`` attribute (node.op) is in `target_ops`. @@ -104,9 +109,9 @@ def del_fakequant_before_op(prepared_model: torch.fx.GraphModule, return prepared_model -def del_fakequant_after_op(prepared_model: torch.fx.GraphModule, +def del_fakequant_after_op(prepared_model, target_ops: Tuple, - inplace: bool = True) -> torch.fx.GraphModule: + inplace: bool = True): """Delete useless fakequant after nodes whose ``op`` attribute (node.op) is in `target_ops`. @@ -145,9 +150,9 @@ def del_fakequant_after_op(prepared_model: torch.fx.GraphModule, return prepared_model -def del_fakequant_before_method(prepared_model: torch.fx.GraphModule, +def del_fakequant_before_method(prepared_model, method_patterns: Tuple, - inplace: bool = True) -> torch.fx.GraphModule: + inplace: bool = True): """Delete useless fakequant before nodes whose op attribute (node.op) is `call_method` and target attribute (node.target) is in `target_patterns`. @@ -182,9 +187,9 @@ def del_fakequant_before_method(prepared_model: torch.fx.GraphModule, return prepared_model -def del_fakequant_after_method(prepared_model: torch.fx.GraphModule, +def del_fakequant_after_method(prepared_model, method_patterns: Tuple, - inplace: bool = True) -> torch.fx.GraphModule: + inplace: bool = True): """Delete useless fakequant after nodes whose op attribute (node.op) is `call_method` and target attribute (node.target) is in `target_patterns`. @@ -224,10 +229,9 @@ def del_fakequant_after_method(prepared_model: torch.fx.GraphModule, return prepared_model -def del_fakequant_before_function( - prepared_model: torch.fx.GraphModule, - function_patterns: Tuple, - inplace: bool = True) -> torch.fx.GraphModule: +def del_fakequant_before_function(prepared_model, + function_patterns: Tuple, + inplace: bool = True): """Delete useless fakequant before nodes whose op attribute (node.op) is `call_function` and target attribute (node.target) is in `target_patterns`. @@ -262,9 +266,9 @@ def del_fakequant_before_function( return prepared_model -def del_fakequant_after_function(prepared_model: torch.fx.GraphModule, +def del_fakequant_after_function(prepared_model, function_patterns: Tuple, - inplace: bool = True) -> torch.fx.GraphModule: + inplace: bool = True): """Delete useless fakequant after nodes whose op attribute (node.op) is `call_function` and target attribute (node.target) is in `target_patterns`. @@ -304,9 +308,9 @@ def del_fakequant_after_function(prepared_model: torch.fx.GraphModule, return prepared_model -def del_fakequant_before_module(prepared_model: torch.fx.GraphModule, +def del_fakequant_before_module(prepared_model, module_patterns: Tuple, - inplace: bool = True) -> torch.fx.GraphModule: + inplace: bool = True): """Delete useless fakequant before modules whose type are in `module_patterns`. @@ -340,9 +344,9 @@ def del_fakequant_before_module(prepared_model: torch.fx.GraphModule, return prepared_model -def del_fakequant_after_module(prepared_model: torch.fx.GraphModule, +def del_fakequant_after_module(prepared_model, module_patterns: Tuple, - inplace: bool = True) -> torch.fx.GraphModule: + inplace: bool = True): """Delete useless fakequant after modules whose type are in `module_patterns`. diff --git a/mmrazor/structures/quantization/backend_config/academic.py b/mmrazor/structures/quantization/backend_config/academic.py index 5983c3996..4348e7179 100644 --- a/mmrazor/structures/quantization/backend_config/academic.py +++ b/mmrazor/structures/quantization/backend_config/academic.py @@ -1,23 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from torch.ao.quantization.backend_config import BackendConfig, DTypeConfig + +try: + from torch.ao.quantization.backend_config import BackendConfig, DTypeConfig +except ImportError: + from mmrazor.utils import get_placeholder + BackendConfig = get_placeholder('torch>=1.13') + DTypeConfig = get_placeholder('torch>=1.13') from .common_operator_config_utils import (_get_conv_configs, _get_linear_configs) -# =================== -# | DTYPE CONFIGS | -# =================== - -# weighted op int8 dtype config -# this is config for ops that has quantized weights, like linear, conv -weighted_op_int8_dtype_config = DTypeConfig( - input_dtype=torch.quint8, - output_dtype=torch.quint8, - weight_dtype=torch.qint8, - bias_dtype=torch.float, -) - # ===================== # | BACKEND CONFIGS | # ===================== @@ -25,6 +18,19 @@ def get_academic_backend_config() -> BackendConfig: """Return the `BackendConfig` for academic reseaching.""" + + # =================== + # | DTYPE CONFIGS | + # =================== + # weighted op int8 dtype config + # this is config for ops that has quantized weights, like linear, conv + weighted_op_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + ) + conv_dtype_configs = [weighted_op_int8_dtype_config] linear_dtype_configs = [weighted_op_int8_dtype_config] diff --git a/mmrazor/structures/quantization/backend_config/common_operator_config_utils.py b/mmrazor/structures/quantization/backend_config/common_operator_config_utils.py index 2a855e687..0a381d5d0 100644 --- a/mmrazor/structures/quantization/backend_config/common_operator_config_utils.py +++ b/mmrazor/structures/quantization/backend_config/common_operator_config_utils.py @@ -5,39 +5,71 @@ import torch import torch.nn as nn -import torch.nn.functional as F -import torch.nn.intrinsic as nni -import torch.nn.intrinsic.qat as nniqat -import torch.nn.qat as nnqat -import torch.nn.quantized._reference as nnqr -from torch.ao.quantization.backend_config import (BackendPatternConfig, - DTypeConfig, ObservationType) -from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize -from torch.ao.quantization.fuser_method_mappings import ( - fuse_conv_bn, fuse_conv_bn_relu, fuse_convtranspose_bn, fuse_linear_bn, - reverse2, reverse3, reverse_sequential_wrapper2) -from torch.ao.quantization.qconfig_mapping import _FIXED_QPARAMS_OP_TO_OBSERVER + +from mmrazor import digit_version + +try: + import torch.nn.functional as F + import torch.nn.intrinsic as nni + import torch.nn.intrinsic.qat as nniqat + import torch.nn.qat as nnqat + import torch.nn.quantized._reference as nnqr + from torch.ao.quantization.backend_config import (BackendPatternConfig, + DTypeConfig, + ObservationType) + from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize + from torch.ao.quantization.fuser_method_mappings import ( + fuse_conv_bn, fuse_conv_bn_relu, fuse_convtranspose_bn, fuse_linear_bn, + reverse2, reverse3, reverse_sequential_wrapper2) + from torch.ao.quantization.qconfig_mapping import \ + _FIXED_QPARAMS_OP_TO_OBSERVER +except ImportError: + from mmrazor.utils import get_package_placeholder, get_placeholder + F = get_package_placeholder('torch>=1.13') + nni = get_package_placeholder('torch>=1.13') + nniqat = get_package_placeholder('torch>=1.13') + nnqat = get_package_placeholder('torch>=1.13') + nnqr = get_package_placeholder('torch>=1.13') + BackendPatternConfig = get_placeholder('torch>=1.13') + DTypeConfig = get_placeholder('torch>=1.13') + ObservationType = get_placeholder('torch>=1.13') + FixedQParamsFakeQuantize = get_placeholder('torch>=1.13') + fuse_conv_bn = get_placeholder('torch>=1.13') + fuse_conv_bn_relu = get_placeholder('torch>=1.13') + fuse_convtranspose_bn = get_placeholder('torch>=1.13') + fuse_linear_bn = get_placeholder('torch>=1.13') + reverse2 = get_placeholder('torch>=1.13') + reverse3 = get_placeholder('torch>=1.13') + reverse_sequential_wrapper2 = get_placeholder('torch>=1.13') + _FIXED_QPARAMS_OP_TO_OBSERVER = get_placeholder('torch>=1.13') _ConvMetadata = namedtuple('_ConvMetadata', [ 'root', 'transpose', 'bn', 'reference', 'transpose_reference', 'fused_conv_relu', 'fused_conv_bn', 'fused_conv_bn_relu', 'qat', 'relu_qat', 'bn_qat', 'bn_relu_qat', 'func' ]) -_Conv1dMetadata = _ConvMetadata(nn.Conv1d, nn.ConvTranspose1d, nn.BatchNorm1d, - nnqr.Conv1d, nnqr.ConvTranspose1d, - nni.ConvReLU1d, nni.ConvBn1d, nni.ConvBnReLU1d, - nnqat.Conv1d, nniqat.ConvReLU1d, - nniqat.ConvBn1d, nniqat.ConvBnReLU1d, F.conv1d) -_Conv2dMetadata = _ConvMetadata(nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d, - nnqr.Conv2d, nnqr.ConvTranspose2d, - nni.ConvReLU2d, nni.ConvBn2d, nni.ConvBnReLU2d, - nnqat.Conv2d, nniqat.ConvReLU2d, - nniqat.ConvBn2d, nniqat.ConvBnReLU2d, F.conv2d) -_Conv3dMetadata = _ConvMetadata(nn.Conv3d, nn.ConvTranspose3d, nn.BatchNorm3d, - nnqr.Conv3d, nnqr.ConvTranspose3d, - nni.ConvReLU3d, nni.ConvBn3d, nni.ConvBnReLU3d, - nnqat.Conv3d, nniqat.ConvReLU3d, - nniqat.ConvBn3d, nniqat.ConvBnReLU3d, F.conv3d) + +if digit_version(torch.__version__) >= digit_version('1.13.0'): + _Conv1dMetadata = _ConvMetadata( + nn.Conv1d, nn.ConvTranspose1d, nn.BatchNorm1d, nnqr.Conv1d, + nnqr.ConvTranspose1d, nni.ConvReLU1d, nni.ConvBn1d, nni.ConvBnReLU1d, + nnqat.Conv1d, nniqat.ConvReLU1d, nniqat.ConvBn1d, nniqat.ConvBnReLU1d, + F.conv1d) + _Conv2dMetadata = _ConvMetadata( + nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d, nnqr.Conv2d, + nnqr.ConvTranspose2d, nni.ConvReLU2d, nni.ConvBn2d, nni.ConvBnReLU2d, + nnqat.Conv2d, nniqat.ConvReLU2d, nniqat.ConvBn2d, nniqat.ConvBnReLU2d, + F.conv2d) + _Conv3dMetadata = _ConvMetadata( + nn.Conv3d, nn.ConvTranspose3d, nn.BatchNorm3d, nnqr.Conv3d, + nnqr.ConvTranspose3d, nni.ConvReLU3d, nni.ConvBn3d, nni.ConvBnReLU3d, + nnqat.Conv3d, nniqat.ConvReLU3d, nniqat.ConvBn3d, nniqat.ConvBnReLU3d, + F.conv3d) +else: + toy_val = _ConvMetadata(*[i for i in range(13)]) + _Conv1dMetadata = toy_val + _Conv2dMetadata = toy_val + _Conv3dMetadata = toy_val def _get_binary_op_configs( diff --git a/mmrazor/structures/quantization/backend_config/mapping.py b/mmrazor/structures/quantization/backend_config/mapping.py index 4c87a73b9..b9cc5372b 100644 --- a/mmrazor/structures/quantization/backend_config/mapping.py +++ b/mmrazor/structures/quantization/backend_config/mapping.py @@ -1,12 +1,23 @@ # Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmrazor import digit_version from .academic import get_academic_backend_config from .native import get_native_backend_config from .openvino import get_openvino_backend_config from .tensorrt import get_tensorrt_backend_config -BackendConfigs = { - 'academic': get_academic_backend_config(), - 'native': get_native_backend_config(), - 'tensorrt': get_tensorrt_backend_config(), - 'openvino': get_openvino_backend_config() -} +if digit_version(torch.__version__) >= digit_version('1.13.0'): + BackendConfigs = { + 'academic': get_academic_backend_config(), + 'native': get_native_backend_config(), + 'tensorrt': get_tensorrt_backend_config(), + 'openvino': get_openvino_backend_config() + } +else: + BackendConfigs = { + 'academic': None, + 'native': None, + 'tensorrt': None, + 'openvino': None + } diff --git a/mmrazor/structures/quantization/backend_config/native.py b/mmrazor/structures/quantization/backend_config/native.py index d771b6012..94c35d535 100644 --- a/mmrazor/structures/quantization/backend_config/native.py +++ b/mmrazor/structures/quantization/backend_config/native.py @@ -1,6 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from torch.ao.quantization.backend_config import BackendConfig, DTypeConfig + +try: + from torch.ao.quantization.backend_config import BackendConfig, DTypeConfig +except ImportError: + from mmrazor.utils import get_placeholder + BackendConfig = get_placeholder('torch>=1.13') + DTypeConfig = get_placeholder('torch>=1.13') from .common_operator_config_utils import ( # noqa: F401,F403 _get_binary_op_configs, _get_bn_configs, _get_cat_config, @@ -8,68 +14,6 @@ _get_fixed_qparams_op_configs, _get_linear_configs, _get_ln_configs, _get_rnn_op_configs, _get_share_qparams_op_configs) -# =================== -# | DTYPE CONFIGS | -# =================== - -# weighted op int8 dtype config -# this is config for ops that has quantized weights, like linear, conv -weighted_op_int8_dtype_config = DTypeConfig( - input_dtype=torch.quint8, - output_dtype=torch.quint8, - weight_dtype=torch.qint8, - bias_dtype=torch.float, -) - -default_op_quint8_dtype_config = DTypeConfig( - input_dtype=torch.quint8, - output_dtype=torch.quint8, -) - -default_dynamic_int8_dtype_config = DTypeConfig( - input_dtype=torch.quint8, - output_dtype=torch.float, - weight_dtype=torch.qint8, - bias_dtype=torch.float, - # currently the dtype check is not yet enabled, so we provided the - # dtype_configs but it is not really used yet, - # we will enable it a bit later after we moved everything to - # backend_config_dict - is_dynamic=True, -) - -default_dynamic_float16_dtype_config = DTypeConfig( - input_dtype=torch.float16, - output_dtype=torch.float, - weight_dtype=torch.float16, - bias_dtype=torch.float, - # currently the dtype check is not yet enabled, so we provided the - # dtype_configs but it is not really used yet, we will enable it a bit - # later after we moved everything to backend_config_dict - is_dynamic=True, -) - -# Needed for LayerNorm and f.layer_norm, since currently the kernel only -# supports float weights -input_output_only_quint8_dtype_config = DTypeConfig( - input_dtype=torch.quint8, - output_dtype=torch.quint8, - weight_dtype=torch.float, - bias_dtype=torch.float, -) - -weight_only_quint8_dtype_config = DTypeConfig( - input_dtype=torch.float, - output_dtype=torch.float, - weight_dtype=torch.quint8, -) - -weight_only_quint4x2_dtype_config = DTypeConfig( - input_dtype=torch.float, - output_dtype=torch.float, - weight_dtype=torch.quint4x2, -) - # ===================== # | BACKEND CONFIGS | # ===================== @@ -80,6 +24,68 @@ def get_native_backend_config() -> BackendConfig: (fbgemm/qnnpack).""" # TODO: express this BackendConfig as a union of the FBGEMM and QNNPACK # BackendConfigs + + # =================== + # | DTYPE CONFIGS | + # =================== + # weighted op int8 dtype config + # this is config for ops that has quantized weights, like linear, conv + weighted_op_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + ) + + default_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + ) + + default_dynamic_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.float, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + # currently the dtype check is not yet enabled, so we provided the + # dtype_configs but it is not really used yet, + # we will enable it a bit later after we moved everything to + # backend_config_dict + is_dynamic=True, + ) + + default_dynamic_float16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float, + weight_dtype=torch.float16, + bias_dtype=torch.float, + # currently the dtype check is not yet enabled, so we provided the + # dtype_configs but it is not really used yet, we will enable it a bit + # later after we moved everything to backend_config_dict + is_dynamic=True, + ) + + # Needed for LayerNorm and f.layer_norm, since currently the kernel only + # supports float weights + input_output_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.float, + bias_dtype=torch.float, + ) + + weight_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint8, + ) + + weight_only_quint4x2_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint4x2, + ) + conv_dtype_configs = [weighted_op_int8_dtype_config] linear_dtype_configs = [ weighted_op_int8_dtype_config, diff --git a/mmrazor/structures/quantization/backend_config/openvino.py b/mmrazor/structures/quantization/backend_config/openvino.py index fd24eed17..d990d4ef9 100644 --- a/mmrazor/structures/quantization/backend_config/openvino.py +++ b/mmrazor/structures/quantization/backend_config/openvino.py @@ -1,8 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from torch.ao.quantization.backend_config import (BackendConfig, - BackendPatternConfig, - DTypeConfig, ObservationType) + +try: + from torch.ao.quantization.backend_config import (BackendConfig, + BackendPatternConfig, + DTypeConfig, + ObservationType) +except ImportError: + from mmrazor.utils import get_placeholder + BackendConfig = get_placeholder('torch>=1.13') + BackendPatternConfig = get_placeholder('torch>=1.13') + DTypeConfig = get_placeholder('torch>=1.13') + ObservationType = get_placeholder('torch>=1.13') from .common_operator_config_utils import (_get_binary_op_configs, _get_conv_configs, diff --git a/mmrazor/structures/quantization/backend_config/tensorrt.py b/mmrazor/structures/quantization/backend_config/tensorrt.py index abb585c6a..53305f650 100644 --- a/mmrazor/structures/quantization/backend_config/tensorrt.py +++ b/mmrazor/structures/quantization/backend_config/tensorrt.py @@ -1,8 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from torch.ao.quantization.backend_config import (BackendConfig, - BackendPatternConfig, - DTypeConfig, ObservationType) + +try: + from torch.ao.quantization.backend_config import (BackendConfig, + BackendPatternConfig, + DTypeConfig, + ObservationType) +except ImportError: + from mmrazor.utils import get_placeholder + BackendConfig = get_placeholder('torch>=1.13') + BackendPatternConfig = get_placeholder('torch>=1.13') + DTypeConfig = get_placeholder('torch>=1.13') + ObservationType = get_placeholder('torch>=1.13') from .common_operator_config_utils import (_get_binary_op_configs, _get_conv_configs, diff --git a/mmrazor/structures/quantization/qconfig.py b/mmrazor/structures/quantization/qconfig.py index 3dca49730..e0fdf113d 100644 --- a/mmrazor/structures/quantization/qconfig.py +++ b/mmrazor/structures/quantization/qconfig.py @@ -3,7 +3,12 @@ import torch from mmengine.config import Config -from torch.ao.quantization import QConfig + +try: + from torch.ao.quantization import QConfig +except ImportError: + from mmrazor.utils import get_placeholder + QConfig = get_placeholder('torch>=1.13') from mmrazor.registry import MODELS diff --git a/tests/data/models.py b/tests/data/models.py index 33fb0c624..0347b9147 100644 --- a/tests/data/models.py +++ b/tests/data/models.py @@ -78,7 +78,6 @@ def untracable_method(self, x): x = x * -2 return x - @MODELS.register_module() class UntracableBackBone(nn.Module): @@ -123,7 +122,6 @@ def forward(self, x): x_last = self.conv2(x_attn) return self.head(x_last) - @MODELS.register_module() class LinearHeadForTest(Module): @@ -704,7 +702,6 @@ def current_choice(self): def current_choice(self, choice): super().current_choice(choice) - class DynamicLinearModel(nn.Module): """ x diff --git a/tests/test_data.py b/tests/test_data.py index df3e07f69..d56a2950b 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -6,8 +6,13 @@ from .data.model_library import (DefaultModelLibrary, MMClsModelLibrary, MMDetModelLibrary, MMModelLibrary, +<<<<<<< HEAD MMPoseModelLibrary, MMSegModelLibrary, ModelGenerator, TorchModelLibrary) +======= + MMSegModelLibrary, ModelGenerator, + TorchModelLibrary) +>>>>>>> 985a611e (Merge dev-1.x into quantize (#430)) from .data.models import SingleLineModel from .data.tracer_passed_models import (BackwardPassedModelManager, FxPassedModelManager) @@ -45,6 +50,7 @@ def test_mmseg(self): if not TEST_DATA: self.skipTest('not test data to save time.') library = MMSegModelLibrary() +<<<<<<< HEAD print(library.short_names()) self.assertTrue(library.is_default_includes_cover_all_models()) @@ -55,6 +61,8 @@ def test_mmpose(self): self.skipTest('not test data to save time.') library = MMPoseModelLibrary() print(library.short_names()) +======= +>>>>>>> 985a611e (Merge dev-1.x into quantize (#430)) self.assertTrue(library.is_default_includes_cover_all_models()) def test_get_model_by_config(self): diff --git a/tests/test_models/test_mutators/test_value_mutator.py b/tests/test_models/test_mutators/test_value_mutator.py new file mode 100644 index 000000000..a76257a9e --- /dev/null +++ b/tests/test_models/test_mutators/test_value_mutator.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import unittest + +import torch + +from mmrazor.models.mutables import MutableValue +from mmrazor.models.mutators import DynamicValueMutator +from tests.data.models import DynamicAttention, DynamicMMBlock + + +class TestValueMutator(unittest.TestCase): + + def test_models_with_predefined_dynamic_op(self): + for Model in [ + DynamicAttention, + ]: + with self.subTest(model=Model): + model = Model() + value_mutator = DynamicValueMutator() + value_mutator.prepare_from_supernet(model) + value_choices = value_mutator.sample_choices() + value_mutator.set_choices(value_choices) + + mutable_value_space = [] + for mutable_value, module in model.named_modules(): + if isinstance(module, MutableValue): + mutable_value_space.append(mutable_value) + elif hasattr(module, 'source_mutables'): + for each_mutables in module.source_mutables: + if isinstance(each_mutables, MutableValue): + mutable_value_space.append(each_mutables) + assert len( + value_mutator.search_groups) == len(mutable_value_space) + + x = torch.rand([2, 3, 224, 224]) + y = model(x) + self.assertEqual(list(y.shape), [2, 624]) + + def test_models_with_multiple_value(self): + for Model in [ + DynamicMMBlock, + ]: + with self.subTest(model=Model): + model = Model() + value_mutator = DynamicValueMutator() + value_mutator.prepare_from_supernet(model) + value_choices = value_mutator.sample_choices() + value_mutator.set_choices(value_choices) + + # TODO check DynamicMMBlock + mutable_value_space = [] + for mutable_value, module in model.named_modules(): + if isinstance(module, MutableValue): + mutable_value_space.append(mutable_value) + elif hasattr(module, 'source_mutables'): + for each_mutables in module.source_mutables: + if isinstance(each_mutables, MutableValue): + mutable_value_space.append(each_mutables) + count = 0 + for values in value_mutator.search_groups.values(): + count += len(values) + assert count == len(mutable_value_space) + + x = torch.rand([2, 3, 224, 224]) + y = model(x) + self.assertEqual(list(y[-1].shape), [2, 1984, 1, 1]) diff --git a/tests/test_models/test_task_modules/test_custom_tracer.py b/tests/test_models/test_task_modules/test_custom_tracer.py deleted file mode 100644 index 671922f69..000000000 --- a/tests/test_models/test_task_modules/test_custom_tracer.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from unittest import TestCase - -from mmrazor.models.task_modules import CustomTracer, UntracedMethodRegistry -from mmrazor.testing import ConvBNReLU - - -class testCustomTracer(TestCase): - - def test_init(self): - tracer = CustomTracer() - assert tracer.skipped_methods.__len__() == 0 - - def test_trace(self): - tracer = CustomTracer() - model = ConvBNReLU(3, 3, norm_cfg=dict(type='BN')) - graph = tracer.trace(model) # noqa: F841 - - def test_auto_skip_call_module(self): - pass - - def test_auto_skip_call_method(self): - pass - - def test_configurable_skipped_methods(self): - pass - - -class testUntracedMethodRgistry(TestCase): - - def test_init(self): - self.assertEqual(len(UntracedMethodRegistry.method_dict), 0) - - def test_add_method(self): - pass diff --git a/tests/test_models/test_task_modules/test_graph_utils.py b/tests/test_models/test_task_modules/test_graph_utils.py index 7250bee95..d8f53c03c 100644 --- a/tests/test_models/test_task_modules/test_graph_utils.py +++ b/tests/test_models/test_task_modules/test_graph_utils.py @@ -4,13 +4,21 @@ import torch import torch.nn as nn -from torch.ao.quantization import QConfigMapping -from torch.ao.quantization.fake_quantize import FakeQuantizeBase -from torch.ao.quantization.fx import prepare -from torch.ao.quantization.quantize_fx import _fuse_fx -from mmrazor.models.task_modules import build_graphmodule -from mmrazor.models.task_modules.tracer import CustomTracer +try: + from torch.ao.quantization import QConfigMapping + from torch.ao.quantization.fake_quantize import FakeQuantizeBase + from torch.ao.quantization.fx import prepare + from torch.ao.quantization.quantize_fx import _fuse_fx +except ImportError: + from mmrazor.utils import get_placeholder + QConfigMapping = get_placeholder('torch>=1.13') + FakeQuantizeBase = get_placeholder('torch>=1.13') + prepare = get_placeholder('torch>=1.13') + _fuse_fx = get_placeholder('torch>=1.13') + +from mmrazor import digit_version +from mmrazor.models.task_modules.tracer import CustomTracer, build_graphmodule from mmrazor.models.task_modules.tracer.fx import ( del_fakequant_after_function, del_fakequant_after_method, del_fakequant_after_module, del_fakequant_after_op, @@ -106,6 +114,9 @@ def forward(self, x): class TestGraphUtils(TestCase): def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + self.tracer = CustomTracer() self.backend_config = BackendConfigs['native'] self.qconfig = QConfigHander(global_qconfig) @@ -114,6 +125,9 @@ def setUp(self): self.example_inputs = (torch.randn(1, 3, 224, 224), ) def swap_ff_with_fxff(self, model): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + modules_to_swap = [] for name, module in model.named_children(): if isinstance(module, torch.ao.nn.quantized.FloatFunctional): @@ -126,6 +140,9 @@ def swap_ff_with_fxff(self, model): model._modules[name] = torch.ao.nn.quantized.FXFloatFunctional() def test_del_fakequant_before_op(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + model_to_quantize = ToyModel() model_to_quantize.eval() @@ -170,6 +187,9 @@ def test_del_fakequant_before_op(self): _get_attrs(prepared, args[0].target), FakeQuantizeBase) def test_del_fakequant_after_op(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + model_to_quantize = ToyModel() model_to_quantize.eval() @@ -211,6 +231,8 @@ def test_del_fakequant_after_op(self): _get_attrs(prepared, node.next.target), FakeQuantizeBase) def test_del_fakequant_before_method(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') model_to_quantize = ToyModel() model_to_quantize.eval() @@ -259,6 +281,9 @@ def test_del_fakequant_before_method(self): _get_attrs(prepared, args[0].target), FakeQuantizeBase) def test_del_fakequant_after_method(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + model_to_quantize = ToyModel() model_to_quantize.eval() @@ -303,6 +328,9 @@ def test_del_fakequant_after_method(self): _get_attrs(prepared, node.next.target), FakeQuantizeBase) def test_del_fakequant_before_function(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + model_to_quantize = ToyModel() model_to_quantize.eval() @@ -356,6 +384,9 @@ def test_del_fakequant_before_function(self): _get_attrs(prepared, args[1].target), FakeQuantizeBase) def test_del_fakequant_after_function(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + model_to_quantize = ToyModel() model_to_quantize.eval() @@ -400,6 +431,9 @@ def test_del_fakequant_after_function(self): _get_attrs(prepared, node.next.target), FakeQuantizeBase) def test_del_fakequant_before_module(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + model_to_quantize = ToyModel() model_to_quantize.eval() @@ -452,6 +486,9 @@ def test_del_fakequant_before_module(self): _get_attrs(prepared, args[0].target), FakeQuantizeBase) def test_del_fakequant_after_module(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + model_to_quantize = ToyModel() model_to_quantize.eval() diff --git a/tests/test_registry/test_registry.py b/tests/test_registry/test_registry.py index 009640684..c8340f352 100644 --- a/tests/test_registry/test_registry.py +++ b/tests/test_registry/test_registry.py @@ -12,6 +12,8 @@ from mmrazor.models.algorithms.base import BaseAlgorithm from mmrazor.models.mutables import OneShotMutableOP from mmrazor.registry import MODELS +from mmrazor.structures import load_fix_subnet +from mmrazor.utils import ValidFixMutable @MODELS.register_module() @@ -44,13 +46,15 @@ class MockAlgorithm(BaseAlgorithm): def __init__(self, architecture: Union[BaseModel, Dict], - _return_architecture_: Optional[bool] = None): + fix_subnet: Optional[ValidFixMutable] = None): super().__init__(architecture) - if _return_architecture_ is True: - self.return_model = self.architecture + if fix_subnet is not None: + # According to fix_subnet, delete the unchosen part of supernet + load_fix_subnet(self, fix_subnet, prefix='architecture.') + self.is_supernet = False else: - self.return_model = self + self.is_supernet = True class TestRegistry(TestCase): @@ -68,18 +72,34 @@ def test_build_razor_from_cfg(self): # model = MODELS.build(self.arch_cfg_path) # self.assertIsNotNone(model) - # test return architecture + # test fix subnet cfg = Config.fromfile( - 'tests/data/test_registry/registry_architecture_config.py') + 'tests/data/test_registry/registry_subnet_config.py') model = MODELS.build(cfg.model) - self.assertTrue(isinstance(model.return_model, MockModel)) - # test return model + # test return architecture cfg = Config.fromfile( 'tests/data/test_registry/registry_architecture_config.py') - cfg.model.pop('_return_architecture_') model = MODELS.build(cfg.model) - self.assertTrue(isinstance(model.return_model, MockAlgorithm)) + self.assertTrue(isinstance(model, BaseModel)) + + def test_build_subnet_prune_from_cfg(self): + mutator_cfg = fileio.load('tests/data/test_registry/subnet.json') + init_cfg = dict( + type='Pretrained', + checkpoint='tests/data/test_registry/subnet_weight.pth') + # test fix subnet + model_cfg = dict( + # use mmrazor's build_func + type='mmrazor.sub_model', + cfg=dict( + cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', + pretrained=False), + fix_subnet=mutator_cfg, + mode='mutator', + init_cfg=init_cfg) + model = MODELS.build(model_cfg) + self.assertTrue(isinstance(model, BaseModel)) def test_build_subnet_prune_from_cfg_by_mutator(self): mutator_cfg = fileio.load('tests/data/test_registry/subnet.json') diff --git a/tests/test_structures/test_qconfig.py b/tests/test_structures/test_qconfig.py index 045b02c83..4730ab6cc 100644 --- a/tests/test_structures/test_qconfig.py +++ b/tests/test_structures/test_qconfig.py @@ -4,8 +4,14 @@ import torch from mmengine.config import Config -from torch.ao.quantization import QConfig +try: + from torch.ao.quantization import QConfig +except ImportError: + from mmrazor.utils import get_placeholder + QConfig = get_placeholder('torch>=1.13') + +from mmrazor import digit_version from mmrazor.models.fake_quants import register_torch_fake_quants from mmrazor.models.observers import register_torch_observers from mmrazor.structures import QConfigHander, QSchemeHander @@ -17,6 +23,9 @@ class TestQSchemeHander(TestCase): def test_init(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + # per_channel qscheme = QSchemeHander(is_symmetry=True, is_per_channel=True) assert qscheme.torch_qscheme is torch.per_channel_symmetric @@ -34,6 +43,9 @@ def test_init(self): assert qscheme.is_symmetric_range is True def test_to_observer_params(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + # qdtype = quint8 ret_params = QSchemeHander(qdtype='quint8').to_observer_params() assert ret_params['dtype'] == torch.quint8 @@ -78,6 +90,9 @@ def setUp(self): self.qconfig = Config(self.qconfig_dict) def test_check_qconfig(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + assert QConfigHander.check_qconfig(self.qconfig_dict) is True assert QConfigHander.check_qconfig(self.qconfig) is True qconfig_dict = copy.copy(self.qconfig_dict) @@ -86,6 +101,9 @@ def test_check_qconfig(self): assert QConfigHander.check_qconfig(qconfig_dict) is False def test_init(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + # test dict init qconfig = QConfigHander(self.qconfig_dict) assert hasattr(qconfig, 'w_qscheme') @@ -105,6 +123,9 @@ def test_init(self): assert qconfig.a_qscheme.is_per_channel is True def test_convert(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + qconfig = QConfigHander(self.qconfig) torch_qconfig = qconfig.convert() assert isinstance(torch_qconfig, QConfig)