From 6b1e482d70de13e9ba4513c4c3e17512c62c8733 Mon Sep 17 00:00:00 2001 From: whcao <41630003+HIT-cwh@users.noreply.github.com> Date: Fri, 25 Nov 2022 15:45:37 +0800 Subject: [PATCH] [Features]Quantize pipeline (#350) * init demo * add customer_tracer * add quantizer * add fake_quant, loop, config * remove CPatcher in custome_tracer * demo_try * init version * modified base.py * pre-rebase * wip of adaround series * adaround experiment * trasfer to s2 * update api * point at sub_reconstruction * pre-checkout * export onnx * add customtracer * fix lint * move custom tracer * fix import * update * updated * retina loss & predict & tesnor DONE * for RFC * Customed FX initialize * add UT init * TDO: UTs * Successfully RUN * update loop * update loop docstrings * update quantizer docstrings * update qscheme docstrings * update qobserver docstrings * update tracer docstrings * update UTs init * update UTs init * fix bugs * fix lsq * refactor quantize pipeline * fix quant * WIP: debug qat * fix lsq bugs * fix qat, docstring in progress * TDO: UTs * fix bugs * fix lsq * refactor quantize pipeline * fix quant * WIP: debug qat * fix lsq bugs * fix qat, docstring in progress * fixed DefaultQconfigs name * fix bugs * add comments and fix typos * delete useless codes * fix bugs and add comments * rename prepare_module_dict * update lsq config Co-authored-by: humu789 Co-authored-by: huangpengsheng Co-authored-by: FreakieHuang Co-authored-by: pppppM --- configs/quantization/ptq/adaround.py | 8 +- configs/quantization/qat/demo.py | 1 - .../qat/lsq_resnet18_8xb16_cifar10.py | 70 ++++ .../qat/lsq_resnet18_8xb32_in1k.py | 75 +++++ .../qat/lsq_resnet50_8xb16_cifar10.py | 37 -- mmrazor/engine/runner/quantization_loops.py | 315 +++++++++++++++--- mmrazor/models/algorithms/__init__.py | 2 +- .../models/algorithms/quantization/base.py | 154 ++++++--- mmrazor/models/fake_quants/lsq.py | 16 + .../units/mutable_channel_unit.py | 4 +- mmrazor/models/observers/__init__.py | 3 +- mmrazor/models/observers/base.py | 3 +- mmrazor/models/observers/minmax.py | 4 +- mmrazor/models/quantizers/base.py | 9 +- mmrazor/models/quantizers/trt_quantizer.py | 4 +- mmrazor/models/task_modules/__init__.py | 1 + .../models/task_modules/tracer/__init__.py | 8 +- .../models/task_modules/tracer/fx/__init__.py | 7 +- .../task_modules/tracer/fx/custom_tracer.py | 88 ++++- mmrazor/registry/registry.py | 1 + mmrazor/structures/quantization/__init__.py | 4 +- .../quantization/backend_default_qconfigs.py | 6 +- tools/ptq_calibrate.py | 73 ++++ 23 files changed, 739 insertions(+), 154 deletions(-) delete mode 100644 configs/quantization/qat/demo.py create mode 100644 configs/quantization/qat/lsq_resnet18_8xb16_cifar10.py create mode 100644 configs/quantization/qat/lsq_resnet18_8xb32_in1k.py delete mode 100644 configs/quantization/qat/lsq_resnet50_8xb16_cifar10.py create mode 100644 tools/ptq_calibrate.py diff --git a/configs/quantization/ptq/adaround.py b/configs/quantization/ptq/adaround.py index 389575dc6..78157c61a 100644 --- a/configs/quantization/ptq/adaround.py +++ b/configs/quantization/ptq/adaround.py @@ -1,12 +1,8 @@ -_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] +_base_ = ['mmcls::resnet/resnet18_8xb16_cifar10.py'] test_cfg = dict( - _delete_=True, type='mmrazor.PTQLoop', - dataloader=_base_.test_dataloader, - evaluator=_base_.test_evaluator, - calibrate_dataloader=_base_.train_dataloader, - batch_num=32, + # reconstruction_cfg=dict( # pattern='layer', # loss=dict( diff --git a/configs/quantization/qat/demo.py b/configs/quantization/qat/demo.py deleted file mode 100644 index be3ec6013..000000000 --- a/configs/quantization/qat/demo.py +++ /dev/null @@ -1 +0,0 @@ -_base_ = ['./lsq_resnet50_8xb16_cifar10.py'] diff --git a/configs/quantization/qat/lsq_resnet18_8xb16_cifar10.py b/configs/quantization/qat/lsq_resnet18_8xb16_cifar10.py new file mode 100644 index 000000000..412a6fd87 --- /dev/null +++ b/configs/quantization/qat/lsq_resnet18_8xb16_cifar10.py @@ -0,0 +1,70 @@ +_base_ = ['mmcls::resnet/resnet18_8xb16_cifar10.py'] + +resnet = _base_.model +pretrained_ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.pth' # noqa: E501 + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='GeneralQuant', + data_preprocessor=dict( + type='mmcls.ClsDataPreprocessor', + num_classes=10, + # RGB format normalization parameters + mean=[125.307, 122.961, 113.8575], + std=[51.5865, 50.847, 51.255], + # loaded images are already RGB format + to_rgb=False), + architecture=resnet, + pretrained_ckpt=pretrained_ckpt, + quantizer=dict( + type='CustomQuantizer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ], + qconfig=dict( + qtype='affine', + w_observer=dict(type='mmrazor.LSQObserver'), + a_observer=dict(type='mmrazor.LSQObserver'), + w_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), + a_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), + w_qscheme=dict( + bit=8, + is_symmetry=False, + is_per_channel=False, + is_pot_scale=False, + ), + a_qscheme=dict( + bit=8, + is_symmetry=False, + is_per_channel=False, + is_pot_scale=False), + ))) + +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.004, momentum=0.9, weight_decay=0.0001)) + +# learning policy +param_scheduler = dict( + _delete_=True, + type='CosineAnnealingLR', + T_max=100, + by_epoch=True, + begin=0, + end=100) + +model_wrapper_cfg = dict( + type='mmrazor.GeneralQuantDDP', + broadcast_buffers=False, + find_unused_parameters=True) + +# train, val, test setting +train_cfg = dict( + _delete_=True, + type='mmrazor.QATEpochBasedLoop', + by_epoch=True, + max_epochs=100, + val_interval=1) +val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop') +test_cfg = val_cfg diff --git a/configs/quantization/qat/lsq_resnet18_8xb32_in1k.py b/configs/quantization/qat/lsq_resnet18_8xb32_in1k.py new file mode 100644 index 000000000..a0885a52a --- /dev/null +++ b/configs/quantization/qat/lsq_resnet18_8xb32_in1k.py @@ -0,0 +1,75 @@ +_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] + +train_cfg = dict( + _delete_=True, + type='mmrazor.QATEpochBasedLoop', + max_epochs=_base_.train_cfg.max_epochs) + +resnet = _base_.model +ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501 +resnet.init_cfg = dict(type='Pretrained', checkpoint=ckpt) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='GeneralQuant', + # data_preprocessor = dict( + # num_classes=1000, + # # RGB format normalization parameters + # mean=[123.675, 116.28, 103.53], + # std=[58.395, 57.12, 57.375], + # # convert image from BGR to RGB + # to_rgb=True, + # ), + architecture=resnet, + quantizer=dict( + type='CustomQuantizer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ], + qconfig=dict( + qtype='affine', + w_observer=dict(type='mmrazor.MinMaxObserver'), + a_observer=dict(type='mmrazor.EMAMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), + a_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), + w_qscheme=dict( + bit=8, + is_symmetry=False, + is_per_channel=False, + is_pot_scale=False, + ), + a_qscheme=dict( + bit=8, + is_symmetry=False, + is_per_channel=False, + is_pot_scale=False), + ))) + +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.004, momentum=0.9, weight_decay=0.0001)) + +# learning policy +param_scheduler = dict( + _delete_=True, + type='CosineAnnealingLR', + T_max=100, + by_epoch=True, + begin=0, + end=100) + +default_hooks = dict( + checkpoint=dict( + type='CheckpointHook', + interval=5, + max_keep_ckpts=3, + out_dir='/mnt/petrelfs/caoweihan.p/training_ckpt/quant')) + +model_wrapper_cfg = dict( + type='mmrazor.GeneralQuantDDP', + broadcast_buffers=False, + find_unused_parameters=False) + +val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop') +test_cfg = val_cfg diff --git a/configs/quantization/qat/lsq_resnet50_8xb16_cifar10.py b/configs/quantization/qat/lsq_resnet50_8xb16_cifar10.py deleted file mode 100644 index a246bc265..000000000 --- a/configs/quantization/qat/lsq_resnet50_8xb16_cifar10.py +++ /dev/null @@ -1,37 +0,0 @@ -_base_ = ['mmcls::resnet/resnet18_8xb16_cifar10.py'] - -train_cfg = dict( - _delete_=True, - type='mmrazor.QATEpochBasedLoop', - max_epochs=_base_.train_cfg.max_epochs, -) - -model = dict( - _delete_=True, - _scope_='mmrazor', - type='GeneralQuant', - architecture={{_base_.model}}, - quantizer=dict( - type='TensorRTQuantizer', - skipped_methods=[ - 'mmcls.models.heads.ClsHead._get_loss', - 'mmcls.models.heads.ClsHead._get_predictions' - ], - qconfig=dict( - qtype='affine', - w_observer=dict(type='mmrazor.MinMaxObserver'), - a_observer=dict(type='mmrazor.EMAMinMaxObserver'), - w_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), - a_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), - w_qscheme=dict( - bit=2, - is_symmetry=False, - is_per_channel=True, - is_pot_scale=False, - ), - a_qscheme=dict( - bit=4, - is_symmetry=False, - is_per_channel=False, - is_pot_scale=False), - ))) diff --git a/mmrazor/engine/runner/quantization_loops.py b/mmrazor/engine/runner/quantization_loops.py index 2f15f5deb..a2d5d383b 100644 --- a/mmrazor/engine/runner/quantization_loops.py +++ b/mmrazor/engine/runner/quantization_loops.py @@ -1,13 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy import os -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Sequence, Tuple, Union import numpy as np import torch from mmengine.evaluator import Evaluator from mmengine.registry import MODELS -from mmengine.runner import EpochBasedTrainLoop, TestLoop +from mmengine.runner import EpochBasedTrainLoop, TestLoop, ValLoop, autocast +from torch.ao.quantization import disable_observer +from torch.nn.intrinsic.qat import freeze_bn_stats from torch.utils.data import DataLoader from mmrazor.models.task_modules import (ModuleInputsRecorder, @@ -28,12 +30,13 @@ class QATEpochBasedLoop(EpochBasedTrainLoop): dataloader (Dataloader or dict): An iterator to generate one batch of dataset each iteration. max_epochs (int): Total training epochs. - calibrate_dataloader (Dataloader or dict, optional): A dataloader - object or a dict to build a dataloader for calibration. Defaults - to None. val_begin (int): The epoch that begins validating. Defaults to 1. val_interval (int): Validation interval. Defaults to 1. + disable_observer_begin (int): The number of total epochs to update + observers. + freeze_bn_begin (int): The number of total epochs to update batch norm + stats. dynamic_intervals (List[Tuple[int, int]], optional): The first element in the tuple is a milestone and the second element is a interval. The interval is used after the @@ -45,68 +48,296 @@ def __init__( runner, dataloader: Union[DataLoader, Dict], max_epochs: int, - calibrate_dataloader: Union[DataLoader, Dict] = None, val_begin: int = 1, val_interval: int = 1, + disable_observer_begin: int = 3, + freeze_bn_begin: int = 3, dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None: super().__init__(runner, dataloader, max_epochs, val_begin, val_interval, dynamic_intervals) - if isinstance(calibrate_dataloader, dict): - # Determine whether or not different ranks use different seed. - diff_rank_seed = runner._randomness_cfg.get( - 'diff_rank_seed', False) - self.calibrate_dataloader = runner.build_dataloader( - calibrate_dataloader, - seed=runner.seed, - diff_rank_seed=diff_rank_seed) - else: - self.calibrate_dataloader = calibrate_dataloader - self.is_calibrate = True if calibrate_dataloader is not None else False + self.disable_observer_begin = disable_observer_begin + self.freeze_bn_begin = freeze_bn_begin - if self.runner.distributed: - self.model = runner.model.module - else: - self.model = runner.model - - def calibrate(self, calibrate_dataloader) -> None: - self.model.eval() - with torch.no_grad(): - for batch_data in calibrate_dataloader: - self.model(batch_data) - - def run(self) -> None: - """Launch training.""" - self.runner.call_hook('before_train') - - self.model.prepare() - - if self.is_calibrate: - self.model.state = (1, 0) - self.calibrate(self.calibrate_dataloader) - - self.model.state = (1, 1) + # compute metrics + metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) + qat_metrics = dict() + for key, value in metrics.items(): + qat_key = 'qat.' + key + ori_key = 'original.' + key + qat_metrics[qat_key] = value + self.runner.message_hub.log_scalars.pop(f'val/{ori_key}', None) while self._epoch < self._max_epochs: + # state: observer_enabled, fakequant_enabled + self.runner.model.state = (True, True) self.run_epoch() self._decide_current_val_interval() if (self.runner.val_loop is not None and self._epoch >= self.val_begin and self._epoch % self.val_interval == 0): + # observer disabled during evaluation + self.runner.model.state = (False, True) + self.runner.model.sync_param() self.runner.val_loop.run() - self.model.convert() + self.runner.call_hook('after_train') - # self.runner.val_loop.run() + def run_epoch(self) -> None: + """Iterate one epoch.""" + self.runner.call_hook('before_train_epoch') + self.runner.model.train() - self.runner.call_hook('after_train') + # TODO freeze bn + if self._epoch >= self.disable_observer_begin: + self.runner.model.apply(disable_observer) + + if self._epoch >= self.freeze_bn_begin: + self.runner.model.apply(freeze_bn_stats) + + for idx, data_batch in enumerate(self.dataloader): + self.run_iter(idx, data_batch) + + self.runner.call_hook('after_train_epoch') + self._epoch += 1 + + +@LOOPS.register_module() +class QATValLoop(ValLoop): + """`ValLoop` for `QuantizationAwareTraining` + + Args: + runner (Runner): A reference of runner + dataloader (Dataloader or dict): An iterator to generate one batch of + dataset each iteration. + evaluator (Evaluator or dict or list): Used for computing metrics. + fp16 (bool): Whether to enable fp16 validation. Defaults to + False. + """ + + def __init__(self, + runner, + dataloader: Union[DataLoader, Dict], + evaluator: Union[Evaluator, Dict, List], + fp16: bool = False) -> None: + super().__init__(runner, dataloader, evaluator, fp16) + if self.runner.distributed: + assert hasattr(self.runner.model.module, 'architecture') + # TODO: remove hard code after mmcls add data_preprocessor + data_preprocessor = self.runner.model.module.data_preprocessor + self.architecture = self.runner.model.module.architecture + self.architecture.data_preprocessor = data_preprocessor + + else: + assert hasattr(self.runner.model, 'architecture') + # TODO: remove hard code after mmcls add data_preprocessor + data_preprocessor = self.runner.model.data_preprocessor + self.architecture = self.runner.model.architecture + self.architecture.data_preprocessor = data_preprocessor + + def run(self) -> dict: + """Launch validation.""" + self.runner.call_hook('before_val') + self.runner.call_hook('before_val_epoch') + self.runner.model.eval() + for idx, data_batch in enumerate(self.dataloader): + self.run_iter(idx, data_batch, self.runner.model) + + # compute metrics + metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) + qat_metrics = dict() + for key, value in metrics.items(): + qat_key = 'qat.' + key + ori_key = 'original.' + key + qat_metrics[qat_key] = value + self.runner.message_hub.log_scalars.pop(f'val/{ori_key}', None) + + self.runner.call_hook('after_val_epoch', metrics=qat_metrics) + + self.runner.call_hook('before_val_epoch') + self.runner.model.eval() + for idx, data_batch in enumerate(self.dataloader): + self.run_iter(idx, data_batch, self.architecture) + + # compute metrics + metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) + qat_metrics = dict() + for key, value in metrics.items(): + qat_key = 'qat.' + key + ori_key = 'original.' + key + qat_metrics[ori_key] = value + self.runner.message_hub.log_scalars.pop(f'val/{qat_key}', None) + + self.runner.call_hook('after_val_epoch', metrics=qat_metrics) + + self.runner.call_hook('after_val') + return qat_metrics + + @torch.no_grad() + def run_iter(self, idx, data_batch: Sequence[dict], model): + """Iterate one mini-batch. + + Args: + data_batch (Sequence[dict]): Batch of data + from dataloader. + """ + self.runner.call_hook( + 'before_val_iter', batch_idx=idx, data_batch=data_batch) + # outputs should be sequence of BaseDataElement + with autocast(enabled=self.fp16): + outputs = model.val_step(data_batch) + self.evaluator.process(data_samples=outputs, data_batch=data_batch) + self.runner.call_hook( + 'after_val_iter', + batch_idx=idx, + data_batch=data_batch, + outputs=outputs) @LOOPS.register_module() class PTQLoop(TestLoop): """`TestLoop` for Post Training Quantization. + Args: + runner (Runner): A reference of runner + dataloader (Dataloader or dict): An iterator to generate one batch of + dataset each iteration. + evaluator (Evaluator or dict or list): Used for computing metrics. + fp16 (bool, optional): Enable FP16 training mode. Defaults to False. + """ + + def __init__(self, + runner, + dataloader: Union[DataLoader, Dict], + evaluator: Union[Evaluator, Dict, List], + fp16: bool = False): + super().__init__(runner, dataloader, evaluator, fp16) + + def run(self) -> dict: + """Launch test.""" + self.runner.call_hook('before_test') + self.runner.call_hook('before_test_epoch') + self.runner.model.eval() + self.runner.model.state = (True, False) + + for idx, data_batch in enumerate(self.dataloader): + self.run_iter(idx, data_batch) + + # compute metrics + metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) + + self.runner.call_hook('after_test_epoch', metrics=metrics) + self.runner.call_hook('after_test') + + # todo: hard code to save checkpoint on disk + self.runner.save_checkpoint( + self.runner.work_dir, + 'checkpoint_after_ptq.pth', + file_client_args=None, + save_optimizer=False, + save_param_scheduler=False) + + return metrics + + @torch.no_grad() + def run_iter(self, idx, data_batch: Sequence[dict]) -> None: + """Iterate one mini-batch. + + Args: + data_batch (Sequence[dict]): Batch of data from dataloader. + """ + self.runner.call_hook( + 'before_test_iter', batch_idx=idx, data_batch=data_batch) + # predictions should be sequence of BaseDataElement + + outputs = self.runner.model.calibrate_step(data_batch) + + self.runner.call_hook( + 'after_test_iter', + batch_idx=idx, + data_batch=data_batch, + outputs=outputs) + + +# TODO refactor to supoort DDP +@LOOPS.register_module() +class AdaRoundLoop(TestLoop): + """`TestLoop` for Post Training Quantization. + + Args: + runner (Runner): A reference of runner + dataloader (Dataloader or dict): An iterator to generate one batch of + dataset each iteration. + evaluator (Evaluator or dict or list): Used for computing metrics. + calibrate_dataloader (Dataloader or dict, optional): A dataloader + object or a dict to build a dataloader for calibration. Defaults + to None. + batch_num (Optional[int], optional): Total calibration batches. + Defaults to None. + reconstruction_cfg (Optional[Dict], optional): Model reconstruction + configuration. Defaults to None. + fp16 (bool, optional): Enable FP16 training mode. Defaults to False. + """ + + def __init__(self, + runner, + dataloader: Union[DataLoader, Dict], + evaluator: Union[Evaluator, Dict, List], + fp16: bool = False): + super().__init__(runner, dataloader, evaluator, fp16) + + def run(self) -> None: + """Launch test.""" + self.runner.call_hook('before_test') + self.runner.call_hook('before_test_epoch') + self.runner.model.eval() + self.runner.model.state = (1, 0) + + for idx, data_batch in enumerate(self.dataloader): + self.run_iter(idx, data_batch) + + # compute metrics + metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) + + self.runner.call_hook('after_test_epoch', metrics=metrics) + self.runner.call_hook('after_test') + + # todo: hard code to save checkpoint on disk + self.runner.save_checkpoint( + self.runner.work_dir, + 'checkpoint_after_ptq.pth', + file_client_args=None, + save_optimizer=False, + save_param_scheduler=False) + + return metrics + + @torch.no_grad() + def run_iter(self, idx, data_batch: Sequence[dict]) -> None: + """Iterate one mini-batch. + + Args: + data_batch (Sequence[dict]): Batch of data from dataloader. + """ + self.runner.call_hook( + 'before_test_iter', batch_idx=idx, data_batch=data_batch) + # predictions should be sequence of BaseDataElement + + outputs = self.runner.model.calibrate_step(data_batch) + + self.runner.call_hook( + 'after_test_iter', + batch_idx=idx, + data_batch=data_batch, + outputs=outputs) + + +# TODO refactor to supoort DDP +@LOOPS.register_module() +class AdaRoundLoop(TestLoop): + """`TestLoop` for Post Training Quantization. + Args: runner (Runner): A reference of runner dataloader (Dataloader or dict): An iterator to generate one batch of diff --git a/mmrazor/models/algorithms/__init__.py b/mmrazor/models/algorithms/__init__.py index 29c14d222..336d00cee 100644 --- a/mmrazor/models/algorithms/__init__.py +++ b/mmrazor/models/algorithms/__init__.py @@ -15,5 +15,5 @@ 'Darts', 'DartsDDP', 'DCFF', 'SelfDistill', 'DataFreeDistillation', 'DAFLDataFreeDistillation', 'OverhaulFeatureDistillation', 'ItePruneAlgorithm', 'DSNAS', 'DSNASDDP', 'Autoformer', 'BigNAS', - 'BigNASDDP' + 'BigNASDDP', 'GeneralQuant' ] diff --git a/mmrazor/models/algorithms/quantization/base.py b/mmrazor/models/algorithms/quantization/base.py index 718b08725..c97d832ff 100644 --- a/mmrazor/models/algorithms/quantization/base.py +++ b/mmrazor/models/algorithms/quantization/base.py @@ -1,11 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. +import os from typing import Dict, List, Optional, Tuple, Union import torch +from mmengine.model import MMDistributedDataParallel +from mmengine.runner import load_checkpoint from mmengine.structures import BaseDataElement -from torch.fx import GraphModule +from torch import nn +from torch.ao.quantization import FakeQuantizeBase -from mmrazor.registry import MODELS +from mmrazor.models.task_modules import build_graphmodule +from mmrazor.registry import MODEL_WRAPPERS, MODELS from ..base import BaseAlgorithm LossResults = Dict[str, torch.Tensor] @@ -19,13 +24,17 @@ class GeneralQuant(BaseAlgorithm): """General quantization. Args: - Args: architecture (dict | :obj:`BaseModel`): The config of :class:`BaseModel` or built model. quantizer (dict | :obj:`BaseModel`): The config of :class:`BaseQuantizer` or built model. + export_mode (str): The mode of the model to be exported. Defaults to + predict. + qmodel_modes (list): The available mode of runner. data_preprocessor (dict | torch.nn.Module | None): The pre-process config of :class:`BaseDataPreprocessor`. Defaults to None. + pretrained_ckpt (str, Optional): The path of pretrained checkpoint. + Defaults to None. init_cfg (dict): The weight initialized config for :class:`BaseModule`. """ @@ -33,74 +42,94 @@ class GeneralQuant(BaseAlgorithm): def __init__(self, architecture, quantizer, + export_mode: str = 'predict', + qmodel_modes: List[str] = ['tensor', 'predict', 'loss'], data_preprocessor=None, + pretrained_ckpt: Optional[str] = None, init_cfg=None): + if data_preprocessor is None: data_preprocessor = {} # The build process is in MMEngine, so we need to add scope here. data_preprocessor.setdefault('type', 'mmcls.ClsDataPreprocessor') - super().__init__(architecture, data_preprocessor, init_cfg) + if pretrained_ckpt: + _ = load_checkpoint(self.architecture, pretrained_ckpt) + self.architecture._is_init = True self.quantizer = MODELS.build(quantizer) - self.observers_enabled = True - self.fake_quants_enabled = True - self.gen_graphs(self.architecture) + self._observers_enabled = True + self._fake_quants_enabled = True + self.export_mode = export_mode + self.qmodel_modes = qmodel_modes + self.qmodels = self._build_qmodels(self.architecture) + + def sync_param(self): + + def traverse(module, prefix): + for name, child in module._modules.items(): + if module is None: + continue + module_name = f'{prefix}{name}' + if isinstance(child, FakeQuantizeBase): + for name, param in child.named_parameters(): + param.data.copy_(self.qmodels['loss'].state_dict() + [f'{module_name}.{name}']) + for name, buffer in child.named_buffers(): + buffer.data.copy_(self.qmodels['loss'].state_dict() + [f'{module_name}.{name}']) + else: + traverse(child, f'{module_name}.') + + for mode in self.qmodel_modes: + if mode == 'loss': + continue + traverse(self.qmodels[mode], '') + + def _build_qmodels(self, model): + + qmodels = nn.ModuleDict() - def gen_graphs(self, model): self.quantizer._swap_ff_with_fxff(model) tracer = self.quantizer.tracer - for mode in ['tensor', 'loss', 'predict']: + + for mode in self.qmodel_modes: concrete_args = {'mode': mode} - if mode == 'tensor': - self.graph_tensor = GraphModule( - model, tracer.trace(model, concrete_args=concrete_args)) - if mode == 'loss': - self.graph_loss = GraphModule( - model, tracer.trace(model, concrete_args=concrete_args)) - if mode == 'predict': - self.graph_predict = GraphModule( - model, tracer.trace(model, concrete_args=concrete_args)) + traced_graph = tracer.trace(model, concrete_args=concrete_args) + + qmodel = build_graphmodule(model, traced_graph) + qmodels[mode] = self.quantizer.prepare(model, qmodel) + + return qmodels def forward(self, inputs: torch.Tensor, data_samples: Optional[List[BaseDataElement]] = None, mode: str = 'tensor') -> ForwardResults: - if mode == 'loss': - return self.graph_loss(inputs, data_samples, mode) - elif mode == 'tensor': - return self.graph_tensor(inputs, data_samples, mode) - elif mode == 'predict': - return self.graph_predict(inputs, data_samples, mode) + if mode in self.qmodels: + qmodel = self.qmodels[mode] + return qmodel(inputs, data_samples, mode) else: - raise RuntimeError(f'Invalid mode "{mode}". ' - 'Only supports loss, predict and tensor mode') + return self.architecture(inputs, data_samples, mode) - def calib_step(self, data): + def calibrate_step(self, data): data = self.data_preprocessor(data, False) + self.state = (1, 0) return self._run_forward(data, mode='tensor') - def prepare(self, mode='tensor'): - assert mode in ['tensor', 'loss', 'predict'] - if mode == 'tensor': - graph = self.graph_tensor - elif mode == 'loss': - graph = self.graph_loss - else: - graph = self.graph_predict - self.architecture = self.quantizer.prepare(self.architecture, graph) - - def convert(self): - self.architecture = self.quantizer.convert(self.architecture) + def convert(self, mode='predict'): + qmodel = self.qmodels[self.export_mode] + self.qmodels[mode] = self.quantizer.convert(qmodel) @property def state(self): - return (self.observers_enabled, self.fake_quants_enabled) + return (self._observers_enabled, self._fake_quants_enabled) @state.setter - def state(self, state): + def state(self, state: Tuple[bool, bool]): observers_enabled, fake_quants_enabled = state - for name, submodule in self.architecture.named_modules(): + qmodel = self.qmodels[self.export_mode] + for submodule in qmodel.modules(): if isinstance(submodule, torch.quantization.FakeQuantize): if observers_enabled: submodule.enable_observer() @@ -112,5 +141,42 @@ def state(self, state): else: submodule.disable_fake_quant() - self.observers_enabled = observers_enabled - self.fake_quants_enabled = fake_quants_enabled + self._observers_enabled = observers_enabled + self._fake_quants_enabled = fake_quants_enabled + + +@MODEL_WRAPPERS.register_module() +class GeneralQuantDDP(MMDistributedDataParallel): + """DDPwapper for GeneralQuant.""" + + def __init__(self, + *, + device_ids: Optional[Union[List, int, torch.device]] = None, + **kwargs) -> None: + if device_ids is None: + if os.environ.get('LOCAL_RANK') is not None: + device_ids = [int(os.environ['LOCAL_RANK'])] + super().__init__(device_ids=device_ids, **kwargs) + # After moving all model parameters and buffers to the GPU + # (`model.cuda()`), the buffers in model are different. + self.module.qmodels = self.module._build_qmodels( + self.module.architecture) + + def calibrate_step(self, data): + return self.module.calibrate_step(data) + + @property + def state(self): + return (self.module._observers_enabled, + self.module._fake_quants_enabled) + + @state.setter + def state(self, state: Tuple[bool]): + self.module.state = state + + def convert(self, mode='predict'): + self.module.convert(mode) + self.module.qmodels[mode].cuda() + + def sync_param(self): + self.module.sync_param() diff --git a/mmrazor/models/fake_quants/lsq.py b/mmrazor/models/fake_quants/lsq.py index 10970a6a3..3e26631eb 100644 --- a/mmrazor/models/fake_quants/lsq.py +++ b/mmrazor/models/fake_quants/lsq.py @@ -48,6 +48,22 @@ def extra_repr(self): self.scale if self.ch_axis == -1 else 'List[%s]' % str(self.scale.shape), # noqa: E501 self.zero_point if self.ch_axis == -1 else 'List') + @torch.jit.export + def calculate_qparams(self): + self.scale.data.clamp_(min=self.eps.item()) + scale = self.scale.detach() + zero_point = self.zero_point.detach().round().clamp( + self.quant_min, self.quant_max).long() + return scale, zero_point + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super(FakeQuantize, self)._save_to_state_dict(destination, prefix, + keep_vars) + destination[prefix + 'scale'] = self.scale if keep_vars \ + else self.scale.detach() + destination[prefix + 'zero_point'] = self.zero_point if keep_vars \ + else self.zero_point.detach() + def forward(self, X): # Learnable fake quantize have to zero_point.float() # to make it learnable. diff --git a/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.py b/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.py index dabe41fab..be7b9ad29 100644 --- a/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.py +++ b/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. """This module defines MutableChannelUnit.""" import abc -from collections import Set +# from collections import set from typing import Dict, List, Type, TypeVar import torch @@ -71,7 +71,7 @@ def process_container(container: MutableChannelContainer, if isinstance(derived_choices, torch.Tensor): derived_choices = derived_choices.sum().item() if isinstance(mutable, DerivedMutable): - source_mutables: Set = \ + source_mutables: set = \ mutable._trace_source_mutables() source_channel_mutables = [ mutable for mutable in source_mutables diff --git a/mmrazor/models/observers/__init__.py b/mmrazor/models/observers/__init__.py index 22af9bae9..eac6371e2 100644 --- a/mmrazor/models/observers/__init__.py +++ b/mmrazor/models/observers/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .lsq_observer import LSQObserver from .minmax import EMAMinMaxObserver, MinMaxObserver from .mse import MSEObserver -__all__ = ['MinMaxObserver', 'MSEObserver', 'EMAMinMaxObserver'] +__all__ = ['MinMaxObserver', 'MSEObserver', 'EMAMinMaxObserver', 'LSQObserver'] diff --git a/mmrazor/models/observers/base.py b/mmrazor/models/observers/base.py index e10738664..8d9c40afe 100644 --- a/mmrazor/models/observers/base.py +++ b/mmrazor/models/observers/base.py @@ -6,9 +6,8 @@ from mmrazor.models.utils import pot_quantization, sync_tensor -# from mmengine.model import BaseModule - +# todo: We only support per-tensor quantization currently. class BaseObserver(UniformQuantizationObserverBase): """Modified torch quantization observer. diff --git a/mmrazor/models/observers/minmax.py b/mmrazor/models/observers/minmax.py index 099296536..2ec620e60 100644 --- a/mmrazor/models/observers/minmax.py +++ b/mmrazor/models/observers/minmax.py @@ -45,8 +45,8 @@ def forward(self, x_orig): min_val_cur, max_val_cur = torch._aminmax(y, 1) min_val = torch.min(self.min_val, min_val_cur) max_val = torch.max(self.max_val, max_val_cur) - self.min_val.copy_(min_val) - self.max_val.copy_(max_val) + self.min_val = min_val + self.max_val = max_val return x diff --git a/mmrazor/models/quantizers/base.py b/mmrazor/models/quantizers/base.py index ab4cf190a..6f1fb4e31 100644 --- a/mmrazor/models/quantizers/base.py +++ b/mmrazor/models/quantizers/base.py @@ -13,7 +13,7 @@ check_is_valid_qconfig_dict, get_custom_module_class_keys) from mmrazor.registry import MODELS -from mmrazor.structures.quantization import (CheckArgs, DefalutQconfigs, +from mmrazor.structures.quantization import (CheckArgs, DefaultQconfigs, QuantizeScheme, SupportQtypes) @@ -22,7 +22,7 @@ class CustomQuantizer(BaseModule): """Configurable quantizer, base class of quantizers. Args: - qconfig (Dict, optional): QConfig. Defaults to DefalutQconfigs['default']. # noqa: E501 + qconfig (Dict, optional): QConfig. Defaults to DefaultQconfigs['default']. # noqa: E501 is_qat (bool, optional): Is QAT ro not. Defaults to True. skipped_methods (List, optional): Skipped methods list for tracer. Defaults to None. @@ -38,7 +38,7 @@ class CustomQuantizer(BaseModule): """ def __init__(self, - qconfig: Dict = DefalutQconfigs['default'], + qconfig: Dict = DefaultQconfigs['default'], is_qat: bool = True, skipped_methods: List = None, prepare_custom_config_dict: Dict = None, @@ -189,6 +189,9 @@ def build_tracer(self): return tracer def fuse_model(self, graph_module): + if not self.is_qat: + graph_module.eval() + graph_module = _fuse_fx(graph_module, self.is_qat, self.prepare_custom_config_dict) return graph_module diff --git a/mmrazor/models/quantizers/trt_quantizer.py b/mmrazor/models/quantizers/trt_quantizer.py index cc8532a53..9dbe9f594 100644 --- a/mmrazor/models/quantizers/trt_quantizer.py +++ b/mmrazor/models/quantizers/trt_quantizer.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmrazor.registry import MODELS -from mmrazor.structures.quantization import DefalutQconfigs +from mmrazor.structures.quantization import DefaultQconfigs from .base import CustomQuantizer @@ -9,7 +9,7 @@ class TensorRTQuantizer(CustomQuantizer): """Quantizer for TensorRT backend.""" def __init__(self, - qconfig=DefalutQconfigs['tensorrt'], + qconfig=DefaultQconfigs['tensorrt'], is_qat=True, skipped_methods=None, prepare_custom_config_dict=None, diff --git a/mmrazor/models/task_modules/__init__.py b/mmrazor/models/task_modules/__init__.py index 931278b8a..1b3811a2a 100644 --- a/mmrazor/models/task_modules/__init__.py +++ b/mmrazor/models/task_modules/__init__.py @@ -3,6 +3,7 @@ from .demo_inputs import * # noqa: F401,F403 from .estimators import ResourceEstimator from .predictor import * # noqa: F401,F403 +# from .fx import * # noqa: F401, F403 from .recorder import * # noqa: F401,F403 from .tracer import * # noqa: F401,F403 diff --git a/mmrazor/models/task_modules/tracer/__init__.py b/mmrazor/models/task_modules/tracer/__init__.py index 1cbc49e71..8f70fb161 100644 --- a/mmrazor/models/task_modules/tracer/__init__.py +++ b/mmrazor/models/task_modules/tracer/__init__.py @@ -2,7 +2,8 @@ from .backward_tracer import BackwardTracer from .channel_analyzer import ChannelAnalyzer # from .razor_tracer import RazorFxTracer -from .fx import CustomTracer, UntracedMethodRegistry, custom_symbolic_trace +from .fx import (CustomTracer, UntracedMethodRegistry, custom_symbolic_trace, + prepare_graph_module) from .loss_calculator import * # noqa: F401,F403 from .parsers import * # noqa: F401,F403 from .path import (Path, PathConcatNode, PathConvNode, PathDepthWiseConvNode, @@ -11,6 +12,7 @@ __all__ = [ 'BackwardTracer', 'PathConvNode', 'PathLinearNode', 'PathNormNode', 'PathConcatNode', 'Path', 'PathList', 'PathNode', 'PathDepthWiseConvNode', - 'ChannelAnalyzer' - 'CustomTracer', 'UntracedMethodRegistry', 'custom_symbolic_trace' + 'ChannelAnalyzer', + 'CustomTracer', 'UntracedMethodRegistry', 'custom_symbolic_trace', + 'prepare_graph_module' ] diff --git a/mmrazor/models/task_modules/tracer/fx/__init__.py b/mmrazor/models/task_modules/tracer/fx/__init__.py index 29c93f83a..998e9ffe1 100644 --- a/mmrazor/models/task_modules/tracer/fx/__init__.py +++ b/mmrazor/models/task_modules/tracer/fx/__init__.py @@ -1,5 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. from .custom_tracer import (CustomTracer, UntracedMethodRegistry, - custom_symbolic_trace) + custom_symbolic_trace, build_graphmodule) -__all__ = ['CustomTracer', 'UntracedMethodRegistry', 'custom_symbolic_trace'] +__all__ = [ + 'CustomTracer', 'UntracedMethodRegistry', 'custom_symbolic_trace', + 'build_graphmodule' +] diff --git a/mmrazor/models/task_modules/tracer/fx/custom_tracer.py b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py index f69ec2269..1d78d3007 100644 --- a/mmrazor/models/task_modules/tracer/fx/custom_tracer.py +++ b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, List, Optional, Type, Union import torch +import torch.nn as nn from mmengine.utils import import_modules_from_strings from torch._C import ScriptObject # type: ignore[attr-defined] from torch.ao.quantization.quantize_fx import QuantizationTracer @@ -14,7 +15,6 @@ _orig_module_call: Callable = torch.nn.Module.__call__ _orig_module_getattr: Callable = torch.nn.Module.__getattr__ -# _orig_module_forward_train: Callable = models.BaseDenseHead.forward_train class UntracedMethodRegistry: @@ -78,6 +78,92 @@ def custom_symbolic_trace( return GraphModule(tracer.root, graph, name) +def _prepare_module_dict(model: nn.Module, fx_graph: torch.fx.Graph): + """If there is a class method that can not be traced by the symbolic + tracer, a ``call_method`` ``Node`` will be inserted into the ``Graph`` in + ``CustomTracer``. + + For example, + ``` + >>> class Model: + ... def __init__(self): + ... self.head = ClsHead() + ... + >>> class ClsHead(nn.Module): + ... def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + ... return feats[-1] + ... + ... def loss(self, feats: Tuple[torch.Tensor], + ... data_samples: List[ClsDataSample], **kwargs) -> dict: + ... cls_score = self(feats) + ... # The part can not be traced by torch.fx + ... losses = self._get_loss(cls_score, data_samples, **kwargs) + ... return losses + ... + ... def _get_loss(self, cls_score: torch.Tensor, + ... data_samples: List[ClsDataSample], **kwargs): + ... if 'score' in data_samples[0].gt_label: + ... xxx + ... else: + ... xxx + ... losses = xxx + ... return losses + ``` + As the ``_get_loss`` can not be traced by torch.fx, ``Toy._get_loss`` need + to be added to ``skipped_methods`` in ``CustomTracer``. Hence the code + above will product the following Graph:: + + .. code-block:: text + ... ... + %head : [#users=1] = get_attr[target=head] + %_get_loss : [#users=1] = call_method[target=_get_loss](args = (%head, %head_fc, %data_samples), kwargs = {}) # noqa: E501 + return _get_loss + + Hence, the head module in the ``GraphModule`` and that in the original + model are the same one (refer to https://github.com/pytorch/pytorch/blob/master/torch/fx/graph_module.py#L346). # noqa: E501 + So changes made to the graph module (in ``prepare()``) will also modify + the original model. + + Args: + model (nn.Module): The original model. + fx_graph (torch.fx.Graph): The fx Graph traced by fx tracer. + """ + + def _get_attrs(target, attrs): + attrs = attrs.split('.') + for att in attrs: + target = getattr(target, att) + return target + + module_dict = dict() + special_nodes = [] + + for node in fx_graph.nodes: + if node.op == 'get_attr': + attr = _get_attrs(model, node.target) + if isinstance(attr, nn.Module): + module_dict[node.target] = nn.Module() + special_nodes.append(node) + elif node.op == 'call_method': + for special_node in special_nodes: + if special_node in node.args or \ + special_node in node.kwargs.values(): + origin_module = getattr(model, special_node.target) + setattr(module_dict[special_node.target], node.target, + getattr(origin_module, node.target)) + + return module_dict + + +def build_graphmodule(model: nn.Module, + fx_graph: torch.fx.Graph, + name: str = 'GraphModule'): + modules = dict(model.named_modules()) + module_dict = _prepare_module_dict(model, fx_graph) + modules.update(module_dict) + return GraphModule(modules, fx_graph, name) + + class CustomTracer(QuantizationTracer): def __init__(self, diff --git a/mmrazor/registry/registry.py b/mmrazor/registry/registry.py index d3a5c5423..d6fd480cd 100644 --- a/mmrazor/registry/registry.py +++ b/mmrazor/registry/registry.py @@ -30,6 +30,7 @@ from mmengine.registry import \ WEIGHT_INITIALIZERS as MMENGINE_WEIGHT_INITIALIZERS from mmengine.registry import Registry, build_from_cfg +from mmengine.runner import load_checkpoint def build_razor_model_from_cfg( diff --git a/mmrazor/structures/quantization/__init__.py b/mmrazor/structures/quantization/__init__.py index fc2133bf2..9447c2f0f 100644 --- a/mmrazor/structures/quantization/__init__.py +++ b/mmrazor/structures/quantization/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .backend_default_qconfigs import CheckArgs, DefalutQconfigs, SupportQtypes +from .backend_default_qconfigs import CheckArgs, DefaultQconfigs, SupportQtypes from .qscheme import QuantizeScheme -__all__ = ['QuantizeScheme', 'DefalutQconfigs', 'SupportQtypes', 'CheckArgs'] +__all__ = ['QuantizeScheme', 'DefaultQconfigs', 'SupportQtypes', 'CheckArgs'] diff --git a/mmrazor/structures/quantization/backend_default_qconfigs.py b/mmrazor/structures/quantization/backend_default_qconfigs.py index 6a1fde183..590f3208a 100644 --- a/mmrazor/structures/quantization/backend_default_qconfigs.py +++ b/mmrazor/structures/quantization/backend_default_qconfigs.py @@ -19,8 +19,8 @@ is_pot_scale=False, bit=8, symmetric_range=True), - w_fake_quant=dict(type='BaseFakeQuantize'), - a_fake_quant=dict(type='BaseFakeQuantize'), + w_fake_quant=dict(type='FakeQuantize'), + a_fake_quant=dict(type='FakeQuantize'), w_observer=dict(type='MinMaxObserver'), a_observer=dict(type='MinMaxObserver')) @@ -43,4 +43,4 @@ w_observer=dict(type='MinMaxObserver'), a_observer=dict(type='EMAMinMaxObserver')) -DefalutQconfigs = dict(default=Default, tensorrt=TensorRT) +DefaultQconfigs = dict(default=Default, tensorrt=TensorRT) diff --git a/tools/ptq_calibrate.py b/tools/ptq_calibrate.py new file mode 100644 index 000000000..2c00c5b11 --- /dev/null +++ b/tools/ptq_calibrate.py @@ -0,0 +1,73 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import os.path as osp + +from mmengine.config import Config, DictAction +from mmengine.runner import Runner + +from mmrazor.utils import register_all_modules + + +# TODO: support fuse_conv_bn, visualization, and format_only +def parse_args(): + parser = argparse.ArgumentParser( + description='MMRazor test (and eval) a model') + parser.add_argument('config', help='test config file path') + # parser.add_argument('checkpoint', help='checkpoint file') + parser.add_argument( + '--work-dir', + help='the directory to save the file containing evaluation metrics') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + + return args + + +def main(): + register_all_modules(False) + args = parse_args() + + # load config + cfg = Config.fromfile(args.config) + cfg.launcher = args.launcher + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + # work_dir is determined in this priority: CLI > segment in file > filename + if args.work_dir is not None: + # update configs according to CLI args if args.work_dir is not None + cfg.work_dir = args.work_dir + elif cfg.get('work_dir', None) is None: + # use config filename as default work_dir if cfg.work_dir is None + cfg.work_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(args.config))[0]) + + # cfg.load_from = args.checkpoint + + # build the runner from config + runner = Runner.from_cfg(cfg) + + # start testing + runner.test() + + +if __name__ == '__main__': + main()