Skip to content

Commit

Permalink
[7/10] Code generation for Pooling and Fully Connected via CMSIS-NN (a…
Browse files Browse the repository at this point in the history
…pache#9531)

Support for code generation of Maxpool, AvgPool and Fully Connected layers via CMSIS-NN
  • Loading branch information
ashutosh-arm authored and baoxinqi committed Dec 27, 2021
1 parent dfa2fad commit de361b8
Show file tree
Hide file tree
Showing 6 changed files with 919 additions and 34 deletions.
72 changes: 71 additions & 1 deletion python/tvm/relay/op/contrib/cmsisnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,73 @@ def check_qnn_conv2d(pattern):
and (not is_depthwise or bias_add is not None)
)

def qnn_fully_connected_pattern():
"""Create pattern for qnn.dense with optional Relu."""
qnn_fc = is_op("qnn.dense")(
wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant()
)
bias_add = is_op("nn.bias_add")(qnn_fc, is_constant())
req = is_op("qnn.requantize")(
qnn_fc | bias_add, is_constant(), is_constant(), is_constant(), is_constant()
)
clip_or_req = req.optional(is_op("clip"))
return clip_or_req

def check_qnn_fully_connected(pattern):
"""Check if the fully connected is supported by CMSIS-NN."""
if str(pattern.op.name) == "clip":
relu = pattern
requantize = relu.args[0]
else:
requantize = pattern
requantize_input = requantize.args[0]
bias_add = None
bias_dtype = "int32"
if str(requantize_input.op.name) == "nn.bias_add":
bias_add = requantize_input
fc = bias_add.args[0]
bias_dtype = bias_add.args[1].checked_type.dtype
else:
fc = requantize_input
fc_input = fc.args[0]
fc_weight = fc.args[1]

# kernel zero_point should be 0
kernel_zp = fc.args[3].data.numpy().item(0)

return (
fc.attrs.out_dtype == "int32"
and fc_input.checked_type.dtype == "int8"
and fc_weight.checked_type.dtype == "int8"
and pattern.checked_type.dtype == "int8"
and bias_dtype == "int32"
and kernel_zp == 0
)

def qnn_avg_pool2d_pattern():
"""Matches average pooling with optional Relu"""
pattern = is_op("cast")(wildcard())
pattern = is_op("nn.avg_pool2d")(pattern)
pattern = is_op("cast")(pattern)
pattern = pattern.optional(is_op("clip"))
return pattern

def check_qnn_avg_pool2d(pattern):
"""Check if avg pool2d is supported by CMSIS-NN."""
in_cast = pattern
out_cast = in_cast.args[0].args[0]
return in_cast.checked_type.dtype == "int8" and out_cast.checked_type.dtype == "int32"

def qnn_max_pool2d_pattern():
"""Matches max pool2d with optional Relu"""
pattern = is_op("nn.max_pool2d")(wildcard())
pattern = pattern.optional(is_op("clip"))
return pattern

def check_qnn_max_pool2d(pattern):
"""Check if max pool2d is supported by CMSIS-NN."""
return True

def binary_op_pattern(op):
"""Matches QNN binary operation"""
return is_op(f"qnn.{op}")(
Expand All @@ -166,8 +233,11 @@ def check_qnn_binary_op(extract):
)

return [
("cmsis-nn.qnn_softmax", qnn_softmax_pattern(), check_qnn_softmax),
("cmsis-nn.qnn_conv2d", qnn_conv2d_pattern(), check_qnn_conv2d),
("cmsis-nn.qnn_fully_connected", qnn_fully_connected_pattern(), check_qnn_fully_connected),
("cmsis-nn.qnn_avg_pool2d", qnn_avg_pool2d_pattern(), check_qnn_avg_pool2d),
("cmsis-nn.qnn_max_pool2d", qnn_max_pool2d_pattern(), check_qnn_max_pool2d),
("cmsis-nn.qnn_mul", binary_op_pattern("mul"), check_qnn_binary_op),
("cmsis-nn.qnn_add", binary_op_pattern("add"), check_qnn_binary_op),
("cmsis-nn.qnn_softmax", qnn_softmax_pattern(), check_qnn_softmax),
]
4 changes: 2 additions & 2 deletions src/relay/backend/contrib/cmsisnn/generate_constants.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@ class GenerateConstantsMutator : public MixedModeMutator {
int32_t* multiplier = static_cast<int32_t*>(multiplier_nda->data);
int32_t* shift = static_cast<int32_t*>(shift_nda->data);
for (int i = 0; i < out_channels; ++i) {
double effective_output_scale =
double quantized_multiplier =
static_cast<double>(input_scales[i]) / static_cast<double>(output_scale);
std::tie(*(multiplier + i), *(shift + i)) =
tvm::relay::qnn::GetFixedPointMultiplierShift(effective_output_scale);
tvm::relay::qnn::GetFixedPointMultiplierShift(quantized_multiplier);
}

// Create constants from requantization multiplier and shift
Expand Down
225 changes: 204 additions & 21 deletions src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,9 @@ class RelayToTIRVisitor : public MixedModeMutator {

tir::PrimFunc replacement_func(func_signature, body, VoidType(), Map<tir::Var, tir::Buffer>(),
DictAttrs(dict_attrs));

ir_module_->Add(global_var, replacement_func);
}

Array<PrimExpr> CMSISNNDimensions(const Array<PrimExpr>& shape) {
ICHECK(shape.size() == 4) << "Supports only CMSIS-NN shapes of dimension 4.";
return Array<PrimExpr>{ToArg(qnn::get_const_int(shape[0])), ToArg(qnn::get_const_int(shape[1])),
ToArg(qnn::get_const_int(shape[2])),
ToArg(qnn::get_const_int(shape[3]))};
}

void EmitConv2D(const GlobalVar& global_var, const Expr& expr) {
const CallNode* clip_call = nullptr;
const CallNode* requantize_call = nullptr;
Expand Down Expand Up @@ -164,21 +156,15 @@ class RelayToTIRVisitor : public MixedModeMutator {
ToArg(dilation_w), ToArg(dilation_h), ToArg(clip_min),
ToArg(clip_max)};

// cmsis_nn_dims *input_dims (NHWC)
// layout NHWC
Array<PrimExpr> input_shape = conv2d_call->args[0]->type_as<TensorTypeNode>()->shape;
Array<PrimExpr> input_dims = CMSISNNDimensions(input_shape);

// cmsis_nn_dims *filter_dims (OHWI for Conv2D and IHWO for depthwise)
// OHWI for Conv2D and IHWO for depthwise
Array<PrimExpr> filter_shape = conv2d_call->args[1]->type_as<TensorTypeNode>()->shape;
Array<PrimExpr> filter_dims = CMSISNNDimensions(filter_shape);

// cmsis_nn_dims *bias_dims
Array<PrimExpr> bias_shape{1, 1, 1, out_channels};
Array<PrimExpr> bias_dims = CMSISNNDimensions(bias_shape);

// cmsis_nn_dims *output_dims (same order as input_dims)
Array<PrimExpr> output_shape = conv2d_call->type_as<TensorTypeNode>()->shape;
Array<PrimExpr> output_dims = CMSISNNDimensions(output_shape);

int32_t depth_multiplier = -1;
int kernel_pos_o = kernel_layout.find("O");
Expand All @@ -194,7 +180,7 @@ class RelayToTIRVisitor : public MixedModeMutator {
if (depth_multiplier != -1) {
cmsisnn_api = "arm_depthwise_conv_wrapper_s8";
Array<PrimExpr> depthwise_filter_shape{1, filter_shape[0], filter_shape[1], out_channels};
filter_dims = CMSISNNDimensions(depthwise_filter_shape);
filter_shape = depthwise_filter_shape;
}

tvm::Array<PrimExpr> call_ext_args = {tir::StringImm(cmsisnn_api), input, filter, multiplier};
Expand All @@ -216,10 +202,10 @@ class RelayToTIRVisitor : public MixedModeMutator {
ToArg(context_buffer_size)};

scalar_args = tvm::runtime::Concat(context_buffer_args, scalar_args);
scalar_args = tvm::runtime::Concat(scalar_args, input_dims);
scalar_args = tvm::runtime::Concat(scalar_args, filter_dims);
scalar_args = tvm::runtime::Concat(scalar_args, bias_dims);
scalar_args = tvm::runtime::Concat(scalar_args, output_dims);
scalar_args = tvm::runtime::Concat(scalar_args, input_shape);
scalar_args = tvm::runtime::Concat(scalar_args, filter_shape);
scalar_args = tvm::runtime::Concat(scalar_args, bias_shape);
scalar_args = tvm::runtime::Concat(scalar_args, output_shape);
call_ext_args = tvm::runtime::Concat(call_ext_args, scalar_args);

Array<tir::Var> func_signature{input, filter, multiplier, filter_scale};
Expand All @@ -234,6 +220,197 @@ class RelayToTIRVisitor : public MixedModeMutator {
context_buffer_size);
}

void EmitFullyConnected(const GlobalVar& global_var, const Expr& expr) {
const CallNode* clip_call = nullptr;
const CallNode* requantize_call = nullptr;
const CallNode* bias_add_call = nullptr;
const CallNode* fc_call = nullptr;
const CallNode* final_call = expr.as<CallNode>();
const OpNode* final_op = final_call->op.as<OpNode>();
if (final_op->name == "clip") {
clip_call = final_call;
requantize_call = clip_call->args[0].as<CallNode>();
} else {
requantize_call = final_call;
}
const CallNode* requantize_input = requantize_call->args[0].as<CallNode>();
const OpNode* requantize_input_op = requantize_input->op.as<OpNode>();
if (requantize_input_op->name == "nn.bias_add") {
bias_add_call = requantize_input;
fc_call = bias_add_call->args[0].as<CallNode>();
} else {
fc_call = requantize_input;
}

// TIR variables are created in the order they appear in the Relay partitioned function
// %1 = qnn.dense(%input, %weight_const_0, input_zero_point_scalar, kernel_zero_point_scalar,
// %input_scale_scalar, %kernel_scale_scalar)
// %2 = nn.bias_add(%1, %bias_const_1, axis=1)
// %3 = qnn.requantize(%2, %req_input_scale_scalar, %req_input_zero_point_scalar,
// %output_scale_scalar, %output_zero_point_scalar)
// clip(%3, a_min=%min_scalar, a_max=%max_scalar)
tir::Var input("input", DataType::Handle(8));
tir::Var filter("filter", DataType::Handle(8));
tir::Var bias("bias", DataType::Handle(32));
tir::Var output("output", DataType::Handle(8));

// Individual arguments to the structs arguments of the CMSIS-NN API are filled into call_extern
// https://github.com/ARM-software/CMSIS_5/blob/def6f800f95661eb3451d317f7d0dde504f6020d/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_wrapper_s8.c#L50

// prepare cmsis_nn_fc_params
const DenseAttrs* dense_attrs = fc_call->attrs.as<DenseAttrs>();
int32_t input_offset = -GetScalarFromConstant<int32_t>(fc_call->args[2]);
int32_t filter_offset = -GetScalarFromConstant<int32_t>(fc_call->args[3]);
int32_t output_offset = GetScalarFromConstant<int32_t>(requantize_call->args[4]);
float input_scale = GetScalarFromConstant<float>(requantize_call->args[1]);
float output_scale = GetScalarFromConstant<float>(requantize_call->args[3]);
int32_t out_channels = qnn::get_const_int(dense_attrs->units);
int32_t clip_min, clip_max;
if (clip_call) {
const ClipAttrs* clip_attrs = clip_call->attrs.as<ClipAttrs>();
clip_min = clip_attrs->a_min;
clip_max = clip_attrs->a_max;
} else {
clip_min = -128;
clip_max = 127;
}

double quantized_multiplier =
static_cast<double>(input_scale) / static_cast<double>(output_scale);
auto mult_shift_pair = tvm::relay::qnn::GetFixedPointMultiplierShift(quantized_multiplier);
int32_t multiplier = std::get<0>(mult_shift_pair);
int32_t shift = std::get<1>(mult_shift_pair);

tvm::Array<PrimExpr> scalar_args = {
ToArg(input_offset), ToArg(filter_offset), ToArg(output_offset), ToArg(clip_min),
ToArg(clip_max), ToArg(multiplier), ToArg(shift)};

Array<PrimExpr> input_shape = fc_call->args[0]->type_as<TensorTypeNode>()->shape;
int32_t batch_size = qnn::get_const_int(input_shape[0]);
int32_t in_channels = qnn::get_const_int(input_shape[1]);
Array<PrimExpr> cmsisnn_input_shape{input_shape[0], 1, 1, input_shape[1]};

Array<PrimExpr> cmsisnn_filter_shape{in_channels, 1, 1, out_channels};

Array<PrimExpr> bias_shape{1, 1, 1, out_channels};

Array<PrimExpr> cmsisnn_output_shape{batch_size, 1, 1, out_channels};

tvm::Array<PrimExpr> call_ext_args = {tir::StringImm("arm_fully_connected_s8"), input, filter};
if (bias_add_call) {
call_ext_args.push_back(bias);
}
call_ext_args.push_back(output);

int context_buffer_size = 0;
std::string context_buffer_name = "NULL";
tvm::Array<PrimExpr> context_buffer_args = {tir::StringImm(context_buffer_name),
ToArg(context_buffer_size)};

scalar_args = tvm::runtime::Concat(context_buffer_args, scalar_args);
scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_input_shape);
scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_filter_shape);
scalar_args = tvm::runtime::Concat(scalar_args, bias_shape);
scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_output_shape);
call_ext_args = tvm::runtime::Concat(call_ext_args, scalar_args);

Array<tir::Var> func_signature{input, filter};
if (bias_add_call) {
func_signature.push_back(bias);
}
func_signature.push_back(output);
CreatePrimFuncForExtern(global_var, func_signature, call_ext_args, context_buffer_name,
context_buffer_size);
}

void EmitPool2D(const GlobalVar& global_var, const Expr& expr, const String pool_name) {
Call clip, pool;
Call final_call = GetRef<Call>(expr.as<CallNode>());
Op final_op = GetRef<Op>(final_call->op.as<OpNode>());
if (final_op->name == "clip") {
clip = final_call;
Call clip_input = GetRef<Call>(clip->args[0].as<CallNode>());
Op clip_input_op = GetRef<Op>(clip_input->op.as<OpNode>());
if (clip_input_op->name == "cast") {
pool = GetRef<Call>(clip_input->args[0].as<CallNode>());
} else { // max_pool2d
pool = clip_input;
}
} else if (final_op->name == "cast") {
pool = GetRef<Call>(final_call->args[0].as<CallNode>());
} else { // max_pool2d
pool = final_call;
}

// prepare cmsis_nn_pool_params
int32_t stride_h, stride_w, padding_h, padding_w, pool_size_h, pool_size_w;
int32_t clip_min, clip_max;
std::string cmsisnn_api;
if (pool_name == "cmsis-nn.qnn_avg_pool2d") {
cmsisnn_api = "arm_avgpool_s8";
const AvgPool2DAttrs* attrs = pool->attrs.as<AvgPool2DAttrs>();
stride_h = qnn::get_const_int(attrs->strides[0]);
stride_w = qnn::get_const_int(attrs->strides[1]);
padding_h = qnn::get_const_int(attrs->padding[0]);
padding_w = qnn::get_const_int(attrs->padding[1]);
pool_size_h = qnn::get_const_int(attrs->pool_size[0]);
pool_size_w = qnn::get_const_int(attrs->pool_size[1]);
} else {
cmsisnn_api = "arm_max_pool_s8";
const MaxPool2DAttrs* attrs = pool->attrs.as<MaxPool2DAttrs>();
stride_h = qnn::get_const_int(attrs->strides[0]);
stride_w = qnn::get_const_int(attrs->strides[1]);
padding_h = qnn::get_const_int(attrs->padding[0]);
padding_w = qnn::get_const_int(attrs->padding[1]);
pool_size_h = qnn::get_const_int(attrs->pool_size[0]);
pool_size_w = qnn::get_const_int(attrs->pool_size[1]);
}
if (clip.defined()) {
const ClipAttrs* clip_attrs = clip->attrs.as<ClipAttrs>();
clip_min = clip_attrs->a_min;
clip_max = clip_attrs->a_max;
} else {
clip_min = -128;
clip_max = 127;
}

tvm::Array<PrimExpr> scalar_args = {ToArg(stride_h), ToArg(stride_w), ToArg(padding_h),
ToArg(padding_w), ToArg(clip_min), ToArg(clip_max)};

Array<PrimExpr> input_shape = pool->args[0]->type_as<TensorTypeNode>()->shape;
Array<PrimExpr> cmsisnn_input_shape{1, input_shape[1], input_shape[2], input_shape[3]};

Array<PrimExpr> cmsisnn_filter_shape{1, pool_size_h, pool_size_w, 1};

Array<PrimExpr> output_shape = pool->type_as<TensorTypeNode>()->shape;
Array<PrimExpr> cmsisnn_output_shape{1, output_shape[1], output_shape[2], output_shape[3]};

tir::Var input("input", DataType::Handle(8));
tir::Var output("output", DataType::Handle(8));
tvm::Array<PrimExpr> call_ext_args = {tir::StringImm(cmsisnn_api), input, output};

int context_buffer_size = 0;
std::string context_buffer_name = "NULL";
if (pool_name == "cmsisnn.qnn_avg_pool2d") {
// TODO(@Mousius): Need to move this into buffer_size calculations
context_buffer_size = qnn::get_const_int(input_shape[3]) * sizeof(int32_t);
context_buffer_name = "context_buffer_" + std::to_string(context_buffer_id_++);
}
tvm::Array<PrimExpr> context_buffer_args = {tir::StringImm(context_buffer_name),
ToArg(context_buffer_size)};

scalar_args = tvm::runtime::Concat(context_buffer_args, scalar_args);
scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_input_shape);
scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_filter_shape);
scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_output_shape);
call_ext_args = tvm::runtime::Concat(call_ext_args, scalar_args);

Array<tir::Var> func_signature{input, output};

CreatePrimFuncForExtern(global_var, func_signature, call_ext_args, context_buffer_name,
context_buffer_size);
}

void EmitSoftMax(const GlobalVar& global_var, const Expr& expr) {
const CallNode* quantize_call = expr.as<CallNode>();
const CallNode* softmax_call = quantize_call->args[0].as<CallNode>();
Expand Down Expand Up @@ -422,6 +599,12 @@ class RelayToTIRVisitor : public MixedModeMutator {
if (comp_name == "cmsis-nn.qnn_conv2d") {
EmitConv2D(new_global_var, composite_func->body);
}
if (comp_name == "cmsis-nn.qnn_fully_connected") {
EmitFullyConnected(new_global_var, composite_func->body);
}
if (comp_name == "cmsis-nn.qnn_avg_pool2d" || comp_name == "cmsis-nn.qnn_max_pool2d") {
EmitPool2D(new_global_var, composite_func->body, comp_name.value());
}

Array<Expr> args;
for (const auto& arg : call->args) {
Expand Down
Loading

0 comments on commit de361b8

Please sign in to comment.