Skip to content

Commit

Permalink
[VTA] quant support for alu-only op (apache#6191)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghaohit authored and trevor-m committed Dec 4, 2020
1 parent d875587 commit e34fa32
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 8 deletions.
1 change: 1 addition & 0 deletions python/tvm/relay/quantize/_annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ def identity_rewrite(ref_call, new_args, ctx):
return QAnnotateExpr(ret_expr, x_kind)


register_annotate_function("reshape", identity_rewrite)
register_annotate_function("clip", identity_rewrite)
register_annotate_function("nn.relu", identity_rewrite)
register_annotate_function("strided_slice", identity_rewrite)
Expand Down
15 changes: 14 additions & 1 deletion python/tvm/relay/quantize/_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def add_partition_generic(ref_call, new_args, ctx):
# ...
lhs = new_args[0].realize()
rhs = new_args[1].realize()
return _forward_op(ref_call, [lhs, rhs])
return QPartitionExpr(_forward_op(ref_call, [lhs, rhs]))
if not lhs_cond and rhs_cond:
# - introduced by residual connection in ResNet
# ...
Expand Down Expand Up @@ -130,6 +130,7 @@ def mul_partition_generic(ref_call, new_args, ctx):

if lhs_cond:
# introduced by bn: multiply(out, scale)
lhs = new_args[0].realize()
return QPartitionExpr(_forward_op(ref_call, [lhs, rhs]))

if not lhs_cond and not rhs_cond:
Expand All @@ -155,3 +156,15 @@ def add_partition_function(ref_call, new_args, ctx):
def multiply_partition_function(ref_call, new_args, ctx):
"""Rewrite function for ewise multiply for partition"""
return mul_partition_generic(ref_call, new_args, ctx)


# add cast after the relu op to make it run on vta
@register_partition_function("nn.global_avg_pool2d")
def global_avg_pool2d_partition_function(ref_call, new_args, ctx):
cond, expr = partition_expr_check(new_args[0])
if cond:
expr = new_args[0].realize()
else:
expr = QPartitionExpr(new_args[0]).realize()

return _forward_op(ref_call, [expr])
28 changes: 21 additions & 7 deletions src/relay/quantize/realize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,8 @@ float ChooseDomScale(const std::vector<const QRealizeIntExprNode*>& nptrs) {

/* \brief Unify the dom scale of arguments */
Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args,
DataType* dtype_ptr, Expr* scale_ptr) {
DataType* dtype_ptr, Expr* scale_ptr,
DataType dtype = DataType::Void()) {
static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
const QConfig& cfg = QConfig::Current();

Expand All @@ -324,13 +325,15 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args

// unify the data type
ICHECK_EQ(ref_args.size(), args.size());
DataType dtype;

if (ret.size() == 2 && nptrs[1]->dtype == cfg->dtype_input) {
dtype = cfg->dtype_input;
} else {
dtype = cfg->dtype_activation;
if (dtype.is_void()) {
if (ret.size() == 2 && nptrs[1]->dtype == cfg->dtype_input) {
dtype = cfg->dtype_input;
} else {
dtype = cfg->dtype_activation;
}
}

for (size_t i = 0; i < ret.size(); ++i) {
auto ref_arg = ref_args[i].as<CallNode>();
if (nptrs[i]->dtype != dtype) {
Expand Down Expand Up @@ -361,7 +364,16 @@ Expr AddRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectR
if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) {
DataType dtype;
Expr dom_scale;
Array<Expr> ret_args = UnifyDTypeScale(ref_call->args, new_args, &dtype, &dom_scale);
// execute the operation with activation data type.
const QConfig& cfg = QConfig::Current();
Array<Expr> ret_args =
UnifyDTypeScale(ref_call->args, new_args, &dtype, &dom_scale, cfg->dtype_activation);
for (size_t i = 0; i < ret_args.size(); ++i) {
// do not fuse float32 arg
if (new_args[i].as<QRealizeIntExprNode>()->dtype == DataType::Float(32)) {
ret_args.Set(i, StopFusion(ret_args[i]));
}
}
Expr ret = ForwardOp(ref_call, ret_args);
return QRealizeIntExpr(ret, dom_scale, dtype);
}
Expand Down Expand Up @@ -430,6 +442,8 @@ Expr IdentityRealize(const Call& ref_call, const Array<Expr>& new_args, const Ob

RELAY_REGISTER_OP("nn.relu").set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);

RELAY_REGISTER_OP("reshape").set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);

RELAY_REGISTER_OP("strided_slice").set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);

RELAY_REGISTER_OP("nn.batch_flatten")
Expand Down

0 comments on commit e34fa32

Please sign in to comment.