-
Notifications
You must be signed in to change notification settings - Fork 231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Docs] Add docstring for MMArchitectureQuant
& NativeQuantizer
#425
Changes from 15 commits
ed86e0e
549cf32
cf58dde
920544f
da1d6ef
53fbe1a
fb09dfa
40824fd
b2fb95a
6654005
1a568d1
489866a
ab9feec
ff6c6de
57c83fa
29aa957
1119988
693a76f
b304b69
14dc5d5
df95ad8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] | ||
|
||
train_dataloader = dict(batch_size=1024) | ||
|
||
global_qconfig = dict( | ||
w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), | ||
a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), | ||
w_fake_quant=dict(type='mmrazor.FakeQuantize'), | ||
a_fake_quant=dict(type='mmrazor.FakeQuantize'), | ||
w_qscheme=dict( | ||
qdtype='qint8', | ||
bit=8, | ||
is_symmetry=True, | ||
is_symmetric_range=True, | ||
), | ||
a_qscheme=dict( | ||
qdtype='quint8', | ||
bit=8, | ||
is_symmetry=True, | ||
averaging_constant=0.1, | ||
), | ||
) | ||
|
||
model = dict( | ||
_delete_=True, | ||
type='mmrazor.MMArchitectureQuant', | ||
architecture=_base_.model, | ||
float_checkpoint='https://download.openmmlab.com/mmclassification/v0/resne' | ||
't/resnet18_8xb32_in1k_20210831-fbbb1da6.pth', | ||
quantizer=dict( | ||
type='mmrazor.OpenVINOQuantizer', | ||
global_qconfig=global_qconfig, | ||
tracer=dict( | ||
type='mmrazor.CustomTracer', | ||
skipped_methods=[ | ||
'mmcls.models.heads.ClsHead._get_loss', | ||
'mmcls.models.heads.ClsHead._get_predictions' | ||
]))) | ||
|
||
optim_wrapper = dict( | ||
optimizer=dict(type='SGD', lr=0.004, momentum=0.9, weight_decay=0.0001)) | ||
|
||
# learning policy | ||
param_scheduler = dict( | ||
_delete_=True, | ||
type='CosineAnnealingLR', | ||
T_max=100, | ||
by_epoch=True, | ||
begin=0, | ||
end=100) | ||
|
||
model_wrapper_cfg = dict( | ||
type='mmrazor.MMArchitectureQuantDDP', | ||
broadcast_buffers=False, | ||
find_unused_parameters=True) | ||
|
||
# train, val, test setting | ||
train_cfg = dict( | ||
_delete_=True, | ||
type='mmrazor.QATEpochBasedLoop', | ||
max_epochs=100, | ||
val_interval=1) | ||
val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop') | ||
# test_cfg = val_cfg |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,7 @@ | |
|
||
from mmrazor.models.task_modules.tracer import build_graphmodule | ||
from mmrazor.registry import MODEL_WRAPPERS, MODELS | ||
from ..base import BaseAlgorithm | ||
from ..base import BaseAlgorithm, BaseModel | ||
|
||
try: | ||
from torch.ao.quantization import FakeQuantizeBase | ||
|
@@ -29,35 +29,44 @@ class MMArchitectureQuant(BaseAlgorithm): | |
"""General quantization. | ||
|
||
Args: | ||
architecture (dict | :obj:`BaseModel`): The config of | ||
:class:`BaseModel` or built model. | ||
quantizer (dict | :obj:`BaseModel`): The config of | ||
:class:`BaseQuantizer` or built model. | ||
export_mode (str): The mode of the model to be exported. Defaults to | ||
predict. | ||
architecture (dict | :obj:`BaseModel`): The config of model to be | ||
quantized. | ||
quantizer (dict | :obj:`BaseModel`): The quantizer to support different | ||
backend type. | ||
qmodel_modes (list): The available mode of runner. | ||
data_preprocessor (dict | torch.nn.Module | None): The pre-process | ||
config of :class:`BaseDataPreprocessor`. Defaults to None. | ||
pretrained_ckpt (str, Optional): The path of pretrained checkpoint. | ||
Defaults to None. | ||
init_cfg (dict): The weight initialized config for | ||
:class:`BaseModule`. | ||
forward_modes (tuple): The modes in forward method in OpenMMLab | ||
architecture could be tensor, predict, or loss. It can generate | ||
different graph of quantized model. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a note to explain why we need to generate multi graph |
||
float_checkpoint (str, Optional): The path of pretrained FP checkpoint. | ||
Quantization is different from or task, we recommend to use | ||
`float_checkpoint` as pretrain model. Defaults to None. | ||
init_cfg (dict): The weight initialized config for :class:`BaseModule`. | ||
|
||
Note: | ||
forward_modes (tuple): In OpenMMLab architecture, differenet modes | ||
will trace a different graph of quantized model. | ||
""" | ||
|
||
def __init__(self, | ||
architecture, | ||
quantizer, | ||
data_preprocessor=None, | ||
forward_modes=('tensor', 'predict', 'loss'), | ||
float_checkpoint: Optional[str] = None, | ||
input_shapes=(1, 3, 224, 224), | ||
init_cfg=None): | ||
def __init__( | ||
self, | ||
architecture: Union[Dict, BaseModel], | ||
quantizer: Union[Dict, BaseModel], | ||
# data_preprocessor: Union[Dict, torch.nn.Module, None] = None, | ||
data_preprocessor=None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add type hint |
||
forward_modes: Union[tuple, str] = ('tensor', 'predict', 'loss'), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. type hint should not include str |
||
float_checkpoint: Optional[str] = None, | ||
input_shapes: tuple = (1, 3, 224, 224), | ||
init_cfg: Optional[dict] = None): | ||
|
||
if data_preprocessor is None: | ||
data_preprocessor = {} | ||
# The build process is in MMEngine, so we need to add scope here. | ||
# Default to mmcls.ClsDataPreprocessor. | ||
data_preprocessor.setdefault('type', 'mmcls.ClsDataPreprocessor') | ||
super().__init__(architecture, data_preprocessor, init_cfg) | ||
# If we have a float_checkpoint, we load it as pretrain. | ||
if float_checkpoint: | ||
_ = load_checkpoint(self.architecture, float_checkpoint) | ||
self.architecture._is_init = True | ||
|
@@ -68,9 +77,23 @@ def __init__(self, | |
|
||
self.qmodels = self._build_qmodels(self.architecture) | ||
|
||
self.sync_qparams('predict') | ||
self.sync_qparams(forward_modes[0]) | ||
|
||
def sync_qparams(self, src_mode): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add type hint |
||
"""Sync all quantize parameters in different `forward_modes`. We have | ||
three modes to generate three graphs, but in training, only one graph | ||
will be update, so we need to sync qparams in the other two graphs. | ||
|
||
Args: | ||
src_mode (str): The modes of forward method. | ||
|
||
Note: | ||
`traverse()` function recursively traverses all module to sync | ||
quantized graph generated from different `forward_modes`. | ||
This is because We have different mode ('tensor', 'predict', | ||
'loss') in OpenMMLab architecture which have different graph | ||
in some subtle ways, so we need to sync them here. | ||
""" | ||
|
||
def traverse(module, prefix): | ||
for name, child in module._modules.items(): | ||
|
@@ -84,10 +107,10 @@ def traverse(module, prefix): | |
if src_param.shape == param.shape: | ||
param.data.copy_(src_param) | ||
else: | ||
requirs_grad = param.requires_grad | ||
param.requires_grad = False | ||
# requirs_grad = param.requires_grad | ||
# param.requires_grad = False | ||
param.resize_(src_param.shape) | ||
param.requires_grad = requirs_grad | ||
# param.requires_grad = requirs_grad | ||
param.data.copy_(src_param) | ||
for name, buffer in child.named_buffers(): | ||
buffer_name = f'{child_name}.{name}' | ||
|
@@ -107,6 +130,30 @@ def traverse(module, prefix): | |
traverse(self.qmodels[mode], '') | ||
|
||
def _build_qmodels(self, model): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add type hint |
||
"""Build quantized models from the given model. | ||
|
||
Args: | ||
model (dict | :obj:`BaseModel`): the given fp model. | ||
|
||
Example: | ||
The main body of the graph is all the same, but the last one or two | ||
op will have difference, as shown below. | ||
|
||
self.qmodels['tensor'].graph.print_tabular() | ||
opcode target args | ||
call_module head.fc (activation_post_process_38,) | ||
output output (head_fc,) | ||
|
||
self.qmodels['loss'].graph.print_tabular() | ||
opcode target args | ||
call_method _get_loss (head, head_fc, data_samples) | ||
output output (_get_loss,) | ||
|
||
self.qmodels['predict'].graph.print_tabular() | ||
opcode target args | ||
call_method _get_predictions (head, head_fc, data_samples) | ||
output output (_get_predictions,) | ||
""" | ||
|
||
qmodels = nn.ModuleDict() | ||
|
||
|
@@ -138,6 +185,8 @@ def forward(self, | |
return self.architecture(inputs, data_samples, mode) | ||
|
||
def calibrate_step(self, data): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add type hint |
||
"""PTQ method need calibrate by cali data.""" | ||
|
||
data = self.data_preprocessor(data, False) | ||
return self._run_forward(data, mode='predict') | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. MMArchitectureQuantDDP need to add docstring and type hint |
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
type hint in docsting had better be same as in
__init__