Skip to content

Commit

Permalink
Added a more accurate description
Browse files Browse the repository at this point in the history
  • Loading branch information
Niko-zyf committed Jan 6, 2023
1 parent 46e4700 commit 61a7406
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,6 @@

train_dataloader = dict(batch_size=1024)

# test_cfg = dict(
# type='mmrazor.PTQLoop',
# calibrate_dataloader=train_dataloader,
# calibrate_steps=32,
# )

global_qconfig = dict(
w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'),
a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'),
Expand Down
34 changes: 28 additions & 6 deletions mmrazor/models/algorithms/quantization/mm_architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ class MMArchitectureQuant(BaseAlgorithm):
init_cfg (dict): The weight initialized config for :class:`BaseModule`.
Note:
forward_modes (tuple):
forward_modes (tuple): In OpenMMLab architecture, differenet modes
will trace a different graph of quantized model.
"""

def __init__(
Expand All @@ -49,7 +50,7 @@ def __init__(
quantizer: Union[Dict, BaseModel],
# data_preprocessor: Union[Dict, torch.nn.Module, None] = None,
data_preprocessor=None,
forward_modes: Union[tuple, str] = ('tensor'),
forward_modes: Union[tuple, str] = ('tensor','predict', 'loss'),
float_checkpoint: Optional[str] = None,
input_shapes: tuple = (1, 3, 224, 224),
init_cfg: Optional[dict] = None):
Expand All @@ -74,7 +75,10 @@ def __init__(
self.sync_qparams(forward_modes[0])

def sync_qparams(self, src_mode):
"""Sync all quantize parameters in different `forward_modes`.
"""Sync all quantize parameters in different `forward_modes`. We have
three modes to generate three graphs, but in training, only one
graph will be update, so we need to sync qparams in the other two
graphs.
Args:
src_mode (str): The modes of forward method.
Expand All @@ -84,7 +88,7 @@ def sync_qparams(self, src_mode):
quantized graph generated from different `forward_modes`.
This is because We have different mode ('tensor', 'predict',
'loss') in OpenMMLab architecture which have different graph
in some subtle ways, so we need to sync them here.
in some subtle ways, so we need to sync them here.
"""

def traverse(module, prefix):
Expand Down Expand Up @@ -117,8 +121,6 @@ def traverse(module, prefix):

src_state_dict = self.qmodels[src_mode].state_dict()
for mode in self.forward_modes:
import pdb
pdb.set_trace()
if mode == src_mode:
continue
traverse(self.qmodels[mode], '')
Expand All @@ -128,6 +130,26 @@ def _build_qmodels(self, model):
Args:
model (dict | :obj:`BaseModel`): the given fp model.
Example:
The main body of the graph is all the same, but the last one or two
op will have difference, as shown below.
self.qmodels['tensor'].graph.print_tabular()
opcode target args
call_module head.fc (activation_post_process_38,)
output output (head_fc,)
self.qmodels['loss'].graph.print_tabular()
opcode target args
call_method _get_loss (head, head_fc, data_samples) {}
output output (_get_loss,)
self.qmodels['predict'].graph.print_tabular()
opcode target args
call_method _get_predictions (head, head_fc, data_samples) {}
output output (_get_predictions,)
"""

qmodels = nn.ModuleDict()
Expand Down
38 changes: 28 additions & 10 deletions mmrazor/models/quantizers/native_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,23 @@ def support_a_modes(self):
return ['per_tensor']

def prepare(self, model, graph_module):
"""prepare graph to ObserverdGraphModule.
"""prepare graph to ObservedGraphModule.
Args:
model (_type_): _description_
graph_module (_type_): Graph modules before fuse.
graph_module (_type_): GraphModules before fuse.
Returns:
ObserverdGraphModule: Graph module after fuse and observerd
ObservedGraphModule: GraphModules after fuse and observer.
Notes:
'graph_module' after '_fuse_fx()' function will fuse conv, BN, ReLU
into modules in SUPPORT_QAT_MODULES.
'graph_module' after 'prepare()' function will become observed.
Notes:
Keep `is_qat` is True is because in Pytorch when `is_qat` is false,
the `_fuse_fx()` function only fuse module into `nn.Squential` ,
but we need it to be fused into `SUPPORT_QAT_MODULES` type.
"""

graph_module = _fuse_fx(
Expand All @@ -129,17 +138,23 @@ def prepare(self, model, graph_module):
node_name_to_scope=self.tracer.node_name_to_scope,
example_inputs=self.example_inputs,
backend_config=self.backend_config)

return prepared

def post_process_weight_fakequant(self,
observed_module,
keep_fake_quant: bool = False):
"""weight fakequant for supported QAT modules.
"""weight fake-quant for supported QAT modules.
Args:
observed_module (_type_): _description_
keep_fake_quant (bool, optional): _description_. Defaults to False.
observed_module (ObservedGraphModule): Modules after fused and
observed.
keep_fake_quant (bool, optional): Bool to determine whether to keep
fake-quant op, depending on the backend. Defaults to False.
Note:
`post_process_weight_fakequant()` function is necessary that the
`SUPPORT_QAT_MODULES` will be convert to normal modules, and
BN will be really integrated into conv layers.
"""

def traverse(module):
Expand All @@ -153,11 +168,14 @@ def traverse(module):
weight_fakequant = child.weight_fake_quant
child.weight.data = weight_fakequant(child.weight.data)

# `to_float()` function fuse BN into conv or conv_relu.
# `to_float()` function fuse BN into conv or conv_relu, and
# also convert a qat module to a normal module.
# source url: torch.nn.intrinsic.qat.modules.conv_fused.py
float_child = child.to_float()

# This is decided by backend type, some backend need
# This is decided by backend type, some backend need
# explicitly keep the fake quant structure, others don't.
# TODO add deploy doc link
if keep_fake_quant:
for m in float_child.modules():
setattr(m, 'qconfig', self.qconfig.convert())
Expand Down

0 comments on commit 61a7406

Please sign in to comment.