diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index 3f23d6895b43..9a8f22bfb9bc 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -21,6 +21,8 @@ from tvm.relay.expr import Tuple, TupleWrapper from tvm.relay.op.nn.utils import get_pad_tuple2d from . import _make +from ... import op as reg +from ...op import OpPattern def requantize( @@ -496,3 +498,8 @@ def subtract( output_scale, output_zero_point, ) + + +# register fuse pattern for qnn ops +reg.register_pattern("qnn.quantize", OpPattern.OPAQUE) +reg.register_pattern("qnn.dequantize", OpPattern.OPAQUE) diff --git a/src/relay/qnn/op/dequantize.cc b/src/relay/qnn/op/dequantize.cc index 2fe075c7e64b..724441e0c523 100644 --- a/src/relay/qnn/op/dequantize.cc +++ b/src/relay/qnn/op/dequantize.cc @@ -79,20 +79,27 @@ Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point, int axis } Expr DequantizeLower(const Expr& input_tensor, const Expr& input_scale, - const Expr& input_zero_point, const Array& input_shape, + const Expr& input_zero_point, const Array& types, const DequantizeAttrs* attrs) { const auto axis = attrs->axis; + ICHECK_EQ(types.size(), 4); + auto in_type = types[0]; + auto in_tensor_type = in_type.as(); + ICHECK(in_tensor_type != nullptr) << "Type information missing" + << " Please run infer_type pass."; + Array input_shape = in_tensor_type->shape; + size_t n_dim = input_shape.size(); // Expand scale and zero point if the input tensor is channel quantized auto expanded_input_scale = input_scale; - if (!IsConstScalar(input_scale)) { + if (!IsConstScalar(input_scale) && !IsScalarType(types[1])) { expanded_input_scale = ExpandBiasToMatchAxis(input_scale, n_dim, {axis}); } auto expanded_input_zero_point = input_zero_point; - if (!IsConstScalar(input_zero_point)) { + if (!IsConstScalar(input_zero_point) && !IsScalarType(types[2])) { expanded_input_zero_point = ExpandBiasToMatchAxis(input_zero_point, n_dim, {axis}); } @@ -113,15 +120,7 @@ Expr DequantizeQnnCanonicalize(const Attrs& attrs, const Array& new_args, const auto* dequantize_attrs = attrs.as(); ICHECK(dequantize_attrs != nullptr); - // Find input shape. - ICHECK_EQ(types.size(), 4); - auto in_type = types[0]; - auto in_tensor_type = in_type.as(); - ICHECK(in_tensor_type != nullptr) << "Type information missing." - << " Please run infer_type pass."; - Array input_shape = in_tensor_type->shape; - - return DequantizeLower(data, input_scale, input_zero_point, input_shape, dequantize_attrs); + return DequantizeLower(data, input_scale, input_zero_point, types, dequantize_attrs); } RELAY_REGISTER_OP("qnn.dequantize") diff --git a/tests/python/relay/test_op_qnn_dequantize.py b/tests/python/relay/test_op_qnn_dequantize.py index e1416622c236..e7fb161a13cb 100644 --- a/tests/python/relay/test_op_qnn_dequantize.py +++ b/tests/python/relay/test_op_qnn_dequantize.py @@ -20,6 +20,7 @@ import numpy as np from tvm import relay from tvm.contrib import graph_runtime +from tvm.relay.testing import run_infer_type def dequantize_test_driver(in_dtype, quant_args, in_data, verify_output_data, axis): @@ -118,9 +119,36 @@ def test_channelwise_axis_0(): ) +def test_dynamic_dequantize(): + x = relay.var("x", shape=(1, 2, 3, 4), dtype="int8") + scale_var = relay.var("scale", shape=(), dtype="float32") + zp_var = relay.var("zp", shape=(), dtype="int32") + + deq_x = relay.qnn.op.dequantize(x, scale_var * scale_var, zp_var + zp_var) + tt = run_infer_type(deq_x) + + assert tt.checked_type == relay.TensorType((1, 2, 3, 4), "float32") + func = relay.Function([x, scale_var, zp_var], deq_x) + data = np.random.uniform(size=(1, 2, 3, 4)).astype("int8") + scale = np.array(1).astype("float32") + zp = np.array(0).astype("int32") + + mod = tvm.ir.IRModule.from_expr(func) + + for target, ctx in tvm.testing.enabled_targets(): + # TODO: (electriclilies) enable AlterOpLayout when it is fixed + with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]): + lib = relay.build(mod, target=target) + + module = graph_runtime.GraphModule(lib["default"](ctx)) + module.set_input(**{"x": data, "scale": scale, "zp": zp}) + module.run() + + if __name__ == "__main__": test_uint8_to_float32() test_int8_to_float32() test_int32_to_float32() test_channelwise_axis_1() test_channelwise_axis_0() + test_dynamic_dequantize() diff --git a/tests/python/relay/test_op_qnn_quantize.py b/tests/python/relay/test_op_qnn_quantize.py index a22c25f5b97f..2ef298679904 100644 --- a/tests/python/relay/test_op_qnn_quantize.py +++ b/tests/python/relay/test_op_qnn_quantize.py @@ -20,6 +20,7 @@ import numpy as np from tvm import relay from tvm.contrib import graph_runtime +from tvm.relay.testing import run_infer_type def quantize_test_driver(in_dtype, quant_args, axis, out_dtype, in_data, verify_output_data): @@ -133,8 +134,35 @@ def test_channelwise_axis_1(): ) +def test_dynamic_quantize(): + x = relay.var("x", shape=(1, 2, 3, 4), dtype="float32") + scale_var = relay.var("scale", shape=(), dtype="float32") + zp_var = relay.var("zp", shape=(), dtype="int32") + + q_x = relay.qnn.op.quantize(x, scale_var * scale_var, zp_var + zp_var) + tt = run_infer_type(q_x) + + assert tt.checked_type == relay.TensorType((1, 2, 3, 4), "int8") + func = relay.Function([x, scale_var, zp_var], q_x) + data = np.random.uniform(size=(1, 2, 3, 4)).astype("float32") + scale = np.array(1).astype("float32") + zp = np.array(0).astype("int32") + + mod = tvm.ir.IRModule.from_expr(func) + + for target, ctx in tvm.testing.enabled_targets(): + # TODO: (electriclilies) enable AlterOpLayout when it is fixed + with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]): + lib = relay.build(mod, target=target) + + module = graph_runtime.GraphModule(lib["default"](ctx)) + module.set_input(**{"x": data, "scale": scale, "zp": zp}) + module.run() + + if __name__ == "__main__": test_float32_to_uint8() test_float32_to_int8() test_channelwise_axis_0() test_channelwise_axis_1() + test_dynamic_quantize()