Skip to content

Commit

Permalink
rename the pass from quantize_fake_quantization to fake_quantization_…
Browse files Browse the repository at this point in the history
…to_integer
  • Loading branch information
Matthew committed Jun 4, 2021
1 parent c761562 commit 92952bb
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 52 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
OpStrategy,
debug,
register_external_compiler,
register_quantize_fake_quantization,
register_fake_quantization_to_integer,
)
from . import strategy

Expand Down
10 changes: 5 additions & 5 deletions python/tvm/relay/op/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,24 +436,24 @@ def register_external_compiler(op_name, fexternal=None, level=10):
return tvm.ir.register_op_attr(op_name, "FTVMExternalCompiler", fexternal, level)


def register_quantize_fake_quantization(op_name, qfq=None, level=10):
def register_fake_quantization_to_integer(op_name, func=None, level=10):
"""Register quantize function for an op
Given an op and Affine Types on it's inputs, this function should return the op
in affine space and the new type of the output
in affine space/integer operators and the new type of the output
Parameters
----------
op_name : str
The name of the operator
qfq: function (expr: Expr, map: Map<Expr, AffineType>) -> new_expr: Expr
The function for translating the op into affine space
func: function (expr: Expr, map: Map<Expr, AffineType>) -> new_expr: Expr
The function for translating the op into affine space and integer operators
level : int
The priority level
"""
return tvm.ir.register_op_attr(op_name, "FTVMQuantizeFakeQuantization", qfq, level)
return tvm.ir.register_op_attr(op_name, "FTVMFakeQuantizationToInteger", func, level)


@tvm._ffi.register_func("relay.op.compiler._lower")
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
# transformation passes
from .transform import *
from .recast import recast
from . import quantize_fake_quantization
from . import fake_quantization_to_integer
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""Relay functions for rewriting fake quantized ops."""
import tvm
from tvm import relay
from ..op import register_quantize_fake_quantization
from ..op import register_fake_quantization_to_integer


def fold_constant(expr):
Expand All @@ -26,16 +26,16 @@ def fold_constant(expr):
return mod["main"].body


@register_quantize_fake_quantization("qnn.dequantize")
def dequantize_qfq(expr, type_map):
@register_fake_quantization_to_integer("qnn.dequantize")
def dequantize(expr, type_map):
"""Remove dequantize op"""
out = expr.args[0]
t = type_map[expr]
return [out, t.scale, t.zero_point, t.dtype]


@register_quantize_fake_quantization("qnn.quantize")
def quantize_qfq(expr, type_map):
@register_fake_quantization_to_integer("qnn.quantize")
def quantize(expr, type_map):
"""Turn a quantize op into requantize or remove it"""
out = expr.args[0]
t = type_map[out]
Expand All @@ -57,23 +57,24 @@ def quantize_qfq(expr, type_map):
return [out, expr.args[1], expr.args[2], expr.attrs.out_dtype]


def register_qfq_identity(op_name, op):
def register_unary_identity(op_name, op):
def identity(expr, type_map):
assert len(expr.args) == 1
arg = expr.args[0]
t = type_map[arg]
out = op(arg, **expr.attrs)
return [out, t.scale, t.zero_point, t.dtype]

return register_quantize_fake_quantization(op_name, identity)
return register_fake_quantization_to_integer(op_name, identity)


register_qfq_identity("reshape", relay.op.reshape)
register_qfq_identity("transpose", relay.op.transpose)
register_qfq_identity("nn.max_pool2d", relay.op.nn.max_pool2d)
register_unary_identity("reshape", relay.op.reshape)
register_unary_identity("transpose", relay.op.transpose)
register_unary_identity("nn.max_pool2d", relay.op.nn.max_pool2d)


@register_quantize_fake_quantization("nn.avg_pool2d")
def avgpool_qfq(expr, type_map):
@register_fake_quantization_to_integer("nn.avg_pool2d")
def avgpool2d(expr, type_map):
"""Rewrite a avgpool op"""
arg = expr.args[0]
t = type_map[arg]
Expand All @@ -83,8 +84,8 @@ def avgpool_qfq(expr, type_map):
return [out, t.scale, t.zero_point, t.dtype]


@register_quantize_fake_quantization("nn.bias_add")
def bias_add_qfq(expr, type_map):
@register_fake_quantization_to_integer("nn.bias_add")
def bias_add(expr, type_map):
"""Rewrite a bias_add op"""
x, b = expr.args
x_t = type_map[x]
Expand All @@ -104,8 +105,8 @@ def bias_add_qfq(expr, type_map):
return [out, x_t.scale, x_t.zero_point, x_t.dtype]


@register_quantize_fake_quantization("nn.conv2d")
def conv2d_qfq(expr, type_map):
@register_fake_quantization_to_integer("nn.conv2d")
def conv2d(expr, type_map):
"""Rewrite a conv2d op"""
attrs = {**expr.attrs}
attrs.pop("out_dtype")
Expand All @@ -120,8 +121,8 @@ def conv2d_qfq(expr, type_map):
return [out, conv_scale, conv_zp, out.attrs.out_dtype]


@register_quantize_fake_quantization("concatenate")
def concat_qfq(expr, type_map):
@register_fake_quantization_to_integer("concatenate")
def concat(expr, type_map):
"""Rewrite a concat op"""
scales = []
zps = []
Expand All @@ -143,8 +144,8 @@ def concat_qfq(expr, type_map):
return [out, out_type.scale, out_type.zero_point, out_type.dtype]


@register_quantize_fake_quantization("clip")
def clip_qfq(expr, type_map):
@register_fake_quantization_to_integer("clip")
def clip(expr, type_map):
"""Rewrite a clip op"""
arg = expr.args[0]
t = type_map[arg]
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,7 +1173,7 @@ def AnnotateSpans():
return _ffi_api.AnnotateSpans()


def QuantizeFakeQuantization():
def FakeQuantizationToInteger():
# pylint: disable=anomalous-backslash-in-string
"""
Find regions of the graph of the form
Expand All @@ -1191,11 +1191,11 @@ def QuantizeFakeQuantization():
where q == qnn.quantize and dq = qnn.dequantize
and rewrite them into integer versions of op1 and op2
Rules for rewriting indivdual ops are in quantize_fake_quantization.py
Rules for rewriting indivdual ops are in fake_quantization_to_integer.py
Returns
-------
ret : tvm.transform.Pass
The registered SimplifyExpr pass.
"""
return _ffi_api.QuantizeFakeQuantization()
return _ffi_api.FakeQuantizationToInteger()
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>

/* Description of QuantizeFakeQuantization
/* Description of FakeQuantizationToInteger
*
* The purpose of this pass is to find regions of the graph that follow
* the general pattern:
Expand All @@ -54,7 +54,7 @@
* types of the inputs for later processing
*
* The third pass is an ExprMutator that recursively rewrites the subgraphs using packed funcs
* registered with the FTVMQuantizeFakeQuantization attribute. These packed funcs rewrite
* registered with the FTVMFakeQuantizationToInteger attribute. These packed funcs rewrite
* the ops based on the affine types of their inputs and then return the affine types of the
* new rewriten ops to pass that information down the stack during rewrite.
*
Expand Down Expand Up @@ -125,7 +125,7 @@ using ExprSet = std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual>;
using ExprMap = std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual>;
using AffineTypeMap = Map<Expr, AffineType>;

using FTVMQuantizeFakeQuantization =
using FTVMFakeQuantizationToInteger =
runtime::TypedPackedFunc<Array<ObjectRef>(const Expr& expr, const AffineTypeMap& map)>;

class SubgraphExtractor : public ExprVisitor {
Expand Down Expand Up @@ -193,7 +193,8 @@ class SubgraphMutator : public ExprMutator {
ICHECK(quantize_node);
ICHECK(quantize_node->op == quantize_op_);
out_type_ = affine_types_[expr];
static auto fqfq = Op::GetAttrMap<FTVMQuantizeFakeQuantization>("FTVMQuantizeFakeQuantization");
static auto fqfq =
Op::GetAttrMap<FTVMFakeQuantizationToInteger>("FTVMFakeQuantizationToInteger");
for (auto node : subgraph_) {
if (!fqfq.count(Downcast<Op>(node.as<CallNode>()->op))) {
// Only modify the subgraph if we have translation
Expand All @@ -208,7 +209,8 @@ class SubgraphMutator : public ExprMutator {
Expr VisitExpr_(const CallNode* call_node) {
Expr out;

static auto fqfq = Op::GetAttrMap<FTVMQuantizeFakeQuantization>("FTVMQuantizeFakeQuantization");
static auto fqfq =
Op::GetAttrMap<FTVMFakeQuantizationToInteger>("FTVMFakeQuantizationToInteger");
Op op = Downcast<Op>(call_node->op);
if (fqfq.count(op)) {
Expr expr;
Expand All @@ -224,7 +226,7 @@ class SubgraphMutator : public ExprMutator {
Array<ObjectRef> vals = fqfq[op](expr, affine_types_);
// Save teh outputs of the rewrite
ICHECK(vals.size() == 4)
<< "got the wrong number of returned arguments from FTWMQuantizeFakeQuantization for "
<< "got the wrong number of returned arguments from FTVMFakeQuantizationToInteger for "
<< AsText(op, false);
out = Downcast<Expr>(vals[0]);
affine_types_.Set(out, AffineType(Downcast<Expr>(vals[1]), Downcast<Expr>(vals[2]),
Expand Down Expand Up @@ -274,22 +276,22 @@ class FakeQuantizationRewriter : public MixedModeMutator {
const Op quantize_op_ = Op::Get("qnn.quantize");
};

Expr QuantizeFakeQuantization(const Expr& expr, const IRModule& mod) {
Expr FakeQuantizationToInteger(const Expr& expr, const IRModule& mod) {
return FakeQuantizationRewriter().Mutate(expr);
}

namespace transform {

Pass QuantizeFakeQuantization() {
Pass FakeQuantizationToInteger() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(QuantizeFakeQuantization(f, m));
return Downcast<Function>(FakeQuantizationToInteger(f, m));
};
return CreateFunctionPass(pass_func, 0, "QuantizeFakeQuantization", {"InferType"});
return CreateFunctionPass(pass_func, 0, "FakeQuantizationToInteger", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.QuantizeFakeQuantization")
.set_body_typed(QuantizeFakeQuantization);
TVM_REGISTER_GLOBAL("relay._transform.FakeQuantizationToInteger")
.set_body_typed(FakeQuantizationToInteger);

} // namespace transform

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import *


def test_fake_quantize_conv():
Expand All @@ -42,7 +41,7 @@ def test_fake_quantize_conv():
x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8")
w_np = np.random.randint(-128, 127, size=[16, 3, 5, 5], dtype="int8")

mod2 = tvm.relay.transform.QuantizeFakeQuantization()(mod)
mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod)
assert not tvm.ir.structural_equal(mod, mod2)
mod2 = tvm.relay.transform.FoldConstant()(mod2)

Expand Down Expand Up @@ -72,7 +71,7 @@ def test_fake_transpose_quantize_conv():
x_np = np.random.randint(-128, 127, size=[1, 224, 224, 3], dtype="int8")
w_np = np.random.randint(-128, 127, size=[16, 3, 5, 5], dtype="int8")

mod2 = tvm.relay.transform.QuantizeFakeQuantization()(mod)
mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod)
assert not tvm.ir.structural_equal(mod, mod2)
mod2 = tvm.relay.transform.FoldConstant()(mod2)

Expand Down Expand Up @@ -105,7 +104,7 @@ def test_fake_transpose_quantize_conv_bias_add():
w_np = np.random.randint(-128, 127, size=[16, 3, 5, 5], dtype="int8")
bias_np = np.random.randint(-32768, 32767, size=[16], dtype="int32")

mod2 = tvm.relay.transform.QuantizeFakeQuantization()(mod)
mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod)
assert not tvm.ir.structural_equal(mod, mod2)
mod2 = tvm.relay.transform.FoldConstant()(mod2)

Expand All @@ -131,7 +130,7 @@ def test_fake_quantize_maxpool():

x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8")

mod2 = tvm.relay.transform.QuantizeFakeQuantization()(mod)
mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod)
assert not tvm.ir.structural_equal(mod, mod2)
mod2 = tvm.relay.transform.FoldConstant()(mod2)

Expand All @@ -157,7 +156,7 @@ def test_fake_quantize_avgpool():

x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8")

mod2 = tvm.relay.transform.QuantizeFakeQuantization()(mod)
mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod)
assert not tvm.ir.structural_equal(mod, mod2)
mod2 = tvm.relay.transform.FoldConstant()(mod2)

Expand All @@ -183,7 +182,7 @@ def test_fake_quantize_reshape():

x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8")

mod2 = tvm.relay.transform.QuantizeFakeQuantization()(mod)
mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod)
assert not tvm.ir.structural_equal(mod, mod2)
mod2 = tvm.relay.transform.FoldConstant()(mod2)

Expand All @@ -210,7 +209,7 @@ def test_fake_quantize_transpose_reshape():

x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8")

mod2 = tvm.relay.transform.QuantizeFakeQuantization()(mod)
mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod)
assert not tvm.ir.structural_equal(mod, mod2)
mod2 = tvm.relay.transform.FoldConstant()(mod2)

Expand Down Expand Up @@ -242,7 +241,7 @@ def test_fake_quantize_concat():
for i in range(4):
inputs_np.append(np.random.randint(-128, 127, size=[1, 4], dtype="int8"))

mod2 = tvm.relay.transform.QuantizeFakeQuantization()(mod)
mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod)
assert not tvm.ir.structural_equal(mod, mod2)
mod2 = tvm.relay.transform.FoldConstant()(mod2)

Expand All @@ -267,7 +266,7 @@ def test_fake_quantize_clip():

x_np = np.random.randint(0, 255, size=[1, 3, 224, 224], dtype="uint8")

mod2 = tvm.relay.transform.QuantizeFakeQuantization()(mod)
mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod)
assert not tvm.ir.structural_equal(mod, mod2)
mod2 = tvm.relay.transform.FoldConstant()(mod2)

Expand Down

0 comments on commit 92952bb

Please sign in to comment.