From f346c60287b50950275e20db9e6d84b3fc568a00 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 9 Mar 2020 13:14:58 -0700 Subject: [PATCH] Revert "[Torch, QNN] Add support for quantized models via QNN (#4977)" (#5013) This reverts commit fc7f0783940c362bf48cd46817956381196201e2. --- python/tvm/relay/frontend/pytorch.py | 88 +-- python/tvm/relay/frontend/qnn_torch.py | 692 ------------------ tests/python/frontend/pytorch/qnn_test.py | 455 ------------ tests/python/frontend/pytorch/test_forward.py | 6 - 4 files changed, 9 insertions(+), 1232 deletions(-) delete mode 100644 python/tvm/relay/frontend/qnn_torch.py delete mode 100644 tests/python/frontend/pytorch/qnn_test.py diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index ff37f823f28b..e284e481d272 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -19,7 +19,6 @@ # pylint: disable=import-outside-toplevel, simplifiable-if-expression, unnecessary-comprehension """PT: PyTorch frontend.""" import itertools -import logging import numpy as np @@ -33,8 +32,6 @@ from .common import infer_shape as _infer_shape from .common import infer_value as _infer_value -from . import qnn_torch - __all__ = ["from_pytorch"] # operator implementation @@ -149,10 +146,6 @@ def _impl(inputs, input_types): def _relu(): def _impl(inputs, input_types): data = inputs[0] - if input_types[0] == "quint8": - assert len(inputs) == 3, "Input quant param not found in op inputs" - input_zero_point = _expr.const(inputs[2], dtype="int32") - return qnn_torch.quantized_relu(data, input_zero_point) return _op.nn.relu(data) return _impl @@ -161,14 +154,9 @@ def _impl(inputs, input_types): data = inputs[0] output_size = _infer_shape(inputs[1]) - def func(x): - return _op.nn.adaptive_avg_pool2d(x, output_size=output_size) - - if input_types[0] == "quint8": - return qnn_torch.quantized_adaptive_avg_2d(data, func) - - return func(data) - + return _op.nn.adaptive_avg_pool2d( + data, + output_size=output_size) return _impl def _adaptive_max_2d(): @@ -518,18 +506,7 @@ def _impl(inputs, input_types): else: exclude = False - def func(x): - return _op.mean(x, axis, keepdims, exclude) - - if input_types[0] == "quint8": - assert len(inputs) == 6, "Input quant param not found in op inputs" - input_scale = _expr.const(inputs[4]) - input_zero_point = _expr.const(inputs[5]) - return qnn_torch.quantized_mean(data, input_scale, - input_zero_point, func) - - return func(data) - + return _op.mean(data, axis, keepdims, exclude) return _impl def _chunk(): @@ -691,40 +668,10 @@ def _impl(inputs, input_types): else: coord_trans = "half_pixel" - def func(x): - return _op.image.resize(x, out_size, "NCHW", method, coord_trans) - - if input_types[0] == "quint8": - import torch - from packaging import version - - # Torch version > 1.4 changed upsampling API - if version.parse(torch.__version__) > version.parse("1.4.0"): - num_inputs = 7 - else: - num_inputs = 5 - - assert len(inputs) == num_inputs, "Input quant param not found in op inputs" - - input_scale = _expr.const(inputs[-2]) - input_zero_point = _expr.const(inputs[-1]) - return qnn_torch.quantized_upsample(data, input_scale, - input_zero_point, func) - return func(data) + return _op.image.resize(data, out_size, "NCHW", method, coord_trans) return _impl - -def _expand_as(): - def _impl(inputs, input_types): - # TODO: maybe fix this - # This assumes expand_as can be removed because TVM has broadcast op - msg = "aten::expand_as(...) found, assume it is part of broadcast op" - logging.warning(msg) - return inputs[0] - return _impl - - # Helper functions for operator implementation def _convert_data_type(input_type): @@ -845,7 +792,6 @@ def _convert_elemwise_input(data, input_type): "aten::detach" : _identity(), "aten::upsample_bilinear2d" : _upsample("bilinear"), "aten::upsample_nearest2d" : _upsample("nearest_neighbor"), - "aten::expand_as" : _expand_as() } @@ -896,7 +842,6 @@ def _report_missing_conversion(op_names): "prim::ListConstruct", "prim::ListUnpack", "prim::TupleConstruct", "prim::TupleUnpack"] known_ops += list(_convert_map.keys()) - known_ops += list(qnn_torch.convert_map.keys()) missing = [op_name for op_name in op_names if op_name not in known_ops] @@ -1063,7 +1008,6 @@ def parse_params(graph, state_dict): getattr_nodes = graph.findAllNodes("prim::GetAttr", recurse=True) params = {} param_tensors = {} - packed_param_map = {} seen = set() for node in getattr_nodes: @@ -1076,18 +1020,14 @@ def parse_params(graph, state_dict): full_attr = _getattr_full_name(getattrs) full_attr_node_name = _get_output_name(getattrs[-1]) - if full_attr.endswith("_packed_params"): # for quantized models - err_msg = "parameter %s not found in state dict" % full_attr - assert full_attr in state_dict, err_msg - packed_param_map[full_attr_node_name] = full_attr - elif full_attr in state_dict: + if full_attr in state_dict: torch_tensor = state_dict[full_attr] tensor, var = _get_tensor_and_var(torch_tensor, full_attr_node_name) param_tensors[full_attr_node_name] = tensor params[full_attr_node_name] = var - return params, param_tensors, packed_param_map + return params, param_tensors def parse_operators(operators, outputs, output_index_map, ret_name): @@ -1168,26 +1108,16 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): params = script_module.state_dict() input_vars = parse_inputs(graph.inputs(), input_shapes) - param_vars, tensors, packed_param_map = parse_params(graph, params) - tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()} + param_vars, tensors = parse_params(graph, params) input_vars.update(param_vars) outputs = list(input_vars.values()) output_index_map = dict(zip(input_vars.keys(), range(len(outputs)))) ret_name = _get_input_names(graph.return_node())[0] - # For quantized models - if "aten::quantize_per_tensor" in op_names: - weight_quant_params = qnn_torch.get_weight_quant_params(script_module) - qnn_torch.add_input_quant_params_to_op_inputs(graph) - qnn_torch.add_quant_params_to_outputs(outputs, output_index_map, - packed_param_map, - weight_quant_params) - qnn_torch.add_quant_params(tvm_params, weight_quant_params) - _convert_map.update(qnn_torch.convert_map) - body = parse_operators(_get_operator_nodes(graph.nodes()), outputs, output_index_map, ret_name) func = tvm.relay.Function(_analysis.free_vars(body), body) + tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()} return _module.IRModule.from_expr(func), tvm_params diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py deleted file mode 100644 index 0704e34b77ef..000000000000 --- a/python/tvm/relay/frontend/qnn_torch.py +++ /dev/null @@ -1,692 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name, import-outside-toplevel -""" Functions to convert quantized torch models to QNN """ - -import numpy as np - -import tvm -from tvm import relay -from tvm.relay import expr as _expr -from tvm.relay import op as _op -from tvm.relay.frontend.common import infer_shape - - -class QNNParam: - """ A placeholder for weight quantization parameters """ - - def __init__(self, weight, bias, scale, zero_point, param_key): - param_prefix = param_key[:-len("._packed_params")] - self.weight_var = _expr.var(param_prefix + "_weight", - shape=weight.shape) - self.weight = weight - - if bias is not None: - self.bias_var = _expr.var(param_prefix + "_bias", - shape=bias.shape) - self.bias = bias.detach().numpy() - else: - self.bias_var = None - self.bias = None - - self.scale = _expr.const(scale) - self.zero_point = _expr.const(zero_point, dtype="int32") - - -def _unpack_quant_params(param_name, packed_params, unpack_func): - # Torch stores quantized params in a custom packed format, - # need to unpack and retrieve them as numpy arrays - qweight, bias = unpack_func(packed_params) - weight_np = qweight.dequantize().numpy() - - import torch - if qweight.qscheme() == torch.per_tensor_affine: - param = QNNParam(weight_np, bias, qweight.q_scale(), - int(qweight.q_zero_point()), param_name) - else: - scales = qweight.q_per_channel_scales().numpy() - zero_points = qweight.q_per_channel_zero_points().numpy() - # This is an assumption posed by QNN - msg = "The values of zero points should be all zero for per channel" - assert np.all(zero_points == 0), msg - param = QNNParam(weight_np, bias, scales, 0, param_name) - - return param - - -def get_weight_quant_params(script_module): - """ Retrive and unpack weight parameters from quantized modules """ - conv_packed_params = [] - linear_packed_params = [] - - import torch - # conv and linear requires different unpacking function - # extract all conv and linear parameters separately to distinguish them - for name, m in script_module.named_modules(): - if isinstance(m, torch.jit.RecursiveScriptModule): - if "Conv" in m.original_name: - conv_packed_params.append((name, m.state_dict())) - elif m.original_name == "LinearPackedParams": - linear_packed_params.append((name, m.state_dict())) - - pairs = [(torch.ops.quantized.conv2d_unpack, conv_packed_params), - (torch.ops.quantized.linear_unpack, linear_packed_params)] - - quant_params = {} - param_name = "_packed_params" - for unpack_func, params in pairs: - for name, state_dict in params: - assert len(state_dict) == 1 - assert param_name in state_dict - key = name + "." + param_name - packed_param = state_dict[param_name] - quant_params[key] = _unpack_quant_params(key, packed_param, - unpack_func) - - return quant_params - - -def add_quant_params_to_outputs(outputs, output_index_map, - packed_param_map, quant_params): - """ - Add quant params to outputs so that they can be referenced by other - ops later. Weights are quantized here. - """ - for node_name, packed_param_name in packed_param_map.items(): - qparam = quant_params[packed_param_name] - output_index_map[node_name] = len(outputs) - qweight = relay.qnn.op.quantize(qparam.weight_var, qparam.scale, - qparam.zero_point, out_dtype="int8", - axis=0) - param_tup = (qweight, qparam.scale, qparam.zero_point, qparam.bias_var) - outputs.append(param_tup) - - -def _get_quant_param_for_input(input_value): - """ - We want to know the input scale and zp of this input_value, since - input quant params are not explicitly passed around in torch (they - are embeded in a QTensor data structure, not visible statically). - We know that it is quantized using output scale and zp - of some previous quantized op. The purpose of this function - is to find that pair of parameters. - """ - # Indices for output scale and zp - # For example, in quantized::conv2d(%input, %1, %2, %3, %4, %5, %6, %7), - # 6th and 7th arg are output scale and zp respectively. - output_quant_param_indices = { - "aten::quantize_per_tensor": (1, 2), - "quantized::conv2d": (6, 7), - "quantized::conv2d_relu": (6, 7), - "quantized::linear": (2, 3), - "quantized::linear_relu": (2, 3), - "quantized::add_relu": (2, 3), - "quantized::add": (2, 3), - "quantized::mul_relu": (2, 3), - "quantized::mul": (2, 3), - "quantized::cat": (2, 3), - "quantized::mul_scalar": (2, 3), - "quantized::add_scalar": (2, 3) - } - - def dfs(current_node): - # trace back to find the producer of this input value - current_op = current_node.kind() - if current_op in output_quant_param_indices: - indices = output_quant_param_indices[current_op] - scale = current_node.inputsAt(indices[0]) - zp = current_node.inputsAt(indices[1]) - return scale, zp - - # Trace back eariler nodes, dfs order - # Assume quantized tensor comes earlier in the args - for arg in current_node.inputs(): - return dfs(arg.node()) - - # shouldn't happen - assert False, "No producer for %s" % (str(current_node)) - - return dfs(input_value.node()) - - -def _get_add_scalar_output_quant_param(input_scale, input_zero_point, - scalar): - """ - Determine the output scale and zp of quantized::add_scalar op - This is used for mobilenet v3 - Refer to aten/src/ATen/native/quantized/cpu/qadd.cpp - The names of variables are the same as torch impl - """ - q_min = 0 - q_max = 255 - s = input_scale - z = input_zero_point - c = scalar - c_q = round(c / s) - - if q_min > z - c_q: - s_prime = (float(q_max) - (z - c_q)) / (float(q_max) - q_min) * s - z_prime = q_min - elif q_max < z - c_q: - s_prime = (float(z - c_q) - q_min) / (float(q_max) - q_min) * s - z_prime = q_max - else: - s_prime = s - z_prime = z - c_q - - return s_prime, z_prime - - -def _get_mul_scalar_output_quant_param(input_scale, input_zero_point, - scalar): - """ - Determine the output scale and zp of quantized::mul_scalar op - This is used for mobilenet v3 - Refer to aten/src/ATen/native/quantized/cpu/qmul.cpp - The names of variables are the same as torch impl - """ - q_min = 0 - q_max = 255 - self_scale = input_scale - self_zero_point = input_zero_point - other_val = scalar - - if other_val > 0.0: - s_prime = other_val * self_scale - z_prime = self_zero_point - elif other_val == 0.0: - s_prime = 1.0 - z_prime = 0 - else: - s_prime = abs(other_val) * self_scale - z_prime = q_max - (self_zero_point - q_min) - - return s_prime, z_prime - - -def _add_output_quant_params_to_scalar_op(node, graph, - input_scale, input_zero_point, - scalar): - """ - The output scale and zp of {add,mul}_scalar op are not explicit in the IR - They are required for _get_quant_param_for_input above to work correctly - So calculate these params using the same way torch does, and make new - constant nodes in the input IR. Also add these params to the inputs of - scalar op. - - For example, - %6 : float = prim::Constant[value=3.]() - %input : QUInt8(1, 3, 224, 224) = quantized::add_scalar(%x.1, %6) - becomes - %6 : float = prim::Constant[value=3.]() - %7 : float = prim::Constant[value=0.015686161816120148]() - %8 : int = prim::Constant[value=0]() - %input : UInt8(1, 3, 224, 224) = quantized::add_scalar(%x.1, %6, %7, %8) - - %7 and %8 are newly created output scale and zp constant nodes - """ - import torch - operator = node.kind() - - if operator == "quantized::mul_scalar": - out_scale, out_zero_point = \ - _get_mul_scalar_output_quant_param(input_scale, input_zero_point, - scalar) - elif operator == "quantized::add_scalar": - out_scale, out_zero_point = \ - _get_add_scalar_output_quant_param(input_scale, input_zero_point, - scalar) - else: - raise NotImplementedError("unsupported scalar op: %s" % operator) - - # create new constant nodes and add them to graph - out_scale_node = graph.create("prim::Constant") - out_zero_point_node = graph.create("prim::Constant") - out_scale_node.insertBefore(node) - out_zero_point_node.insertBefore(node) - out_scale_node.f_("value", out_scale) - out_zero_point_node.i_("value", out_zero_point) - out_scale_node.output().setType(torch._C.FloatType.get()) - out_zero_point_node.output().setType(torch._C.IntType.get()) - node.addInput(out_scale_node.output()) - node.addInput(out_zero_point_node.output()) - - -def add_input_quant_params_to_op_inputs(graph): - """ - In Torch, input quant params are not explicitly passed around - Instead, they are stored in QTensor data structure, and retrieved - at runtime by each quantized ops. - However, they need to be known statically for QNN translation. - To workaround and simplify the translation of inputs, we manually add - input quant params to inputs of Torch quantized operators listed below. - See _quantized_conv2d() below for example of why this is helpful. - - For example, - %input : QUInt8(1, 512, 7, 7) = quantized::add(%x.8, %x.9, %434, %435) - becomes - %395 : float = prim::Constant[value=0.036212071776390076]() - %396 : int = prim::Constant[value=0]() - %430 : float = prim::Constant[value=0.16080744564533234]() - %431 : int = prim::Constant[value=42]() - %input : QUInt8(1, 512, 7, 7) = quantized::add(%x.8, %x.9, %434, %435, - %430, %431, %395, %396) - - %434, %435 are output scale and zp of quantized::add op - %430, %431, %395, %396 are two pairs of input (scale, zp) for two tensors - added by this function - """ - # How many quantized tensors each op takes as inputs? - # A pair of (scale, zp) for each input quantized tensor will be added - # to the input nodes - num_quantized_inputs = {"quantized::conv2d": 1, - "quantized::conv2d_relu": 1, - "quantized::linear": 1, - "quantized::linear_relu": 1, - "quantized::add_relu": 2, - "quantized::add": 2, - "quantized::mul_relu": 2, - "quantized::mul": 2, - "aten::dequantize": 1, - "aten::mean": 1, - "aten::upsample_bilinear2d": 1, - "aten::relu_": 1, - "aten::relu": 1, - "quantized::add_scalar": 1, - "quantized::mul_scalar": 1, - 'quantized::relu6': 1} - - need_input_quant_param = set(num_quantized_inputs.keys()) - need_input_quant_param.add("quantized::cat") - - for node in graph.nodes(): - operator = node.kind() - if operator not in need_input_quant_param: - continue - - input_scales = [] - input_zero_points = [] - - if operator == "quantized::cat": - # the number of inputs to concat is not constant - # so handle it separately - inputs = node.inputsAt(0).node().inputs() - for inp in inputs: - scale, zp = _get_quant_param_for_input(inp) - input_scales.append(scale) - input_zero_points.append(zp) - else: - for i in range(num_quantized_inputs[operator]): - scale, zp = _get_quant_param_for_input(node.inputsAt(i)) - input_scales.append(scale) - input_zero_points.append(zp) - - if operator in ["quantized::add_scalar", "quantized::mul_scalar"]: - scalar = node.inputsAt(1).node().f("value") - inp_scale = input_scales[0].node().f("value") - inp_zero_point = input_zero_points[0].node().i("value") - - # see the comments in this function above - _add_output_quant_params_to_scalar_op(node, graph, - inp_scale, inp_zero_point, - scalar) - - for scale, zp in zip(input_scales, input_zero_points): - node.addInput(scale) - node.addInput(zp) - - -def add_quant_params(params, quant_params): - """ Add quant parameters to TVM param map """ - for qparam in quant_params.values(): - params[qparam.weight_var.name_hint] = tvm.nd.array(qparam.weight) - if qparam.bias is not None: - params[qparam.bias_var.name_hint] = tvm.nd.array(qparam.bias) - - -def quantized_adaptive_avg_2d(data, func_fp32): - # this follows tflite impl - inp = _op.cast(data, dtype="int32") - out = func_fp32(inp) - return _op.cast(out, "uint8") - - -def quantized_mean(data, input_scale, input_zero_point, func_fp32): - # refer to aten/src/ATen/native/quantized/cpu/qreduction.cpp - dequantized = relay.qnn.op.dequantize(data, input_scale, input_zero_point) - out = func_fp32(dequantized) - return relay.qnn.op.quantize(out, input_scale, input_zero_point, - out_dtype="uint8", axis=1) - - -def quantized_upsample(data, input_scale, input_zero_point, func_fp32): - # currently piggy backs to fp32, it gets identical output as torch - data = relay.qnn.op.dequantize(data, input_scale, input_zero_point) - out = func_fp32(data) - return relay.qnn.op.quantize(out, input_scale, input_zero_point, - out_dtype="uint8", axis=1) - - -def quantized_relu(data, input_zero_point): - # refer to aten/src/ATen/native/quantized/cpu/qrelu.cpp - zp = _op.cast(input_zero_point, dtype="uint8") - return _op.tensor.maximum(data, zp) - - -def _quantize_per_tensor(): - def _impl(inputs, _): - return relay.qnn.op.quantize(inputs[0], _expr.const(inputs[1]), - _expr.const(inputs[2]), out_dtype="uint8", - axis=1) - return _impl - - -def _dequantize(): - def _impl(inputs, _): - assert len(inputs) == 3, "Input quant params not found in op inputs" - inp_scale = _expr.const(inputs[1]) - inp_zero_point = _expr.const(inputs[2]) - return relay.qnn.op.dequantize(inputs[0], inp_scale, inp_zero_point) - return _impl - - -def _get_numpy(relay_const_scalar): - return relay_const_scalar.data.asnumpy() - - -def _get_scalar(relay_const_scalar): - return np.asscalar(_get_numpy(relay_const_scalar)) - - -def _do_bias_and_requantize(output, bias, input_scale, weight_scale, - output_scale, output_zero_point, - with_relu): - """ Output processing for conv and linear """ - # this is a vector for per channel case - requant_input_scale = _expr.const(_get_numpy(input_scale) * - _get_numpy(weight_scale)) - # Torch does bias add and requanize scale in fp32 - # refer to third_party/fbgemm/include/fbgemm/OutputProcessing-inl.h - # Instead, we do bias add in int32 and use qnn requantize, which needs - # integer input. - # We observed no loss in accuracy in doing this way, and it is better - # for tvm because bias quantization can be done at compile time - # Instead, the torch way requires rounding of activation at runtime - - if bias is not None: - qbias = relay.qnn.op.quantize(bias, requant_input_scale, - _expr.const(0, "int32"), - out_dtype="int32", axis=0) - requantize_input = _op.nn.bias_add(output, qbias) - else: - requantize_input = output - - requantized = relay.qnn.op.requantize(requantize_input, - requant_input_scale, - relay.const(0, 'int32'), - output_scale, output_zero_point, - out_dtype="int32", axis=1) - clip_min = 0 - if with_relu: - clip_min = _get_scalar(output_zero_point) - - clip = _op.tensor.clip(requantized, clip_min, 255.) - return _op.cast(clip, dtype="uint8") - - -def _quantized_conv2d(with_relu=False): - def _impl(inputs, _): - # refer to src/ATen/native/quantized/cpu/qconv.cpp - # inputs[0]: input tensor - # inputs[1]: (weight, scale, zero_point, bias) - # inputs[2-5]: stride, padding, dilation, groups - # inputs[6]: output_scale - # inputs[7]: output_zero_point - # inputs[8]: input_scale (added manually by frontend) - # inputs[9]: input_zero_point (added manually by frontend) - weight = inputs[1][0] - weight_scale = inputs[1][1] - weight_zero_point = inputs[1][2] - - output_scale = _expr.const(inputs[6]) - output_zero_point = _expr.const(inputs[7]) - - assert len(inputs) == 10, "Input quant params not found in op inputs" - # These are manually added by add_input_quant_params_to_op_inputs above - # In torch, they are retrieved from QTensor data structure at runtime - input_scale = _expr.const(inputs[8]) - input_zero_point = _expr.const(inputs[9]) - - strides, padding, dilation = inputs[2], inputs[3], inputs[4] - strides = infer_shape(inputs[2]) - padding = infer_shape(inputs[3]) - dilation = infer_shape(inputs[4]) - groups = inputs[5] - - weight_shape = infer_shape(weight) - kernel_size = (weight_shape[2], weight_shape[3]) - out_channels = weight_shape[0] - - if padding[0] != 0 or padding[1] != 0: - pad_val = _get_scalar(input_zero_point) - inp = _op.nn.pad(inputs[0], pad_width=((0, 0), - (0, 0), - (padding[0], padding[0]), - (padding[1], padding[1])), - pad_value=float(pad_val)) - else: - inp = inputs[0] - - # padding is (0, 0) because we did explicit pad op with - # pad value being zero point above - conv_out = relay.qnn.op.conv2d(inp, weight, - input_zero_point, weight_zero_point, - input_scale, weight_scale, - kernel_size=kernel_size, - dilation=dilation, strides=strides, - padding=(0, 0), groups=groups, - channels=out_channels) - bias_var = inputs[1][3] - - return _do_bias_and_requantize(conv_out, bias_var, input_scale, - weight_scale, output_scale, - output_zero_point, with_relu) - - return _impl - - -def _linear(with_relu=False): - # similar to conv - def _impl(inputs, _): - weight = inputs[1][0] - weight_scale = inputs[1][1] - weight_zero_point = inputs[1][2] - output_scale = _expr.const(inputs[2]) - output_zero_point = _expr.const(inputs[3]) - assert len(inputs) == 6, "Input quant params not found in op inputs" - # Manually added by add_input_quant_params_to_op_inputs above - input_scale = _expr.const(inputs[4]) - input_zero_point = _expr.const(inputs[5]) - - weight_shape = infer_shape(weight) - dense = relay.qnn.op.dense(inputs[0], weight, - input_zero_point, weight_zero_point, - input_scale, weight_scale, - units=weight_shape[0]) - bias_var = inputs[1][3] - - return _do_bias_and_requantize(dense, bias_var, input_scale, - weight_scale, output_scale, - output_zero_point, with_relu) - - return _impl - - -def _binop(relay_op, with_relu=False): - # refer to aten/src/ATen/native/quantized/cpu/{qadd, qmul}.cpp - # they piggy backs to fp32 math by dequantize -> fp32 math -> quantize - def _impl(inputs, _): - output_scale = _expr.const(inputs[2]) - output_zero_point = _expr.const(inputs[3]) - assert len(inputs) == 8, "Input quant params not found in op inputs" - # Manually added by add_input_quant_params_to_op_inputs above - input_scale_lhs = _expr.const(inputs[4]) - input_zero_point_lhs = _expr.const(inputs[5]) - input_scale_rhs = _expr.const(inputs[6]) - input_zero_point_rhs = _expr.const(inputs[7]) - lhs = inputs[0] - rhs = inputs[1] - - if isinstance(lhs, _expr.Call) and lhs.op.name == 'qnn.quantize': - lhs = lhs.args[0] - else: - lhs = relay.qnn.op.dequantize(lhs, - input_scale_lhs, - input_zero_point_lhs) - - if isinstance(rhs, _expr.Call) and rhs.op.name == 'qnn.quantize': - rhs = rhs.args[0] - else: - rhs = relay.qnn.op.dequantize(rhs, - input_scale_rhs, - input_zero_point_rhs) - fp32_out = relay_op(lhs, rhs) - - if with_relu: - fp32_out = _op.nn.relu(fp32_out) - - return relay.qnn.op.quantize(fp32_out, - output_scale, - output_zero_point, - axis=-1, - out_dtype="uint8") - return _impl - - -def _cat(): - # refer to aten/src/ATen/native/quantized/cpu/qconcat.cpp - # for concat they also piggy backs to fp32(!) - # dequantize -> fp32 math -> quantize - # we can also use QNN concat op. we observed no change in accuracy - def _impl(inputs, _): - axis = inputs[1] - output_scale = _expr.const(inputs[2]) - output_zero_point = _expr.const(inputs[3]) - num_inputs = (len(inputs) - 4) // 2 - dequantized = [] - - for i in range(0, num_inputs): - inp_scale = _expr.const(inputs[4+i*2]) - inp_zp = _expr.const(inputs[4+i*2+1]) - dequantized.append(relay.qnn.op.dequantize(inputs[0][i], - inp_scale, inp_zp)) - - concat = _op.tensor.concatenate(dequantized, axis=axis) - return relay.qnn.op.quantize(concat, output_scale, output_zero_point, - axis=1, out_dtype="uint8") - - return _impl - - -def _add_scalar(): - # this is used for mobilenet v3 - def _impl(inputs, _): - # refer to aten/src/ATen/native/quantized/cpu/qadd.cpp - assert len(inputs) == 6, "Input quant params not found in op inputs" - s = inputs[4] - z = inputs[5] - c = inputs[1] - c_q = round(c / s) - q_min = 0 - q_max = 255 - - # math for calculating output scale and zp are already done - # during _add_output_quant_params_to_scalar_op above - out_scale = _expr.const(inputs[2]) - out_zp = _expr.const(inputs[3]) - - if q_min > z - c_q or q_max < z - c_q: - dequant = relay.qnn.op.dequantize(inputs[0], - _expr.const(s), _expr.const(z)) - dequantized_add = _op.tensor.add(dequant, _expr.const(c_q * s)) - return relay.qnn.op.quantize(dequantized_add, out_scale, out_zp, - axis=1, out_dtype="uint8") - # only scale change - return inputs[0] - - return _impl - - -def quantize_scalar(data, scale, zero_point): - # used to quantize 6., in mobilenet v3 - transformed = zero_point + data / scale - return max(0, min(round(transformed), 255)) - - -def _relu6(): - # refer to src/ATen/native/quantized/cpu/qrelu.cpp - def _impl(inputs, _): - assert len(inputs) == 4, "Input quant params not found in op inputs" - input_scale = inputs[2] - input_zero_point = inputs[3] - six = quantize_scalar(6., input_scale, input_zero_point) - return _op.tensor.clip(inputs[0], input_zero_point, six) - return _impl - - -def _mul_scalar(): - # this is used for mobilenet v3 - def _impl(inputs, _): - # refer to aten/src/ATen/native/quantized/cpu/qmul.cpp - # math for calculating output scale and zp are already done - # during _add_output_quant_params_to_scalar_op above - assert len(inputs) == 6, "Input quant params not found in op inputs" - other_val = inputs[1] # scalar - - if other_val > 0.0: - # only scale change - return inputs[0] - if other_val == 0.0: - shape = infer_shape(inputs[0]) - return _op.full(_expr.const(0), shape, dtype="uint8") - - # negative scale case - q_min = 0 - q_max = 255 - bias = _expr.const(q_max + q_min, dtype="int8") - int8 = bias - _op.cast(inputs[0], "int8") - return _op.cast(int8, "uint8") - - return _impl - - -convert_map = { - 'aten::quantize_per_tensor': _quantize_per_tensor(), - 'quantized::conv2d_relu': _quantized_conv2d(True), - 'aten::dequantize': _dequantize(), - 'quantized::conv2d': _quantized_conv2d(), - 'quantized::add_relu': _binop(relay.add, True), - 'quantized::add': _binop(relay.add), - 'quantized::mul_relu': _binop(relay.multiply, True), - 'quantized::mul': _binop(relay.multiply), - 'quantized::linear': _linear(), - 'quantized::linear_relu': _linear(True), - 'quantized::cat': _cat(), - 'quantized::add_scalar': _add_scalar(), - 'quantized::mul_scalar': _mul_scalar(), - 'quantized::relu6': _relu6() -} diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py deleted file mode 100644 index e3a876c79591..000000000000 --- a/tests/python/frontend/pytorch/qnn_test.py +++ /dev/null @@ -1,455 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" Tests on quantized torch model conversion """ -import os - -from PIL import Image - -import numpy as np - -import torch -from torch import nn -from torch.quantization import QuantStub, DeQuantStub -from torch.quantization import fuse_modules, QuantWrapper - -import tvm -from tvm import relay -from tvm.relay.frontend.pytorch import get_graph_input_names -from tvm.contrib.download import download_testdata - - -def torch_version_check(): - from packaging import version - return version.parse(torch.__version__) > version.parse("1.4.0") - - -def get_tvm_runtime(script_module, input_name, ishape): - - input_shapes = {input_name: ishape} - mod, params = relay.frontend.from_pytorch(script_module, input_shapes) - - with relay.build_config(opt_level=3): - # test on only cpu for now, torch cannot run quant models on cuda - # also not to make CI too slow - json, lib, params = relay.build(mod, target="llvm", params=params) - - runtime = tvm.contrib.graph_runtime.create(json, lib, tvm.cpu(0)) - runtime.set_input(**params) - return runtime - - -def get_qconfig(per_channel): - from torch.quantization.observer import MovingAverageMinMaxObserver - from torch.quantization.observer import default_weight_observer - - if per_channel: - return torch.quantization.get_default_qconfig('fbgemm') - else: - act = MovingAverageMinMaxObserver.with_args(reduce_range=False) - return torch.quantization.QConfig(activation=act, - weight=default_weight_observer) - - -def quantize_model(model, inp, per_channel=False, dummy=True): - model.fuse_model() - model.qconfig = get_qconfig(per_channel) - torch.quantization.prepare(model, inplace=True) - model(inp) - torch.quantization.convert(model, inplace=True) - - -class ConvBn(nn.Module): - def __init__(self, with_relu=False): - super().__init__() - layers = [nn.Conv2d(3, 32, 3, bias=True), - nn.BatchNorm2d(32)] - if with_relu: - layers.append(nn.ReLU()) - self.conv = nn.Sequential(*layers) - self.quant_wrap = QuantWrapper(self.conv) - self.with_relu = with_relu - - def forward(self, x): - return self.quant_wrap(x) - - def fuse_model(self): - indices = ["0", "1"] - if self.with_relu: - indices.append("2") - fuse_modules(self.conv, indices, inplace=True) - - -class Linear(nn.Module): - def __init__(self, with_relu=False): - super().__init__() - layers = [nn.Linear(16, 32)] - if with_relu: - layers.append(nn.ReLU()) - self.fc = nn.Sequential(*layers) - self.quant_wrap = QuantWrapper(self.fc) - self.with_relu = with_relu - - def forward(self, x): - return self.quant_wrap(x) - - def fuse_model(self): - if self.with_relu: - fuse_modules(self.fc, ["0", "1"], inplace=True) - - -class ReLU(nn.Module): - def __init__(self): - super().__init__() - self.relu = QuantWrapper(nn.ReLU()) - - def forward(self, x): - return self.relu(x) - - def fuse_model(self): - pass - - -# Mobilenet V3 related modules -class Hsigmoid(nn.Module): - def __init__(self, inplace=True, add_stub=False): - super().__init__() - self.float_op = nn.quantized.FloatFunctional() - self.relu6 = nn.ReLU6(inplace=inplace) - self.quant = QuantStub() - self.dequant = DeQuantStub() - self.add_stub = add_stub - - def forward(self, x): - if self.add_stub: - x = self.quant(x) - relu6 = self.relu6(self.float_op.add_scalar(x, 3.)) - mul = self.float_op.mul_scalar(relu6, 1/6.) - if self.add_stub: - mul = self.dequant(mul) - return mul - - def fuse_model(self): - pass - - -class Hswish(nn.Module): - def __init__(self, inplace=True, add_stub=False): - super(Hswish, self).__init__() - self.float_op = nn.quantized.FloatFunctional() - self.hsigmoid = Hsigmoid(inplace, add_stub=False) - self.quant = QuantStub() - self.dequant = DeQuantStub() - self.add_stub = add_stub - - def forward(self, x): - if self.add_stub: - x = self.quant(x) - mul = self.float_op.mul(x, self.hsigmoid(x)) - if self.add_stub: - mul = self.dequant(mul) - return mul - - def fuse_model(self): - pass - - -class SqueezeExcite(nn.Module): - def __init__(self, channel, reduction=4, add_stub=False): - super(SqueezeExcite, self).__init__() - self.avg_pool = nn.AdaptiveAvgPool2d(1) - self.fc = nn.Sequential( - nn.Linear(channel, channel // reduction, bias=False), - nn.ReLU(inplace=True), - nn.Linear(channel // reduction, channel, bias=False), - Hsigmoid(add_stub=False) - ) - self.fmul = nn.quantized.FloatFunctional() - self.quant = QuantStub() - self.dequant = DeQuantStub() - self.add_stub = add_stub - - def forward(self, x): - b, c, _, _ = x.size() - if self.add_stub: - x = self.quant(x) - y = self.avg_pool(x).view(b, c) - y = self.fc(y).view(b, c, 1, 1) - out = self.fmul.mul(x, y.expand_as(x)) - if self.add_stub: - return self.dequant(out) - else: - return out - - def fuse_model(self): - fuse_modules(self.fc, ["0", "1"], inplace=True) - - -# test on quantized::mul_scalar with negative scale -class MulScalarNegative(nn.Module): - def __init__(self, ): - super().__init__() - self.float_op = nn.quantized.FloatFunctional() - self.quant = QuantStub() - self.dequant = DeQuantStub() - - def forward(self, x): - x = self.quant(x) - mul = self.float_op.mul_scalar(x, -0.3) - return self.dequant(mul) - - def fuse_model(self): - pass - - -class UpsamplingBilinear(nn.Module): - def __init__(self): - super().__init__() - self.relu = QuantWrapper(nn.ReLU()) - self.quant = QuantStub() - self.dequant = DeQuantStub() - - def forward(self, x): - x = self.quant(x) - upsample = nn.functional.interpolate(x, scale_factor=2, - mode='bilinear', - align_corners=True) - return self.dequant(upsample) - - def fuse_model(self): - pass - - -def test_quantized_modules(): - imagenet_ishape = (1, 3, 224, 224) - - qmodules = [ - ("relu", imagenet_ishape, ReLU(), False), - ("upsample bilinear", (1, 3, 64, 64), UpsamplingBilinear(), False), - ] - - for per_channel in [False, True]: - if per_channel: - postfix = ", per_channel" - else: - postfix = "" - - qmodules += [ - ("conv_bn" + postfix, imagenet_ishape, ConvBn(), per_channel), - ("conv_bn_relu" + postfix, imagenet_ishape, ConvBn(with_relu=True), per_channel), - ("linear" + postfix, (16, 16), Linear(), per_channel), - ("linear_relu" + postfix, (16, 16), Linear(with_relu=True), per_channel) - ] - - if torch_version_check(): - qmodules += [ - ("hsigmoid", imagenet_ishape, Hsigmoid(add_stub=True), False), - ("hswish", imagenet_ishape, Hswish(add_stub=True), False), - ("semodule", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), False), - ("semodule, per_channel", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), True), - ("mul_scalar negative", imagenet_ishape, MulScalarNegative(), False) - ] - else: - print("Skipping tests that require torch > 1.4") - - for (module_name, ishape, raw_module, per_channel) in qmodules: - raw_module.eval() - inp = torch.rand(ishape) - - quantize_model(raw_module, inp, per_channel=per_channel, dummy=True) - script_module = torch.jit.trace(raw_module, inp).eval() - - with torch.no_grad(): - pt_result = script_module(inp.clone()).numpy() - - input_name = get_graph_input_names(script_module)[0] - - runtime = get_tvm_runtime(script_module, input_name, ishape) - runtime.set_input(input_name, inp.numpy().copy()) - runtime.run() - tvm_result = runtime.get_output(0).asnumpy() - - max_abs_diff = np.max(np.abs(tvm_result - pt_result)) - mean_abs_diff = np.mean(np.abs(tvm_result - pt_result)) - num_identical = np.sum(tvm_result == pt_result) - match_ratio = num_identical / float(np.prod(tvm_result.shape)) - - print(module_name, max_abs_diff, mean_abs_diff, match_ratio) - - # sample outputs - """ - relu 0.0039215684 2.6052087e-08 0.9999933567176871 - upsample bilinear 0.0 0.0 1.0 - conv_bn 0.22062653 0.011478779 0.6909348115006899 - conv_bn_relu 0.3700896 0.010921672 0.7489366477964451 - linear 0.15987062 0.009231662 0.794921875 - linear_relu 0.14180502 0.0053220326 0.8828125 - conv_bn, per_channel 0.01654929 2.9486866e-06 0.9998218235127019 - conv_bn_relu, per_channel 0.009089053 1.4926576e-06 0.9998357732732732 - linear, per_channel 0.0 0.0 1.0 - linear_relu, per_channel 0.0 0.0 1.0 - hsigmoid 0.002614379 0.00020525524 0.9214896896258503 - hswish 0.0052286386 0.00063522335 0.7587359162414966 - semodule, per_channel 0.0039885044 0.0008620687 0.7838592529296875 - mul_scalar negative 0.0011764616 7.815566e-09 0.9999933567176871 - """ - - # we cannot make any guarantee on how close the raw output is to torch - # tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-1, atol=1e-1) - - -def test_quantized_imagenet(): - def get_transform(): - import torchvision.transforms as transforms - normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]) - return transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - normalize, - ]) - - def get_real_image(im_height, im_width): - repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/' - img_name = 'elephant-299.jpg' - image_url = os.path.join(repo_base, img_name) - img_path = download_testdata(image_url, img_name, module='data') - return Image.open(img_path).resize((im_height, im_width)) - - def get_imagenet_input(): - im = get_real_image(224, 224) - preprocess = get_transform() - pt_tensor = preprocess(im) - return np.expand_dims(pt_tensor.numpy(), 0) - - from torchvision.models.quantization import resnet as qresnet - from torchvision.models.quantization import mobilenet as qmobilenet - from torchvision.models.quantization import inception as qinception - from torchvision.models.quantization import googlenet as qgooglenet - - qmodels = [] - - for per_channel in [False, True]: - qmodels += [ - ("resnet18", qresnet.resnet18(pretrained=True), per_channel), - ("mobilenet_v2", qmobilenet.mobilenet_v2(pretrained=True), per_channel), - ("inception_v3", qinception.inception_v3(pretrained=True), per_channel), - ("googlenet", qgooglenet(pretrained=True), per_channel), - ] - - results = [] - - for (model_name, raw_model, per_channel) in qmodels: - raw_model.eval() - - if per_channel: - model_name += ", per channel quantization" - else: - model_name += ", per tensor quantization" - - inp = get_imagenet_input() - pt_inp = torch.from_numpy(inp) - - quantize_model(raw_model, pt_inp, per_channel=per_channel, dummy=False) - script_module = torch.jit.trace(raw_model, pt_inp).eval() - - with torch.no_grad(): - pt_result = script_module(pt_inp).numpy() - - input_name = get_graph_input_names(script_module)[0] - runtime = get_tvm_runtime(script_module, input_name, (1, 3, 224, 224)) - runtime.set_input(input_name, inp) - runtime.run() - - tvm_result = runtime.get_output(0).asnumpy() - - results.append((model_name, pt_result[0], tvm_result[0])) - - for (model_name, pt_result, tvm_result) in results: - max_abs_diff = np.max(np.abs(tvm_result - pt_result)) - mean_abs_diff = np.mean(np.abs(tvm_result - pt_result)) - num_identical = np.sum(tvm_result == pt_result) - pt_top3_labels = np.argsort(pt_result)[::-1][:3] - tvm_top3_labels = np.argsort(pt_result)[::-1][:3] - - print("\nModel name: %s" % model_name) - print("PyTorch top3 label:", pt_top3_labels) - print("TVM top3 label:", tvm_top3_labels) - print("max abs diff:", max_abs_diff) - print("mean abs_diff:", mean_abs_diff) - print("%d in 1000 raw outputs identical." % num_identical) - - assert set(pt_top3_labels) == set(tvm_top3_labels) - - # sample outputs - """ - Model name: resnet18, per tensor quantization - PyTorch top3 label: [386 101 385] - TVM top3 label: [386 101 385] - max abs diff: 0.65681696 - mean abs_diff: 0.14055882 - 236 in 1000 raw outputs identical. - - Model name: mobilenet_v2, per tensor quantization - PyTorch top3 label: [101 386 385] - TVM top3 label: [101 386 385] - max abs diff: 2.1262953 - mean abs_diff: 0.41025686 - 101 in 1000 raw outputs identical. - - Model name: inception_v3, per tensor quantization - PyTorch top3 label: [101 386 385] - TVM top3 label: [101 386 385] - max abs diff: 0.9994669 - mean abs_diff: 0.098697364 - 272 in 1000 raw outputs identical. - - Model name: googlenet, per tensor quantization - PyTorch top3 label: [101 386 385] - TVM top3 label: [101 386 385] - max abs diff: 0.28248847 - mean abs_diff: 0.0634469 - 274 in 1000 raw outputs identical. - - Model name: resnet18, per channel quantization - PyTorch top3 label: [101 386 385] - TVM top3 label: [101 386 385] - max abs diff: 0.65908074 - mean abs_diff: 0.1274223 - 469 in 1000 raw outputs identical. - - Model name: mobilenet_v2, per channel quantization - PyTorch top3 label: [101 386 385] - TVM top3 label: [101 386 385] - max abs diff: 0.71120834 - mean abs_diff: 0.15883648 - 423 in 1000 raw outputs identical. - - Model name: inception_v3, per channel quantization - PyTorch top3 label: [386 101 385] - TVM top3 label: [386 101 385] - max abs diff: 1.3372154 - mean abs_diff: 0.1225224 - 401 in 1000 raw outputs identical. - - Model name: googlenet, per channel quantization - PyTorch top3 label: [101 386 385] - TVM top3 label: [101 386 385] - max abs diff: 0.34015465 - mean abs_diff: 0.054197952 - 558 in 1000 raw outputs identical. - """ diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index eed47ea8ad5a..e60c1fd88183 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -854,9 +854,3 @@ def forward(self, inp): test_custom_conversion_map() test_segmentaton_models() - - # Quantization test - from qnn_test import test_quantized_imagenet, test_quantized_modules - - test_quantized_modules() - test_quantized_imagenet()