Skip to content

Commit

Permalink
[Docs] Add docstring for MMArchitectureQuant & NativeQuantizer (#425
Browse files Browse the repository at this point in the history
)

* add docstring on mm_architecture& native_quantizer

* add naive openvino r18 qat config & dist_ptq.sh

* Added a more accurate description

* unitest&doc

* checkpoint url

* unitest

* passed_pre_commit

* unitest on native_quantizer& fix bugs

* remove dist_ptq

* add get_placeholder&skipTest

* complete arg descriptions

* fix import bugs

* fix pre-commit

* add get_placeholder

* add typehint and doctring

* update docstring&typehint

* update docstring

* pre-commit

* fix some problems

* fix bug
  • Loading branch information
Niko-zyf authored Jan 17, 2023
1 parent 5d50314 commit 9872c61
Show file tree
Hide file tree
Showing 8 changed files with 679 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
qdtype='quint8', bit=8, is_symmetry=True, averaging_constant=0.1),
)

float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth' # noqa: E501

model = dict(
_delete_=True,
type='mmrazor.MMArchitectureQuant',
architecture=_base_.model,
float_checkpoint='/tmp/humu/mobilenet_v2_batch256_imagenet' +
'_20200708-3b2dc3af.pth',
float_checkpoint=float_checkpoint,
quantizer=dict(
type='mmrazor.OpenVINOQuantizer',
global_qconfig=global_qconfig,
Expand All @@ -32,3 +33,8 @@
'mmcls.models.heads.ClsHead._get_loss',
'mmcls.models.heads.ClsHead._get_predictions'
])))

model_wrapper_cfg = dict(
type='mmrazor.MMArchitectureQuantDDP',
broadcast_buffers=False,
find_unused_parameters=True)
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
qdtype='quint8', bit=8, is_symmetry=True, averaging_constant=0.1),
)

float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501

model = dict(
_delete_=True,
type='mmrazor.MMArchitectureQuant',
architecture=_base_.model,
float_checkpoint='/tmp/humu/resnet18_8xb32_in1k_20210831-fbbb1da6.pth',
float_checkpoint=float_checkpoint,
quantizer=dict(
type='mmrazor.OpenVINOQuantizer',
global_qconfig=global_qconfig,
Expand All @@ -33,3 +35,5 @@
'mmcls.models.heads.ClsHead._get_loss',
'mmcls.models.heads.ClsHead._get_predictions'
])))

model_wrapper_cfg = dict(type='mmrazor.MMArchitectureQuantDDP', )
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
qdtype='quint8', bit=8, is_symmetry=True, averaging_constant=0.1),
)

float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth' # noqa: E501

model = dict(
_delete_=True,
type='mmrazor.MMArchitectureQuant',
architecture=_base_.model,
float_checkpoint='/tmp/humu/resnet50_8xb32_in1k_20210831-ea4938fc.pth',
float_checkpoint=float_checkpoint,
quantizer=dict(
type='mmrazor.OpenVINOQuantizer',
global_qconfig=global_qconfig,
Expand All @@ -33,3 +35,4 @@
'mmcls.models.heads.ClsHead._get_loss',
'mmcls.models.heads.ClsHead._get_predictions'
])))
model_wrapper_cfg = dict(type='mmrazor.MMArchitectureQuantDDP', )
65 changes: 65 additions & 0 deletions configs/quantization/qat/minmax_openvino_resnet18_8xb32_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py']

train_dataloader = dict(batch_size=1024)

global_qconfig = dict(
w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'),
a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'),
w_fake_quant=dict(type='mmrazor.FakeQuantize'),
a_fake_quant=dict(type='mmrazor.FakeQuantize'),
w_qscheme=dict(
qdtype='qint8',
bit=8,
is_symmetry=True,
is_symmetric_range=True,
),
a_qscheme=dict(
qdtype='quint8',
bit=8,
is_symmetry=True,
averaging_constant=0.1,
),
)

float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501

model = dict(
_delete_=True,
type='mmrazor.MMArchitectureQuant',
architecture=_base_.model,
float_checkpoint=float_checkpoint,
quantizer=dict(
type='mmrazor.OpenVINOQuantizer',
global_qconfig=global_qconfig,
tracer=dict(
type='mmrazor.CustomTracer',
skipped_methods=[
'mmcls.models.heads.ClsHead._get_loss',
'mmcls.models.heads.ClsHead._get_predictions'
])))

optim_wrapper = dict(
optimizer=dict(type='SGD', lr=0.004, momentum=0.9, weight_decay=0.0001))

# learning policy
param_scheduler = dict(
_delete_=True,
type='CosineAnnealingLR',
T_max=100,
by_epoch=True,
begin=0,
end=100)

model_wrapper_cfg = dict(
type='mmrazor.MMArchitectureQuantDDP',
broadcast_buffers=False,
find_unused_parameters=True)

# train, val, test setting
train_cfg = dict(
_delete_=True,
type='mmrazor.QATEpochBasedLoop',
max_epochs=100,
val_interval=1)
val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop')
# test_cfg = val_cfg
129 changes: 101 additions & 28 deletions mmrazor/models/algorithms/quantization/mm_architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from mmrazor.models.task_modules.tracer import build_graphmodule
from mmrazor.registry import MODEL_WRAPPERS, MODELS
from ..base import BaseAlgorithm
from ..base import BaseAlgorithm, BaseModel

try:
from torch.ao.quantization import FakeQuantizeBase
Expand All @@ -29,35 +29,43 @@ class MMArchitectureQuant(BaseAlgorithm):
"""General quantization.
Args:
architecture (dict | :obj:`BaseModel`): The config of
:class:`BaseModel` or built model.
quantizer (dict | :obj:`BaseModel`): The config of
:class:`BaseQuantizer` or built model.
export_mode (str): The mode of the model to be exported. Defaults to
predict.
qmodel_modes (list): The available mode of runner.
data_preprocessor (dict | torch.nn.Module | None): The pre-process
architecture (Union[Dict, BaseModel]): The config of model to be
quantized.
quantizer (Union[Dict, BaseModel]): The quantizer to support different
backend type.
qmodel_modes (List): The available mode of runner.
data_preprocessor (Optional[Dict]): The pre-process
config of :class:`BaseDataPreprocessor`. Defaults to None.
pretrained_ckpt (str, Optional): The path of pretrained checkpoint.
Defaults to None.
init_cfg (dict): The weight initialized config for
:class:`BaseModule`.
forward_modes (Tuple): The modes in forward method in OpenMMLab
architecture could be tensor, predict, or loss. It can generate
different graph of quantized model.
float_checkpoint (Optional[str]): The path of pretrained FP checkpoint.
Quantization is different from or task, we recommend to use
`float_checkpoint` as pretrain model. Defaults to None.
init_cfg (Optional[Dict]): The weight initialized config for:
class:`BaseModule`.
Note:
forward_modes (Tuple): In OpenMMLab architecture, differenet modes
will trace a different graph of quantized model.
"""

def __init__(self,
architecture,
quantizer,
data_preprocessor=None,
forward_modes=('tensor', 'predict', 'loss'),
architecture: Union[Dict, BaseModel],
quantizer: Union[Dict, BaseModel],
data_preprocessor: Optional[Dict] = None,
forward_modes: Tuple = ('tensor', 'predict', 'loss'),
float_checkpoint: Optional[str] = None,
input_shapes=(1, 3, 224, 224),
init_cfg=None):
input_shapes: Tuple = (1, 3, 224, 224),
init_cfg: Optional[Dict] = None):

if data_preprocessor is None:
data_preprocessor = {}
# The build process is in MMEngine, so we need to add scope here.
# Default to mmcls.ClsDataPreprocessor.
data_preprocessor.setdefault('type', 'mmcls.ClsDataPreprocessor')
super().__init__(architecture, data_preprocessor, init_cfg)
# If we have a float_checkpoint, we load it as pretrain.
if float_checkpoint:
_ = load_checkpoint(self.architecture, float_checkpoint)
self.architecture._is_init = True
Expand All @@ -70,7 +78,22 @@ def __init__(self,

self.sync_qparams('predict')

def sync_qparams(self, src_mode):
def sync_qparams(self, src_mode: str):
"""Sync all quantize parameters in different `forward_modes`. We could
have more than one forward mode to generate graphs, each mode will
generate one graph. But in training, only one graph will be update, so
we need to sync qparams in the other graphs.
Args:
src_mode (str): The modes of forward method.
Note:
`traverse()` function recursively traverses all module to sync
quantized graph generated from different `forward_modes`.
This is because We have different mode ('tensor', 'predict',
'loss') in OpenMMLab architecture which have different graph
in some subtle ways, so we need to sync them here.
"""

def traverse(module, prefix):
for name, child in module._modules.items():
Expand All @@ -84,10 +107,10 @@ def traverse(module, prefix):
if src_param.shape == param.shape:
param.data.copy_(src_param)
else:
requirs_grad = param.requires_grad
param.requires_grad = False
# requirs_grad = param.requires_grad
# param.requires_grad = False
param.resize_(src_param.shape)
param.requires_grad = requirs_grad
# param.requires_grad = requirs_grad
param.data.copy_(src_param)
for name, buffer in child.named_buffers():
buffer_name = f'{child_name}.{name}'
Expand All @@ -106,7 +129,31 @@ def traverse(module, prefix):
continue
traverse(self.qmodels[mode], '')

def _build_qmodels(self, model):
def _build_qmodels(self, model: BaseModel):
"""Build quantized models from the given model.
Args:
model (BaseModel): the given fp model.
Example:
The main body of the graph is all the same, but the last one or two
op will have difference, as shown below.
self.qmodels['tensor'].graph.print_tabular()
opcode target args
call_module head.fc (activation_post_process_38,)
output output (head_fc,)
self.qmodels['loss'].graph.print_tabular()
opcode target args
call_method _get_loss (head, head_fc, data_samples)
output output (_get_loss,)
self.qmodels['predict'].graph.print_tabular()
opcode target args
call_method _get_predictions (head, head_fc, data_samples)
output output (_get_predictions,)
"""

qmodels = nn.ModuleDict()

Expand Down Expand Up @@ -137,19 +184,27 @@ def forward(self,
else:
return self.architecture(inputs, data_samples, mode)

def calibrate_step(self, data):
def calibrate_step(self, data: Union[Dict, Tuple, List]):
"""PTQ method need calibrate by cali data."""

data = self.data_preprocessor(data, False)
return self._run_forward(data, mode='predict')


@MODEL_WRAPPERS.register_module()
class MMArchitectureQuantDDP(MMDistributedDataParallel):
"""DDPwapper for GeneralQuant."""
"""DDPwapper for GeneralQuant.
Args:
device_ids (Optional[Union[List, int, torch.device]]): devices to run
ddp.
"""

def __init__(self,
*,
device_ids: Optional[Union[List, int, torch.device]] = None,
**kwargs) -> None:

if device_ids is None:
if os.environ.get('LOCAL_RANK') is not None:
device_ids = [int(os.environ['LOCAL_RANK'])]
Expand All @@ -159,8 +214,26 @@ def __init__(self,
self.module.qmodels = self.module._build_qmodels(
self.module.architecture)

def calibrate_step(self, data):
def calibrate_step(self, data: Union[Dict, Tuple, List]):
"""PTQ method need calibrate by cali data."""

return self.module.calibrate_step(data)

def sync_qparams(self, src):
def sync_qparams(self, src: str):
"""Same as in 'MMArchitectureQuant'. Sync all quantize parameters in
different `forward_modes`. We could have several modes to generate
graphs, but in training, only one graph will be update, so we need to
sync qparams on the other graphs.
Args:
src (str): The src modes of forward method.
Note:
`traverse()` function recursively traverses all module to sync
quantized graph generated from different `forward_modes`.
This is because We have different mode ('tensor', 'predict',
'loss') in OpenMMLab architecture which have different graph
in some subtle ways, so we need to sync them here.
"""

self.module.sync_qparams(src)
Loading

0 comments on commit 9872c61

Please sign in to comment.