Skip to content

Commit

Permalink
Code generation for Depthwise Conv2D via CMSIS-NN
Browse files Browse the repository at this point in the history
Change-Id: Iac6acf1c310b20b9e99bf0b9c0310423c93de9be
  • Loading branch information
ashutosh-arm committed Nov 24, 2021
1 parent d5eb75a commit 3ad4ba9
Show file tree
Hide file tree
Showing 6 changed files with 327 additions and 114 deletions.
11 changes: 11 additions & 0 deletions python/tvm/relay/op/contrib/cmsisnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,16 @@ def check_qnn_conv2d(pattern):
kernel_zp = conv2d.args[3].data.numpy()
kernel_zp = [kernel_zp] if kernel_zp.ndim == 0 else kernel_zp

# check if depthwise Conv2D
kernel_layout = conv2d.attrs.kernel_layout
pos_o = kernel_layout.index("O")
groups = conv2d.attrs.groups
is_depthwise = False
if groups == int(conv2d_input.checked_type.shape[3]) and groups == int(
conv2d_weight.checked_type.shape[pos_o]
):
is_depthwise = True

return (
conv2d.attrs.out_dtype == "int32"
and conv2d.attrs.padding[2] == 0
Expand All @@ -132,6 +142,7 @@ def check_qnn_conv2d(pattern):
and pattern.checked_type.dtype == "int8"
and bias_dtype == "int32"
and all([zp == 0 for zp in kernel_zp])
and (not is_depthwise or bias_add is not None)
)

def binary_op_pattern(op):
Expand Down
21 changes: 15 additions & 6 deletions src/relay/backend/contrib/cmsisnn/generate_constants.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,20 @@ class GenerateConstantsMutator : public MixedModeMutator {
conv2d_call = requantize_input;
}

// Transpose weights: HWIO -> OHWI
auto* conv2d_attrs = conv2d_call->attrs.as<Conv2DAttrs>();
tvm::Attrs new_conv2d_attrs;
Expr transposed_kernel =
ConvertKernelLayout(conv2d_call->args[1], conv2d_attrs, &new_conv2d_attrs);
tvm::Attrs new_conv2d_attrs = conv2d_call->attrs;
Expr conv2d_kernel = conv2d_call->args[1];

Array<PrimExpr> input_shape = conv2d_call->args[0]->type_as<TensorTypeNode>()->shape;
Array<PrimExpr> kernel_shape = conv2d_call->args[1]->type_as<TensorTypeNode>()->shape;
std::string kernel_layout = conv2d_attrs->kernel_layout.c_str();
int kernel_pos_o = kernel_layout.find("O");
int groups = conv2d_attrs->groups;
if (groups != qnn::get_const_int(input_shape[3]) ||
groups != qnn::get_const_int(kernel_shape[kernel_pos_o])) {
// Transpose weights: HWIO -> OHWI for Conv2D
conv2d_kernel = ConvertKernelLayout(conv2d_call->args[1], conv2d_attrs, &new_conv2d_attrs);
}

// Obtain input and output scales from Relay's Requantization
int64_t out_channels = conv2d_attrs->channels.as<IntImmNode>()->value;
Expand Down Expand Up @@ -153,11 +162,11 @@ class GenerateConstantsMutator : public MixedModeMutator {
req_inp_scale = Constant(req_inp_scale_nda);
}

// Replace existing weights (HWIO) with the transposed ones (OHWI)
// Replace existing weights (HWIO) with the transposed ones (OHWI) for Conv2D
// Substitute Conv2D weight_zero_point with the CMSIS-NN multiplier
// Substitute Requantize input_zero_point with CMSIS-NN shift
// Conv2D arguments: data, weight, input_zp, weight_zp, input_sc, weight_sc
Array<Expr> conv2d_args = {conv2d_call->args[0], transposed_kernel, conv2d_call->args[2],
Array<Expr> conv2d_args = {conv2d_call->args[0], conv2d_kernel, conv2d_call->args[2],
multiplier_const, conv2d_call->args[4], weight_scale};
Call ret_call = Call(conv2d_call->op, conv2d_args, new_conv2d_attrs, {});
if (bias_add_call) {
Expand Down
43 changes: 31 additions & 12 deletions src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ class RelayToTIRVisitor : public MixedModeMutator {
int32_t padding_h = qnn::get_const_int(conv2d_attrs->padding[0]);
int32_t dilation_w = qnn::get_const_int(conv2d_attrs->dilation[1]);
int32_t dilation_h = qnn::get_const_int(conv2d_attrs->dilation[0]);
int32_t out_channels = qnn::get_const_int(conv2d_attrs->channels);
int32_t groups = conv2d_attrs->groups;
std::string kernel_layout = conv2d_attrs->kernel_layout.c_str();
int32_t clip_min, clip_max;
if (clip_call) {
const ClipAttrs* clip_attrs = clip_call->attrs.as<ClipAttrs>();
Expand All @@ -156,14 +159,6 @@ class RelayToTIRVisitor : public MixedModeMutator {
clip_max = 127;
}

tvm::Array<PrimExpr> call_ext_args = {tir::StringImm("arm_convolve_wrapper_s8"), input, filter,
multiplier};
if (bias_add_call) {
call_ext_args.push_back(bias);
}
call_ext_args.push_back(shift);
call_ext_args.push_back(output);

tvm::Array<PrimExpr> scalar_args = {ToArg(input_offset), ToArg(output_offset), ToArg(stride_w),
ToArg(stride_h), ToArg(padding_w), ToArg(padding_h),
ToArg(dilation_w), ToArg(dilation_h), ToArg(clip_min),
Expand All @@ -173,18 +168,42 @@ class RelayToTIRVisitor : public MixedModeMutator {
Array<PrimExpr> input_shape = conv2d_call->args[0]->type_as<TensorTypeNode>()->shape;
Array<PrimExpr> input_dims = CMSISNNDimensions(input_shape);

// cmsis_nn_dims *filter_dims (OHWI)
// cmsis_nn_dims *filter_dims (OHWI for Conv2D and IHWO for depthwise)
Array<PrimExpr> filter_shape = conv2d_call->args[1]->type_as<TensorTypeNode>()->shape;
Array<PrimExpr> filter_dims = CMSISNNDimensions(filter_shape);

// cmsis_nn_dims *bias_dims (1,1,1,output_channels)
Array<PrimExpr> bias_shape{1, 1, 1, filter_shape[0]};
// cmsis_nn_dims *bias_dims
Array<PrimExpr> bias_shape{1, 1, 1, out_channels};
Array<PrimExpr> bias_dims = CMSISNNDimensions(bias_shape);

// cmsis_nn_dims *output_dims (NHWC)
// cmsis_nn_dims *output_dims (same order as input_dims)
Array<PrimExpr> output_shape = conv2d_call->type_as<TensorTypeNode>()->shape;
Array<PrimExpr> output_dims = CMSISNNDimensions(output_shape);

int32_t depth_multiplier = -1;
int kernel_pos_o = kernel_layout.find("O");
if (groups == qnn::get_const_int(input_shape[3]) &&
groups == qnn::get_const_int(filter_shape[kernel_pos_o])) {
int kernel_pos_i = kernel_layout.find("I");
depth_multiplier = qnn::get_const_int(filter_shape[kernel_pos_i]);
}
scalar_args.push_back(ToArg(depth_multiplier));

// original filter_layout for depthwise is HWOI
std::string cmsisnn_api = "arm_convolve_wrapper_s8";
if (depth_multiplier != -1) {
cmsisnn_api = "arm_depthwise_conv_wrapper_s8";
Array<PrimExpr> depthwise_filter_shape{1, filter_shape[0], filter_shape[1], out_channels};
filter_dims = CMSISNNDimensions(depthwise_filter_shape);
}

tvm::Array<PrimExpr> call_ext_args = {tir::StringImm(cmsisnn_api), input, filter, multiplier};
if (bias_add_call) {
call_ext_args.push_back(bias);
}
call_ext_args.push_back(shift);
call_ext_args.push_back(output);

// https://github.com/ARM-software/CMSIS_5/blob/d788fd583984388553391de18afd8b4d2a146868/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_s8.c#L367
std::string context_buffer_name = "NULL";
size_t context_buffer_size =
Expand Down
Loading

0 comments on commit 3ad4ba9

Please sign in to comment.