Skip to content

Commit

Permalink
[Relay] Convert a fake quantized or QAT graph into QNN ops (#8126)
Browse files Browse the repository at this point in the history
* Convert a fake quantized or QAT graph into qnn ops

* fix pylint

* fix typos

* use an identify function for some ops

* rename the pass from quantize_fake_quantization to fake_quantization_to_integer

* add definition for affine
  • Loading branch information
Matthew Brookhart authored Jun 8, 2021
1 parent 64a8e81 commit 9be0f4f
Show file tree
Hide file tree
Showing 8 changed files with 802 additions and 0 deletions.
6 changes: 6 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2465,6 +2465,12 @@ def _impl_v1(cls, inputs, attr, params):

@classmethod
def _impl_v11(cls, inputs, attr, params):
if len(inputs) == 3 and isinstance(inputs[2], _expr.Constant):
attr["max"] = inputs[2].data.asnumpy().item()
inputs = inputs[0:2]
if len(inputs) >= 2 and isinstance(inputs[1], _expr.Constant):
attr["min"] = inputs[1].data.asnumpy().item()
inputs = inputs[0:1]
if "min" in attr and "max" in attr:
return Clip.convert_attributes(inputs, attr, params)

Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
OpStrategy,
debug,
register_external_compiler,
register_fake_quantization_to_integer,
)
from . import strategy

Expand Down
21 changes: 21 additions & 0 deletions python/tvm/relay/op/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,27 @@ def register_external_compiler(op_name, fexternal=None, level=10):
return tvm.ir.register_op_attr(op_name, "FTVMExternalCompiler", fexternal, level)


def register_fake_quantization_to_integer(op_name, func=None, level=10):
"""Register quantize function for an op
Given an op and Affine Types on it's inputs, this function should return the op
in affine space/integer operators and the new type of the output, where affine
denotes the transformation x_real = (x_affine - zero_point) * scale
Parameters
----------
op_name : str
The name of the operator
func: function (expr: Expr, map: Map<Expr, AffineType>) -> new_expr: Expr
The function for translating the op into affine space and integer operators
level : int
The priority level
"""
return tvm.ir.register_op_attr(op_name, "FTVMFakeQuantizationToInteger", func, level)


@tvm._ffi.register_func("relay.op.compiler._lower")
def _lower(name, schedule, inputs, outputs):
return lower(schedule, list(inputs) + list(outputs), name=name)
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
# transformation passes
from .transform import *
from .recast import recast
from . import fake_quantization_to_integer
166 changes: 166 additions & 0 deletions python/tvm/relay/transform/fake_quantization_to_integer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Relay functions for rewriting fake quantized ops."""
import tvm
from tvm import relay
from ..op import register_fake_quantization_to_integer


def fold_constant(expr):
mod = tvm.IRModule.from_expr(expr)
mod = relay.transform.FoldConstant()(mod)
return mod["main"].body


@register_fake_quantization_to_integer("qnn.dequantize")
def dequantize(expr, type_map):
"""Remove dequantize op"""
out = expr.args[0]
t = type_map[expr]
return [out, t.scale, t.zero_point, t.dtype]


@register_fake_quantization_to_integer("qnn.quantize")
def quantize(expr, type_map):
"""Turn a quantize op into requantize or remove it"""
out = expr.args[0]
t = type_map[out]
in_scale = fold_constant(t.scale)
in_zero_point = fold_constant(t.zero_point)
if not (
tvm.ir.structural_equal(in_scale, expr.args[1])
and tvm.ir.structural_equal(in_zero_point, expr.args[2])
and tvm.ir.structural_equal(t.dtype, expr.attrs.out_dtype)
):
out = relay.qnn.op.requantize(
out,
in_scale,
in_zero_point,
expr.args[1],
expr.args[2],
out_dtype=expr.attrs.out_dtype,
)
return [out, expr.args[1], expr.args[2], expr.attrs.out_dtype]


def register_unary_identity(op_name, op):
def identity(expr, type_map):
assert len(expr.args) == 1
arg = expr.args[0]
t = type_map[arg]
out = op(arg, **expr.attrs)
return [out, t.scale, t.zero_point, t.dtype]

return register_fake_quantization_to_integer(op_name, identity)


register_unary_identity("reshape", relay.op.reshape)
register_unary_identity("transpose", relay.op.transpose)
register_unary_identity("nn.max_pool2d", relay.op.nn.max_pool2d)


@register_fake_quantization_to_integer("nn.avg_pool2d")
def avgpool2d(expr, type_map):
"""Rewrite a avgpool op"""
arg = expr.args[0]
t = type_map[arg]
arg = relay.op.cast(arg, "int32")
out = relay.op.nn.avg_pool2d(arg, **expr.attrs)
out = relay.op.cast(out, t.dtype)
return [out, t.scale, t.zero_point, t.dtype]


@register_fake_quantization_to_integer("nn.bias_add")
def bias_add(expr, type_map):
"""Rewrite a bias_add op"""
x, b = expr.args
x_t = type_map[x]
b_t = type_map[b]
in_scale = fold_constant(x_t.scale)
in_zero_point = fold_constant(x_t.zero_point)
if not tvm.ir.structural_equal(x_t, b_t):
b = relay.qnn.op.requantize(
b,
b_t.scale,
b_t.zero_point,
in_scale,
in_zero_point,
out_dtype=xt.dtype,
)
out = relay.op.nn.bias_add(x, b, **expr.attrs)
return [out, x_t.scale, x_t.zero_point, x_t.dtype]


@register_fake_quantization_to_integer("nn.conv2d")
def conv2d(expr, type_map):
"""Rewrite a conv2d op"""
attrs = {**expr.attrs}
attrs.pop("out_dtype")
x, weight = expr.args
x_t = type_map[x]
w_t = type_map[weight]
conv_scale = fold_constant(x_t.scale * w_t.scale)
conv_zp = relay.const(0)
out = relay.qnn.op.conv2d(
x, weight, x_t.zero_point, w_t.zero_point, x_t.scale, w_t.scale, **attrs
)
return [out, conv_scale, conv_zp, out.attrs.out_dtype]


@register_fake_quantization_to_integer("concatenate")
def concat(expr, type_map):
"""Rewrite a concat op"""
scales = []
zps = []
for arg in expr.args[0].fields:
t = type_map[arg]
scales.append(t.scale)
zps.append(t.zero_point)

out_type = type_map[expr]

out = relay.qnn.op.concatenate(
expr.args[0],
relay.Tuple(scales),
relay.Tuple(zps),
out_type.scale,
out_type.zero_point,
**expr.attrs,
)
return [out, out_type.scale, out_type.zero_point, out_type.dtype]


@register_fake_quantization_to_integer("clip")
def clip(expr, type_map):
"""Rewrite a clip op"""
arg = expr.args[0]
t = type_map[arg]
amin = expr.attrs.a_min
amax = expr.attrs.a_max
scale = fold_constant(t.scale)
z_p = fold_constant(t.zero_point)
if isinstance(scale, relay.expr.Constant) and isinstance(z_p, relay.expr.Constant):
scale = scale.data.numpy().item()
z_p = z_p.data.numpy().item()
new_min = int(amin / scale + z_p)
new_max = int(amax / scale + z_p)
out = relay.op.clip(arg, new_min, new_max)
else:
amin = relay.op.round(relay.op.const(amin) / scale + z_p)
amax = relay.op.round(relay.op.const(amax) / scale + z_p)
out = relay.op.minimum(relay.op.maximum(arg, amin), amax)
return [out, t.scale, t.zero_point, t.dtype]
28 changes: 28 additions & 0 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,3 +1171,31 @@ def AnnotateSpans():
The regsistered AnnotateSpans pass.
"""
return _ffi_api.AnnotateSpans()


def FakeQuantizationToInteger():
# pylint: disable=anomalous-backslash-in-string
"""
Find regions of the graph of the form
x w
| |
dq dq
\ /
op1
|
op2
|
q
where q == qnn.quantize and dq = qnn.dequantize
and rewrite them into integer versions of op1 and op2
Rules for rewriting indivdual ops are in fake_quantization_to_integer.py
Returns
-------
ret : tvm.transform.Pass
The registered SimplifyExpr pass.
"""
return _ffi_api.FakeQuantizationToInteger()
Loading

0 comments on commit 9be0f4f

Please sign in to comment.