From a18b991fe71908043a1fa2e902434bebeac7e515 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Thu, 12 Jan 2023 19:19:27 +0800 Subject: [PATCH 01/19] modify ptq pipeline and support lsq --- ...tq_openvino_mbv2_8xb32_in1k_calib32xb32.py | 1 + ...penvino_resnet18_8xb32_in1k_calib32xb32.py | 1 + ...penvino_resnet50_8xb32_in1k_calib32xb32.py | 1 + .../ptq/ptq_openvino_retina_r50_1x_coco.py | 51 ++++ .../lsq_openvino_resnet18_8xb16_cifar10.py | 61 ---- .../qat/lsq_openvino_resnet18_8xb32_in1k.py | 71 +++++ mmrazor/engine/runner/quantization_loops.py | 107 ++++++- .../quantization/mm_architecture.py | 68 ++++- mmrazor/models/fake_quants/__init__.py | 7 +- mmrazor/models/fake_quants/lsq.py | 275 ++++++++++++++++++ mmrazor/models/observers/__init__.py | 6 +- mmrazor/models/observers/lsq.py | 129 ++++++++ .../models/quantizers/academic_quantizer.py | 3 + mmrazor/models/quantizers/base.py | 34 +++ mmrazor/models/quantizers/native_quantizer.py | 28 +- .../models/quantizers/openvino_quantizer.py | 1 + .../task_modules/tracer/fx/custom_tracer.py | 30 ++ 17 files changed, 785 insertions(+), 89 deletions(-) create mode 100644 configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco.py delete mode 100644 configs/quantization/qat/lsq_openvino_resnet18_8xb16_cifar10.py create mode 100644 configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py create mode 100644 mmrazor/models/fake_quants/lsq.py create mode 100644 mmrazor/models/observers/lsq.py diff --git a/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py index d7c9cdf47..97333a282 100644 --- a/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py +++ b/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py @@ -26,6 +26,7 @@ float_checkpoint=float_checkpoint, quantizer=dict( type='mmrazor.OpenVINOQuantizer', + is_qat=False, global_qconfig=global_qconfig, tracer=dict( type='mmrazor.CustomTracer', diff --git a/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py index 5ba1eec85..36bea3bc9 100644 --- a/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py +++ b/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py @@ -28,6 +28,7 @@ float_checkpoint=float_checkpoint, quantizer=dict( type='mmrazor.OpenVINOQuantizer', + is_qat=False, global_qconfig=global_qconfig, tracer=dict( type='mmrazor.CustomTracer', diff --git a/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py index bd734ee40..3f7740b02 100644 --- a/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py +++ b/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py @@ -28,6 +28,7 @@ float_checkpoint=float_checkpoint, quantizer=dict( type='mmrazor.OpenVINOQuantizer', + is_qat=False, global_qconfig=global_qconfig, tracer=dict( type='mmrazor.CustomTracer', diff --git a/configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco.py b/configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco.py new file mode 100644 index 000000000..36bd81a0a --- /dev/null +++ b/configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco.py @@ -0,0 +1,51 @@ +_base_ = ['mmdet::retinanet/retinanet_r50_fpn_1x_coco.py'] + +train_dataloader = dict(batch_size=32) + +test_cfg = dict( + type='mmrazor.PTQLoop', + calibrate_dataloader=train_dataloader, + calibrate_steps=32, +) + +retina = _base_.model +# data_preprocessor = retina.data_preprocessor +float_ckpt = '/mnt/petrelfs/caoweihan.p/ckpt/retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth' # noqa: E501 + +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), +) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='MMArchitectureQuant', + data_preprocessor=dict( + type='mmdet.DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32), + architecture=retina, + float_checkpoint=float_ckpt, + quantizer=dict( + type='mmrazor.OpenVINOQuantizer', + is_qat=False, + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmdet.models.dense_heads.base_dense_head.BaseDenseHead.predict_by_feat', # noqa: E501 + 'mmdet.models.dense_heads.anchor_head.AnchorHead.loss_by_feat', + ]))) + +model_wrapper_cfg = dict( + type='mmrazor.MMArchitectureQuantDDP', + broadcast_buffers=False, + find_unused_parameters=True) diff --git a/configs/quantization/qat/lsq_openvino_resnet18_8xb16_cifar10.py b/configs/quantization/qat/lsq_openvino_resnet18_8xb16_cifar10.py deleted file mode 100644 index 8076769a9..000000000 --- a/configs/quantization/qat/lsq_openvino_resnet18_8xb16_cifar10.py +++ /dev/null @@ -1,61 +0,0 @@ -_base_ = ['mmcls::resnet/resnet18_8xb16_cifar10.py'] - -resnet = _base_.model -float_ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.pth' # noqa: E501 - -model = dict( - _delete_=True, - _scope_='mmrazor', - type='MMArchitectureQuant', - architecture=resnet, - float_checkpoint=float_ckpt, - quantizer=dict( - type='OpenvinoQuantizer', - skipped_methods=[ - 'mmcls.models.heads.ClsHead._get_loss', - 'mmcls.models.heads.ClsHead._get_predictions' - ], - qconfig=dict( - qtype='affine', - w_observer=dict(type='mmrazor.LSQObserver'), - a_observer=dict(type='mmrazor.LSQObserver'), - w_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), - a_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), - w_qscheme=dict( - bit=8, - is_symmetry=True, - is_per_channel=True, - is_pot_scale=False, - ), - a_qscheme=dict( - bit=8, - is_symmetry=False, - is_per_channel=False, - is_pot_scale=False), - ))) - -optim_wrapper = dict( - optimizer=dict(type='SGD', lr=0.004, momentum=0.9, weight_decay=0.0001)) - -# learning policy -param_scheduler = dict( - _delete_=True, - type='CosineAnnealingLR', - T_max=100, - by_epoch=True, - begin=0, - end=100) - -model_wrapper_cfg = dict( - type='mmrazor.MMArchitectureQuantDDP', - broadcast_buffers=False, - find_unused_parameters=True) - -# train, val, test setting -train_cfg = dict( - _delete_=True, - type='mmrazor.QATEpochBasedLoop', - max_epochs=100, - val_interval=1) -val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop') -# test_cfg = val_cfg diff --git a/configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py b/configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py new file mode 100644 index 000000000..fef3ed1f1 --- /dev/null +++ b/configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py @@ -0,0 +1,71 @@ +_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] + +resnet = _base_.model +float_ckpt = '/mnt/petrelfs/caoweihan.p/ckpt/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501 + +global_qconfig = dict( + w_observer=dict(type='mmrazor.LSQPerChannelObserver'), + a_observer=dict(type='mmrazor.LSQObserver'), + w_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), + a_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), +) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='MMArchitectureQuant', + data_preprocessor=dict( + type='mmcls.ClsDataPreprocessor', + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True), + architecture=resnet, + float_checkpoint=float_ckpt, + quantizer=dict( + type='mmrazor.OpenVINOQuantizer', + is_qat=True, + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ]))) + +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.0001, momentum=0.9, weight_decay=0.0001)) + +# learning policy +param_scheduler = dict( + _delete_=True, + type='CosineAnnealingLR', + T_max=100, + by_epoch=True, + begin=0, + end=100) + +model_wrapper_cfg = dict( + type='mmrazor.MMArchitectureQuantDDP', + broadcast_buffers=False, + find_unused_parameters=True) + +# train, val, test setting +train_cfg = dict( + _delete_=True, + type='mmrazor.LSQEpochBasedLoop', + max_epochs=100, + val_interval=1) +val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop') +test_cfg = val_cfg + +default_hooks = dict( + checkpoint=dict( + type='CheckpointHook', + interval=-1, + out_dir='/mnt/petrelfs/caoweihan.p/training_ckpt/lsq')) diff --git a/mmrazor/engine/runner/quantization_loops.py b/mmrazor/engine/runner/quantization_loops.py index e90715910..299549dae 100644 --- a/mmrazor/engine/runner/quantization_loops.py +++ b/mmrazor/engine/runner/quantization_loops.py @@ -18,6 +18,8 @@ from torch.utils.data import DataLoader +from mmrazor.models.fake_quants import (enable_param_learning, + enable_static_estimate, enable_val) from mmrazor.registry import LOOPS @@ -30,13 +32,13 @@ class QATEpochBasedLoop(EpochBasedTrainLoop): dataloader (Dataloader or dict): An iterator to generate one batch of dataset each iteration. max_epochs (int): Total training epochs. - val_begin (int): The epoch that begins validating. - Defaults to 1. + val_begin (int): The epoch that begins validating. Defaults to 1. val_interval (int): Validation interval. Defaults to 1. disable_observer_begin (int): The number of total epochs to update - observers. + observers. Defaults to -1, which means observers are enabled + all the time. freeze_bn_begin (int): The number of total epochs to update batch norm - stats. + stats. Defaults to -1, which means no need to freeze bn. dynamic_intervals (List[Tuple[int, int]], optional): The first element in the tuple is a milestone and the second element is a interval. The interval is used after the @@ -50,8 +52,8 @@ def __init__( max_epochs: int, val_begin: int = 1, val_interval: int = 1, - disable_observer_begin: int = 3, - freeze_bn_begin: int = 3, + disable_observer_begin: int = -1, + freeze_bn_begin: int = -1, dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None: super().__init__(runner, dataloader, max_epochs, val_begin, val_interval, dynamic_intervals) @@ -59,14 +61,24 @@ def __init__( self.disable_observer_begin = disable_observer_begin self.freeze_bn_begin = freeze_bn_begin + def prepare_for_run_epoch(self): + """Toggle the state of the observers and fake quantizers before qat + training.""" + self.runner.model.apply(enable_fake_quant) + self.runner.model.apply(enable_observer) + + def prepare_for_val(self): + """Toggle the state of the observers and fake quantizers before + validation.""" + self.runner.model.apply(enable_fake_quant) + self.runner.model.apply(disable_observer) + def run(self) -> torch.nn.Module: """Launch training.""" self.runner.call_hook('before_train') while self._epoch < self._max_epochs: - # state: observer_enabled, fakequant_enabled - self.runner.model.apply(enable_fake_quant) - self.runner.model.apply(enable_observer) + self.prepare_for_run_epoch() self.run_epoch() self._decide_current_val_interval() @@ -74,8 +86,8 @@ def run(self) -> torch.nn.Module: and self._epoch >= self.val_begin and self._epoch % self.val_interval == 0): # observer disabled during evaluation - self.runner.model.apply(enable_fake_quant) - self.runner.model.apply(disable_observer) + self.prepare_for_val() + self.runner.model.sync_qparams(src='loss') self.runner.val_loop.run() self.runner.call_hook('after_train') @@ -99,6 +111,79 @@ def run_epoch(self) -> None: self._epoch += 1 +@LOOPS.register_module() +class LSQEpochBasedLoop(QATEpochBasedLoop): + """`EpochBasedLoop` for `LEARNED STEP SIZE QUANTIZATION` + + Paper: Learned Step Size Quantization. + + Args: + runner (Runner): A reference of runner + dataloader (Dataloader or dict): An iterator to generate one batch of + dataset each iteration. + max_epochs (int): Total training epochs. + val_begin (int): The epoch that begins validating. Defaults to 1. + val_interval (int): Validation interval. Defaults to 1. + freeze_bn_begin (int): The number of total epochs to update batch norm + stats. Defaults to -1, which means no need to freeze bn. + dynamic_intervals (List[Tuple[int, int]], optional): The + first element in the tuple is a milestone and the second + element is a interval. The interval is used after the + corresponding milestone. Defaults to None. + """ + + def __init__( + self, + runner, + dataloader: Union[DataLoader, Dict], + max_epochs: int, + val_begin: int = 1, + val_interval: int = 1, + freeze_bn_begin: int = -1, + dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None: + super().__init__( + runner, + dataloader, + max_epochs, + val_begin, + val_interval, + freeze_bn_begin=freeze_bn_begin, + dynamic_intervals=dynamic_intervals) + + self.is_first_batch = True + + def prepare_for_run_epoch(self): + """Toggle the state of the observers and fake quantizers before qat + training.""" + pass + + def prepare_for_val(self): + """Toggle the state of the observers and fake quantizers before + validation.""" + self.runner.model.apply(enable_val) + + def run_epoch(self) -> None: + """Iterate one epoch.""" + self.runner.call_hook('before_train_epoch') + self.runner.model.train() + + # TODO freeze bn + if self._epoch >= self.freeze_bn_begin: + self.runner.model.apply(freeze_bn_stats) + + for idx, data_batch in enumerate(self.dataloader): + if self.is_first_batch: + # lsq init + self.is_first_batch = False + self.runner.model.apply(enable_static_estimate) + else: + self.runner.model.apply(enable_param_learning) + self.run_iter(idx, data_batch) + + self.runner.call_hook('after_train_epoch') + self._epoch += 1 + + @LOOPS.register_module() class QATValLoop(ValLoop): """`ValLoop` for `QuantizationAwareTraining` diff --git a/mmrazor/models/algorithms/quantization/mm_architecture.py b/mmrazor/models/algorithms/quantization/mm_architecture.py index afdd7799c..b58486894 100644 --- a/mmrazor/models/algorithms/quantization/mm_architecture.py +++ b/mmrazor/models/algorithms/quantization/mm_architecture.py @@ -12,10 +12,13 @@ from ..base import BaseAlgorithm, BaseModel try: - from torch.ao.quantization import FakeQuantizeBase + from torch.ao.quantization import (FakeQuantizeBase, MinMaxObserver, + PerChannelMinMaxObserver) except ImportError: from mmrazor.utils import get_placeholder FakeQuantizeBase = get_placeholder('torch>=1.13') + MinMaxObserver = get_placeholder('torch>=1.13') + PerChannelMinMaxObserver = get_placeholder('torch>=1.13') LossResults = Dict[str, torch.Tensor] TensorResults = Union[Tuple[torch.Tensor], torch.Tensor] @@ -59,23 +62,43 @@ def __init__(self, init_cfg: Optional[Dict] = None): if data_preprocessor is None: - data_preprocessor = {} - # The build process is in MMEngine, so we need to add scope here. - # Default to mmcls.ClsDataPreprocessor. - data_preprocessor.setdefault('type', 'mmcls.ClsDataPreprocessor') + data_preprocessor = getattr(architecture, 'data_preprocessor', + dict()) super().__init__(architecture, data_preprocessor, init_cfg) - # If we have a float_checkpoint, we load it as pretrain. - if float_checkpoint: - _ = load_checkpoint(self.architecture, float_checkpoint) - self.architecture._is_init = True self.quantizer = MODELS.build(quantizer) self.input_shapes = input_shapes self.forward_modes = forward_modes + # Replace syncbn and _BatchNormXd (in mmengine) with batchnorm2d + self.quantizer.convert_batchnorm2d(self.architecture) + + # If we have a float_checkpoint, we load it as pretrain. + if float_checkpoint: + _ = load_checkpoint(self.architecture, float_checkpoint) + self.architecture._is_init = True + self.qmodels = self._build_qmodels(self.architecture) + self.sync_qparams('tensor') + self.reset_observer_and_fakequant_statistics(self) - self.sync_qparams('predict') + def reset_observer_and_fakequant_statistics(self, model): + """Reset the statistics in observers and fake quantizers. + + The forward computation in `_build_qmodels` can modify the original + statistics in observers and fake quantizers. + """ + for module in model.modules(): + if isinstance(module, MinMaxObserver): + module.reset_min_max_vals() + elif isinstance(module, PerChannelMinMaxObserver): + min_val = torch.rand(0, ) + max_val = torch.rand(0, ) + module.min_val.resize_(min_val.shape).copy_(min_val) + module.max_val.resize_(max_val.shape).copy_(max_val) + elif isinstance(module, FakeQuantizeBase): + module.scale.data = torch.ones_like(module.scale) + module.zero_point.data = torch.zeros_like(module.zero_point) def sync_qparams(self, src_mode: str): """Sync all quantize parameters in different `forward_modes`. We could @@ -106,10 +129,10 @@ def traverse(module, prefix): if src_param.shape == param.shape: param.data.copy_(src_param) else: - # requirs_grad = param.requires_grad - # param.requires_grad = False + requirs_grad = param.requires_grad + param.requires_grad = False param.resize_(src_param.shape) - # param.requires_grad = requirs_grad + param.requires_grad = requirs_grad param.data.copy_(src_param) for name, buffer in child.named_buffers(): buffer_name = f'{child_name}.{name}' @@ -160,6 +183,21 @@ def _build_qmodels(self, model: BaseModel): observed_module = self.quantizer.prepare(model, concrete_args) qmodels[mode] = observed_module + is_training = qmodels['tensor'].training + # Avoid random input changing bn's statistics + qmodels['tensor'].eval() + # Originally, the steps to train a qat model is as follows: + # 1. build qmodels 2. convert the model to ddpmodel 3. forward backward + # The shape of `scale` and `zero_point` can be modified during forward. + # We initialize these parameters with per-tensor mode by default for + # convenience. Their shape will be modified during forward if + # per-channel mode is used. It's hacky. Hence we need to input a + # dummy input to make sure the shape has been modified. + device = next(qmodels.parameters()).device + dummy_input = torch.randn(self.input_shapes).to(device) + qmodels['tensor'](dummy_input, None, 'tensor') + qmodels['tensor'].train(mode=is_training) + return qmodels def forward(self, @@ -183,7 +221,7 @@ def calibrate_step(self, data: Union[Dict, Tuple, List]): @MODEL_WRAPPERS.register_module() class MMArchitectureQuantDDP(MMDistributedDataParallel): - """DDPwapper for GeneralQuant. + """DDPwapper for MMArchitectureQuant. Args: device_ids (Optional[Union[List, int, torch.device]]): devices to run @@ -203,6 +241,8 @@ def __init__(self, # (`model.cuda()`), the buffers in model are different. self.module.qmodels = self.module._build_qmodels( self.module.architecture) + self.module.sync_qparams('tensor') + self.module.reset_observer_and_fakequant_statistics(self) def calibrate_step(self, data: Union[Dict, Tuple, List]): """PTQ method need calibrate by cali data.""" diff --git a/mmrazor/models/fake_quants/__init__.py b/mmrazor/models/fake_quants/__init__.py index 9030660f6..950821210 100644 --- a/mmrazor/models/fake_quants/__init__.py +++ b/mmrazor/models/fake_quants/__init__.py @@ -1,5 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from .base import BaseFakeQuantize +from .lsq import (LearnableFakeQuantize, enable_param_learning, + enable_static_estimate, enable_val) from .torch_fake_quants import register_torch_fake_quants -__all__ = ['BaseFakeQuantize', 'register_torch_fake_quants'] +__all__ = [ + 'BaseFakeQuantize', 'register_torch_fake_quants', 'LearnableFakeQuantize', + 'enable_val', 'enable_param_learning', 'enable_static_estimate' +] diff --git a/mmrazor/models/fake_quants/lsq.py b/mmrazor/models/fake_quants/lsq.py new file mode 100644 index 000000000..9b9ca0dd0 --- /dev/null +++ b/mmrazor/models/fake_quants/lsq.py @@ -0,0 +1,275 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.nn.parameter import Parameter + +from mmrazor.registry import MODELS + +try: + from torch.ao.quantization import FakeQuantizeBase +except ImportError: + from mmrazor.utils import get_placeholder + FakeQuantizeBase = get_placeholder('torch>=1.13') + + +def enable_param_learning(mod): + """Enables learning of quantization parameters, if applicable. Example + usage:: + + # model is any PyTorch model model.apply(enable_param_learning) + """ + if isinstance(mod, LearnableFakeQuantize): + mod.enable_param_learning() + + +def enable_static_estimate(mod): + """Enables static observer estimates, if applicable. Example usage:: + + # model is any PyTorch model model.apply(enable_static_estimate) + """ + if isinstance(mod, LearnableFakeQuantize): + mod.enable_static_estimate() + + +def enable_val(mod): + """Enable validation, if applicable. Example usage:: + + # model is any PyTorch model model.apply(enable_val) + """ + if isinstance(mod, LearnableFakeQuantize): + mod.enable_val() + + +@MODELS.register_module() +class LearnableFakeQuantize(FakeQuantizeBase): + """This is an extension of the FakeQuantize module in fake_quantize.py, + which supports learning of the scale and zero point parameters through + backpropagation. + + In addition to the attributes in the original FakeQuantize module, the + LearnableFakeQuantize module also includes the following attributes to + support quantization parameter learning. + + * :attr:`fake_quant_enabled` defines the flag for enabling fake + quantization on the output. + + * :attr:`static_enabled` defines the flag for using observer's static + estimation for scale and zero point. + + * :attr:`learning_enabled` defines the flag for enabling backpropagation + for scale and zero point. + + Args: + observer (module): Module for observing statistics on input tensors and + calculating scale and zero-point. + quant_min (int): Minimum quantization value. If unspecified, it will + follow the 8-bit setup. + quant_max (int): Maximum quantization value. If unspecified, it will + follow the 8-bit setup. + scale (float): The initial value of the floating-point scale factor. + Defaults to 1. + zero_point (float): The initial value of the floating-point zero-point. + Defaults to 0. + use_grad_scaling (bool): Whether the gradients for scale and zero point + are normalized by the constant, which is proportional to the square + root of the number of elements in the tensor. The related + literature justifying the use of this particular constant can be + found here: https://openreview.net/pdf?id=rkgO66VKDS. Defaults to + True. + zero_point_trainable (bool): Whether the zero_point is trainable. + Defaults to False. + observer_kwargs (dict | optional): Arguments for the observer module. + """ + + def __init__(self, + observer, + quant_min=0, + quant_max=255, + scale=1., + zero_point=0., + use_grad_scaling=True, + zero_point_trainable=False, + **observer_kwargs): + super(LearnableFakeQuantize, self).__init__() + assert quant_min < quant_max, \ + 'quant_min must be strictly less than quant_max.' + self.quant_min = quant_min + self.quant_max = quant_max + # also pass quant_min and quant_max to observer + observer_kwargs['quant_min'] = quant_min + observer_kwargs['quant_max'] = quant_max + self.use_grad_scaling = use_grad_scaling + + self.scale = Parameter(torch.tensor([scale])) + self.zero_point_trainable = zero_point_trainable + if zero_point_trainable: + self.zero_point = Parameter(torch.tensor([zero_point])) + else: + self.register_buffer('zero_point', torch.tensor([zero_point])) + + self.activation_post_process = observer(**observer_kwargs) + assert \ + torch.iinfo(self.activation_post_process.dtype).min <= quant_min, \ + 'quant_min out of bound' + assert \ + quant_max <= torch.iinfo(self.activation_post_process.dtype).max, \ + 'quant_max out of bound' + self.dtype = self.activation_post_process.dtype + self.qscheme = self.activation_post_process.qscheme + self.ch_axis = self.activation_post_process.ch_axis \ + if hasattr(self.activation_post_process, 'ch_axis') else -1 + self.register_buffer('fake_quant_enabled', + torch.tensor([1], dtype=torch.uint8)) + self.register_buffer('static_enabled', + torch.tensor([1], dtype=torch.uint8)) + self.register_buffer('learning_enabled', + torch.tensor([0], dtype=torch.uint8)) + + bitrange = torch.tensor(quant_max - quant_min + 1).double() + self.bitwidth = int(torch.log2(bitrange).item()) + self.register_buffer('eps', + torch.tensor([torch.finfo(torch.float32).eps])) + + @torch.jit.export + def enable_param_learning(self): + """Enables learning of quantization parameters and disables static + observer estimates. + + Forward path returns fake quantized X. + """ + self.toggle_qparam_learning(enabled=True) \ + .toggle_fake_quant(enabled=True) \ + .toggle_observer_update(enabled=False) + return self + + @torch.jit.export + def enable_static_estimate(self): + """Enables static observer estimates and disables learning of + quantization parameters. + + Forward path returns fake quantized X. + """ + self.toggle_qparam_learning(enabled=False) \ + .toggle_fake_quant(enabled=True) \ + .toggle_observer_update(enabled=True) + + @torch.jit.export + def enable_val(self): + """Disables static observer accumulating data from input and doesn't + update the quantization parameters. + + Forward path returns fake quantized X. + """ + self.toggle_qparam_learning(enabled=False) \ + .toggle_fake_quant(enabled=True) \ + .toggle_observer_update(enabled=False) + + @torch.jit.export + def enable_static_observation(self): + """Enables static observer accumulating data from input but doesn't + update the quantization parameters. + + Forward path returns the original X. + """ + self.toggle_qparam_learning(enabled=False) \ + .toggle_fake_quant(enabled=False) \ + .toggle_observer_update(enabled=True) + + @torch.jit.export + def toggle_observer_update(self, enabled=True): + """Toggles whether static observer accumulates data from input.""" + self.static_enabled[0] = int(enabled) + return self + + @torch.jit.export + def enable_observer(self, enabled=True): + """Enables static observer accumulating data from input.""" + self.toggle_observer_update(enabled) + + @torch.jit.export + def toggle_qparam_learning(self, enabled=True): + """Toggles whether the quantization parameters are learnable.""" + self.learning_enabled[0] = int(enabled) + self.scale.requires_grad = enabled + if self.zero_point_trainable: + self.zero_point.requires_grad = enabled + return self + + @torch.jit.export + def toggle_fake_quant(self, enabled=True): + """Toggles whether the fake quantization is enabled.""" + self.fake_quant_enabled[0] = int(enabled) + return self + + @torch.jit.export + def observe_quant_params(self): + """Shows the quantization parameters.""" + print('LearnableFakeQuantize Scale: {}'.format(self.scale.detach())) + print('LearnableFakeQuantize Zero Point: {}'.format( + self.zero_point.detach())) + + @torch.jit.export + def calculate_qparams(self): + """Calculate the quantization parameters.""" + self.scale.data.clamp_(min=self.eps.item()) + scale = self.scale.detach() + zero_point = self.zero_point.detach().round().clamp( + self.quant_min, self.quant_max).long() + return scale, zero_point + + def forward(self, X): + """Forward computation. + + Forward path returns fake quantized X. + """ + if self.static_enabled[0] == 1: + self.activation_post_process(X.detach()) + _scale, _zero_point = \ + self.activation_post_process.calculate_qparams() + _scale = _scale.to(self.scale.device) + _zero_point = _zero_point.to(self.zero_point.device) + + if self.qscheme in (torch.per_channel_symmetric, + torch.per_channel_affine): + self.scale.data = torch.ones_like(_scale) + self.zero_point.data = torch.zeros_like(_zero_point.float()) + + self.scale.data.copy_(_scale) + self.zero_point.data.copy_(_zero_point) + else: + self.scale.data.clamp_(min=self.eps.item()) + + if self.fake_quant_enabled[0] == 1: + + if self.use_grad_scaling: + grad_factor = 1.0 / (X.numel() * self.quant_max)**0.5 + else: + grad_factor = 1.0 + if self.qscheme in (torch.per_channel_symmetric, + torch.per_channel_affine): + X = torch._fake_quantize_learnable_per_channel_affine( + X, self.scale, self.zero_point, self.ch_axis, + self.quant_min, self.quant_max, grad_factor) + else: + if not (self.quant_min <= self.zero_point <= self.quant_max): + print(self.quant_min, self.zero_point, self.quant_max) + X = torch._fake_quantize_learnable_per_tensor_affine( + X, self.scale, self.zero_point, self.quant_min, + self.quant_max, grad_factor) + + return X + + @torch.jit.export + def extra_repr(self): + """The printable representational string.""" + repr_str = self.__class__.__name__ + repr_str += '(' + repr_str += f'static_enabled={self.static_enabled}, ' + repr_str += f'fake_quant_enabled={self.fake_quant_enabled}' + repr_str += f'quant_min={self.activation_post_process.quant_min}' + repr_str += f'quant_max={self.activation_post_process.quant_max}' + repr_str += f'dtype={self.dtype}' + repr_str += f'qscheme={self.qscheme}' + repr_str += f'scale={self.scale}' + repr_str += f'zero_point={self.zero_point}' + repr_str += ')' + return repr_str diff --git a/mmrazor/models/observers/__init__.py b/mmrazor/models/observers/__init__.py index c82f902f5..84d1677dd 100644 --- a/mmrazor/models/observers/__init__.py +++ b/mmrazor/models/observers/__init__.py @@ -1,5 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. from .base import BaseObserver +from .lsq import LSQObserver, LSQPerChannelObserver from .torch_observers import register_torch_observers -__all__ = ['BaseObserver', 'register_torch_observers'] +__all__ = [ + 'BaseObserver', 'register_torch_observers', 'LSQObserver', + 'LSQPerChannelObserver' +] diff --git a/mmrazor/models/observers/lsq.py b/mmrazor/models/observers/lsq.py new file mode 100644 index 000000000..ccab3b0e6 --- /dev/null +++ b/mmrazor/models/observers/lsq.py @@ -0,0 +1,129 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.distributed as dist + +from mmrazor.registry import MODELS + +try: + from torch.ao.quantization.observer import (MinMaxObserver, + PerChannelMinMaxObserver) +except ImportError: + from mmrazor.utils import get_placeholder + MinMaxObserver = get_placeholder('torch>=1.13') + PerChannelMinMaxObserver = get_placeholder('torch>=1.13') + + +def sync_tensor(tensor): + """Synchronize the target tensor during distributed training.""" + if torch.distributed.is_initialized() and tensor.is_cuda: + tensor.data = tensor.data / dist.get_world_size() + dist.all_reduce(tensor.data) + return tensor + + +class LSQObserverMixIn: + """A mixin class for LSQObserver which can provide the initialized + floating-point scale factor.""" + + def __init__(self): + self.tensor_norm = None + + @torch.jit.export + def _calculate_scale(self): + """Calculate the initialized floating-point scale factor. + + Each layer of weights and each layer of activations has a distinct step + size, represented as a fp32 value, initialized to 2<|v|> / sqrt(Q_p), + computed on either the initial weights values or the first batch of + activations, respectively. + """ + scale = 2 * self.tensor_norm / math.sqrt(self.quant_max) + sync_tensor(scale) + return scale + + +@MODELS.register_module() +class LSQObserver(MinMaxObserver, LSQObserverMixIn): + """LSQ observer. + + Paper: Learned Step Size Quantization. + """ + + def __init__(self, *args, **kwargs): + MinMaxObserver.__init__(self, *args, **kwargs) + LSQObserverMixIn.__init__(self) + + def forward(self, x_orig): + """Records the running minimum, maximum and tensor_norm of ``x``.""" + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() # avoid keeping autograd tape + x = x.to(self.min_val.dtype) + self.tensor_norm = x.abs().mean() + min_val_cur, max_val_cur = torch.aminmax(x) + min_val = torch.min(min_val_cur, self.min_val) + max_val = torch.max(max_val_cur, self.max_val) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + return x_orig + + @torch.jit.export + def calculate_qparams(self): + """Calculates the quantization parameters.""" + _, zero_point = MinMaxObserver.calculate_qparams(self) + scale = LSQObserverMixIn._calculate_scale(self) + return scale, zero_point + + +@MODELS.register_module() +class LSQPerChannelObserver(PerChannelMinMaxObserver, LSQObserverMixIn): + """LSQ per-channel observer. + + Paper: Learned Step Size Quantization. + """ + + def __init__(self, *args, **kwargs): + PerChannelMinMaxObserver.__init__(self, *args, **kwargs) + LSQObserverMixIn.__init__(self) + + def forward(self, x_orig): + """Records the per-channel running minimum, maximum and tensor_norm of + ``x``.""" + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() # avoid keeping autograd tape + min_val = self.min_val + max_val = self.max_val + x_dim = x.size() + + new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x.permute(new_axis_list) + # Need to match dtype of min/max because the updates to buffers + # are done in place and types need to match for comparisons + y = y.to(self.min_val.dtype) + y = torch.flatten(y, start_dim=1) + + self.tensor_norm = y.abs().mean(1) + + if min_val.numel() == 0 or max_val.numel() == 0: + min_val, max_val = torch.aminmax(y, dim=1) + else: + min_val_cur, max_val_cur = torch.aminmax(y, dim=1) + min_val = torch.min(min_val_cur, min_val) + max_val = torch.max(max_val_cur, max_val) + self.min_val.resize_(min_val.shape) + self.max_val.resize_(max_val.shape) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + return x_orig + + @torch.jit.export + def calculate_qparams(self): + """Calculates the quantization parameters.""" + _, zero_point = PerChannelMinMaxObserver.calculate_qparams(self) + scale = LSQObserverMixIn._calculate_scale(self) + return scale, zero_point diff --git a/mmrazor/models/quantizers/academic_quantizer.py b/mmrazor/models/quantizers/academic_quantizer.py index a6cfc257c..c8824e512 100644 --- a/mmrazor/models/quantizers/academic_quantizer.py +++ b/mmrazor/models/quantizers/academic_quantizer.py @@ -103,6 +103,9 @@ def prepare(self, model, concrete_args=None): setattr(graph_module, attr_name, getattr(model, attr_name)) fuse_custom_config = FuseCustomConfig().set_preserved_attributes( preserved_attributes) + + self.sync_module_training_mode(graph_module) + graph_module = _fuse_fx( graph_module=graph_module, is_qat=True, diff --git a/mmrazor/models/quantizers/base.py b/mmrazor/models/quantizers/base.py index 866199735..78c8163c7 100644 --- a/mmrazor/models/quantizers/base.py +++ b/mmrazor/models/quantizers/base.py @@ -3,7 +3,9 @@ from typing import Dict import torch +import torch.nn as nn from mmengine.model import BaseModule +from mmengine.model.utils import _BatchNormXd from mmrazor.registry import TASK_UTILS @@ -24,6 +26,38 @@ def __init__(self, tracer: Dict): super().__init__() self.tracer = TASK_UTILS.build(tracer) + def sync_module_training_mode(self, model, mode=True): + """Synchronize the training modes. + + Note that modes of conv and bn must be the same during ``_fuse_fx``. + """ + for module in model.modules(): + module.training = mode + return + + @staticmethod + def convert_batchnorm2d(model): + """Helper function to convert all :attr:`_BatchNormXd` layers and + :class:`torch.nn.SyncBatchNorm` layers in the model to + :class:`torch.nn.BatchNorm2d` layers. + """ + # todo: Convert all `_BatchNormXd` and `SyncBatchNorm` + # layers to `BatchNorm2d` layers but they may be :attr:`BatchNorm*D` + # layers + module_checklist = [nn.modules.batchnorm.SyncBatchNorm, _BatchNormXd] + + def traverse(module: nn.Module): + for child_name, child in module.named_children(): + if isinstance(child, tuple(module_checklist)): + bn = nn.BatchNorm2d(child.num_features, child.eps, + child.momentum, child.affine, + child.track_running_stats) + setattr(module, child_name, bn) + else: + traverse(child) + + traverse(model) + @abstractmethod def prepare(self, model): """Prepare for quantizing model, which usually includes as follows: diff --git a/mmrazor/models/quantizers/native_quantizer.py b/mmrazor/models/quantizers/native_quantizer.py index b5de1c028..653cbc931 100644 --- a/mmrazor/models/quantizers/native_quantizer.py +++ b/mmrazor/models/quantizers/native_quantizer.py @@ -9,7 +9,9 @@ from torch.ao.quantization import enable_fake_quant from torch.ao.quantization.fx import prepare from torch.ao.quantization.fx.graph_module import ObservedGraphModule - from torch.ao.quantization.qconfig_mapping import QConfigMapping + from torch.ao.quantization.qconfig_mapping import ( + _FIXED_QPARAMS_OP_TO_OBSERVER, FixedQParamsFakeQuantize, QConfig, + QConfigMapping, default_weight_fake_quant) from torch.ao.quantization.quantize_fx import _fuse_fx from torch.fx.graph_module import GraphModule from torch.nn.intrinsic.qat import modules as qat_fused_modules @@ -24,6 +26,10 @@ _fuse_fx = get_placeholder('torch>=1.13') qat_fused_modules = get_package_placeholder('torch>=1.13') qat_modules = get_package_placeholder('torch>=1.13') + _FIXED_QPARAMS_OP_TO_OBSERVER = get_package_placeholder('torch>=1.13') + FixedQParamsFakeQuantize = get_package_placeholder('torch>=1.13') + QConfig = get_package_placeholder('torch>=1.13') + default_weight_fake_quant = get_package_placeholder('torch>=1.13') from mmrazor import digit_version from mmrazor.models.task_modules.tracer import build_graphmodule @@ -125,6 +131,24 @@ def __init__(self, self.qconfig_mapping.set_object_type(mod, None) else: self.no_observer_modules = no_observer_modules + + fixed_qparams_observer_to_qconfig = {} + for fixed_qparams_op, observer in _FIXED_QPARAMS_OP_TO_OBSERVER.items( + ): + if observer in fixed_qparams_observer_to_qconfig: + fixed_qparams_qconfig = fixed_qparams_observer_to_qconfig[ + observer] + else: + activation = FixedQParamsFakeQuantize.with_args( + observer=observer) + + fixed_qparams_qconfig = QConfig( + activation=activation, weight=default_weight_fake_quant) + fixed_qparams_observer_to_qconfig[ + observer] = fixed_qparams_qconfig + self.qconfig_mapping.set_object_type(fixed_qparams_op, + fixed_qparams_qconfig) + self.backend_config = BackendConfigs[self.backend] self.example_inputs = (torch.randn(1, 3, 224, 224), ) @@ -169,6 +193,8 @@ def prepare(self, model, concrete_args=None): self.swap_ff_with_fxff(model) traced_graph = self.tracer.trace(model, concrete_args=concrete_args) graph_module = build_graphmodule(model, traced_graph) + + self.sync_module_training_mode(graph_module) graph_module = _fuse_fx( graph_module=graph_module, is_qat=True, diff --git a/mmrazor/models/quantizers/openvino_quantizer.py b/mmrazor/models/quantizers/openvino_quantizer.py index cb7d3084b..f8a25bd56 100644 --- a/mmrazor/models/quantizers/openvino_quantizer.py +++ b/mmrazor/models/quantizers/openvino_quantizer.py @@ -59,6 +59,7 @@ def prepare_for_mmdeploy(self, 3. post process weight fakequant for exporting .onnx that meet the backend's requirement. """ + self.convert_batchnorm2d(model) observed_model = self.prepare(model) if dummy_input is not None: observed_model(torch.randn(dummy_input)) diff --git a/mmrazor/models/task_modules/tracer/fx/custom_tracer.py b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py index a3cff1167..3be211f68 100644 --- a/mmrazor/models/task_modules/tracer/fx/custom_tracer.py +++ b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import functools +from copy import deepcopy from types import FunctionType from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union @@ -152,6 +153,33 @@ def _get_attrs(target, attrs): return module_dict +def duplicate_reused_nodes(graph: Graph, modules: Dict[str, Any] = {}): + """Deepcopy the shared modules (e.g. shared detection head in RetinaNet) to + make sure modules can be fused correctly. + + Modified from https://github.com/ModelTC/MQBench/blob/main/mqbench/prepare_by_platform.py # noqa: E501 + """ + _dup_prefix = '_dup' + target_dict = dict() + dup_modules = dict() + for node in graph.nodes: + if node.op == 'call_module': + if node.target not in target_dict: + target_dict[node.target] = [node] + else: + target_dict[node.target].append(node) + for key in target_dict: + if len(target_dict[key]) > 1: + for idx, node in enumerate(target_dict[key]): + if idx == 0: + continue + module = deepcopy(modules[node.target]) + node.target += _dup_prefix + str(idx) + dup_modules[node.target] = module + graph.lint() + return graph, dup_modules + + def build_graphmodule(model: torch.nn.Module, fx_graph, name: str = 'GraphModule'): @@ -180,7 +208,9 @@ def build_graphmodule(model: torch.nn.Module, """ modules = dict(model.named_modules()) module_dict = _prepare_module_dict(model, fx_graph) + fx_graph, duplicated_modules = duplicate_reused_nodes(fx_graph, modules) modules.update(module_dict) + modules.update(duplicated_modules) return GraphModule(modules, fx_graph, name) From 7fa425e8aad3c69526b8e169846bc9681144803f Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Thu, 12 Jan 2023 19:43:47 +0800 Subject: [PATCH 02/19] use placeholder --- .../models/task_modules/tracer/fx/custom_tracer.py | 2 ++ mmrazor/models/task_modules/tracer/fx/graph_utils.py | 12 +++++++----- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/mmrazor/models/task_modules/tracer/fx/custom_tracer.py b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py index 3be211f68..afbafaf03 100644 --- a/mmrazor/models/task_modules/tracer/fx/custom_tracer.py +++ b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py @@ -53,6 +53,7 @@ class is traced with CustomTracer, the decorated method will be as a leaf def __init__(self, method: FunctionType): self.method = method + self.instances: Dict = dict() self.owner = None def __set_name__(self, owner, name): @@ -302,6 +303,7 @@ def call_method(self, m: torch.nn.Module, name: str, method: Callable, kwargs (Dict): kwargs of the module callsite Return: + The return value from the Module call. In the case that a ``call_module`` node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever value was returned from the ``Module`` diff --git a/mmrazor/models/task_modules/tracer/fx/graph_utils.py b/mmrazor/models/task_modules/tracer/fx/graph_utils.py index 5e3ddc2f4..ca1291711 100644 --- a/mmrazor/models/task_modules/tracer/fx/graph_utils.py +++ b/mmrazor/models/task_modules/tracer/fx/graph_utils.py @@ -6,9 +6,11 @@ try: from torch.ao.quantization.fake_quantize import FakeQuantizeBase + from torch.fx import Node except ImportError: from mmrazor.utils import get_placeholder FakeQuantizeBase = get_placeholder('torch>=1.13') + Node = get_placeholder('torch>=1.13') def _get_attrs(target: torch.nn.Module, attr: str) -> Any: @@ -61,11 +63,11 @@ def recursive_find_erased_nodes(node, prepared_model): nodes_to_erase = [] for prev_node in node.args: - if isinstance(prev_node, torch.fx.Node): + if isinstance(prev_node, Node): nodes_to_erase.extend( recursive_find_erased_nodes(prev_node, prepared_model)) for prev_node in node.kwargs.values(): - if isinstance(prev_node, torch.fx.Node): + if isinstance(prev_node, Node): nodes_to_erase.extend( recursive_find_erased_nodes(prev_node, prepared_model)) @@ -94,7 +96,7 @@ def del_fakequant_before_op(prepared_model, new_graph = copy.deepcopy(prepared_model.graph) for node in new_graph.nodes: if node.op in target_ops: - nodes_to_erase: List[torch.fx.Node] = recursive_find_erased_nodes( + nodes_to_erase: List[Node] = recursive_find_erased_nodes( node, prepared_model) for to_erase in nodes_to_erase: assert to_erase.op == 'call_module' and isinstance( @@ -172,7 +174,7 @@ def del_fakequant_before_method(prepared_model, new_graph = copy.deepcopy(prepared_model.graph) for node in new_graph.nodes: if node.op == 'call_method' and node.target in method_patterns: - nodes_to_erase: List[torch.fx.Node] = recursive_find_erased_nodes( + nodes_to_erase: List[Node] = recursive_find_erased_nodes( node, prepared_model) for to_erase in nodes_to_erase: assert to_erase.op == 'call_module' and isinstance( @@ -251,7 +253,7 @@ def del_fakequant_before_function(prepared_model, new_graph = copy.deepcopy(prepared_model.graph) for node in new_graph.nodes: if node.op == 'call_function' and node.target in function_patterns: - nodes_to_erase: List[torch.fx.Node] = recursive_find_erased_nodes( + nodes_to_erase: List[Node] = recursive_find_erased_nodes( node, prepared_model) for to_erase in nodes_to_erase: assert to_erase.op == 'call_module' and isinstance( From 788cc56fba2d54e3c3da4397ee14f399043f6bbb Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Fri, 13 Jan 2023 18:18:19 +0800 Subject: [PATCH 03/19] fix lsq && quantloop --- mmrazor/engine/__init__.py | 10 ++++++---- mmrazor/engine/runner/__init__.py | 5 +++-- mmrazor/engine/runner/quantization_loops.py | 13 +++++++------ mmrazor/models/fake_quants/lsq.py | 20 +++++++++----------- 4 files changed, 25 insertions(+), 23 deletions(-) diff --git a/mmrazor/engine/__init__.py b/mmrazor/engine/__init__.py index ced74bc92..5c28d1160 100644 --- a/mmrazor/engine/__init__.py +++ b/mmrazor/engine/__init__.py @@ -3,14 +3,16 @@ from .optimizers import SeparateOptimWrapperConstructor from .runner import (AutoSlimGreedySearchLoop, DartsEpochBasedTrainLoop, DartsIterBasedTrainLoop, EvolutionSearchLoop, - GreedySamplerTrainLoop, PTQLoop, QATEpochBasedLoop, - SelfDistillValLoop, SingleTeacherDistillValLoop, - SlimmableValLoop, SubnetValLoop) + GreedySamplerTrainLoop, LSQEpochBasedLoop, PTQLoop, + QATEpochBasedLoop, QATValLoop, SelfDistillValLoop, + SingleTeacherDistillValLoop, SlimmableValLoop, + SubnetValLoop) __all__ = [ 'SeparateOptimWrapperConstructor', 'DumpSubnetHook', 'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop', 'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop', 'GreedySamplerTrainLoop', 'EstimateResourcesHook', 'SelfDistillValLoop', - 'AutoSlimGreedySearchLoop', 'SubnetValLoop', 'PTQLoop', 'QATEpochBasedLoop' + 'AutoSlimGreedySearchLoop', 'SubnetValLoop', 'PTQLoop', + 'QATEpochBasedLoop', 'LSQEpochBasedLoop', 'QATValLoop' ] diff --git a/mmrazor/engine/runner/__init__.py b/mmrazor/engine/runner/__init__.py index 2ca6c0dbb..5fe2fd524 100644 --- a/mmrazor/engine/runner/__init__.py +++ b/mmrazor/engine/runner/__init__.py @@ -4,7 +4,8 @@ from .distill_val_loop import SelfDistillValLoop, SingleTeacherDistillValLoop from .evolution_search_loop import EvolutionSearchLoop from .iteprune_val_loop import ItePruneValLoop -from .quantization_loops import PTQLoop, QATEpochBasedLoop +from .quantization_loops import (LSQEpochBasedLoop, PTQLoop, QATEpochBasedLoop, + QATValLoop) from .slimmable_val_loop import SlimmableValLoop from .subnet_sampler_loop import GreedySamplerTrainLoop from .subnet_val_loop import SubnetValLoop @@ -14,5 +15,5 @@ 'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop', 'GreedySamplerTrainLoop', 'SubnetValLoop', 'SelfDistillValLoop', 'ItePruneValLoop', 'AutoSlimGreedySearchLoop', 'QATEpochBasedLoop', - 'PTQLoop' + 'PTQLoop', 'LSQEpochBasedLoop', 'QATValLoop' ] diff --git a/mmrazor/engine/runner/quantization_loops.py b/mmrazor/engine/runner/quantization_loops.py index 299549dae..df0f4f76d 100644 --- a/mmrazor/engine/runner/quantization_loops.py +++ b/mmrazor/engine/runner/quantization_loops.py @@ -73,7 +73,7 @@ def prepare_for_val(self): self.runner.model.apply(enable_fake_quant) self.runner.model.apply(disable_observer) - def run(self) -> torch.nn.Module: + def run(self): """Launch training.""" self.runner.call_hook('before_train') @@ -87,7 +87,7 @@ def run(self) -> torch.nn.Module: and self._epoch % self.val_interval == 0): # observer disabled during evaluation self.prepare_for_val() - self.runner.model.sync_qparams(src='loss') + self.runner.model.sync_qparams(src_mode='loss') self.runner.val_loop.run() self.runner.call_hook('after_train') @@ -97,11 +97,12 @@ def run_epoch(self) -> None: self.runner.call_hook('before_train_epoch') self.runner.model.train() - # TODO freeze bn - if self._epoch >= self.disable_observer_begin: + # The initialized _epoch equals to 0 so _epoch + 1 + # equal to the current epoch + if self._epoch + 1 >= self.disable_observer_begin: self.runner.model.apply(disable_observer) - if self._epoch >= self.freeze_bn_begin: + if self._epoch + 1 >= self.freeze_bn_begin: self.runner.model.apply(freeze_bn_stats) for idx, data_batch in enumerate(self.dataloader): @@ -168,7 +169,7 @@ def run_epoch(self) -> None: self.runner.model.train() # TODO freeze bn - if self._epoch >= self.freeze_bn_begin: + if self._epoch + 1 >= self.freeze_bn_begin: self.runner.model.apply(freeze_bn_stats) for idx, data_batch in enumerate(self.dataloader): diff --git a/mmrazor/models/fake_quants/lsq.py b/mmrazor/models/fake_quants/lsq.py index 9b9ca0dd0..270140b85 100644 --- a/mmrazor/models/fake_quants/lsq.py +++ b/mmrazor/models/fake_quants/lsq.py @@ -261,15 +261,13 @@ def forward(self, X): @torch.jit.export def extra_repr(self): """The printable representational string.""" - repr_str = self.__class__.__name__ - repr_str += '(' - repr_str += f'static_enabled={self.static_enabled}, ' - repr_str += f'fake_quant_enabled={self.fake_quant_enabled}' - repr_str += f'quant_min={self.activation_post_process.quant_min}' - repr_str += f'quant_max={self.activation_post_process.quant_max}' - repr_str += f'dtype={self.dtype}' - repr_str += f'qscheme={self.qscheme}' - repr_str += f'scale={self.scale}' - repr_str += f'zero_point={self.zero_point}' - repr_str += ')' + repr_str = f'static_enabled={self.static_enabled}, ' + repr_str += f'fake_quant_enabled={self.fake_quant_enabled}, ' + repr_str += f'quant_min={self.activation_post_process.quant_min}, ' + repr_str += f'quant_max={self.activation_post_process.quant_max}, ' + repr_str += f'dtype={self.dtype}, ' + repr_str += f'qscheme={self.qscheme}, ' + repr_str += f'scale={self.scale}, ' + repr_str += f'zero_point={self.zero_point}, ' + repr_str += f'zero_point_trainable={self.zero_point_trainable}' return repr_str From c92f824bb2c5cfb61229d34c53c2eb7c039ae2bc Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Fri, 13 Jan 2023 18:20:41 +0800 Subject: [PATCH 04/19] add lsq pytest --- .../test_fake_quants/test_lsq_fake_quants.py | 178 +++++++++++++++++- 1 file changed, 169 insertions(+), 9 deletions(-) diff --git a/tests/test_models/test_fake_quants/test_lsq_fake_quants.py b/tests/test_models/test_fake_quants/test_lsq_fake_quants.py index d6b670bb5..63bf8b167 100644 --- a/tests/test_models/test_fake_quants/test_lsq_fake_quants.py +++ b/tests/test_models/test_fake_quants/test_lsq_fake_quants.py @@ -1,23 +1,183 @@ # Copyright (c) OpenMMLab. All rights reserved. from unittest import TestCase +import torch +from torch.nn.parameter import Parameter + +from mmrazor.models import LearnableFakeQuantize + +try: + from torch.ao.quantization import MovingAverageMinMaxObserver +except ImportError: + from mmrazor.utils import get_placeholder + MovingAverageMinMaxObserver = get_placeholder('torch>=1.13') + class TestLearnableFakeQuantize(TestCase): - def test_init(self): - pass + def setUp(self): + self.zero_point_trainable_fakequant = LearnableFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=True, + zero_point_trainable=True) + + self.zero_point_untrainable_fakequant = \ + LearnableFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=True, + zero_point_trainable=False) def test_repr(self): - pass + fq_module = self.zero_point_untrainable_fakequant() + repr_str = f'static_enabled={torch.tensor([1], dtype=torch.uint8)}, ' + repr_str += f'fake_quant_enabled=' \ + f'{torch.tensor([1], dtype=torch.uint8)}, ' + repr_str += 'quant_min=0, ' + repr_str += 'quant_max=127, ' + repr_str += f'dtype={torch.quint8}, ' + repr_str += f'qscheme={torch.per_tensor_affine}, ' + repr_str += f'scale={Parameter(torch.tensor([1.0]))}, ' + repr_str += f'zero_point={torch.tensor([0.])}, ' + repr_str += 'zero_point_trainable=False' + self.assertEqual(fq_module.extra_repr(), repr_str) + + fq_module = self.zero_point_trainable_fakequant() + repr_str = f'static_enabled={torch.tensor([1], dtype=torch.uint8)}, ' + repr_str += f'fake_quant_enabled=' \ + f'{torch.tensor([1], dtype=torch.uint8)}, ' + repr_str += 'quant_min=0, ' + repr_str += 'quant_max=127, ' + repr_str += f'dtype={torch.quint8}, ' + repr_str += f'qscheme={torch.per_tensor_affine}, ' + repr_str += f'scale={Parameter(torch.tensor([1.0]))}, ' + repr_str += f'zero_point={Parameter(torch.tensor([0.]))}, ' + repr_str += 'zero_point_trainable=True' + self.assertEqual(fq_module.extra_repr(), repr_str) def test_calculate_qparams(self): - pass + fq_module = self.zero_point_untrainable_fakequant() + scale, zero_point = fq_module.calculate_qparams() + self.assertEqual(scale, 1.) + self.assertEqual(zero_point, 0.) + + fq_module = self.zero_point_trainable_fakequant() + scale, zero_point = fq_module.calculate_qparams() + self.assertEqual(scale, 1.) + self.assertEqual(zero_point, 0.) def test_forward(self): - pass + fq_module = self.zero_point_untrainable_fakequant() + torch.manual_seed(42) + X = torch.rand(20, 10, dtype=torch.float32) + # Output of fake quant is not identical to input + Y = fq_module(X) + self.assertFalse(torch.equal(Y, X)) + # self.assertNotEqual(Y, X) + fq_module.toggle_fake_quant(False) + X = torch.rand(20, 10, dtype=torch.float32) + Y = fq_module(X) + # Fake quant is disabled,output is identical to input + self.assertTrue(torch.equal(Y, X)) + + # Explicit copy at this point in time, because FakeQuant keeps internal + # state in mutable buffers. + scale = fq_module.scale.clone().detach() + zero_point = fq_module.zero_point.clone().detach() + + fq_module.toggle_observer_update(False) + fq_module.toggle_fake_quant(True) + X = 10.0 * torch.rand(20, 10, dtype=torch.float32) - 5.0 + Y = fq_module(X) + self.assertFalse(torch.equal(Y, X)) + # Observer is disabled, scale and zero-point do not change + self.assertEqual(fq_module.scale, scale) + self.assertEqual(fq_module.zero_point, zero_point) + + fq_module.toggle_observer_update(True) + Y = fq_module(X) + self.assertFalse(torch.equal(Y, X)) + # Observer is enabled, scale and zero-point are different + self.assertNotEqual(fq_module.scale, scale) + self.assertNotEqual(fq_module.zero_point, zero_point) + + fq_module = self.zero_point_trainable_fakequant() + torch.manual_seed(42) + X = torch.rand(20, 10, dtype=torch.float32) + # Output of fake quant is not identical to input + Y = fq_module(X) + self.assertFalse(torch.equal(Y, X)) + # self.assertNotEqual(Y, X) + fq_module.toggle_fake_quant(False) + X = torch.rand(20, 10, dtype=torch.float32) + Y = fq_module(X) + # Fake quant is disabled,output is identical to input + self.assertTrue(torch.equal(Y, X)) + + # Explicit copy at this point in time, because FakeQuant keeps internal + # state in mutable buffers. + scale = fq_module.scale.clone().detach() + zero_point = fq_module.zero_point.clone().detach() + + fq_module.toggle_observer_update(False) + fq_module.toggle_fake_quant(True) + X = 10.0 * torch.rand(20, 10, dtype=torch.float32) - 5.0 + Y = fq_module(X) + self.assertFalse(torch.equal(Y, X)) + # Observer is disabled, scale and zero-point do not change + self.assertEqual(fq_module.scale, scale) + self.assertEqual(fq_module.zero_point, zero_point) + + fq_module.toggle_observer_update(True) + Y = fq_module(X) + self.assertFalse(torch.equal(Y, X)) + # Observer is enabled, scale and zero-point are different + self.assertNotEqual(fq_module.scale, scale) + self.assertNotEqual(fq_module.zero_point, zero_point) + + def test_state(self): + fq_module = self.zero_point_untrainable_fakequant() + + fq_module.enable_param_learning() + self.assertEqual(fq_module.learning_enabled[0], 1) + self.assertEqual(fq_module.scale.requires_grad, 1) + self.assertEqual(fq_module.zero_point.requires_grad, 0) + self.assertEqual(fq_module.fake_quant_enabled[0], 1) + self.assertEqual(fq_module.static_enabled[0], 0) + + fq_module.enable_static_estimate() + self.assertEqual(fq_module.learning_enabled[0], 0) + self.assertEqual(fq_module.scale.requires_grad, 0) + self.assertEqual(fq_module.zero_point.requires_grad, 0) + self.assertEqual(fq_module.fake_quant_enabled[0], 1) + self.assertEqual(fq_module.static_enabled[0], 1) + + fq_module.enable_val() + self.assertEqual(fq_module.learning_enabled[0], 0) + self.assertEqual(fq_module.scale.requires_grad, 0) + self.assertEqual(fq_module.zero_point.requires_grad, 0) + self.assertEqual(fq_module.fake_quant_enabled[0], 1) + self.assertEqual(fq_module.static_enabled[0], 0) + + fq_module.enable_static_observation() + self.assertEqual(fq_module.learning_enabled[0], 0) + self.assertEqual(fq_module.scale.requires_grad, 0) + self.assertEqual(fq_module.zero_point.requires_grad, 0) + self.assertEqual(fq_module.fake_quant_enabled[0], 0) + self.assertEqual(fq_module.static_enabled[0], 1) - def test_load_state_dict(self): - pass + fq_module = self.zero_point_trainable_fakequant() - def test_save_state_dict(self): - pass + fq_module.enable_param_learning() + self.assertEqual(fq_module.learning_enabled[0], 1) + self.assertEqual(fq_module.scale.requires_grad, 1) + self.assertEqual(fq_module.zero_point.requires_grad, 1) + self.assertEqual(fq_module.fake_quant_enabled[0], 1) + self.assertEqual(fq_module.static_enabled[0], 0) From 0e558df8ac9ce322892cb2b8758f4e4b579da03d Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Fri, 13 Jan 2023 18:22:02 +0800 Subject: [PATCH 05/19] add quant loop pytest --- tests/test_runners/test_quantization_loop.py | 416 +++++++++++++++++++ 1 file changed, 416 insertions(+) create mode 100644 tests/test_runners/test_quantization_loop.py diff --git a/tests/test_runners/test_quantization_loop.py b/tests/test_runners/test_quantization_loop.py new file mode 100644 index 000000000..64505b14a --- /dev/null +++ b/tests/test_runners/test_quantization_loop.py @@ -0,0 +1,416 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import logging +import shutil +import tempfile +from unittest import TestCase + +import torch +import torch.nn as nn +from mmengine.config import Config, ConfigDict +from mmengine.evaluator import BaseMetric +from mmengine.hooks import Hook +from mmengine.logging import MMLogger +from mmengine.model import BaseModel +from mmengine.optim import OptimWrapper +from mmengine.registry import DATASETS, HOOKS, METRICS, MODELS, OPTIM_WRAPPERS +from mmengine.runner import Runner +from torch.ao.quantization.qconfig_mapping import get_default_qconfig_mapping +from torch.nn.intrinsic.qat import ConvBnReLU2d +from torch.utils.data import Dataset + +from mmrazor.engine import (LSQEpochBasedLoop, PTQLoop, QATEpochBasedLoop, + QATValLoop) + +try: + from torch.ao.quantization import QConfigMapping + from torch.ao.quantization.fake_quantize import FakeQuantizeBase + from torch.ao.quantization.fx import prepare + from torch.ao.quantization.quantize_fx import _fuse_fx +except ImportError: + from mmrazor.utils import get_placeholder + QConfigMapping = get_placeholder('torch>=1.13') + FakeQuantizeBase = get_placeholder('torch>=1.13') + prepare = get_placeholder('torch>=1.13') + _fuse_fx = get_placeholder('torch>=1.13') + +from mmrazor import digit_version + + +class ToyDataset(Dataset): + METAINFO = dict() # type: ignore + data = torch.randn(12, 3, 4, 4) + label = torch.ones(12) + + @property + def metainfo(self): + return self.METAINFO + + def __len__(self): + return self.data.size(0) + + def __getitem__(self, index): + return dict(inputs=self.data[index], data_sample=self.label[index]) + + +class MMArchitectureQuant(BaseModel): + + def __init__(self, data_preprocessor=None): + super().__init__(data_preprocessor=data_preprocessor) + self.architecture = ToyModel() + + def calibrate_step(self, data): + data = self.data_preprocessor(data, False) + return self.architecture(**data) + + def swap_ff_with_fxff(self, model): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + modules_to_swap = [] + for name, module in model.named_children(): + if isinstance(module, torch.ao.nn.quantized.FloatFunctional): + modules_to_swap.append(name) + else: + self.swap_ff_with_fxff(module) + + for name in modules_to_swap: + del model._modules[name] + model._modules[name] = torch.ao.nn.quantized.FXFloatFunctional() + + def sync_qparams(self, src_mode): + pass + + def forward(self, inputs, data_sample, mode='tensor'): + return self.architecture(inputs, data_sample, mode) + + +class ToyModel(BaseModel): + + def __init__(self, data_preprocessor=None): + super().__init__(data_preprocessor=data_preprocessor) + qconfig = get_default_qconfig_mapping().to_dict()[''] + self.architecture = nn.Sequential( + ConvBnReLU2d(3, 3, 1, qconfig=qconfig)) + + def forward(self, inputs, data_sample, mode='tensor'): + if isinstance(inputs, list): + inputs = torch.stack(inputs) + if isinstance(data_sample, list): + data_sample = torch.stack(data_sample) + outputs = self.architecture(inputs) + + if mode == 'tensor': + return outputs + elif mode == 'loss': + loss = data_sample.sum() - outputs.sum() + outputs = dict(loss=loss) + return outputs + elif mode == 'predict': + return outputs + + +class ToyOptimWrapper(OptimWrapper): + ... + + +class ToyMetric1(BaseMetric): + + def __init__(self, collect_device='cpu', dummy_metrics=None): + super().__init__(collect_device=collect_device) + self.dummy_metrics = dummy_metrics + + def process(self, data_batch, predictions): + result = {'acc': 1} + self.results.append(result) + + def compute_metrics(self, results): + return dict(acc=1) + + +DEFAULT_CFG = ConfigDict( + model=dict(type='MMArchitectureQuant'), + train_dataloader=dict( + dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', shuffle=True), + batch_size=3, + num_workers=0), + val_dataloader=dict( + dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', shuffle=False), + batch_size=3, + num_workers=0), + test_dataloader=dict( + dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', shuffle=False), + batch_size=3, + num_workers=0), + optim_wrapper=dict( + type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)), + val_evaluator=dict(type='ToyMetric1'), + test_evaluator=dict(type='ToyMetric1'), + train_cfg=dict(), + val_cfg=dict(), + test_cfg=dict(), + custom_hooks=[], + data_preprocessor=None, + launcher='none', + env_cfg=dict(dist_cfg=dict(backend='nccl')), +) + + +class TestQATEpochBasedLoop(TestCase): + + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + MODELS.register_module(module=MMArchitectureQuant, force=True) + DATASETS.register_module(module=ToyDataset, force=True) + METRICS.register_module(module=ToyMetric1, force=True) + OPTIM_WRAPPERS.register_module(module=ToyOptimWrapper, force=True) + + default_cfg = copy.deepcopy(DEFAULT_CFG) + default_cfg = Config(default_cfg) + default_cfg.work_dir = self.temp_dir + default_cfg.train_cfg = ConfigDict( + type='mmrazor.QATEpochBasedLoop', + max_epochs=4, + val_begin=1, + val_interval=1, + disable_observer_begin=-1, + freeze_bn_begin=-1, + dynamic_intervals=None) + self.default_cfg = default_cfg + + def tearDown(self): + MODELS.module_dict.pop('MMArchitectureQuant') + DATASETS.module_dict.pop('ToyDataset') + METRICS.module_dict.pop('ToyMetric1') + OPTIM_WRAPPERS.module_dict.pop('ToyOptimWrapper') + + logging.shutdown() + MMLogger._instance_dict.clear() + shutil.rmtree(self.temp_dir) + + def test_init(self): + cfg = copy.deepcopy(self.default_cfg) + cfg.experiment_name = 'test_init' + runner = Runner(**cfg) + self.assertIsInstance(runner, Runner) + self.assertIsInstance(runner.train_loop, QATEpochBasedLoop) + + def test_run_epoch(self): + cfg = copy.deepcopy(self.default_cfg) + cfg.experiment_name = 'test_train' + runner = Runner.from_cfg(cfg) + runner.train() + + @HOOKS.register_module(force=True) + class TestFreezeBNHook(Hook): + + def __init__(self, freeze_bn_begin): + self.freeze_bn_begin = freeze_bn_begin + + def after_train_epoch(self, runner): + + def check_bn_stats(mod): + if isinstance(mod, ConvBnReLU2d): + assert mod.freeze_bn + assert not mod.bn.training + + if runner.train_loop._epoch + 1 >= self.freeze_bn_begin: + runner.model.apply(check_bn_stats) + + cfg = copy.deepcopy(self.default_cfg) + cfg.experiment_name = 'test_freeze_bn' + cfg.custom_hooks = [ + dict(type='TestFreezeBNHook', priority=50, freeze_bn_begin=1) + ] + cfg.train_cfg.freeze_bn_begin = 1 + runner = Runner.from_cfg(cfg) + runner.train() + + @HOOKS.register_module(force=True) + class TestDisableObserverHook(Hook): + + def __init__(self, disable_observer_begin): + self.disable_observer_begin = disable_observer_begin + + def after_train_epoch(self, runner): + + def check_observer_stats(mod): + if isinstance(mod, FakeQuantizeBase): + assert mod.fake_quant_enabled[0] == 0 + + if runner.train_loop._epoch + 1 >= self.disable_observer_begin: + runner.model.apply(check_observer_stats) + + cfg = copy.deepcopy(self.default_cfg) + cfg.experiment_name = 'test_disable_observer' + cfg.custom_hooks = [ + dict( + type='TestDisableObserverHook', + priority=50, + disable_observer_begin=1) + ] + cfg.train_cfg.disable_observer_begin = 1 + runner = Runner.from_cfg(cfg) + runner.train() + + +class TestLSQEpochBasedLoop(TestCase): + + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + MODELS.register_module(module=MMArchitectureQuant, force=True) + DATASETS.register_module(module=ToyDataset, force=True) + METRICS.register_module(module=ToyMetric1, force=True) + OPTIM_WRAPPERS.register_module(module=ToyOptimWrapper, force=True) + + default_cfg = copy.deepcopy(DEFAULT_CFG) + default_cfg = Config(default_cfg) + default_cfg.work_dir = self.temp_dir + default_cfg.train_cfg = ConfigDict( + type='mmrazor.LSQEpochBasedLoop', + max_epochs=4, + val_begin=1, + val_interval=1, + freeze_bn_begin=-1, + dynamic_intervals=None) + self.default_cfg = default_cfg + + def tearDown(self): + MODELS.module_dict.pop('MMArchitectureQuant') + DATASETS.module_dict.pop('ToyDataset') + METRICS.module_dict.pop('ToyMetric1') + OPTIM_WRAPPERS.module_dict.pop('ToyOptimWrapper') + + logging.shutdown() + MMLogger._instance_dict.clear() + shutil.rmtree(self.temp_dir) + + def test_init(self): + cfg = copy.deepcopy(self.default_cfg) + cfg.experiment_name = 'test_init' + runner = Runner(**cfg) + self.assertIsInstance(runner, Runner) + self.assertIsInstance(runner.train_loop, LSQEpochBasedLoop) + + def test_run_epoch(self): + cfg = copy.deepcopy(self.default_cfg) + cfg.experiment_name = 'test_train' + runner = Runner.from_cfg(cfg) + runner.train() + + @HOOKS.register_module(force=True) + class TestFreezeBNHook(Hook): + + def __init__(self, freeze_bn_begin): + self.freeze_bn_begin = freeze_bn_begin + + def after_train_epoch(self, runner): + + def check_bn_stats(mod): + if isinstance(mod, ConvBnReLU2d): + assert mod.freeze_bn + assert not mod.bn.training + + if runner.train_loop._epoch + 1 >= self.freeze_bn_begin: + runner.model.apply(check_bn_stats) + + cfg = copy.deepcopy(self.default_cfg) + cfg.experiment_name = 'test_freeze_bn' + cfg.custom_hooks = [ + dict(type='TestFreezeBNHook', priority=50, freeze_bn_begin=1) + ] + cfg.train_cfg.freeze_bn_begin = 1 + runner = Runner.from_cfg(cfg) + runner.train() + + +class TestQATValLoop(TestCase): + + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + MODELS.register_module(module=MMArchitectureQuant, force=True) + DATASETS.register_module(module=ToyDataset, force=True) + METRICS.register_module(module=ToyMetric1, force=True) + OPTIM_WRAPPERS.register_module(module=ToyOptimWrapper, force=True) + + default_cfg = copy.deepcopy(DEFAULT_CFG) + default_cfg = Config(default_cfg) + default_cfg.work_dir = self.temp_dir + default_cfg.val_cfg = ConfigDict(type='mmrazor.QATValLoop') + self.default_cfg = default_cfg + + def tearDown(self): + MODELS.module_dict.pop('MMArchitectureQuant') + DATASETS.module_dict.pop('ToyDataset') + METRICS.module_dict.pop('ToyMetric1') + OPTIM_WRAPPERS.module_dict.pop('ToyOptimWrapper') + + logging.shutdown() + MMLogger._instance_dict.clear() + shutil.rmtree(self.temp_dir) + + def test_init(self): + cfg = copy.deepcopy(self.default_cfg) + cfg.experiment_name = 'test_init' + runner = Runner(**cfg) + self.assertIsInstance(runner, Runner) + self.assertIsInstance(runner.val_loop, QATValLoop) + + def test_run(self): + cfg = copy.deepcopy(self.default_cfg) + cfg.experiment_name = 'test_val' + cfg.pop('train_dataloader') + cfg.pop('train_cfg') + cfg.pop('optim_wrapper') + cfg.pop('test_dataloader') + cfg.pop('test_cfg') + cfg.pop('test_evaluator') + runner = Runner.from_cfg(cfg) + runner.val() + + +class TestPTQLoop(TestCase): + + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + MODELS.register_module(module=MMArchitectureQuant, force=True) + DATASETS.register_module(module=ToyDataset, force=True) + METRICS.register_module(module=ToyMetric1, force=True) + OPTIM_WRAPPERS.register_module(module=ToyOptimWrapper, force=True) + + default_cfg = copy.deepcopy(DEFAULT_CFG) + default_cfg = Config(default_cfg) + default_cfg.work_dir = self.temp_dir + # save_checkpoint in PTQLoop need train_dataloader + default_cfg.train_cfg = ConfigDict(by_epoch=True, max_epochs=3) + default_cfg.test_cfg = ConfigDict( + type='mmrazor.PTQLoop', + calibrate_dataloader=default_cfg.train_dataloader, + calibrate_steps=32) + self.default_cfg = default_cfg + + def tearDown(self): + MODELS.module_dict.pop('MMArchitectureQuant') + DATASETS.module_dict.pop('ToyDataset') + METRICS.module_dict.pop('ToyMetric1') + OPTIM_WRAPPERS.module_dict.pop('ToyOptimWrapper') + + logging.shutdown() + MMLogger._instance_dict.clear() + shutil.rmtree(self.temp_dir) + + def test_init(self): + cfg = copy.deepcopy(self.default_cfg) + cfg.experiment_name = 'test_init' + runner = Runner(**cfg) + self.assertIsInstance(runner, Runner) + self.assertIsInstance(runner.test_loop, PTQLoop) + + def test_run(self): + cfg = copy.deepcopy(self.default_cfg) + cfg.experiment_name = 'test_val' + runner = Runner.from_cfg(cfg) + runner.test() From b51181f0cb47b13f6b1eda536982664cb7c49532 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Fri, 13 Jan 2023 19:20:53 +0800 Subject: [PATCH 06/19] test lsq observer --- .../test_observer/test_lsq_observer.py | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 tests/test_models/test_observer/test_lsq_observer.py diff --git a/tests/test_models/test_observer/test_lsq_observer.py b/tests/test_models/test_observer/test_lsq_observer.py new file mode 100644 index 000000000..4ae7a361a --- /dev/null +++ b/tests/test_models/test_observer/test_lsq_observer.py @@ -0,0 +1,72 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch + +from mmrazor.models import LSQObserver, LSQPerChannelObserver + + +class TestLSQObserver(TestCase): + + def setUp(self): + self.lsq = LSQObserver.with_args( + dtype=torch.quint8, + qscheme=torch.per_tensor_symmetric, + reduce_range=False, + quant_min=0, + quant_max=255) + + def test_forward(self): + lsq_observer = self.lsq() + torch.manual_seed(42) + X = torch.rand(20, 10, dtype=torch.float32) + Y = lsq_observer(X) + # Output of observer is identical to input + self.assertTrue(torch.equal(Y, X)) + + X = torch.rand(0, dtype=torch.float32) + Y = lsq_observer(X) + # Output of observer is identical to input + self.assertTrue(torch.equal(Y, X)) + + def test_calculate_qparams(self): + lsq_observer = self.lsq() + X = torch.ones(10, dtype=torch.float32) + _ = lsq_observer(X) + scale, zero_point = lsq_observer.calculate_qparams() + # tensor_norm = 1, quant_max = 255 + self.assertEqual(scale, 2 * torch.tensor([1.]) / (255**0.5)) + self.assertEqual(zero_point, 127) + + +class TestLSQPerChannelObserver(TestCase): + + def setUp(self): + self.lsq = LSQPerChannelObserver.with_args( + dtype=torch.qint8, + qscheme=torch.per_channel_symmetric, + reduce_range=False, + quant_min=-127, + quant_max=127) + + def test_forward(self): + lsq_observer = self.lsq() + torch.manual_seed(42) + X = torch.rand(2, 10, dtype=torch.float32) + Y = lsq_observer(X) + # Output of observer is identical to input + self.assertTrue(torch.equal(Y, X)) + + X = torch.rand(0, dtype=torch.float32) + Y = lsq_observer(X) + # Output of observer is identical to input + self.assertTrue(torch.equal(Y, X)) + + def test_calculate_qparams(self): + lsq_observer = self.lsq() + X = torch.ones(2, 10, dtype=torch.float32) + X[0] -= 1 + _ = lsq_observer(X) + scale, zero_point = lsq_observer.calculate_qparams() + self.assertEqual(scale[0], 2 * torch.tensor([0.]) / (127**0.5)) + self.assertEqual(scale[1], 2 * torch.tensor([1.]) / (127**0.5)) From 9253badc243b399611ccbd152630208933404c6a Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Fri, 13 Jan 2023 19:40:22 +0800 Subject: [PATCH 07/19] fix bug under pt13 --- tests/test_runners/test_quantization_loop.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_runners/test_quantization_loop.py b/tests/test_runners/test_quantization_loop.py index 64505b14a..7a15a5ccb 100644 --- a/tests/test_runners/test_quantization_loop.py +++ b/tests/test_runners/test_quantization_loop.py @@ -15,7 +15,6 @@ from mmengine.optim import OptimWrapper from mmengine.registry import DATASETS, HOOKS, METRICS, MODELS, OPTIM_WRAPPERS from mmengine.runner import Runner -from torch.ao.quantization.qconfig_mapping import get_default_qconfig_mapping from torch.nn.intrinsic.qat import ConvBnReLU2d from torch.utils.data import Dataset @@ -26,6 +25,8 @@ from torch.ao.quantization import QConfigMapping from torch.ao.quantization.fake_quantize import FakeQuantizeBase from torch.ao.quantization.fx import prepare + from torch.ao.quantization.qconfig_mapping import \ + get_default_qconfig_mapping from torch.ao.quantization.quantize_fx import _fuse_fx except ImportError: from mmrazor.utils import get_placeholder @@ -33,6 +34,7 @@ FakeQuantizeBase = get_placeholder('torch>=1.13') prepare = get_placeholder('torch>=1.13') _fuse_fx = get_placeholder('torch>=1.13') + get_default_qconfig_mapping = get_placeholder('torch>=1.13') from mmrazor import digit_version From 7a1ef2d036d15c7142c60859d1191003a9c552ec Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Fri, 13 Jan 2023 21:16:56 +0800 Subject: [PATCH 08/19] fix reset_min_max_vals --- .../quantization/mm_architecture.py | 7 +------ mmrazor/models/observers/torch_observers.py | 20 +++++++++++++++++++ 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/mmrazor/models/algorithms/quantization/mm_architecture.py b/mmrazor/models/algorithms/quantization/mm_architecture.py index b58486894..767d7c4ce 100644 --- a/mmrazor/models/algorithms/quantization/mm_architecture.py +++ b/mmrazor/models/algorithms/quantization/mm_architecture.py @@ -89,13 +89,8 @@ def reset_observer_and_fakequant_statistics(self, model): statistics in observers and fake quantizers. """ for module in model.modules(): - if isinstance(module, MinMaxObserver): + if isinstance(module, (MinMaxObserver, PerChannelMinMaxObserver)): module.reset_min_max_vals() - elif isinstance(module, PerChannelMinMaxObserver): - min_val = torch.rand(0, ) - max_val = torch.rand(0, ) - module.min_val.resize_(min_val.shape).copy_(min_val) - module.max_val.resize_(max_val.shape).copy_(max_val) elif isinstance(module, FakeQuantizeBase): module.scale.data = torch.ones_like(module.scale) module.zero_point.data = torch.zeros_like(module.zero_point) diff --git a/mmrazor/models/observers/torch_observers.py b/mmrazor/models/observers/torch_observers.py index 5dc24609f..0de628a9a 100644 --- a/mmrazor/models/observers/torch_observers.py +++ b/mmrazor/models/observers/torch_observers.py @@ -2,13 +2,33 @@ import inspect from typing import List +import torch + from mmrazor.registry import MODELS try: import torch.ao.quantization.observer as torch_observer_src + from torch.ao.quantization.observer import PerChannelMinMaxObserver except ImportError: from mmrazor.utils import get_package_placeholder torch_observer_src = get_package_placeholder('torch>=1.13') + UniformQuantizationObserverBase = get_package_placeholder('torch>=1.13') + + +@torch.jit.export +def reset_min_max_vals(self): + """Resets the min/max values. + + `min_val` and `max_val` are always be on cpu in the pytorch version of this + method. + """ + min_val = torch.rand(0, ) + max_val = torch.rand(0, ) + self.min_val.resize_(min_val.shape).copy_(min_val) + self.max_val.resize_(max_val.shape).copy_(max_val) + + +PerChannelMinMaxObserver.reset_min_max_vals = reset_min_max_vals def register_torch_observers() -> List[str]: From 4cc328afc6b7b4a2fbf2929477ae39d333be8516 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Sat, 14 Jan 2023 08:55:50 +0800 Subject: [PATCH 09/19] fix bugs under pt13 --- mmrazor/models/observers/torch_observers.py | 2 +- tests/test_runners/test_quantization_loop.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/mmrazor/models/observers/torch_observers.py b/mmrazor/models/observers/torch_observers.py index 0de628a9a..996314d27 100644 --- a/mmrazor/models/observers/torch_observers.py +++ b/mmrazor/models/observers/torch_observers.py @@ -12,7 +12,7 @@ except ImportError: from mmrazor.utils import get_package_placeholder torch_observer_src = get_package_placeholder('torch>=1.13') - UniformQuantizationObserverBase = get_package_placeholder('torch>=1.13') + PerChannelMinMaxObserver = get_package_placeholder('torch>=1.13') @torch.jit.export diff --git a/tests/test_runners/test_quantization_loop.py b/tests/test_runners/test_quantization_loop.py index 7a15a5ccb..0ddc578bf 100644 --- a/tests/test_runners/test_quantization_loop.py +++ b/tests/test_runners/test_quantization_loop.py @@ -22,6 +22,7 @@ QATValLoop) try: + from torch.ao.nn.quantized import FloatFunctional, FXFloatFunctional from torch.ao.quantization import QConfigMapping from torch.ao.quantization.fake_quantize import FakeQuantizeBase from torch.ao.quantization.fx import prepare @@ -35,6 +36,8 @@ prepare = get_placeholder('torch>=1.13') _fuse_fx = get_placeholder('torch>=1.13') get_default_qconfig_mapping = get_placeholder('torch>=1.13') + FloatFunctional = get_placeholder('torch>=1.13') + FXFloatFunctional = get_placeholder('torch>=1.13') from mmrazor import digit_version @@ -71,14 +74,14 @@ def swap_ff_with_fxff(self, model): modules_to_swap = [] for name, module in model.named_children(): - if isinstance(module, torch.ao.nn.quantized.FloatFunctional): + if isinstance(module, FloatFunctional): modules_to_swap.append(name) else: self.swap_ff_with_fxff(module) for name in modules_to_swap: del model._modules[name] - model._modules[name] = torch.ao.nn.quantized.FXFloatFunctional() + model._modules[name] = FXFloatFunctional() def sync_qparams(self, src_mode): pass From 9bae28ccc422c53a3ef641f2ad0033e998be650c Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Wed, 18 Jan 2023 10:58:56 +0800 Subject: [PATCH 10/19] fix configs --- .../ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py | 8 ++++++++ .../ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py | 8 ++++++++ .../ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py | 8 ++++++++ ... => ptq_openvino_retina_r50_1x_coco_calib32xb32.py} | 5 ++--- .../qat/lsq_openvino_resnet18_8xb32_in1k.py | 10 ++-------- 5 files changed, 28 insertions(+), 11 deletions(-) rename configs/quantization/ptq/{ptq_openvino_retina_r50_1x_coco.py => ptq_openvino_retina_r50_1x_coco_calib32xb32.py} (88%) diff --git a/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py index 97333a282..e0f2128cc 100644 --- a/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py +++ b/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py @@ -22,6 +22,14 @@ model = dict( _delete_=True, type='mmrazor.MMArchitectureQuant', + data_preprocessor=dict( + type='mmcls.ClsDataPreprocessor', + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True), architecture=_base_.model, float_checkpoint=float_checkpoint, quantizer=dict( diff --git a/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py index 36bea3bc9..84d757552 100644 --- a/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py +++ b/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py @@ -24,6 +24,14 @@ model = dict( _delete_=True, type='mmrazor.MMArchitectureQuant', + data_preprocessor=dict( + type='mmcls.ClsDataPreprocessor', + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True), architecture=_base_.model, float_checkpoint=float_checkpoint, quantizer=dict( diff --git a/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py index 3f7740b02..03f8a4e22 100644 --- a/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py +++ b/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py @@ -24,6 +24,14 @@ model = dict( _delete_=True, type='mmrazor.MMArchitectureQuant', + data_preprocessor=dict( + type='mmcls.ClsDataPreprocessor', + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True), architecture=_base_.model, float_checkpoint=float_checkpoint, quantizer=dict( diff --git a/configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco.py b/configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py similarity index 88% rename from configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco.py rename to configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py index 36bd81a0a..59fb9f9df 100644 --- a/configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco.py +++ b/configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py @@ -9,8 +9,7 @@ ) retina = _base_.model -# data_preprocessor = retina.data_preprocessor -float_ckpt = '/mnt/petrelfs/caoweihan.p/ckpt/retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth' # noqa: E501 +float_checkpoint = 'https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_1x_coco/retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth' # noqa: E501 global_qconfig = dict( w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), @@ -33,7 +32,7 @@ bgr_to_rgb=True, pad_size_divisor=32), architecture=retina, - float_checkpoint=float_ckpt, + float_checkpoint=float_checkpoint, quantizer=dict( type='mmrazor.OpenVINOQuantizer', is_qat=False, diff --git a/configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py b/configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py index fef3ed1f1..3a8a65bb8 100644 --- a/configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py +++ b/configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py @@ -1,7 +1,7 @@ _base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] resnet = _base_.model -float_ckpt = '/mnt/petrelfs/caoweihan.p/ckpt/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501 +float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501 global_qconfig = dict( w_observer=dict(type='mmrazor.LSQPerChannelObserver'), @@ -26,7 +26,7 @@ # convert image from BGR to RGB to_rgb=True), architecture=resnet, - float_checkpoint=float_ckpt, + float_checkpoint=float_checkpoint, quantizer=dict( type='mmrazor.OpenVINOQuantizer', is_qat=True, @@ -63,9 +63,3 @@ val_interval=1) val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop') test_cfg = val_cfg - -default_hooks = dict( - checkpoint=dict( - type='CheckpointHook', - interval=-1, - out_dir='/mnt/petrelfs/caoweihan.p/training_ckpt/lsq')) From db1acb3dd524a8fb9da84cba7a894d29cf545ed6 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Wed, 18 Jan 2023 10:59:45 +0800 Subject: [PATCH 11/19] add get_qconfig_mapping --- mmrazor/models/quantizers/native_quantizer.py | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/mmrazor/models/quantizers/native_quantizer.py b/mmrazor/models/quantizers/native_quantizer.py index 653cbc931..102a6710f 100644 --- a/mmrazor/models/quantizers/native_quantizer.py +++ b/mmrazor/models/quantizers/native_quantizer.py @@ -123,14 +123,20 @@ def __init__(self, assert w_mode in self.support_w_modes assert a_mode in self.support_a_modes - self.qconfig_mapping = QConfigMapping().set_global( - self.qconfig.convert()) - if no_observer_modules: - self.no_observer_modules = str2class(no_observer_modules) - for mod in self.no_observer_modules: - self.qconfig_mapping.set_object_type(mod, None) - else: - self.no_observer_modules = no_observer_modules + self.qconfig_mapping = self.get_qconfig_mapping(no_observer_modules) + + self.backend_config = BackendConfigs[self.backend] + self.example_inputs = (torch.randn(1, 3, 224, 224), ) + + self.extra_redundant_fakequants = extra_redundant_fakequants + + def get_qconfig_mapping(self, no_observer_modules): + qconfig_mapping = QConfigMapping().set_global(self.qconfig.convert()) + + if no_observer_modules is not None: + no_observer_modules = str2class(no_observer_modules) + for mod in no_observer_modules: + qconfig_mapping.set_object_type(mod, None) fixed_qparams_observer_to_qconfig = {} for fixed_qparams_op, observer in _FIXED_QPARAMS_OP_TO_OBSERVER.items( @@ -146,13 +152,10 @@ def __init__(self, activation=activation, weight=default_weight_fake_quant) fixed_qparams_observer_to_qconfig[ observer] = fixed_qparams_qconfig - self.qconfig_mapping.set_object_type(fixed_qparams_op, - fixed_qparams_qconfig) - - self.backend_config = BackendConfigs[self.backend] - self.example_inputs = (torch.randn(1, 3, 224, 224), ) + qconfig_mapping.set_object_type(fixed_qparams_op, + fixed_qparams_qconfig) - self.extra_redundant_fakequants = extra_redundant_fakequants + return qconfig_mapping @property def backend(self): From 6fae8dd74c3ebfa301f433d75deae80ee8fa83d6 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Wed, 18 Jan 2023 12:10:09 +0800 Subject: [PATCH 12/19] delete is_qat, add doc and fix pytest --- .../ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py | 1 - ..._openvino_resnet18_8xb32_in1k_calib32xb32.py | 1 - ..._openvino_resnet50_8xb32_in1k_calib32xb32.py | 1 - ...q_openvino_retina_r50_1x_coco_calib32xb32.py | 1 - .../qat/lsq_openvino_resnet18_8xb32_in1k.py | 1 - .../algorithms/quantization/mm_architecture.py | 3 +++ mmrazor/models/quantizers/academic_quantizer.py | 3 ++- mmrazor/models/quantizers/native_quantizer.py | 16 ++++++++++++---- tests/test_runners/test_quantization_loop.py | 17 ----------------- 9 files changed, 17 insertions(+), 27 deletions(-) diff --git a/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py index e0f2128cc..7c919c0fd 100644 --- a/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py +++ b/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py @@ -34,7 +34,6 @@ float_checkpoint=float_checkpoint, quantizer=dict( type='mmrazor.OpenVINOQuantizer', - is_qat=False, global_qconfig=global_qconfig, tracer=dict( type='mmrazor.CustomTracer', diff --git a/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py index 84d757552..125f46367 100644 --- a/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py +++ b/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py @@ -36,7 +36,6 @@ float_checkpoint=float_checkpoint, quantizer=dict( type='mmrazor.OpenVINOQuantizer', - is_qat=False, global_qconfig=global_qconfig, tracer=dict( type='mmrazor.CustomTracer', diff --git a/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py index 03f8a4e22..f629337ed 100644 --- a/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py +++ b/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py @@ -36,7 +36,6 @@ float_checkpoint=float_checkpoint, quantizer=dict( type='mmrazor.OpenVINOQuantizer', - is_qat=False, global_qconfig=global_qconfig, tracer=dict( type='mmrazor.CustomTracer', diff --git a/configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py index 59fb9f9df..578f5fe84 100644 --- a/configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py +++ b/configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py @@ -35,7 +35,6 @@ float_checkpoint=float_checkpoint, quantizer=dict( type='mmrazor.OpenVINOQuantizer', - is_qat=False, global_qconfig=global_qconfig, tracer=dict( type='mmrazor.CustomTracer', diff --git a/configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py b/configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py index 3a8a65bb8..0b79232f8 100644 --- a/configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py +++ b/configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py @@ -29,7 +29,6 @@ float_checkpoint=float_checkpoint, quantizer=dict( type='mmrazor.OpenVINOQuantizer', - is_qat=True, global_qconfig=global_qconfig, tracer=dict( type='mmrazor.CustomTracer', diff --git a/mmrazor/models/algorithms/quantization/mm_architecture.py b/mmrazor/models/algorithms/quantization/mm_architecture.py index 767d7c4ce..06580cbb3 100644 --- a/mmrazor/models/algorithms/quantization/mm_architecture.py +++ b/mmrazor/models/algorithms/quantization/mm_architecture.py @@ -178,6 +178,9 @@ def _build_qmodels(self, model: BaseModel): observed_module = self.quantizer.prepare(model, concrete_args) qmodels[mode] = observed_module + # data_samples can not be None in detectors during prediction. + # But we need to make the dummy prediction in _build_qmodels. + # It is more convenient to use `tensor` mode. is_training = qmodels['tensor'].training # Avoid random input changing bn's statistics qmodels['tensor'].eval() diff --git a/mmrazor/models/quantizers/academic_quantizer.py b/mmrazor/models/quantizers/academic_quantizer.py index c8824e512..2d56be6c5 100644 --- a/mmrazor/models/quantizers/academic_quantizer.py +++ b/mmrazor/models/quantizers/academic_quantizer.py @@ -104,7 +104,8 @@ def prepare(self, model, concrete_args=None): fuse_custom_config = FuseCustomConfig().set_preserved_attributes( preserved_attributes) - self.sync_module_training_mode(graph_module) + # set the training modes of all modules to True to `_fuse_fx` correctly + self.sync_module_training_mode(graph_module, mode=True) graph_module = _fuse_fx( graph_module=graph_module, diff --git a/mmrazor/models/quantizers/native_quantizer.py b/mmrazor/models/quantizers/native_quantizer.py index 102a6710f..1fe620b7e 100644 --- a/mmrazor/models/quantizers/native_quantizer.py +++ b/mmrazor/models/quantizers/native_quantizer.py @@ -123,15 +123,21 @@ def __init__(self, assert w_mode in self.support_w_modes assert a_mode in self.support_a_modes - self.qconfig_mapping = self.get_qconfig_mapping(no_observer_modules) + self.qconfig_mapping = self.gen_qconfig_mapping( + self.qconfig, no_observer_modules) self.backend_config = BackendConfigs[self.backend] self.example_inputs = (torch.randn(1, 3, 224, 224), ) self.extra_redundant_fakequants = extra_redundant_fakequants - def get_qconfig_mapping(self, no_observer_modules): - qconfig_mapping = QConfigMapping().set_global(self.qconfig.convert()) + def gen_qconfig_mapping(self, qconfig, no_observer_modules): + """Convert qconfig in config file to `QConfigMapping`. + + `QConfigMapping` is a custom class for mapping from model ops to + :class:`torch.ao.quantization.QConfig` s. + """ + qconfig_mapping = QConfigMapping().set_global(qconfig.convert()) if no_observer_modules is not None: no_observer_modules = str2class(no_observer_modules) @@ -197,7 +203,9 @@ def prepare(self, model, concrete_args=None): traced_graph = self.tracer.trace(model, concrete_args=concrete_args) graph_module = build_graphmodule(model, traced_graph) - self.sync_module_training_mode(graph_module) + # set the training modes of all modules to True to `_fuse_fx` correctly + self.sync_module_training_mode(graph_module, mode=True) + graph_module = _fuse_fx( graph_module=graph_module, is_qat=True, diff --git a/tests/test_runners/test_quantization_loop.py b/tests/test_runners/test_quantization_loop.py index 0ddc578bf..bafeb203e 100644 --- a/tests/test_runners/test_quantization_loop.py +++ b/tests/test_runners/test_quantization_loop.py @@ -39,8 +39,6 @@ FloatFunctional = get_placeholder('torch>=1.13') FXFloatFunctional = get_placeholder('torch>=1.13') -from mmrazor import digit_version - class ToyDataset(Dataset): METAINFO = dict() # type: ignore @@ -68,21 +66,6 @@ def calibrate_step(self, data): data = self.data_preprocessor(data, False) return self.architecture(**data) - def swap_ff_with_fxff(self, model): - if digit_version(torch.__version__) < digit_version('1.13.0'): - self.skipTest('version of torch < 1.13.0') - - modules_to_swap = [] - for name, module in model.named_children(): - if isinstance(module, FloatFunctional): - modules_to_swap.append(name) - else: - self.swap_ff_with_fxff(module) - - for name in modules_to_swap: - del model._modules[name] - model._modules[name] = FXFloatFunctional() - def sync_qparams(self, src_mode): pass From ebbea723d2cd1ff0275465a9016829ea1cb88bd6 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Wed, 18 Jan 2023 12:20:09 +0800 Subject: [PATCH 13/19] delete useless codes in custom_tracer --- mmrazor/models/task_modules/tracer/fx/custom_tracer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mmrazor/models/task_modules/tracer/fx/custom_tracer.py b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py index afbafaf03..68d5f0809 100644 --- a/mmrazor/models/task_modules/tracer/fx/custom_tracer.py +++ b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py @@ -53,7 +53,6 @@ class is traced with CustomTracer, the decorated method will be as a leaf def __init__(self, method: FunctionType): self.method = method - self.instances: Dict = dict() self.owner = None def __set_name__(self, owner, name): From 553433ac53adbde9fdabe5f46143b49cb1071ae2 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Wed, 18 Jan 2023 12:25:03 +0800 Subject: [PATCH 14/19] skip pytest under pt13 --- .../test_models/test_fake_quants/test_lsq_fake_quants.py | 3 +++ tests/test_models/test_observer/test_lsq_observer.py | 5 +++++ tests/test_runners/test_quantization_loop.py | 9 +++++++++ 3 files changed, 17 insertions(+) diff --git a/tests/test_models/test_fake_quants/test_lsq_fake_quants.py b/tests/test_models/test_fake_quants/test_lsq_fake_quants.py index 63bf8b167..bd8fcbd50 100644 --- a/tests/test_models/test_fake_quants/test_lsq_fake_quants.py +++ b/tests/test_models/test_fake_quants/test_lsq_fake_quants.py @@ -4,6 +4,7 @@ import torch from torch.nn.parameter import Parameter +from mmrazor import digit_version from mmrazor.models import LearnableFakeQuantize try: @@ -16,6 +17,8 @@ class TestLearnableFakeQuantize(TestCase): def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') self.zero_point_trainable_fakequant = LearnableFakeQuantize.with_args( observer=MovingAverageMinMaxObserver, quant_min=0, diff --git a/tests/test_models/test_observer/test_lsq_observer.py b/tests/test_models/test_observer/test_lsq_observer.py index 4ae7a361a..a61f95d7f 100644 --- a/tests/test_models/test_observer/test_lsq_observer.py +++ b/tests/test_models/test_observer/test_lsq_observer.py @@ -3,12 +3,15 @@ import torch +from mmrazor import digit_version from mmrazor.models import LSQObserver, LSQPerChannelObserver class TestLSQObserver(TestCase): def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') self.lsq = LSQObserver.with_args( dtype=torch.quint8, qscheme=torch.per_tensor_symmetric, @@ -42,6 +45,8 @@ def test_calculate_qparams(self): class TestLSQPerChannelObserver(TestCase): def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') self.lsq = LSQPerChannelObserver.with_args( dtype=torch.qint8, qscheme=torch.per_channel_symmetric, diff --git a/tests/test_runners/test_quantization_loop.py b/tests/test_runners/test_quantization_loop.py index bafeb203e..ac5e5c501 100644 --- a/tests/test_runners/test_quantization_loop.py +++ b/tests/test_runners/test_quantization_loop.py @@ -18,6 +18,7 @@ from torch.nn.intrinsic.qat import ConvBnReLU2d from torch.utils.data import Dataset +from mmrazor import digit_version from mmrazor.engine import (LSQEpochBasedLoop, PTQLoop, QATEpochBasedLoop, QATValLoop) @@ -150,6 +151,8 @@ def compute_metrics(self, results): class TestQATEpochBasedLoop(TestCase): def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') self.temp_dir = tempfile.mkdtemp() MODELS.register_module(module=MMArchitectureQuant, force=True) DATASETS.register_module(module=ToyDataset, force=True) @@ -248,6 +251,8 @@ def check_observer_stats(mod): class TestLSQEpochBasedLoop(TestCase): def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') self.temp_dir = tempfile.mkdtemp() MODELS.register_module(module=MMArchitectureQuant, force=True) DATASETS.register_module(module=ToyDataset, force=True) @@ -318,6 +323,8 @@ def check_bn_stats(mod): class TestQATValLoop(TestCase): def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') self.temp_dir = tempfile.mkdtemp() MODELS.register_module(module=MMArchitectureQuant, force=True) DATASETS.register_module(module=ToyDataset, force=True) @@ -363,6 +370,8 @@ def test_run(self): class TestPTQLoop(TestCase): def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') self.temp_dir = tempfile.mkdtemp() MODELS.register_module(module=MMArchitectureQuant, force=True) DATASETS.register_module(module=ToyDataset, force=True) From 1c00387fdbfd992d5818e17a20d05abe0b31e9a0 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Wed, 18 Jan 2023 12:26:35 +0800 Subject: [PATCH 15/19] add todo: check freezebn --- mmrazor/models/quantizers/academic_quantizer.py | 1 + mmrazor/models/quantizers/native_quantizer.py | 1 + 2 files changed, 2 insertions(+) diff --git a/mmrazor/models/quantizers/academic_quantizer.py b/mmrazor/models/quantizers/academic_quantizer.py index 2d56be6c5..0dbe6dcdd 100644 --- a/mmrazor/models/quantizers/academic_quantizer.py +++ b/mmrazor/models/quantizers/academic_quantizer.py @@ -105,6 +105,7 @@ def prepare(self, model, concrete_args=None): preserved_attributes) # set the training modes of all modules to True to `_fuse_fx` correctly + # todo: check freezebn self.sync_module_training_mode(graph_module, mode=True) graph_module = _fuse_fx( diff --git a/mmrazor/models/quantizers/native_quantizer.py b/mmrazor/models/quantizers/native_quantizer.py index 1fe620b7e..f7250a073 100644 --- a/mmrazor/models/quantizers/native_quantizer.py +++ b/mmrazor/models/quantizers/native_quantizer.py @@ -204,6 +204,7 @@ def prepare(self, model, concrete_args=None): graph_module = build_graphmodule(model, traced_graph) # set the training modes of all modules to True to `_fuse_fx` correctly + # todo: check freezebn self.sync_module_training_mode(graph_module, mode=True) graph_module = _fuse_fx( From ebd1b9f0dc3f218ddd21824d79db64ad560f4ab4 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Wed, 18 Jan 2023 13:15:31 +0800 Subject: [PATCH 16/19] fix pytest bugs --- .../quantization/mm_architecture.py | 3 -- .../test_algorithms/test_mm_architecture.py | 46 ++++++++++++------- 2 files changed, 30 insertions(+), 19 deletions(-) diff --git a/mmrazor/models/algorithms/quantization/mm_architecture.py b/mmrazor/models/algorithms/quantization/mm_architecture.py index 06580cbb3..d3b0be089 100644 --- a/mmrazor/models/algorithms/quantization/mm_architecture.py +++ b/mmrazor/models/algorithms/quantization/mm_architecture.py @@ -61,9 +61,6 @@ def __init__(self, input_shapes: Tuple = (1, 3, 224, 224), init_cfg: Optional[Dict] = None): - if data_preprocessor is None: - data_preprocessor = getattr(architecture, 'data_preprocessor', - dict()) super().__init__(architecture, data_preprocessor, init_cfg) self.quantizer = MODELS.build(quantizer) diff --git a/tests/test_models/test_algorithms/test_mm_architecture.py b/tests/test_models/test_algorithms/test_mm_architecture.py index 4862bff91..639e0f492 100644 --- a/tests/test_models/test_algorithms/test_mm_architecture.py +++ b/tests/test_models/test_algorithms/test_mm_architecture.py @@ -61,11 +61,10 @@ def _inner_forward(x): return out -@MODELS.register_module() -class ToyQuantModel(BaseModel): +class ToyModel(nn.Module): def __init__(self): - super().__init__() + super(ToyModel, self).__init__() self.stem_layer = nn.Sequential( nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3), nn.ReLU()) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) @@ -85,15 +84,34 @@ def forward(self, x): return x +class ToyQuantModel(BaseModel): + + def __init__(self): + super().__init__() + self.architecture = ToyModel() + + def loss(self, outputs, data_samples): + return dict(loss=outputs.sum() - data_samples.sum()) + + def forward(self, inputs, data_samples, mode: str = 'tensor'): + if isinstance(inputs, list): + inputs = torch.stack(inputs) + outputs = self.architecture(inputs) + + return outputs + + class TestMMArchitectureQuant(TestCase): def setUp(self): if digit_version(torch.__version__) < digit_version('1.13.0'): self.skipTest('version of torch < 1.13.0') + + MODELS.register_module(module=ToyQuantModel, force=True) + self.temp_dir = tempfile.mkdtemp() filename = 'fp_model.pth' filename = os.path.join(self.temp_dir, filename) - # import pdb; pdb.set_trace() toymodel = ToyQuantModel() torch.save(toymodel.state_dict(), filename) @@ -120,18 +138,14 @@ def setUp(self): quantizer=dict( type='mmrazor.OpenVINOQuantizer', global_qconfig=global_qconfig, - tracer=dict( - type='mmrazor.CustomTracer', - skipped_methods=[ - 'mmcls.models.heads.ClsHead._get_loss', - 'mmcls.models.heads.ClsHead._get_predictions' - ]))) + tracer=dict(type='mmrazor.CustomTracer'))) self.alg_kwargs = alg_kwargs self.toy_model = MODELS.build(self.alg_kwargs) def tearDown(self): if digit_version(torch.__version__) < digit_version('1.13.0'): self.skipTest('version of torch < 1.13.0') + MODELS.module_dict.pop('ToyQuantModel') shutil.rmtree(self.temp_dir) def test_init(self): @@ -145,12 +159,12 @@ def test_sync_qparams(self): self.skipTest('version of torch < 1.13.0') mode = self.toy_model.forward_modes[0] self.toy_model.sync_qparams(mode) - w_loss = self.toy_model.qmodels['loss'].block.conv1.state_dict( - )['weight'] - w_tensor = self.toy_model.qmodels['tensor'].block.conv1.state_dict( - )['weight'] - w_pred = self.toy_model.qmodels['predict'].block.conv1.state_dict( - )['weight'] + w_loss = self.toy_model.qmodels[ + 'loss'].architecture.block.conv1.state_dict()['weight'] + w_tensor = self.toy_model.qmodels[ + 'tensor'].architecture.block.conv1.state_dict()['weight'] + w_pred = self.toy_model.qmodels[ + 'predict'].architecture.block.conv1.state_dict()['weight'] assert w_loss.equal(w_pred) assert w_loss.equal(w_tensor) From 53d760418bec0c7ba8f93b77466b80926dd2c068 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Wed, 18 Jan 2023 13:33:17 +0800 Subject: [PATCH 17/19] fix pytest --- mmrazor/models/quantizers/native_quantizer.py | 1 + .../test_lsq_observer.py | 0 .../test_models/test_quantizers/test_native_quantizer.py | 8 ++------ 3 files changed, 3 insertions(+), 6 deletions(-) rename tests/test_models/{test_observer => test_observers}/test_lsq_observer.py (100%) diff --git a/mmrazor/models/quantizers/native_quantizer.py b/mmrazor/models/quantizers/native_quantizer.py index f7250a073..1d566b45f 100644 --- a/mmrazor/models/quantizers/native_quantizer.py +++ b/mmrazor/models/quantizers/native_quantizer.py @@ -125,6 +125,7 @@ def __init__(self, self.qconfig_mapping = self.gen_qconfig_mapping( self.qconfig, no_observer_modules) + self.no_observer_modules = no_observer_modules self.backend_config = BackendConfigs[self.backend] self.example_inputs = (torch.randn(1, 3, 224, 224), ) diff --git a/tests/test_models/test_observer/test_lsq_observer.py b/tests/test_models/test_observers/test_lsq_observer.py similarity index 100% rename from tests/test_models/test_observer/test_lsq_observer.py rename to tests/test_models/test_observers/test_lsq_observer.py diff --git a/tests/test_models/test_quantizers/test_native_quantizer.py b/tests/test_models/test_quantizers/test_native_quantizer.py index 62052f66f..06a12c206 100644 --- a/tests/test_models/test_quantizers/test_native_quantizer.py +++ b/tests/test_models/test_quantizers/test_native_quantizer.py @@ -1,11 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. -import collections from unittest import TestCase import torch import torch.nn as nn from mmrazor import digit_version +from mmrazor.models.quantizers import NativeQuantizer from mmrazor.models.quantizers.native_quantizer import SUPPORT_QAT_MODULES from mmrazor.models.task_modules.tracer import CustomTracer from mmrazor.models.task_modules.tracer.fx.custom_tracer import \ @@ -155,11 +155,7 @@ def test_init(self): if digit_version(torch.__version__) < digit_version('1.13.0'): self.skipTest('version of torch < 1.13.0') native_quantizer = MODELS.build(self.q_kwargs) - no_ob_dict = collections.OrderedDict() - no_ob_dict = no_ob_dict.fromkeys(native_quantizer.no_observer_modules, - None) - assert native_quantizer.qconfig_mapping.object_type_qconfigs == \ - no_ob_dict + self.assertIsInstance(native_quantizer, NativeQuantizer) def test_prepare(self): if digit_version(torch.__version__) < digit_version('1.13.0'): From 81af85b580243616ff4c55ef755b68daac190160 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Wed, 18 Jan 2023 13:45:25 +0800 Subject: [PATCH 18/19] fix pytest --- tests/test_runners/test_quantization_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_runners/test_quantization_loop.py b/tests/test_runners/test_quantization_loop.py index ac5e5c501..f5f0d5148 100644 --- a/tests/test_runners/test_quantization_loop.py +++ b/tests/test_runners/test_quantization_loop.py @@ -356,7 +356,7 @@ def test_init(self): def test_run(self): cfg = copy.deepcopy(self.default_cfg) - cfg.experiment_name = 'test_val' + cfg.experiment_name = 'test_qat_val' cfg.pop('train_dataloader') cfg.pop('train_cfg') cfg.pop('optim_wrapper') @@ -408,6 +408,6 @@ def test_init(self): def test_run(self): cfg = copy.deepcopy(self.default_cfg) - cfg.experiment_name = 'test_val' + cfg.experiment_name = 'test_ptq_run' runner = Runner.from_cfg(cfg) runner.test() From 60850141f8ab50dc714d51a40506a5a603087424 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Wed, 18 Jan 2023 13:53:45 +0800 Subject: [PATCH 19/19] fix pytest --- tests/test_runners/test_quantization_loop.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_runners/test_quantization_loop.py b/tests/test_runners/test_quantization_loop.py index f5f0d5148..6a300fb91 100644 --- a/tests/test_runners/test_quantization_loop.py +++ b/tests/test_runners/test_quantization_loop.py @@ -184,7 +184,7 @@ def tearDown(self): def test_init(self): cfg = copy.deepcopy(self.default_cfg) - cfg.experiment_name = 'test_init' + cfg.experiment_name = 'test_init_qat_train_loop' runner = Runner(**cfg) self.assertIsInstance(runner, Runner) self.assertIsInstance(runner.train_loop, QATEpochBasedLoop) @@ -283,7 +283,7 @@ def tearDown(self): def test_init(self): cfg = copy.deepcopy(self.default_cfg) - cfg.experiment_name = 'test_init' + cfg.experiment_name = 'test_init_lsq_train_loop' runner = Runner(**cfg) self.assertIsInstance(runner, Runner) self.assertIsInstance(runner.train_loop, LSQEpochBasedLoop) @@ -349,7 +349,7 @@ def tearDown(self): def test_init(self): cfg = copy.deepcopy(self.default_cfg) - cfg.experiment_name = 'test_init' + cfg.experiment_name = 'test_init_qat_val_loop' runner = Runner(**cfg) self.assertIsInstance(runner, Runner) self.assertIsInstance(runner.val_loop, QATValLoop) @@ -401,7 +401,7 @@ def tearDown(self): def test_init(self): cfg = copy.deepcopy(self.default_cfg) - cfg.experiment_name = 'test_init' + cfg.experiment_name = 'test_init_ptq_loop' runner = Runner(**cfg) self.assertIsInstance(runner, Runner) self.assertIsInstance(runner.test_loop, PTQLoop)