Skip to content

Commit

Permalink
[QNN] Dynamic scale, zero point in qnn.op.dequantize (apache#6849)
Browse files Browse the repository at this point in the history
* add dynamic dequantize

* register quantize and dequantize as opaque

* make tests better

* black

* remove main fn

* fix black again

* move tests

* fix import

* fix import again

* try again

* fix import
  • Loading branch information
electriclilies authored and Trevor Morris committed Dec 4, 2020
1 parent 90ddf10 commit b9c2ed1
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 12 deletions.
7 changes: 7 additions & 0 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from tvm.relay.expr import Tuple, TupleWrapper
from tvm.relay.op.nn.utils import get_pad_tuple2d
from . import _make
from ... import op as reg
from ...op import OpPattern


def requantize(
Expand Down Expand Up @@ -496,3 +498,8 @@ def subtract(
output_scale,
output_zero_point,
)


# register fuse pattern for qnn ops
reg.register_pattern("qnn.quantize", OpPattern.OPAQUE)
reg.register_pattern("qnn.dequantize", OpPattern.OPAQUE)
23 changes: 11 additions & 12 deletions src/relay/qnn/op/dequantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,20 +79,27 @@ Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point, int axis
}

Expr DequantizeLower(const Expr& input_tensor, const Expr& input_scale,
const Expr& input_zero_point, const Array<IndexExpr>& input_shape,
const Expr& input_zero_point, const Array<tvm::relay::Type>& types,
const DequantizeAttrs* attrs) {
const auto axis = attrs->axis;

ICHECK_EQ(types.size(), 4);
auto in_type = types[0];
auto in_tensor_type = in_type.as<TensorTypeNode>();
ICHECK(in_tensor_type != nullptr) << "Type information missing"
<< " Please run infer_type pass.";
Array<IndexExpr> input_shape = in_tensor_type->shape;

size_t n_dim = input_shape.size();

// Expand scale and zero point if the input tensor is channel quantized
auto expanded_input_scale = input_scale;
if (!IsConstScalar(input_scale)) {
if (!IsConstScalar(input_scale) && !IsScalarType(types[1])) {
expanded_input_scale = ExpandBiasToMatchAxis(input_scale, n_dim, {axis});
}

auto expanded_input_zero_point = input_zero_point;
if (!IsConstScalar(input_zero_point)) {
if (!IsConstScalar(input_zero_point) && !IsScalarType(types[2])) {
expanded_input_zero_point = ExpandBiasToMatchAxis(input_zero_point, n_dim, {axis});
}

Expand All @@ -113,15 +120,7 @@ Expr DequantizeQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
const auto* dequantize_attrs = attrs.as<DequantizeAttrs>();
ICHECK(dequantize_attrs != nullptr);

// Find input shape.
ICHECK_EQ(types.size(), 4);
auto in_type = types[0];
auto in_tensor_type = in_type.as<TensorTypeNode>();
ICHECK(in_tensor_type != nullptr) << "Type information missing."
<< " Please run infer_type pass.";
Array<IndexExpr> input_shape = in_tensor_type->shape;

return DequantizeLower(data, input_scale, input_zero_point, input_shape, dequantize_attrs);
return DequantizeLower(data, input_scale, input_zero_point, types, dequantize_attrs);
}

RELAY_REGISTER_OP("qnn.dequantize")
Expand Down
28 changes: 28 additions & 0 deletions tests/python/relay/test_op_qnn_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import numpy as np
from tvm import relay
from tvm.contrib import graph_runtime
from tvm.relay.testing import run_infer_type


def dequantize_test_driver(in_dtype, quant_args, in_data, verify_output_data, axis):
Expand Down Expand Up @@ -118,9 +119,36 @@ def test_channelwise_axis_0():
)


def test_dynamic_dequantize():
x = relay.var("x", shape=(1, 2, 3, 4), dtype="int8")
scale_var = relay.var("scale", shape=(), dtype="float32")
zp_var = relay.var("zp", shape=(), dtype="int32")

deq_x = relay.qnn.op.dequantize(x, scale_var * scale_var, zp_var + zp_var)
tt = run_infer_type(deq_x)

assert tt.checked_type == relay.TensorType((1, 2, 3, 4), "float32")
func = relay.Function([x, scale_var, zp_var], deq_x)
data = np.random.uniform(size=(1, 2, 3, 4)).astype("int8")
scale = np.array(1).astype("float32")
zp = np.array(0).astype("int32")

mod = tvm.ir.IRModule.from_expr(func)

for target, ctx in tvm.testing.enabled_targets():
# TODO: (electriclilies) enable AlterOpLayout when it is fixed
with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
lib = relay.build(mod, target=target)

module = graph_runtime.GraphModule(lib["default"](ctx))
module.set_input(**{"x": data, "scale": scale, "zp": zp})
module.run()


if __name__ == "__main__":
test_uint8_to_float32()
test_int8_to_float32()
test_int32_to_float32()
test_channelwise_axis_1()
test_channelwise_axis_0()
test_dynamic_dequantize()
28 changes: 28 additions & 0 deletions tests/python/relay/test_op_qnn_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import numpy as np
from tvm import relay
from tvm.contrib import graph_runtime
from tvm.relay.testing import run_infer_type


def quantize_test_driver(in_dtype, quant_args, axis, out_dtype, in_data, verify_output_data):
Expand Down Expand Up @@ -133,8 +134,35 @@ def test_channelwise_axis_1():
)


def test_dynamic_quantize():
x = relay.var("x", shape=(1, 2, 3, 4), dtype="float32")
scale_var = relay.var("scale", shape=(), dtype="float32")
zp_var = relay.var("zp", shape=(), dtype="int32")

q_x = relay.qnn.op.quantize(x, scale_var * scale_var, zp_var + zp_var)
tt = run_infer_type(q_x)

assert tt.checked_type == relay.TensorType((1, 2, 3, 4), "int8")
func = relay.Function([x, scale_var, zp_var], q_x)
data = np.random.uniform(size=(1, 2, 3, 4)).astype("float32")
scale = np.array(1).astype("float32")
zp = np.array(0).astype("int32")

mod = tvm.ir.IRModule.from_expr(func)

for target, ctx in tvm.testing.enabled_targets():
# TODO: (electriclilies) enable AlterOpLayout when it is fixed
with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
lib = relay.build(mod, target=target)

module = graph_runtime.GraphModule(lib["default"](ctx))
module.set_input(**{"x": data, "scale": scale, "zp": zp})
module.run()


if __name__ == "__main__":
test_float32_to_uint8()
test_float32_to_int8()
test_channelwise_axis_0()
test_channelwise_axis_1()
test_dynamic_quantize()

0 comments on commit b9c2ed1

Please sign in to comment.