From c334a2a35fd04bf17cc951e2a9c08d322ff45300 Mon Sep 17 00:00:00 2001 From: masahi Date: Sat, 14 Mar 2020 07:11:32 +0900 Subject: [PATCH] [Torch, QNN] Remove FP32 piggy back and use QNN add/mul/concatenate (#5061) * use qnn add/mul/concatenate * remove logging --- python/tvm/relay/frontend/qnn_torch.py | 99 ++++++++++++++++++-------- 1 file changed, 70 insertions(+), 29 deletions(-) diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index 0704e34b77ef..70178be52bb0 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=invalid-name, import-outside-toplevel """ Functions to convert quantized torch models to QNN """ +import logging import numpy as np @@ -536,21 +537,23 @@ def _impl(inputs, _): return _impl -def _binop(relay_op, with_relu=False): +def _binop(relay_op, with_relu=False, fp32_piggy_back=False): + def qnn_impl(lhs, rhs, input_scale_lhs, input_zero_point_lhs, + input_scale_rhs, input_zero_point_rhs, + output_scale, output_zero_point): + qnn_out = relay_op(lhs, rhs, input_scale_lhs, input_zero_point_lhs, + input_scale_rhs, input_zero_point_rhs, + output_scale, output_zero_point) + if with_relu: + clip_min = _get_scalar(output_zero_point) + return _op.tensor.clip(qnn_out, clip_min, 255) + return qnn_out + # refer to aten/src/ATen/native/quantized/cpu/{qadd, qmul}.cpp # they piggy backs to fp32 math by dequantize -> fp32 math -> quantize - def _impl(inputs, _): - output_scale = _expr.const(inputs[2]) - output_zero_point = _expr.const(inputs[3]) - assert len(inputs) == 8, "Input quant params not found in op inputs" - # Manually added by add_input_quant_params_to_op_inputs above - input_scale_lhs = _expr.const(inputs[4]) - input_zero_point_lhs = _expr.const(inputs[5]) - input_scale_rhs = _expr.const(inputs[6]) - input_zero_point_rhs = _expr.const(inputs[7]) - lhs = inputs[0] - rhs = inputs[1] - + def torch_impl(lhs, rhs, input_scale_lhs, input_zero_point_lhs, + input_scale_rhs, input_zero_point_rhs, + output_scale, output_zero_point): if isinstance(lhs, _expr.Call) and lhs.op.name == 'qnn.quantize': lhs = lhs.args[0] else: @@ -574,30 +577,68 @@ def _impl(inputs, _): output_zero_point, axis=-1, out_dtype="uint8") + + def _impl(inputs, _): + lhs = inputs[0] + rhs = inputs[1] + output_scale = _expr.const(inputs[2]) + output_zero_point = _expr.const(inputs[3]) + assert len(inputs) == 8, "Input quant params not found in op inputs" + # Manually added by add_input_quant_params_to_op_inputs above + input_scale_lhs = _expr.const(inputs[4]) + input_zero_point_lhs = _expr.const(inputs[5]) + input_scale_rhs = _expr.const(inputs[6]) + input_zero_point_rhs = _expr.const(inputs[7]) + + if fp32_piggy_back: + logging.info("Piggy backing to FP32 op (PyTorch way)") + return torch_impl(lhs, rhs, input_scale_lhs, input_zero_point_lhs, + input_scale_rhs, input_zero_point_rhs, + output_scale, output_zero_point) + + return qnn_impl(lhs, rhs, input_scale_lhs, input_zero_point_lhs, + input_scale_rhs, input_zero_point_rhs, + output_scale, output_zero_point) + return _impl -def _cat(): +def _cat(fp32_piggy_back=False): # refer to aten/src/ATen/native/quantized/cpu/qconcat.cpp # for concat they also piggy backs to fp32(!) # dequantize -> fp32 math -> quantize - # we can also use QNN concat op. we observed no change in accuracy + def torch_impl(inputs, input_scales, input_zero_points, + output_scale, output_zero_point, axis): + dequantized = [] + for inp, inp_scale, inp_zp in zip(inputs, input_scales, + input_zero_points): + dequantized.append(relay.qnn.op.dequantize(inp, inp_scale, inp_zp)) + + concat = _op.tensor.concatenate(dequantized, axis=axis) + return relay.qnn.op.quantize(concat, output_scale, output_zero_point, + axis=axis, out_dtype="uint8") + def _impl(inputs, _): axis = inputs[1] output_scale = _expr.const(inputs[2]) output_zero_point = _expr.const(inputs[3]) num_inputs = (len(inputs) - 4) // 2 - dequantized = [] + + input_scales = [] + input_zero_points = [] for i in range(0, num_inputs): - inp_scale = _expr.const(inputs[4+i*2]) - inp_zp = _expr.const(inputs[4+i*2+1]) - dequantized.append(relay.qnn.op.dequantize(inputs[0][i], - inp_scale, inp_zp)) + input_scales.append(_expr.const(inputs[4+i*2])) + input_zero_points.append(_expr.const(inputs[4+i*2+1])) - concat = _op.tensor.concatenate(dequantized, axis=axis) - return relay.qnn.op.quantize(concat, output_scale, output_zero_point, - axis=1, out_dtype="uint8") + if fp32_piggy_back: + return torch_impl(inputs[0], input_scales, input_zero_points, + output_scale, output_zero_point, axis) + + return relay.qnn.op.concatenate(inputs[0], + input_scales, input_zero_points, + output_scale, output_zero_point, + axis) return _impl @@ -676,15 +717,15 @@ def _impl(inputs, _): convert_map = { 'aten::quantize_per_tensor': _quantize_per_tensor(), - 'quantized::conv2d_relu': _quantized_conv2d(True), + 'quantized::conv2d_relu': _quantized_conv2d(with_relu=True), 'aten::dequantize': _dequantize(), 'quantized::conv2d': _quantized_conv2d(), - 'quantized::add_relu': _binop(relay.add, True), - 'quantized::add': _binop(relay.add), - 'quantized::mul_relu': _binop(relay.multiply, True), - 'quantized::mul': _binop(relay.multiply), + 'quantized::add_relu': _binop(relay.qnn.op.add, with_relu=True), + 'quantized::add': _binop(relay.qnn.op.add), + 'quantized::mul_relu': _binop(relay.qnn.op.mul, with_relu=True), + 'quantized::mul': _binop(relay.qnn.op.mul), 'quantized::linear': _linear(), - 'quantized::linear_relu': _linear(True), + 'quantized::linear_relu': _linear(with_relu=True), 'quantized::cat': _cat(), 'quantized::add_scalar': _add_scalar(), 'quantized::mul_scalar': _mul_scalar(),