Skip to content

Commit

Permalink
[Quantization] Fix annotation for multiply op (apache#4458)
Browse files Browse the repository at this point in the history
* fix mul rewrite

* register Realize Rewrite for global avg pool and add test

* remove unnecessary check

* improve the test case
  • Loading branch information
masahi authored and Xingyu Zhou committed Dec 13, 2019
1 parent 9a399ef commit 879a0e8
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 6 deletions.
6 changes: 4 additions & 2 deletions python/tvm/relay/quantize/_annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,10 @@ def multiply_rewrite(ref_call, new_args, ctx):
# quantize lhs to INPUT field
if lhs_kind == QAnnotateKind.ACTIVATION:
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
# quantize rhs to WEIGHT field
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
if _analysis.check_constant(rhs_expr):
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
else:
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)

Expand Down
7 changes: 3 additions & 4 deletions src/relay/pass/quantize/realize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -278,13 +278,9 @@ Expr MulRealize(const Call& ref_call,
DataType dtype = cfg->dtype_activation;
if (lhs->dtype != dtype) {
ldata = Cast(ldata, dtype);
} else {
CHECK_EQ(lhs->dtype, dtype);
}
if (rhs->dtype != dtype) {
rdata = Cast(rdata, dtype);
} else {
CHECK_EQ(rhs->dtype, dtype);
}

Expr ret = ForwardOp(ref_call, {ldata, rdata});
Expand Down Expand Up @@ -499,6 +495,9 @@ Expr AvgPoolRealize(const Call& ref_call,
RELAY_REGISTER_OP("nn.avg_pool2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);

RELAY_REGISTER_OP("nn.global_avg_pool2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);

Expr CastHintRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
Expand Down
49 changes: 49 additions & 0 deletions tests/python/relay/test_pass_auto_quantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import relay
from tvm.relay import testing


def quantize_and_build(out):
f = relay.Function(relay.analysis.free_vars(out), out)
mod, params = testing.create_workload(f)

with relay.quantize.qconfig(skip_conv_layers=[]):
qmod = relay.quantize.quantize(mod, params)

relay.build(qmod, "llvm", params=params)


def test_mul_rewrite():
"""a test case where rhs of mul is not constant"""
data = relay.var("data", shape=(1, 16, 64, 64))
multiplier = relay.sigmoid(relay.var("data", shape=(1, 16, 1, 1)))
conv = relay.nn.conv2d(data, relay.var("weight"),
kernel_size=(3, 3),
padding=(1, 1),
channels=16)
act = relay.nn.relu(data=conv)

quantize_and_build(act * multiplier)

pool = relay.nn.global_avg_pool2d(data=act)

quantize_and_build(act * pool)

if __name__ == "__main__":
test_mul_rewrite()

0 comments on commit 879a0e8

Please sign in to comment.