diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/__init__.py b/python/paddle/fluid/contrib/slim/quantization/imperative/__init__.py index 77872e88a0733..7210da93f7bf5 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/__init__.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/__init__.py @@ -14,9 +14,6 @@ from __future__ import print_function -from . import quant_nn -from .quant_nn import * - from . import qat from .qat import * @@ -33,7 +30,6 @@ from .ptq_registry import * __all__ = [] -__all__ += quant_nn.__all__ __all__ += qat.__all__ __all__ += ptq.__all__ __all__ += ptq_config.__all__ diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py index 600ce6397e1af..3b4f9a757437a 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py @@ -20,6 +20,7 @@ import warnings import paddle +import paddle.nn.quant.quant_layers as quant_layers from paddle.fluid import dygraph, core, framework, unique_name from paddle.fluid.executor import Executor, global_scope from paddle.fluid.param_attr import ParamAttr @@ -28,7 +29,6 @@ from paddle.fluid.io import load_inference_model, save_inference_model from paddle.fluid.log_helper import get_logger from .. import quantization_pass -from . import quant_nn from . import utils __all__ = ['ImperativeQuantAware'] @@ -39,7 +39,7 @@ class ImperativeQuantAware(object): """ - Applying quantization aware training (QAT) to dgraph model. + Applying quantization aware training (QAT) to the dgraph model. """ def __init__(self, @@ -329,12 +329,12 @@ def _get_input_quantized_layer(self, layer): "The layer %s is unsupported to be quantized." \ % layer.full_name() - return quant_nn.__dict__[quant_layer_name](layer, **self._kwargs) + return quant_layers.__dict__[quant_layer_name](layer, **self._kwargs) class ImperativeQuantizeOutputs(object): """ - Calculate the output scales for some layers. + Calculate the output scales for target layers. """ def __init__(self, moving_rate=0.9): @@ -371,11 +371,11 @@ def apply(self, model): utils.find_parent_layer_and_sub_name(model, cur_name) if isinstance(cur_layer, tuple(utils.fake_quant_output_layers)): - cur_quant_layer = quant_nn.FakeQuantMAOutputScaleLayer( + cur_quant_layer = quant_layers.FakeQuantMAOutputScaleLayer( cur_layer, self._moving_rate) else: - cur_quant_layer = quant_nn.MAOutputScaleLayer(cur_layer, - self._moving_rate) + cur_quant_layer = quant_layers.MAOutputScaleLayer( + cur_layer, self._moving_rate) setattr(parent_layer, sub_name, cur_quant_layer) @@ -433,7 +433,7 @@ def save_quantized_model(self, layer, path, input_spec=None, **config): model_filename=model_filename, params_filename=params_filename)) - self._save_output_scale(infer_program, scope) + self._gather_scales(infer_program, scope) self._set_skip_quant_attr(infer_program) @@ -455,36 +455,79 @@ def _is_target_layer(self, layer): """ flag = False if isinstance(layer, dygraph.Layer): - # exclude fake_quant ops in quant_nn file + # exclude fake_quant ops in quant_layers file if utils.is_leaf_layer(layer) and \ not isinstance(layer, tuple(utils.fake_quant_leaf_layers)): flag = True - # consider QuantizedConv2D and QuantizedLinear ops + if isinstance(layer, tuple(utils.fake_quant_wrap_layers)): flag = True - if isinstance(layer, paddle.nn.quant.FloatFunctionalLayer): - flag = True + + if isinstance(layer, paddle.nn.quant.FloatFunctionalLayer): + flag = True + return flag - def _save_output_scale(self, program, scope): + def _gather_scales(self, program, scope): """ - Save all output scales to the corresponding ops in static - inference program and delete 'moving_average_abs_max_scale' ops. + Get all scales from fake ops, save them into the corresponding ops + and delete all moving_average_abs_max_scale ops. """ - for block in program.blocks: - for op in block.ops: - if op.type == "moving_average_abs_max_scale": - in_var_name = op.input('X')[0] - out_var_name = op.output('Out')[0] - out_scale_name = op.output('OutScale')[0] - - out_scale = utils.load_variable_data(scope, out_scale_name) - previous_op = utils.find_previous_op(block, in_var_name) - previous_op._set_attr("out_threshold", float(out_scale)) - - next_ops = utils.find_next_ops(block, out_var_name) - for next_op in next_ops: - next_op._rename_input(out_var_name, in_var_name) + + def _gather_input_scale(): + target_ops = [] + skip_ops = utils.fake_quantize_dequantize_op_types + \ + ["moving_average_abs_max_scale"] + for block in program.blocks: + for op in block.ops: + if op.type not in skip_ops: + target_ops.append(op) + + for op in target_ops: + for in_var_name in utils._get_op_input_var_names(op): + previous_op = utils.find_previous_op(op.block, in_var_name) + + if previous_op is not None and \ + ("quantize_dequantize" in previous_op.type or \ + previous_op.type == "moving_average_abs_max_scale"): + scale_name = previous_op.output('OutScale')[0] + in_scale = utils.load_variable_data(scope, scale_name) + in_scale = utils.fp_numpy_to_naive(in_scale) + argname, index = utils._get_input_name_index( + op, in_var_name) + op._set_attr(argname + str(index) + "_threshold", + in_scale) + + def _gather_output_scale(): + target_ops = [] + for block in program.blocks: + for op in block.ops: + if op.type == "moving_average_abs_max_scale": + target_ops.append(op) + + for op in target_ops: + in_var_name = op.input('X')[0] + out_var_name = op.output('Out')[0] + block = op.block + previous_op = utils.find_previous_op(block, in_var_name) + next_ops = utils.find_next_ops(block, out_var_name) + + out_scale_name = op.output('OutScale')[0] + out_scale = utils.load_variable_data(scope, out_scale_name) + out_scale = utils.fp_numpy_to_naive(out_scale) + + if previous_op.type != "feed": + argname, index = utils._get_output_name_index(previous_op, + in_var_name) + previous_op._set_attr(argname + str(index) + "_threshold", + out_scale) + previous_op._set_attr("out_threshold", out_scale) + + for next_op in next_ops: + next_op._rename_input(out_var_name, in_var_name) + + _gather_input_scale() + _gather_output_scale() def _set_skip_quant_attr(self, program): """ diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py b/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py index 98eefc7360812..4158c52d5ae25 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py @@ -16,8 +16,12 @@ import numpy as np import paddle +import paddle.nn.quant.quant_layers as quant_layers -from . import quant_nn +from ..quantization_pass import _get_op_input_var_names +from ..quantization_pass import _get_op_output_var_names +from ..quantization_pass import _get_output_name_index +from ..quantization_pass import _get_input_name_index layer_name_map = { 'Conv2D': paddle.nn.Conv2D, @@ -54,13 +58,15 @@ ] fake_quant_leaf_layers = [ - quant_nn.FakeQuantAbsMax, - quant_nn.FakeQuantChannelWiseAbsMax, - quant_nn.FakeQuantMovingAverageAbsMax, - quant_nn.MovingAverageAbsMaxScale, + quant_layers.FakeQuantAbsMax, + quant_layers.FakeQuantChannelWiseAbsMax, + quant_layers.FakeQuantMovingAverageAbsMax, + quant_layers.MovingAverageAbsMaxScale, ] -fake_quant_wrap_layers = [quant_nn.QuantizedConv2D, quant_nn.QuantizedLinear] +fake_quant_wrap_layers = [ + quant_layers.QuantizedConv2D, quant_layers.QuantizedLinear +] # The weight format of these layers is Cin * Cout * H * W spec_channel_axis_layers = [paddle.nn.Conv2D, paddle.nn.Conv2DTranspose] @@ -94,6 +100,7 @@ def find_previous_op(block, var_name): for op in block.ops: if var_name in op.output_arg_names: return op + return None def find_next_ops(block, var_name): @@ -244,3 +251,10 @@ def cal_kl_scaling_factor(hist, abs_max, bits): break min_kl_index = starting_iter return (min_kl_index + 0.5) * bin_width + + +def fp_numpy_to_naive(x_np): + if x_np.size == 1: + return float(x_np) + else: + return x_np.tolist() diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 010c6a67a3a38..b3b12a477e2a0 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -141,12 +141,21 @@ def _get_op_input_var_names(op): - """ """ + """ + Get the input var names of the op. + Args: + op(IrNode, Operator): the input op. + Returns: + input_var_names or None. + """ assert isinstance(op, (IrNode, Operator)), \ "The input op should be IrNode or Operator." var_names = [] op_name = op.name() if isinstance(op, IrNode) \ else op.type + if op_name not in _op_real_in_out_name: + return [] + name_list = _op_real_in_out_name[op_name][0] for name in name_list: var_name = op.input(name) @@ -163,6 +172,9 @@ def _get_input_name_index(op, input_var_name): "The input op should be IrNode or Operator." op_name = op.name() if isinstance(op, IrNode) \ else op.type + if op_name not in _op_real_in_out_name: + return None + res = None for argname in _op_real_in_out_name[op_name][0]: var_names = op.input(argname) @@ -179,6 +191,9 @@ def _get_op_output_var_names(op): var_names = [] op_name = op.name() if isinstance(op, IrNode) \ else op.type + if op_name not in _op_real_in_out_name: + return [] + name_list = _op_real_in_out_name[op_name][1] for name in name_list: var_name = op.output(name) @@ -195,6 +210,9 @@ def _get_output_name_index(op, output_var_name): "The input op should be IrNode or Operator." op_name = op.name() if isinstance(op, IrNode) \ else op.type + if op_name not in _op_real_in_out_name: + return None + name_list = _op_real_in_out_name[op_name][1] res = None for name in name_list: diff --git a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py index 3cc61ce8c5808..39d44060abfb3 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py +++ b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py @@ -31,7 +31,7 @@ from paddle.nn import Linear, Conv2D, Softmax from paddle.fluid.log_helper import get_logger from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX -from paddle.fluid.contrib.slim.quantization.imperative.quant_nn import QuantizedConv2D +from paddle.nn.quant.quant_layers import QuantizedConv2D from imperative_test_utils import fix_model_dict, ImperativeLenet diff --git a/python/paddle/fluid/contrib/slim/tests/test_moving_average_abs_max_scale_op.py b/python/paddle/fluid/contrib/slim/tests/test_moving_average_abs_max_scale_op.py index 10c01566d05ee..656fb1dda3bd1 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_moving_average_abs_max_scale_op.py +++ b/python/paddle/fluid/contrib/slim/tests/test_moving_average_abs_max_scale_op.py @@ -20,7 +20,7 @@ import paddle import paddle.fluid as fluid from paddle.fluid import core -from paddle.fluid.contrib.slim.quantization.imperative import quant_nn +import paddle.nn.quant.quant_layers as quant_layers paddle.enable_static() @@ -45,7 +45,7 @@ def check_backward(self, use_cuda): name='image', shape=[784], dtype='float32') label = fluid.layers.data(name='label', shape=[1], dtype='int64') fc_tmp = fluid.layers.fc(image, size=10, act='softmax') - out_scale = quant_nn.MovingAverageAbsMaxScale( + out_scale = quant_layers.MovingAverageAbsMaxScale( name=fc_tmp.name, dtype=fc_tmp.dtype) fc_tmp_1 = out_scale(fc_tmp) cross_entropy = fluid.layers.softmax_with_cross_entropy(fc_tmp, diff --git a/python/paddle/nn/quant/__init__.py b/python/paddle/nn/quant/__init__.py index c7f9a5073def8..8973761ab6944 100644 --- a/python/paddle/nn/quant/__init__.py +++ b/python/paddle/nn/quant/__init__.py @@ -21,5 +21,6 @@ from .functional_layers import transpose # noqa: F401 from .functional_layers import concat # noqa: F401 from .functional_layers import flatten # noqa: F401 +from .quant_layers import QuantStub # noqa: F401 __all__ = [] diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/quant_nn.py b/python/paddle/nn/quant/quant_layers.py similarity index 95% rename from python/paddle/fluid/contrib/slim/quantization/imperative/quant_nn.py rename to python/paddle/nn/quant/quant_layers.py index fd1f7f423ff8f..c069b3147115e 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/quant_nn.py +++ b/python/paddle/nn/quant/quant_layers.py @@ -26,21 +26,103 @@ from paddle.fluid.log_helper import get_logger __all__ = [ - 'FakeQuantMovingAverageAbsMax', 'FakeQuantAbsMax', + 'FakeQuantMovingAverageAbsMax', 'FakeQuantChannelWiseAbsMax', 'QuantizedConv2D', 'QuantizedLinear', - 'QuantizedNoweightLayer', 'MovingAverageAbsMaxScale', 'MAOutputScaleLayer', 'FakeQuantMAOutputScaleLayer', + 'QuantStub', ] _logger = get_logger( __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') +class FakeQuantAbsMax(layers.Layer): + r""" + FakeQuantAbsMax layer does the abs_max quant and then dequant. + Its computational formula is described as below: + + :math:`scale = max(abs(X))` + :math:`range = 2^{bit\_length - 1} - 1` + :math:`Out = round(X / scale * range) * scale / range` + """ + + def __init__(self, + name=None, + quant_bits=8, + dtype='float32', + quant_on_weight=False): + super(FakeQuantAbsMax, self).__init__() + self._quant_bits = quant_bits + self._name = name + scale_prefix = "{}.scale".format( + name) if name else 'quant_dequant.scale' + self._scale_name = unique_name.generate(scale_prefix) + if quant_on_weight: + scale_attr = ParamAttr( + name=self._scale_name, + initializer=Constant(0.0), + trainable=False) + self._scale = self.create_parameter( + shape=[1], attr=scale_attr, dtype=self._dtype) + self._scale.stop_gradient = True + else: + self._scale = None + + def forward(self, input): + if in_dygraph_mode(): + attrs = ('bit_length', self._quant_bits) + quant_out = _varbase_creator( + type=input.type, + name="{}.quantized.dequantized".format(input.name), + shape=input.shape, + dtype=input.dtype, + persistable=False) + out_scale = self._scale + if not out_scale: + out_scale = _varbase_creator( + type=core.VarDesc.VarType.LOD_TENSOR, + name=self._scale_name, + shape=[1], + dtype=self._dtype, + persistable=False) + out_scale.stop_gradient = True + out, _, = core.ops.fake_quantize_dequantize_abs_max( + input, quant_out, out_scale, *attrs) + return out + + check_variable_and_dtype(input, 'input', ['float32'], "FakeQuantAbsMax") + attrs = {'bit_length': self._quant_bits} + inputs = {"X": [input]} + quant_out = self._helper.create_variable( + name="{}.quantized.dequantized".format(input.name), + dtype=input.dtype, + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=False) + out_scale = self._scale + if not out_scale: + out_scale = self._helper.create_variable( + name=self._scale_name, + dtype=self._dtype, + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=True) + outputs = {"Out": [quant_out], "OutScale": [out_scale]} + + self._helper.append_op( + type="fake_quantize_dequantize_abs_max", + inputs=inputs, + outputs=outputs, + attrs=attrs) + + return quant_out + + class FakeQuantMovingAverageAbsMax(layers.Layer): r""" FakeQuantMovingAverageAbsMax layer does the moving_average_abs_max quant and then dequant. @@ -64,7 +146,7 @@ def __init__(self, name) if name else 'quant_dequant.scale' scale_attr = ParamAttr( name=unique_name.generate(scale_prefix), - initializer=Constant(0.001), + initializer=Constant(0.), trainable=False) self._scale = self.create_parameter( shape=[1], attr=scale_attr, dtype=dtype) @@ -74,7 +156,7 @@ def __init__(self, name) if name else 'quant_dequant.state' state_attr = ParamAttr( name=unique_name.generate(state_prefix), - initializer=Constant(1), + initializer=Constant(0), trainable=False) self._state = self.create_parameter( shape=[1], attr=state_attr, dtype=dtype) @@ -84,7 +166,7 @@ def __init__(self, name) if name else 'quant_dequant.accum' accum_attr = ParamAttr( name=unique_name.generate(accum_prefix), - initializer=Constant(1), + initializer=Constant(0), trainable=False) self._accum = self.create_parameter( shape=[1], attr=accum_attr, dtype=dtype) @@ -139,24 +221,21 @@ def forward(self, input): return quant_out -class FakeQuantAbsMax(layers.Layer): - r""" - FakeQuantAbsMax layer does the abs_max quant and then dequant. - Its computational formula is described as below: - - :math:`scale = max(abs(X))` - :math:`range = 2^{bit\_length - 1} - 1` - :math:`Out = round(X / scale * range) * scale / range` - """ - +class FakeQuantChannelWiseAbsMax(layers.Layer): def __init__(self, name=None, + channel_num=None, quant_bits=8, + quant_axis=0, dtype='float32', quant_on_weight=False): - super(FakeQuantAbsMax, self).__init__() + assert quant_on_weight == True, "Channel_wise only can be used on weight quantization." + super(FakeQuantChannelWiseAbsMax, self).__init__() self._quant_bits = quant_bits + self._quant_axis = quant_axis + self._dtype = dtype self._name = name + self._channel_num = channel_num scale_prefix = "{}.scale".format( name) if name else 'quant_dequant.scale' self._scale_name = unique_name.generate(scale_prefix) @@ -166,35 +245,39 @@ def __init__(self, initializer=Constant(0.0), trainable=False) self._scale = self.create_parameter( - shape=[1], attr=scale_attr, dtype=self._dtype) + shape=[self._channel_num], attr=scale_attr, dtype=self._dtype) self._scale.stop_gradient = True else: self._scale = None def forward(self, input): if in_dygraph_mode(): - attrs = ('bit_length', self._quant_bits) + attrs = ('bit_length', self._quant_bits, 'quant_axis', + self._quant_axis) quant_out = _varbase_creator( type=input.type, name="{}.quantized.dequantized".format(input.name), shape=input.shape, dtype=input.dtype, persistable=False) + out_scale = self._scale - if not out_scale: + if out_scale is None: out_scale = _varbase_creator( type=core.VarDesc.VarType.LOD_TENSOR, name=self._scale_name, - shape=[1], + shape=[self._channel_num], dtype=self._dtype, persistable=False) out_scale.stop_gradient = True - out, _, = core.ops.fake_quantize_dequantize_abs_max( + + out, _, = core.ops.fake_channel_wise_quantize_dequantize_abs_max( input, quant_out, out_scale, *attrs) return out - check_variable_and_dtype(input, 'input', ['float32'], "FakeQuantAbsMax") - attrs = {'bit_length': self._quant_bits} + check_variable_and_dtype(input, 'input', ['float32'], + "FakeQuantChannelWiseAbsMax") + attrs = {'bit_length': self._quant_bits, 'quant_axis': self._quant_axis} inputs = {"X": [input]} quant_out = self._helper.create_variable( name="{}.quantized.dequantized".format(input.name), @@ -213,7 +296,7 @@ def forward(self, input): outputs = {"Out": [quant_out], "OutScale": [out_scale]} self._helper.append_op( - type="fake_quantize_dequantize_abs_max", + type="fake_channel_wise_quantize_dequantize_abs_max", inputs=inputs, outputs=outputs, attrs=attrs) @@ -221,82 +304,83 @@ def forward(self, input): return quant_out -class FakeQuantChannelWiseAbsMax(layers.Layer): - def __init__(self, - name=None, - channel_num=None, - quant_bits=8, - quant_axis=0, - dtype='float32', - quant_on_weight=False): - assert quant_on_weight == True, "Channel_wise only can be used on weight quantization." - super(FakeQuantChannelWiseAbsMax, self).__init__() - self._quant_bits = quant_bits - self._quant_axis = quant_axis - self._dtype = dtype - self._name = name - self._channel_num = channel_num - scale_prefix = "{}.scale".format( - name) if name else 'quant_dequant.scale' - self._scale_name = unique_name.generate(scale_prefix) - if quant_on_weight: - scale_attr = ParamAttr( - name=self._scale_name, - initializer=Constant(0.0), - trainable=False) - self._scale = self.create_parameter( - shape=[self._channel_num], attr=scale_attr, dtype=self._dtype) - self._scale.stop_gradient = True - else: - self._scale = None +class MovingAverageAbsMaxScale(layers.Layer): + def __init__(self, name=None, moving_rate=0.9, dtype='float32'): + r""" + MovingAverageMaxScale layer is used to calculating the output quantization + scale of Layer. Its computational formula is described as below: + + :math:`scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)` + :math:`Out = X` + """ + super(MovingAverageAbsMaxScale, self).__init__() + self._moving_rate = moving_rate + + scale_prefix = '{}.scale'.format(name) if name else 'outscale.scale' + scale_name = unique_name.generate(scale_prefix) + scale_attr = ParamAttr( + name=scale_name, initializer=Constant(0), trainable=False) + self._scale = self.create_parameter( + shape=[1], attr=scale_attr, dtype=dtype) + self._scale.stop_gradient = True + + state_prefix = "{}.state".format(name) if name else 'outscale.state' + state_attr = ParamAttr( + name=unique_name.generate(state_prefix), + initializer=Constant(0), + trainable=False) + self._state = self.create_parameter( + shape=[1], attr=state_attr, dtype=dtype) + self._state.stop_gradient = True + + accum_prefix = "{}.accum".format(name) if name else 'outscale.accum' + accum_attr = ParamAttr( + name=unique_name.generate(accum_prefix), + initializer=Constant(0), + trainable=False) + self._accum = self.create_parameter( + shape=[1], attr=accum_attr, dtype=dtype) + self._accum.stop_gradient = True def forward(self, input): if in_dygraph_mode(): - attrs = ('bit_length', self._quant_bits, 'quant_axis', - self._quant_axis) + attrs = ('moving_rate', self._moving_rate, 'is_test', + not self.training) + state = self._state if self.training else None + accum = self._accum if self.training else None quant_out = _varbase_creator( type=input.type, - name="{}.quantized.dequantized".format(input.name), + name="{}.tmp".format(input.name), shape=input.shape, dtype=input.dtype, persistable=False) - out_scale = self._scale - if out_scale is None: - out_scale = _varbase_creator( - type=core.VarDesc.VarType.LOD_TENSOR, - name=self._scale_name, - shape=[self._channel_num], - dtype=self._dtype, - persistable=False) - out_scale.stop_gradient = True - - out, _, = core.ops.fake_channel_wise_quantize_dequantize_abs_max( - input, quant_out, out_scale, *attrs) + out, _, _, _ = core.ops.moving_average_abs_max_scale( + input, accum, state, quant_out, self._scale, state, accum, + *attrs) return out - check_variable_and_dtype(input, 'input', ['float32'], - "FakeQuantChannelWiseAbsMax") - attrs = {'bit_length': self._quant_bits, 'quant_axis': self._quant_axis} + check_variable_and_dtype(input, 'input', ['float32', 'float64'], + 'MovingAverageAbsMaxScale') + + attrs = {'moving_rate': self._moving_rate, 'is_test': not self.training} inputs = {"X": [input]} quant_out = self._helper.create_variable( - name="{}.quantized.dequantized".format(input.name), + name="{}.tmp".format(input.name), dtype=input.dtype, type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, stop_gradient=False) - out_scale = self._scale - if not out_scale: - out_scale = self._helper.create_variable( - name=self._scale_name, - dtype=self._dtype, - type=core.VarDesc.VarType.LOD_TENSOR, - persistable=False, - stop_gradient=True) - outputs = {"Out": [quant_out], "OutScale": [out_scale]} + outputs = {"Out": [quant_out], "OutScale": [self._scale]} + + if self.training: + inputs['InState'] = [self._state] + inputs['InAccum'] = [self._accum] + outputs['OutState'] = [self._state] + outputs['OutAccum'] = [self._accum] self._helper.append_op( - type="fake_channel_wise_quantize_dequantize_abs_max", + type="moving_average_abs_max_scale", inputs=inputs, outputs=outputs, attrs=attrs) @@ -304,31 +388,7 @@ def forward(self, input): return quant_out -def _get_fake_quant_type(quant_type, **kwargs): - call_args = { - "name": kwargs.get("name", None), - "quant_bits": kwargs.get("quant_bits", 8), - "dtype": kwargs.get("dtype", "float32") - } - - if quant_type == 'abs_max': - call_args["quant_on_weight"] = kwargs.get("quant_on_weight", False) - elif quant_type == 'moving_average_abs_max': - call_args["moving_rate"] = kwargs.get("moving_rate", 0.9) - elif quant_type == 'channel_wise_abs_max': - call_args["quant_on_weight"] = kwargs.get("quant_on_weight", False) - call_args["channel_num"] = kwargs.get("channel_num", None) - call_args["quant_axis"] = kwargs.get("quant_axis", 0) - assert call_args["channel_num"] is not None, ( - "You need to input channel_num" - "when you use channel_wise_abs_max strategy.") - fake_quant_map = { - 'abs_max': FakeQuantAbsMax, - 'moving_average_abs_max': FakeQuantMovingAverageAbsMax, - 'channel_wise_abs_max': FakeQuantChannelWiseAbsMax - } - - return fake_quant_map[quant_type](**call_args) +QuantStub = MovingAverageAbsMaxScale class QuantizedConv2D(layers.Layer): @@ -489,117 +549,10 @@ def forward(self, input): return out -class QuantizedNoweightLayer(layers.Layer): - def __init__(self, - layer, - weight_bits=8, - activation_bits=8, - moving_rate=0.9, - *args, - **kwargs): - - super(QuantizedNoweightLayer, self).__init__() - self._layer = layer - self._fake_quant_input = _get_fake_quant_type( - 'moving_average_abs_max', - name=layer.full_name(), - moving_rate=moving_rate, - quant_bits=activation_bits, - dtype=self._dtype, - quant_on_weight=False) - - def forward(self, input): - return self._layer.forward(self._fake_quant_input(input)) - - -class MovingAverageAbsMaxScale(layers.Layer): - def __init__(self, name=None, moving_rate=0.9, dtype='float32'): - r""" - MovingAverageMaxScale layer is used to calculating the output quantization - scale of Layer. Its computational formula is described as below: - - :math:`scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)` - :math:`Out = X` - """ - super(MovingAverageAbsMaxScale, self).__init__() - self._moving_rate = moving_rate - - scale_prefix = '{}.scale'.format(name) if name else 'outscale.scale' - scale_name = unique_name.generate(scale_prefix) - scale_attr = ParamAttr( - name=scale_name, initializer=Constant(1), trainable=False) - self._scale = self.create_parameter( - shape=[1], attr=scale_attr, dtype=dtype) - self._scale.stop_gradient = True - - state_prefix = "{}.state".format(name) if name else 'outscale.state' - state_attr = ParamAttr( - name=unique_name.generate(state_prefix), - initializer=Constant(1), - trainable=False) - self._state = self.create_parameter( - shape=[1], attr=state_attr, dtype=dtype) - self._state.stop_gradient = True - - accum_prefix = "{}.accum".format(name) if name else 'outscale.accum' - accum_attr = ParamAttr( - name=unique_name.generate(accum_prefix), - initializer=Constant(1), - trainable=False) - self._accum = self.create_parameter( - shape=[1], attr=accum_attr, dtype=dtype) - self._accum.stop_gradient = True - - def forward(self, input): - if in_dygraph_mode(): - attrs = ('moving_rate', self._moving_rate, 'is_test', - not self.training) - state = self._state if self.training else None - accum = self._accum if self.training else None - quant_out = _varbase_creator( - type=input.type, - name="{}.tmp".format(input.name), - shape=input.shape, - dtype=input.dtype, - persistable=False) - - out, _, _, _ = core.ops.moving_average_abs_max_scale( - input, accum, state, quant_out, self._scale, state, accum, - *attrs) - return out - - check_variable_and_dtype(input, 'input', ['float32', 'float64'], - 'MovingAverageAbsMaxScale') - - attrs = {'moving_rate': self._moving_rate, 'is_test': not self.training} - inputs = {"X": [input]} - quant_out = self._helper.create_variable( - name="{}.tmp".format(input.name), - dtype=input.dtype, - type=core.VarDesc.VarType.LOD_TENSOR, - persistable=False, - stop_gradient=False) - outputs = {"Out": [quant_out], "OutScale": [self._scale]} - - if self.training: - inputs['InState'] = [self._state] - inputs['InAccum'] = [self._accum] - outputs['OutState'] = [self._state] - outputs['OutAccum'] = [self._accum] - - self._helper.append_op( - type="moving_average_abs_max_scale", - inputs=inputs, - outputs=outputs, - attrs=attrs) - - return quant_out - - class MAOutputScaleLayer(layers.Layer): """ - Calculate the scale (moving average abs max) for the output of the input layer. Add MovingAverageMaxScale layer to the behind of the input layer. + Calculate the scale (moving average abs max) for the output of the input layer. """ def __init__(self, layer=None, moving_rate=0.9, name=None, dtype='float32'): @@ -623,6 +576,10 @@ def forward(self, *inputs, **kwargs): class FakeQuantMAOutputScaleLayer(layers.Layer): + """ + Add FakeQuantMovingAverageAbsMax layer to the behind of the input layer. + """ + def __init__(self, layer, weight_bits=8, @@ -649,3 +606,30 @@ def forward(self, *inputs, **kwargs): return out else: return self._fake_quant_output(out) + + +def _get_fake_quant_type(quant_type, **kwargs): + call_args = { + "name": kwargs.get("name", None), + "quant_bits": kwargs.get("quant_bits", 8), + "dtype": kwargs.get("dtype", "float32") + } + + if quant_type == 'abs_max': + call_args["quant_on_weight"] = kwargs.get("quant_on_weight", False) + elif quant_type == 'moving_average_abs_max': + call_args["moving_rate"] = kwargs.get("moving_rate", 0.9) + elif quant_type == 'channel_wise_abs_max': + call_args["quant_on_weight"] = kwargs.get("quant_on_weight", False) + call_args["channel_num"] = kwargs.get("channel_num", None) + call_args["quant_axis"] = kwargs.get("quant_axis", 0) + assert call_args["channel_num"] is not None, ( + "You need to input channel_num" + "when you use channel_wise_abs_max strategy.") + fake_quant_map = { + 'abs_max': FakeQuantAbsMax, + 'moving_average_abs_max': FakeQuantMovingAverageAbsMax, + 'channel_wise_abs_max': FakeQuantChannelWiseAbsMax + } + + return fake_quant_map[quant_type](**call_args)