Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Quantization] Fix add_rewrite and UnifyDTypeScale #3534

Merged
merged 4 commits into from
Jul 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions python/tvm/relay/quantize/_annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,13 +260,11 @@ def add_rewrite(ref_call, new_args, ctx):
if isinstance(rhs_expr, _expr.Constant):
# quantize rhs to WEIGHT field if it is Constant
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
assert lhs_kind == QAnnotateKind.ACTIVATION
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
else:
# quantize rhs to INPUT field if it is not Constant
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
raise ValueError
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)

if lhs_kind is not None and rhs_kind is not None:
if lhs_kind == QAnnotateKind.INPUT and rhs_kind == QAnnotateKind.INPUT:
Expand All @@ -277,6 +275,10 @@ def add_rewrite(ref_call, new_args, ctx):
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.INPUT:
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
raise ValueError()


@register_annotate_function("stop_fusion")
Expand Down
47 changes: 19 additions & 28 deletions src/relay/pass/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,22 +135,6 @@ TVM_REGISTER_API("relay._quantize.make_annotate_expr")
});


TVM_REGISTER_API("relay._quantize.annotate")
.set_body_typed<Expr(Expr)>([] (const Expr& expr) {
std::function<Expr(const Expr&)> fmulti_ref = [](const Expr& e) {
if (e->derived_from<TempExprNode>()) {
const auto* n = e.as<QAnnotateExprNode>();
CHECK(n);
const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
Expr ret = (*f)(n->expr, static_cast<int>(kQInput));
return static_cast<Expr>(QAnnotateExprNode::make(ret, kQInput));
}
return e;
};
return ForwardRewrite(expr, "FQAnnotateRewrite", nullptr, nullptr);
});


// =============
// realize pass

Expand Down Expand Up @@ -395,10 +379,9 @@ float ChooseDomScale(const std::vector<const QRealizeIntExprNode*>& nptrs) {


/* \brief Unify the dom scale of arguments */
Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args,
const Array<Expr>& args,
DataType* dtype_ptr,
Expr* scale_ptr) {
Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args,
DataType* dtype_ptr, Expr* scale_ptr) {
static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
const QConfig& cfg = QConfig::Current();

std::vector<const QRealizeIntExprNode*> nptrs;
Expand All @@ -413,14 +396,21 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args,
// unify the data type
CHECK_EQ(ref_args.size(), args.size());
DataType dtype;
if (nptrs[0]->dtype == cfg->dtype_activation) {
DataType dtype = cfg->dtype_activation;
ret.Set(1, Cast(ret[1], dtype));
} else if (nptrs[1]->dtype == cfg->dtype_input) {
DataType dtype = cfg->dtype_input;
ret.Set(0, Cast(ret[0], dtype));
if (ret.size() == 2 && nptrs[1]->dtype == cfg->dtype_input) {
dtype = cfg->dtype_input;
} else {
LOG(FATAL) << "should not touch here.";
dtype = cfg->dtype_activation;
}
for (size_t i = 0; i < ret.size(); ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @vinx13 , I guess we can remove this part with rewrite_for_vta, which will insert cast and stop_fusion in the end of residual block

Copy link
Member Author

@vinx13 vinx13 Jul 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZihengJiang This part is needed when rewrite_for_vta is not enabled (store_lowbit_output=False), or we can remove this part in rewrite_for_vta?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vinx13 Current Realize pass has been used to achieve multiple goals: insert stop_fusion; decide datatype; unify scales for some operator, etc. It's better to decouple things for passes, I would also like to make rewrite_for_vta and store_lowbit_output enabled by default.

auto ref_arg = ref_args[i].as<CallNode>();
if (nptrs[i]->dtype != dtype) {
ret.Set(i, Cast(ret[i], dtype));
} else if (ref_arg && ref_arg->op.same_as(simulated_quantize) &&
ref_arg->attrs.as<SimulatedQuantizeAttrs>()->kind == kQInput) {
auto new_arg = Cast(ret[i], cfg->dtype_input);
new_arg = StopFusion(new_arg);
ret.Set(i, Cast(new_arg, dtype));
}
}

// unify the dom_scale
Expand All @@ -447,6 +437,7 @@ Expr AddRealize(const Call& ref_call,
Expr ret = ForwardOp(ref_call, ret_args);
return QRealizeIntExprNode::make(ret, dom_scale, dtype);
}

CHECK(!new_args[0]->derived_from<TempExprNode>() && !new_args[1]->derived_from<TempExprNode>());
return Expr(nullptr);
}
Expand Down Expand Up @@ -674,7 +665,7 @@ Pass QuantizeAnnotate() {

runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
auto func = Downcast<Function>(ForwardRewrite(f, "FQAnnotateRewrite", fmulti_ref));
auto func = Downcast<Function>(ForwardRewrite(f, "FQAnnotateRewrite", nullptr, fmulti_ref));
auto new_params = func->params;
for (const auto& x : FreeVars(func)) {
new_params.push_back(x);
Expand Down