From db2550ce3e52540e8b22c9b6b5938201f213b539 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 8 Jul 2019 05:58:36 +0000 Subject: [PATCH 1/4] [Relay][Quantization] Fix issue introduced in #3135 --- python/tvm/relay/quantize/_annotate.py | 1 - src/relay/pass/quantize.cc | 15 ++++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index 90bb2d08a8ed..0ab5a9594141 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -260,7 +260,6 @@ def add_rewrite(ref_call, new_args, ctx): if isinstance(rhs_expr, _expr.Constant): # quantize rhs to WEIGHT field if it is Constant rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT) - assert lhs_kind == QAnnotateKind.ACTIVATION expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) else: diff --git a/src/relay/pass/quantize.cc b/src/relay/pass/quantize.cc index 8220ca6b3bab..8842421d4885 100644 --- a/src/relay/pass/quantize.cc +++ b/src/relay/pass/quantize.cc @@ -413,14 +413,15 @@ Array UnifyDTypeScale(const Array& ref_args, // unify the data type CHECK_EQ(ref_args.size(), args.size()); DataType dtype; - if (nptrs[0]->dtype == cfg->dtype_activation) { - DataType dtype = cfg->dtype_activation; - ret.Set(1, Cast(ret[1], dtype)); - } else if (nptrs[1]->dtype == cfg->dtype_input) { - DataType dtype = cfg->dtype_input; - ret.Set(0, Cast(ret[0], dtype)); + if (ret.size() == 2 && nptrs[1]->dtype == cfg->dtype_input) { + dtype = cfg->dtype_input; } else { - LOG(FATAL) << "should not touch here."; + dtype = cfg->dtype_activation; + } + for (size_t i = 0; i < ret.size(); ++i) { + if (nptrs[i]->dtype != dtype) { + ret.Set(i, Cast(ret[i], dtype)); + } } // unify the dom_scale From 7db9e2dec520ac21a4ef37cdef50295ac03a7641 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 11 Jul 2019 05:06:00 +0000 Subject: [PATCH 2/4] Recover StopFusion --- src/relay/pass/quantize.cc | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/relay/pass/quantize.cc b/src/relay/pass/quantize.cc index 8842421d4885..699410a2a36b 100644 --- a/src/relay/pass/quantize.cc +++ b/src/relay/pass/quantize.cc @@ -395,10 +395,9 @@ float ChooseDomScale(const std::vector& nptrs) { /* \brief Unify the dom scale of arguments */ -Array UnifyDTypeScale(const Array& ref_args, - const Array& args, - DataType* dtype_ptr, - Expr* scale_ptr) { +Array UnifyDTypeScale(const Array& ref_args, const Array& args, + DataType* dtype_ptr, Expr* scale_ptr) { + static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize"); const QConfig& cfg = QConfig::Current(); std::vector nptrs; @@ -419,8 +418,14 @@ Array UnifyDTypeScale(const Array& ref_args, dtype = cfg->dtype_activation; } for (size_t i = 0; i < ret.size(); ++i) { + auto ref_arg = ref_args[i].as(); if (nptrs[i]->dtype != dtype) { ret.Set(i, Cast(ret[i], dtype)); + } else if (ref_arg && ref_arg->op.same_as(simulated_quantize) && + ref_arg->attrs.as()->kind == kQInput) { + auto new_arg = Cast(ret[i], cfg->dtype_input); + new_arg = StopFusion(new_arg); + ret.Set(i, Cast(new_arg, dtype)); } } From 173d8d35cbdf19cb4bcd8ded497e116a95048901 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 11 Jul 2019 08:08:02 +0000 Subject: [PATCH 3/4] Fix fmultiref --- python/tvm/relay/quantize/_annotate.py | 7 ++++++- src/relay/pass/quantize.cc | 19 ++----------------- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index 0ab5a9594141..8aadedc59913 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -265,7 +265,8 @@ def add_rewrite(ref_call, new_args, ctx): else: # quantize rhs to INPUT field if it is not Constant rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT) - raise ValueError + expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) + return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) if lhs_kind is not None and rhs_kind is not None: if lhs_kind == QAnnotateKind.INPUT and rhs_kind == QAnnotateKind.INPUT: @@ -276,6 +277,10 @@ def add_rewrite(ref_call, new_args, ctx): rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT) expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) + if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.INPUT: + expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) + return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) + raise ValueError() @register_annotate_function("stop_fusion") diff --git a/src/relay/pass/quantize.cc b/src/relay/pass/quantize.cc index 699410a2a36b..83d9220ccf79 100644 --- a/src/relay/pass/quantize.cc +++ b/src/relay/pass/quantize.cc @@ -135,22 +135,6 @@ TVM_REGISTER_API("relay._quantize.make_annotate_expr") }); -TVM_REGISTER_API("relay._quantize.annotate") -.set_body_typed([] (const Expr& expr) { - std::function fmulti_ref = [](const Expr& e) { - if (e->derived_from()) { - const auto* n = e.as(); - CHECK(n); - const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize"); - Expr ret = (*f)(n->expr, static_cast(kQInput)); - return static_cast(QAnnotateExprNode::make(ret, kQInput)); - } - return e; - }; - return ForwardRewrite(expr, "FQAnnotateRewrite", nullptr, nullptr); -}); - - // ============= // realize pass @@ -453,6 +437,7 @@ Expr AddRealize(const Call& ref_call, Expr ret = ForwardOp(ref_call, ret_args); return QRealizeIntExprNode::make(ret, dom_scale, dtype); } + CHECK(!new_args[0]->derived_from() && !new_args[1]->derived_from()); return Expr(nullptr); } @@ -680,7 +665,7 @@ Pass QuantizeAnnotate() { runtime::TypedPackedFunc pass_func = [=](Function f, Module m, PassContext pc) { - auto func = Downcast(ForwardRewrite(f, "FQAnnotateRewrite", fmulti_ref)); + auto func = Downcast(ForwardRewrite(f, "FQAnnotateRewrite", nullptr, fmulti_ref)); auto new_params = func->params; for (const auto& x : FreeVars(func)) { new_params.push_back(x); From f5bf1de37731706073eb629cb4a2e3bbc4e92a9a Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 11 Jul 2019 12:54:51 +0000 Subject: [PATCH 4/4] Fix lint --- python/tvm/relay/quantize/_annotate.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index 8aadedc59913..7b7f9c42f2f1 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -260,13 +260,11 @@ def add_rewrite(ref_call, new_args, ctx): if isinstance(rhs_expr, _expr.Constant): # quantize rhs to WEIGHT field if it is Constant rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT) - expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) - return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) else: # quantize rhs to INPUT field if it is not Constant rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT) - expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) - return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) + expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) + return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) if lhs_kind is not None and rhs_kind is not None: if lhs_kind == QAnnotateKind.INPUT and rhs_kind == QAnnotateKind.INPUT: