Skip to content

Commit

Permalink
code review: tests for new passes, clean up of relay_to_tir for cmsis-nn
Browse files Browse the repository at this point in the history
Change-Id: Icd9ae4d456a75761f476f8ae73bff64d48e59dd5
  • Loading branch information
ashutosh-arm committed Oct 25, 2021
1 parent 66eed5c commit 49847bd
Show file tree
Hide file tree
Showing 9 changed files with 604 additions and 166 deletions.
9 changes: 4 additions & 5 deletions python/tvm/relay/op/contrib/cmsisnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from ...dataflow_pattern import is_constant, is_op, wildcard
from .register import register_pattern_table

tvm._ffi._init_api("relay.ext.cmsisnn.transform", __name__)


def enabled():
return bool(tvm.get_global_func("relay.ext.cmsisnn", True))
Expand All @@ -47,8 +49,6 @@ def partition_for_cmsisnn(mod, params=None, **opts):
if params:
mod["main"] = bind_params_by_name(mod["main"], params)

tvm._ffi._init_api("relay.ext.cmsisnn.transform", __name__)

seq = tvm.transform.Sequential(
[
transform.InferType(),
Expand Down Expand Up @@ -91,7 +91,7 @@ def qnn_conv2d_pattern():
"""Create pattern for qnn.conv2D with optional fused relu."""
qnn_conv2d = is_op("qnn.conv2d")(
wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant()
).has_attr({"kernel_layout": "HWIO"})
)
bias_add = is_op("nn.bias_add")(qnn_conv2d, is_constant())
req = is_op("qnn.requantize")(
qnn_conv2d | bias_add, is_constant(), is_constant(), is_constant(), is_constant()
Expand Down Expand Up @@ -123,8 +123,7 @@ def check_qnn_conv2d(pattern):
kernel_zp = [kernel_zp] if kernel_zp.ndim == 0 else kernel_zp

return (
conv2d.attrs.kernel_layout == "HWIO"
and conv2d.attrs.out_dtype == "int32"
conv2d.attrs.out_dtype == "int32"
and conv2d.attrs.padding[2] == 0
and conv2d.attrs.padding[3] == 0
and conv2d_input.checked_type.dtype == "int8"
Expand Down
56 changes: 33 additions & 23 deletions src/relay/backend/contrib/cmsisnn/extract_constants.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file extract_constant.cc
* \brief Pushes out constants within partitioned functions all the way upto main()
*/

#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
Expand All @@ -30,44 +35,47 @@ namespace relay {
namespace contrib {
namespace cmsisnn {

/*!
* \brief This Mutator finds all functions with constants. Constants are replaced with function
* parameter variables. Constants are pushed all the way upto main().
*/
class ExtractConstantsMutator : public MixedModeMutator {
public:
explicit ExtractConstantsMutator(IRModule& mod) : mod_(mod) {}
explicit ExtractConstantsMutator(const IRModule& mod) : mod_(mod) {}

private:
String gen_var_name() { return "tvm_var_extract_const_" + std::to_string(var_count_++); }

Expr VisitExpr_(const FunctionNode* func) final {
Function final_func = GetRef<Function>(func);
++func_nesting_level_;
Expr VisitExpr_(const FunctionNode* function) final {
Function func = GetRef<Function>(function);
function_to_constants_.Set(func, Array<Constant>{});
functions_.push_back(func);
auto new_body = VisitExpr(func->body);
--func_nesting_level_;
if (!new_body.same_as(func->body)) {
final_func = Function(FreeVars(new_body), new_body, func->ret_type,
FreeTypeVars(new_body, mod_), func->attrs);
function_to_constants_.Set(GetRef<Function>(func), constants_within_function_);
constants_within_function_.clear();
functions_.pop_back();
if (function_to_constants_[func].size()) {
func = Function(FreeVars(new_body), new_body, func->ret_type, FreeTypeVars(new_body, mod_),
func->attrs);
}
return final_func;
return func;
}

Expr Rewrite_(const CallNode* call, const Expr& post) final {
Expr final_call = post;
auto* post_call = post.as<CallNode>();
if (post_call == nullptr) {
return final_call;
}

// Replace Constant arguments with Vars for ML Operators
// Perform this for non-main Call Nodes only
if (func_nesting_level_ && call->op.as<OpNode>()) {
if (!functions_.empty() && call->op.as<OpNode>()) {
Array<Expr> new_args;
for (auto& arg : post_call->args) {
auto* const_arg = arg.as<ConstantNode>();
if (const_arg && !const_arg->is_scalar()) {
Var var_arg = Var(gen_var_name(), const_arg->tensor_type());
new_args.push_back(var_arg);
constants_within_function_.push_back(GetRef<Constant>(const_arg));
const Function& last_func = functions_.back();
Array<Constant> fconstants(function_to_constants_[last_func]);
fconstants.push_back(GetRef<Constant>(const_arg));
function_to_constants_.Set(last_func, fconstants);
} else {
new_args.push_back(arg);
}
Expand All @@ -94,17 +102,21 @@ class ExtractConstantsMutator : public MixedModeMutator {

// Since the constants are kicked out of the local partitioned functions
// a new call to local function is needed
// Also, pass on the constants to the callee of this function to support nested functions
if (auto* func_node = call->op.as<FunctionNode>()) {
Function func = GetRef<Function>(func_node);
auto new_func = VisitExpr(func);
if (!new_func.same_as(func)) {
Array<Expr> new_args = post_call->args;
ICHECK(function_to_constants_.find(func) != function_to_constants_.end());
const Function& last_func = functions_.back();
Array<Constant> fconstants(function_to_constants_[last_func]);
for (auto constant : function_to_constants_.at(func)) {
constants_within_function_.push_back(constant);
fconstants.push_back(constant);
Var var_arg = Var(gen_var_name(), constant->tensor_type());
new_args.push_back(var_arg);
}
function_to_constants_.Set(last_func, fconstants);
final_call = Call(new_func, new_args);
}
}
Expand All @@ -117,16 +129,14 @@ class ExtractConstantsMutator : public MixedModeMutator {
IRModule mod_;
/* \brief Maintains mapping of original function to the replaced constants */
Map<Function, Array<Constant>> function_to_constants_;
/* \brief Constants being kicked out of a function during the function visit */
Array<Constant> constants_within_function_;
/* \brief Stack of functions to determine scope while filling up function_to_constants_ */
Array<Function> functions_;
/* \brief Keeps track of variables being created */
int var_count_ = 0;
/* \brief Keeps track of function scope */
int func_nesting_level_ = 0;
};

/*! * \brief Kicks out all constants out of the partitioned function into main() */
IRModule ExtractConstants(IRModule mod) {
IRModule ExtractConstants(const IRModule& mod) {
String func_name;
Function func;

Expand All @@ -150,7 +160,7 @@ transform::Pass ExtractConstantsFromPartitionedFunction() {
}

TVM_REGISTER_GLOBAL("relay.ext.cmsisnn.transform.ExtractConstantsFromPartitionedFunction")
.set_body_typed([]() { return ExtractConstantsFromPartitionedFunction(); });
.set_body_typed(ExtractConstantsFromPartitionedFunction);

} // namespace cmsisnn
} // namespace contrib
Expand Down
34 changes: 24 additions & 10 deletions src/relay/backend/contrib/cmsisnn/generate_constants.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,35 @@
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file generate_constant.cc
* \brief Generates quantization parameters needed by CMSIS-NN
*/

#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/ndarray.h>

#include "../../../op/make_op.h"
#include "../../../qnn/utils.h"
#include "../../../transforms/pattern_utils.h"

namespace tvm {
namespace relay {
Expr MakeTranspose(Expr data, Array<Integer> axes);
namespace contrib {
namespace cmsisnn {

/*!
* \brief This Mutator will find all partitioned functions meant for CMSIS-NN Conv2D.
* It will substitute original Conv2D's weight zero point and original Requantize's input zero point
* with CMSIS-NN's quantization parameters.
* https://github.com/tensorflow/tflite-micro/blob/0f40100fc60276e9f345c23282de3baf19a78059/tensorflow/lite/kernels/internal/quantization_util.cc#L53
*/
class GenerateConstantsMutator : public MixedModeMutator {
public:
explicit GenerateConstantsMutator(IRModule& mod) : mod_(mod) {}
explicit GenerateConstantsMutator(const IRModule& mod) : mod_(mod) {}

private:
/*! * \brief Converts Kernel layout from HWIO to OHWI to align to CMSIS-NN requirements */
Expand All @@ -52,8 +63,15 @@ class GenerateConstantsMutator : public MixedModeMutator {
attrs->out_dtype = std::move(conv2d_attrs->out_dtype);
*new_attrs = tvm::Attrs{attrs};

std::string kernel_layout = conv2d_attrs->kernel_layout.c_str();
int pos_o = kernel_layout.find("O");
int pos_h = kernel_layout.find("H");
int pos_w = kernel_layout.find("W");
int pos_i = kernel_layout.find("I");

IRModule kernel_module;
auto func_body = MakeTranspose(kernel_expr, {Integer(3), Integer(0), Integer(1), Integer(2)});
auto func_body = MakeTranspose(
kernel_expr, {Integer(pos_o), Integer(pos_h), Integer(pos_w), Integer(pos_i)});
auto kernel_func =
Function(FreeVars(func_body), func_body, Type(), FreeTypeVars(func_body, kernel_module));
GlobalVar kernel_var("main");
Expand Down Expand Up @@ -158,9 +176,6 @@ class GenerateConstantsMutator : public MixedModeMutator {
Expr Rewrite_(const CallNode* call, const Expr& post) final {
Expr final_call = post;
auto* post_call = post.as<CallNode>();
if (post_call == nullptr) {
return final_call;
}

auto* global_var = call->op.as<GlobalVarNode>();
if (global_var) {
Expand Down Expand Up @@ -196,7 +211,7 @@ class GenerateConstantsMutator : public MixedModeMutator {
IRModule mod_;
};

IRModule GenerateConstants(IRModule mod) {
IRModule GenerateConstants(const IRModule& mod) {
String func_name;
Function func;

Expand All @@ -220,9 +235,8 @@ transform::Pass GenerateCMSISNNConstants() {
return tvm::transform::CreateModulePass(pass_func, 0, "GenerateCMSISNNConstants", {});
}

TVM_REGISTER_GLOBAL("relay.ext.cmsisnn.transform.GenerateCMSISNNConstants").set_body_typed([]() {
return GenerateCMSISNNConstants();
});
TVM_REGISTER_GLOBAL("relay.ext.cmsisnn.transform.GenerateCMSISNNConstants")
.set_body_typed(GenerateCMSISNNConstants);

} // namespace cmsisnn
} // namespace contrib
Expand Down
Loading

0 comments on commit 49847bd

Please sign in to comment.