From ce748b0ee8377b443653fae926104a0cc0d3d902 Mon Sep 17 00:00:00 2001 From: baiyfbupt Date: Thu, 12 Nov 2020 11:08:24 +0800 Subject: [PATCH 1/5] support user-defined quant and preprocess --- .../slim/quantization/imperative/qat.py | 48 +++- .../slim/quantization/imperative/quant_nn.py | 122 +++++++--- .../tests/test_imperative_qat_user_defined.py | 213 ++++++++++++++++++ 3 files changed, 344 insertions(+), 39 deletions(-) create mode 100644 python/paddle/fluid/contrib/slim/tests/test_imperative_qat_user_defined.py diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py index 7fc177e7ad765..1a76c91e365ee 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py @@ -16,6 +16,7 @@ import numpy as np import sys import os +import types import paddle from paddle.fluid import dygraph, core, framework from paddle.fluid.executor import Executor @@ -59,7 +60,11 @@ def __init__(self, weight_quantize_type='abs_max', activation_quantize_type='moving_average_abs_max', moving_rate=0.9, - quantizable_layer_type=['Conv2D', 'Linear']): + quantizable_layer_type=['Conv2D', 'Linear'], + weight_preprocess=None, + act_preprocess=None, + weight_quantize=None, + act_quantize=None): """ The constructor for ImperativeQuantAware. @@ -81,7 +86,28 @@ def __init__(self, quantizable_op_type(list[str]): List the type of layers that will be quantized. Default is ['Conv2D', 'Linear']. The quantizable_op_type in QuantizationFreezePass and ConvertToInt8Pass must be the same as this. - + weight_preprocess(paddle.nn.Layer): A paddle Layer that defines how to preprocess + weight before quantization. Using this can quickly test if user's + preprocess method works or not. The input is non-quantized + weight and function returns processed weight to be quantized. + If None, the weight will be quantized directly. Default is None. + act_preprocess(paddle.nn.Layer): A paddle Layer that defines how to preprocess + activation before quantization. Using this can quickly test if user's + preprocess method works or not. The input is non-quantized + activation and function returns processed activation to be quantized. + If None, the activation will be quantized directly. Default is None. + weight_quantize(paddle.nn.Layer): A paddle Layer that defines how to quantize weight. + Using this can quickly test if user's quantization method works or not. + In this layer, user should both define quantization method and + dequantization method, that is, the function's input is non-quantized + weight and returns dequantized weight. If None, will use + quantization op defined by 'weight_quantize_type'. Default is None. + act_quantize(paddle.nn.Layer): A paddle Layer that defines how to quantize activation. + Using this can quickly test if user's quantization method works or not. + In this layer, user should both define quantization method and + dequantization method, that is, the function's input is non-quantized + activation and returns dequantized activation. If None, will use + quantization op defined by 'activation_quantize_type'. Default is None. Examples: .. code-block:: python @@ -118,6 +144,20 @@ def __init__(self, self._activation_bits = activation_bits self._moving_rate = moving_rate + self._weight_preprocess = weight_preprocess + self._act_preprocess = act_preprocess + self._weight_quantize = weight_quantize + self._act_quantize = act_quantize + + t_check = lambda method: method is None or issubclass(method, dygraph.layers.Layer) + assert t_check( + self._weight_preprocess), "weight_preprocess should be nn.Layer" + assert t_check( + self._act_preprocess), "act_preprocess should be nn.Layer" + assert t_check( + self._weight_quantize), "weight_quantize should be nn.Layer" + assert t_check(self._act_quantize), "act_quantize should be nn.Layer" + quant_type = { 'abs_max', 'moving_average_abs_max', 'channel_wise_abs_max' } @@ -189,7 +229,9 @@ def _get_quantized_counterpart(self, layer): quantized_layer = quant_nn.__dict__[quantized_counterpart[index]]( layer, self._weight_bits, self._activation_bits, self._moving_rate, - self._weight_quantize_type, self._activation_quantize_type) + self._weight_quantize_type, self._activation_quantize_type, + self._weight_preprocess, self._act_preprocess, + self._weight_quantize, self._act_quantize) return quantized_layer diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/quant_nn.py b/python/paddle/fluid/contrib/slim/quantization/imperative/quant_nn.py index bbaae56439eb6..697367e1f2f79 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/quant_nn.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/quant_nn.py @@ -332,7 +332,11 @@ def __init__(self, activation_bits=8, moving_rate=0.9, weight_quantize_type='abs_max', - activation_quantize_type='abs_max'): + activation_quantize_type='abs_max', + weight_preprocess=None, + act_preprocess=None, + weight_quantize=None, + act_quantize=None): super(QuantizedConv2D, self).__init__() # For Conv2D self._groups = getattr(layer, '_groups') @@ -347,26 +351,46 @@ def __init__(self, self.bias = getattr(layer, 'bias') # For FakeQuant self._conv2d_quant_axis = 0 - self._fake_quant_weight = _get_fake_quant_type( - weight_quantize_type, - name=self.weight.name, - moving_rate=moving_rate, - quant_bits=weight_bits, - dtype=self._dtype, - quant_on_weight=True, - channel_num=self.weight.shape[self._conv2d_quant_axis], - quant_axis=self._conv2d_quant_axis) - self._fake_quant_input = _get_fake_quant_type( - activation_quantize_type, - name=layer.full_name(), - moving_rate=moving_rate, - quant_bits=activation_bits, - dtype=self._dtype, - quant_on_weight=False) + + if weight_quantize is not None: + self._fake_quant_weight = weight_quantize() + else: + self._fake_quant_weight = _get_fake_quant_type( + weight_quantize_type, + name=self.weight.name, + moving_rate=moving_rate, + quant_bits=weight_bits, + dtype=self._dtype, + quant_on_weight=True, + channel_num=self.weight.shape[self._conv2d_quant_axis], + quant_axis=self._conv2d_quant_axis) + if act_quantize is not None: + self._fake_quant_input = act_quantize() + else: + self._fake_quant_input = _get_fake_quant_type( + activation_quantize_type, + name=layer.full_name(), + moving_rate=moving_rate, + quant_bits=activation_bits, + dtype=self._dtype, + quant_on_weight=False) + + self.do_act_preprocess = True if act_preprocess is not None else False + self.do_weight_preprocess = True if weight_preprocess is not None else False + if self.do_act_preprocess: + self._act_preprocess = act_preprocess() + if self.do_weight_preprocess: + self._weight_preprocess = weight_preprocess() def forward(self, input): + if self.do_act_preprocess: + input = self._act_preprocess(input) quant_input = self._fake_quant_input(input) - quant_weight = self._fake_quant_weight(self.weight) + + weight = self.weight + if self.do_weight_preprocess: + weight = self._weight_preprocess(self.weight) + quant_weight = self._fake_quant_weight(weight) if in_dygraph_mode() and self._l_type == 'conv2d': attrs = ('strides', self._stride, 'paddings', self._padding, @@ -428,7 +452,11 @@ def __init__(self, activation_bits=8, moving_rate=0.9, weight_quantize_type='abs_max', - activation_quantize_type='abs_max'): + activation_quantize_type='abs_max', + weight_preprocess=None, + act_preprocess=None, + weight_quantize=None, + act_quantize=None): super(QuantizedLinear, self).__init__() # For Linear self._act = getattr(layer, '_act') @@ -437,26 +465,48 @@ def __init__(self, self.bias = getattr(layer, 'bias') # For FakeQuant self._linear_quant_axis = 1 - self._fake_quant_weight = _get_fake_quant_type( - weight_quantize_type, - name=self.weight.name, - moving_rate=moving_rate, - quant_bits=weight_bits, - dtype=self._dtype, - quant_on_weight=True, - channel_num=self.weight.shape[self._linear_quant_axis], - quant_axis=self._linear_quant_axis) - self._fake_quant_input = _get_fake_quant_type( - activation_quantize_type, - name=layer.full_name(), - moving_rate=moving_rate, - quant_bits=activation_bits, - dtype=self._dtype, - quant_on_weight=False) + + if weight_quantize is not None: + self._fake_quant_weight = weight_quantize() + else: + self._fake_quant_weight = _get_fake_quant_type( + weight_quantize_type, + name=self.weight.name, + moving_rate=moving_rate, + quant_bits=weight_bits, + dtype=self._dtype, + quant_on_weight=True, + channel_num=self.weight.shape[self._linear_quant_axis], + quant_axis=self._linear_quant_axis) + + if act_quantize is not None: + self._fake_quant_input = act_quantize() + else: + self._fake_quant_input = _get_fake_quant_type( + activation_quantize_type, + name=layer.full_name(), + moving_rate=moving_rate, + quant_bits=activation_bits, + dtype=self._dtype, + quant_on_weight=False) + + self.do_act_preprocess = True if act_preprocess is not None else False + self.do_weight_preprocess = True if weight_preprocess is not None else False + if self.do_act_preprocess: + self._act_preprocess = act_preprocess() + if self.do_weight_preprocess: + self._weight_preprocess = weight_preprocess() def forward(self, input): + if self.do_act_preprocess: + input = self._act_preprocess(input) quant_input = self._fake_quant_input(input) - quant_weight = self._fake_quant_weight(self.weight) + + weight = self.weight + if self.do_weight_preprocess: + weight = self._weight_preprocess(self.weight) + quant_weight = self._fake_quant_weight(weight) + if in_dygraph_mode(): pre_bias = _varbase_creator(dtype=input.dtype) core.ops.matmul(quant_input, quant_weight, pre_bias, 'transpose_X', diff --git a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_user_defined.py b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_user_defined.py new file mode 100644 index 0000000000000..dfc13b5f39aee --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_user_defined.py @@ -0,0 +1,213 @@ +# copyright (c) 2020 paddlepaddle authors. all rights reserved. +# +# licensed under the apache license, version 2.0 (the "license"); +# you may not use this file except in compliance with the license. +# you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. + +from __future__ import print_function + +import os +import numpy as np +import random +import unittest +import logging +import paddle +import paddle.nn as nn +import paddle.fluid as fluid +from paddle.fluid import core +from paddle.fluid.optimizer import AdamOptimizer +from paddle.fluid.framework import IrGraph +from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware +from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass +from paddle.fluid.dygraph.container import Sequential +from paddle.fluid.dygraph.nn import Conv2D +from paddle.fluid.dygraph.nn import Pool2D +from paddle.fluid.dygraph.nn import Linear +from paddle.fluid.log_helper import get_logger +from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX + +os.environ["CPU_NUM"] = "1" + +_logger = get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') + + +class PACT(nn.Layer): + def __init__(self, init_value=20): + super(PACT, self).__init__() + alpha_attr = paddle.ParamAttr( + name=self.full_name() + ".pact", + initializer=paddle.nn.initializer.Constant(value=init_value)) + self.alpha = self.create_parameter( + shape=[1], attr=alpha_attr, dtype='float32') + + def forward(self, x): + out_left = paddle.nn.functional.relu(x - self.alpha) + out_right = paddle.nn.functional.relu(-self.alpha - x) + x = x - out_left + out_right + return x + + +class ImperativeLenet(fluid.dygraph.Layer): + def __init__(self, num_classes=10, classifier_activation='softmax'): + super(ImperativeLenet, self).__init__() + self.features = Sequential( + Conv2D( + num_channels=1, + num_filters=6, + filter_size=3, + stride=1, + padding=1), + Pool2D( + pool_size=2, pool_type='max', pool_stride=2), + Conv2D( + num_channels=6, + num_filters=16, + filter_size=5, + stride=1, + padding=0), + Pool2D( + pool_size=2, pool_type='max', pool_stride=2)) + + self.fc = Sequential( + Linear( + input_dim=400, output_dim=120), + Linear( + input_dim=120, output_dim=84), + Linear( + input_dim=84, output_dim=num_classes, + act=classifier_activation)) + + def forward(self, inputs): + x = self.features(inputs) + + x = fluid.layers.flatten(x, 1) + x = self.fc(x) + return x + + +class TestUserDefinedActPreprocess(unittest.TestCase): + def setUp(self): + _logger.info("test act_preprocess") + self.imperative_qat = ImperativeQuantAware(act_preprocess=PACT) + + def test_quant_aware_training(self): + imperative_qat = self.imperative_qat + seed = 1 + np.random.seed(seed) + fluid.default_main_program().random_seed = seed + fluid.default_startup_program().random_seed = seed + lenet = ImperativeLenet() + fixed_state = {} + param_init_map = {} + for name, param in lenet.named_parameters(): + p_shape = param.numpy().shape + p_value = param.numpy() + if name.endswith("bias"): + value = np.zeros_like(p_value).astype('float32') + else: + value = np.random.normal( + loc=0.0, scale=0.01, + size=np.product(p_shape)).reshape(p_shape).astype('float32') + fixed_state[name] = value + param_init_map[param.name] = value + lenet.set_dict(fixed_state) + + imperative_qat.quantize(lenet) + adam = AdamOptimizer( + learning_rate=0.001, parameter_list=lenet.parameters()) + dynamic_loss_rec = [] + + def train(model): + adam = AdamOptimizer( + learning_rate=0.001, parameter_list=model.parameters()) + epoch_num = 1 + for epoch in range(epoch_num): + model.train() + for batch_id, data in enumerate(train_reader()): + x_data = np.array([x[0].reshape(1, 28, 28) + for x in data]).astype('float32') + y_data = np.array( + [x[1] for x in data]).astype('int64').reshape(-1, 1) + + img = fluid.dygraph.to_variable(x_data) + label = fluid.dygraph.to_variable(y_data) + out = model(img) + acc = fluid.layers.accuracy(out, label) + loss = fluid.layers.cross_entropy(out, label) + avg_loss = fluid.layers.mean(loss) + avg_loss.backward() + adam.minimize(avg_loss) + model.clear_gradients() + if batch_id % 100 == 0: + _logger.info( + "Train | At epoch {} step {}: loss = {:}, acc= {:}". + format(epoch, batch_id, + avg_loss.numpy(), acc.numpy())) + + def test(model): + model.eval() + avg_acc = [[], []] + for batch_id, data in enumerate(test_reader()): + x_data = np.array([x[0].reshape(1, 28, 28) + for x in data]).astype('float32') + y_data = np.array( + [x[1] for x in data]).astype('int64').reshape(-1, 1) + + img = fluid.dygraph.to_variable(x_data) + label = fluid.dygraph.to_variable(y_data) + + out = model(img) + acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) + acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5) + avg_acc[0].append(acc_top1.numpy()) + avg_acc[1].append(acc_top5.numpy()) + if batch_id % 100 == 0: + _logger.info( + "Test | step {}: acc1 = {:}, acc5 = {:}".format( + batch_id, acc_top1.numpy(), acc_top5.numpy())) + + train_reader = paddle.batch( + paddle.dataset.mnist.train(), batch_size=64, drop_last=True) + test_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=64) + train(lenet) + test(lenet) + print(paddle.summary(lenet, (1, 1, 28, 28))) + + paddle.jit.save( + layer=lenet, + path="./dynamic_quant_user_defined/model", + input_spec=[ + paddle.static.InputSpec( + shape=[None, 1, 28, 28], dtype='float32') + ]) + + +class TestUserDefinedWeightPreprocess(TestUserDefinedActPreprocess): + def setUp(self): + _logger.info("test weight_preprocess") + self.imperative_qat = ImperativeQuantAware(weight_preprocess=PACT) + + +class TestUserDefinedActQuantize(TestUserDefinedActPreprocess): + def setUp(self): + _logger.info("test act_quantize") + self.imperative_qat = ImperativeQuantAware(act_quantize=PACT) + + +class TestUserDefinedWeightQuantize(TestUserDefinedActPreprocess): + def setUp(self): + _logger.info("test weight_quantize") + self.imperative_qat = ImperativeQuantAware(weight_quantize=PACT) + + +if __name__ == '__main__': + unittest.main() From 55266ec2423824a882d2669697622e145aa3575e Mon Sep 17 00:00:00 2001 From: baiyfbupt Date: Thu, 12 Nov 2020 12:00:19 +0800 Subject: [PATCH 2/5] code clean --- .../tests/test_imperative_qat_user_defined.py | 46 ++++++++----------- 1 file changed, 20 insertions(+), 26 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_user_defined.py b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_user_defined.py index dfc13b5f39aee..733cf06d985b6 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_user_defined.py +++ b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_user_defined.py @@ -21,18 +21,14 @@ import logging import paddle import paddle.nn as nn -import paddle.fluid as fluid -from paddle.fluid import core -from paddle.fluid.optimizer import AdamOptimizer -from paddle.fluid.framework import IrGraph +from paddle.optimizer import Adam from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass -from paddle.fluid.dygraph.container import Sequential -from paddle.fluid.dygraph.nn import Conv2D -from paddle.fluid.dygraph.nn import Pool2D -from paddle.fluid.dygraph.nn import Linear +from paddle.nn import Sequential +from paddle.fluid.dygraph import Conv2D +from paddle.nn import Pool2D +from paddle.fluid.dygraph import Linear from paddle.fluid.log_helper import get_logger -from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX os.environ["CPU_NUM"] = "1" @@ -56,7 +52,7 @@ def forward(self, x): return x -class ImperativeLenet(fluid.dygraph.Layer): +class ImperativeLenet(paddle.nn.Layer): def __init__(self, num_classes=10, classifier_activation='softmax'): super(ImperativeLenet, self).__init__() self.features = Sequential( @@ -89,7 +85,7 @@ def __init__(self, num_classes=10, classifier_activation='softmax'): def forward(self, inputs): x = self.features(inputs) - x = fluid.layers.flatten(x, 1) + x = paddle.flatten(x, 1) x = self.fc(x) return x @@ -103,8 +99,8 @@ def test_quant_aware_training(self): imperative_qat = self.imperative_qat seed = 1 np.random.seed(seed) - fluid.default_main_program().random_seed = seed - fluid.default_startup_program().random_seed = seed + paddle.static.default_main_program().random_seed = seed + paddle.static.default_startup_program().random_seed = seed lenet = ImperativeLenet() fixed_state = {} param_init_map = {} @@ -122,13 +118,11 @@ def test_quant_aware_training(self): lenet.set_dict(fixed_state) imperative_qat.quantize(lenet) - adam = AdamOptimizer( - learning_rate=0.001, parameter_list=lenet.parameters()) + adam = Adam(learning_rate=0.001, parameters=lenet.parameters()) dynamic_loss_rec = [] def train(model): - adam = AdamOptimizer( - learning_rate=0.001, parameter_list=model.parameters()) + adam = Adam(learning_rate=0.001, parameters=model.parameters()) epoch_num = 1 for epoch in range(epoch_num): model.train() @@ -138,12 +132,12 @@ def train(model): y_data = np.array( [x[1] for x in data]).astype('int64').reshape(-1, 1) - img = fluid.dygraph.to_variable(x_data) - label = fluid.dygraph.to_variable(y_data) + img = paddle.to_tensor(x_data) + label = paddle.to_tensor(y_data) out = model(img) - acc = fluid.layers.accuracy(out, label) - loss = fluid.layers.cross_entropy(out, label) - avg_loss = fluid.layers.mean(loss) + acc = paddle.metric.accuracy(out, label, k=1) + loss = nn.functional.loss.cross_entropy(out, label) + avg_loss = paddle.mean(loss) avg_loss.backward() adam.minimize(avg_loss) model.clear_gradients() @@ -162,12 +156,12 @@ def test(model): y_data = np.array( [x[1] for x in data]).astype('int64').reshape(-1, 1) - img = fluid.dygraph.to_variable(x_data) - label = fluid.dygraph.to_variable(y_data) + img = paddle.to_tensor(x_data) + label = paddle.to_tensor(y_data) out = model(img) - acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) - acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5) + acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1) + acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5) avg_acc[0].append(acc_top1.numpy()) avg_acc[1].append(acc_top5.numpy()) if batch_id % 100 == 0: From a81328850aff3a3818d091ebbb2bb8ef4b53d133 Mon Sep 17 00:00:00 2001 From: baiyfbupt Date: Thu, 12 Nov 2020 13:11:02 +0800 Subject: [PATCH 3/5] code clean --- .../slim/tests/test_imperative_qat_user_defined.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_user_defined.py b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_user_defined.py index 733cf06d985b6..c0c2021490364 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_user_defined.py +++ b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_user_defined.py @@ -141,11 +141,12 @@ def train(model): avg_loss.backward() adam.minimize(avg_loss) model.clear_gradients() - if batch_id % 100 == 0: + if batch_id % 50 == 0: _logger.info( "Train | At epoch {} step {}: loss = {:}, acc= {:}". format(epoch, batch_id, avg_loss.numpy(), acc.numpy())) + break def test(model): model.eval() @@ -170,11 +171,10 @@ def test(model): batch_id, acc_top1.numpy(), acc_top5.numpy())) train_reader = paddle.batch( - paddle.dataset.mnist.train(), batch_size=64, drop_last=True) - test_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=64) + paddle.dataset.mnist.train(), batch_size=512, drop_last=True) + test_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=512) train(lenet) test(lenet) - print(paddle.summary(lenet, (1, 1, 28, 28))) paddle.jit.save( layer=lenet, From cf2741ea33c3fb744b959dab1ad06d0e74ca0337 Mon Sep 17 00:00:00 2001 From: baiyfbupt Date: Tue, 17 Nov 2020 17:07:27 +0800 Subject: [PATCH 4/5] code clean --- .../slim/quantization/imperative/qat.py | 38 ++++++------ .../slim/quantization/imperative/quant_nn.py | 60 +++++++++---------- 2 files changed, 46 insertions(+), 52 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py index 1a76c91e365ee..cae2417723267 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py @@ -16,7 +16,6 @@ import numpy as np import sys import os -import types import paddle from paddle.fluid import dygraph, core, framework from paddle.fluid.executor import Executor @@ -61,10 +60,10 @@ def __init__(self, activation_quantize_type='moving_average_abs_max', moving_rate=0.9, quantizable_layer_type=['Conv2D', 'Linear'], - weight_preprocess=None, - act_preprocess=None, - weight_quantize=None, - act_quantize=None): + weight_preprocess_layer=None, + act_preprocess_layer=None, + weight_quantize_layer=None, + act_quantize_layer=None): """ The constructor for ImperativeQuantAware. @@ -86,23 +85,23 @@ def __init__(self, quantizable_op_type(list[str]): List the type of layers that will be quantized. Default is ['Conv2D', 'Linear']. The quantizable_op_type in QuantizationFreezePass and ConvertToInt8Pass must be the same as this. - weight_preprocess(paddle.nn.Layer): A paddle Layer that defines how to preprocess + weight_preprocess_layer(paddle.nn.Layer, optional): A paddle Layer that defines how to preprocess weight before quantization. Using this can quickly test if user's preprocess method works or not. The input is non-quantized weight and function returns processed weight to be quantized. If None, the weight will be quantized directly. Default is None. - act_preprocess(paddle.nn.Layer): A paddle Layer that defines how to preprocess + act_preprocess_layer(paddle.nn.Layer, optional): A paddle Layer that defines how to preprocess activation before quantization. Using this can quickly test if user's preprocess method works or not. The input is non-quantized activation and function returns processed activation to be quantized. If None, the activation will be quantized directly. Default is None. - weight_quantize(paddle.nn.Layer): A paddle Layer that defines how to quantize weight. + weight_quantize_layer(paddle.nn.Layer, optional): A paddle Layer that defines how to quantize weight. Using this can quickly test if user's quantization method works or not. In this layer, user should both define quantization method and dequantization method, that is, the function's input is non-quantized weight and returns dequantized weight. If None, will use quantization op defined by 'weight_quantize_type'. Default is None. - act_quantize(paddle.nn.Layer): A paddle Layer that defines how to quantize activation. + act_quantize_layer(paddle.nn.Layer, optional): A paddle Layer that defines how to quantize activation. Using this can quickly test if user's quantization method works or not. In this layer, user should both define quantization method and dequantization method, that is, the function's input is non-quantized @@ -144,19 +143,18 @@ def __init__(self, self._activation_bits = activation_bits self._moving_rate = moving_rate - self._weight_preprocess = weight_preprocess - self._act_preprocess = act_preprocess - self._weight_quantize = weight_quantize - self._act_quantize = act_quantize + self._weight_pre_layer = weight_preprocess_layer + self._act_pre_layer = act_preprocess_layer + self._weight_quant_layer = weight_quantize_layer + self._act_quant_layer = act_quantize_layer t_check = lambda method: method is None or issubclass(method, dygraph.layers.Layer) assert t_check( - self._weight_preprocess), "weight_preprocess should be nn.Layer" + self._weight_pre_layer), "weight_preprocess should be nn.Layer" + assert t_check(self._act_pre_layer), "act_preprocess should be nn.Layer" assert t_check( - self._act_preprocess), "act_preprocess should be nn.Layer" - assert t_check( - self._weight_quantize), "weight_quantize should be nn.Layer" - assert t_check(self._act_quantize), "act_quantize should be nn.Layer" + self._weight_quant_layer), "weight_quantize should be nn.Layer" + assert t_check(self._act_quant_layer), "act_quantize should be nn.Layer" quant_type = { 'abs_max', 'moving_average_abs_max', 'channel_wise_abs_max' @@ -230,8 +228,8 @@ def _get_quantized_counterpart(self, layer): quantized_layer = quant_nn.__dict__[quantized_counterpart[index]]( layer, self._weight_bits, self._activation_bits, self._moving_rate, self._weight_quantize_type, self._activation_quantize_type, - self._weight_preprocess, self._act_preprocess, - self._weight_quantize, self._act_quantize) + self._weight_pre_layer, self._act_pre_layer, + self._weight_quant_layer, self._act_quant_layer) return quantized_layer diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/quant_nn.py b/python/paddle/fluid/contrib/slim/quantization/imperative/quant_nn.py index 697367e1f2f79..79138febd0ce8 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/quant_nn.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/quant_nn.py @@ -333,10 +333,10 @@ def __init__(self, moving_rate=0.9, weight_quantize_type='abs_max', activation_quantize_type='abs_max', - weight_preprocess=None, - act_preprocess=None, - weight_quantize=None, - act_quantize=None): + weight_pre_layer=None, + act_pre_layer=None, + weight_quant_layer=None, + act_quant_layer=None): super(QuantizedConv2D, self).__init__() # For Conv2D self._groups = getattr(layer, '_groups') @@ -352,8 +352,8 @@ def __init__(self, # For FakeQuant self._conv2d_quant_axis = 0 - if weight_quantize is not None: - self._fake_quant_weight = weight_quantize() + if weight_quant_layer is not None: + self._fake_quant_weight = weight_quant_layer() else: self._fake_quant_weight = _get_fake_quant_type( weight_quantize_type, @@ -364,8 +364,8 @@ def __init__(self, quant_on_weight=True, channel_num=self.weight.shape[self._conv2d_quant_axis], quant_axis=self._conv2d_quant_axis) - if act_quantize is not None: - self._fake_quant_input = act_quantize() + if act_quant_layer is not None: + self._fake_quant_input = act_quant_layer() else: self._fake_quant_input = _get_fake_quant_type( activation_quantize_type, @@ -375,20 +375,18 @@ def __init__(self, dtype=self._dtype, quant_on_weight=False) - self.do_act_preprocess = True if act_preprocess is not None else False - self.do_weight_preprocess = True if weight_preprocess is not None else False - if self.do_act_preprocess: - self._act_preprocess = act_preprocess() - if self.do_weight_preprocess: - self._weight_preprocess = weight_preprocess() + self._act_preprocess = act_pre_layer( + ) if act_pre_layer is not None else None + self._weight_preprocess = weight_pre_layer( + ) if weight_pre_layer is not None else None def forward(self, input): - if self.do_act_preprocess: + if self._act_preprocess is not None: input = self._act_preprocess(input) quant_input = self._fake_quant_input(input) weight = self.weight - if self.do_weight_preprocess: + if self._weight_preprocess is not None: weight = self._weight_preprocess(self.weight) quant_weight = self._fake_quant_weight(weight) @@ -453,10 +451,10 @@ def __init__(self, moving_rate=0.9, weight_quantize_type='abs_max', activation_quantize_type='abs_max', - weight_preprocess=None, - act_preprocess=None, - weight_quantize=None, - act_quantize=None): + weight_pre_layer=None, + act_pre_layer=None, + weight_quant_layer=None, + act_quant_layer=None): super(QuantizedLinear, self).__init__() # For Linear self._act = getattr(layer, '_act') @@ -466,8 +464,8 @@ def __init__(self, # For FakeQuant self._linear_quant_axis = 1 - if weight_quantize is not None: - self._fake_quant_weight = weight_quantize() + if weight_quant_layer is not None: + self._fake_quant_weight = weight_quant_layer() else: self._fake_quant_weight = _get_fake_quant_type( weight_quantize_type, @@ -479,8 +477,8 @@ def __init__(self, channel_num=self.weight.shape[self._linear_quant_axis], quant_axis=self._linear_quant_axis) - if act_quantize is not None: - self._fake_quant_input = act_quantize() + if act_quant_layer is not None: + self._fake_quant_input = act_quant_layer() else: self._fake_quant_input = _get_fake_quant_type( activation_quantize_type, @@ -490,20 +488,18 @@ def __init__(self, dtype=self._dtype, quant_on_weight=False) - self.do_act_preprocess = True if act_preprocess is not None else False - self.do_weight_preprocess = True if weight_preprocess is not None else False - if self.do_act_preprocess: - self._act_preprocess = act_preprocess() - if self.do_weight_preprocess: - self._weight_preprocess = weight_preprocess() + self._act_preprocess = act_pre_layer( + ) if act_pre_layer is not None else None + self._weight_preprocess = weight_pre_layer( + ) if weight_pre_layer is not None else None def forward(self, input): - if self.do_act_preprocess: + if self._act_preprocess is not None: input = self._act_preprocess(input) quant_input = self._fake_quant_input(input) weight = self.weight - if self.do_weight_preprocess: + if self._weight_preprocess is not None: weight = self._weight_preprocess(self.weight) quant_weight = self._fake_quant_weight(weight) From 609026b2cfe0a2efae1ceb63153c05bde45be7a6 Mon Sep 17 00:00:00 2001 From: baiyfbupt Date: Tue, 17 Nov 2020 18:52:17 +0800 Subject: [PATCH 5/5] code clean --- .../tests/test_imperative_qat_user_defined.py | 65 +++++++++++++++---- 1 file changed, 53 insertions(+), 12 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_user_defined.py b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_user_defined.py index c0c2021490364..29b69bbe0f8ea 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_user_defined.py +++ b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_user_defined.py @@ -52,6 +52,54 @@ def forward(self, x): return x +class CustomQAT(nn.Layer): + def __init__(self): + super(CustomQAT, self).__init__() + attr = paddle.ParamAttr( + initializer=paddle.nn.initializer.Constant(value=1.0)) + self.u_param = self.create_parameter( + shape=[1], attr=attr, dtype='float32') + self.l_param = self.create_parameter( + shape=[1], attr=attr, dtype='float32') + self.alpha_param = self.create_parameter( + shape=[1], attr=attr, dtype='float32') + self.upper = self.create_parameter( + shape=[1], attr=attr, dtype='float32') + self.upper.stop_gradient = True + self.lower = self.create_parameter( + shape=[1], attr=attr, dtype='float32') + self.lower.stop_gradient = True + + def forward(self, x): + def clip(x, upper, lower): + x = x + paddle.nn.functional.relu(lower - x) + x = x - paddle.nn.functional.relu(x - upper) + return x + + def phi_function(x, mi, alpha, delta): + s = 1 / (1 - alpha) + k = paddle.log(2 / alpha - 1) * (1 / delta) + x = (paddle.tanh((x - mi) * k)) * s + return x + + def dequantize(x, lower_bound, delta, interval): + x = ((x + 1) / 2 + interval) * delta + lower_bound + return x + + bit = 8 + bit_range = 2**bit - 1 + + paddle.assign(self.upper * 0.9 + self.u_param * 0.1, self.upper) + paddle.assign(self.lower * 0.9 + self.l_param * 0.1, self.lower) + x = clip(x, self.upper, self.lower) + delta = (self.upper - self.lower) / bit_range + interval = (x - self.lower) / delta + mi = (interval + 0.5) * delta + self.l_param + x = phi_function(x, mi, self.alpha_param, delta) + x = dequantize(x, self.l_param, delta, interval) + return x + + class ImperativeLenet(paddle.nn.Layer): def __init__(self, num_classes=10, classifier_activation='softmax'): super(ImperativeLenet, self).__init__() @@ -93,7 +141,7 @@ def forward(self, inputs): class TestUserDefinedActPreprocess(unittest.TestCase): def setUp(self): _logger.info("test act_preprocess") - self.imperative_qat = ImperativeQuantAware(act_preprocess=PACT) + self.imperative_qat = ImperativeQuantAware(act_preprocess_layer=PACT) def test_quant_aware_training(self): imperative_qat = self.imperative_qat @@ -176,31 +224,24 @@ def test(model): train(lenet) test(lenet) - paddle.jit.save( - layer=lenet, - path="./dynamic_quant_user_defined/model", - input_spec=[ - paddle.static.InputSpec( - shape=[None, 1, 28, 28], dtype='float32') - ]) - class TestUserDefinedWeightPreprocess(TestUserDefinedActPreprocess): def setUp(self): _logger.info("test weight_preprocess") - self.imperative_qat = ImperativeQuantAware(weight_preprocess=PACT) + self.imperative_qat = ImperativeQuantAware(weight_preprocess_layer=PACT) class TestUserDefinedActQuantize(TestUserDefinedActPreprocess): def setUp(self): _logger.info("test act_quantize") - self.imperative_qat = ImperativeQuantAware(act_quantize=PACT) + self.imperative_qat = ImperativeQuantAware(act_quantize_layer=CustomQAT) class TestUserDefinedWeightQuantize(TestUserDefinedActPreprocess): def setUp(self): _logger.info("test weight_quantize") - self.imperative_qat = ImperativeQuantAware(weight_quantize=PACT) + self.imperative_qat = ImperativeQuantAware( + weight_quantize_layer=CustomQAT) if __name__ == '__main__':