Skip to content

Commit

Permalink
[Torch, QNN] Remove FP32 piggy back and use QNN add/mul/concatenate (#…
Browse files Browse the repository at this point in the history
…5061)

* use qnn add/mul/concatenate

* remove logging
  • Loading branch information
masahi authored Mar 13, 2020
1 parent d7a7483 commit 4fbc2fb
Showing 1 changed file with 70 additions and 29 deletions.
99 changes: 70 additions & 29 deletions python/tvm/relay/frontend/qnn_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
# pylint: disable=invalid-name, import-outside-toplevel
""" Functions to convert quantized torch models to QNN """
import logging

import numpy as np

Expand Down Expand Up @@ -536,21 +537,23 @@ def _impl(inputs, _):
return _impl


def _binop(relay_op, with_relu=False):
def _binop(relay_op, with_relu=False, fp32_piggy_back=False):
def qnn_impl(lhs, rhs, input_scale_lhs, input_zero_point_lhs,
input_scale_rhs, input_zero_point_rhs,
output_scale, output_zero_point):
qnn_out = relay_op(lhs, rhs, input_scale_lhs, input_zero_point_lhs,
input_scale_rhs, input_zero_point_rhs,
output_scale, output_zero_point)
if with_relu:
clip_min = _get_scalar(output_zero_point)
return _op.tensor.clip(qnn_out, clip_min, 255)
return qnn_out

# refer to aten/src/ATen/native/quantized/cpu/{qadd, qmul}.cpp
# they piggy backs to fp32 math by dequantize -> fp32 math -> quantize
def _impl(inputs, _):
output_scale = _expr.const(inputs[2])
output_zero_point = _expr.const(inputs[3])
assert len(inputs) == 8, "Input quant params not found in op inputs"
# Manually added by add_input_quant_params_to_op_inputs above
input_scale_lhs = _expr.const(inputs[4])
input_zero_point_lhs = _expr.const(inputs[5])
input_scale_rhs = _expr.const(inputs[6])
input_zero_point_rhs = _expr.const(inputs[7])
lhs = inputs[0]
rhs = inputs[1]

def torch_impl(lhs, rhs, input_scale_lhs, input_zero_point_lhs,
input_scale_rhs, input_zero_point_rhs,
output_scale, output_zero_point):
if isinstance(lhs, _expr.Call) and lhs.op.name == 'qnn.quantize':
lhs = lhs.args[0]
else:
Expand All @@ -574,30 +577,68 @@ def _impl(inputs, _):
output_zero_point,
axis=-1,
out_dtype="uint8")

def _impl(inputs, _):
lhs = inputs[0]
rhs = inputs[1]
output_scale = _expr.const(inputs[2])
output_zero_point = _expr.const(inputs[3])
assert len(inputs) == 8, "Input quant params not found in op inputs"
# Manually added by add_input_quant_params_to_op_inputs above
input_scale_lhs = _expr.const(inputs[4])
input_zero_point_lhs = _expr.const(inputs[5])
input_scale_rhs = _expr.const(inputs[6])
input_zero_point_rhs = _expr.const(inputs[7])

if fp32_piggy_back:
logging.info("Piggy backing to FP32 op (PyTorch way)")
return torch_impl(lhs, rhs, input_scale_lhs, input_zero_point_lhs,
input_scale_rhs, input_zero_point_rhs,
output_scale, output_zero_point)

return qnn_impl(lhs, rhs, input_scale_lhs, input_zero_point_lhs,
input_scale_rhs, input_zero_point_rhs,
output_scale, output_zero_point)

return _impl


def _cat():
def _cat(fp32_piggy_back=False):
# refer to aten/src/ATen/native/quantized/cpu/qconcat.cpp
# for concat they also piggy backs to fp32(!)
# dequantize -> fp32 math -> quantize
# we can also use QNN concat op. we observed no change in accuracy
def torch_impl(inputs, input_scales, input_zero_points,
output_scale, output_zero_point, axis):
dequantized = []
for inp, inp_scale, inp_zp in zip(inputs, input_scales,
input_zero_points):
dequantized.append(relay.qnn.op.dequantize(inp, inp_scale, inp_zp))

concat = _op.tensor.concatenate(dequantized, axis=axis)
return relay.qnn.op.quantize(concat, output_scale, output_zero_point,
axis=axis, out_dtype="uint8")

def _impl(inputs, _):
axis = inputs[1]
output_scale = _expr.const(inputs[2])
output_zero_point = _expr.const(inputs[3])
num_inputs = (len(inputs) - 4) // 2
dequantized = []

input_scales = []
input_zero_points = []

for i in range(0, num_inputs):
inp_scale = _expr.const(inputs[4+i*2])
inp_zp = _expr.const(inputs[4+i*2+1])
dequantized.append(relay.qnn.op.dequantize(inputs[0][i],
inp_scale, inp_zp))
input_scales.append(_expr.const(inputs[4+i*2]))
input_zero_points.append(_expr.const(inputs[4+i*2+1]))

concat = _op.tensor.concatenate(dequantized, axis=axis)
return relay.qnn.op.quantize(concat, output_scale, output_zero_point,
axis=1, out_dtype="uint8")
if fp32_piggy_back:
return torch_impl(inputs[0], input_scales, input_zero_points,
output_scale, output_zero_point, axis)

return relay.qnn.op.concatenate(inputs[0],
input_scales, input_zero_points,
output_scale, output_zero_point,
axis)

return _impl

Expand Down Expand Up @@ -676,15 +717,15 @@ def _impl(inputs, _):

convert_map = {
'aten::quantize_per_tensor': _quantize_per_tensor(),
'quantized::conv2d_relu': _quantized_conv2d(True),
'quantized::conv2d_relu': _quantized_conv2d(with_relu=True),
'aten::dequantize': _dequantize(),
'quantized::conv2d': _quantized_conv2d(),
'quantized::add_relu': _binop(relay.add, True),
'quantized::add': _binop(relay.add),
'quantized::mul_relu': _binop(relay.multiply, True),
'quantized::mul': _binop(relay.multiply),
'quantized::add_relu': _binop(relay.qnn.op.add, with_relu=True),
'quantized::add': _binop(relay.qnn.op.add),
'quantized::mul_relu': _binop(relay.qnn.op.mul, with_relu=True),
'quantized::mul': _binop(relay.qnn.op.mul),
'quantized::linear': _linear(),
'quantized::linear_relu': _linear(True),
'quantized::linear_relu': _linear(with_relu=True),
'quantized::cat': _cat(),
'quantized::add_scalar': _add_scalar(),
'quantized::mul_scalar': _mul_scalar(),
Expand Down

0 comments on commit 4fbc2fb

Please sign in to comment.