Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Docs] Add docstring for MMArchitectureQuant & NativeQuantizer #425

Merged
merged 21 commits into from
Jan 17, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
_delete_=True,
type='mmrazor.MMArchitectureQuant',
architecture=_base_.model,
float_checkpoint='/tmp/humu/mobilenet_v2_batch256_imagenet' +
'_20200708-3b2dc3af.pth',
float_checkpoint='https://download.openmmlab.com/mmclassification/v0/mobil'
'enet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth',
quantizer=dict(
type='mmrazor.OpenVINOQuantizer',
global_qconfig=global_qconfig,
Expand All @@ -32,3 +32,8 @@
'mmcls.models.heads.ClsHead._get_loss',
'mmcls.models.heads.ClsHead._get_predictions'
])))

model_wrapper_cfg = dict(
type='mmrazor.MMArchitectureQuantDDP',
broadcast_buffers=False,
find_unused_parameters=True)
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
_delete_=True,
type='mmrazor.MMArchitectureQuant',
architecture=_base_.model,
float_checkpoint='/tmp/humu/resnet18_8xb32_in1k_20210831-fbbb1da6.pth',
float_checkpoint='https://download.openmmlab.com/mmclassification/v0/resne'
't/resnet18_8xb32_in1k_20210831-fbbb1da6.pth',
quantizer=dict(
type='mmrazor.OpenVINOQuantizer',
global_qconfig=global_qconfig,
Expand All @@ -33,3 +34,5 @@
'mmcls.models.heads.ClsHead._get_loss',
'mmcls.models.heads.ClsHead._get_predictions'
])))

model_wrapper_cfg = dict(type='mmrazor.MMArchitectureQuantDDP', )
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
_delete_=True,
type='mmrazor.MMArchitectureQuant',
architecture=_base_.model,
float_checkpoint='/tmp/humu/resnet50_8xb32_in1k_20210831-ea4938fc.pth',
float_checkpoint='https://download.openmmlab.com/mmclassification/v0/resne'
't/resnet50_8xb32_in1k_20210831-ea4938fc.pth',
quantizer=dict(
type='mmrazor.OpenVINOQuantizer',
global_qconfig=global_qconfig,
Expand All @@ -33,3 +34,5 @@
'mmcls.models.heads.ClsHead._get_loss',
'mmcls.models.heads.ClsHead._get_predictions'
])))

model_wrapper_cfg = dict(type='mmrazor.MMArchitectureQuantDDP', )
64 changes: 64 additions & 0 deletions configs/quantization/qat/minmax_openvino_resnet18_8xb32_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py']

train_dataloader = dict(batch_size=1024)

global_qconfig = dict(
w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'),
a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'),
w_fake_quant=dict(type='mmrazor.FakeQuantize'),
a_fake_quant=dict(type='mmrazor.FakeQuantize'),
w_qscheme=dict(
qdtype='qint8',
bit=8,
is_symmetry=True,
is_symmetric_range=True,
),
a_qscheme=dict(
qdtype='quint8',
bit=8,
is_symmetry=True,
averaging_constant=0.1,
),
)

model = dict(
_delete_=True,
type='mmrazor.MMArchitectureQuant',
architecture=_base_.model,
float_checkpoint='https://download.openmmlab.com/mmclassification/v0/resne'
't/resnet18_8xb32_in1k_20210831-fbbb1da6.pth',
quantizer=dict(
type='mmrazor.OpenVINOQuantizer',
global_qconfig=global_qconfig,
tracer=dict(
type='mmrazor.CustomTracer',
skipped_methods=[
'mmcls.models.heads.ClsHead._get_loss',
'mmcls.models.heads.ClsHead._get_predictions'
])))

optim_wrapper = dict(
optimizer=dict(type='SGD', lr=0.004, momentum=0.9, weight_decay=0.0001))

# learning policy
param_scheduler = dict(
_delete_=True,
type='CosineAnnealingLR',
T_max=100,
by_epoch=True,
begin=0,
end=100)

model_wrapper_cfg = dict(
type='mmrazor.MMArchitectureQuantDDP',
broadcast_buffers=False,
find_unused_parameters=True)

# train, val, test setting
train_cfg = dict(
_delete_=True,
type='mmrazor.QATEpochBasedLoop',
max_epochs=100,
val_interval=1)
val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop')
# test_cfg = val_cfg
95 changes: 72 additions & 23 deletions mmrazor/models/algorithms/quantization/mm_architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from mmrazor.models.task_modules.tracer import build_graphmodule
from mmrazor.registry import MODEL_WRAPPERS, MODELS
from ..base import BaseAlgorithm
from ..base import BaseAlgorithm, BaseModel

try:
from torch.ao.quantization import FakeQuantizeBase
Expand All @@ -29,35 +29,44 @@ class MMArchitectureQuant(BaseAlgorithm):
"""General quantization.

Args:
architecture (dict | :obj:`BaseModel`): The config of
:class:`BaseModel` or built model.
quantizer (dict | :obj:`BaseModel`): The config of
:class:`BaseQuantizer` or built model.
export_mode (str): The mode of the model to be exported. Defaults to
predict.
architecture (dict | :obj:`BaseModel`): The config of model to be
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type hint in docsting had better be same as in __init__

quantized.
quantizer (dict | :obj:`BaseModel`): The quantizer to support different
backend type.
qmodel_modes (list): The available mode of runner.
data_preprocessor (dict | torch.nn.Module | None): The pre-process
config of :class:`BaseDataPreprocessor`. Defaults to None.
pretrained_ckpt (str, Optional): The path of pretrained checkpoint.
Defaults to None.
init_cfg (dict): The weight initialized config for
:class:`BaseModule`.
forward_modes (tuple): The modes in forward method in OpenMMLab
architecture could be tensor, predict, or loss. It can generate
different graph of quantized model.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a note to explain why we need to generate multi graph

float_checkpoint (str, Optional): The path of pretrained FP checkpoint.
Quantization is different from or task, we recommend to use
`float_checkpoint` as pretrain model. Defaults to None.
init_cfg (dict): The weight initialized config for :class:`BaseModule`.

Note:
forward_modes (tuple): In OpenMMLab architecture, differenet modes
will trace a different graph of quantized model.
"""

def __init__(self,
architecture,
quantizer,
data_preprocessor=None,
forward_modes=('tensor', 'predict', 'loss'),
float_checkpoint: Optional[str] = None,
input_shapes=(1, 3, 224, 224),
init_cfg=None):
def __init__(
self,
architecture: Union[Dict, BaseModel],
quantizer: Union[Dict, BaseModel],
# data_preprocessor: Union[Dict, torch.nn.Module, None] = None,
data_preprocessor=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add type hint

forward_modes: Union[tuple, str] = ('tensor', 'predict', 'loss'),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type hint should not include str

float_checkpoint: Optional[str] = None,
input_shapes: tuple = (1, 3, 224, 224),
init_cfg: Optional[dict] = None):

if data_preprocessor is None:
data_preprocessor = {}
# The build process is in MMEngine, so we need to add scope here.
# Default to mmcls.ClsDataPreprocessor.
data_preprocessor.setdefault('type', 'mmcls.ClsDataPreprocessor')
super().__init__(architecture, data_preprocessor, init_cfg)
# If we have a float_checkpoint, we load it as pretrain.
if float_checkpoint:
_ = load_checkpoint(self.architecture, float_checkpoint)
self.architecture._is_init = True
Expand All @@ -68,9 +77,23 @@ def __init__(self,

self.qmodels = self._build_qmodels(self.architecture)

self.sync_qparams('predict')
self.sync_qparams(forward_modes[0])

def sync_qparams(self, src_mode):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add type hint

"""Sync all quantize parameters in different `forward_modes`. We have
three modes to generate three graphs, but in training, only one graph
will be update, so we need to sync qparams in the other two graphs.

Args:
src_mode (str): The modes of forward method.

Note:
`traverse()` function recursively traverses all module to sync
quantized graph generated from different `forward_modes`.
This is because We have different mode ('tensor', 'predict',
'loss') in OpenMMLab architecture which have different graph
in some subtle ways, so we need to sync them here.
"""

def traverse(module, prefix):
for name, child in module._modules.items():
Expand All @@ -84,10 +107,10 @@ def traverse(module, prefix):
if src_param.shape == param.shape:
param.data.copy_(src_param)
else:
requirs_grad = param.requires_grad
param.requires_grad = False
# requirs_grad = param.requires_grad
# param.requires_grad = False
param.resize_(src_param.shape)
param.requires_grad = requirs_grad
# param.requires_grad = requirs_grad
param.data.copy_(src_param)
for name, buffer in child.named_buffers():
buffer_name = f'{child_name}.{name}'
Expand All @@ -107,6 +130,30 @@ def traverse(module, prefix):
traverse(self.qmodels[mode], '')

def _build_qmodels(self, model):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add type hint

"""Build quantized models from the given model.

Args:
model (dict | :obj:`BaseModel`): the given fp model.

Example:
The main body of the graph is all the same, but the last one or two
op will have difference, as shown below.

self.qmodels['tensor'].graph.print_tabular()
opcode target args
call_module head.fc (activation_post_process_38,)
output output (head_fc,)

self.qmodels['loss'].graph.print_tabular()
opcode target args
call_method _get_loss (head, head_fc, data_samples)
output output (_get_loss,)

self.qmodels['predict'].graph.print_tabular()
opcode target args
call_method _get_predictions (head, head_fc, data_samples)
output output (_get_predictions,)
"""

qmodels = nn.ModuleDict()

Expand Down Expand Up @@ -138,6 +185,8 @@ def forward(self,
return self.architecture(inputs, data_samples, mode)

def calibrate_step(self, data):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add type hint

"""PTQ method need calibrate by cali data."""

data = self.data_preprocessor(data, False)
return self._run_forward(data, mode='predict')

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MMArchitectureQuantDDP need to add docstring and type hint

Expand Down
Loading