Skip to content

Commit

Permalink
making quantization tweaks (#6731)
Browse files Browse the repository at this point in the history
  • Loading branch information
tmoreau89 authored Nov 7, 2020
1 parent 89ce1ed commit ff9c480
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 0 deletions.
43 changes: 43 additions & 0 deletions python/tvm/relay/quantize/_annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,28 @@ def conv2d_rewrite(ref_call, new_args, ctx):
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)


@register_annotate_function("nn.conv1d")
def conv1d_rewrite(ref_call, new_args, ctx):
"""Rewrite function for conv1d. Lhs of conv will be quantized to
input field, and rhs of conv will be quantized to weight field.
Output would be in activation field"""
if quantize_context().check_to_skip(ref_call):
return None

lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])

if lhs_kind is None or lhs_kind == QAnnotateKind.ACTIVATION:
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)

assert rhs_kind is None
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)

expr = _forward_op(ref_call, [lhs_expr, rhs_expr])

return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)


@register_annotate_function("nn.dense")
def dense_rewrite(ref_call, new_args, ctx):
"""Rewrite function for dense. Lhs of dense will be quantized to input field, and rhs of
Expand Down Expand Up @@ -289,6 +311,8 @@ def identity_rewrite(ref_call, new_args, ctx):
register_annotate_function("nn.relu", identity_rewrite)
register_annotate_function("strided_slice", identity_rewrite)
register_annotate_function("nn.avg_pool2d", identity_rewrite)
register_annotate_function("nn.batch_flatten", identity_rewrite)
register_annotate_function("transpose", identity_rewrite)
register_annotate_function("annotation.stop_fusion", identity_rewrite)


Expand All @@ -311,6 +335,25 @@ def pool2d_rewrite(ref_call, new_args, ctx):
register_annotate_function("nn.max_pool2d", pool2d_rewrite)


def pool1d_rewrite(ref_call, new_args, ctx):
"""Rewrite function for max pool1d"""
if quantize_context().check_to_skip(ref_call):
return None

expr, x_kind = _get_expr_kind(new_args[0])

if x_kind is None:
return None
if x_kind == QAnnotateKind.ACTIVATION:
expr = attach_simulated_quantize(expr, QAnnotateKind.INPUT)

expr = _forward_op(ref_call, [expr])
return QAnnotateExpr(expr, QAnnotateKind.INPUT)


register_annotate_function("nn.max_pool1d", pool1d_rewrite)


@register_annotate_function("annotation.cast_hint")
def cast_hint_rewrite(ref_call, new_args, ctx):
"""Rewrite function to force cast"""
Expand Down
36 changes: 36 additions & 0 deletions src/relay/quantize/realize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,37 @@ Expr Conv2dRealize(const Call& ref_call, const Array<Expr>& new_args, const Obje

RELAY_REGISTER_OP("nn.conv2d").set_attr<FForwardRewrite>("FQRealizeRewrite", Conv2dRealize);

Expr Conv1dRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
const QConfig& cfg = QConfig::Current();
CHECK_EQ(new_args.size(), 2);
if (!new_args[0]->IsInstance<TempExprNode>() && !new_args[1]->IsInstance<TempExprNode>()) {
return Expr(nullptr);
}
const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
CHECK(lhs);
const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
CHECK(rhs);

Expr ldata = lhs->data;
if (lhs->dtype != cfg->dtype_input) {
ldata = Cast(ldata, cfg->dtype_input);
}
Expr rdata = Cast(rhs->data, cfg->dtype_weight);

const auto ref_attrs = ref_call->attrs.as<Conv1DAttrs>();
auto attrs = make_object<Conv1DAttrs>();
*attrs = *ref_attrs;
DataType out_dtype = cfg->dtype_activation;
attrs->out_dtype = out_dtype;

Expr ret = Call(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args);
Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
Expr dom_scale = FoldConstantOpt(mul);
return QRealizeIntExpr(ret, dom_scale, out_dtype);
}

RELAY_REGISTER_OP("nn.conv1d").set_attr<FForwardRewrite>("FQRealizeRewrite", Conv1dRealize);

Expr DenseRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
const QConfig& cfg = QConfig::Current();
ICHECK_EQ(new_args.size(), 2);
Expand Down Expand Up @@ -449,6 +480,8 @@ RELAY_REGISTER_OP("strided_slice").set_attr<FForwardRewrite>("FQRealizeRewrite",
RELAY_REGISTER_OP("nn.batch_flatten")
.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);

RELAY_REGISTER_OP("transpose").set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);

RELAY_REGISTER_OP("annotation.stop_fusion")
.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);

Expand All @@ -469,6 +502,9 @@ Expr CastDtypeInputRealize(const Call& ref_call, const Array<Expr>& new_args,
RELAY_REGISTER_OP("nn.max_pool2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", CastDtypeInputRealize);

RELAY_REGISTER_OP("nn.max_pool1d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", CastDtypeInputRealize);

Expr AvgPoolRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
const QConfig& cfg = QConfig::Current();
ICHECK_EQ(new_args.size(), 1);
Expand Down

0 comments on commit ff9c480

Please sign in to comment.