From 9f5fec23bfc0c11240832913f11c40833caaf19a Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Fri, 11 Feb 2022 14:03:33 +0000 Subject: [PATCH 1/3] [CMSIS-NN] enable USMP with CMSIS-NN This commit mainly enables the USMP with CMSIS-NN codegen. In order to do that, CMSIS-NN functions needed to contain BufferMaps. This commit adds the necessary BufferMaps as well. All the tests are modified to run with USMP while the networks tests run with and without USMP. Change-Id: I18c7958addfff90b8243e9a6c70b7411158462fa --- .../backend/contrib/cmsisnn/relay_to_tir.cc | 165 ++++++++++-------- .../backend/contrib/cmsisnn/tir_to_runtime.cc | 2 +- .../contrib/test_cmsisnn/test_binary_ops.py | 5 +- .../contrib/test_cmsisnn/test_conv2d.py | 7 +- .../test_cmsisnn/test_fully_connected.py | 3 +- .../contrib/test_cmsisnn/test_networks.py | 5 +- .../contrib/test_cmsisnn/test_pooling.py | 3 +- .../contrib/test_cmsisnn/test_softmax.py | 3 +- tests/python/relay/aot/aot_test_utils.py | 14 ++ 9 files changed, 127 insertions(+), 80 deletions(-) diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc index f366e4ab2635..b50eadbd55c0 100644 --- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc +++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc @@ -37,6 +37,38 @@ namespace relay { namespace contrib { namespace cmsisnn { +/*! + * \brief This is a helper class to generate tir.Buffers and BufferMap + * + * The PrimFuncs generated needs Buffers produced to attach information + * about the inputs and output tir::Vars. This is helper class to generate + * them in the relay to TIR lowering + */ +class BufferCreator { + public: + /*! \brief Creates a tir::Var and tir::Buffer then returns the buffer var to be used by the body + */ + tir::Var CreateBufferVar(String name_hint, DataType dtype) { + tir::Var var = tir::Var(name_hint, dtype); + tir::Buffer buffer = tir::decl_buffer({}, DataType::Int(dtype.bits()), name_hint + "_"); + _primfunc_params_.push_back(var); + _buffer_map_.Set(var, buffer); + _buffer_vars_.Set(name_hint, buffer->data); + return buffer->data; + } + /*! \brief Access already created buffer_var by associated tir::Var name */ + tir::Var GetBufferVar(String name_hint) { return _buffer_vars_[name_hint]; } + /*! \brief Get the BufferMap that maps tir::Var to tir::Buffer */ + Map GetBufferMap() { return _buffer_map_; } + /*! \brief Get the PrimFunc params that is a collection of tir::Vars created in the process */ + Array GetPrimFuncParams() { return _primfunc_params_; } + + private: + Map _buffer_vars_; + Map _buffer_map_; + Array _primfunc_params_; +}; + class RelayToTIRVisitor : public MixedModeMutator { public: explicit RelayToTIRVisitor(IRModule ir_module, Target target) @@ -58,8 +90,9 @@ class RelayToTIRVisitor : public MixedModeMutator { inline IntImm ToArg(int32_t value) { return IntImm(DataType::Int(32), value); } void CreatePrimFuncForExtern(const GlobalVar& global_var, Array func_signature, + const Map& buffer_map, tvm::Array call_extern_args, - std::string context_buffer_name = "NULL", + tir::Var context_buffer_var = tir::Var{"NULL"}, int context_buffer_size = 0) { Map dict_attrs; dict_attrs.Set(tvm::attr::kGlobalSymbol, global_var->name_hint); @@ -70,16 +103,11 @@ class RelayToTIRVisitor : public MixedModeMutator { tvm::tir::Call(DataType::Int(8), tir::builtin::call_extern(), call_extern_args)); if (context_buffer_size) { - tir::Var buffer_var(context_buffer_name, - PointerType(PrimType(DataType::Int(8)), "global.workspace")); - body = tir::Allocate(buffer_var, DataType::Int(8), {context_buffer_size}, tir::const_true(), - body); - body = - tir::AttrStmt(PrimExpr(), tvm::tir::attr::device_type, target_->kind->device_type, body); - body = tir::AttrStmt(PrimExpr(), tvm::tir::attr::device_id, 0, body); + body = tir::Allocate(context_buffer_var, DataType::Int(8), {context_buffer_size}, + tir::const_true(), body); } - tir::PrimFunc replacement_func(func_signature, body, VoidType(), Map(), + tir::PrimFunc replacement_func(func_signature, body, VoidType(), buffer_map, DictAttrs(dict_attrs)); ir_module_->Add(global_var, replacement_func); } @@ -113,14 +141,17 @@ class RelayToTIRVisitor : public MixedModeMutator { // %3 = qnn.requantize(%2, %input_scale_const_4, %cmsisnn_shift_const_5, // %output_scale_scalar, %output_zero_point_scalar) // clip(%3, a_min=%min_scalar, a_max=%max_scalar) - tir::Var input("input", DataType::Handle(8)); - tir::Var filter("filter", DataType::Handle(8)); - tir::Var multiplier("multiplier", DataType::Handle(32)); - tir::Var filter_scale("filter_scale", DataType::Handle(32)); - tir::Var bias("bias", DataType::Handle(32)); - tir::Var input_scale("input_scale", DataType::Handle(32)); - tir::Var shift("shift", DataType::Handle(32)); - tir::Var output("output", DataType::Handle(8)); + BufferCreator buffer_creator_; + tir::Var input = buffer_creator_.CreateBufferVar("input", DataType::Handle(8)); + tir::Var filter = buffer_creator_.CreateBufferVar("filter", DataType::Handle(8)); + tir::Var multiplier = buffer_creator_.CreateBufferVar("multiplier", DataType::Handle(32)); + tir::Var filter_scale = buffer_creator_.CreateBufferVar("filter_scale", DataType::Handle(32)); + if (bias_add_call) { + buffer_creator_.CreateBufferVar("bias", DataType::Handle(32)); + } + tir::Var input_scale = buffer_creator_.CreateBufferVar("input_scale", DataType::Handle(32)); + tir::Var shift = buffer_creator_.CreateBufferVar("shift", DataType::Handle(32)); + tir::Var output = buffer_creator_.CreateBufferVar("output", DataType::Handle(8)); // Individual arguments to the structs arguments of the CMSIS-NN API are filled into call_extern // https://github.com/ARM-software/CMSIS_5/blob/def6f800f95661eb3451d317f7d0dde504f6020d/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_wrapper_s8.c#L50 @@ -196,12 +227,13 @@ class RelayToTIRVisitor : public MixedModeMutator { tvm::Array call_ext_args = {tir::StringImm(cmsisnn_api), input, filter, multiplier}; if (bias_add_call) { + tir::Var bias = buffer_creator_.GetBufferVar("bias"); call_ext_args.push_back(bias); } call_ext_args.push_back(shift); call_ext_args.push_back(output); - std::string context_buffer_name = "NULL"; + tir::Var buffer_var{"NULL"}; CMSISNNFlags flags = GetCompilerFlags(transform::PassContext::Current()); size_t context_buffer_size; if (is_depthwise) { @@ -214,10 +246,11 @@ class RelayToTIRVisitor : public MixedModeMutator { } if (context_buffer_size) { - context_buffer_name = "context_buffer_" + std::to_string(context_buffer_id_++); + String context_buffer_name = "context_buffer_" + std::to_string(context_buffer_id_++); + buffer_var = tir::Var(context_buffer_name, + PointerType(PrimType(DataType::Int(8)), "global.workspace")); } - tvm::Array context_buffer_args = {tir::StringImm(context_buffer_name), - ToArg(context_buffer_size)}; + tvm::Array context_buffer_args = {buffer_var, ToArg(context_buffer_size)}; scalar_args = tvm::runtime::Concat(context_buffer_args, scalar_args); scalar_args = tvm::runtime::Concat(scalar_args, input_shape); @@ -226,15 +259,8 @@ class RelayToTIRVisitor : public MixedModeMutator { scalar_args = tvm::runtime::Concat(scalar_args, output_shape); call_ext_args = tvm::runtime::Concat(call_ext_args, scalar_args); - Array func_signature{input, filter, multiplier, filter_scale}; - if (bias_add_call) { - func_signature.push_back(bias); - } - func_signature.push_back(input_scale); - func_signature.push_back(shift); - func_signature.push_back(output); - - CreatePrimFuncForExtern(global_var, func_signature, call_ext_args, context_buffer_name, + CreatePrimFuncForExtern(global_var, buffer_creator_.GetPrimFuncParams(), + buffer_creator_.GetBufferMap(), call_ext_args, buffer_var, context_buffer_size); } @@ -267,10 +293,13 @@ class RelayToTIRVisitor : public MixedModeMutator { // %3 = qnn.requantize(%2, %req_input_scale_scalar, %req_input_zero_point_scalar, // %output_scale_scalar, %output_zero_point_scalar) // clip(%3, a_min=%min_scalar, a_max=%max_scalar) - tir::Var input("input", DataType::Handle(8)); - tir::Var filter("filter", DataType::Handle(8)); - tir::Var bias("bias", DataType::Handle(32)); - tir::Var output("output", DataType::Handle(8)); + BufferCreator buffer_creator_; + tir::Var input = buffer_creator_.CreateBufferVar("input", DataType::Handle(8)); + tir::Var filter = buffer_creator_.CreateBufferVar("filter", DataType::Handle(8)); + if (bias_add_call) { + buffer_creator_.CreateBufferVar("bias", DataType::Handle(32)); + } + tir::Var output = buffer_creator_.CreateBufferVar("output", DataType::Handle(8)); // Individual arguments to the structs arguments of the CMSIS-NN API are filled into call_extern // https://github.com/ARM-software/CMSIS_5/blob/def6f800f95661eb3451d317f7d0dde504f6020d/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_wrapper_s8.c#L50 @@ -316,14 +345,15 @@ class RelayToTIRVisitor : public MixedModeMutator { tvm::Array call_ext_args = {tir::StringImm("arm_fully_connected_s8"), input, filter}; if (bias_add_call) { - call_ext_args.push_back(bias); + call_ext_args.push_back(buffer_creator_.GetBufferVar("bias")); } call_ext_args.push_back(output); int context_buffer_size = 0; std::string context_buffer_name = "NULL"; - tvm::Array context_buffer_args = {tir::StringImm(context_buffer_name), - ToArg(context_buffer_size)}; + tir::Var buffer_var = + tir::Var(context_buffer_name, PointerType(PrimType(DataType::Int(8)), "global.workspace")); + tvm::Array context_buffer_args = {buffer_var, ToArg(context_buffer_size)}; scalar_args = tvm::runtime::Concat(context_buffer_args, scalar_args); scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_input_shape); @@ -332,12 +362,8 @@ class RelayToTIRVisitor : public MixedModeMutator { scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_output_shape); call_ext_args = tvm::runtime::Concat(call_ext_args, scalar_args); - Array func_signature{input, filter}; - if (bias_add_call) { - func_signature.push_back(bias); - } - func_signature.push_back(output); - CreatePrimFuncForExtern(global_var, func_signature, call_ext_args, context_buffer_name, + CreatePrimFuncForExtern(global_var, buffer_creator_.GetPrimFuncParams(), + buffer_creator_.GetBufferMap(), call_ext_args, buffer_var, context_buffer_size); } @@ -403,20 +429,22 @@ class RelayToTIRVisitor : public MixedModeMutator { Array output_shape = pool->type_as()->shape; Array cmsisnn_output_shape{1, output_shape[1], output_shape[2], output_shape[3]}; - tir::Var input("input", DataType::Handle(8)); - tir::Var output("output", DataType::Handle(8)); + BufferCreator buffer_creator_; + tir::Var input = buffer_creator_.CreateBufferVar("input", DataType::Handle(8)); + tir::Var output = buffer_creator_.CreateBufferVar("output", DataType::Handle(8)); tvm::Array call_ext_args = {tir::StringImm(cmsisnn_api), input, output}; int context_buffer_size = 0; - std::string context_buffer_name = "NULL"; + tir::Var context_buffer_var{"NULL"}; if (pool_name == "cmsisnn.qnn_avg_pool2d") { CMSISNNFlags flags = GetCompilerFlags(transform::PassContext::Current()); int32_t input_c = qnn::get_const_int(input_shape[3]); context_buffer_size = AvgPoolBufferSize(flags, input_c); - context_buffer_name = "context_buffer_" + std::to_string(context_buffer_id_++); + std::string context_buffer_name = "context_buffer_" + std::to_string(context_buffer_id_++); + context_buffer_var = tir::Var(context_buffer_name, + PointerType(PrimType(DataType::Int(8)), "global.workspace")); } - tvm::Array context_buffer_args = {tir::StringImm(context_buffer_name), - ToArg(context_buffer_size)}; + tvm::Array context_buffer_args = {context_buffer_var, ToArg(context_buffer_size)}; scalar_args = tvm::runtime::Concat(context_buffer_args, scalar_args); scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_input_shape); @@ -424,9 +452,8 @@ class RelayToTIRVisitor : public MixedModeMutator { scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_output_shape); call_ext_args = tvm::runtime::Concat(call_ext_args, scalar_args); - Array func_signature{input, output}; - - CreatePrimFuncForExtern(global_var, func_signature, call_ext_args, context_buffer_name, + CreatePrimFuncForExtern(global_var, buffer_creator_.GetPrimFuncParams(), + buffer_creator_.GetBufferMap(), call_ext_args, context_buffer_var, context_buffer_size); } @@ -460,10 +487,9 @@ class RelayToTIRVisitor : public MixedModeMutator { diff_min >>= shift; diff_min *= -1; - auto in_var = tir::Var("input", DataType::Handle(8)); - auto out_var = tir::Var("output", DataType::Handle(8)); - - Array func_signature{in_var, out_var}; + BufferCreator buffer_creator_; + tir::Var in_var = buffer_creator_.CreateBufferVar("input", DataType::Handle(8)); + tir::Var out_var = buffer_creator_.CreateBufferVar("output", DataType::Handle(8)); tvm::Array args = { tir::StringImm("arm_softmax_s8"), @@ -476,7 +502,8 @@ class RelayToTIRVisitor : public MixedModeMutator { out_var, }; - CreatePrimFuncForExtern(global_var, func_signature, args); + CreatePrimFuncForExtern(global_var, buffer_creator_.GetPrimFuncParams(), + buffer_creator_.GetBufferMap(), args); } void EmitMul(const GlobalVar& global_var, const Expr& expr) { @@ -498,11 +525,10 @@ class RelayToTIRVisitor : public MixedModeMutator { PrimExpr tensor_size = mul_call->type_as()->Size(); - tir::Var input_0("input_0", DataType::Handle(8)); - tir::Var input_1("input_1", DataType::Handle(8)); - tir::Var output("output", DataType::Handle(8)); - - Array func_signature{input_0, input_1, output}; + BufferCreator buffer_creator_; + tir::Var input_0 = buffer_creator_.CreateBufferVar("input_0", DataType::Handle(8)); + tir::Var input_1 = buffer_creator_.CreateBufferVar("input_1", DataType::Handle(8)); + tir::Var output = buffer_creator_.CreateBufferVar("output", DataType::Handle(8)); tvm::Array args = { tir::StringImm("arm_elementwise_mul_s8"), @@ -519,7 +545,8 @@ class RelayToTIRVisitor : public MixedModeMutator { tensor_size, }; - CreatePrimFuncForExtern(global_var, func_signature, args); + CreatePrimFuncForExtern(global_var, buffer_creator_.GetPrimFuncParams(), + buffer_creator_.GetBufferMap(), args); } void EmitAdd(const GlobalVar& global_var, const Expr& expr) { @@ -560,11 +587,10 @@ class RelayToTIRVisitor : public MixedModeMutator { PrimExpr tensor_size = add_call->type_as()->Size(); - tir::Var input_0("input_0", DataType::Handle(8)); - tir::Var input_1("input_1", DataType::Handle(8)); - tir::Var output("output", DataType::Handle(8)); - - Array func_signature{input_0, input_1, output}; + BufferCreator buffer_creator_; + tir::Var input_0 = buffer_creator_.CreateBufferVar("input_0", DataType::Handle(8)); + tir::Var input_1 = buffer_creator_.CreateBufferVar("input_1", DataType::Handle(8)); + tir::Var output = buffer_creator_.CreateBufferVar("output", DataType::Handle(8)); tvm::Array args = { tir::StringImm("arm_elementwise_add_s8"), @@ -586,7 +612,8 @@ class RelayToTIRVisitor : public MixedModeMutator { tensor_size, }; - CreatePrimFuncForExtern(global_var, func_signature, args); + CreatePrimFuncForExtern(global_var, buffer_creator_.GetPrimFuncParams(), + buffer_creator_.GetBufferMap(), args); } Expr Rewrite_(const CallNode* pre, const Expr& post) override { diff --git a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc index e0e5aa962239..3a6731c5a040 100644 --- a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc +++ b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc @@ -228,7 +228,7 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost { /*! * \brief extracts CMSIS-NN context buffer information */ CMSISNNContextBuffer extract_context_buffer_info(const CallNode* op, int base_pos) { CMSISNNContextBuffer context_buffer; - context_buffer.name = op->args[base_pos].as()->value; + context_buffer.name = op->args[base_pos].as()->name_hint; context_buffer.size = ValueFromArg(op, base_pos + 1); return context_buffer; } diff --git a/tests/python/contrib/test_cmsisnn/test_binary_ops.py b/tests/python/contrib/test_cmsisnn/test_binary_ops.py index d08a88201d2e..3180ffc726da 100644 --- a/tests/python/contrib/test_cmsisnn/test_binary_ops.py +++ b/tests/python/contrib/test_cmsisnn/test_binary_ops.py @@ -38,6 +38,7 @@ from tests.python.relay.aot.aot_test_utils import ( AOTTestModel, AOT_CORSTONE300_RUNNER, + AOT_USMP_CORSTONE300_RUNNER, generate_ref_data, compile_and_run, ) @@ -101,7 +102,7 @@ def make_model( def test_op_int8(op, input_0_scale, input_0_zero_point, input_1_scale, input_1_zero_point): interface_api = "c" use_unpacked_api = True - test_runner = AOT_CORSTONE300_RUNNER + test_runner = AOT_USMP_CORSTONE300_RUNNER dtype = "int8" shape = [1, 16, 16, 3] @@ -174,7 +175,7 @@ def parameterize_for_constant_inputs(test): def test_constant_input_int8(op, input_0, input_1): interface_api = "c" use_unpacked_api = True - test_runner = AOT_CORSTONE300_RUNNER + test_runner = AOT_USMP_CORSTONE300_RUNNER dtype = "int8" shape = [1, 16, 16, 3] diff --git a/tests/python/contrib/test_cmsisnn/test_conv2d.py b/tests/python/contrib/test_cmsisnn/test_conv2d.py index 641fd517a4e7..16c37e21607a 100644 --- a/tests/python/contrib/test_cmsisnn/test_conv2d.py +++ b/tests/python/contrib/test_cmsisnn/test_conv2d.py @@ -27,6 +27,7 @@ from tests.python.relay.aot.aot_test_utils import ( AOTTestModel, AOT_CORSTONE300_RUNNER, + AOT_USMP_CORSTONE300_RUNNER, AOT_DEFAULT_RUNNER, generate_ref_data, compile_and_run, @@ -143,7 +144,7 @@ def test_conv2d_symmetric_padding_int8( ): interface_api = "c" use_unpacked_api = True - test_runner = AOT_CORSTONE300_RUNNER + test_runner = AOT_USMP_CORSTONE300_RUNNER ifm_shape = (1, 64, 100, 4) kernel_size = (3, 3) @@ -233,7 +234,7 @@ def test_conv2d_asymmetric_padding_int8( ): interface_api = "c" use_unpacked_api = True - test_runner = AOT_CORSTONE300_RUNNER + test_runner = AOT_USMP_CORSTONE300_RUNNER ifm_shape = (1, 25, 25, 12) kernel_size = (5, 5) @@ -334,7 +335,7 @@ def test_depthwise_int8( ): interface_api = "c" use_unpacked_api = True - test_runner = AOT_CORSTONE300_RUNNER + test_runner = AOT_USMP_CORSTONE300_RUNNER dtype = "int8" groups = 1 diff --git a/tests/python/contrib/test_cmsisnn/test_fully_connected.py b/tests/python/contrib/test_cmsisnn/test_fully_connected.py index 17670fb84cfc..bf452952f188 100644 --- a/tests/python/contrib/test_cmsisnn/test_fully_connected.py +++ b/tests/python/contrib/test_cmsisnn/test_fully_connected.py @@ -27,6 +27,7 @@ from tests.python.relay.aot.aot_test_utils import ( AOTTestModel, AOT_CORSTONE300_RUNNER, + AOT_USMP_CORSTONE300_RUNNER, AOT_DEFAULT_RUNNER, generate_ref_data, compile_and_run, @@ -119,7 +120,7 @@ def test_op_int8( ): interface_api = "c" use_unpacked_api = True - test_runner = AOT_CORSTONE300_RUNNER + test_runner = AOT_USMP_CORSTONE300_RUNNER dtype = "int8" kernel_zero_point = 0 diff --git a/tests/python/contrib/test_cmsisnn/test_networks.py b/tests/python/contrib/test_cmsisnn/test_networks.py index d74782e57567..a6e77515859e 100644 --- a/tests/python/contrib/test_cmsisnn/test_networks.py +++ b/tests/python/contrib/test_cmsisnn/test_networks.py @@ -31,6 +31,7 @@ from tests.python.relay.aot.aot_test_utils import ( AOTTestModel, AOT_CORSTONE300_RUNNER, + AOT_USMP_CORSTONE300_RUNNER, generate_ref_data, compile_and_run, ) @@ -77,7 +78,8 @@ def convert_to_list(x): @skip_if_no_reference_system @tvm.testing.requires_package("tflite") @tvm.testing.requires_cmsisnn -def test_cnn_small(): +@pytest.mark.parametrize("test_runner", [AOT_CORSTONE300_RUNNER, AOT_USMP_CORSTONE300_RUNNER]) +def test_cnn_small(test_runner): # download the model base_url = "https://github.com/ARM-software/ML-zoo/raw/48a22ee22325d15d2371a6df24eb7d67e21dcc97/models/keyword_spotting/cnn_small/tflite_int8" file_to_download = "cnn_s_quantized.tflite" @@ -99,7 +101,6 @@ def test_cnn_small(): # validate CMSIS-NN output against CPU output interface_api = "c" use_unpacked_api = True - test_runner = AOT_CORSTONE300_RUNNER inputs = {"input": input_data} params = {} output_list = generate_ref_data(orig_mod["main"], inputs, params) diff --git a/tests/python/contrib/test_cmsisnn/test_pooling.py b/tests/python/contrib/test_cmsisnn/test_pooling.py index ee4f5c4aea4d..732fd9bb82ec 100644 --- a/tests/python/contrib/test_cmsisnn/test_pooling.py +++ b/tests/python/contrib/test_cmsisnn/test_pooling.py @@ -27,6 +27,7 @@ from tests.python.relay.aot.aot_test_utils import ( AOTTestModel, AOT_CORSTONE300_RUNNER, + AOT_USMP_CORSTONE300_RUNNER, AOT_DEFAULT_RUNNER, generate_ref_data, compile_and_run, @@ -87,7 +88,7 @@ def test_op_int8( ): interface_api = "c" use_unpacked_api = True - test_runner = AOT_CORSTONE300_RUNNER + test_runner = AOT_USMP_CORSTONE300_RUNNER dtype = "int8" diff --git a/tests/python/contrib/test_cmsisnn/test_softmax.py b/tests/python/contrib/test_cmsisnn/test_softmax.py index c3617cce15d4..6eac76d841b4 100644 --- a/tests/python/contrib/test_cmsisnn/test_softmax.py +++ b/tests/python/contrib/test_cmsisnn/test_softmax.py @@ -37,6 +37,7 @@ from tests.python.relay.aot.aot_test_utils import ( AOTTestModel, AOT_CORSTONE300_RUNNER, + AOT_USMP_CORSTONE300_RUNNER, generate_ref_data, compile_and_run, ) @@ -68,7 +69,7 @@ def make_model( def test_op_int8(zero_point, scale): interface_api = "c" use_unpacked_api = True - test_runner = AOT_CORSTONE300_RUNNER + test_runner = AOT_USMP_CORSTONE300_RUNNER dtype = "int8" shape = [1, 16, 16, 3] diff --git a/tests/python/relay/aot/aot_test_utils.py b/tests/python/relay/aot/aot_test_utils.py index 63817fc4b965..fd897f095127 100644 --- a/tests/python/relay/aot/aot_test_utils.py +++ b/tests/python/relay/aot/aot_test_utils.py @@ -153,6 +153,20 @@ class AOTTestRunner(NamedTuple): }, ) +AOT_USMP_CORSTONE300_RUNNER = AOTTestRunner( + makefile="corstone300", + prologue=""" + uart_init(); + """, + includes=["uart.h"], + pass_config={ + "relay.ext.cmsisnn.options": { + "mcpu": "cortex-m55", + }, + "tir.usmp.enable": True, + }, +) + def mangle_name(mod_name, name): mod_name = mangle_module_name(mod_name) From 3c4d71c80ed0d0156fabe3afc2bc919616c7bff5 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Tue, 15 Feb 2022 14:56:20 +0000 Subject: [PATCH 2/3] [CMSIS-NN] enable USMP with CMSIS-NN "NULL" should be supplied as an extern arg if the context buffer is not allocated. Change-Id: I58866df69a27ce976ac45a57e2f261094a48e48e --- src/relay/backend/contrib/cmsisnn/relay_to_tir.cc | 14 ++++++-------- .../backend/contrib/cmsisnn/tir_to_runtime.cc | 9 ++++++++- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc index b50eadbd55c0..9e300b7c3abf 100644 --- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc +++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc @@ -92,7 +92,7 @@ class RelayToTIRVisitor : public MixedModeMutator { void CreatePrimFuncForExtern(const GlobalVar& global_var, Array func_signature, const Map& buffer_map, tvm::Array call_extern_args, - tir::Var context_buffer_var = tir::Var{"NULL"}, + PrimExpr context_buffer_var = PrimExpr(), int context_buffer_size = 0) { Map dict_attrs; dict_attrs.Set(tvm::attr::kGlobalSymbol, global_var->name_hint); @@ -103,8 +103,8 @@ class RelayToTIRVisitor : public MixedModeMutator { tvm::tir::Call(DataType::Int(8), tir::builtin::call_extern(), call_extern_args)); if (context_buffer_size) { - body = tir::Allocate(context_buffer_var, DataType::Int(8), {context_buffer_size}, - tir::const_true(), body); + body = tir::Allocate(Downcast(context_buffer_var), DataType::Int(8), + {context_buffer_size}, tir::const_true(), body); } tir::PrimFunc replacement_func(func_signature, body, VoidType(), buffer_map, @@ -233,7 +233,7 @@ class RelayToTIRVisitor : public MixedModeMutator { call_ext_args.push_back(shift); call_ext_args.push_back(output); - tir::Var buffer_var{"NULL"}; + PrimExpr buffer_var = tir::StringImm("NULL"); CMSISNNFlags flags = GetCompilerFlags(transform::PassContext::Current()); size_t context_buffer_size; if (is_depthwise) { @@ -350,9 +350,7 @@ class RelayToTIRVisitor : public MixedModeMutator { call_ext_args.push_back(output); int context_buffer_size = 0; - std::string context_buffer_name = "NULL"; - tir::Var buffer_var = - tir::Var(context_buffer_name, PointerType(PrimType(DataType::Int(8)), "global.workspace")); + PrimExpr buffer_var = tir::StringImm("NULL"); tvm::Array context_buffer_args = {buffer_var, ToArg(context_buffer_size)}; scalar_args = tvm::runtime::Concat(context_buffer_args, scalar_args); @@ -435,7 +433,7 @@ class RelayToTIRVisitor : public MixedModeMutator { tvm::Array call_ext_args = {tir::StringImm(cmsisnn_api), input, output}; int context_buffer_size = 0; - tir::Var context_buffer_var{"NULL"}; + PrimExpr context_buffer_var = tir::StringImm("NULL"); if (pool_name == "cmsisnn.qnn_avg_pool2d") { CMSISNNFlags flags = GetCompilerFlags(transform::PassContext::Current()); int32_t input_c = qnn::get_const_int(input_shape[3]); diff --git a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc index 3a6731c5a040..fbd0ff1707c0 100644 --- a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc +++ b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc @@ -228,7 +228,14 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost { /*! * \brief extracts CMSIS-NN context buffer information */ CMSISNNContextBuffer extract_context_buffer_info(const CallNode* op, int base_pos) { CMSISNNContextBuffer context_buffer; - context_buffer.name = op->args[base_pos].as()->name_hint; + + // The argument could be a Var if it is allocated to hold the + // context buffer OR it will be a StringImm with "NULL" + if (op->args[base_pos]->IsInstance()) { + context_buffer.name = op->args[base_pos].as()->name_hint; + } else { + context_buffer.name = op->args[base_pos].as()->value; + } context_buffer.size = ValueFromArg(op, base_pos + 1); return context_buffer; } From ebb81900a91081a1e4b2a44540bf0cc688926ce2 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Tue, 1 Mar 2022 17:20:52 +0000 Subject: [PATCH 3/3] [CMSIS-NN] enable USMP with CMSIS-NN * renaming buffer_var to be context_buffer_var * removing trailing underscore from local variable : buffer_creator Change-Id: I874164aa4944087c0e13cef80ead04c0de2be149 --- .../backend/contrib/cmsisnn/relay_to_tir.cc | 96 +++++++++---------- 1 file changed, 48 insertions(+), 48 deletions(-) diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc index 9e300b7c3abf..530d6495adb2 100644 --- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc +++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc @@ -141,17 +141,17 @@ class RelayToTIRVisitor : public MixedModeMutator { // %3 = qnn.requantize(%2, %input_scale_const_4, %cmsisnn_shift_const_5, // %output_scale_scalar, %output_zero_point_scalar) // clip(%3, a_min=%min_scalar, a_max=%max_scalar) - BufferCreator buffer_creator_; - tir::Var input = buffer_creator_.CreateBufferVar("input", DataType::Handle(8)); - tir::Var filter = buffer_creator_.CreateBufferVar("filter", DataType::Handle(8)); - tir::Var multiplier = buffer_creator_.CreateBufferVar("multiplier", DataType::Handle(32)); - tir::Var filter_scale = buffer_creator_.CreateBufferVar("filter_scale", DataType::Handle(32)); + BufferCreator buffer_creator; + tir::Var input = buffer_creator.CreateBufferVar("input", DataType::Handle(8)); + tir::Var filter = buffer_creator.CreateBufferVar("filter", DataType::Handle(8)); + tir::Var multiplier = buffer_creator.CreateBufferVar("multiplier", DataType::Handle(32)); + tir::Var filter_scale = buffer_creator.CreateBufferVar("filter_scale", DataType::Handle(32)); if (bias_add_call) { - buffer_creator_.CreateBufferVar("bias", DataType::Handle(32)); + buffer_creator.CreateBufferVar("bias", DataType::Handle(32)); } - tir::Var input_scale = buffer_creator_.CreateBufferVar("input_scale", DataType::Handle(32)); - tir::Var shift = buffer_creator_.CreateBufferVar("shift", DataType::Handle(32)); - tir::Var output = buffer_creator_.CreateBufferVar("output", DataType::Handle(8)); + tir::Var input_scale = buffer_creator.CreateBufferVar("input_scale", DataType::Handle(32)); + tir::Var shift = buffer_creator.CreateBufferVar("shift", DataType::Handle(32)); + tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(8)); // Individual arguments to the structs arguments of the CMSIS-NN API are filled into call_extern // https://github.com/ARM-software/CMSIS_5/blob/def6f800f95661eb3451d317f7d0dde504f6020d/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_wrapper_s8.c#L50 @@ -227,13 +227,13 @@ class RelayToTIRVisitor : public MixedModeMutator { tvm::Array call_ext_args = {tir::StringImm(cmsisnn_api), input, filter, multiplier}; if (bias_add_call) { - tir::Var bias = buffer_creator_.GetBufferVar("bias"); + tir::Var bias = buffer_creator.GetBufferVar("bias"); call_ext_args.push_back(bias); } call_ext_args.push_back(shift); call_ext_args.push_back(output); - PrimExpr buffer_var = tir::StringImm("NULL"); + PrimExpr context_buffer_var = tir::StringImm("NULL"); CMSISNNFlags flags = GetCompilerFlags(transform::PassContext::Current()); size_t context_buffer_size; if (is_depthwise) { @@ -247,10 +247,10 @@ class RelayToTIRVisitor : public MixedModeMutator { if (context_buffer_size) { String context_buffer_name = "context_buffer_" + std::to_string(context_buffer_id_++); - buffer_var = tir::Var(context_buffer_name, - PointerType(PrimType(DataType::Int(8)), "global.workspace")); + context_buffer_var = tir::Var(context_buffer_name, + PointerType(PrimType(DataType::Int(8)), "global.workspace")); } - tvm::Array context_buffer_args = {buffer_var, ToArg(context_buffer_size)}; + tvm::Array context_buffer_args = {context_buffer_var, ToArg(context_buffer_size)}; scalar_args = tvm::runtime::Concat(context_buffer_args, scalar_args); scalar_args = tvm::runtime::Concat(scalar_args, input_shape); @@ -259,8 +259,8 @@ class RelayToTIRVisitor : public MixedModeMutator { scalar_args = tvm::runtime::Concat(scalar_args, output_shape); call_ext_args = tvm::runtime::Concat(call_ext_args, scalar_args); - CreatePrimFuncForExtern(global_var, buffer_creator_.GetPrimFuncParams(), - buffer_creator_.GetBufferMap(), call_ext_args, buffer_var, + CreatePrimFuncForExtern(global_var, buffer_creator.GetPrimFuncParams(), + buffer_creator.GetBufferMap(), call_ext_args, context_buffer_var, context_buffer_size); } @@ -293,13 +293,13 @@ class RelayToTIRVisitor : public MixedModeMutator { // %3 = qnn.requantize(%2, %req_input_scale_scalar, %req_input_zero_point_scalar, // %output_scale_scalar, %output_zero_point_scalar) // clip(%3, a_min=%min_scalar, a_max=%max_scalar) - BufferCreator buffer_creator_; - tir::Var input = buffer_creator_.CreateBufferVar("input", DataType::Handle(8)); - tir::Var filter = buffer_creator_.CreateBufferVar("filter", DataType::Handle(8)); + BufferCreator buffer_creator; + tir::Var input = buffer_creator.CreateBufferVar("input", DataType::Handle(8)); + tir::Var filter = buffer_creator.CreateBufferVar("filter", DataType::Handle(8)); if (bias_add_call) { - buffer_creator_.CreateBufferVar("bias", DataType::Handle(32)); + buffer_creator.CreateBufferVar("bias", DataType::Handle(32)); } - tir::Var output = buffer_creator_.CreateBufferVar("output", DataType::Handle(8)); + tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(8)); // Individual arguments to the structs arguments of the CMSIS-NN API are filled into call_extern // https://github.com/ARM-software/CMSIS_5/blob/def6f800f95661eb3451d317f7d0dde504f6020d/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_wrapper_s8.c#L50 @@ -345,13 +345,13 @@ class RelayToTIRVisitor : public MixedModeMutator { tvm::Array call_ext_args = {tir::StringImm("arm_fully_connected_s8"), input, filter}; if (bias_add_call) { - call_ext_args.push_back(buffer_creator_.GetBufferVar("bias")); + call_ext_args.push_back(buffer_creator.GetBufferVar("bias")); } call_ext_args.push_back(output); int context_buffer_size = 0; - PrimExpr buffer_var = tir::StringImm("NULL"); - tvm::Array context_buffer_args = {buffer_var, ToArg(context_buffer_size)}; + PrimExpr context_buffer_var = tir::StringImm("NULL"); + tvm::Array context_buffer_args = {context_buffer_var, ToArg(context_buffer_size)}; scalar_args = tvm::runtime::Concat(context_buffer_args, scalar_args); scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_input_shape); @@ -360,8 +360,8 @@ class RelayToTIRVisitor : public MixedModeMutator { scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_output_shape); call_ext_args = tvm::runtime::Concat(call_ext_args, scalar_args); - CreatePrimFuncForExtern(global_var, buffer_creator_.GetPrimFuncParams(), - buffer_creator_.GetBufferMap(), call_ext_args, buffer_var, + CreatePrimFuncForExtern(global_var, buffer_creator.GetPrimFuncParams(), + buffer_creator.GetBufferMap(), call_ext_args, context_buffer_var, context_buffer_size); } @@ -427,9 +427,9 @@ class RelayToTIRVisitor : public MixedModeMutator { Array output_shape = pool->type_as()->shape; Array cmsisnn_output_shape{1, output_shape[1], output_shape[2], output_shape[3]}; - BufferCreator buffer_creator_; - tir::Var input = buffer_creator_.CreateBufferVar("input", DataType::Handle(8)); - tir::Var output = buffer_creator_.CreateBufferVar("output", DataType::Handle(8)); + BufferCreator buffer_creator; + tir::Var input = buffer_creator.CreateBufferVar("input", DataType::Handle(8)); + tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(8)); tvm::Array call_ext_args = {tir::StringImm(cmsisnn_api), input, output}; int context_buffer_size = 0; @@ -450,8 +450,8 @@ class RelayToTIRVisitor : public MixedModeMutator { scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_output_shape); call_ext_args = tvm::runtime::Concat(call_ext_args, scalar_args); - CreatePrimFuncForExtern(global_var, buffer_creator_.GetPrimFuncParams(), - buffer_creator_.GetBufferMap(), call_ext_args, context_buffer_var, + CreatePrimFuncForExtern(global_var, buffer_creator.GetPrimFuncParams(), + buffer_creator.GetBufferMap(), call_ext_args, context_buffer_var, context_buffer_size); } @@ -485,9 +485,9 @@ class RelayToTIRVisitor : public MixedModeMutator { diff_min >>= shift; diff_min *= -1; - BufferCreator buffer_creator_; - tir::Var in_var = buffer_creator_.CreateBufferVar("input", DataType::Handle(8)); - tir::Var out_var = buffer_creator_.CreateBufferVar("output", DataType::Handle(8)); + BufferCreator buffer_creator; + tir::Var in_var = buffer_creator.CreateBufferVar("input", DataType::Handle(8)); + tir::Var out_var = buffer_creator.CreateBufferVar("output", DataType::Handle(8)); tvm::Array args = { tir::StringImm("arm_softmax_s8"), @@ -500,8 +500,8 @@ class RelayToTIRVisitor : public MixedModeMutator { out_var, }; - CreatePrimFuncForExtern(global_var, buffer_creator_.GetPrimFuncParams(), - buffer_creator_.GetBufferMap(), args); + CreatePrimFuncForExtern(global_var, buffer_creator.GetPrimFuncParams(), + buffer_creator.GetBufferMap(), args); } void EmitMul(const GlobalVar& global_var, const Expr& expr) { @@ -523,10 +523,10 @@ class RelayToTIRVisitor : public MixedModeMutator { PrimExpr tensor_size = mul_call->type_as()->Size(); - BufferCreator buffer_creator_; - tir::Var input_0 = buffer_creator_.CreateBufferVar("input_0", DataType::Handle(8)); - tir::Var input_1 = buffer_creator_.CreateBufferVar("input_1", DataType::Handle(8)); - tir::Var output = buffer_creator_.CreateBufferVar("output", DataType::Handle(8)); + BufferCreator buffer_creator; + tir::Var input_0 = buffer_creator.CreateBufferVar("input_0", DataType::Handle(8)); + tir::Var input_1 = buffer_creator.CreateBufferVar("input_1", DataType::Handle(8)); + tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(8)); tvm::Array args = { tir::StringImm("arm_elementwise_mul_s8"), @@ -543,8 +543,8 @@ class RelayToTIRVisitor : public MixedModeMutator { tensor_size, }; - CreatePrimFuncForExtern(global_var, buffer_creator_.GetPrimFuncParams(), - buffer_creator_.GetBufferMap(), args); + CreatePrimFuncForExtern(global_var, buffer_creator.GetPrimFuncParams(), + buffer_creator.GetBufferMap(), args); } void EmitAdd(const GlobalVar& global_var, const Expr& expr) { @@ -585,10 +585,10 @@ class RelayToTIRVisitor : public MixedModeMutator { PrimExpr tensor_size = add_call->type_as()->Size(); - BufferCreator buffer_creator_; - tir::Var input_0 = buffer_creator_.CreateBufferVar("input_0", DataType::Handle(8)); - tir::Var input_1 = buffer_creator_.CreateBufferVar("input_1", DataType::Handle(8)); - tir::Var output = buffer_creator_.CreateBufferVar("output", DataType::Handle(8)); + BufferCreator buffer_creator; + tir::Var input_0 = buffer_creator.CreateBufferVar("input_0", DataType::Handle(8)); + tir::Var input_1 = buffer_creator.CreateBufferVar("input_1", DataType::Handle(8)); + tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(8)); tvm::Array args = { tir::StringImm("arm_elementwise_add_s8"), @@ -610,8 +610,8 @@ class RelayToTIRVisitor : public MixedModeMutator { tensor_size, }; - CreatePrimFuncForExtern(global_var, buffer_creator_.GetPrimFuncParams(), - buffer_creator_.GetBufferMap(), args); + CreatePrimFuncForExtern(global_var, buffer_creator.GetPrimFuncParams(), + buffer_creator.GetBufferMap(), args); } Expr Rewrite_(const CallNode* pre, const Expr& post) override {