Skip to content

Commit

Permalink
Code generation for Conv2D via CMSIS-NN
Browse files Browse the repository at this point in the history
Change-Id: I0a2279965a0b505f809ffcf8b955f64db8f4aff0
  • Loading branch information
ashutosh-arm committed Oct 20, 2021
1 parent 151696f commit 66eed5c
Show file tree
Hide file tree
Showing 11 changed files with 1,198 additions and 45 deletions.
84 changes: 64 additions & 20 deletions python/tvm/relay/op/contrib/cmsisnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,42 +47,93 @@ def partition_for_cmsisnn(mod, params=None, **opts):
if params:
mod["main"] = bind_params_by_name(mod["main"], params)

tvm._ffi._init_api("relay.ext.cmsisnn.transform", __name__)

seq = tvm.transform.Sequential(
[
transform.InferType(),
transform.MergeComposite(pattern_table()),
transform.AnnotateTarget("cmsisnn"),
transform.MergeCompilerRegions(),
transform.PartitionGraph(),
GenerateCMSISNNConstants(),
ExtractConstantsFromPartitionedFunction(),
transform.InferType(),
]
)

return seq(mod)


@register_pattern_table("cmsisnn")
def pattern_table():
"""Get the cmsisnn compiler pattern table."""

def softmax_pattern():
def qnn_softmax_pattern():
"""Create pattern for quantized softmax"""
pattern = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant())
pattern = is_op("nn.softmax")(pattern)
pattern = is_op("qnn.quantize")(pattern, is_constant(), is_constant())
return pattern

def check_quantized_softmax(extract):
def check_qnn_softmax(pattern):
"""Check if softmax is supported by CMSIS-NN."""
dequantize_call = extract.args[0].args[0]
scale = extract.args[1].data.numpy().item(0)
zero_point = extract.args[2].data.numpy().item(0)
dequantize_call = pattern.args[0].args[0]
scale = pattern.args[1].data.numpy().item(0)
zero_point = pattern.args[2].data.numpy().item(0)

# check for dtypes of quantize and dequantize
return (
(scale == 1.0 / 256 and zero_point == -128)
and extract.attrs.out_dtype == "int8"
and pattern.attrs.out_dtype == "int8"
and dequantize_call.args[0].checked_type.dtype == "int8"
)

def qnn_conv2d_pattern():
"""Create pattern for qnn.conv2D with optional fused relu."""
qnn_conv2d = is_op("qnn.conv2d")(
wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant()
).has_attr({"kernel_layout": "HWIO"})
bias_add = is_op("nn.bias_add")(qnn_conv2d, is_constant())
req = is_op("qnn.requantize")(
qnn_conv2d | bias_add, is_constant(), is_constant(), is_constant(), is_constant()
)
clip_or_req = req.optional(is_op("clip"))
return clip_or_req

def check_qnn_conv2d(pattern):
"""Check if the Conv2D is supported by CMSIS-NN."""
if str(pattern.op.name) == "clip":
relu = pattern
requantize = relu.args[0]
else:
requantize = pattern
requantize_input = requantize.args[0]
bias_add = None
bias_dtype = "int32"
if str(requantize_input.op.name) == "nn.bias_add":
bias_add = requantize_input
conv2d = bias_add.args[0]
bias_dtype = bias_add.args[1].checked_type.dtype
else:
conv2d = requantize_input
conv2d_input = conv2d.args[0]
conv2d_weight = conv2d.args[1]

# kernel zero_point should be 0
kernel_zp = conv2d.args[3].data.numpy()
kernel_zp = [kernel_zp] if kernel_zp.ndim == 0 else kernel_zp

return (
conv2d.attrs.kernel_layout == "HWIO"
and conv2d.attrs.out_dtype == "int32"
and conv2d.attrs.padding[2] == 0
and conv2d.attrs.padding[3] == 0
and conv2d_input.checked_type.dtype == "int8"
and conv2d_weight.checked_type.dtype == "int8"
and pattern.checked_type.dtype == "int8"
and bias_dtype == "int32"
and all([zp == 0 for zp in kernel_zp])
)

def binary_op_pattern(op):
"""Matches QNN binary operation"""
return is_op(f"qnn.{op}")(
Expand All @@ -96,23 +147,16 @@ def binary_op_pattern(op):
is_constant(),
)

def check_quantized_binary_op(extract):
def check_qnn_binary_op(extract):
"""Check if multiply is supported by CMSIS-NN."""
return (
extract.args[0].checked_type.dtype == "int8"
and extract.args[1].checked_type.dtype == "int8"
)

return [
("cmsisnn.quantized_softmax", softmax_pattern(), check_quantized_softmax),
(
"cmsisnn.quantized_mul",
binary_op_pattern("mul"),
check_quantized_binary_op,
),
(
"cmsisnn.quantized_add",
binary_op_pattern("add"),
check_quantized_binary_op,
),
("cmsisnn.qnn_softmax", qnn_softmax_pattern(), check_qnn_softmax),
("cmsisnn.qnn_conv2d", qnn_conv2d_pattern(), check_qnn_conv2d),
("cmsisnn.qnn_mul", binary_op_pattern("mul"), check_qnn_binary_op),
("cmsisnn.qnn_add", binary_op_pattern("add"), check_qnn_binary_op),
]
9 changes: 5 additions & 4 deletions src/relay/backend/contrib/cmsisnn/codegen_cmsisnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
#include <tvm/runtime/registry.h>

namespace tvm {
namespace codegen {
runtime::Module CMSISNNModuleNodeCreate(IRModule mod);
} // namespace codegen
namespace relay {
namespace contrib {
namespace cmsisnn {
Expand All @@ -33,14 +36,12 @@ runtime::Module CompileCMSISNN(const ObjectRef& ref) {
auto func_name = relay_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
GlobalVar var = GlobalVar(func_name.value());
relay_mod->Add(var, relay_func);
relay_mod = transform::InferType()(relay_mod);

Array<transform::Pass> pass_seqs{transform::InferType(), RelayToTIR()};
Array<transform::Pass> pass_seqs{RelayToTIR()};
transform::Sequential seq(pass_seqs);
IRModule tir_mod = seq(relay_mod);

const auto* pf = runtime::Registry::Get("runtime.CMSISNNModuleNodeCreate");
return (*pf)(tir_mod);
return tvm::codegen::CMSISNNModuleNodeCreate(tir_mod);
}

TVM_REGISTER_GLOBAL("relay.ext.cmsisnn").set_body_typed(CompileCMSISNN);
Expand Down
158 changes: 158 additions & 0 deletions src/relay/backend/contrib/cmsisnn/extract_constants.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@

/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/ndarray.h>

#include "../../../qnn/utils.h"
#include "../../../transforms/pattern_utils.h"

namespace tvm {
namespace relay {
namespace contrib {
namespace cmsisnn {

class ExtractConstantsMutator : public MixedModeMutator {
public:
explicit ExtractConstantsMutator(IRModule& mod) : mod_(mod) {}

private:
String gen_var_name() { return "tvm_var_extract_const_" + std::to_string(var_count_++); }

Expr VisitExpr_(const FunctionNode* func) final {
Function final_func = GetRef<Function>(func);
++func_nesting_level_;
auto new_body = VisitExpr(func->body);
--func_nesting_level_;
if (!new_body.same_as(func->body)) {
final_func = Function(FreeVars(new_body), new_body, func->ret_type,
FreeTypeVars(new_body, mod_), func->attrs);
function_to_constants_.Set(GetRef<Function>(func), constants_within_function_);
constants_within_function_.clear();
}
return final_func;
}

Expr Rewrite_(const CallNode* call, const Expr& post) final {
Expr final_call = post;
auto* post_call = post.as<CallNode>();
if (post_call == nullptr) {
return final_call;
}

// Replace Constant arguments with Vars for ML Operators
// Perform this for non-main Call Nodes only
if (func_nesting_level_ && call->op.as<OpNode>()) {
Array<Expr> new_args;
for (auto& arg : post_call->args) {
auto* const_arg = arg.as<ConstantNode>();
if (const_arg && !const_arg->is_scalar()) {
Var var_arg = Var(gen_var_name(), const_arg->tensor_type());
new_args.push_back(var_arg);
constants_within_function_.push_back(GetRef<Constant>(const_arg));
} else {
new_args.push_back(arg);
}
}
final_call = Call(call->op, new_args, call->attrs, {});
}

// Since the constants are kicked out of partitioned functions
// a new call to global function is needed
if (auto* glob_var_node = post_call->op.as<GlobalVarNode>()) {
auto glob_var = GetRef<GlobalVar>(glob_var_node);
auto glob_func = Downcast<Function>(mod_->Lookup(glob_var));
auto new_glob_func = VisitExpr(glob_func);
if (!new_glob_func.same_as(glob_func)) {
mod_->Update(glob_var, Downcast<Function>(new_glob_func));
Array<Expr> new_args = post_call->args;
ICHECK(function_to_constants_.find(glob_func) != function_to_constants_.end());
for (auto constant : function_to_constants_.at(glob_func)) {
new_args.push_back(constant);
}
final_call = Call(glob_var, new_args);
}
}

// Since the constants are kicked out of the local partitioned functions
// a new call to local function is needed
if (auto* func_node = call->op.as<FunctionNode>()) {
Function func = GetRef<Function>(func_node);
auto new_func = VisitExpr(func);
if (!new_func.same_as(func)) {
Array<Expr> new_args = post_call->args;
ICHECK(function_to_constants_.find(func) != function_to_constants_.end());
for (auto constant : function_to_constants_.at(func)) {
constants_within_function_.push_back(constant);
Var var_arg = Var(gen_var_name(), constant->tensor_type());
new_args.push_back(var_arg);
}
final_call = Call(new_func, new_args);
}
}

return final_call;
}

private:
/* \brief Updated module where all calls have replaced constants with new variables */
IRModule mod_;
/* \brief Maintains mapping of original function to the replaced constants */
Map<Function, Array<Constant>> function_to_constants_;
/* \brief Constants being kicked out of a function during the function visit */
Array<Constant> constants_within_function_;
/* \brief Keeps track of variables being created */
int var_count_ = 0;
/* \brief Keeps track of function scope */
int func_nesting_level_ = 0;
};

/*! * \brief Kicks out all constants out of the partitioned function into main() */
IRModule ExtractConstants(IRModule mod) {
String func_name;
Function func;

auto extract_constants = ExtractConstantsMutator(mod);
Function main_func = Downcast<Function>(mod->Lookup("main"));
auto new_main_body = extract_constants.VisitExpr(main_func->body);
if (!new_main_body.same_as(main_func->body)) {
auto main_var = mod->GetGlobalVar("main");
auto new_main_func = Function(main_func->params, new_main_body, main_func->ret_type,
main_func->type_params, main_func->attrs);
mod->Update(main_var, new_main_func);
}
return mod;
}

transform::Pass ExtractConstantsFromPartitionedFunction() {
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
[=](IRModule m, transform::PassContext pc) { return ExtractConstants(m); };
return tvm::transform::CreateModulePass(pass_func, 0, "ExtractConstantsFromPartitionedFunction",
{});
}

TVM_REGISTER_GLOBAL("relay.ext.cmsisnn.transform.ExtractConstantsFromPartitionedFunction")
.set_body_typed([]() { return ExtractConstantsFromPartitionedFunction(); });

} // namespace cmsisnn
} // namespace contrib
} // namespace relay
} // namespace tvm
Loading

0 comments on commit 66eed5c

Please sign in to comment.