Skip to content

Commit

Permalink
[Dygraph QAT] Save all scales to target ops and Move quant layers to …
Browse files Browse the repository at this point in the history
…paddle.nn.quant (#33871)

* Save all scales to target ops
* Move quant layers to paddle.nn.quant
  • Loading branch information
juncaipeng authored Jul 5, 2021
1 parent ea1a0d4 commit 00c85a7
Show file tree
Hide file tree
Showing 8 changed files with 313 additions and 257 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@

from __future__ import print_function

from . import quant_nn
from .quant_nn import *

from . import qat
from .qat import *

Expand All @@ -33,7 +30,6 @@
from .ptq_registry import *

__all__ = []
__all__ += quant_nn.__all__
__all__ += qat.__all__
__all__ += ptq.__all__
__all__ += ptq_config.__all__
Expand Down
101 changes: 72 additions & 29 deletions python/paddle/fluid/contrib/slim/quantization/imperative/qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import warnings

import paddle
import paddle.nn.quant.quant_layers as quant_layers
from paddle.fluid import dygraph, core, framework, unique_name
from paddle.fluid.executor import Executor, global_scope
from paddle.fluid.param_attr import ParamAttr
Expand All @@ -28,7 +29,6 @@
from paddle.fluid.io import load_inference_model, save_inference_model
from paddle.fluid.log_helper import get_logger
from .. import quantization_pass
from . import quant_nn
from . import utils

__all__ = ['ImperativeQuantAware']
Expand All @@ -39,7 +39,7 @@

class ImperativeQuantAware(object):
"""
Applying quantization aware training (QAT) to dgraph model.
Applying quantization aware training (QAT) to the dgraph model.
"""

def __init__(self,
Expand Down Expand Up @@ -329,12 +329,12 @@ def _get_input_quantized_layer(self, layer):
"The layer %s is unsupported to be quantized." \
% layer.full_name()

return quant_nn.__dict__[quant_layer_name](layer, **self._kwargs)
return quant_layers.__dict__[quant_layer_name](layer, **self._kwargs)


class ImperativeQuantizeOutputs(object):
"""
Calculate the output scales for some layers.
Calculate the output scales for target layers.
"""

def __init__(self, moving_rate=0.9):
Expand Down Expand Up @@ -371,11 +371,11 @@ def apply(self, model):
utils.find_parent_layer_and_sub_name(model, cur_name)

if isinstance(cur_layer, tuple(utils.fake_quant_output_layers)):
cur_quant_layer = quant_nn.FakeQuantMAOutputScaleLayer(
cur_quant_layer = quant_layers.FakeQuantMAOutputScaleLayer(
cur_layer, self._moving_rate)
else:
cur_quant_layer = quant_nn.MAOutputScaleLayer(cur_layer,
self._moving_rate)
cur_quant_layer = quant_layers.MAOutputScaleLayer(
cur_layer, self._moving_rate)

setattr(parent_layer, sub_name, cur_quant_layer)

Expand Down Expand Up @@ -433,7 +433,7 @@ def save_quantized_model(self, layer, path, input_spec=None, **config):
model_filename=model_filename,
params_filename=params_filename))

self._save_output_scale(infer_program, scope)
self._gather_scales(infer_program, scope)

self._set_skip_quant_attr(infer_program)

Expand All @@ -455,36 +455,79 @@ def _is_target_layer(self, layer):
"""
flag = False
if isinstance(layer, dygraph.Layer):
# exclude fake_quant ops in quant_nn file
# exclude fake_quant ops in quant_layers file
if utils.is_leaf_layer(layer) and \
not isinstance(layer, tuple(utils.fake_quant_leaf_layers)):
flag = True
# consider QuantizedConv2D and QuantizedLinear ops

if isinstance(layer, tuple(utils.fake_quant_wrap_layers)):
flag = True
if isinstance(layer, paddle.nn.quant.FloatFunctionalLayer):
flag = True

if isinstance(layer, paddle.nn.quant.FloatFunctionalLayer):
flag = True

return flag

def _save_output_scale(self, program, scope):
def _gather_scales(self, program, scope):
"""
Save all output scales to the corresponding ops in static
inference program and delete 'moving_average_abs_max_scale' ops.
Get all scales from fake ops, save them into the corresponding ops
and delete all moving_average_abs_max_scale ops.
"""
for block in program.blocks:
for op in block.ops:
if op.type == "moving_average_abs_max_scale":
in_var_name = op.input('X')[0]
out_var_name = op.output('Out')[0]
out_scale_name = op.output('OutScale')[0]

out_scale = utils.load_variable_data(scope, out_scale_name)
previous_op = utils.find_previous_op(block, in_var_name)
previous_op._set_attr("out_threshold", float(out_scale))

next_ops = utils.find_next_ops(block, out_var_name)
for next_op in next_ops:
next_op._rename_input(out_var_name, in_var_name)

def _gather_input_scale():
target_ops = []
skip_ops = utils.fake_quantize_dequantize_op_types + \
["moving_average_abs_max_scale"]
for block in program.blocks:
for op in block.ops:
if op.type not in skip_ops:
target_ops.append(op)

for op in target_ops:
for in_var_name in utils._get_op_input_var_names(op):
previous_op = utils.find_previous_op(op.block, in_var_name)

if previous_op is not None and \
("quantize_dequantize" in previous_op.type or \
previous_op.type == "moving_average_abs_max_scale"):
scale_name = previous_op.output('OutScale')[0]
in_scale = utils.load_variable_data(scope, scale_name)
in_scale = utils.fp_numpy_to_naive(in_scale)
argname, index = utils._get_input_name_index(
op, in_var_name)
op._set_attr(argname + str(index) + "_threshold",
in_scale)

def _gather_output_scale():
target_ops = []
for block in program.blocks:
for op in block.ops:
if op.type == "moving_average_abs_max_scale":
target_ops.append(op)

for op in target_ops:
in_var_name = op.input('X')[0]
out_var_name = op.output('Out')[0]
block = op.block
previous_op = utils.find_previous_op(block, in_var_name)
next_ops = utils.find_next_ops(block, out_var_name)

out_scale_name = op.output('OutScale')[0]
out_scale = utils.load_variable_data(scope, out_scale_name)
out_scale = utils.fp_numpy_to_naive(out_scale)

if previous_op.type != "feed":
argname, index = utils._get_output_name_index(previous_op,
in_var_name)
previous_op._set_attr(argname + str(index) + "_threshold",
out_scale)
previous_op._set_attr("out_threshold", out_scale)

for next_op in next_ops:
next_op._rename_input(out_var_name, in_var_name)

_gather_input_scale()
_gather_output_scale()

def _set_skip_quant_attr(self, program):
"""
Expand Down
26 changes: 20 additions & 6 deletions python/paddle/fluid/contrib/slim/quantization/imperative/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@
import numpy as np

import paddle
import paddle.nn.quant.quant_layers as quant_layers

from . import quant_nn
from ..quantization_pass import _get_op_input_var_names
from ..quantization_pass import _get_op_output_var_names
from ..quantization_pass import _get_output_name_index
from ..quantization_pass import _get_input_name_index

layer_name_map = {
'Conv2D': paddle.nn.Conv2D,
Expand Down Expand Up @@ -54,13 +58,15 @@
]

fake_quant_leaf_layers = [
quant_nn.FakeQuantAbsMax,
quant_nn.FakeQuantChannelWiseAbsMax,
quant_nn.FakeQuantMovingAverageAbsMax,
quant_nn.MovingAverageAbsMaxScale,
quant_layers.FakeQuantAbsMax,
quant_layers.FakeQuantChannelWiseAbsMax,
quant_layers.FakeQuantMovingAverageAbsMax,
quant_layers.MovingAverageAbsMaxScale,
]

fake_quant_wrap_layers = [quant_nn.QuantizedConv2D, quant_nn.QuantizedLinear]
fake_quant_wrap_layers = [
quant_layers.QuantizedConv2D, quant_layers.QuantizedLinear
]

# The weight format of these layers is Cin * Cout * H * W
spec_channel_axis_layers = [paddle.nn.Conv2D, paddle.nn.Conv2DTranspose]
Expand Down Expand Up @@ -94,6 +100,7 @@ def find_previous_op(block, var_name):
for op in block.ops:
if var_name in op.output_arg_names:
return op
return None


def find_next_ops(block, var_name):
Expand Down Expand Up @@ -244,3 +251,10 @@ def cal_kl_scaling_factor(hist, abs_max, bits):
break
min_kl_index = starting_iter
return (min_kl_index + 0.5) * bin_width


def fp_numpy_to_naive(x_np):
if x_np.size == 1:
return float(x_np)
else:
return x_np.tolist()
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,21 @@


def _get_op_input_var_names(op):
""" """
"""
Get the input var names of the op.
Args:
op(IrNode, Operator): the input op.
Returns:
input_var_names or None.
"""
assert isinstance(op, (IrNode, Operator)), \
"The input op should be IrNode or Operator."
var_names = []
op_name = op.name() if isinstance(op, IrNode) \
else op.type
if op_name not in _op_real_in_out_name:
return []

name_list = _op_real_in_out_name[op_name][0]
for name in name_list:
var_name = op.input(name)
Expand All @@ -163,6 +172,9 @@ def _get_input_name_index(op, input_var_name):
"The input op should be IrNode or Operator."
op_name = op.name() if isinstance(op, IrNode) \
else op.type
if op_name not in _op_real_in_out_name:
return None

res = None
for argname in _op_real_in_out_name[op_name][0]:
var_names = op.input(argname)
Expand All @@ -179,6 +191,9 @@ def _get_op_output_var_names(op):
var_names = []
op_name = op.name() if isinstance(op, IrNode) \
else op.type
if op_name not in _op_real_in_out_name:
return []

name_list = _op_real_in_out_name[op_name][1]
for name in name_list:
var_name = op.output(name)
Expand All @@ -195,6 +210,9 @@ def _get_output_name_index(op, output_var_name):
"The input op should be IrNode or Operator."
op_name = op.name() if isinstance(op, IrNode) \
else op.type
if op_name not in _op_real_in_out_name:
return None

name_list = _op_real_in_out_name[op_name][1]
res = None
for name in name_list:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from paddle.nn import Linear, Conv2D, Softmax
from paddle.fluid.log_helper import get_logger
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from paddle.fluid.contrib.slim.quantization.imperative.quant_nn import QuantizedConv2D
from paddle.nn.quant.quant_layers import QuantizedConv2D

from imperative_test_utils import fix_model_dict, ImperativeLenet

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.contrib.slim.quantization.imperative import quant_nn
import paddle.nn.quant.quant_layers as quant_layers

paddle.enable_static()

Expand All @@ -45,7 +45,7 @@ def check_backward(self, use_cuda):
name='image', shape=[784], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
fc_tmp = fluid.layers.fc(image, size=10, act='softmax')
out_scale = quant_nn.MovingAverageAbsMaxScale(
out_scale = quant_layers.MovingAverageAbsMaxScale(
name=fc_tmp.name, dtype=fc_tmp.dtype)
fc_tmp_1 = out_scale(fc_tmp)
cross_entropy = fluid.layers.softmax_with_cross_entropy(fc_tmp,
Expand Down
1 change: 1 addition & 0 deletions python/paddle/nn/quant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@
from .functional_layers import transpose # noqa: F401
from .functional_layers import concat # noqa: F401
from .functional_layers import flatten # noqa: F401
from .quant_layers import QuantStub # noqa: F401

__all__ = []
Loading

0 comments on commit 00c85a7

Please sign in to comment.