Skip to content

Commit

Permalink
[CMSIS-NN] Convert CMSIS-NN to use Target Hooks (apache#9397)
Browse files Browse the repository at this point in the history
* [CMSIS-NN] Convert CMSIS-NN to use Target Hooks

This migrates CMSIS-NN to use Target Hooks instead of fully BYOC, which
means it will now go through any central passes the Driver API.

* Mutated PrimFunc arguments in LowerTE so all functions are correctly lowered
* Made Target `cmsis-nn` to match external code generator `cmsis-nn` to connect the Target with the external code generator
* Modified Partition Graph to sanitise compiler names to generate them properly in C
* Port tvmc fixes for hybrid targets
* Update NPU tests with new sanitisation
  • Loading branch information
Mousius authored and mehrdadh committed Dec 1, 2021
1 parent a9f1eb3 commit 692ecc7
Show file tree
Hide file tree
Showing 16 changed files with 227 additions and 196 deletions.
4 changes: 3 additions & 1 deletion python/tvm/driver/tvmc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import tvm

from tvm.driver import tvmc
from tvm import relay
from tvm import transform
from tvm._ffi import registry
Expand Down Expand Up @@ -206,6 +207,7 @@ def parse_target(target):
a key-value for all options passed via CLI; 'raw',
containing the plain string for this codegen
"""
codegen_names = tvmc.composite_target.get_codegen_names()
codegens = []

tvm_target_kinds = tvm.target.Target.list_kinds()
Expand All @@ -232,7 +234,7 @@ def parse_target(target):
for codegen_def in split_codegens:
# the first is expected to be the name
name = codegen_def[0]
is_tvm_target = name in tvm_target_kinds
is_tvm_target = name in tvm_target_kinds and name not in codegen_names
raw_target = " ".join(codegen_def)
all_opts = codegen_def[1:] if len(codegen_def) > 1 else []
opts = {}
Expand Down
14 changes: 9 additions & 5 deletions python/tvm/driver/tvmc/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
This file contains functions for processing target inputs for the TVMC CLI
"""

from tvm.driver import tvmc
from tvm.target import Target

# We can't tell the type inside an Array but all current options are strings so
Expand All @@ -27,6 +28,11 @@
INTERNAL_TO_HELP = {"runtime.String": " string", "IntImm": "", "Array": " options"}


def _valid_target_kinds():
codegen_names = tvmc.composite_target.get_codegen_names()
return filter(lambda target: target not in codegen_names, Target.list_kinds())


def _generate_target_kind_args(parser, kind):
target_group = parser.add_argument_group(f"target {kind.name}")
for target_option, target_type in kind.options.items():
Expand All @@ -45,8 +51,7 @@ def generate_target_args(parser):
help="compilation target as plain string, inline JSON or path to a JSON file",
required=True,
)
target_kinds = Target.list_kinds()
for target_kind in target_kinds:
for target_kind in _valid_target_kinds():
target = Target(target_kind)
_generate_target_kind_args(parser, target.kind)

Expand All @@ -55,7 +60,7 @@ def _reconstruct_target_kind_args(args, kind):
kind_options = {}
for target_option, target_type in kind.options.items():
if target_type in INTERNAL_TO_NATIVE_TYPE:
var_name = f"target_{kind.name}_{target_option.replace('-', '_')}"
var_name = f"target_{kind.name.replace('-', '_')}_{target_option.replace('-', '_')}"
option_value = getattr(args, var_name)
if option_value is not None:
kind_options[target_option] = getattr(args, var_name)
Expand All @@ -64,9 +69,8 @@ def _reconstruct_target_kind_args(args, kind):

def reconstruct_target_args(args):
"""Reconstructs the target options from the arguments"""
target_kinds = Target.list_kinds()
reconstructed = {}
for target_kind in target_kinds:
for target_kind in _valid_target_kinds():
target = Target(target_kind)
kind_options = _reconstruct_target_kind_args(args, target.kind)
if kind_options:
Expand Down
15 changes: 8 additions & 7 deletions python/tvm/relay/op/contrib/cmsisnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=invalid-name, unused-argument
"""Arm(R) CMSIS-NN supported operators for Cortex-M."""
import tvm.ir
from tvm.target import Target
from tvm.relay import transform
from tvm.relay.build_module import bind_params_by_name

Expand All @@ -25,7 +26,7 @@


def enabled():
return bool(tvm.get_global_func("relay.ext.cmsisnn", True))
return "cmsis-nn" in Target.list_kinds()


def partition_for_cmsisnn(mod, params=None, **opts):
Expand All @@ -51,7 +52,7 @@ def partition_for_cmsisnn(mod, params=None, **opts):
[
transform.InferType(),
transform.MergeComposite(pattern_table()),
transform.AnnotateTarget("cmsisnn"),
transform.AnnotateTarget("cmsis-nn"),
transform.MergeCompilerRegions(),
transform.PartitionGraph(),
]
Expand All @@ -60,9 +61,9 @@ def partition_for_cmsisnn(mod, params=None, **opts):
return seq(mod)


@register_pattern_table("cmsisnn")
@register_pattern_table("cmsis-nn")
def pattern_table():
"""Get the cmsisnn compiler pattern table."""
"""Get the CMSIS-NN compiler pattern table."""

def softmax_pattern():
pattern = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant())
Expand Down Expand Up @@ -104,14 +105,14 @@ def check_quantized_binary_op(extract):
)

return [
("cmsisnn.quantized_softmax", softmax_pattern(), check_quantized_softmax),
("cmsis-nn.quantized_softmax", softmax_pattern(), check_quantized_softmax),
(
"cmsisnn.quantized_mul",
"cmsis-nn.quantized_mul",
binary_op_pattern("mul"),
check_quantized_binary_op,
),
(
"cmsisnn.quantized_add",
"cmsis-nn.quantized_add",
binary_op_pattern("add"),
check_quantized_binary_op,
),
Expand Down
5 changes: 4 additions & 1 deletion src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,10 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
mixed_pass_list.push_back(tir::transform::InferFragment());
mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());

if (target->GetAttr<Bool>("unpacked-api").value_or(Bool(false))) {
// The host Target contains these parameters at the moment rather than
// the specific Target
// TODO(Mousius) - Move these to the Executor object rather than Target
if (target->GetHost().value()->GetAttr<Bool>("unpacked-api").value_or(Bool(false))) {
mixed_pass_list.push_back(tir::transform::MakeUnpackedAPI());
} else {
mixed_pass_list.push_back(tir::transform::MakePackedAPI(-1));
Expand Down
127 changes: 73 additions & 54 deletions src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/ir/transform.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/tir/builtin.h>
Expand All @@ -33,29 +34,46 @@ namespace relay {
namespace contrib {
namespace cmsisnn {

class RelayToTIRVisitor : public MixedModeVisitor {
class RelayToTIRVisitor : public MixedModeMutator {
public:
explicit RelayToTIRVisitor(String func_name) : func_name_(func_name) {}
explicit RelayToTIRVisitor(IRModule ir_module, Target target)
: ir_module_(ir_module), target_(target) {}

tir::PrimFunc GetReplacementPrimFunc() { return primfunc_; }
IRModule Mutate() {
GlobalVar main_global_var = ir_module_->GetGlobalVar("main");
BaseFunc main = ir_module_->Lookup(main_global_var);
Function main_func = GetRef<Function>(main.as<FunctionNode>());

// Copy everything across and mutate the body
Function mutated_main =
Function(main_func->params, VisitExpr(main_func->body), main_func->ret_type,
main_func->type_params, main_func->attrs, main_func->span);

ir_module_->Update(main_global_var, mutated_main);

return ir_module_;
}

private:
inline IntImm ToArg(int32_t value) { return IntImm(DataType::Int(32), value); }

void CreatePrimFuncForExtern(Array<tir::Var> func_signature,
void CreatePrimFuncForExtern(const GlobalVar& global_var, Array<tir::Var> func_signature,
tvm::Array<PrimExpr> call_extern_args) {
Map<String, ObjectRef> dict_attrs;
dict_attrs.Set("global_symbol", func_name_);
dict_attrs.Set(tvm::attr::kGlobalSymbol, global_var->name_hint);
dict_attrs.Set(tvm::attr::kTarget, target_);
dict_attrs.Set("tir.noalias", Bool(true));

tir::Stmt body = tir::Evaluate(
tvm::tir::Call(DataType::Int(8), tir::builtin::call_extern(), call_extern_args));

primfunc_ = tir::PrimFunc(func_signature, body, VoidType(), Map<tir::Var, tir::Buffer>(),
DictAttrs(dict_attrs));
tir::PrimFunc replacement_func(func_signature, body, VoidType(), Map<tir::Var, tir::Buffer>(),
DictAttrs(dict_attrs));

ir_module_->Add(global_var, replacement_func);
}

void EmitSoftMax(const Expr& expr) {
void EmitSoftMax(const GlobalVar& global_var, const Expr& expr) {
auto* quantize_call = expr.as<CallNode>();
auto* softmax_call = quantize_call->args[0].as<CallNode>();
auto* dequant_call = softmax_call->args[0].as<CallNode>();
Expand Down Expand Up @@ -102,10 +120,10 @@ class RelayToTIRVisitor : public MixedModeVisitor {
out_var,
};

CreatePrimFuncForExtern(func_signature, args);
CreatePrimFuncForExtern(global_var, func_signature, args);
}

void EmitMul(const Expr& expr) {
void EmitMul(const GlobalVar& global_var, const Expr& expr) {
auto* mul_call = expr.as<CallNode>();

const float input_0_scale = GetScalarFromConstant<float>(mul_call->args[2]);
Expand Down Expand Up @@ -145,10 +163,10 @@ class RelayToTIRVisitor : public MixedModeVisitor {
tensor_size,
};

CreatePrimFuncForExtern(func_signature, args);
CreatePrimFuncForExtern(global_var, func_signature, args);
}

void EmitAdd(const Expr& expr) {
void EmitAdd(const GlobalVar& global_var, const Expr& expr) {
auto* add_call = expr.as<CallNode>();

const float input_0_scale = GetScalarFromConstant<float>(add_call->args[2]);
Expand Down Expand Up @@ -212,58 +230,59 @@ class RelayToTIRVisitor : public MixedModeVisitor {
tensor_size,
};

CreatePrimFuncForExtern(func_signature, args);
CreatePrimFuncForExtern(global_var, func_signature, args);
}

void VisitExpr_(const CallNode* call) final {
auto* func = call->op.as<FunctionNode>();
if (func == nullptr) {
return;
}

auto comp_name = func->GetAttr<String>(attr::kComposite);
if (comp_name.defined()) {
if (comp_name == "cmsisnn.quantized_softmax") {
EmitSoftMax(func->body);
Expr Rewrite_(const CallNode* pre, const Expr& post) override {
if (const CallNode* call = post.as<CallNode>()) {
auto* func = call->op.as<FunctionNode>();
if (func == nullptr) {
return post;
}
if (comp_name == "cmsisnn.quantized_mul") {
EmitMul(func->body);
}
if (comp_name == "cmsisnn.quantized_add") {
EmitAdd(func->body);

auto codegen_name = func->GetAttr<String>(attr::kCompiler);
if (codegen_name.defined() && codegen_name == "cmsis-nn") {
const CallNode* inner_call = func->body.as<CallNode>();
const FunctionNode* composite_func = inner_call->op.as<FunctionNode>();
auto comp_name = composite_func->GetAttr<String>(attr::kComposite);
auto func_name = func->GetAttr<String>(::tvm::attr::kGlobalSymbol);

GlobalVar new_global_var(func_name.value());
new_global_var->checked_type_ = composite_func->checked_type();

if (comp_name == "cmsis-nn.quantized_softmax") {
EmitSoftMax(new_global_var, composite_func->body);
}
if (comp_name == "cmsis-nn.quantized_mul") {
EmitMul(new_global_var, composite_func->body);
}
if (comp_name == "cmsis-nn.quantized_add") {
EmitAdd(new_global_var, composite_func->body);
}

Array<Expr> args;
for (const auto& arg : call->args) {
args.push_back(VisitExpr(arg));
}

return Call(new_global_var, args, call->attrs, call->type_args, call->span);
}
}
}

public:
String func_name_;
tir::PrimFunc primfunc_;
};

IRModule GenerateTIR(IRModule mod) {
String func_name;
Function func;

// Obtain external Relay Function that needs to be translated into TIR
ICHECK(mod->functions.size() == 1) << "Supports modules with single external Relay function.";
for (auto kv : mod->functions) {
func = Downcast<Function>(kv.second);
func_name = func->GetAttr<String>(tvm::attr::kGlobalSymbol).value();
return post;
}

// Prepare PrimFunc from Relay Function
auto relay_to_tir = RelayToTIRVisitor(func_name);
relay_to_tir.VisitExpr(func->body);

// Build the TIR IRModule from the generated PrimFunc
Map<GlobalVar, BaseFunc> var_func_map;
var_func_map.Set(GlobalVar(func_name), relay_to_tir.GetReplacementPrimFunc());
return IRModule(var_func_map);
}
private:
IRModule ir_module_;
Target target_;
};

transform::Pass RelayToTIR() {
tvm::transform::Pass RelayToTIR() {
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
[=](IRModule m, transform::PassContext pc) { return GenerateTIR(m); };
[=](IRModule ir_module, transform::PassContext pass_context) {
auto relay_to_tir = RelayToTIRVisitor(ir_module, Target("cmsis-nn"));
return relay_to_tir.Mutate();
};
return tvm::transform::CreateModulePass(pass_func, 0, "RelayToTIR", {});
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
Expand All @@ -16,34 +17,22 @@
* specific language governing permissions and limitations
* under the License.
*/

#include <tvm/relay/transform.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/target/target.h>

namespace tvm {

namespace relay {
namespace contrib {
namespace cmsisnn {

transform::Pass RelayToTIR();

runtime::Module CompileCMSISNN(const ObjectRef& ref) {
IRModule relay_mod;
Function relay_func = Downcast<Function>(ref);
auto func_name = relay_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
GlobalVar var = GlobalVar(func_name.value());
relay_mod->Add(var, relay_func);
relay_mod = transform::InferType()(relay_mod);

Array<transform::Pass> pass_seqs{transform::InferType(), RelayToTIR()};
transform::Sequential seq(pass_seqs);
IRModule tir_mod = seq(relay_mod);

const auto* pf = runtime::Registry::Get("runtime.CMSISNNModuleNodeCreate");
return (*pf)(tir_mod);
}
tvm::transform::Pass RelayToTIR();
runtime::Module TIRToRuntime(IRModule mod, Target target);

TVM_REGISTER_GLOBAL("relay.ext.cmsisnn").set_body_typed(CompileCMSISNN);
TVM_REGISTER_TARGET_KIND("cmsis-nn", kDLCPU)
.set_attr<FTVMRelayToTIR>("RelayToTIR", RelayToTIR())
.set_attr<FTVMTIRToRuntime>("TIRToRuntime", TIRToRuntime);

} // namespace cmsisnn
} // namespace contrib
Expand Down
Loading

0 comments on commit 692ecc7

Please sign in to comment.