diff --git a/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py index df667c141..d7c9cdf47 100644 --- a/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py +++ b/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py @@ -17,12 +17,13 @@ qdtype='quint8', bit=8, is_symmetry=True, averaging_constant=0.1), ) +float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth' # noqa: E501 + model = dict( _delete_=True, type='mmrazor.MMArchitectureQuant', architecture=_base_.model, - float_checkpoint='/tmp/humu/mobilenet_v2_batch256_imagenet' + - '_20200708-3b2dc3af.pth', + float_checkpoint=float_checkpoint, quantizer=dict( type='mmrazor.OpenVINOQuantizer', global_qconfig=global_qconfig, @@ -32,3 +33,8 @@ 'mmcls.models.heads.ClsHead._get_loss', 'mmcls.models.heads.ClsHead._get_predictions' ]))) + +model_wrapper_cfg = dict( + type='mmrazor.MMArchitectureQuantDDP', + broadcast_buffers=False, + find_unused_parameters=True) diff --git a/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py index 56da13de9..5ba1eec85 100644 --- a/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py +++ b/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py @@ -19,11 +19,13 @@ qdtype='quint8', bit=8, is_symmetry=True, averaging_constant=0.1), ) +float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501 + model = dict( _delete_=True, type='mmrazor.MMArchitectureQuant', architecture=_base_.model, - float_checkpoint='/tmp/humu/resnet18_8xb32_in1k_20210831-fbbb1da6.pth', + float_checkpoint=float_checkpoint, quantizer=dict( type='mmrazor.OpenVINOQuantizer', global_qconfig=global_qconfig, @@ -33,3 +35,5 @@ 'mmcls.models.heads.ClsHead._get_loss', 'mmcls.models.heads.ClsHead._get_predictions' ]))) + +model_wrapper_cfg = dict(type='mmrazor.MMArchitectureQuantDDP', ) diff --git a/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py index 09e103bfc..bd734ee40 100644 --- a/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py +++ b/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py @@ -19,11 +19,13 @@ qdtype='quint8', bit=8, is_symmetry=True, averaging_constant=0.1), ) +float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth' # noqa: E501 + model = dict( _delete_=True, type='mmrazor.MMArchitectureQuant', architecture=_base_.model, - float_checkpoint='/tmp/humu/resnet50_8xb32_in1k_20210831-ea4938fc.pth', + float_checkpoint=float_checkpoint, quantizer=dict( type='mmrazor.OpenVINOQuantizer', global_qconfig=global_qconfig, @@ -33,3 +35,4 @@ 'mmcls.models.heads.ClsHead._get_loss', 'mmcls.models.heads.ClsHead._get_predictions' ]))) +model_wrapper_cfg = dict(type='mmrazor.MMArchitectureQuantDDP', ) diff --git a/configs/quantization/qat/minmax_openvino_resnet18_8xb32_in1k.py b/configs/quantization/qat/minmax_openvino_resnet18_8xb32_in1k.py new file mode 100644 index 000000000..8aa11d6b3 --- /dev/null +++ b/configs/quantization/qat/minmax_openvino_resnet18_8xb32_in1k.py @@ -0,0 +1,65 @@ +_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] + +train_dataloader = dict(batch_size=1024) + +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', + bit=8, + is_symmetry=True, + is_symmetric_range=True, + ), + a_qscheme=dict( + qdtype='quint8', + bit=8, + is_symmetry=True, + averaging_constant=0.1, + ), +) + +float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501 + +model = dict( + _delete_=True, + type='mmrazor.MMArchitectureQuant', + architecture=_base_.model, + float_checkpoint=float_checkpoint, + quantizer=dict( + type='mmrazor.OpenVINOQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ]))) + +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.004, momentum=0.9, weight_decay=0.0001)) + +# learning policy +param_scheduler = dict( + _delete_=True, + type='CosineAnnealingLR', + T_max=100, + by_epoch=True, + begin=0, + end=100) + +model_wrapper_cfg = dict( + type='mmrazor.MMArchitectureQuantDDP', + broadcast_buffers=False, + find_unused_parameters=True) + +# train, val, test setting +train_cfg = dict( + _delete_=True, + type='mmrazor.QATEpochBasedLoop', + max_epochs=100, + val_interval=1) +val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop') +# test_cfg = val_cfg diff --git a/mmrazor/models/algorithms/quantization/mm_architecture.py b/mmrazor/models/algorithms/quantization/mm_architecture.py index f5cf30f10..9feb3fb53 100644 --- a/mmrazor/models/algorithms/quantization/mm_architecture.py +++ b/mmrazor/models/algorithms/quantization/mm_architecture.py @@ -10,7 +10,7 @@ from mmrazor.models.task_modules.tracer import build_graphmodule from mmrazor.registry import MODEL_WRAPPERS, MODELS -from ..base import BaseAlgorithm +from ..base import BaseAlgorithm, BaseModel try: from torch.ao.quantization import FakeQuantizeBase @@ -29,35 +29,43 @@ class MMArchitectureQuant(BaseAlgorithm): """General quantization. Args: - architecture (dict | :obj:`BaseModel`): The config of - :class:`BaseModel` or built model. - quantizer (dict | :obj:`BaseModel`): The config of - :class:`BaseQuantizer` or built model. - export_mode (str): The mode of the model to be exported. Defaults to - predict. - qmodel_modes (list): The available mode of runner. - data_preprocessor (dict | torch.nn.Module | None): The pre-process + architecture (Union[Dict, BaseModel]): The config of model to be + quantized. + quantizer (Union[Dict, BaseModel]): The quantizer to support different + backend type. + qmodel_modes (List): The available mode of runner. + data_preprocessor (Optional[Dict]): The pre-process config of :class:`BaseDataPreprocessor`. Defaults to None. - pretrained_ckpt (str, Optional): The path of pretrained checkpoint. - Defaults to None. - init_cfg (dict): The weight initialized config for - :class:`BaseModule`. + forward_modes (Tuple): The modes in forward method in OpenMMLab + architecture could be tensor, predict, or loss. It can generate + different graph of quantized model. + float_checkpoint (Optional[str]): The path of pretrained FP checkpoint. + Quantization is different from or task, we recommend to use + `float_checkpoint` as pretrain model. Defaults to None. + init_cfg (Optional[Dict]): The weight initialized config for: + class:`BaseModule`. + + Note: + forward_modes (Tuple): In OpenMMLab architecture, differenet modes + will trace a different graph of quantized model. """ def __init__(self, - architecture, - quantizer, - data_preprocessor=None, - forward_modes=('tensor', 'predict', 'loss'), + architecture: Union[Dict, BaseModel], + quantizer: Union[Dict, BaseModel], + data_preprocessor: Optional[Dict] = None, + forward_modes: Tuple = ('tensor', 'predict', 'loss'), float_checkpoint: Optional[str] = None, - input_shapes=(1, 3, 224, 224), - init_cfg=None): + input_shapes: Tuple = (1, 3, 224, 224), + init_cfg: Optional[Dict] = None): if data_preprocessor is None: data_preprocessor = {} # The build process is in MMEngine, so we need to add scope here. + # Default to mmcls.ClsDataPreprocessor. data_preprocessor.setdefault('type', 'mmcls.ClsDataPreprocessor') super().__init__(architecture, data_preprocessor, init_cfg) + # If we have a float_checkpoint, we load it as pretrain. if float_checkpoint: _ = load_checkpoint(self.architecture, float_checkpoint) self.architecture._is_init = True @@ -70,7 +78,22 @@ def __init__(self, self.sync_qparams('predict') - def sync_qparams(self, src_mode): + def sync_qparams(self, src_mode: str): + """Sync all quantize parameters in different `forward_modes`. We could + have more than one forward mode to generate graphs, each mode will + generate one graph. But in training, only one graph will be update, so + we need to sync qparams in the other graphs. + + Args: + src_mode (str): The modes of forward method. + + Note: + `traverse()` function recursively traverses all module to sync + quantized graph generated from different `forward_modes`. + This is because We have different mode ('tensor', 'predict', + 'loss') in OpenMMLab architecture which have different graph + in some subtle ways, so we need to sync them here. + """ def traverse(module, prefix): for name, child in module._modules.items(): @@ -84,10 +107,10 @@ def traverse(module, prefix): if src_param.shape == param.shape: param.data.copy_(src_param) else: - requirs_grad = param.requires_grad - param.requires_grad = False + # requirs_grad = param.requires_grad + # param.requires_grad = False param.resize_(src_param.shape) - param.requires_grad = requirs_grad + # param.requires_grad = requirs_grad param.data.copy_(src_param) for name, buffer in child.named_buffers(): buffer_name = f'{child_name}.{name}' @@ -106,7 +129,31 @@ def traverse(module, prefix): continue traverse(self.qmodels[mode], '') - def _build_qmodels(self, model): + def _build_qmodels(self, model: BaseModel): + """Build quantized models from the given model. + + Args: + model (BaseModel): the given fp model. + + Example: + The main body of the graph is all the same, but the last one or two + op will have difference, as shown below. + + self.qmodels['tensor'].graph.print_tabular() + opcode target args + call_module head.fc (activation_post_process_38,) + output output (head_fc,) + + self.qmodels['loss'].graph.print_tabular() + opcode target args + call_method _get_loss (head, head_fc, data_samples) + output output (_get_loss,) + + self.qmodels['predict'].graph.print_tabular() + opcode target args + call_method _get_predictions (head, head_fc, data_samples) + output output (_get_predictions,) + """ qmodels = nn.ModuleDict() @@ -137,19 +184,27 @@ def forward(self, else: return self.architecture(inputs, data_samples, mode) - def calibrate_step(self, data): + def calibrate_step(self, data: Union[Dict, Tuple, List]): + """PTQ method need calibrate by cali data.""" + data = self.data_preprocessor(data, False) return self._run_forward(data, mode='predict') @MODEL_WRAPPERS.register_module() class MMArchitectureQuantDDP(MMDistributedDataParallel): - """DDPwapper for GeneralQuant.""" + """DDPwapper for GeneralQuant. + + Args: + device_ids (Optional[Union[List, int, torch.device]]): devices to run + ddp. + """ def __init__(self, *, device_ids: Optional[Union[List, int, torch.device]] = None, **kwargs) -> None: + if device_ids is None: if os.environ.get('LOCAL_RANK') is not None: device_ids = [int(os.environ['LOCAL_RANK'])] @@ -159,8 +214,26 @@ def __init__(self, self.module.qmodels = self.module._build_qmodels( self.module.architecture) - def calibrate_step(self, data): + def calibrate_step(self, data: Union[Dict, Tuple, List]): + """PTQ method need calibrate by cali data.""" + return self.module.calibrate_step(data) - def sync_qparams(self, src): + def sync_qparams(self, src: str): + """Same as in 'MMArchitectureQuant'. Sync all quantize parameters in + different `forward_modes`. We could have several modes to generate + graphs, but in training, only one graph will be update, so we need to + sync qparams on the other graphs. + + Args: + src (str): The src modes of forward method. + + Note: + `traverse()` function recursively traverses all module to sync + quantized graph generated from different `forward_modes`. + This is because We have different mode ('tensor', 'predict', + 'loss') in OpenMMLab architecture which have different graph + in some subtle ways, so we need to sync them here. + """ + self.module.sync_qparams(src) diff --git a/mmrazor/models/quantizers/native_quantizer.py b/mmrazor/models/quantizers/native_quantizer.py index b3f2002e5..d0534d361 100644 --- a/mmrazor/models/quantizers/native_quantizer.py +++ b/mmrazor/models/quantizers/native_quantizer.py @@ -1,17 +1,23 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, Tuple +from typing import Dict, List, Optional, Tuple, Union import torch +import torch.nn as nn +from mmengine.config import Config try: from torch.ao.quantization import enable_fake_quant from torch.ao.quantization.fx import prepare + from torch.ao.quantization.fx.graph_module import ObservedGraphModule from torch.ao.quantization.qconfig_mapping import QConfigMapping from torch.ao.quantization.quantize_fx import _fuse_fx + from torch.fx.graph_module import GraphModule from torch.nn.intrinsic.qat import modules as qat_fused_modules from torch.nn.qat import modules as qat_modules except ImportError: from mmrazor.utils import get_package_placeholder, get_placeholder + GraphModule = get_placeholder('torch>=1.13') + ObservedGraphModule = get_placeholder('torch>=1.13') enable_fake_quant = get_placeholder('torch>=1.13') prepare = get_placeholder('torch>=1.13') QConfigMapping = get_placeholder('torch>=1.13') @@ -56,17 +62,43 @@ @MODELS.register_module() class NativeQuantizer(BaseQuantizer): - """tmp.""" + """Native class for quantizer. + + Args: + global_qconfig (Union[Dict, Config]): Config for quantization details + of weight and activation include observer, quantizer, and qscheme. + no_observer_modules (Optional[List]): Modules don't need observer. + To fit different backend, we need qconfig to determine the modules + which don't need observer. + tracer (Dict): Config for tracer to trace modules for torch fx . + + Raises: + NotImplementedError: _description_ + + Examples: + >>> global_qconfig = dict( + ... w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + ... a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + ... w_fake_quant=dict(type='mmrazor.FakeQuantize'), + ... a_fake_quant=dict(type='mmrazor.FakeQuantize'), + ... w_qscheme=dict( + ... qdtype='qint8', bit=8, is_symmetry=True, + ... is_symmetric_range=True), + ... a_qscheme=dict( + ... qdtype='quint8', bit=8, is_symmetry=True, + ... averaging_constant=0.1), +) + """ # backend: 'native' # support_w_modes = ['per_tensor', 'per_channel'] # support_a_modes = ['per_tensor'] def __init__(self, - global_qconfig, - no_observer_modules=None, - tracer=dict(type='CustomTracer'), - extra_redundant_fakequants=dict( + global_qconfig: Union[Dict, Config], + no_observer_modules: Optional[List] = None, + tracer: Dict = dict(type='CustomTracer'), + extra_redundant_fakequants: Dict = dict( extra_module_prev_wo_fakequant=tuple(), extra_module_next_wo_fakequant=tuple(), extra_function_prev_wo_fakequant=tuple(), @@ -117,7 +149,28 @@ def support_a_modes(self): return ['per_tensor'] def prepare(self, model, graph_module): - """tmp.""" + """prepare graph to ObservedGraphModule. + + Args: + graph_module (_type_): GraphModules before fuse. + + Returns: + ObservedGraphModule: GraphModules after fuse and observer. + + Notes: + 'graph_module' after '_fuse_fx()' function will fuse conv, BN, ReLU + into modules in SUPPORT_QAT_MODULES. + 'graph_module' after 'prepare()' function will become observed. + + Notes: + Keep `is_qat` is True is because in Pytorch when `is_qat` is false, + the `_fuse_fx()` function only fuse module into `nn.Squential`. + In mmrazor, we aim to add more ptq algorithm into our pipeline such + as Adaround, these kind of ptq method have some additional + fake_quant operations that we need it to be fused into our + `SUPPORT_QAT_MODULES` type, which is a tricky way to deal with it. + """ + graph_module = _fuse_fx( graph_module=graph_module, is_qat=True, @@ -134,18 +187,41 @@ def prepare(self, model, graph_module): return prepared def post_process_weight_fakequant(self, - observed_module, - keep_fake_quant=False): - """tmp.""" + observed_module: ObservedGraphModule, + keep_fake_quant: bool = False): + """weight fake-quant for supported QAT modules. + + Args: + observed_module (ObservedGraphModule): Modules after fused and + observed. + keep_fake_quant (bool, optional): Bool to determine whether to keep + fake-quant op, depending on the backend. Defaults to False. + + Note: + `post_process_weight_fakequant()` function is necessary that the + `SUPPORT_QAT_MODULES` will be convert to normal modules, and + BN will be really integrated into conv layers. + """ def traverse(module): for name, child in module.named_children(): + # Trace `SUPPORT_QAT_MODULES` recursively. if isinstance(child, SUPPORT_QAT_MODULES): + # We add w_fakequant once in case some ptq methods have + # specific operations such as Adaround. So we do Quantize + # to perform these operations and do dequantize to + # introduce quantization loss in advance. weight_fakequant = child.weight_fake_quant child.weight.data = weight_fakequant(child.weight.data) + # `to_float()` function fuse BN into conv or conv_relu, and + # also convert a qat module to a normal module. + # source url: https://github.com/pytorch/pytorch/blob/master/torch/nn/intrinsic/qat/modules/conv_fused.py # noqa: E501 float_child = child.to_float() + # This is decided by backend type, some backend need + # explicitly keep the fake quant structure, others don't. + # TODO add deploy doc link if keep_fake_quant: for m in float_child.modules(): setattr(m, 'qconfig', self.qconfig.convert()) @@ -166,12 +242,24 @@ def traverse(module): observed_module.apply(enable_fake_quant) traverse(observed_module) - def prepare_for_mmdeploy(self, model, dummy_input, checkpoint): - """tmp.""" + def prepare_for_mmdeploy(self, model: nn.Module, dummy_input: Tuple, + checkpoint: Optional[str]): + """Prepare model to Observed_model.""" raise NotImplementedError - def del_redundant_fakequant(self, prepared): - """tmp.""" + def del_redundant_fakequant(self, prepared: GraphModule): + """delete redundant fakequant op in prepared model. + + Returns: + prepared (GraphModule): prepared model after delete redundant + fakequant op. + + Notes: + We can configure different ways to delete redundant nodes: + @property + def module_prev_wo_fakequant(self): + return (torch.nn.ReLU6, torch.nn.Identity) + """ extra_module_prev_wo_fakequant = self.extra_redundant_fakequants.get( 'extra_module_prev_wo_fakequant', tuple()) prepared = del_fakequant_before_module( diff --git a/tests/test_models/test_algorithms/test_mm_architecture.py b/tests/test_models/test_algorithms/test_mm_architecture.py new file mode 100644 index 000000000..4862bff91 --- /dev/null +++ b/tests/test_models/test_algorithms/test_mm_architecture.py @@ -0,0 +1,166 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import shutil +import tempfile +from unittest import TestCase + +import torch +import torch.nn as nn + +try: + from torch.fx import GraphModule +except ImportError: + from mmrazor.utils import get_placeholder + GraphModule = get_placeholder('torch>=1.13') + +from mmengine.model import BaseModel + +from mmrazor import digit_version +from mmrazor.models.algorithms import MMArchitectureQuant +from mmrazor.registry import MODELS + + +class BasicBlock(nn.Module): + + def __init__(self, in_channels, out_channels): + super(BasicBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.mid_channels = out_channels + + self.norm1 = nn.BatchNorm2d(self.mid_channels) + self.norm2 = nn.BatchNorm2d(out_channels) + self.conv1 = nn.Conv2d(in_channels, self.mid_channels, 1) + self.conv2 = nn.Conv2d(self.mid_channels, out_channels, 1) + + self.relu = nn.ReLU6() + self.drop_path = nn.Identity() + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + + out = self.drop_path(out) + + out += identity + + return out + + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +@MODELS.register_module() +class ToyQuantModel(BaseModel): + + def __init__(self): + super().__init__() + self.stem_layer = nn.Sequential( + nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3), nn.ReLU()) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.block = BasicBlock(3, 3) + self.block2 = BasicBlock(3, 3) + self.gap = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(3, 4) + + def forward(self, x): + x = self.stem_layer(x) + x = self.maxpool(x) + x = self.block(x) + x = self.block2(x) + x = self.gap(x) + x = x.flatten(1) + x = self.fc(x) + return x + + +class TestMMArchitectureQuant(TestCase): + + def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + self.temp_dir = tempfile.mkdtemp() + filename = 'fp_model.pth' + filename = os.path.join(self.temp_dir, filename) + # import pdb; pdb.set_trace() + toymodel = ToyQuantModel() + torch.save(toymodel.state_dict(), filename) + + global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', + bit=8, + is_symmetry=True, + is_symmetric_range=True), + a_qscheme=dict( + qdtype='quint8', + bit=8, + is_symmetry=True, + averaging_constant=0.1), + ) + alg_kwargs = dict( + type='mmrazor.MMArchitectureQuant', + architecture=dict(type='ToyQuantModel'), + float_checkpoint=filename, + quantizer=dict( + type='mmrazor.OpenVINOQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ]))) + self.alg_kwargs = alg_kwargs + self.toy_model = MODELS.build(self.alg_kwargs) + + def tearDown(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + shutil.rmtree(self.temp_dir) + + def test_init(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + assert isinstance(self.toy_model, MMArchitectureQuant) + assert hasattr(self.toy_model, 'quantizer') + + def test_sync_qparams(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + mode = self.toy_model.forward_modes[0] + self.toy_model.sync_qparams(mode) + w_loss = self.toy_model.qmodels['loss'].block.conv1.state_dict( + )['weight'] + w_tensor = self.toy_model.qmodels['tensor'].block.conv1.state_dict( + )['weight'] + w_pred = self.toy_model.qmodels['predict'].block.conv1.state_dict( + )['weight'] + assert w_loss.equal(w_pred) + assert w_loss.equal(w_tensor) + + def test_build_qmodels(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + for forward_modes in self.toy_model.forward_modes: + qmodels = self.toy_model.qmodels[forward_modes] + assert isinstance(qmodels, GraphModule) + + def test_calibrate_step(self): + # TODO + pass diff --git a/tests/test_models/test_quantizers/test_native_quantizer.py b/tests/test_models/test_quantizers/test_native_quantizer.py new file mode 100644 index 000000000..afd6011ed --- /dev/null +++ b/tests/test_models/test_quantizers/test_native_quantizer.py @@ -0,0 +1,228 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import collections +from unittest import TestCase + +import torch +import torch.nn as nn + +from mmrazor import digit_version +from mmrazor.models.quantizers.native_quantizer import SUPPORT_QAT_MODULES +from mmrazor.models.task_modules.tracer import CustomTracer +from mmrazor.models.task_modules.tracer.fx.custom_tracer import \ + build_graphmodule +from mmrazor.registry import MODELS +from mmrazor.structures.quantization import BackendConfigs, QConfigHander + +try: + from torch.ao.quantization.fx import prepare + from torch.ao.quantization.fx.graph_module import ObservedGraphModule + from torch.ao.quantization.qconfig_mapping import QConfigMapping + from torch.ao.quantization.quantize_fx import _fuse_fx + from torch.fx import GraphModule +except ImportError: + from mmrazor.utils import get_placeholder + GraphModule = get_placeholder('torch>=1.13') + ObservedGraphModule = get_placeholder('torch>=1.13') + QConfigMapping = get_placeholder('torch>=1.13') + prepare = get_placeholder('torch>=1.13') + _fuse_fx = get_placeholder('torch>=1.13') + + +class BasicBlock(nn.Module): + + def __init__(self, in_channels, out_channels): + super(BasicBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.mid_channels = out_channels + + self.norm1 = nn.BatchNorm2d(self.mid_channels) + self.norm2 = nn.BatchNorm2d(out_channels) + self.conv1 = nn.Conv2d(in_channels, self.mid_channels, 1) + self.conv2 = nn.Conv2d(self.mid_channels, out_channels, 1) + + self.relu = nn.ReLU6() + self.drop_path = nn.Identity() + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + + out = self.drop_path(out) + + out += identity + + return out + + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +class ToyQuantModel(nn.Module): + + def __init__(self): + super().__init__() + self.stem_layer = nn.Sequential( + nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3), nn.ReLU()) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.block = BasicBlock(3, 3) + self.block2 = BasicBlock(3, 3) + self.gap = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(3, 4) + + def forward(self, x): + x = self.stem_layer(x) + x = self.maxpool(x) + x = self.block(x) + x = self.block2(x) + x = self.gap(x) + x = x.flatten(1) + x = self.fc(x) + return x + + +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict( + qdtype='quint8', bit=8, is_symmetry=True, averaging_constant=0.1)) + +no_observer_modules = [ + 'torch.nn.Conv2d', +] + +q_kwargs = dict( + type='mmrazor.NativeQuantizer', + global_qconfig=global_qconfig, + no_observer_modules=no_observer_modules, + tracer=dict(type='CustomTracer'), +) + + +class TestNativeQuantizer(TestCase): + """TODO. + + Args: + TestCase (_type_): _description_ + """ + + def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + self.q_kwargs = q_kwargs + self.tracer = CustomTracer() + self.backend_config = BackendConfigs['native'] + self.qconfig = QConfigHander(global_qconfig) + self.qconfig_mapping = QConfigMapping().set_global( + self.qconfig.convert()) + self.example_inputs = (torch.randn(1, 3, 224, 224), ) + self.native_quantizer = MODELS.build(self.q_kwargs) + + def tearDown(self): + pass + + def swap_ff_with_fxff(self, model): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + modules_to_swap = [] + for name, module in model.named_children(): + if isinstance(module, torch.ao.nn.quantized.FloatFunctional): + modules_to_swap.append(name) + else: + self.swap_ff_with_fxff(module) + + for name in modules_to_swap: + del model._modules[name] + model._modules[name] = torch.ao.nn.quantized.FXFloatFunctional() + + def test_init(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + native_quantizer = MODELS.build(self.q_kwargs) + no_ob_dict = collections.OrderedDict() + no_ob_dict = no_ob_dict.fromkeys(native_quantizer.no_observer_modules, + None) + assert native_quantizer.qconfig_mapping.object_type_qconfigs == \ + no_ob_dict + + def test_prepare(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + toy_model = ToyQuantModel() + toy_model.eval() + + self.swap_ff_with_fxff(toy_model) + traced_graph = self.tracer.trace(toy_model) + graph_module = build_graphmodule(toy_model, traced_graph) + + graph_module = _fuse_fx( + graph_module=graph_module, + is_qat=True, + backend_config=self.backend_config) + assert isinstance(graph_module, GraphModule) + prepared = prepare( + model=graph_module, + qconfig_mapping=self.qconfig_mapping, + is_qat=True, + node_name_to_scope=self.tracer.node_name_to_scope, + example_inputs=self.example_inputs, + backend_config=self.backend_config) + assert isinstance(prepared, ObservedGraphModule) + + prepared = self.native_quantizer.del_redundant_fakequant(prepared) + assert isinstance(prepared, GraphModule) + + def test_post_process_weight_fakequant(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + toy_model = ToyQuantModel() + toy_model.eval() + + self.swap_ff_with_fxff(toy_model) + traced_graph = self.tracer.trace(toy_model) + graph_module = build_graphmodule(toy_model, traced_graph) + + graph_module = _fuse_fx( + graph_module=graph_module, + is_qat=True, + backend_config=self.backend_config) + assert isinstance(graph_module, GraphModule) + prepared = prepare( + model=graph_module, + qconfig_mapping=self.qconfig_mapping, + is_qat=True, + node_name_to_scope=self.tracer.node_name_to_scope, + example_inputs=self.example_inputs, + backend_config=self.backend_config) + assert isinstance(prepared, ObservedGraphModule) + + prepared = self.native_quantizer.del_redundant_fakequant(prepared) + assert isinstance(prepared, GraphModule) + + prepared_no_fq = prepared + + self.native_quantizer.post_process_weight_fakequant(prepared) + for name, child in prepared.named_children(): + if isinstance(child, SUPPORT_QAT_MODULES): + raise ValueError + self.native_quantizer.post_process_weight_fakequant( + prepared_no_fq, True) + for name, child in prepared_no_fq.named_children(): + if isinstance(child, SUPPORT_QAT_MODULES): + raise ValueError