diff --git a/python/tvm/relay/op/contrib/cmsisnn.py b/python/tvm/relay/op/contrib/cmsisnn.py index 34efb1d7a162..5e0ad2796a16 100644 --- a/python/tvm/relay/op/contrib/cmsisnn.py +++ b/python/tvm/relay/op/contrib/cmsisnn.py @@ -123,6 +123,16 @@ def check_qnn_conv2d(pattern): kernel_zp = conv2d.args[3].data.numpy() kernel_zp = [kernel_zp] if kernel_zp.ndim == 0 else kernel_zp + # check if depthwise Conv2D + kernel_layout = conv2d.attrs.kernel_layout + pos_o = kernel_layout.index("O") + groups = conv2d.attrs.groups + is_depthwise = False + if groups == int(conv2d_input.checked_type.shape[3]) and groups == int( + conv2d_weight.checked_type.shape[pos_o] + ): + is_depthwise = True + return ( conv2d.attrs.out_dtype == "int32" and conv2d.attrs.padding[2] == 0 @@ -132,6 +142,7 @@ def check_qnn_conv2d(pattern): and pattern.checked_type.dtype == "int8" and bias_dtype == "int32" and all([zp == 0 for zp in kernel_zp]) + and (not is_depthwise or bias_add is not None) ) def binary_op_pattern(op): diff --git a/src/relay/backend/contrib/cmsisnn/generate_constants.cc b/src/relay/backend/contrib/cmsisnn/generate_constants.cc index 0231e8b52117..2e12697f36f1 100644 --- a/src/relay/backend/contrib/cmsisnn/generate_constants.cc +++ b/src/relay/backend/contrib/cmsisnn/generate_constants.cc @@ -105,11 +105,20 @@ class GenerateConstantsMutator : public MixedModeMutator { conv2d_call = requantize_input; } - // Transpose weights: HWIO -> OHWI auto* conv2d_attrs = conv2d_call->attrs.as(); - tvm::Attrs new_conv2d_attrs; - Expr transposed_kernel = - ConvertKernelLayout(conv2d_call->args[1], conv2d_attrs, &new_conv2d_attrs); + tvm::Attrs new_conv2d_attrs = conv2d_call->attrs; + Expr conv2d_kernel = conv2d_call->args[1]; + + Array input_shape = conv2d_call->args[0]->type_as()->shape; + Array kernel_shape = conv2d_call->args[1]->type_as()->shape; + std::string kernel_layout = conv2d_attrs->kernel_layout.c_str(); + int kernel_pos_o = kernel_layout.find("O"); + int groups = conv2d_attrs->groups; + if (groups != qnn::get_const_int(input_shape[3]) || + groups != qnn::get_const_int(kernel_shape[kernel_pos_o])) { + // Transpose weights: HWIO -> OHWI for Conv2D + conv2d_kernel = ConvertKernelLayout(conv2d_call->args[1], conv2d_attrs, &new_conv2d_attrs); + } // Obtain input and output scales from Relay's Requantization int64_t out_channels = conv2d_attrs->channels.as()->value; @@ -153,11 +162,11 @@ class GenerateConstantsMutator : public MixedModeMutator { req_inp_scale = Constant(req_inp_scale_nda); } - // Replace existing weights (HWIO) with the transposed ones (OHWI) + // Replace existing weights (HWIO) with the transposed ones (OHWI) for Conv2D // Substitute Conv2D weight_zero_point with the CMSIS-NN multiplier // Substitute Requantize input_zero_point with CMSIS-NN shift // Conv2D arguments: data, weight, input_zp, weight_zp, input_sc, weight_sc - Array conv2d_args = {conv2d_call->args[0], transposed_kernel, conv2d_call->args[2], + Array conv2d_args = {conv2d_call->args[0], conv2d_kernel, conv2d_call->args[2], multiplier_const, conv2d_call->args[4], weight_scale}; Call ret_call = Call(conv2d_call->op, conv2d_args, new_conv2d_attrs, {}); if (bias_add_call) { diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc index 1b639dd36e9d..668352700805 100644 --- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc +++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc @@ -146,6 +146,9 @@ class RelayToTIRVisitor : public MixedModeMutator { int32_t padding_h = qnn::get_const_int(conv2d_attrs->padding[0]); int32_t dilation_w = qnn::get_const_int(conv2d_attrs->dilation[1]); int32_t dilation_h = qnn::get_const_int(conv2d_attrs->dilation[0]); + int32_t out_channels = qnn::get_const_int(conv2d_attrs->channels); + int32_t groups = conv2d_attrs->groups; + std::string kernel_layout = conv2d_attrs->kernel_layout.c_str(); int32_t clip_min, clip_max; if (clip_call) { const ClipAttrs* clip_attrs = clip_call->attrs.as(); @@ -156,14 +159,6 @@ class RelayToTIRVisitor : public MixedModeMutator { clip_max = 127; } - tvm::Array call_ext_args = {tir::StringImm("arm_convolve_wrapper_s8"), input, filter, - multiplier}; - if (bias_add_call) { - call_ext_args.push_back(bias); - } - call_ext_args.push_back(shift); - call_ext_args.push_back(output); - tvm::Array scalar_args = {ToArg(input_offset), ToArg(output_offset), ToArg(stride_w), ToArg(stride_h), ToArg(padding_w), ToArg(padding_h), ToArg(dilation_w), ToArg(dilation_h), ToArg(clip_min), @@ -173,18 +168,42 @@ class RelayToTIRVisitor : public MixedModeMutator { Array input_shape = conv2d_call->args[0]->type_as()->shape; Array input_dims = CMSISNNDimensions(input_shape); - // cmsis_nn_dims *filter_dims (OHWI) + // cmsis_nn_dims *filter_dims (OHWI for Conv2D and IHWO for depthwise) Array filter_shape = conv2d_call->args[1]->type_as()->shape; Array filter_dims = CMSISNNDimensions(filter_shape); - // cmsis_nn_dims *bias_dims (1,1,1,output_channels) - Array bias_shape{1, 1, 1, filter_shape[0]}; + // cmsis_nn_dims *bias_dims + Array bias_shape{1, 1, 1, out_channels}; Array bias_dims = CMSISNNDimensions(bias_shape); - // cmsis_nn_dims *output_dims (NHWC) + // cmsis_nn_dims *output_dims (same order as input_dims) Array output_shape = conv2d_call->type_as()->shape; Array output_dims = CMSISNNDimensions(output_shape); + int32_t depth_multiplier = -1; + int kernel_pos_o = kernel_layout.find("O"); + if (groups == qnn::get_const_int(input_shape[3]) && + groups == qnn::get_const_int(filter_shape[kernel_pos_o])) { + int kernel_pos_i = kernel_layout.find("I"); + depth_multiplier = qnn::get_const_int(filter_shape[kernel_pos_i]); + } + scalar_args.push_back(ToArg(depth_multiplier)); + + // original filter_layout for depthwise is HWOI + std::string cmsisnn_api = "arm_convolve_wrapper_s8"; + if (depth_multiplier != -1) { + cmsisnn_api = "arm_depthwise_conv_wrapper_s8"; + Array depthwise_filter_shape{1, filter_shape[0], filter_shape[1], out_channels}; + filter_dims = CMSISNNDimensions(depthwise_filter_shape); + } + + tvm::Array call_ext_args = {tir::StringImm(cmsisnn_api), input, filter, multiplier}; + if (bias_add_call) { + call_ext_args.push_back(bias); + } + call_ext_args.push_back(shift); + call_ext_args.push_back(output); + // https://github.com/ARM-software/CMSIS_5/blob/d788fd583984388553391de18afd8b4d2a146868/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_s8.c#L367 std::string context_buffer_name = "NULL"; size_t context_buffer_size = diff --git a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc index b243af6c4d5f..85923b3ed08e 100644 --- a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc +++ b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc @@ -39,7 +39,6 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost { decl_stream << "#include \n"; decl_stream << "#include \n"; decl_stream << "#include \n"; - decl_stream << "#include \n"; decl_stream << "#include \n"; decl_stream << "#include \n"; CodeGenCHost::Init(output_ssa, emit_asserts, target_str); @@ -53,6 +52,35 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost { void AddFunction(const PrimFunc& prim_func) { CodeGenC::AddFunction(prim_func); } private: + /*! * \brief CMSIS-NN context buffer info */ + struct CMSISNNContextBuffer { + std::string name; + int size; + }; + + /*! * \brief CMSIS-NN buffer dimensions */ + struct CMSISNNDims { + int n; + int h; + int w; + int c; + }; + + /*! * \brief CMSIS-NN Conv2D and Depthwise parameters */ + struct Conv2DParams { + int input_offset; + int output_offset; + int stride_w; + int stride_h; + int padding_w; + int padding_h; + int dilation_w; + int dilation_h; + int clip_min; + int clip_max; + int depth_multiplier; + }; + /*! * \brief Emit the CMSIS-NN context buffer */ void VisitStmt_(const AllocateNode* op) { context_buffer_name_ = op->buffer_var->name_hint; @@ -70,38 +98,46 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost { if (cmsis_func_name == "arm_softmax_s8" || cmsis_func_name == "arm_elementwise_mul_s8" || cmsis_func_name == "arm_elementwise_add_s8") { CodeGenC::VisitExpr_(op, os); - } else if (cmsis_func_name == "arm_convolve_wrapper_s8") { + } else if (cmsis_func_name == "arm_convolve_wrapper_s8" || + cmsis_func_name == "arm_depthwise_conv_wrapper_s8") { EmitConv2D(op); } return; } /*! * \brief Emits cmsis_nn_context struct */ - std::string EmitCMSISNNContext(std::ostream& os, std::string buf_name, int buf_size) { + std::string EmitCMSISNNContext(std::ostream& os, CMSISNNContextBuffer context_buffer) { std::string struct_name = "context"; PrintIndent(); - os << "cmsis_nn_context " << struct_name << "= {" << buf_name << "," << buf_size << "};\n"; + os << "cmsis_nn_context " << struct_name << "= {" << context_buffer.name << "," + << context_buffer.size << "};\n"; return struct_name; } /*! * \brief Emits cmsis_nn_conv_params struct */ - std::string EmitCMSISNNConvParams(std::ostream& os, int32_t input_offset, int32_t output_offset, - int32_t stride_w, int32_t stride_h, int32_t padding_w, - int32_t padding_h, int32_t dilation_w, int32_t dilation_h, - int32_t clip_min, int32_t clip_max) { - std::string struct_name = "conv_params"; + std::string EmitCMSISNNConvParams(std::ostream& os, Conv2DParams params) { + std::string struct_name = "cmsis_nn_conv_params"; + std::string instance_name = "conv_params"; + if (params.depth_multiplier != -1) { + struct_name = "cmsis_nn_dw_conv_params"; + } PrintIndent(); - os << "cmsis_nn_tile stride = {" << stride_w << "," << stride_h << "};\n"; + os << "cmsis_nn_tile stride = {" << params.stride_w << "," << params.stride_h << "};\n"; PrintIndent(); - os << "cmsis_nn_tile padding = {" << padding_w << "," << padding_h << "};\n"; + os << "cmsis_nn_tile padding = {" << params.padding_w << "," << params.padding_h << "};\n"; PrintIndent(); - os << "cmsis_nn_tile dilation = {" << dilation_w << "," << dilation_h << "};\n"; + os << "cmsis_nn_tile dilation = {" << params.dilation_w << "," << params.dilation_h << "};\n"; PrintIndent(); - os << "cmsis_nn_activation activation = {" << clip_min << "," << clip_max << "};\n"; + os << "cmsis_nn_activation activation = {" << params.clip_min << "," << params.clip_max + << "};\n"; PrintIndent(); - os << "cmsis_nn_conv_params " << struct_name << " = {" << input_offset << ", " << output_offset - << ", stride, padding, dilation, activation};\n"; - return struct_name; + os << struct_name << " " << instance_name << " = {" << params.input_offset << ", " + << params.output_offset; + if (params.depth_multiplier != -1) { + os << ", " << params.depth_multiplier; + } + os << ", stride, padding, dilation, activation};\n"; + return instance_name; } /*! * \brief Emits cmsis_nn_per_channel_quant_params struct */ @@ -115,83 +151,109 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost { } /*! * \brief Emits cmsis_nn_dims struct */ - std::string EmitCMSISNNDims(std::ostream& os, std::string tensor_type, int32_t n, int32_t h, - int32_t w, int32_t c) { + std::string EmitCMSISNNDims(std::ostream& os, std::string tensor_type, CMSISNNDims dims) { std::string struct_name = tensor_type + "_dims"; PrintIndent(); - os << "cmsis_nn_dims " << struct_name << " = {" << n << "," << h << "," << w << "," << c - << "};\n"; + os << "cmsis_nn_dims " << struct_name << " = {" << dims.n << "," << dims.h << "," << dims.w + << "," << dims.c << "};\n"; return struct_name; } + /*! * \brief Deduces variable name from call_extern argument resting at id */ + std::string VarNameFromArg(const CallNode* op, int id) { + return op->args[id].as()->name_hint.c_str(); + } + + /*! * \brief Deduces value from call_extern argument resting at id */ + int ValueFromArg(const CallNode* op, int id) { return op->args[id].as()->value; } + + /*! * \brief extracts CMSIS-NN context buffer information */ + CMSISNNContextBuffer extract_context_buffer_info(const CallNode* op, int base_pos) { + CMSISNNContextBuffer context_buffer; + context_buffer.name = op->args[base_pos].as()->value; + context_buffer.size = ValueFromArg(op, base_pos + 1); + return context_buffer; + } + + /*! * \brief extracts CMSIS-NN conv2d parameters from call_extern */ + Conv2DParams extract_conv2d_params(const CallNode* op, int base_pos) { + Conv2DParams conv2d_params; + conv2d_params.input_offset = ValueFromArg(op, base_pos); + conv2d_params.output_offset = ValueFromArg(op, ++base_pos); + conv2d_params.stride_w = ValueFromArg(op, ++base_pos); + conv2d_params.stride_h = ValueFromArg(op, ++base_pos); + conv2d_params.padding_w = ValueFromArg(op, ++base_pos); + conv2d_params.padding_h = ValueFromArg(op, ++base_pos); + conv2d_params.dilation_w = ValueFromArg(op, ++base_pos); + conv2d_params.dilation_h = ValueFromArg(op, ++base_pos); + conv2d_params.clip_min = ValueFromArg(op, ++base_pos); + conv2d_params.clip_max = ValueFromArg(op, ++base_pos); + conv2d_params.depth_multiplier = ValueFromArg(op, ++base_pos); + return conv2d_params; + } + + /*! * \brief extracts CMSIS-NN buffer dimensions from call_extern */ + CMSISNNDims extract_buffer_dims(const CallNode* op, int base_pos) { + CMSISNNDims dims; + dims.n = ValueFromArg(op, base_pos); + dims.h = ValueFromArg(op, ++base_pos); + dims.w = ValueFromArg(op, ++base_pos); + dims.c = ValueFromArg(op, ++base_pos); + return dims; + } + /*! * \brief Emits CMSIS-NN APIs for every call_extern */ void EmitConv2D(const CallNode* op) { - static const int max_num_args = 35; - std::string cmsis_func_name = op->args[0].as()->value; + // Position of various arguments relative to buffers in the call_extern + enum CallExternArgPos { + CONTEXT_BUFFER_POS = 1, + CONV2D_PARAMS_POS = 3, + INPUT_DIM_POS = 14, + FILTER_DIM_POS = 18, + BIAS_DIM_POS = 22, + OUTPUT_DIM_POS = 26, + MAX_NUM_ARGS = 36 + }; - bool bias_enabled = false; - if (op->args.size() == max_num_args) { - bias_enabled = true; - } + std::string cmsis_func_name = op->args[0].as()->value; - auto get_var_name = [](const CallNode* op, int id) { - return op->args[id].as()->name_hint.c_str(); - }; - auto get_arg_value = [](const CallNode* op, int id) { - return op->args[id].as()->value; - }; + // extract buffer names from call_extern int arg_id = 0; - std::string input_data = get_var_name(op, ++arg_id); - std::string filter_data = get_var_name(op, ++arg_id); - std::string multiplier = get_var_name(op, ++arg_id); - std::string bias_data("0x0"); - if (bias_enabled) { - bias_data = get_var_name(op, ++arg_id); + std::string input_data = VarNameFromArg(op, ++arg_id); + std::string filter_data = VarNameFromArg(op, ++arg_id); + std::string multiplier = VarNameFromArg(op, ++arg_id); + std::string bias_data("NULL"); + if (op->args.size() == CallExternArgPos::MAX_NUM_ARGS) { + bias_data = VarNameFromArg(op, ++arg_id); } - std::string shift = get_var_name(op, ++arg_id); - std::string output_data = get_var_name(op, ++arg_id); - - std::string context_buffer_name = op->args[++arg_id].as()->value; - int context_buffer_size = get_arg_value(op, ++arg_id); - int input_offset = get_arg_value(op, ++arg_id); - int output_offset = get_arg_value(op, ++arg_id); - int stride_w = get_arg_value(op, ++arg_id); - int stride_h = get_arg_value(op, ++arg_id); - int padding_w = get_arg_value(op, ++arg_id); - int padding_h = get_arg_value(op, ++arg_id); - int dilation_w = get_arg_value(op, ++arg_id); - int dilation_h = get_arg_value(op, ++arg_id); - int clip_min = get_arg_value(op, ++arg_id); - int clip_max = get_arg_value(op, ++arg_id); - int input_n = get_arg_value(op, ++arg_id); - int input_h = get_arg_value(op, ++arg_id); - int input_w = get_arg_value(op, ++arg_id); - int input_c = get_arg_value(op, ++arg_id); - int filter_n = get_arg_value(op, ++arg_id); - int filter_h = get_arg_value(op, ++arg_id); - int filter_w = get_arg_value(op, ++arg_id); - int filter_c = get_arg_value(op, ++arg_id); - int bias_n = get_arg_value(op, ++arg_id); - int bias_h = get_arg_value(op, ++arg_id); - int bias_w = get_arg_value(op, ++arg_id); - int bias_c = get_arg_value(op, ++arg_id); - int output_n = get_arg_value(op, ++arg_id); - int output_h = get_arg_value(op, ++arg_id); - int output_w = get_arg_value(op, ++arg_id); - int output_c = get_arg_value(op, ++arg_id); - - std::string context = EmitCMSISNNContext(stream, context_buffer_name, context_buffer_size); - std::string conv_params = - EmitCMSISNNConvParams(stream, input_offset, output_offset, stride_w, stride_h, padding_w, - padding_h, dilation_w, dilation_h, clip_min, clip_max); + std::string shift = VarNameFromArg(op, ++arg_id); + std::string output_data = VarNameFromArg(op, ++arg_id); + + // extract CMSIS-NN API parameters + int context_buffer_pos = arg_id + CallExternArgPos::CONTEXT_BUFFER_POS; + int conv2d_params_pos = arg_id + CallExternArgPos::CONV2D_PARAMS_POS; + int input_dim_pos = arg_id + CallExternArgPos::INPUT_DIM_POS; + int filter_dim_pos = arg_id + CallExternArgPos::FILTER_DIM_POS; + int bias_dim_pos = arg_id + CallExternArgPos::BIAS_DIM_POS; + int output_dim_pos = arg_id + CallExternArgPos::OUTPUT_DIM_POS; + + CMSISNNContextBuffer context_buffer = extract_context_buffer_info(op, context_buffer_pos); + Conv2DParams conv2d_params = extract_conv2d_params(op, conv2d_params_pos); + CMSISNNDims input_dims = extract_buffer_dims(op, input_dim_pos); + CMSISNNDims filter_dims = extract_buffer_dims(op, filter_dim_pos); + CMSISNNDims bias_dims = extract_buffer_dims(op, bias_dim_pos); + CMSISNNDims output_dims = extract_buffer_dims(op, output_dim_pos); + + // Emit CMSIS-NN API arguments + std::string context = EmitCMSISNNContext(stream, context_buffer); + std::string conv_params = EmitCMSISNNConvParams(stream, conv2d_params); std::string quant_params = EmitCMSISNNPerChannelQuantParams(stream, multiplier, shift); - std::string input_dim = EmitCMSISNNDims(stream, "input", input_n, input_h, input_w, input_c); - std::string filter_dim = - EmitCMSISNNDims(stream, "filter", filter_n, filter_h, filter_w, filter_c); - std::string bias_dim = EmitCMSISNNDims(stream, "bias", bias_n, bias_h, bias_w, bias_c); - std::string output_dim = - EmitCMSISNNDims(stream, "output", output_n, output_h, output_w, output_c); + std::string input_dim = EmitCMSISNNDims(stream, "input", input_dims); + std::string filter_dim = EmitCMSISNNDims(stream, "filter", filter_dims); + std::string bias_dim = EmitCMSISNNDims(stream, "bias", bias_dims); + std::string output_dim = EmitCMSISNNDims(stream, "output", output_dims); + // Emit CMSIS-NN API PrintIndent(); stream << "arm_status status = "; stream << cmsis_func_name << "("; diff --git a/tests/python/contrib/test_cmsisnn/test_conv2d.py b/tests/python/contrib/test_cmsisnn/test_conv2d.py index 243197e4eb3e..8d62763aec52 100644 --- a/tests/python/contrib/test_cmsisnn/test_conv2d.py +++ b/tests/python/contrib/test_cmsisnn/test_conv2d.py @@ -67,31 +67,30 @@ def make_model( w_index = weight_format.index("W") kernel_h = kernel_shape[h_index] kernel_w = kernel_shape[w_index] - a = relay.var("input", shape=shape, dtype=dtype) + invar = relay.var("input", shape=shape, dtype=dtype) p = (0, 0, 0, 0) if padding == "SAME": p = get_same_padding((shape[1], shape[2]), (kernel_h, kernel_w), dilation, strides) - a = relay.nn.pad( - a, + invar = relay.nn.pad( + invar, pad_width=[(0, 0), (p[0], p[2]), (p[1], p[3]), (0, 0)], pad_value=input_zero_point, pad_mode="constant", ) shape = (shape[0], shape[1] + p[0] + p[2], shape[2] + p[1] + p[3], shape[3]) - weight_shape = (kernel_h, kernel_w, shape[3] // groups, out_channels) rng = np.random.default_rng(12321) w = tvm.nd.array( rng.integers( np.iinfo(kernel_dtype).min, high=np.iinfo(kernel_dtype).max, - size=weight_shape, + size=kernel_shape, dtype=kernel_dtype, ) ) weight_const = relay.const(w, kernel_dtype) conv = relay.qnn.op.conv2d( - a, + invar, weight_const, input_zero_point=relay.const(input_zero_point, "int32"), kernel_zero_point=relay.const(kernel_zero_point, "int32"), @@ -128,14 +127,14 @@ def make_model( @pytest.mark.parametrize("ifm_shape", [(1, 28, 28, 12), (1, 64, 100, 4)]) @pytest.mark.parametrize("kernel_size", [(3, 3)]) @pytest.mark.parametrize("padding", ["SAME", "VALID"]) -@pytest.mark.parametrize("strides, dilation", [((2, 2), (1, 1)), ((1, 1), (1, 1))]) +@pytest.mark.parametrize("strides, dilation", [((1, 1), (1, 1))]) +@pytest.mark.parametrize("relu_type", ["RELU"]) @pytest.mark.parametrize("enable_bias", [True, False]) -@pytest.mark.parametrize("relu_type", ["NONE", "RELU"]) @pytest.mark.parametrize( "input_zero_point, input_scale, kernel_scale, out_channels", [(10, 0.0128, [0.11, 0.22], 2), (-64, 1, [1, 0.0256, 1.37], 3)], ) -def test_op_int8( +def test_conv2d_int8( ifm_shape, kernel_size, padding, @@ -152,22 +151,134 @@ def test_op_int8( use_unpacked_api = True test_runner = AOT_CORSTONE300_RUNNER - kernel_zero_point = 0 + dtype = "int8" groups = 1 weight_format = "HWIO" kernel_h = kernel_size[0] kernel_w = kernel_size[1] + kernel_shape = (kernel_h, kernel_w, ifm_shape[3] // groups, out_channels) + kernel_zero_point = 0 + in_min, in_max = get_range_for_dtype_str(dtype) + + output_scale, output_zero_point = get_conv2d_qnn_params( + kernel_shape, + input_scale, + input_zero_point, + kernel_scale, + kernel_zero_point, + dtype, + dtype, + dtype, + ) + + model, params = make_model( + ifm_shape, + kernel_shape, + input_zero_point, + input_scale, + kernel_zero_point, + kernel_scale, + output_zero_point, + output_scale, + padding, + strides, + dilation, + groups, + dtype, + dtype, + out_channels, + weight_format, + enable_bias, + relu_type, + ) + orig_mod = make_module(model) + cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params) + + # validate pattern matching + attrs = [ + cmsisnn_mod[var.name_hint].attrs + for var in cmsisnn_mod.get_global_vars() + if cmsisnn_mod[var.name_hint].attrs + ] + assert any(attrs), "At least one function with external attributes was expected." + + compilers = [ + key == "Compiler" and value == "cmsis-nn" for attr in attrs for key, value in attr.items() + ] + assert any(compilers), "Module does not contain function for cmsis-nn target." + + assert count_num_calls(orig_mod) == count_num_calls( + cmsisnn_mod + ), "Number of calls changed during partitioning" + + # validate the output + rng = np.random.default_rng(12345) + inputs = {"input": rng.integers(in_min, high=in_max, size=ifm_shape, dtype=dtype)} + output_list = generate_ref_data(orig_mod["main"], inputs, params) + compile_and_run( + AOTTestModel( + module=cmsisnn_mod, + inputs=inputs, + outputs=output_list, + params=params, + output_tolerance=1, + ), + test_runner, + interface_api, + use_unpacked_api, + ) + + +@tvm.testing.requires_cmsisnn +@pytest.mark.parametrize("ifm_shape", [(1, 28, 28, 12), (1, 64, 100, 4)]) +@pytest.mark.parametrize("kernel_size", [(3, 3)]) +@pytest.mark.parametrize("padding", ["SAME", "VALID"]) +@pytest.mark.parametrize("strides, dilation", [((1, 1), (1, 1))]) +@pytest.mark.parametrize("relu_type", ["RELU"]) +@pytest.mark.parametrize( + "depth_multiplier, enable_bias", + [(1, True), (3, True)], +) +@pytest.mark.parametrize( + "input_zero_point, input_scale, kernel_scale, out_channels", + [(10, 0.0128, [0.11, 0.22], 2), (-64, 1, [1, 0.0256, 1.37], 3)], +) +def test_depthwise_int8( + ifm_shape, + kernel_size, + padding, + strides, + dilation, + enable_bias, + relu_type, + input_zero_point, + input_scale, + kernel_scale, + out_channels, + depth_multiplier, +): + interface_api = "c" + use_unpacked_api = True + test_runner = AOT_CORSTONE300_RUNNER + dtype = "int8" + groups = 1 + weight_format = "HWIO" + kernel_h = kernel_size[0] + kernel_w = kernel_size[1] + kernel_shape = (kernel_h, kernel_w, ifm_shape[3] // groups, out_channels) + kernel_zero_point = 0 in_min, in_max = get_range_for_dtype_str(dtype) - weight_shape = None - if weight_format == "HWIO": - weight_shape = (kernel_h, kernel_w, ifm_shape[3] // groups, out_channels) - else: - weight_shape = (kernel_h, kernel_w, ifm_shape[3], out_channels) + groups = ifm_shape[3] + weight_format = "HWOI" + kernel_shape = (kernel_h, kernel_w, ifm_shape[3], depth_multiplier) + out_channels = ifm_shape[3] * depth_multiplier + ks_len = len(kernel_scale) + kernel_scale = [kernel_scale[i % ks_len] for i in range(out_channels)] output_scale, output_zero_point = get_conv2d_qnn_params( - weight_shape, + kernel_shape, input_scale, input_zero_point, kernel_scale, @@ -175,12 +286,12 @@ def test_op_int8( dtype, dtype, dtype, - False, + True, ) model, params = make_model( ifm_shape, - weight_shape, + kernel_shape, input_zero_point, input_scale, kernel_zero_point,