Skip to content

Commit

Permalink
[QNN] Enable constant folding for QNN operations. (apache#11228)
Browse files Browse the repository at this point in the history
* [QNN] Enable constant folding for QNN operations.

This commit enables constant folding for QNN operations.
This functionalty is disabled by default, use fold_qnn=True to enable.

Co-authored-by: Alexander Peskov <peskovnn@gmail.com>

* [NFC] Fixed comments

* Added more unit tests for QNN opers in constant folding pass.

* Address PR feedbacks

Co-authored-by: Alexander Peskov <peskovnn@gmail.com>
  • Loading branch information
2 people authored and Yuanjing Shi committed May 17, 2022
1 parent 1fd3375 commit 77a59c3
Show file tree
Hide file tree
Showing 5 changed files with 229 additions and 15 deletions.
10 changes: 9 additions & 1 deletion include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,17 @@ TVM_DLL Pass LazyGradientInit();
/*!
* \brief Fold constant expressions.
*
* Because of backward compatibility reason it skips QNN primitives from folding by default.
* There are some transformation passes like FakeQuantizationToInteger, which requires to keep QNN
* primitives for constant subgraphs. Uncontrolled constant folding of QNN primitives may break
* applicability of FakeQuantizationToInteger. We suggest to use FoldConstant pass with none
* default fold_qnn=True value only when all other QNN sensitive passes were already applied.
*
* \param fold_qnn Whether to fold constants for QNN operations.
*
* \return The pass.
*/
TVM_DLL Pass FoldConstant();
TVM_DLL Pass FoldConstant(bool fold_qnn = false);

/*!
* \brief Split function with huge number of arguments to smaller pieces.
Expand Down
21 changes: 17 additions & 4 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,32 +261,45 @@ def LazyGradientInit():
return _ffi_api.LazyGradientInit()


def FoldConstantExpr(expr, mod):
def FoldConstantExpr(expr, mod, fold_qnn=False):
"""Fold the constant expressions in a Relay program.
Parameters
----------
expr: Expr
The expression to fold
mod: IRModule
The module the expr lives in (for global calls)
fold_qnn: bool
Whether to fold constants for QNN operations.
Returns
-------
new_expr: Expr
The expr after Constant Folding
"""
return _ffi_api.FoldConstantExpr(expr, mod)
return _ffi_api.FoldConstantExpr(expr, mod, fold_qnn)


def FoldConstant():
def FoldConstant(fold_qnn=False):
"""Fold the constant expressions in a Relay program.
Because of backward compatibility reason it skips QNN primitives from folding by default.
There are some transformation passes like FakeQuantizationToInteger, which requires to keep QNN
primitives for constant subgraphs. Uncontrolled constant folding of QNN primitives may break
applicability of FakeQuantizationToInteger. We suggest to use FoldConstant pass with none
default fold_qnn=True value only when all other QNN sensitive passes were already applied.
Parameters
----------
fold_qnn: bool
Whether to fold constants for QNN operations.
Returns
-------
ret : tvm.transform.Pass
The registered pass for constant folding.
"""
return _ffi_api.FoldConstant()
return _ffi_api.FoldConstant(fold_qnn)


def FuseOps(fuse_opt_level=-1):
Expand Down
3 changes: 2 additions & 1 deletion src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <tvm/relay/feature.h>
#include <tvm/relay/interpreter.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/relay/qnn/transform.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/container/map.h>
#include <tvm/runtime/device_api.h>
Expand Down Expand Up @@ -948,7 +949,7 @@ IRModule Prepare(IRModule mod, CompilationConfig config) {
VirtualDevice host_virtual_device = config->host_virtual_device;
// Run minimal transforms on module to establish invariants needed by interpreter.
transform::Sequential seq(
{transform::SimplifyInference(),
{transform::SimplifyInference(), qnn::transform::Legalize(),
// Figure out which devices should be used to execute.
// TODO(mbs): Should ignore all existing annotations when constant folding
transform::PlanDevices(std::move(config)),
Expand Down
27 changes: 18 additions & 9 deletions src/relay/transforms/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,9 @@ bool IsComplexConstant(const Expr& expr) {
// or make a more powerful partial evaluator.
class ConstantFolder : public MixedModeMutator {
public:
explicit ConstantFolder(IRModule module)
explicit ConstantFolder(IRModule module, bool fold_qnn)
: module_(std::move(module)),
fold_qnn_(fold_qnn),
device_copy_op_(Op::Get("device_copy")),
shape_of_op_(Op::Get("shape_of")),
vm_shape_of_op_(Op::Get("vm.shape_of")),
Expand Down Expand Up @@ -158,8 +159,6 @@ class ConstantFolder : public MixedModeMutator {
return std::move(pre_call);
}

static auto fnoncomputational = Op::GetAttrMap<TNonComputational>("TNonComputational");

const auto* op_node = post_call->op.as<OpNode>();
if (op_node == nullptr) {
// Only evaluate primitives.
Expand All @@ -182,8 +181,15 @@ class ConstantFolder : public MixedModeMutator {
if (Optional<Expr> opt_result = EvaluateNdarraySize(pre_call)) {
return opt_result.value();
}
if ((fnoncomputational.count(op) && fnoncomputational[op]) || op == device_copy_op_ ||
op == shape_of_op_ || op == vm_shape_of_op_ || op == ndarray_size_op_) {
static auto fnoncomputational = Op::GetAttrMap<TNonComputational>("TNonComputational");
static auto qnn_canonicalize = Op::GetAttrMap<FTVMLegalize>("FTVMQnnCanonicalize");
bool is_no_qnn_canonicalized = !qnn_canonicalize.count(op);
bool is_no_computational = fnoncomputational.count(op) && fnoncomputational[op];
if (is_no_computational && (is_no_qnn_canonicalized || !fold_qnn_)) {
return std::move(post_call);
}
if (op == device_copy_op_ || op == shape_of_op_ || op == vm_shape_of_op_ ||
op == ndarray_size_op_) {
// We should think about potentially constant evaluation over these ops too.
return std::move(post_call);
}
Expand Down Expand Up @@ -387,6 +393,9 @@ class ConstantFolder : public MixedModeMutator {
// Module
IRModule module_;

// Whether to fold constants for QNN operations.
bool fold_qnn_;

// The kDLCPU device assumed to be available to the compiler. Used only when evaluating
// sub-expressions.
Device eval_cpu_dev_{kDLCPU, /*device_id=*/0};
Expand Down Expand Up @@ -417,20 +426,20 @@ TVM_REGISTER_GLOBAL("relay.analysis.check_constant").set_body_typed(IsComplexCon
* from their p.o.v. Furthermore, this function can be called before conversion to ANF so
* we must avoid all recursion.
*/
Expr FoldConstantExpr(const Expr& expr, const IRModule& mod) {
Expr FoldConstantExpr(const Expr& expr, const IRModule& mod, bool fold_qnn) {
VLOG_CONTEXT << "FoldConstantExpr";
VLOG(1) << "folding:" << std::endl << PrettyPrint(expr);
Expr result = ConstantFolder(mod).VisitExpr(expr);
Expr result = ConstantFolder(mod, fold_qnn).VisitExpr(expr);
VLOG(1) << "folded to:" << std::endl << PrettyPrint(result);
return result;
}

TVM_REGISTER_GLOBAL("relay._transform.FoldConstantExpr").set_body_typed(FoldConstantExpr);

Pass FoldConstant() {
Pass FoldConstant(bool fold_qnn) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(FoldConstantExpr(f, m));
return Downcast<Function>(FoldConstantExpr(f, m, fold_qnn));
};
return CreateFunctionPass(pass_func, 2, "FoldConstant", {});
}
Expand Down
183 changes: 183 additions & 0 deletions tests/python/relay/test_pass_fold_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,189 @@ def before():
tvm.ir.assert_structural_equal(run_infer_type(before_mod["main"]), after_mod["main"])


def test_fold_qnn_const():
def before():
# QNN op with 2 constant arguments.
add = relay.qnn.op.add(
relay.const(np.ones((2, 3), dtype="uint8"), dtype="uint8"),
relay.const(np.ones((2, 3), dtype="uint8"), dtype="uint8"),
lhs_scale=relay.const(2.0),
lhs_zero_point=relay.const(0),
rhs_scale=relay.const(2.0),
rhs_zero_point=relay.const(0),
output_scale=relay.const(1.0),
output_zero_point=relay.const(0),
)
# QNN op with 1 constant and 1 non-constant arguments.
a = relay.var("a", shape=[2, 3], dtype="float32")
dense = relay.qnn.op.dense(
relay.qnn.op.quantize(a, relay.const(1.0), relay.const(0)),
add,
input_zero_point=relay.const(0),
kernel_zero_point=relay.const(0),
input_scale=relay.const(2.0),
kernel_scale=relay.const(2.0),
units=None,
)
# QNN op with 2 non-constant arguments.
b = relay.var("b", shape=[2], dtype="float32")
bias = relay.qnn.op.add(
dense,
relay.qnn.op.quantize(b, relay.const(1.0), relay.const(0), out_dtype="int32"),
lhs_scale=relay.const(2.0),
lhs_zero_point=relay.const(0),
rhs_scale=relay.const(2.0),
rhs_zero_point=relay.const(0),
output_scale=relay.const(1.0),
output_zero_point=relay.const(0),
)
return relay.Function([a, b], bias)

def expected():
a = relay.var("a", shape=[2, 3], dtype="float32")
dense = relay.qnn.op.dense(
relay.qnn.op.quantize(a, relay.const(1.0), relay.const(0)),
relay.const(np.array([[4, 4, 4], [4, 4, 4]], dtype="uint8"), dtype="uint8"),
input_zero_point=relay.const(0),
kernel_zero_point=relay.const(0),
input_scale=relay.const(2.0),
kernel_scale=relay.const(2.0),
units=None,
)
b = relay.var("b", shape=[2], dtype="float32")
bias = relay.qnn.op.add(
dense,
relay.qnn.op.quantize(b, relay.const(1.0), relay.const(0), out_dtype="int32"),
lhs_scale=relay.const(2.0),
lhs_zero_point=relay.const(0),
rhs_scale=relay.const(2.0),
rhs_zero_point=relay.const(0),
output_scale=relay.const(1.0),
output_zero_point=relay.const(0),
)
return relay.Function([a, b], bias)

# Nothing changed after applying FoldConstant
a = run_opt_pass(before(), transform.FoldConstant())
b = run_opt_pass(before(), transform.InferType())
tvm.ir.assert_structural_equal(a, b)

# Fold QNN constants
a = run_opt_pass(before(), transform.FoldConstant(fold_qnn=True))
b = run_opt_pass(expected(), transform.InferType())
tvm.ir.assert_structural_equal(a, b)


def test_fold_quantize():
t = relay.TensorType([1, 2, 3], "int8")

def before():
data = tvm.nd.array(np.array([1.0, 2.0, 3.0], dtype="float32"))
const_fp = relay.const(data, dtype="float32")
const_i8 = relay.qnn.op.quantize(
const_fp, output_scale=relay.const(0.5), output_zero_point=relay.const(0)
)
x = relay.var("x", t)
sub = relay.op.subtract(x, const_i8)
func = relay.Function([x], sub)
return func

def expected():
data = tvm.nd.array(np.array([2, 4, 6], dtype="int8"))
const_i8 = relay.const(data, dtype="int8")
x = relay.var("x", t)
sub = relay.op.subtract(x, const_i8)
func = relay.Function([x], sub)
return func

# Nothing changed after applying FoldConstant
a = run_opt_pass(before(), transform.FoldConstant())
b = run_opt_pass(before(), transform.InferType())
tvm.ir.assert_structural_equal(a, b)

# Fold QNN constants
a = run_opt_pass(before(), transform.FoldConstant(fold_qnn=True))
b = run_opt_pass(expected(), transform.InferType())
tvm.ir.assert_structural_equal(a, b)


def test_fold_qnn_conv2d_qnn_mul():
def before():
dtype = "uint8"
op0 = relay.qnn.op.conv2d(
relay.const(np.ones((1, 1, 2, 2), dtype=dtype), dtype=dtype),
relay.const(np.ones((1, 1, 2, 2), dtype=dtype), dtype=dtype),
input_zero_point=relay.const(0, "int32"),
kernel_zero_point=relay.const(0, "int32"),
input_scale=relay.const(1.0, "float32"),
kernel_scale=relay.const(1.0, "float32"),
kernel_size=(2, 2),
channels=1,
)
op = relay.qnn.op.mul(
op0,
relay.const(np.array([10], dtype="int32"), dtype="int32"),
relay.const(1.0, dtype="float32"),
relay.const(0, dtype="int32"),
relay.const(1.0, dtype="float32"),
relay.const(0, dtype="int32"),
relay.const(1.0, dtype="float32"),
relay.const(0, dtype="int32"),
)
func = relay.Function([], op)
return func

def expected():
data = relay.const(np.array([[[[40]]]], dtype="int32"), dtype="int32")
func = relay.Function([], data)
return func

# Nothing changed after applying FoldConstant
a = run_opt_pass(before(), transform.FoldConstant())
b = run_opt_pass(before(), transform.InferType())
tvm.ir.assert_structural_equal(a, b)

# Fold QNN constants
a = run_opt_pass(before(), transform.FoldConstant(fold_qnn=True))
b = run_opt_pass(expected(), transform.InferType())
tvm.ir.assert_structural_equal(a, b)


def test_fold_requantize():
def before():
data = tvm.nd.array(np.array([1, 2, 3], dtype="int8"))
const_i8 = relay.const(data, dtype="int8")
op = relay.qnn.op.requantize(
const_i8,
input_scale=relay.const(2.0, dtype="float32"),
input_zero_point=relay.const(1, dtype="int32"),
output_scale=relay.const(1.0, dtype="float32"),
output_zero_point=relay.const(1, dtype="int32"),
)
x = relay.var("x", relay.TensorType([3], "int8"))
add = relay.op.add(op, x)
func = relay.Function([x], add)
return func

def expected():
data = tvm.nd.array(np.array([1, 3, 5], dtype="int8"))
const_i8 = relay.const(data, dtype="int8")
x = relay.var("x", relay.TensorType([3], "int8"))
add = relay.op.add(const_i8, x)
func = relay.Function([x], add)
return func

# Nothing changed after applying FoldConstant
a = run_opt_pass(before(), transform.FoldConstant())
b = run_opt_pass(before(), transform.InferType())
tvm.ir.assert_structural_equal(a, b)

# Fold QNN constants
a = run_opt_pass(before(), transform.FoldConstant(fold_qnn=True))
b = run_opt_pass(expected(), transform.InferType())
tvm.ir.assert_structural_equal(a, b)


def test_pass_link_params():
"""
This test checks ensures that proper executor is passed to interpreter instance
Expand Down

0 comments on commit 77a59c3

Please sign in to comment.