Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VTA] quant support for alu-only op #6191

Merged
merged 2 commits into from
Oct 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tvm/relay/quantize/_annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ def identity_rewrite(ref_call, new_args, ctx):
return QAnnotateExpr(ret_expr, x_kind)


register_annotate_function("reshape", identity_rewrite)
register_annotate_function("clip", identity_rewrite)
register_annotate_function("nn.relu", identity_rewrite)
register_annotate_function("strided_slice", identity_rewrite)
Expand Down
15 changes: 14 additions & 1 deletion python/tvm/relay/quantize/_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def add_partition_generic(ref_call, new_args, ctx):
# ...
lhs = new_args[0].realize()
rhs = new_args[1].realize()
return _forward_op(ref_call, [lhs, rhs])
return QPartitionExpr(_forward_op(ref_call, [lhs, rhs]))
if not lhs_cond and rhs_cond:
# - introduced by residual connection in ResNet
# ...
Expand Down Expand Up @@ -130,6 +130,7 @@ def mul_partition_generic(ref_call, new_args, ctx):

if lhs_cond:
# introduced by bn: multiply(out, scale)
lhs = new_args[0].realize()
return QPartitionExpr(_forward_op(ref_call, [lhs, rhs]))

if not lhs_cond and not rhs_cond:
Expand All @@ -155,3 +156,15 @@ def add_partition_function(ref_call, new_args, ctx):
def multiply_partition_function(ref_call, new_args, ctx):
"""Rewrite function for ewise multiply for partition"""
return mul_partition_generic(ref_call, new_args, ctx)


# add cast after the relu op to make it run on vta
@register_partition_function("nn.global_avg_pool2d")
def global_avg_pool2d_partition_function(ref_call, new_args, ctx):
cond, expr = partition_expr_check(new_args[0])
if cond:
expr = new_args[0].realize()
else:
expr = QPartitionExpr(new_args[0]).realize()

return _forward_op(ref_call, [expr])
28 changes: 21 additions & 7 deletions src/relay/quantize/realize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,8 @@ float ChooseDomScale(const std::vector<const QRealizeIntExprNode*>& nptrs) {

/* \brief Unify the dom scale of arguments */
Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args,
DataType* dtype_ptr, Expr* scale_ptr) {
DataType* dtype_ptr, Expr* scale_ptr,
DataType dtype = DataType::Void()) {
static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
const QConfig& cfg = QConfig::Current();

Expand All @@ -324,13 +325,15 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args

// unify the data type
ICHECK_EQ(ref_args.size(), args.size());
DataType dtype;

if (ret.size() == 2 && nptrs[1]->dtype == cfg->dtype_input) {
dtype = cfg->dtype_input;
} else {
dtype = cfg->dtype_activation;
if (dtype.is_void()) {
if (ret.size() == 2 && nptrs[1]->dtype == cfg->dtype_input) {
dtype = cfg->dtype_input;
} else {
dtype = cfg->dtype_activation;
}
}

for (size_t i = 0; i < ret.size(); ++i) {
auto ref_arg = ref_args[i].as<CallNode>();
if (nptrs[i]->dtype != dtype) {
Expand Down Expand Up @@ -361,7 +364,16 @@ Expr AddRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectR
if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) {
DataType dtype;
Expr dom_scale;
Array<Expr> ret_args = UnifyDTypeScale(ref_call->args, new_args, &dtype, &dom_scale);
// execute the operation with activation data type.
const QConfig& cfg = QConfig::Current();
Array<Expr> ret_args =
UnifyDTypeScale(ref_call->args, new_args, &dtype, &dom_scale, cfg->dtype_activation);
for (size_t i = 0; i < ret_args.size(); ++i) {
// do not fuse float32 arg
if (new_args[i].as<QRealizeIntExprNode>()->dtype == DataType::Float(32)) {
ret_args.Set(i, StopFusion(ret_args[i]));
}
}
Expr ret = ForwardOp(ref_call, ret_args);
return QRealizeIntExpr(ret, dom_scale, dtype);
}
Expand Down Expand Up @@ -430,6 +442,8 @@ Expr IdentityRealize(const Call& ref_call, const Array<Expr>& new_args, const Ob

RELAY_REGISTER_OP("nn.relu").set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);

RELAY_REGISTER_OP("reshape").set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);

RELAY_REGISTER_OP("strided_slice").set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);

RELAY_REGISTER_OP("nn.batch_flatten")
Expand Down