Skip to content

Commit

Permalink
Code generation for Pooling layers via CMSIS-NN
Browse files Browse the repository at this point in the history
Change-Id: Ibf22250d961a683208faee362d1960ea266347e8
  • Loading branch information
ashutosh-arm committed Nov 29, 2021
1 parent dbf40a0 commit 28bc3dd
Show file tree
Hide file tree
Showing 4 changed files with 386 additions and 1 deletion.
28 changes: 27 additions & 1 deletion python/tvm/relay/op/contrib/cmsisnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def check_qnn_conv2d(pattern):
)

def qnn_fully_connected_pattern():
"""Create pattern for qnn.dense with optional relu."""
"""Create pattern for qnn.dense with optional Relu."""
qnn_fc = is_op("qnn.dense")(
wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant()
)
Expand Down Expand Up @@ -188,6 +188,30 @@ def check_qnn_fully_connected(pattern):
and kernel_zp == 0
)

def qnn_avg_pool2d_pattern():
"""Matches average pooling with optional Relu"""
pattern = is_op("cast")(wildcard())
pattern = is_op("nn.avg_pool2d")(pattern)
pattern = is_op("cast")(pattern)
pattern = pattern.optional(is_op("clip"))
return pattern

def check_qnn_avg_pool2d(pattern):
"""Check if avg pool2d is supported by CMSIS-NN."""
in_cast = pattern
out_cast = in_cast.args[0].args[0]
return in_cast.checked_type.dtype == "int8" and out_cast.checked_type.dtype == "int32"

def qnn_max_pool2d_pattern():
"""Matches max pooling with optional Relu"""
pattern = is_op("nn.max_pool2d")(wildcard())
pattern = pattern.optional(is_op("clip"))
return pattern

def check_qnn_max_pool2d(pattern):
"""Check if max pool2d is supported by CMSIS-NN."""
return True

def binary_op_pattern(op):
"""Matches QNN binary operation"""
return is_op(f"qnn.{op}")(
Expand All @@ -211,6 +235,8 @@ def check_qnn_binary_op(extract):
return [
("cmsis-nn.qnn_conv2d", qnn_conv2d_pattern(), check_qnn_conv2d),
("cmsis-nn.qnn_fully_connected", qnn_fully_connected_pattern(), check_qnn_fully_connected),
("cmsis-nn.qnn_avg_pool2d", qnn_avg_pool2d_pattern(), check_qnn_avg_pool2d),
("cmsis-nn.qnn_max_pool2d", qnn_max_pool2d_pattern(), check_qnn_max_pool2d),
("cmsis-nn.qnn_mul", binary_op_pattern("mul"), check_qnn_binary_op),
("cmsis-nn.qnn_add", binary_op_pattern("add"), check_qnn_binary_op),
("cmsis-nn.qnn_softmax", qnn_softmax_pattern(), check_qnn_softmax),
Expand Down
94 changes: 94 additions & 0 deletions src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,97 @@ class RelayToTIRVisitor : public MixedModeMutator {
context_buffer_size);
}

void EmitPool2D(const GlobalVar& global_var, const Expr& expr, const String pool_name) {
Call clip, pool;
Call final_call = GetRef<Call>(expr.as<CallNode>());
Op final_op = GetRef<Op>(final_call->op.as<OpNode>());
if (final_op->name == "clip") {
clip = final_call;
Call clip_input = GetRef<Call>(clip->args[0].as<CallNode>());
Op clip_input_op = GetRef<Op>(clip_input->op.as<OpNode>());
if (clip_input_op->name == "cast") {
pool = GetRef<Call>(clip_input->args[0].as<CallNode>());
} else { // max_pool2d
pool = clip_input;
}
} else if (final_op->name == "cast") {
pool = GetRef<Call>(final_call->args[0].as<CallNode>());
} else { // max_pool2d
pool = final_call;
}

// prepare cmsis_nn_pool_params
int32_t stride_h, stride_w, padding_h, padding_w, pool_size_h, pool_size_w;
int32_t clip_min, clip_max;
std::string cmsisnn_api;
if (pool_name == "cmsis-nn.qnn_avg_pool2d") {
cmsisnn_api = "arm_avgpool_s8";
const AvgPool2DAttrs* attrs = pool->attrs.as<AvgPool2DAttrs>();
stride_h = qnn::get_const_int(attrs->strides[0]);
stride_w = qnn::get_const_int(attrs->strides[1]);
padding_h = qnn::get_const_int(attrs->padding[0]);
padding_w = qnn::get_const_int(attrs->padding[1]);
pool_size_h = qnn::get_const_int(attrs->pool_size[0]);
pool_size_w = qnn::get_const_int(attrs->pool_size[1]);
} else {
cmsisnn_api = "arm_max_pool_s8";
const MaxPool2DAttrs* attrs = pool->attrs.as<MaxPool2DAttrs>();
stride_h = qnn::get_const_int(attrs->strides[0]);
stride_w = qnn::get_const_int(attrs->strides[1]);
padding_h = qnn::get_const_int(attrs->padding[0]);
padding_w = qnn::get_const_int(attrs->padding[1]);
pool_size_h = qnn::get_const_int(attrs->pool_size[0]);
pool_size_w = qnn::get_const_int(attrs->pool_size[1]);
}
if (clip.defined()) {
const ClipAttrs* clip_attrs = clip->attrs.as<ClipAttrs>();
clip_min = clip_attrs->a_min;
clip_max = clip_attrs->a_max;
} else {
clip_min = -128;
clip_max = 127;
}

tvm::Array<PrimExpr> scalar_args = {ToArg(stride_h), ToArg(stride_w), ToArg(padding_h),
ToArg(padding_w), ToArg(clip_min), ToArg(clip_max)};

// cmsis_nn_dims *input_dims
Array<PrimExpr> input_shape = pool->args[0]->type_as<TensorTypeNode>()->shape;
Array<PrimExpr> cmsisnn_input_shape{1, input_shape[1], input_shape[2], input_shape[3]};

// cmsis_nn_dims *filter_dims
Array<PrimExpr> cmsisnn_filter_shape{1, pool_size_h, pool_size_w, 1};

// cmsis_nn_dims *output_dims
Array<PrimExpr> output_shape = pool->type_as<TensorTypeNode>()->shape;
Array<PrimExpr> cmsisnn_output_shape{1, output_shape[1], output_shape[2], output_shape[3]};

tir::Var input("input", DataType::Handle(8));
tir::Var output("output", DataType::Handle(8));
tvm::Array<PrimExpr> call_ext_args = {tir::StringImm(cmsisnn_api), input, output};

int context_buffer_size = 0;
std::string context_buffer_name = "NULL";
if (pool_name == "cmsisnn.qnn_avg_pool2d") {
// TODO(@Mousius): Need to move this into buffer_size calculations
context_buffer_size = qnn::get_const_int(input_shape[3]) * sizeof(int32_t);
context_buffer_name = "context_buffer_" + std::to_string(context_buffer_id_++);
}
tvm::Array<PrimExpr> context_buffer_args = {tir::StringImm(context_buffer_name),
ToArg(context_buffer_size)};

scalar_args = tvm::runtime::Concat(context_buffer_args, scalar_args);
scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_input_shape);
scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_filter_shape);
scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_output_shape);
call_ext_args = tvm::runtime::Concat(call_ext_args, scalar_args);

Array<tir::Var> func_signature{input, output};

CreatePrimFuncForExtern(global_var, func_signature, call_ext_args, context_buffer_name,
context_buffer_size);
}

void EmitSoftMax(const GlobalVar& global_var, const Expr& expr) {
const CallNode* quantize_call = expr.as<CallNode>();
const CallNode* softmax_call = quantize_call->args[0].as<CallNode>();
Expand Down Expand Up @@ -521,6 +612,9 @@ class RelayToTIRVisitor : public MixedModeMutator {
if (comp_name == "cmsis-nn.qnn_fully_connected") {
EmitFullyConnected(new_global_var, composite_func->body);
}
if (comp_name == "cmsis-nn.qnn_avg_pool2d" || comp_name == "cmsis-nn.qnn_max_pool2d") {
EmitPool2D(new_global_var, composite_func->body, comp_name.value());
}

Array<Expr> args;
for (const auto& arg : call->args) {
Expand Down
93 changes: 93 additions & 0 deletions src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,15 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
int shift;
};

struct PoolParams {
int stride_h;
int stride_w;
int padding_h;
int padding_w;
int clip_min;
int clip_max;
};

/*! * \brief Emits CMSIS-NN APIs for every call_extern */
void VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
if (!op->op.same_as(builtin::call_extern())) {
Expand All @@ -107,6 +116,8 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
EmitConv2D(op);
} else if (cmsis_func_name == "arm_fully_connected_s8") {
EmitFullyConnected(op);
} else if (cmsis_func_name == "arm_avgpool_s8" || cmsis_func_name == "arm_max_pool_s8") {
EmitPool2D(op);
}
return;
}
Expand Down Expand Up @@ -160,6 +171,22 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
return instance_name;
}

/*! * \brief Emits cmsis_nn_pool_params struct */
std::string EmitCMSISNNPoolParams(std::ostream& os, PoolParams params) {
std::string struct_name = "cmsis_nn_pool_params";
std::string instance_name = "pool_params";
PrintIndent();
os << "cmsis_nn_tile stride = {" << params.stride_w << "," << params.stride_h << "};\n";
PrintIndent();
os << "cmsis_nn_tile padding = {" << params.padding_w << "," << params.padding_h << "};\n";
PrintIndent();
os << "cmsis_nn_activation activation = {" << params.clip_min << "," << params.clip_max
<< "};\n";
PrintIndent();
os << struct_name << " " << instance_name << " = {stride, padding, activation};\n";
return instance_name;
}

/*! * \brief Emits cmsis_nn_per_channel_quant_params struct */
std::string EmitCMSISNNPerChannelQuantParams(std::ostream& os, std::string multiplier,
std::string shift) {
Expand Down Expand Up @@ -234,6 +261,18 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
return fc_params;
}

/*! * \brief extracts CMSIS-NN Pooling parameters from call_extern */
PoolParams extract_pool_params(const CallNode* op, int base_pos) {
PoolParams pool_params;
pool_params.stride_h = ValueFromArg(op, base_pos);
pool_params.stride_w = ValueFromArg(op, ++base_pos);
pool_params.padding_h = ValueFromArg(op, ++base_pos);
pool_params.padding_w = ValueFromArg(op, ++base_pos);
pool_params.clip_min = ValueFromArg(op, ++base_pos);
pool_params.clip_max = ValueFromArg(op, ++base_pos);
return pool_params;
}

/*! * \brief extracts CMSIS-NN buffer dimensions from call_extern */
CMSISNNDims extract_buffer_dims(const CallNode* op, int base_pos) {
CMSISNNDims dims;
Expand Down Expand Up @@ -383,6 +422,60 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
PrintIndent();
stream << "}\n";
}

/*! * \brief Emits CMSIS-NN APIs for every call_extern comprising pooling ops */
void EmitPool2D(const CallNode* op) {
// Position of various arguments relative to buffers in the call_extern
enum CallExternArgPos {
CONTEXT_BUFFER_POS = 1,
POOL_PARAMS_POS = 3,
INPUT_DIM_POS = 9,
FILTER_DIM_POS = 13,
OUTPUT_DIM_POS = 17,
MAX_NUM_ARGS = 23
};
std::string cmsis_func_name = op->args[0].as<StringImmNode>()->value;

// extract buffer names from call_extern
int arg_id = 0;
std::string input_data = VarNameFromArg(op, ++arg_id);
std::string output_data = VarNameFromArg(op, ++arg_id);

// extract CMSIS-NN API parameters
int context_buffer_pos = arg_id + CallExternArgPos::CONTEXT_BUFFER_POS;
int pool_params_pos = arg_id + CallExternArgPos::POOL_PARAMS_POS;
int input_dim_pos = arg_id + CallExternArgPos::INPUT_DIM_POS;
int filter_dim_pos = arg_id + CallExternArgPos::FILTER_DIM_POS;
int output_dim_pos = arg_id + CallExternArgPos::OUTPUT_DIM_POS;

CMSISNNContextBuffer context_buffer = extract_context_buffer_info(op, context_buffer_pos);
PoolParams pool_params = extract_pool_params(op, pool_params_pos);
CMSISNNDims input_dims = extract_buffer_dims(op, input_dim_pos);
CMSISNNDims filter_dims = extract_buffer_dims(op, filter_dim_pos);
CMSISNNDims output_dims = extract_buffer_dims(op, output_dim_pos);

std::string context = EmitCMSISNNContext(stream, context_buffer);
std::string cmsisnn_pool_params = EmitCMSISNNPoolParams(stream, pool_params);
std::string input_dim = EmitCMSISNNDims(stream, "input", input_dims);
std::string filter_dim = EmitCMSISNNDims(stream, "filter", filter_dims);
std::string output_dim = EmitCMSISNNDims(stream, "output", output_dims);

PrintIndent();
stream << "arm_status status = ";
stream << cmsis_func_name << "(";
stream << "&" << context << ", ";
stream << "&" << cmsisnn_pool_params << ", ";
stream << "&" << input_dim << ", " << input_data << ", ";
stream << "&" << filter_dim << ", ";
stream << "&" << output_dim << ", " << output_data << ");\n";
PrintIndent();
stream << "if (status != ARM_MATH_SUCCESS) {\n";
PrintIndent();
PrintIndent();
stream << "return -1;\n";
PrintIndent();
stream << "}\n";
}
};

runtime::Module TIRToRuntime(IRModule mod, Target target) {
Expand Down
Loading

0 comments on commit 28bc3dd

Please sign in to comment.