From 8c711d98132243b3bd5a62429f06fdcfc007f8c5 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Tue, 15 Feb 2022 14:56:20 +0000 Subject: [PATCH] [CMSIS-NN] enable USMP with CMSIS-NN "NULL" should be supplied as an extern arg if the context buffer is not allocated. Change-Id: I58866df69a27ce976ac45a57e2f261094a48e48e --- src/relay/backend/contrib/cmsisnn/relay_to_tir.cc | 14 ++++++-------- .../backend/contrib/cmsisnn/tir_to_runtime.cc | 9 ++++++++- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc index b50eadbd55c09..9e300b7c3abfd 100644 --- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc +++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc @@ -92,7 +92,7 @@ class RelayToTIRVisitor : public MixedModeMutator { void CreatePrimFuncForExtern(const GlobalVar& global_var, Array func_signature, const Map& buffer_map, tvm::Array call_extern_args, - tir::Var context_buffer_var = tir::Var{"NULL"}, + PrimExpr context_buffer_var = PrimExpr(), int context_buffer_size = 0) { Map dict_attrs; dict_attrs.Set(tvm::attr::kGlobalSymbol, global_var->name_hint); @@ -103,8 +103,8 @@ class RelayToTIRVisitor : public MixedModeMutator { tvm::tir::Call(DataType::Int(8), tir::builtin::call_extern(), call_extern_args)); if (context_buffer_size) { - body = tir::Allocate(context_buffer_var, DataType::Int(8), {context_buffer_size}, - tir::const_true(), body); + body = tir::Allocate(Downcast(context_buffer_var), DataType::Int(8), + {context_buffer_size}, tir::const_true(), body); } tir::PrimFunc replacement_func(func_signature, body, VoidType(), buffer_map, @@ -233,7 +233,7 @@ class RelayToTIRVisitor : public MixedModeMutator { call_ext_args.push_back(shift); call_ext_args.push_back(output); - tir::Var buffer_var{"NULL"}; + PrimExpr buffer_var = tir::StringImm("NULL"); CMSISNNFlags flags = GetCompilerFlags(transform::PassContext::Current()); size_t context_buffer_size; if (is_depthwise) { @@ -350,9 +350,7 @@ class RelayToTIRVisitor : public MixedModeMutator { call_ext_args.push_back(output); int context_buffer_size = 0; - std::string context_buffer_name = "NULL"; - tir::Var buffer_var = - tir::Var(context_buffer_name, PointerType(PrimType(DataType::Int(8)), "global.workspace")); + PrimExpr buffer_var = tir::StringImm("NULL"); tvm::Array context_buffer_args = {buffer_var, ToArg(context_buffer_size)}; scalar_args = tvm::runtime::Concat(context_buffer_args, scalar_args); @@ -435,7 +433,7 @@ class RelayToTIRVisitor : public MixedModeMutator { tvm::Array call_ext_args = {tir::StringImm(cmsisnn_api), input, output}; int context_buffer_size = 0; - tir::Var context_buffer_var{"NULL"}; + PrimExpr context_buffer_var = tir::StringImm("NULL"); if (pool_name == "cmsisnn.qnn_avg_pool2d") { CMSISNNFlags flags = GetCompilerFlags(transform::PassContext::Current()); int32_t input_c = qnn::get_const_int(input_shape[3]); diff --git a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc index 3a6731c5a040d..fbd0ff1707c0f 100644 --- a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc +++ b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc @@ -228,7 +228,14 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost { /*! * \brief extracts CMSIS-NN context buffer information */ CMSISNNContextBuffer extract_context_buffer_info(const CallNode* op, int base_pos) { CMSISNNContextBuffer context_buffer; - context_buffer.name = op->args[base_pos].as()->name_hint; + + // The argument could be a Var if it is allocated to hold the + // context buffer OR it will be a StringImm with "NULL" + if (op->args[base_pos]->IsInstance()) { + context_buffer.name = op->args[base_pos].as()->name_hint; + } else { + context_buffer.name = op->args[base_pos].as()->value; + } context_buffer.size = ValueFromArg(op, base_pos + 1); return context_buffer; }