Skip to content

Commit

Permalink
[QNN] Requantize operator (#3531)
Browse files Browse the repository at this point in the history
* [Relay] [Quantization] WIP - Common files for the qauntization work.

* [Relay] [Quantization] WIP - Prototyping requantize op.

* Requantize operator implementation.

Requantize converts one quantized tensor representation to another quantized
representation. The PR has following implementation features

- Requantize operator defined in qnn namespace - relay.qnn.requantize
- Lowering of the requantize to exisiting Relay operators
- Integer fixed point implementation of requantize
    - Two rounding modes - FE_UPWARDS (round towards infinity) and
    FE_AWAY_FROM_ZERO (std::round behavior)
- Floating point implementation as well, that can act as reference or can be
used for devices when FP32 computation is not used.
- Unit test cases

Relevant Issue - #2351

Credit to TFLite and GemmLowp to provide reference implementations.

* Typo and lint fixes.

* Doc fix.

* Uncommenting the lint script (fixing mistake).

* Modifying the unit tests.

* Moving C++ files into src/relay/qnn

* Moving python files to python/tvm/relay/qnn. Some minor fixes.

* Moving the attrs.h inside the include directory.

* Pushing files that I forgot earlier. Changing util location.

* Incorporating comments. API change. Lint fixes.

* Modifying the GetFixedPointMultiplierShift API as per comments.

* Forgot the dialect change.

* Changing rewrite to qnn_lower.

* Renaming Quantize to Qnn for clarity.

* Remove use_int_domain.

* Incorportaing review comments.

* Adding API doc for QNN dialect.

* Move the qnn_lower pass to transform namespace.

* Moving from expr to module. Adding namespace in C++.

* Minor sentence rewrites. Added qnn namespace.

* Added the API doc.

* Chanding default out_dtype to int8. Adding a test with in/out_dtype as uint8.

* Style fixes. Better error messages.

* Adding documentation.

* More documentation fixes.

* Adding out dtype check for requantize.

* Adding corner case for FP32 to fixed point conversion.

* Adding extra line.

* Documentation fix.

* Adding static inline.

* Incorporating jackwish comment. Removed idtype from requantize lowering.

* Removing Quantize/Dequantize code. Restricting Requantize to (u)int8/int32.

* Style fixes.

* Fix the docs.

* Move to Legalize API.
  • Loading branch information
anijain2305 authored and tqchen committed Aug 8, 2019
1 parent 60607ef commit a78adbd
Show file tree
Hide file tree
Showing 11 changed files with 854 additions and 0 deletions.
15 changes: 15 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,16 @@ This level support backpropagation of broadcast operators. It is temporary.
tvm.relay.contrib.adaptive_avg_pool2d


**Level 11: Dialect Operators**

This level supports dialect operators.

.. autosummary::
:nosignatures:

tvm.relay.qnn.op.requantize


Level 1 Definitions
-------------------
.. autofunction:: tvm.relay.log
Expand Down Expand Up @@ -340,3 +350,8 @@ Level 10 Definitions
.. autofunction:: tvm.relay.nn.batch_matmul
.. autofunction:: tvm.relay.contrib.adaptive_max_pool2d
.. autofunction:: tvm.relay.contrib.adaptive_avg_pool2d


Level 11 Definitions
--------------------
.. autofunction:: tvm.relay.qnn.op.requantize
71 changes: 71 additions & 0 deletions include/tvm/relay/qnn/attrs.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relay/qnn/attrs.h
* \brief Auxiliary attributes for qnn operators.
*/
#ifndef TVM_RELAY_QNN_ATTRS_H_
#define TVM_RELAY_QNN_ATTRS_H_

#include <tvm/attrs.h>
#include <string>

namespace tvm {
namespace relay {
namespace qnn {

/*! \brief Attribute for requantize operator */
struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> {
double input_scale;
int32_t input_zero_point;
double output_scale;
int32_t output_zero_point;
std::string rounding;
DataType out_dtype;

TVM_DECLARE_ATTRS(RequantizeAttrs, "relay.attrs.RequantizeAttrs") {
TVM_ATTR_FIELD(input_scale)
.describe("The scale of the input tensor.");
TVM_ATTR_FIELD(input_zero_point)
.describe("The zero point of the input tensor.");
TVM_ATTR_FIELD(output_scale)
.describe("The scale of the output tensor.");
TVM_ATTR_FIELD(output_zero_point)
.describe("The zero point of the output tensor.");
TVM_ATTR_FIELD(rounding).set_default("TONEAREST")
.describe("Defines the rounding direction when the value is midway between"
"two representable values. There are two supported modes - UPWARD"
"or TONEAREST. Both modes behave exactly same except at the"
"midpoints between the two representable values. At the midpoint,"
"UPWARD rounds towards positive infinity (for example -1.5 will be"
"rounded to -1). TONEAREST is the standard rounding where the"
"value is rounded away from zero at midpoints (for example, -1.5"
"rounds to -2). More context can be found at following gblic manual"
"https://www.gnu.org/software/libc/manual/html_node/Rounding.html.");
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
}
};

} // namespace qnn
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_QNN_ATTRS_H_
3 changes: 3 additions & 0 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@
from . import backend
from . import quantize

# Dialects
from . import qnn

from .scope_builder import ScopeBuilder

# Span
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/relay/qnn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import
"""QNN dialect operators and IR passes."""
from __future__ import absolute_import as _abs
from . import op
20 changes: 20 additions & 0 deletions python/tvm/relay/qnn/op/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import
"""Neural network related operators."""
from __future__ import absolute_import as _abs
from .qnn import *
20 changes: 20 additions & 0 deletions python/tvm/relay/qnn/op/_make.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Constructor APIs"""
from ...._ffi.function import _init_api

_init_api("relay.qnn.op._make", __name__)
74 changes: 74 additions & 0 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#pylint: disable=invalid-name
"""QNN dialect operators."""

from __future__ import absolute_import as _abs
from . import _make

def requantize(data,
input_scale,
input_zero_point,
output_scale,
output_zero_point,
rounding="TONEAREST",
out_dtype="int8"):
r"""Requantized operator.
The requantize operator converts one quantized tensor representation to
another quantized tensor representation. For the output tensor, we are
provided with output scale and zero point. The computation is as follows
Q_output = zp_output + (scale_input)/(scale_output) * (Q_input - zp_input)
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
input_scale: float
The quantization scale for the input tensor.
input_zero_point: int
The zero point of the input tensor.
output_scale: float
The quantization scale for the output tensor.
output_zero_point: int
The zero point of the output tensor.
rounding : string, optional
Defines the rounding direction when the value is midway between two
representable values.
out_dtype : str, optional
Specifies the output data type.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""

return _make.requantize(data,
input_scale,
input_zero_point,
output_scale,
output_zero_point,
rounding,
out_dtype)
20 changes: 20 additions & 0 deletions src/relay/pass/pattern_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,26 @@ inline Expr Variance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, b
}


static inline Expr Where(const Expr& condition, const Expr& x, const Expr& y) {
static const Op& op = Op::Get("where");
return CallNode::make(op, {condition, x, y});
}

static inline Expr GreaterEqual(const Expr& lhs, const Expr& rhs) {
static const Op& op = Op::Get("greater_equal");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}

static inline Expr Full(Expr fill_value,
Array<IndexExpr> shape,
DataType dtype) {
auto attrs = make_node<InitOpAttrs>();
attrs->shape = std::move(shape);
attrs->dtype = std::move(dtype);
static const Op& op = Op::Get("full");
return CallNode::make(op, {fill_value}, Attrs(attrs), {});
}

Expr MakeConcatenate(Expr data, int axis);

Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides);
Expand Down
Loading

0 comments on commit a78adbd

Please sign in to comment.