diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 13505fd0f738..c8855b2ea2be 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2465,6 +2465,12 @@ def _impl_v1(cls, inputs, attr, params): @classmethod def _impl_v11(cls, inputs, attr, params): + if len(inputs) == 3 and isinstance(inputs[2], _expr.Constant): + attr["max"] = inputs[2].data.asnumpy().item() + inputs = inputs[0:2] + if len(inputs) >= 2 and isinstance(inputs[1], _expr.Constant): + attr["min"] = inputs[1].data.asnumpy().item() + inputs = inputs[0:1] if "min" in attr and "max" in attr: return Clip.convert_attributes(inputs, attr, params) diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index 1f267abedc1a..4c693fe64ee0 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -28,6 +28,7 @@ OpStrategy, debug, register_external_compiler, + register_fake_quantization_to_integer, ) from . import strategy diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index 33cb46d67f34..ccf011819a97 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -436,6 +436,27 @@ def register_external_compiler(op_name, fexternal=None, level=10): return tvm.ir.register_op_attr(op_name, "FTVMExternalCompiler", fexternal, level) +def register_fake_quantization_to_integer(op_name, func=None, level=10): + """Register quantize function for an op + + Given an op and Affine Types on it's inputs, this function should return the op + in affine space/integer operators and the new type of the output, where affine + denotes the transformation x_real = (x_affine - zero_point) * scale + + Parameters + ---------- + op_name : str + The name of the operator + + func: function (expr: Expr, map: Map) -> new_expr: Expr + The function for translating the op into affine space and integer operators + + level : int + The priority level + """ + return tvm.ir.register_op_attr(op_name, "FTVMFakeQuantizationToInteger", func, level) + + @tvm._ffi.register_func("relay.op.compiler._lower") def _lower(name, schedule, inputs, outputs): return lower(schedule, list(inputs) + list(outputs), name=name) diff --git a/python/tvm/relay/transform/__init__.py b/python/tvm/relay/transform/__init__.py index ca9996aeaaae..9ed40f85c3bc 100644 --- a/python/tvm/relay/transform/__init__.py +++ b/python/tvm/relay/transform/__init__.py @@ -19,3 +19,4 @@ # transformation passes from .transform import * from .recast import recast +from . import fake_quantization_to_integer diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py new file mode 100644 index 000000000000..5f4c53772eec --- /dev/null +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -0,0 +1,166 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relay functions for rewriting fake quantized ops.""" +import tvm +from tvm import relay +from ..op import register_fake_quantization_to_integer + + +def fold_constant(expr): + mod = tvm.IRModule.from_expr(expr) + mod = relay.transform.FoldConstant()(mod) + return mod["main"].body + + +@register_fake_quantization_to_integer("qnn.dequantize") +def dequantize(expr, type_map): + """Remove dequantize op""" + out = expr.args[0] + t = type_map[expr] + return [out, t.scale, t.zero_point, t.dtype] + + +@register_fake_quantization_to_integer("qnn.quantize") +def quantize(expr, type_map): + """Turn a quantize op into requantize or remove it""" + out = expr.args[0] + t = type_map[out] + in_scale = fold_constant(t.scale) + in_zero_point = fold_constant(t.zero_point) + if not ( + tvm.ir.structural_equal(in_scale, expr.args[1]) + and tvm.ir.structural_equal(in_zero_point, expr.args[2]) + and tvm.ir.structural_equal(t.dtype, expr.attrs.out_dtype) + ): + out = relay.qnn.op.requantize( + out, + in_scale, + in_zero_point, + expr.args[1], + expr.args[2], + out_dtype=expr.attrs.out_dtype, + ) + return [out, expr.args[1], expr.args[2], expr.attrs.out_dtype] + + +def register_unary_identity(op_name, op): + def identity(expr, type_map): + assert len(expr.args) == 1 + arg = expr.args[0] + t = type_map[arg] + out = op(arg, **expr.attrs) + return [out, t.scale, t.zero_point, t.dtype] + + return register_fake_quantization_to_integer(op_name, identity) + + +register_unary_identity("reshape", relay.op.reshape) +register_unary_identity("transpose", relay.op.transpose) +register_unary_identity("nn.max_pool2d", relay.op.nn.max_pool2d) + + +@register_fake_quantization_to_integer("nn.avg_pool2d") +def avgpool2d(expr, type_map): + """Rewrite a avgpool op""" + arg = expr.args[0] + t = type_map[arg] + arg = relay.op.cast(arg, "int32") + out = relay.op.nn.avg_pool2d(arg, **expr.attrs) + out = relay.op.cast(out, t.dtype) + return [out, t.scale, t.zero_point, t.dtype] + + +@register_fake_quantization_to_integer("nn.bias_add") +def bias_add(expr, type_map): + """Rewrite a bias_add op""" + x, b = expr.args + x_t = type_map[x] + b_t = type_map[b] + in_scale = fold_constant(x_t.scale) + in_zero_point = fold_constant(x_t.zero_point) + if not tvm.ir.structural_equal(x_t, b_t): + b = relay.qnn.op.requantize( + b, + b_t.scale, + b_t.zero_point, + in_scale, + in_zero_point, + out_dtype=xt.dtype, + ) + out = relay.op.nn.bias_add(x, b, **expr.attrs) + return [out, x_t.scale, x_t.zero_point, x_t.dtype] + + +@register_fake_quantization_to_integer("nn.conv2d") +def conv2d(expr, type_map): + """Rewrite a conv2d op""" + attrs = {**expr.attrs} + attrs.pop("out_dtype") + x, weight = expr.args + x_t = type_map[x] + w_t = type_map[weight] + conv_scale = fold_constant(x_t.scale * w_t.scale) + conv_zp = relay.const(0) + out = relay.qnn.op.conv2d( + x, weight, x_t.zero_point, w_t.zero_point, x_t.scale, w_t.scale, **attrs + ) + return [out, conv_scale, conv_zp, out.attrs.out_dtype] + + +@register_fake_quantization_to_integer("concatenate") +def concat(expr, type_map): + """Rewrite a concat op""" + scales = [] + zps = [] + for arg in expr.args[0].fields: + t = type_map[arg] + scales.append(t.scale) + zps.append(t.zero_point) + + out_type = type_map[expr] + + out = relay.qnn.op.concatenate( + expr.args[0], + relay.Tuple(scales), + relay.Tuple(zps), + out_type.scale, + out_type.zero_point, + **expr.attrs, + ) + return [out, out_type.scale, out_type.zero_point, out_type.dtype] + + +@register_fake_quantization_to_integer("clip") +def clip(expr, type_map): + """Rewrite a clip op""" + arg = expr.args[0] + t = type_map[arg] + amin = expr.attrs.a_min + amax = expr.attrs.a_max + scale = fold_constant(t.scale) + z_p = fold_constant(t.zero_point) + if isinstance(scale, relay.expr.Constant) and isinstance(z_p, relay.expr.Constant): + scale = scale.data.numpy().item() + z_p = z_p.data.numpy().item() + new_min = int(amin / scale + z_p) + new_max = int(amax / scale + z_p) + out = relay.op.clip(arg, new_min, new_max) + else: + amin = relay.op.round(relay.op.const(amin) / scale + z_p) + amax = relay.op.round(relay.op.const(amax) / scale + z_p) + out = relay.op.minimum(relay.op.maximum(arg, amin), amax) + return [out, t.scale, t.zero_point, t.dtype] diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 20e8bb94c501..20e045abab6c 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1171,3 +1171,31 @@ def AnnotateSpans(): The regsistered AnnotateSpans pass. """ return _ffi_api.AnnotateSpans() + + +def FakeQuantizationToInteger(): + # pylint: disable=anomalous-backslash-in-string + """ + Find regions of the graph of the form + + x w + | | + dq dq + \ / + op1 + | + op2 + | + q + + where q == qnn.quantize and dq = qnn.dequantize + and rewrite them into integer versions of op1 and op2 + + Rules for rewriting indivdual ops are in fake_quantization_to_integer.py + + Returns + ------- + ret : tvm.transform.Pass + The registered SimplifyExpr pass. + """ + return _ffi_api.FakeQuantizationToInteger() diff --git a/src/relay/transforms/fake_quantization_to_integer.cc b/src/relay/transforms/fake_quantization_to_integer.cc new file mode 100644 index 000000000000..1a3c459967bc --- /dev/null +++ b/src/relay/transforms/fake_quantization_to_integer.cc @@ -0,0 +1,300 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/transforms/quantize_fake_quantization.cc + * \brief A pass for taking fake quantized graphs and converting them + * to actual integer operations. + */ + +#include +#include +#include + +/* Description of FakeQuantizationToInteger + * + * The purpose of this pass is to find regions of the graph that follow + * the general pattern: + * + * x w + * | | + * dq dq + * \ / + * op1 + * | + * op2 + * | + * q + * + * and convert them into subgraphs with actual integer operations on x and w + * + * The pass does this via a multi-pass approach: + * + * The main pass is a MixedModeMutator that traverses the full graph searching for + * quantize operations + * + * The second pass is an ExprVisitor that recursively searches for subgraphs leading to the + * quantize for subtraphs bounded by dequantize operations. This pass extracts the affine + * types of the inputs for later processing, where affine denotes the transformation + * x_real = (x_affine - zero_point) * scale + * + * The third pass is an ExprMutator that recursively rewrites the subgraphs using packed funcs + * registered with the FTVMFakeQuantizationToInteger attribute. These packed funcs rewrite + * the ops based on the affine types of their inputs and then return the affine types of the + * new rewriten ops to pass that information down the stack during rewrite. + * + * After the second and third passes run, the first pass replaces the quantize with the + * rewritten subgraph and the processing continues + */ + +namespace tvm { +namespace relay { + +/*! + * \brief AffineType representation + * \sa AffineType + */ +class AffineTypeNode : public Object { + public: + /*! \brief The scale of this type */ + Expr scale; + /*! \brief The zero point of this type */ + Expr zero_point; + /*! \brief The data type of this type */ + DataType dtype; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("scale", &scale); + v->Visit("zero_point", &zero_point); + v->Visit("dtype", &dtype); + } + + bool SEqualReduce(const AffineTypeNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); + return equal(scale, other->scale) && equal(zero_point, other->zero_point) && + equal(dtype, other->dtype); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce->MarkGraphNode(); + hash_reduce(scale); + hash_reduce(zero_point); + hash_reduce(dtype); + } + + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + static constexpr const char* _type_key = "AffineTypeNode"; + TVM_DECLARE_BASE_OBJECT_INFO(AffineTypeNode, Object); +}; + +/*! + * \brief Managed reference to AffineTypes. + * \sa AffineTypeNode + */ +class AffineType : public ObjectRef { + public: + TVM_DLL AffineType(Expr scale, Expr zero_point, DataType dtype) { + ObjectPtr n = make_object(); + n->scale = std::move(scale); + n->zero_point = std::move(zero_point); + n->dtype = std::move(dtype); + data_ = std::move(n); + } + TVM_DEFINE_OBJECT_REF_METHODS(AffineType, ObjectRef, AffineTypeNode); +}; + +TVM_REGISTER_NODE_TYPE(AffineTypeNode); + +using ExprSet = std::unordered_set; +using ExprMap = std::unordered_map; +using AffineTypeMap = Map; + +using FTVMFakeQuantizationToInteger = + runtime::TypedPackedFunc(const Expr& expr, const AffineTypeMap& map)>; + +class SubgraphExtractor : public ExprVisitor { + public: + const ExprSet GetSubgraph(const Expr& expr) { + VisitExpr(expr); + ExprSet subgraph; + if (is_fake_quantized_) { + for (auto kv : this->visit_counter_) { + if (auto call_node = GetRef(kv.first).as()) { + if (call_node->op != quantize_op_) { + subgraph.insert(Downcast(GetRef(kv.first))); + } + } + } + } + return subgraph; + } + const AffineTypeMap GetAffineTypes() { return affine_types_; } + void VisitExpr(const Expr& expr) { + if (expr.as() == nullptr && expr.as() == nullptr && + expr.as() == nullptr) { + is_fake_quantized_ = false; + } else { + ExprVisitor::VisitExpr(expr); + } + } + + protected: + void VisitExpr_(const CallNode* call_node) override { + if (call_node->op == quantize_op_) { + // Only look at arg0 for quantize + VisitExpr(call_node->args[0]); + // Collect type of quantize ops + affine_types_.Set(GetRef(call_node), + AffineType(call_node->args[1], call_node->args[2], + call_node->checked_type().as()->dtype)); + } else if (call_node->op == dequantize_op_) { + // Collect type of dequantize ops + affine_types_.Set(GetRef(call_node), + AffineType(call_node->args[1], call_node->args[2], + call_node->args[0]->checked_type().as()->dtype)); + } else { + // run normally on everything else. + ExprVisitor::VisitExpr_(call_node); + } + } + + const Op quantize_op_ = Op::Get("qnn.quantize"); + const Op dequantize_op_ = Op::Get("qnn.dequantize"); + bool is_fake_quantized_ = true; + AffineTypeMap affine_types_; +}; + +class SubgraphMutator : public ExprMutator { + public: + SubgraphMutator(ExprSet subgraph, AffineTypeMap affine_types) + : subgraph_(subgraph), affine_types_(affine_types) {} + + Expr MutateSubgraph(const Expr& expr) { + if (subgraph_.size() == 0) { + return expr; + } + const CallNode* quantize_node = expr.as(); + ICHECK(quantize_node); + ICHECK(quantize_node->op == quantize_op_); + out_type_ = affine_types_[expr]; + static auto fqfq = + Op::GetAttrMap("FTVMFakeQuantizationToInteger"); + for (auto node : subgraph_) { + if (!fqfq.count(Downcast(node.as()->op))) { + // Only modify the subgraph if we have translation + // rules for every op + return expr; + } + } + return Mutate(expr); + } + + protected: + Expr VisitExpr_(const CallNode* call_node) { + Expr out; + + static auto fqfq = + Op::GetAttrMap("FTVMFakeQuantizationToInteger"); + Op op = Downcast(call_node->op); + if (fqfq.count(op)) { + Expr expr; + if (op == dequantize_op_) { + expr = GetRef(call_node); + } else { + expr = ExprMutator::VisitExpr_(call_node); + // Set the current op to the output type, useful if we can't deduce output parameters + // from input parameters + affine_types_.Set(expr, out_type_); + } + // Call the rewrite + Array vals = fqfq[op](expr, affine_types_); + // Save teh outputs of the rewrite + ICHECK(vals.size() == 4) + << "got the wrong number of returned arguments from FTVMFakeQuantizationToInteger for " + << AsText(op, false); + out = Downcast(vals[0]); + affine_types_.Set(out, AffineType(Downcast(vals[1]), Downcast(vals[2]), + DataType(String2DLDataType(Downcast(vals[3]))))); + } else { + ICHECK(false) << "When rewriting a fake quantized graph, found an invalid node " + << AsText(GetRef(call_node), false); + } + return out; + } + ExprSet subgraph_; + AffineTypeMap affine_types_; + AffineType out_type_; + const Op quantize_op_ = Op::Get("qnn.quantize"); + const Op dequantize_op_ = Op::Get("qnn.dequantize"); +}; + +class FakeQuantizationRewriter : public MixedModeMutator { + protected: + Expr Rewrite_(const CallNode* pre, const Expr& post) override { + if (const CallNode* call_node = post.as()) { + if (call_node->op == quantize_op_) { + SubgraphExtractor extractor; + ExprSet subgraph = extractor.GetSubgraph(GetRef(pre)); + AffineTypeMap affine_types = extractor.GetAffineTypes(); + + ExprSet post_subgraph; + AffineTypeMap post_affine_types; + + for (auto kv : affine_types) { + if (pre == kv.first.as()) { + // we havent memoized the current op yet + post_affine_types.Set(post, kv.second); + } else { + post_affine_types.Set(memo_.at(kv.first), kv.second); + } + } + for (auto expr : subgraph) { + post_subgraph.insert(memo_[expr]); + } + Expr out = SubgraphMutator(post_subgraph, post_affine_types).MutateSubgraph(post); + return out; + } + } + return post; + } + const Op quantize_op_ = Op::Get("qnn.quantize"); +}; + +Expr FakeQuantizationToInteger(const Expr& expr, const IRModule& mod) { + return FakeQuantizationRewriter().Mutate(expr); +} + +namespace transform { + +Pass FakeQuantizationToInteger() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(FakeQuantizationToInteger(f, m)); + }; + return CreateFunctionPass(pass_func, 0, "FakeQuantizationToInteger", {"InferType"}); +} + +TVM_REGISTER_GLOBAL("relay._transform.FakeQuantizationToInteger") + .set_body_typed(FakeQuantizationToInteger); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py new file mode 100644 index 000000000000..3271379cf3ef --- /dev/null +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -0,0 +1,279 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-wildcard-import +import numpy as np +import pytest + +import tvm +from tvm import relay + + +def test_fake_quantize_conv(): + for out_dtype in ["int8", "uint8"]: + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + w = relay.var("w", shape=[16, 3, 5, 5], dtype="int8") + one = relay.const(1.0) + zero = relay.const(0) + + op = relay.op.nn.conv2d( + relay.qnn.op.dequantize(x, relay.const(2.0), zero), + relay.qnn.op.dequantize(w, relay.const(0.5), zero), + ) + op = relay.qnn.op.quantize(op, one, zero, out_dtype=out_dtype) + + mod = tvm.IRModule.from_expr(op) + mod = tvm.relay.transform.InferType()(mod) + + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") + w_np = np.random.randint(-128, 127, size=[16, 3, 5, 5], dtype="int8") + + mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) + assert not tvm.ir.structural_equal(mod, mod2) + mod2 = tvm.relay.transform.FoldConstant()(mod2) + + ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") + result = ex.evaluate()(x_np, w_np).asnumpy() + + ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") + result2 = ex.evaluate()(x_np, w_np).asnumpy() + + assert np.array_equal(result, result2) + + +def test_fake_transpose_quantize_conv(): + x = relay.var("x", shape=[1, 224, 224, 3], dtype="int8") + w = relay.var("w", shape=[16, 3, 5, 5], dtype="int8") + one = relay.const(1.0) + zero = relay.const(0) + + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + x = relay.transpose(x, [0, 3, 1, 2]) + op = relay.op.nn.conv2d(x, relay.qnn.op.dequantize(w, relay.const(0.5), zero)) + op = relay.qnn.op.quantize(op, one, zero) + + mod = tvm.IRModule.from_expr(op) + mod = tvm.relay.transform.InferType()(mod) + + x_np = np.random.randint(-128, 127, size=[1, 224, 224, 3], dtype="int8") + w_np = np.random.randint(-128, 127, size=[16, 3, 5, 5], dtype="int8") + + mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) + assert not tvm.ir.structural_equal(mod, mod2) + mod2 = tvm.relay.transform.FoldConstant()(mod2) + + ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") + result = ex.evaluate()(x_np, w_np).asnumpy() + + ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") + result2 = ex.evaluate()(x_np, w_np).asnumpy() + + assert np.array_equal(result, result2) + + +def test_fake_transpose_quantize_conv_bias_add(): + x = relay.var("x", shape=[1, 224, 224, 3], dtype="int8") + w = relay.var("w", shape=[16, 3, 5, 5], dtype="int8") + bias = relay.var("bias", shape=[16], dtype="int32") + one = relay.const(1.0) + zero = relay.const(0) + + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + x = relay.transpose(x, [0, 3, 1, 2]) + op = relay.op.nn.conv2d(x, relay.qnn.op.dequantize(w, relay.const(0.5), zero)) + op = relay.op.nn.bias_add(op, relay.qnn.op.dequantize(bias, one, zero)) + op = relay.qnn.op.quantize(op, one, zero) + + mod = tvm.IRModule.from_expr(op) + mod = tvm.relay.transform.InferType()(mod) + + x_np = np.random.randint(-128, 127, size=[1, 224, 224, 3], dtype="int8") + w_np = np.random.randint(-128, 127, size=[16, 3, 5, 5], dtype="int8") + bias_np = np.random.randint(-32768, 32767, size=[16], dtype="int32") + + mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) + assert not tvm.ir.structural_equal(mod, mod2) + mod2 = tvm.relay.transform.FoldConstant()(mod2) + + ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") + result = ex.evaluate()(x_np, w_np, bias_np).asnumpy() + + ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") + result2 = ex.evaluate()(x_np, w_np, bias_np).asnumpy() + + assert np.array_equal(result, result2) + + +def test_fake_quantize_maxpool(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + + zero = relay.const(0) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + op = relay.op.nn.max_pool2d(x, [3, 3]) + op = relay.qnn.op.quantize(op, relay.const(2.0), zero) + + mod = tvm.IRModule.from_expr(op) + mod = tvm.relay.transform.InferType()(mod) + + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") + + mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) + assert not tvm.ir.structural_equal(mod, mod2) + mod2 = tvm.relay.transform.FoldConstant()(mod2) + + ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") + result = ex.evaluate()(x_np).asnumpy() + + ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") + result2 = ex.evaluate()(x_np).asnumpy() + + assert np.array_equal(result, result2) + + +def test_fake_quantize_avgpool(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + + zero = relay.const(0) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + op = relay.op.nn.avg_pool2d(x, [3, 3]) + op = relay.qnn.op.quantize(op, relay.const(2.0), zero) + + mod = tvm.IRModule.from_expr(op) + mod = tvm.relay.transform.InferType()(mod) + + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") + + mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) + assert not tvm.ir.structural_equal(mod, mod2) + mod2 = tvm.relay.transform.FoldConstant()(mod2) + + ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") + result = ex.evaluate()(x_np).asnumpy() + + ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") + result2 = ex.evaluate()(x_np).asnumpy() + + assert np.all(np.abs(result - result2) <= 1) + + +def test_fake_quantize_reshape(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + + zero = relay.const(0) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + op = relay.op.reshape(x, [1, 3, -1]) + op = relay.qnn.op.quantize(op, relay.const(2.0), zero) + + mod = tvm.IRModule.from_expr(op) + mod = tvm.relay.transform.InferType()(mod) + + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") + + mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) + assert not tvm.ir.structural_equal(mod, mod2) + mod2 = tvm.relay.transform.FoldConstant()(mod2) + + ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") + result = ex.evaluate()(x_np).asnumpy() + + ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") + result2 = ex.evaluate()(x_np).asnumpy() + + assert np.array_equal(result, result2) + + +def test_fake_quantize_transpose_reshape(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + + zero = relay.const(0) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + op = relay.op.transpose(x, [1, 0, 2, 3]) + op = relay.op.reshape(op, [3, -1]) + op = relay.qnn.op.quantize(op, relay.const(2.0), zero) + + mod = tvm.IRModule.from_expr(op) + mod = tvm.relay.transform.InferType()(mod) + + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") + + mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) + assert not tvm.ir.structural_equal(mod, mod2) + mod2 = tvm.relay.transform.FoldConstant()(mod2) + + ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") + result = ex.evaluate()(x_np).asnumpy() + + ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") + result2 = ex.evaluate()(x_np).asnumpy() + + assert np.array_equal(result, result2) + + +def test_fake_quantize_concat(): + zero = relay.const(0) + inputs = [] + for i in range(4): + inputs.append( + relay.qnn.op.dequantize( + relay.var("x%d" % i, shape=[1, 4], dtype="int8"), relay.const(i + 0.5), zero + ) + ) + concat = relay.op.concatenate(inputs, axis=1) + out = relay.qnn.op.quantize(concat, relay.const(3.5), zero) + + mod = tvm.IRModule.from_expr(out) + mod = tvm.relay.transform.InferType()(mod) + + inputs_np = [] + for i in range(4): + inputs_np.append(np.random.randint(-128, 127, size=[1, 4], dtype="int8")) + + mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) + assert not tvm.ir.structural_equal(mod, mod2) + mod2 = tvm.relay.transform.FoldConstant()(mod2) + + ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") + result = ex.evaluate()(*inputs_np).asnumpy() + + ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") + result2 = ex.evaluate()(*inputs_np).asnumpy() + + assert np.array_equal(result, result2) + + +def test_fake_quantize_clip(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="uint8") + + x = relay.qnn.op.dequantize(x, relay.const(2.0), relay.const(114)) + op = relay.op.clip(x, 0, 6) + op = relay.qnn.op.quantize(op, relay.const(2.0), relay.const(114), out_dtype="uint8") + + mod = tvm.IRModule.from_expr(op) + mod = tvm.relay.transform.InferType()(mod) + + x_np = np.random.randint(0, 255, size=[1, 3, 224, 224], dtype="uint8") + + mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) + assert not tvm.ir.structural_equal(mod, mod2) + mod2 = tvm.relay.transform.FoldConstant()(mod2) + + ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") + result = ex.evaluate()(x_np).asnumpy() + + ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") + result2 = ex.evaluate()(x_np).asnumpy() + + assert np.array_equal(result, result2)