Skip to content

Commit

Permalink
[CMSIS-NN] enable USMP with CMSIS-NN
Browse files Browse the repository at this point in the history
This commit mainly enables the USMP
with CMSIS-NN codegen.

In order to do that, CMSIS-NN functions needed
to contain BufferMaps. This commit adds the necessary
BufferMaps as well.

All the tests are modified to run with USMP
while the networks tests run with and without
USMP.

Change-Id: I18c7958addfff90b8243e9a6c70b7411158462fa
  • Loading branch information
manupak committed Feb 24, 2022
1 parent cb7f773 commit 0e4b0d0
Show file tree
Hide file tree
Showing 9 changed files with 127 additions and 80 deletions.
165 changes: 96 additions & 69 deletions src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,38 @@ namespace relay {
namespace contrib {
namespace cmsisnn {

/*!
* \brief This is a helper class to generate tir.Buffers and BufferMap
*
* The PrimFuncs generated needs Buffers produced to attach information
* about the inputs and output tir::Vars. This is helper class to generate
* them in the relay to TIR lowering
*/
class BufferCreator {
public:
/*! \brief Creates a tir::Var and tir::Buffer then returns the buffer var to be used by the body
*/
tir::Var CreateBufferVar(String name_hint, DataType dtype) {
tir::Var var = tir::Var(name_hint, dtype);
tir::Buffer buffer = tir::decl_buffer({}, DataType::Int(dtype.bits()), name_hint + "_");
_primfunc_params_.push_back(var);
_buffer_map_.Set(var, buffer);
_buffer_vars_.Set(name_hint, buffer->data);
return buffer->data;
}
/*! \brief Access already created buffer_var by associated tir::Var name */
tir::Var GetBufferVar(String name_hint) { return _buffer_vars_[name_hint]; }
/*! \brief Get the BufferMap that maps tir::Var to tir::Buffer */
Map<tir::Var, tir::Buffer> GetBufferMap() { return _buffer_map_; }
/*! \brief Get the PrimFunc params that is a collection of tir::Vars created in the process */
Array<tir::Var> GetPrimFuncParams() { return _primfunc_params_; }

private:
Map<String, tir::Var> _buffer_vars_;
Map<tir::Var, tir::Buffer> _buffer_map_;
Array<tir::Var> _primfunc_params_;
};

class RelayToTIRVisitor : public MixedModeMutator {
public:
explicit RelayToTIRVisitor(IRModule ir_module, Target target)
Expand All @@ -58,8 +90,9 @@ class RelayToTIRVisitor : public MixedModeMutator {
inline IntImm ToArg(int32_t value) { return IntImm(DataType::Int(32), value); }

void CreatePrimFuncForExtern(const GlobalVar& global_var, Array<tir::Var> func_signature,
const Map<tir::Var, tir::Buffer>& buffer_map,
tvm::Array<PrimExpr> call_extern_args,
std::string context_buffer_name = "NULL",
tir::Var context_buffer_var = tir::Var{"NULL"},
int context_buffer_size = 0) {
Map<String, ObjectRef> dict_attrs;
dict_attrs.Set(tvm::attr::kGlobalSymbol, global_var->name_hint);
Expand All @@ -70,16 +103,11 @@ class RelayToTIRVisitor : public MixedModeMutator {
tvm::tir::Call(DataType::Int(8), tir::builtin::call_extern(), call_extern_args));

if (context_buffer_size) {
tir::Var buffer_var(context_buffer_name,
PointerType(PrimType(DataType::Int(8)), "global.workspace"));
body = tir::Allocate(buffer_var, DataType::Int(8), {context_buffer_size}, tir::const_true(),
body);
body =
tir::AttrStmt(PrimExpr(), tvm::tir::attr::device_type, target_->kind->device_type, body);
body = tir::AttrStmt(PrimExpr(), tvm::tir::attr::device_id, 0, body);
body = tir::Allocate(context_buffer_var, DataType::Int(8), {context_buffer_size},
tir::const_true(), body);
}

tir::PrimFunc replacement_func(func_signature, body, VoidType(), Map<tir::Var, tir::Buffer>(),
tir::PrimFunc replacement_func(func_signature, body, VoidType(), buffer_map,
DictAttrs(dict_attrs));
ir_module_->Add(global_var, replacement_func);
}
Expand Down Expand Up @@ -113,14 +141,17 @@ class RelayToTIRVisitor : public MixedModeMutator {
// %3 = qnn.requantize(%2, %input_scale_const_4, %cmsisnn_shift_const_5,
// %output_scale_scalar, %output_zero_point_scalar)
// clip(%3, a_min=%min_scalar, a_max=%max_scalar)
tir::Var input("input", DataType::Handle(8));
tir::Var filter("filter", DataType::Handle(8));
tir::Var multiplier("multiplier", DataType::Handle(32));
tir::Var filter_scale("filter_scale", DataType::Handle(32));
tir::Var bias("bias", DataType::Handle(32));
tir::Var input_scale("input_scale", DataType::Handle(32));
tir::Var shift("shift", DataType::Handle(32));
tir::Var output("output", DataType::Handle(8));
BufferCreator buffer_creator_;
tir::Var input = buffer_creator_.CreateBufferVar("input", DataType::Handle(8));
tir::Var filter = buffer_creator_.CreateBufferVar("filter", DataType::Handle(8));
tir::Var multiplier = buffer_creator_.CreateBufferVar("multiplier", DataType::Handle(32));
tir::Var filter_scale = buffer_creator_.CreateBufferVar("filter_scale", DataType::Handle(32));
if (bias_add_call) {
buffer_creator_.CreateBufferVar("bias", DataType::Handle(32));
}
tir::Var input_scale = buffer_creator_.CreateBufferVar("input_scale", DataType::Handle(32));
tir::Var shift = buffer_creator_.CreateBufferVar("shift", DataType::Handle(32));
tir::Var output = buffer_creator_.CreateBufferVar("output", DataType::Handle(8));

// Individual arguments to the structs arguments of the CMSIS-NN API are filled into call_extern
// https://github.com/ARM-software/CMSIS_5/blob/def6f800f95661eb3451d317f7d0dde504f6020d/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_wrapper_s8.c#L50
Expand Down Expand Up @@ -196,12 +227,13 @@ class RelayToTIRVisitor : public MixedModeMutator {

tvm::Array<PrimExpr> call_ext_args = {tir::StringImm(cmsisnn_api), input, filter, multiplier};
if (bias_add_call) {
tir::Var bias = buffer_creator_.GetBufferVar("bias");
call_ext_args.push_back(bias);
}
call_ext_args.push_back(shift);
call_ext_args.push_back(output);

std::string context_buffer_name = "NULL";
tir::Var buffer_var{"NULL"};
CMSISNNFlags flags = GetCompilerFlags(transform::PassContext::Current());
size_t context_buffer_size;
if (is_depthwise) {
Expand All @@ -214,10 +246,11 @@ class RelayToTIRVisitor : public MixedModeMutator {
}

if (context_buffer_size) {
context_buffer_name = "context_buffer_" + std::to_string(context_buffer_id_++);
String context_buffer_name = "context_buffer_" + std::to_string(context_buffer_id_++);
buffer_var = tir::Var(context_buffer_name,
PointerType(PrimType(DataType::Int(8)), "global.workspace"));
}
tvm::Array<PrimExpr> context_buffer_args = {tir::StringImm(context_buffer_name),
ToArg(context_buffer_size)};
tvm::Array<PrimExpr> context_buffer_args = {buffer_var, ToArg(context_buffer_size)};

scalar_args = tvm::runtime::Concat(context_buffer_args, scalar_args);
scalar_args = tvm::runtime::Concat(scalar_args, input_shape);
Expand All @@ -226,15 +259,8 @@ class RelayToTIRVisitor : public MixedModeMutator {
scalar_args = tvm::runtime::Concat(scalar_args, output_shape);
call_ext_args = tvm::runtime::Concat(call_ext_args, scalar_args);

Array<tir::Var> func_signature{input, filter, multiplier, filter_scale};
if (bias_add_call) {
func_signature.push_back(bias);
}
func_signature.push_back(input_scale);
func_signature.push_back(shift);
func_signature.push_back(output);

CreatePrimFuncForExtern(global_var, func_signature, call_ext_args, context_buffer_name,
CreatePrimFuncForExtern(global_var, buffer_creator_.GetPrimFuncParams(),
buffer_creator_.GetBufferMap(), call_ext_args, buffer_var,
context_buffer_size);
}

Expand Down Expand Up @@ -267,10 +293,13 @@ class RelayToTIRVisitor : public MixedModeMutator {
// %3 = qnn.requantize(%2, %req_input_scale_scalar, %req_input_zero_point_scalar,
// %output_scale_scalar, %output_zero_point_scalar)
// clip(%3, a_min=%min_scalar, a_max=%max_scalar)
tir::Var input("input", DataType::Handle(8));
tir::Var filter("filter", DataType::Handle(8));
tir::Var bias("bias", DataType::Handle(32));
tir::Var output("output", DataType::Handle(8));
BufferCreator buffer_creator_;
tir::Var input = buffer_creator_.CreateBufferVar("input", DataType::Handle(8));
tir::Var filter = buffer_creator_.CreateBufferVar("filter", DataType::Handle(8));
if (bias_add_call) {
buffer_creator_.CreateBufferVar("bias", DataType::Handle(32));
}
tir::Var output = buffer_creator_.CreateBufferVar("output", DataType::Handle(8));

// Individual arguments to the structs arguments of the CMSIS-NN API are filled into call_extern
// https://github.com/ARM-software/CMSIS_5/blob/def6f800f95661eb3451d317f7d0dde504f6020d/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_wrapper_s8.c#L50
Expand Down Expand Up @@ -316,14 +345,15 @@ class RelayToTIRVisitor : public MixedModeMutator {

tvm::Array<PrimExpr> call_ext_args = {tir::StringImm("arm_fully_connected_s8"), input, filter};
if (bias_add_call) {
call_ext_args.push_back(bias);
call_ext_args.push_back(buffer_creator_.GetBufferVar("bias"));
}
call_ext_args.push_back(output);

int context_buffer_size = 0;
std::string context_buffer_name = "NULL";
tvm::Array<PrimExpr> context_buffer_args = {tir::StringImm(context_buffer_name),
ToArg(context_buffer_size)};
tir::Var buffer_var =
tir::Var(context_buffer_name, PointerType(PrimType(DataType::Int(8)), "global.workspace"));
tvm::Array<PrimExpr> context_buffer_args = {buffer_var, ToArg(context_buffer_size)};

scalar_args = tvm::runtime::Concat(context_buffer_args, scalar_args);
scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_input_shape);
Expand All @@ -332,12 +362,8 @@ class RelayToTIRVisitor : public MixedModeMutator {
scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_output_shape);
call_ext_args = tvm::runtime::Concat(call_ext_args, scalar_args);

Array<tir::Var> func_signature{input, filter};
if (bias_add_call) {
func_signature.push_back(bias);
}
func_signature.push_back(output);
CreatePrimFuncForExtern(global_var, func_signature, call_ext_args, context_buffer_name,
CreatePrimFuncForExtern(global_var, buffer_creator_.GetPrimFuncParams(),
buffer_creator_.GetBufferMap(), call_ext_args, buffer_var,
context_buffer_size);
}

Expand Down Expand Up @@ -403,30 +429,31 @@ class RelayToTIRVisitor : public MixedModeMutator {
Array<PrimExpr> output_shape = pool->type_as<TensorTypeNode>()->shape;
Array<PrimExpr> cmsisnn_output_shape{1, output_shape[1], output_shape[2], output_shape[3]};

tir::Var input("input", DataType::Handle(8));
tir::Var output("output", DataType::Handle(8));
BufferCreator buffer_creator_;
tir::Var input = buffer_creator_.CreateBufferVar("input", DataType::Handle(8));
tir::Var output = buffer_creator_.CreateBufferVar("output", DataType::Handle(8));
tvm::Array<PrimExpr> call_ext_args = {tir::StringImm(cmsisnn_api), input, output};

int context_buffer_size = 0;
std::string context_buffer_name = "NULL";
tir::Var context_buffer_var{"NULL"};
if (pool_name == "cmsisnn.qnn_avg_pool2d") {
CMSISNNFlags flags = GetCompilerFlags(transform::PassContext::Current());
int32_t input_c = qnn::get_const_int(input_shape[3]);
context_buffer_size = AvgPoolBufferSize(flags, input_c);
context_buffer_name = "context_buffer_" + std::to_string(context_buffer_id_++);
std::string context_buffer_name = "context_buffer_" + std::to_string(context_buffer_id_++);
context_buffer_var = tir::Var(context_buffer_name,
PointerType(PrimType(DataType::Int(8)), "global.workspace"));
}
tvm::Array<PrimExpr> context_buffer_args = {tir::StringImm(context_buffer_name),
ToArg(context_buffer_size)};
tvm::Array<PrimExpr> context_buffer_args = {context_buffer_var, ToArg(context_buffer_size)};

scalar_args = tvm::runtime::Concat(context_buffer_args, scalar_args);
scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_input_shape);
scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_filter_shape);
scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_output_shape);
call_ext_args = tvm::runtime::Concat(call_ext_args, scalar_args);

Array<tir::Var> func_signature{input, output};

CreatePrimFuncForExtern(global_var, func_signature, call_ext_args, context_buffer_name,
CreatePrimFuncForExtern(global_var, buffer_creator_.GetPrimFuncParams(),
buffer_creator_.GetBufferMap(), call_ext_args, context_buffer_var,
context_buffer_size);
}

Expand Down Expand Up @@ -460,10 +487,9 @@ class RelayToTIRVisitor : public MixedModeMutator {
diff_min >>= shift;
diff_min *= -1;

auto in_var = tir::Var("input", DataType::Handle(8));
auto out_var = tir::Var("output", DataType::Handle(8));

Array<tir::Var> func_signature{in_var, out_var};
BufferCreator buffer_creator_;
tir::Var in_var = buffer_creator_.CreateBufferVar("input", DataType::Handle(8));
tir::Var out_var = buffer_creator_.CreateBufferVar("output", DataType::Handle(8));

tvm::Array<PrimExpr> args = {
tir::StringImm("arm_softmax_s8"),
Expand All @@ -476,7 +502,8 @@ class RelayToTIRVisitor : public MixedModeMutator {
out_var,
};

CreatePrimFuncForExtern(global_var, func_signature, args);
CreatePrimFuncForExtern(global_var, buffer_creator_.GetPrimFuncParams(),
buffer_creator_.GetBufferMap(), args);
}

void EmitMul(const GlobalVar& global_var, const Expr& expr) {
Expand All @@ -498,11 +525,10 @@ class RelayToTIRVisitor : public MixedModeMutator {

PrimExpr tensor_size = mul_call->type_as<TensorTypeNode>()->Size();

tir::Var input_0("input_0", DataType::Handle(8));
tir::Var input_1("input_1", DataType::Handle(8));
tir::Var output("output", DataType::Handle(8));

Array<tir::Var> func_signature{input_0, input_1, output};
BufferCreator buffer_creator_;
tir::Var input_0 = buffer_creator_.CreateBufferVar("input_0", DataType::Handle(8));
tir::Var input_1 = buffer_creator_.CreateBufferVar("input_1", DataType::Handle(8));
tir::Var output = buffer_creator_.CreateBufferVar("output", DataType::Handle(8));

tvm::Array<PrimExpr> args = {
tir::StringImm("arm_elementwise_mul_s8"),
Expand All @@ -519,7 +545,8 @@ class RelayToTIRVisitor : public MixedModeMutator {
tensor_size,
};

CreatePrimFuncForExtern(global_var, func_signature, args);
CreatePrimFuncForExtern(global_var, buffer_creator_.GetPrimFuncParams(),
buffer_creator_.GetBufferMap(), args);
}

void EmitAdd(const GlobalVar& global_var, const Expr& expr) {
Expand Down Expand Up @@ -560,11 +587,10 @@ class RelayToTIRVisitor : public MixedModeMutator {

PrimExpr tensor_size = add_call->type_as<TensorTypeNode>()->Size();

tir::Var input_0("input_0", DataType::Handle(8));
tir::Var input_1("input_1", DataType::Handle(8));
tir::Var output("output", DataType::Handle(8));

Array<tir::Var> func_signature{input_0, input_1, output};
BufferCreator buffer_creator_;
tir::Var input_0 = buffer_creator_.CreateBufferVar("input_0", DataType::Handle(8));
tir::Var input_1 = buffer_creator_.CreateBufferVar("input_1", DataType::Handle(8));
tir::Var output = buffer_creator_.CreateBufferVar("output", DataType::Handle(8));

tvm::Array<PrimExpr> args = {
tir::StringImm("arm_elementwise_add_s8"),
Expand All @@ -586,7 +612,8 @@ class RelayToTIRVisitor : public MixedModeMutator {
tensor_size,
};

CreatePrimFuncForExtern(global_var, func_signature, args);
CreatePrimFuncForExtern(global_var, buffer_creator_.GetPrimFuncParams(),
buffer_creator_.GetBufferMap(), args);
}

Expr Rewrite_(const CallNode* pre, const Expr& post) override {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
/*! * \brief extracts CMSIS-NN context buffer information */
CMSISNNContextBuffer extract_context_buffer_info(const CallNode* op, int base_pos) {
CMSISNNContextBuffer context_buffer;
context_buffer.name = op->args[base_pos].as<StringImmNode>()->value;
context_buffer.name = op->args[base_pos].as<VarNode>()->name_hint;
context_buffer.size = ValueFromArg(op, base_pos + 1);
return context_buffer;
}
Expand Down
5 changes: 3 additions & 2 deletions tests/python/contrib/test_cmsisnn/test_binary_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from tests.python.relay.aot.aot_test_utils import (
AOTTestModel,
AOT_CORSTONE300_RUNNER,
AOT_USMP_CORSTONE300_RUNNER,
generate_ref_data,
compile_and_run,
)
Expand Down Expand Up @@ -101,7 +102,7 @@ def make_model(
def test_op_int8(op, input_0_scale, input_0_zero_point, input_1_scale, input_1_zero_point):
interface_api = "c"
use_unpacked_api = True
test_runner = AOT_CORSTONE300_RUNNER
test_runner = AOT_USMP_CORSTONE300_RUNNER

dtype = "int8"
shape = [1, 16, 16, 3]
Expand Down Expand Up @@ -174,7 +175,7 @@ def parameterize_for_constant_inputs(test):
def test_constant_input_int8(op, input_0, input_1):
interface_api = "c"
use_unpacked_api = True
test_runner = AOT_CORSTONE300_RUNNER
test_runner = AOT_USMP_CORSTONE300_RUNNER

dtype = "int8"
shape = [1, 16, 16, 3]
Expand Down
7 changes: 4 additions & 3 deletions tests/python/contrib/test_cmsisnn/test_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from tests.python.relay.aot.aot_test_utils import (
AOTTestModel,
AOT_CORSTONE300_RUNNER,
AOT_USMP_CORSTONE300_RUNNER,
AOT_DEFAULT_RUNNER,
generate_ref_data,
compile_and_run,
Expand Down Expand Up @@ -143,7 +144,7 @@ def test_conv2d_symmetric_padding_int8(
):
interface_api = "c"
use_unpacked_api = True
test_runner = AOT_CORSTONE300_RUNNER
test_runner = AOT_USMP_CORSTONE300_RUNNER

ifm_shape = (1, 64, 100, 4)
kernel_size = (3, 3)
Expand Down Expand Up @@ -233,7 +234,7 @@ def test_conv2d_asymmetric_padding_int8(
):
interface_api = "c"
use_unpacked_api = True
test_runner = AOT_CORSTONE300_RUNNER
test_runner = AOT_USMP_CORSTONE300_RUNNER

ifm_shape = (1, 25, 25, 12)
kernel_size = (5, 5)
Expand Down Expand Up @@ -334,7 +335,7 @@ def test_depthwise_int8(
):
interface_api = "c"
use_unpacked_api = True
test_runner = AOT_CORSTONE300_RUNNER
test_runner = AOT_USMP_CORSTONE300_RUNNER

dtype = "int8"
groups = 1
Expand Down
3 changes: 2 additions & 1 deletion tests/python/contrib/test_cmsisnn/test_fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from tests.python.relay.aot.aot_test_utils import (
AOTTestModel,
AOT_CORSTONE300_RUNNER,
AOT_USMP_CORSTONE300_RUNNER,
AOT_DEFAULT_RUNNER,
generate_ref_data,
compile_and_run,
Expand Down Expand Up @@ -119,7 +120,7 @@ def test_op_int8(
):
interface_api = "c"
use_unpacked_api = True
test_runner = AOT_CORSTONE300_RUNNER
test_runner = AOT_USMP_CORSTONE300_RUNNER

dtype = "int8"
kernel_zero_point = 0
Expand Down
Loading

0 comments on commit 0e4b0d0

Please sign in to comment.