From e70dc0be6ae10ff9180bfcf737763e941ad5d674 Mon Sep 17 00:00:00 2001 From: pengjuncai <13006307475@163.com> Date: Mon, 5 Jul 2021 07:27:29 +0000 Subject: [PATCH 1/3] PTQ save quantized model --- .../slim/quantization/imperative/ptq.py | 165 ++++++++++++++++-- .../quantization/imperative/ptq_config.py | 3 +- .../slim/quantization/imperative/ptq_hooks.py | 1 - .../quantization/imperative/ptq_quantizer.py | 11 +- .../quantization/imperative/ptq_registry.py | 23 +++ .../slim/quantization/imperative/qat.py | 8 +- 6 files changed, 189 insertions(+), 22 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq.py b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq.py index 13ca44d7f2a11..75d8509b57413 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq.py @@ -14,10 +14,13 @@ import logging import copy +import os import numpy as np import paddle +import paddle.nn.quant.quant_layers as quant_layers from paddle.fluid.log_helper import get_logger +from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX from . import utils from . import ptq_hooks @@ -70,19 +73,103 @@ def quantize(self, model, inplace=False): for name, layer in new_model.named_sublayers(): if PTQRegistry.is_supported_layer(layer) \ - and utils.is_leaf_layer(layer): + and utils.is_leaf_layer(layer) \ + and not self._is_skip_layer(layer): + # Add quant config quant_config = copy.deepcopy(self._quant_config) layer._quant_config = quant_config + # register hook hook = ptq_hooks.quant_forward_post_hook quant_hook_handle = layer.register_forward_post_hook(hook) quant_config.quant_hook_handle = quant_hook_handle layer._forward_post_hooks.move_to_end( quant_hook_handle._hook_id, last=False) + # TODO(jc): fake quantize the weights + return new_model - def convert(self, model): + def save_quantized_model(self, model, path, input_spec=None, **config): + """ + Save the quantized model for the inference. + + Args: + model (Layer): The model to be saved. + path (str): The path prefix to save model. The format is + ``dirname/file_prefix`` or ``file_prefix``. + input_spec (list[InputSpec|Tensor], optional): Describes the input + of the saved model's forward method, which can be described by + InputSpec or example Tensor. If None, all input variables of + the original Layer's forward method would be the inputs of + the saved model. Default None. + **configs (dict, optional): Other save configuration options for + compatibility. We do not recommend using these configurations, + they may be removed in the future. If not necessary, DO NOT use + them. Default None. + The following options are currently supported: + (1) output_spec (list[Tensor]): Selects the output targets of + the saved model. By default, all return variables of original + Layer's forward method are kept as the output of the saved model. + If the provided ``output_spec`` list is not all output variables, + the saved model will be pruned according to the given + ``output_spec`` list. + + Returns: + None + """ + + assert isinstance(model, paddle.nn.Layer), \ + "The model must be the instance of paddle.nn.Layer." + + model = self._post_process_scales(model) + model = self._wrap_layers(model) + + paddle.jit.save(layer=model, path=path, input_spec=input_spec, **config) + + is_dynamic_mode = False + if paddle.in_dynamic_mode(): + is_dynamic_mode = True + paddle.enable_static() + + place = paddle.CPUPlace() + scope = paddle.static.global_scope() + exe = paddle.static.Executor(place) + + dirname = os.path.dirname(path) + basename = os.path.basename(path) + model_filename = basename + INFER_MODEL_SUFFIX + params_filename = basename + INFER_PARAMS_SUFFIX + + [infer_program, feed_target_names, fetch_targets] = ( + paddle.fluid.io.load_inference_model( + dirname=dirname, + executor=exe, + model_filename=model_filename, + params_filename=params_filename)) + + # TODO(jc): + # process the first moving_average_abs_max_scale layer + # fuse conv + bn + # propagate the threshold + # support skip_quant + + # self._save_output_scale(infer_program, scope) + # self._set_skip_quant_attr(infer_program) + + paddle.fluid.io.save_inference_model( + dirname=dirname, + feeded_var_names=feed_target_names, + target_vars=fetch_targets, + executor=exe, + main_program=infer_program.clone(), + model_filename=model_filename, + params_filename=params_filename) + + if is_dynamic_mode: + paddle.disable_static() + + def _post_process_scales(self, model): """ Process the scales and remove the hooks. @@ -94,23 +181,79 @@ def convert(self, model): assert isinstance(model, paddle.nn.Layer), \ "The input model must be the instance of paddle.nn.Layer." + # remove hook and calculate thresholds for name, sub_layer in model.named_sublayers(): - if PTQRegistry.is_supported_layer(sub_layer) \ - and utils.is_leaf_layer(sub_layer): - - assert hasattr(sub_layer, "_quant_config") + if self._is_quant_layer(sub_layer): quant_config = sub_layer._quant_config quant_config.quant_hook_handle.remove() - quant_config.in_act_quantizer.cal_thresholds() quant_config.out_act_quantizer.cal_thresholds() - # get weight thresholds - if isinstance(sub_layer, tuple(utils.fake_quant_input_layers)): + if PTQRegistry.is_simulated_quant_layer(sub_layer): weights = (sub_layer.weight, ) quant_config.wt_quantizer.sample_data(sub_layer, weights) + quant_config.wt_quantizer.cal_thresholds() + + # save output activation and weight thresholds + for name, sub_layer in model.named_sublayers(): + if self._is_quant_layer(sub_layer): + quant_config = sub_layer._quant_config + layer_info = PTQRegistry.layer_info(sub_layer) + + output_names = layer_info.output_names + output_thresholds = quant_config.out_act_quantizer.thresholds + assert len(output_names) == 1 + assert len(output_thresholds) == 1 + save_name = output_names[0] + str(0) + "_threshold" + sub_layer._set_op_attrs({save_name: output_thresholds[0]}) + sub_layer._set_op_attrs({"out_threshold": output_thresholds[0]}) - # TODO (jc): - # save input activation threshold and quant bits + if PTQRegistry.is_simulated_quant_layer(sub_layer): + weight_names = layer_info.weight_names + weight_thresholds = quant_config.wt_quantizer.thresholds + assert len(weight_names) == 1 + assert len(weight_thresholds) == 1 + save_name = weight_names[0] + str(0) + "_threshold" + sub_layer._set_op_attrs({save_name: weight_thresholds[0]}) return model + + def _wrap_layers(self, model): + """ + Replace conv2d and linear with the quantized layers, and save + thresholds into the fake layers. + Args: + model(paddle.nn.Layer): The model to be quantized. + Returns: + modified_model(paddle.nn.Layer): The modified model. + """ + assert isinstance(model, paddle.nn.Layer), \ + "The input model must be the instance of paddle.nn.Layer." + + # wrap conv2d and linear, save thresholds and quant bits to fake ops + for name, sub_layer in model.named_sublayers(): + if self._is_quant_layer(sub_layer) \ + and PTQRegistry.is_simulated_quant_layer(sub_layer): + parent_layer, sub_name = \ + utils.find_parent_layer_and_sub_name(model, name) + + quant_layer_name = None + for key, value in utils.layer_name_map.items(): + if isinstance(sub_layer, value): + quant_layer_name = 'Quantized' + key + break + assert quant_layer_name is not None + + # TODO(jc): + # quant_layer = quant_layers.__dict__[quant_layer_name](sub_layer, **self._kwargs) + # setattr(parent_layer, sub_name, quant_layer) + + return model + + @staticmethod + def _is_skip_layer(layer): + return hasattr(layer, "skip_quant") and layer.skip_quant == True + + @staticmethod + def _is_quant_layer(layer): + return hasattr(layer, "_quant_config") diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_config.py b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_config.py index 4db311567a734..bbf6c41c06fbd 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_config.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_config.py @@ -43,7 +43,8 @@ def __init__(self, activation_quantizer, weight_quantizer): assert isinstance(activation_quantizer, BaseQuantizer) assert isinstance(weight_quantizer, BaseQuantizer) - self.in_act_quantizer = copy.deepcopy(activation_quantizer) + assert activation_quantizer.quant_bits == weight_quantizer.quant_bits + self.out_act_quantizer = copy.deepcopy(activation_quantizer) self.wt_quantizer = copy.deepcopy(weight_quantizer) diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_hooks.py b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_hooks.py index 82a277ad28e3b..001d4eafd6eb5 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_hooks.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_hooks.py @@ -24,5 +24,4 @@ def quant_forward_post_hook(layer, inputs, outputs): """ assert hasattr(layer, '_quant_config'), \ "The layer should have _quant_config attr" - layer._quant_config.in_act_quantizer.sample_data(layer, inputs) layer._quant_config.out_act_quantizer.sample_data(layer, (outputs, )) diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_quantizer.py b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_quantizer.py index 9999de6bd0fda..68e9c7fe54b1e 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_quantizer.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_quantizer.py @@ -110,6 +110,7 @@ def __init__(self, quant_bits=8): self.quant_bits = quant_bits + self.abs_max_vals = [] self.thresholds = [] @abc.abstractmethod @@ -133,10 +134,10 @@ def sample_data(self, layer, tensors): assert isinstance(tensors, tuple) abs_max_vals = [abs_max_value(t) for t in tensors] - self.thresholds = merge_max_value(self.thresholds, abs_max_vals) + self.abs_max_vals = merge_max_value(self.abs_max_vals, abs_max_vals) def cal_thresholds(self): - pass + self.thresholds = self.abs_max_vals class PerChannelAbsmaxQuantizer(BaseQuantizer): @@ -164,10 +165,11 @@ def sample_data(self, layer, tensors): ] abs_max_vals_list.append(abs_max_vals) - self.thresholds = merge_max_value(self.thresholds, abs_max_vals_list) + self.abs_max_vals = merge_max_value(self.abs_max_vals, + abs_max_vals_list) def cal_thresholds(self): - pass + self.thresholds = self.abs_max_vals @six.add_metaclass(abc.ABCMeta) @@ -180,7 +182,6 @@ def __init__(self, quant_bits=8, bins=1024, upsample_bins=64): self.bins = bins self.upsample_bins = upsample_bins - self.abs_max_vals = [] self.hists = [] def sample_data(self, layer, tensors): diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_registry.py b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_registry.py index 973d66303ece9..8fe2bc42b7686 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_registry.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_registry.py @@ -47,6 +47,8 @@ def __init__(self, layer, input_names, weight_names, output_names): LayerInfo(paddle.nn.quant.add, ['X', 'Y'], [], ['Out']), ] +SIMULATED_LAYERS = [paddle.nn.Conv2D, paddle.nn.Linear] + class PTQRegistry(object): """ @@ -69,14 +71,35 @@ def _init(cls): def is_supported_layer(cls, layer): """ Analyze whether the layer supports quantization. + Args: + layer(Layer): The input layer can be a python class or an instance. + Returns: + flag(bool): Whther the layer is supported. """ cls._init() return layer in cls.supported_layers_map or \ isinstance(layer, tuple(cls.supported_layers_map.keys())) + @classmethod + def is_simulated_quant_layer(cls, layer): + """ + Analyze whether the layer is simulated quant layer. + Args: + layer(Layer): The input layer can be a python class or an instance. + Returns: + flag(bool): Whther the layer is supported. + """ + return layer in SIMULATED_LAYERS or \ + isinstance(layer, tuple(SIMULATED_LAYERS)) + + @classmethod def layer_info(cls, layer): """ Get the infomation for the supported layer. + Args: + layer(Layer): The input layer can be a python class or an instance. + Returns: + layer_info(LayerInfo): The layer info of the input layer. """ assert cls.is_supported_layer( layer), "The input layer is not supported." diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py index 3b4f9a757437a..b8c0e47e9bbc2 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py @@ -379,12 +379,12 @@ def apply(self, model): setattr(parent_layer, sub_name, cur_quant_layer) - def save_quantized_model(self, layer, path, input_spec=None, **config): + def save_quantized_model(self, model, path, input_spec=None, **config): """ Save the quantized model for the inference. Args: - layer (Layer): The Layer to be saved. + model (Layer): The model to be saved. path (str): The path prefix to save model. The format is ``dirname/file_prefix`` or ``file_prefix``. input_spec (list[InputSpec|Tensor], optional): Describes the input @@ -407,10 +407,10 @@ def save_quantized_model(self, layer, path, input_spec=None, **config): Returns: None """ - assert isinstance(layer, dygraph.Layer), \ + assert isinstance(model, dygraph.Layer), \ "The model must be the instance of dygraph.Layer." - paddle.jit.save(layer=layer, path=path, input_spec=input_spec, **config) + paddle.jit.save(layer=model, path=path, input_spec=input_spec, **config) is_dynamic_mode = False if paddle.in_dynamic_mode(): From 8c169656d49c25fd75154dae4178e7a67e790d3f Mon Sep 17 00:00:00 2001 From: pengjuncai <13006307475@163.com> Date: Tue, 6 Jul 2021 06:58:30 +0000 Subject: [PATCH 2/3] Wrap simulated layer --- .../slim/quantization/imperative/ptq.py | 190 ++++++++++++++---- .../quantization/imperative/ptq_config.py | 12 +- .../slim/quantization/imperative/ptq_hooks.py | 7 +- .../quantization/imperative/ptq_quantizer.py | 12 +- .../quantization/imperative/ptq_registry.py | 33 ++- .../slim/quantization/imperative/utils.py | 2 +- .../slim/tests/imperative_test_utils.py | 4 +- .../contrib/slim/tests/test_imperative_ptq.py | 56 +++--- 8 files changed, 233 insertions(+), 83 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq.py b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq.py index 75d8509b57413..acaebddb85018 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq.py @@ -25,6 +25,7 @@ from . import utils from . import ptq_hooks from . import ptq_config +from . import ptq_quantizer from .ptq_registry import PTQRegistry __all__ = ['ImperativePTQ'] @@ -56,7 +57,7 @@ def __init__(self, quant_config=ptq_config.default_ptq_config): def quantize(self, model, inplace=False): """ - Add hook to the leaf layer to calculate the threshold of inputs and outputs. + Add quant config and hook to the target layer. Args: model(paddle.nn.Layer): The model to be quantized. @@ -75,8 +76,11 @@ def quantize(self, model, inplace=False): if PTQRegistry.is_supported_layer(layer) \ and utils.is_leaf_layer(layer) \ and not self._is_skip_layer(layer): + # Add quant config quant_config = copy.deepcopy(self._quant_config) + if PTQRegistry.is_simulated_quant_layer(layer): + quant_config.enable_in_act_quantizer = True layer._quant_config = quant_config # register hook @@ -86,8 +90,6 @@ def quantize(self, model, inplace=False): layer._forward_post_hooks.move_to_end( quant_hook_handle._hook_id, last=False) - # TODO(jc): fake quantize the weights - return new_model def save_quantized_model(self, model, path, input_spec=None, **config): @@ -122,8 +124,7 @@ def save_quantized_model(self, model, path, input_spec=None, **config): assert isinstance(model, paddle.nn.Layer), \ "The model must be the instance of paddle.nn.Layer." - model = self._post_process_scales(model) - model = self._wrap_layers(model) + self._convert(model) paddle.jit.save(layer=model, path=path, input_spec=input_spec, **config) @@ -154,7 +155,7 @@ def save_quantized_model(self, model, path, input_spec=None, **config): # propagate the threshold # support skip_quant - # self._save_output_scale(infer_program, scope) + # self._post_process_thresholds(infer_program, scope) # self._set_skip_quant_attr(infer_program) paddle.fluid.io.save_inference_model( @@ -169,24 +170,48 @@ def save_quantized_model(self, model, path, input_spec=None, **config): if is_dynamic_mode: paddle.disable_static() - def _post_process_scales(self, model): + def _convert(self, model): """ - Process the scales and remove the hooks. + Convert the quantized model. Args: - model(paddle.nn.Layer): The model to be quantized. + model(paddle.nn.Layer): The quantized model. + inplace(bool): Whether apply conversion to the input model. + Default: False. Returns: - converted_model(paddle.nn.Layer): The converted model. + None + """ + + for name, sub_layer in model.named_sublayers(): + if self._is_quant_layer(sub_layer): + sub_layer._quant_config.quant_hook_handle.remove() + + self._cal_thresholds(model) + + for name, sub_layer in model.named_sublayers(): + if self._is_quant_layer(sub_layer): + self._save_output_thresholds(sub_layer, sub_layer._quant_config) + + self._wrap_simulated_layers(model) + + def _cal_thresholds(self, model): + """ + Calculate the thresholds of inputs and outputs. + + Args: + model(paddle.nn.Layer): The quantized model. + Returns: + None """ assert isinstance(model, paddle.nn.Layer), \ "The input model must be the instance of paddle.nn.Layer." - # remove hook and calculate thresholds for name, sub_layer in model.named_sublayers(): if self._is_quant_layer(sub_layer): quant_config = sub_layer._quant_config - quant_config.quant_hook_handle.remove() + if quant_config.enable_in_act_quantizer: + quant_config.in_act_quantizer.cal_thresholds() quant_config.out_act_quantizer.cal_thresholds() if PTQRegistry.is_simulated_quant_layer(sub_layer): @@ -194,49 +219,97 @@ def _post_process_scales(self, model): quant_config.wt_quantizer.sample_data(sub_layer, weights) quant_config.wt_quantizer.cal_thresholds() - # save output activation and weight thresholds + def _save_thresholds(self, model): + """ + For all layers in the model, save output activation and weight thresholds. + + Args: + model(paddle.nn.Layer): The quantized model. + Returns: + None + """ + assert isinstance(model, paddle.nn.Layer), \ + "The input model must be the instance of paddle.nn.Layer." + for name, sub_layer in model.named_sublayers(): if self._is_quant_layer(sub_layer): - quant_config = sub_layer._quant_config - layer_info = PTQRegistry.layer_info(sub_layer) + self._save_output_thresholds(sub_layer, sub_layer._quant_config) - output_names = layer_info.output_names - output_thresholds = quant_config.out_act_quantizer.thresholds - assert len(output_names) == 1 - assert len(output_thresholds) == 1 - save_name = output_names[0] + str(0) + "_threshold" - sub_layer._set_op_attrs({save_name: output_thresholds[0]}) - sub_layer._set_op_attrs({"out_threshold": output_thresholds[0]}) + def _save_output_thresholds(self, sub_layer, quant_config): + """ + Save the output thresholds to the layer. - if PTQRegistry.is_simulated_quant_layer(sub_layer): - weight_names = layer_info.weight_names - weight_thresholds = quant_config.wt_quantizer.thresholds - assert len(weight_names) == 1 - assert len(weight_thresholds) == 1 - save_name = weight_names[0] + str(0) + "_threshold" - sub_layer._set_op_attrs({save_name: weight_thresholds[0]}) + Args: + sub_layer(paddle.nn.Layer): The quantized layer. + quant_config(PTQConfig): the quant config for the layer. + Returns: + None + """ + assert isinstance(sub_layer, paddle.nn.Layer), \ + "The input model must be the instance of paddle.nn.Layer." + + layer_info = PTQRegistry.layer_info(sub_layer) + + output_names = layer_info.output_names + output_thresholds = quant_config.out_act_quantizer.thresholds + assert len(output_names) == 1 + assert len(output_thresholds) == 1 + save_name = output_names[0] + str(0) + "_threshold" + sub_layer._set_op_attrs({save_name: output_thresholds[0]}) + sub_layer._set_op_attrs({"out_threshold": output_thresholds[0]}) + + def _save_input_thresholds(self, sub_layer, quant_config): + """ + Save the input thresholds to the layer. + + Args: + sub_layer(paddle.nn.Layer): The quantized layer. + quant_config(PTQConfig): the quant config for the layer. + Returns: + None + """ + assert isinstance(sub_layer, paddle.nn.Layer), \ + "The input model must be the instance of paddle.nn.Layer." + assert quant_config.enable_in_act_quantizer == True - return model + layer_info = PTQRegistry.layer_info(sub_layer) - def _wrap_layers(self, model): + input_names = layer_info.input_names + input_thresholds = quant_config.in_act_quantizer.thresholds + assert len(input_names) == 1 + assert len(input_thresholds) == 1 + save_name = input_names[0] + str(0) + "_threshold" + sub_layer._set_op_attrs({save_name: input_thresholds[0]}) + + weight_names = layer_info.weight_names + weight_thresholds = quant_config.wt_quantizer.thresholds + assert len(weight_names) == 1 + assert len(weight_thresholds) == 1 + save_name = weight_names[0] + str(0) + "_threshold" + sub_layer._set_op_attrs({save_name: weight_thresholds[0]}) + + def _wrap_simulated_layers(self, model): """ Replace conv2d and linear with the quantized layers, and save thresholds into the fake layers. Args: model(paddle.nn.Layer): The model to be quantized. Returns: - modified_model(paddle.nn.Layer): The modified model. + None """ assert isinstance(model, paddle.nn.Layer), \ "The input model must be the instance of paddle.nn.Layer." - # wrap conv2d and linear, save thresholds and quant bits to fake ops for name, sub_layer in model.named_sublayers(): if self._is_quant_layer(sub_layer) \ and PTQRegistry.is_simulated_quant_layer(sub_layer): - parent_layer, sub_name = \ - utils.find_parent_layer_and_sub_name(model, name) + quant_config = sub_layer._quant_config + assert quant_config.enable_in_act_quantizer == True + wt_quantizer = quant_config.wt_quantizer + in_act_quantizer = quant_config.in_act_quantizer + + # create layer quant_layer_name = None for key, value in utils.layer_name_map.items(): if isinstance(sub_layer, value): @@ -244,11 +317,48 @@ def _wrap_layers(self, model): break assert quant_layer_name is not None - # TODO(jc): - # quant_layer = quant_layers.__dict__[quant_layer_name](sub_layer, **self._kwargs) - # setattr(parent_layer, sub_name, quant_layer) - - return model + if isinstance(wt_quantizer, ptq_quantizer.AbsmaxQuantizer): + weight_quantize_type = "abs_max" + else: + weight_quantize_type = "channel_wise_abs_max" + kwargs = { + "weight_quantize_type": weight_quantize_type, + "activation_quantize_type": "moving_average_abs_max", + "weight_bits": wt_quantizer.quant_bits, + "activation_bits": in_act_quantizer.quant_bits, + } + + quant_layer = quant_layers.__dict__[quant_layer_name](sub_layer, + **kwargs) + + # save the input thresholds + assert hasattr(quant_layer, "_fake_quant_input") + assert hasattr(quant_layer._fake_quant_input, "_scale") + assert len(in_act_quantizer.thresholds) == 1 + input_threshold = np.array( + [in_act_quantizer.thresholds[0]], dtype=np.float32) + quant_layer._fake_quant_input._scale.set_value(input_threshold) + + assert hasattr(quant_layer, "_fake_quant_weight") + assert hasattr(quant_layer._fake_quant_weight, "_scale") + assert len(wt_quantizer.thresholds) == 1 + weight_threshold = wt_quantizer.thresholds[0] + if isinstance(weight_threshold, list): + weight_threshold = np.array( + weight_threshold, dtype=np.float32) + else: + weight_threshold = np.array( + [weight_threshold], dtype=np.float32) + quant_layer._fake_quant_weight._scale.set_value( + weight_threshold) + + # save the output thresholds + self._save_output_thresholds(quant_layer, quant_config) + + # replace the layer + parent_layer, sub_name = \ + utils.find_parent_layer_and_sub_name(model, name) + setattr(parent_layer, sub_name, quant_layer) @staticmethod def _is_skip_layer(layer): diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_config.py b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_config.py index bbf6c41c06fbd..1d089b32181d0 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_config.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_config.py @@ -39,16 +39,18 @@ def __init__(self, activation_quantizer, weight_quantizer): It should be the instance of BaseQuantizer. """ super(PTQConfig, self).__init__() + assert isinstance(activation_quantizer, tuple(SUPPORT_ACT_QUANTIZERS)) + assert isinstance(weight_quantizer, tuple(SUPPORT_WT_QUANTIZERS)) - assert isinstance(activation_quantizer, BaseQuantizer) - assert isinstance(weight_quantizer, BaseQuantizer) - - assert activation_quantizer.quant_bits == weight_quantizer.quant_bits - + self.in_act_quantizer = copy.deepcopy(activation_quantizer) self.out_act_quantizer = copy.deepcopy(activation_quantizer) self.wt_quantizer = copy.deepcopy(weight_quantizer) self.quant_hook_handle = None + # In order to wrap simulated layers, use in_act_quantizer + # to calculate the input thresholds for conv2d, linear and etc. + self.enable_in_act_quantizer = False + default_ptq_config = PTQConfig(AbsmaxQuantizer(), AbsmaxQuantizer()) diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_hooks.py b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_hooks.py index 001d4eafd6eb5..41c9b07195aef 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_hooks.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_hooks.py @@ -16,6 +16,7 @@ import math import numpy as np from . import ptq_config +from .ptq_registry import PTQRegistry def quant_forward_post_hook(layer, inputs, outputs): @@ -24,4 +25,8 @@ def quant_forward_post_hook(layer, inputs, outputs): """ assert hasattr(layer, '_quant_config'), \ "The layer should have _quant_config attr" - layer._quant_config.out_act_quantizer.sample_data(layer, (outputs, )) + + qc = layer._quant_config + if qc.enable_in_act_quantizer: + qc.in_act_quantizer.sample_data(layer, inputs) + qc.out_act_quantizer.sample_data(layer, (outputs, )) diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_quantizer.py b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_quantizer.py index 68e9c7fe54b1e..63b3578871710 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_quantizer.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_quantizer.py @@ -24,11 +24,9 @@ from ..cal_kl_threshold import cal_kl_threshold __all__ = [ - 'BaseQuantizer', - 'AbsmaxQuantizer', - 'PerChannelAbsmaxQuantizer', - 'KLQuantizer', - 'HistQuantizer', + 'BaseQuantizer', 'AbsmaxQuantizer', 'PerChannelAbsmaxQuantizer', + 'KLQuantizer', 'HistQuantizer', 'SUPPORT_ACT_QUANTIZERS', + 'SUPPORT_WT_QUANTIZERS' ] @@ -263,3 +261,7 @@ def cal_thresholds(self): bin_width = abs_max_val / hist.shape[0] threshold = cal_kl_threshold(hist, bin_width, self.quant_bits) self.thresholds.append(threshold) + + +SUPPORT_ACT_QUANTIZERS = [AbsmaxQuantizer, HistQuantizer, KLQuantizer] +SUPPORT_WT_QUANTIZERS = [AbsmaxQuantizer, PerChannelAbsmaxQuantizer] diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_registry.py b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_registry.py index 8fe2bc42b7686..a6b8033bc78c9 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_registry.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_registry.py @@ -47,6 +47,13 @@ def __init__(self, layer, input_names, weight_names, output_names): LayerInfo(paddle.nn.quant.add, ['X', 'Y'], [], ['Out']), ] +QUANT_LAYERS_INFO = [ + LayerInfo(paddle.nn.quant.quant_layers.QuantizedConv2D, ['Input'], + ['Filter'], ['Output']), + LayerInfo(paddle.nn.quant.quant_layers.QuantizedLinear, ['X'], ['Y'], + ['Out']), +] + SIMULATED_LAYERS = [paddle.nn.Conv2D, paddle.nn.Linear] @@ -55,6 +62,7 @@ class PTQRegistry(object): Register the supported layers for PTQ and provide layers info. """ supported_layers_map = {} + registered_layers_map = {} is_inited = False def __init__(self): @@ -65,6 +73,10 @@ def _init(cls): if not cls.is_inited: for layer_info in PTQ_LAYERS_INFO: cls.supported_layers_map[layer_info.layer] = layer_info + + all_layers_info = PTQ_LAYERS_INFO + QUANT_LAYERS_INFO + for layer_info in all_layers_info: + cls.registered_layers_map[layer_info.layer] = layer_info cls.is_inited = True @classmethod @@ -80,6 +92,19 @@ def is_supported_layer(cls, layer): return layer in cls.supported_layers_map or \ isinstance(layer, tuple(cls.supported_layers_map.keys())) + @classmethod + def is_registered_layer(cls, layer): + """ + Analyze whether the layer is register layer_info. + Args: + layer(Layer): The input layer can be a python class or an instance. + Returns: + flag(bool): Wether the layer is register layer_info. + """ + cls._init() + return layer in cls.registered_layers_map or \ + isinstance(layer, tuple(cls.registered_layers_map.keys())) + @classmethod def is_simulated_quant_layer(cls, layer): """ @@ -95,15 +120,15 @@ def is_simulated_quant_layer(cls, layer): @classmethod def layer_info(cls, layer): """ - Get the infomation for the supported layer. + Get the infomation for the layer. Args: layer(Layer): The input layer can be a python class or an instance. Returns: layer_info(LayerInfo): The layer info of the input layer. """ - assert cls.is_supported_layer( - layer), "The input layer is not supported." + assert cls.is_registered_layer(layer), \ + "The input layer is not register." - for layer_key, layer_info in cls.supported_layers_map.items(): + for layer_key, layer_info in cls.registered_layers_map.items(): if layer == layer_key or isinstance(layer, layer_key): return layer_info diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py b/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py index cae26a6dbd307..706a904f690dc 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py @@ -69,7 +69,7 @@ ] # The weight format of these layers is Cin * Cout * H * W -spec_channel_axis_layers = [paddle.nn.Conv2D, paddle.nn.Conv2DTranspose] +spec_channel_axis_layers = [paddle.nn.Conv2DTranspose, paddle.nn.Linear] weight_op_types = [ "conv2d", "depthwise_conv2d", "matmul", "conv2d_transpose", diff --git a/python/paddle/fluid/contrib/slim/tests/imperative_test_utils.py b/python/paddle/fluid/contrib/slim/tests/imperative_test_utils.py index cc26f6a88f2e0..5c91f01d0bdda 100644 --- a/python/paddle/fluid/contrib/slim/tests/imperative_test_utils.py +++ b/python/paddle/fluid/contrib/slim/tests/imperative_test_utils.py @@ -128,9 +128,11 @@ def __init__(self, num_classes=10): bias_attr=fc_b3_attr), Softmax()) self.add = paddle.nn.quant.add() + self.quant_stub = paddle.nn.quant.QuantStub() def forward(self, inputs): - x = self.features(inputs) + x = self.quant_stub(inputs) + x = self.features(x) x = fluid.layers.flatten(x, 1) x = self.add(x, paddle.to_tensor(0.0)) # For CI diff --git a/python/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py b/python/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py index 236e4a823d7f2..65aa33eb36030 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py +++ b/python/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py @@ -20,6 +20,7 @@ import shutil import time import unittest +import copy import logging import paddle @@ -59,7 +60,8 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): try: - shutil.rmtree(cls.root_path) + pass + # shutil.rmtree(cls.root_path) except Exception as e: print("Failed to delete {} due to {}".format(cls.root_path, str(e))) @@ -86,6 +88,7 @@ def set_vars(self): self.batch_size = 10 self.eval_acc_top1 = 0.99 + # the input, output and weight thresholds of quantized op self.gt_thresholds = { 'conv2d_0': [[1.0], [0.37673383951187134], [0.10933732241392136]], 'batch_norm2d_0': [[0.37673383951187134], [0.44249194860458374]], @@ -167,17 +170,10 @@ def check_thresholds(self, model): assert hasattr(layer, '_quant_config') quant_config = layer._quant_config - in_val = quant_config.in_act_quantizer.thresholds out_val = quant_config.out_act_quantizer.thresholds wt_val = quant_config.wt_quantizer.thresholds check_num += 1 - self.assertTrue( - np.allclose( - ref_val[0], in_val, atol=1e-3), - "%s | The thresholds(%s) is different " - "from the ground truth(%s)." % - (layer_name, str(in_val), str(ref_val[0]))) self.assertTrue( np.allclose( ref_val[1], out_val, atol=1e-3), @@ -208,30 +204,38 @@ def test_ptq(self): model_state_dict = paddle.load(params_path) model.set_state_dict(model_state_dict) + # Quantize, calibrate and save quant_model = self.ptq.quantize(model) - acc_top1 = self.model_test(quant_model, self.batch_num, - self.batch_size) - print('acc_top1: %s' % acc_top1) - self.assertTrue( - acc_top1 > self.eval_acc_top1, - msg="The test acc {%f} is less than {%f}." % - (acc_top1, self.eval_acc_top1)) + before_acc_top1 = self.model_test(quant_model, self.batch_num, + self.batch_size) - final_model = self.ptq.convert(quant_model) + #self.check_thresholds(final_model) - self.check_thresholds(final_model) + input_spec = [ + paddle.static.InputSpec( + shape=[None, 1, 28, 28], dtype='float32') + ] + self.ptq.save_quantized_model( + model=quant_model, path=self.save_path, input_spec=input_spec) + print('Quantized model saved in {%s}' % self.save_path) - input_spec = [ - paddle.static.InputSpec( - shape=[None, 1, 28, 28], dtype='float32') - ] - paddle.jit.save( - layer=final_model, path=self.save_path, input_spec=input_spec) - print('Quantized model saved in {%s}' % self.save_path) + # Check + after_acc_top1 = self.model_test(quant_model, self.batch_num, + self.batch_size) + print('Before converted acc_top1: %s' % before_acc_top1) + print('After converted acc_top1: %s' % after_acc_top1) + + self.assertTrue( + after_acc_top1 > self.eval_acc_top1, + msg="The test acc {%f} is less than {%f}." % + (after_acc_top1, self.eval_acc_top1)) + self.assertTrue( + np.allclose(before_acc_top1, after_acc_top1), + msg='The acc is lower after converting model.') - end_time = time.time() - print("total time: %ss" % (end_time - start_time)) + end_time = time.time() + print("total time: %ss" % (end_time - start_time)) class TestImperativePTQHist(TestImperativePTQ): From bbda7db3f45ee96c61745aae3366ce4ff7164a82 Mon Sep 17 00:00:00 2001 From: pengjuncai <13006307475@163.com> Date: Tue, 6 Jul 2021 11:36:30 +0000 Subject: [PATCH 3/3] post process the inference model --- .../slim/quantization/imperative/ptq.py | 162 +++++++++++------ .../slim/quantization/imperative/utils.py | 11 ++ .../contrib/slim/tests/test_imperative_ptq.py | 163 ++++++++---------- 3 files changed, 188 insertions(+), 148 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq.py b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq.py index acaebddb85018..b85a4b6637545 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq.py @@ -94,7 +94,9 @@ def quantize(self, model, inplace=False): def save_quantized_model(self, model, path, input_spec=None, **config): """ - Save the quantized model for the inference. + 1. Convert the quantized model + 2. Call jit.save to save the inference model + 3. Load and postprocess the inference model. Args: model (Layer): The model to be saved. @@ -124,10 +126,12 @@ def save_quantized_model(self, model, path, input_spec=None, **config): assert isinstance(model, paddle.nn.Layer), \ "The model must be the instance of paddle.nn.Layer." + # Convert and save dygraph quantized model self._convert(model) paddle.jit.save(layer=model, path=path, input_spec=input_spec, **config) + # Load inference program is_dynamic_mode = False if paddle.in_dynamic_mode(): is_dynamic_mode = True @@ -149,15 +153,12 @@ def save_quantized_model(self, model, path, input_spec=None, **config): model_filename=model_filename, params_filename=params_filename)) - # TODO(jc): - # process the first moving_average_abs_max_scale layer - # fuse conv + bn - # propagate the threshold - # support skip_quant - - # self._post_process_thresholds(infer_program, scope) - # self._set_skip_quant_attr(infer_program) + # Process inference program + self._clean_up(infer_program) + self._gather_input_thresholds(infer_program, scope) + self._remove_scale_op(infer_program) + # Save final program paddle.fluid.io.save_inference_model( dirname=dirname, feeded_var_names=feed_target_names, @@ -219,22 +220,6 @@ def _cal_thresholds(self, model): quant_config.wt_quantizer.sample_data(sub_layer, weights) quant_config.wt_quantizer.cal_thresholds() - def _save_thresholds(self, model): - """ - For all layers in the model, save output activation and weight thresholds. - - Args: - model(paddle.nn.Layer): The quantized model. - Returns: - None - """ - assert isinstance(model, paddle.nn.Layer), \ - "The input model must be the instance of paddle.nn.Layer." - - for name, sub_layer in model.named_sublayers(): - if self._is_quant_layer(sub_layer): - self._save_output_thresholds(sub_layer, sub_layer._quant_config) - def _save_output_thresholds(self, sub_layer, quant_config): """ Save the output thresholds to the layer. @@ -258,36 +243,6 @@ def _save_output_thresholds(self, sub_layer, quant_config): sub_layer._set_op_attrs({save_name: output_thresholds[0]}) sub_layer._set_op_attrs({"out_threshold": output_thresholds[0]}) - def _save_input_thresholds(self, sub_layer, quant_config): - """ - Save the input thresholds to the layer. - - Args: - sub_layer(paddle.nn.Layer): The quantized layer. - quant_config(PTQConfig): the quant config for the layer. - Returns: - None - """ - assert isinstance(sub_layer, paddle.nn.Layer), \ - "The input model must be the instance of paddle.nn.Layer." - assert quant_config.enable_in_act_quantizer == True - - layer_info = PTQRegistry.layer_info(sub_layer) - - input_names = layer_info.input_names - input_thresholds = quant_config.in_act_quantizer.thresholds - assert len(input_names) == 1 - assert len(input_thresholds) == 1 - save_name = input_names[0] + str(0) + "_threshold" - sub_layer._set_op_attrs({save_name: input_thresholds[0]}) - - weight_names = layer_info.weight_names - weight_thresholds = quant_config.wt_quantizer.thresholds - assert len(weight_names) == 1 - assert len(weight_thresholds) == 1 - save_name = weight_names[0] + str(0) + "_threshold" - sub_layer._set_op_attrs({save_name: weight_thresholds[0]}) - def _wrap_simulated_layers(self, model): """ Replace conv2d and linear with the quantized layers, and save @@ -360,6 +315,103 @@ def _wrap_simulated_layers(self, model): utils.find_parent_layer_and_sub_name(model, name) setattr(parent_layer, sub_name, quant_layer) + def _gather_input_thresholds(self, program, scope): + """ + Get and save input thresholds from the front ops. + + Args: + program(Program): the input infer program. + scope(Scope): the corresponding scope for the program. + Returns: + None + """ + for op in utils.program_all_ops(program): + for in_var_name in utils._get_op_input_var_names(op): + previous_op = utils.find_previous_op(op.block, in_var_name) + if previous_op is None: + continue + + if "quantize_dequantize" in previous_op.type or \ + previous_op.type == "moving_average_abs_max_scale": + attr_name = previous_op.output('OutScale')[0] + in_threshold = utils.load_variable_data(scope, attr_name) + in_threshold = utils.fp_numpy_to_naive(in_threshold) + argname, index = utils._get_input_name_index(op, + in_var_name) + op._set_attr(argname + str(index) + "_threshold", + in_threshold) + else: + for out_var_name in utils._get_op_output_var_names( + previous_op): + if out_var_name != in_var_name: + continue + argname, index = utils._get_output_name_index( + previous_op, out_var_name) + attr_name = argname + str(index) + "_threshold" + if not previous_op.has_attr(attr_name): + continue + threshold = previous_op.attr(attr_name) + + argname, index = utils._get_input_name_index( + op, in_var_name) + attr_name = argname + str(index) + "_threshold" + op._set_attr(attr_name, threshold) + + def _clean_up(self, program): + """ + Remove useless thresholds which are added in jit.save. + + Args: + program(Program): the input infer program. + Returns: + None + """ + + def _helper(op, next_op, old_attr_name, new_attr_name): + if op.has_attr(old_attr_name) and next_op.has_attr(old_attr_name) \ + and op.attr(old_attr_name) == next_op.attr(old_attr_name): + threshold = op.attr(old_attr_name) + op._remove_attr(old_attr_name) + next_op._remove_attr(old_attr_name) + next_op._set_attr(new_attr_name, threshold) + + for op in utils.program_all_ops(program): + if "quantize_dequantize" in op.type: + # remove the thresholds in fake ops + for attr_name in op.attr_names: + if "_threshold" in attr_name: + op._remove_attr(attr_name) + elif op.type in ["conv2d", "matmul"]: + # change the thresholds in conv2d/matmul + eleadd + arg_name = "Output" if op.type == "conv2d" else "Out" + out_var_name = op.output(arg_name)[0] + next_ops = utils.find_next_ops(op.block, out_var_name) + if len(next_ops) > 1 or next_ops[0].type != "elementwise_add": + continue + next_op = next_ops[0] + + argname, index = utils._get_output_name_index(op, out_var_name) + old_attr_name = argname + str(index) + "_threshold" + + argname, index = utils._get_output_name_index( + next_op, next_op.output("Out")[0]) + new_attr_name = argname + str(index) + "_threshold" + + _helper(op, next_op, old_attr_name, new_attr_name) + _helper(op, next_op, "out_threshold", "out_threshold") + + def _remove_scale_op(self, program): + """ + Remove the moving_average_abs_max_scale op. + """ + for op in utils.program_all_ops(program): + if op.type == "moving_average_abs_max_scale": + in_var_name = op.input("X")[0] + out_var_name = op.output("Out")[0] + next_ops = utils.find_next_ops(op.block, out_var_name) + for next_op in next_ops: + next_op._rename_input(out_var_name, in_var_name) + @staticmethod def _is_skip_layer(layer): return hasattr(layer, "skip_quant") and layer.skip_quant == True diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py b/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py index 706a904f690dc..a9d52c5a87ad3 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py @@ -139,6 +139,17 @@ def find_parent_layer_and_sub_name(model, name): return parent_layer, sub_name +def program_all_ops(program): + """ + Return all ops for the input program. + """ + all_ops = [] + for block in program.blocks: + for op in block.ops: + all_ops.append(op) + return all_ops + + def is_leaf_layer(layer): """ Whether the layer is leaf layer. diff --git a/python/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py b/python/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py index 65aa33eb36030..24ae75456a014 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py +++ b/python/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py @@ -86,7 +86,7 @@ def set_vars(self): self.batch_num = 10 self.batch_size = 10 - self.eval_acc_top1 = 0.99 + self.eval_acc_top1 = 0.95 # the input, output and weight thresholds of quantized op self.gt_thresholds = { @@ -99,36 +99,6 @@ def set_vars(self): 'add_0': [[1.7058950662612915, 0.0], [1.7058950662612915]], } - def model_train(self, model, train_reader, max_step=-1): - model.train() - adam = paddle.optimizer.Adam( - learning_rate=0.001, parameters=model.parameters()) - - for batch_id, data in enumerate(train_reader()): - x_data = np.array([x[0].reshape(1, 28, 28) - for x in data]).astype('float32') - y_data = np.array( - [x[1] for x in data]).astype('int64').reshape(-1, 1) - - img = paddle.to_tensor(x_data) - label = paddle.to_tensor(y_data) - - out = model(img) - acc = fluid.layers.accuracy(out, label) - loss = fluid.layers.cross_entropy(out, label) - avg_loss = fluid.layers.mean(loss) - avg_loss.backward() - - adam.minimize(avg_loss) - model.clear_gradients() - - if batch_id % 100 == 0: - _logger.info("Train | step {}: loss = {:}, acc= {:}".format( - batch_id, avg_loss.numpy(), acc.numpy())) - - if max_step > 0 and batch_id > max_step: # For shortening CI time - break - def model_test(self, model, batch_num=-1, batch_size=8): model.eval() @@ -148,9 +118,9 @@ def model_test(self, model, batch_num=-1, batch_size=8): out = model(img) acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5) + eval_acc_top1_list.append(float(acc_top1.numpy())) - if batch_id % 100 == 0: - eval_acc_top1_list.append(float(acc_top1.numpy())) + if batch_id % 50 == 0: _logger.info("Test | At step {}: acc1 = {:}, acc5 = {:}".format( batch_id, acc_top1.numpy(), acc_top5.numpy())) @@ -161,81 +131,88 @@ def model_test(self, model, batch_num=-1, batch_size=8): return eval_acc_top1 - def check_thresholds(self, model): - check_num = 0 - for name, layer in model.named_sublayers(): - layer_name = layer.full_name() - if layer_name in self.gt_thresholds: - ref_val = self.gt_thresholds[layer_name] - assert hasattr(layer, '_quant_config') - - quant_config = layer._quant_config - out_val = quant_config.out_act_quantizer.thresholds - wt_val = quant_config.wt_quantizer.thresholds - check_num += 1 - - self.assertTrue( - np.allclose( - ref_val[1], out_val, atol=1e-3), - "%s | The thresholds(%s) is different " - "from the ground truth(%s)." % - (layer_name, str(out_val), str(ref_val[1]))) - if len(ref_val) > 2 and ref_val[2] != []: - self.assertTrue( - np.allclose( - ref_val[2], wt_val, atol=1e-3), - "%s | The thresholds(%s) is different " - "from the ground truth(%s)." % - (layer_name, str(wt_val), str(ref_val[2]))) - - self.assertTrue(check_num == len(self.gt_thresholds)) + def program_test(self, program_path, batch_num=-1, batch_size=8): + exe = paddle.static.Executor(paddle.CPUPlace()) + [inference_program, feed_target_names, fetch_targets] = ( + paddle.static.load_inference_model(program_path, exe)) + + test_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=batch_size) + + top1_correct_num = 0. + total_num = 0. + for batch_id, data in enumerate(test_reader()): + img = np.array([x[0].reshape(1, 28, 28) + for x in data]).astype('float32') + label = np.array([x[1] for x in data]).astype('int64') + + feed = {feed_target_names[0]: img} + results = exe.run(inference_program, + feed=feed, + fetch_list=fetch_targets) + + pred = np.argmax(results[0], axis=1) + top1_correct_num += np.sum(np.equal(pred, label)) + total_num += len(img) + + if total_num % 50 == 49: + _logger.info("Test | Test num {}: acc1 = {:}".format( + total_num, top1_correct_num / total_num)) + + if batch_num > 0 and batch_id + 1 >= batch_num: + break + return top1_correct_num / total_num def test_ptq(self): start_time = time.time() self.set_vars() + # Load model params_path = self.download_model(self.lenet_url, self.lenet_md5, "lenet") params_path += "/lenet_pretrained/lenet.pdparams" - with fluid.dygraph.guard(): - model = ImperativeLenet() - model_state_dict = paddle.load(params_path) - model.set_state_dict(model_state_dict) + model = ImperativeLenet() + model_state_dict = paddle.load(params_path) + model.set_state_dict(model_state_dict) - # Quantize, calibrate and save - quant_model = self.ptq.quantize(model) + # Quantize, calibrate and save + quant_model = self.ptq.quantize(model) + before_acc_top1 = self.model_test(quant_model, self.batch_num, + self.batch_size) - before_acc_top1 = self.model_test(quant_model, self.batch_num, - self.batch_size) + input_spec = [ + paddle.static.InputSpec( + shape=[None, 1, 28, 28], dtype='float32') + ] + self.ptq.save_quantized_model( + model=quant_model, path=self.save_path, input_spec=input_spec) + print('Quantized model saved in {%s}' % self.save_path) - #self.check_thresholds(final_model) + after_acc_top1 = self.model_test(quant_model, self.batch_num, + self.batch_size) - input_spec = [ - paddle.static.InputSpec( - shape=[None, 1, 28, 28], dtype='float32') - ] - self.ptq.save_quantized_model( - model=quant_model, path=self.save_path, input_spec=input_spec) - print('Quantized model saved in {%s}' % self.save_path) + paddle.enable_static() + infer_acc_top1 = self.program_test(self.save_path, self.batch_num, + self.batch_size) + paddle.disable_static() - # Check - after_acc_top1 = self.model_test(quant_model, self.batch_num, - self.batch_size) - print('Before converted acc_top1: %s' % before_acc_top1) - print('After converted acc_top1: %s' % after_acc_top1) + # Check + print('Before converted acc_top1: %s' % before_acc_top1) + print('After converted acc_top1: %s' % after_acc_top1) + print('Infer acc_top1: %s' % infer_acc_top1) - self.assertTrue( - after_acc_top1 > self.eval_acc_top1, - msg="The test acc {%f} is less than {%f}." % - (after_acc_top1, self.eval_acc_top1)) - self.assertTrue( - np.allclose(before_acc_top1, after_acc_top1), - msg='The acc is lower after converting model.') + self.assertTrue( + after_acc_top1 >= self.eval_acc_top1, + msg="The test acc {%f} is less than {%f}." % + (after_acc_top1, self.eval_acc_top1)) + self.assertTrue( + infer_acc_top1 >= after_acc_top1, + msg='The acc is lower after converting model.') - end_time = time.time() - print("total time: %ss" % (end_time - start_time)) + end_time = time.time() + print("total time: %ss \n" % (end_time - start_time)) class TestImperativePTQHist(TestImperativePTQ): @@ -245,7 +222,7 @@ def set_vars(self): self.batch_num = 10 self.batch_size = 10 - self.eval_acc_top1 = 0.99 + self.eval_acc_top1 = 0.98 self.gt_thresholds = { 'conv2d_0': @@ -266,7 +243,7 @@ def set_vars(self): self.batch_num = 10 self.batch_size = 10 - self.eval_acc_top1 = 0.99 + self.eval_acc_top1 = 1.0 conv2d_1_wt_thresholds = [ 0.18116560578346252, 0.17079241573810577, 0.1702047884464264,