diff --git a/python/tvm/relay/op/contrib/cmsisnn.py b/python/tvm/relay/op/contrib/cmsisnn.py index 5e0ad2796a16..b80da2c8ccd0 100644 --- a/python/tvm/relay/op/contrib/cmsisnn.py +++ b/python/tvm/relay/op/contrib/cmsisnn.py @@ -145,6 +145,73 @@ def check_qnn_conv2d(pattern): and (not is_depthwise or bias_add is not None) ) + def qnn_fully_connected_pattern(): + """Create pattern for qnn.dense with optional Relu.""" + qnn_fc = is_op("qnn.dense")( + wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant() + ) + bias_add = is_op("nn.bias_add")(qnn_fc, is_constant()) + req = is_op("qnn.requantize")( + qnn_fc | bias_add, is_constant(), is_constant(), is_constant(), is_constant() + ) + clip_or_req = req.optional(is_op("clip")) + return clip_or_req + + def check_qnn_fully_connected(pattern): + """Check if the fully connected is supported by CMSIS-NN.""" + if str(pattern.op.name) == "clip": + relu = pattern + requantize = relu.args[0] + else: + requantize = pattern + requantize_input = requantize.args[0] + bias_add = None + bias_dtype = "int32" + if str(requantize_input.op.name) == "nn.bias_add": + bias_add = requantize_input + fc = bias_add.args[0] + bias_dtype = bias_add.args[1].checked_type.dtype + else: + fc = requantize_input + fc_input = fc.args[0] + fc_weight = fc.args[1] + + # kernel zero_point should be 0 + kernel_zp = fc.args[3].data.numpy().item(0) + + return ( + fc.attrs.out_dtype == "int32" + and fc_input.checked_type.dtype == "int8" + and fc_weight.checked_type.dtype == "int8" + and pattern.checked_type.dtype == "int8" + and bias_dtype == "int32" + and kernel_zp == 0 + ) + + def qnn_avg_pool2d_pattern(): + """Matches average pooling with optional Relu""" + pattern = is_op("cast")(wildcard()) + pattern = is_op("nn.avg_pool2d")(pattern) + pattern = is_op("cast")(pattern) + pattern = pattern.optional(is_op("clip")) + return pattern + + def check_qnn_avg_pool2d(pattern): + """Check if avg pool2d is supported by CMSIS-NN.""" + in_cast = pattern + out_cast = in_cast.args[0].args[0] + return in_cast.checked_type.dtype == "int8" and out_cast.checked_type.dtype == "int32" + + def qnn_max_pool2d_pattern(): + """Matches max pool2d with optional Relu""" + pattern = is_op("nn.max_pool2d")(wildcard()) + pattern = pattern.optional(is_op("clip")) + return pattern + + def check_qnn_max_pool2d(pattern): + """Check if max pool2d is supported by CMSIS-NN.""" + return True + def binary_op_pattern(op): """Matches QNN binary operation""" return is_op(f"qnn.{op}")( @@ -166,8 +233,11 @@ def check_qnn_binary_op(extract): ) return [ - ("cmsis-nn.qnn_softmax", qnn_softmax_pattern(), check_qnn_softmax), ("cmsis-nn.qnn_conv2d", qnn_conv2d_pattern(), check_qnn_conv2d), + ("cmsis-nn.qnn_fully_connected", qnn_fully_connected_pattern(), check_qnn_fully_connected), + ("cmsis-nn.qnn_avg_pool2d", qnn_avg_pool2d_pattern(), check_qnn_avg_pool2d), + ("cmsis-nn.qnn_max_pool2d", qnn_max_pool2d_pattern(), check_qnn_max_pool2d), ("cmsis-nn.qnn_mul", binary_op_pattern("mul"), check_qnn_binary_op), ("cmsis-nn.qnn_add", binary_op_pattern("add"), check_qnn_binary_op), + ("cmsis-nn.qnn_softmax", qnn_softmax_pattern(), check_qnn_softmax), ] diff --git a/src/relay/backend/contrib/cmsisnn/generate_constants.cc b/src/relay/backend/contrib/cmsisnn/generate_constants.cc index 2e12697f36f1..056784b6675d 100644 --- a/src/relay/backend/contrib/cmsisnn/generate_constants.cc +++ b/src/relay/backend/contrib/cmsisnn/generate_constants.cc @@ -134,10 +134,10 @@ class GenerateConstantsMutator : public MixedModeMutator { int32_t* multiplier = static_cast(multiplier_nda->data); int32_t* shift = static_cast(shift_nda->data); for (int i = 0; i < out_channels; ++i) { - double effective_output_scale = + double quantized_multiplier = static_cast(input_scales[i]) / static_cast(output_scale); std::tie(*(multiplier + i), *(shift + i)) = - tvm::relay::qnn::GetFixedPointMultiplierShift(effective_output_scale); + tvm::relay::qnn::GetFixedPointMultiplierShift(quantized_multiplier); } // Create constants from requantization multiplier and shift diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc index 668352700805..8afbd91ce37c 100644 --- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc +++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc @@ -84,17 +84,9 @@ class RelayToTIRVisitor : public MixedModeMutator { tir::PrimFunc replacement_func(func_signature, body, VoidType(), Map(), DictAttrs(dict_attrs)); - ir_module_->Add(global_var, replacement_func); } - Array CMSISNNDimensions(const Array& shape) { - ICHECK(shape.size() == 4) << "Supports only CMSIS-NN shapes of dimension 4."; - return Array{ToArg(qnn::get_const_int(shape[0])), ToArg(qnn::get_const_int(shape[1])), - ToArg(qnn::get_const_int(shape[2])), - ToArg(qnn::get_const_int(shape[3]))}; - } - void EmitConv2D(const GlobalVar& global_var, const Expr& expr) { const CallNode* clip_call = nullptr; const CallNode* requantize_call = nullptr; @@ -164,21 +156,15 @@ class RelayToTIRVisitor : public MixedModeMutator { ToArg(dilation_w), ToArg(dilation_h), ToArg(clip_min), ToArg(clip_max)}; - // cmsis_nn_dims *input_dims (NHWC) + // layout NHWC Array input_shape = conv2d_call->args[0]->type_as()->shape; - Array input_dims = CMSISNNDimensions(input_shape); - // cmsis_nn_dims *filter_dims (OHWI for Conv2D and IHWO for depthwise) + // OHWI for Conv2D and IHWO for depthwise Array filter_shape = conv2d_call->args[1]->type_as()->shape; - Array filter_dims = CMSISNNDimensions(filter_shape); - // cmsis_nn_dims *bias_dims Array bias_shape{1, 1, 1, out_channels}; - Array bias_dims = CMSISNNDimensions(bias_shape); - // cmsis_nn_dims *output_dims (same order as input_dims) Array output_shape = conv2d_call->type_as()->shape; - Array output_dims = CMSISNNDimensions(output_shape); int32_t depth_multiplier = -1; int kernel_pos_o = kernel_layout.find("O"); @@ -194,7 +180,7 @@ class RelayToTIRVisitor : public MixedModeMutator { if (depth_multiplier != -1) { cmsisnn_api = "arm_depthwise_conv_wrapper_s8"; Array depthwise_filter_shape{1, filter_shape[0], filter_shape[1], out_channels}; - filter_dims = CMSISNNDimensions(depthwise_filter_shape); + filter_shape = depthwise_filter_shape; } tvm::Array call_ext_args = {tir::StringImm(cmsisnn_api), input, filter, multiplier}; @@ -216,10 +202,10 @@ class RelayToTIRVisitor : public MixedModeMutator { ToArg(context_buffer_size)}; scalar_args = tvm::runtime::Concat(context_buffer_args, scalar_args); - scalar_args = tvm::runtime::Concat(scalar_args, input_dims); - scalar_args = tvm::runtime::Concat(scalar_args, filter_dims); - scalar_args = tvm::runtime::Concat(scalar_args, bias_dims); - scalar_args = tvm::runtime::Concat(scalar_args, output_dims); + scalar_args = tvm::runtime::Concat(scalar_args, input_shape); + scalar_args = tvm::runtime::Concat(scalar_args, filter_shape); + scalar_args = tvm::runtime::Concat(scalar_args, bias_shape); + scalar_args = tvm::runtime::Concat(scalar_args, output_shape); call_ext_args = tvm::runtime::Concat(call_ext_args, scalar_args); Array func_signature{input, filter, multiplier, filter_scale}; @@ -234,6 +220,197 @@ class RelayToTIRVisitor : public MixedModeMutator { context_buffer_size); } + void EmitFullyConnected(const GlobalVar& global_var, const Expr& expr) { + const CallNode* clip_call = nullptr; + const CallNode* requantize_call = nullptr; + const CallNode* bias_add_call = nullptr; + const CallNode* fc_call = nullptr; + const CallNode* final_call = expr.as(); + const OpNode* final_op = final_call->op.as(); + if (final_op->name == "clip") { + clip_call = final_call; + requantize_call = clip_call->args[0].as(); + } else { + requantize_call = final_call; + } + const CallNode* requantize_input = requantize_call->args[0].as(); + const OpNode* requantize_input_op = requantize_input->op.as(); + if (requantize_input_op->name == "nn.bias_add") { + bias_add_call = requantize_input; + fc_call = bias_add_call->args[0].as(); + } else { + fc_call = requantize_input; + } + + // TIR variables are created in the order they appear in the Relay partitioned function + // %1 = qnn.dense(%input, %weight_const_0, input_zero_point_scalar, kernel_zero_point_scalar, + // %input_scale_scalar, %kernel_scale_scalar) + // %2 = nn.bias_add(%1, %bias_const_1, axis=1) + // %3 = qnn.requantize(%2, %req_input_scale_scalar, %req_input_zero_point_scalar, + // %output_scale_scalar, %output_zero_point_scalar) + // clip(%3, a_min=%min_scalar, a_max=%max_scalar) + tir::Var input("input", DataType::Handle(8)); + tir::Var filter("filter", DataType::Handle(8)); + tir::Var bias("bias", DataType::Handle(32)); + tir::Var output("output", DataType::Handle(8)); + + // Individual arguments to the structs arguments of the CMSIS-NN API are filled into call_extern + // https://github.com/ARM-software/CMSIS_5/blob/def6f800f95661eb3451d317f7d0dde504f6020d/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_wrapper_s8.c#L50 + + // prepare cmsis_nn_fc_params + const DenseAttrs* dense_attrs = fc_call->attrs.as(); + int32_t input_offset = -GetScalarFromConstant(fc_call->args[2]); + int32_t filter_offset = -GetScalarFromConstant(fc_call->args[3]); + int32_t output_offset = GetScalarFromConstant(requantize_call->args[4]); + float input_scale = GetScalarFromConstant(requantize_call->args[1]); + float output_scale = GetScalarFromConstant(requantize_call->args[3]); + int32_t out_channels = qnn::get_const_int(dense_attrs->units); + int32_t clip_min, clip_max; + if (clip_call) { + const ClipAttrs* clip_attrs = clip_call->attrs.as(); + clip_min = clip_attrs->a_min; + clip_max = clip_attrs->a_max; + } else { + clip_min = -128; + clip_max = 127; + } + + double quantized_multiplier = + static_cast(input_scale) / static_cast(output_scale); + auto mult_shift_pair = tvm::relay::qnn::GetFixedPointMultiplierShift(quantized_multiplier); + int32_t multiplier = std::get<0>(mult_shift_pair); + int32_t shift = std::get<1>(mult_shift_pair); + + tvm::Array scalar_args = { + ToArg(input_offset), ToArg(filter_offset), ToArg(output_offset), ToArg(clip_min), + ToArg(clip_max), ToArg(multiplier), ToArg(shift)}; + + Array input_shape = fc_call->args[0]->type_as()->shape; + int32_t batch_size = qnn::get_const_int(input_shape[0]); + int32_t in_channels = qnn::get_const_int(input_shape[1]); + Array cmsisnn_input_shape{input_shape[0], 1, 1, input_shape[1]}; + + Array cmsisnn_filter_shape{in_channels, 1, 1, out_channels}; + + Array bias_shape{1, 1, 1, out_channels}; + + Array cmsisnn_output_shape{batch_size, 1, 1, out_channels}; + + tvm::Array call_ext_args = {tir::StringImm("arm_fully_connected_s8"), input, filter}; + if (bias_add_call) { + call_ext_args.push_back(bias); + } + call_ext_args.push_back(output); + + int context_buffer_size = 0; + std::string context_buffer_name = "NULL"; + tvm::Array context_buffer_args = {tir::StringImm(context_buffer_name), + ToArg(context_buffer_size)}; + + scalar_args = tvm::runtime::Concat(context_buffer_args, scalar_args); + scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_input_shape); + scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_filter_shape); + scalar_args = tvm::runtime::Concat(scalar_args, bias_shape); + scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_output_shape); + call_ext_args = tvm::runtime::Concat(call_ext_args, scalar_args); + + Array func_signature{input, filter}; + if (bias_add_call) { + func_signature.push_back(bias); + } + func_signature.push_back(output); + CreatePrimFuncForExtern(global_var, func_signature, call_ext_args, context_buffer_name, + context_buffer_size); + } + + void EmitPool2D(const GlobalVar& global_var, const Expr& expr, const String pool_name) { + Call clip, pool; + Call final_call = GetRef(expr.as()); + Op final_op = GetRef(final_call->op.as()); + if (final_op->name == "clip") { + clip = final_call; + Call clip_input = GetRef(clip->args[0].as()); + Op clip_input_op = GetRef(clip_input->op.as()); + if (clip_input_op->name == "cast") { + pool = GetRef(clip_input->args[0].as()); + } else { // max_pool2d + pool = clip_input; + } + } else if (final_op->name == "cast") { + pool = GetRef(final_call->args[0].as()); + } else { // max_pool2d + pool = final_call; + } + + // prepare cmsis_nn_pool_params + int32_t stride_h, stride_w, padding_h, padding_w, pool_size_h, pool_size_w; + int32_t clip_min, clip_max; + std::string cmsisnn_api; + if (pool_name == "cmsis-nn.qnn_avg_pool2d") { + cmsisnn_api = "arm_avgpool_s8"; + const AvgPool2DAttrs* attrs = pool->attrs.as(); + stride_h = qnn::get_const_int(attrs->strides[0]); + stride_w = qnn::get_const_int(attrs->strides[1]); + padding_h = qnn::get_const_int(attrs->padding[0]); + padding_w = qnn::get_const_int(attrs->padding[1]); + pool_size_h = qnn::get_const_int(attrs->pool_size[0]); + pool_size_w = qnn::get_const_int(attrs->pool_size[1]); + } else { + cmsisnn_api = "arm_max_pool_s8"; + const MaxPool2DAttrs* attrs = pool->attrs.as(); + stride_h = qnn::get_const_int(attrs->strides[0]); + stride_w = qnn::get_const_int(attrs->strides[1]); + padding_h = qnn::get_const_int(attrs->padding[0]); + padding_w = qnn::get_const_int(attrs->padding[1]); + pool_size_h = qnn::get_const_int(attrs->pool_size[0]); + pool_size_w = qnn::get_const_int(attrs->pool_size[1]); + } + if (clip.defined()) { + const ClipAttrs* clip_attrs = clip->attrs.as(); + clip_min = clip_attrs->a_min; + clip_max = clip_attrs->a_max; + } else { + clip_min = -128; + clip_max = 127; + } + + tvm::Array scalar_args = {ToArg(stride_h), ToArg(stride_w), ToArg(padding_h), + ToArg(padding_w), ToArg(clip_min), ToArg(clip_max)}; + + Array input_shape = pool->args[0]->type_as()->shape; + Array cmsisnn_input_shape{1, input_shape[1], input_shape[2], input_shape[3]}; + + Array cmsisnn_filter_shape{1, pool_size_h, pool_size_w, 1}; + + Array output_shape = pool->type_as()->shape; + Array cmsisnn_output_shape{1, output_shape[1], output_shape[2], output_shape[3]}; + + tir::Var input("input", DataType::Handle(8)); + tir::Var output("output", DataType::Handle(8)); + tvm::Array call_ext_args = {tir::StringImm(cmsisnn_api), input, output}; + + int context_buffer_size = 0; + std::string context_buffer_name = "NULL"; + if (pool_name == "cmsisnn.qnn_avg_pool2d") { + // TODO(@Mousius): Need to move this into buffer_size calculations + context_buffer_size = qnn::get_const_int(input_shape[3]) * sizeof(int32_t); + context_buffer_name = "context_buffer_" + std::to_string(context_buffer_id_++); + } + tvm::Array context_buffer_args = {tir::StringImm(context_buffer_name), + ToArg(context_buffer_size)}; + + scalar_args = tvm::runtime::Concat(context_buffer_args, scalar_args); + scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_input_shape); + scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_filter_shape); + scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_output_shape); + call_ext_args = tvm::runtime::Concat(call_ext_args, scalar_args); + + Array func_signature{input, output}; + + CreatePrimFuncForExtern(global_var, func_signature, call_ext_args, context_buffer_name, + context_buffer_size); + } + void EmitSoftMax(const GlobalVar& global_var, const Expr& expr) { const CallNode* quantize_call = expr.as(); const CallNode* softmax_call = quantize_call->args[0].as(); @@ -422,6 +599,12 @@ class RelayToTIRVisitor : public MixedModeMutator { if (comp_name == "cmsis-nn.qnn_conv2d") { EmitConv2D(new_global_var, composite_func->body); } + if (comp_name == "cmsis-nn.qnn_fully_connected") { + EmitFullyConnected(new_global_var, composite_func->body); + } + if (comp_name == "cmsis-nn.qnn_avg_pool2d" || comp_name == "cmsis-nn.qnn_max_pool2d") { + EmitPool2D(new_global_var, composite_func->body, comp_name.value()); + } Array args; for (const auto& arg : call->args) { diff --git a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc index 85923b3ed08e..2a7d0ae21769 100644 --- a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc +++ b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc @@ -81,12 +81,25 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost { int depth_multiplier; }; - /*! * \brief Emit the CMSIS-NN context buffer */ - void VisitStmt_(const AllocateNode* op) { - context_buffer_name_ = op->buffer_var->name_hint; - context_buffer_size_ = op->constant_allocation_size(); - CodeGenC::VisitStmt_(op); - } + /*! * \brief CMSIS-NN Conv2D and Depthwise parameters */ + struct FCParams { + int input_offset; + int filter_offset; + int output_offset; + int clip_min; + int clip_max; + int multiplier; + int shift; + }; + + struct PoolParams { + int stride_h; + int stride_w; + int padding_h; + int padding_w; + int clip_min; + int clip_max; + }; /*! * \brief Emits CMSIS-NN APIs for every call_extern */ void VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) @@ -101,6 +114,10 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost { } else if (cmsis_func_name == "arm_convolve_wrapper_s8" || cmsis_func_name == "arm_depthwise_conv_wrapper_s8") { EmitConv2D(op); + } else if (cmsis_func_name == "arm_fully_connected_s8") { + EmitFullyConnected(op); + } else if (cmsis_func_name == "arm_avgpool_s8" || cmsis_func_name == "arm_max_pool_s8") { + EmitPool2D(op); } return; } @@ -140,6 +157,36 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost { return instance_name; } + /*! * \brief Emits cmsis_nn_fc_params struct */ + std::string EmitCMSISNNFCParams(std::ostream& os, FCParams params) { + std::string struct_name = "cmsis_nn_fc_params"; + std::string instance_name = "fc_params"; + PrintIndent(); + os << "cmsis_nn_activation activation = {" << params.clip_min << "," << params.clip_max + << "};\n"; + PrintIndent(); + os << struct_name << " " << instance_name << " = {" << params.input_offset << ", " + << params.filter_offset << ", " << params.output_offset; + os << ", activation};\n"; + return instance_name; + } + + /*! * \brief Emits cmsis_nn_pool_params struct */ + std::string EmitCMSISNNPoolParams(std::ostream& os, PoolParams params) { + std::string struct_name = "cmsis_nn_pool_params"; + std::string instance_name = "pool_params"; + PrintIndent(); + os << "cmsis_nn_tile stride = {" << params.stride_w << "," << params.stride_h << "};\n"; + PrintIndent(); + os << "cmsis_nn_tile padding = {" << params.padding_w << "," << params.padding_h << "};\n"; + PrintIndent(); + os << "cmsis_nn_activation activation = {" << params.clip_min << "," << params.clip_max + << "};\n"; + PrintIndent(); + os << struct_name << " " << instance_name << " = {stride, padding, activation};\n"; + return instance_name; + } + /*! * \brief Emits cmsis_nn_per_channel_quant_params struct */ std::string EmitCMSISNNPerChannelQuantParams(std::ostream& os, std::string multiplier, std::string shift) { @@ -150,6 +197,15 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost { return struct_name; } + /*! * \brief Emits cmsis_nn_per_tensor_quant_params struct */ + std::string EmitCMSISNNPerTensorQuantParams(std::ostream& os, int multiplier, int shift) { + std::string struct_name = "quant_params"; + PrintIndent(); + os << "cmsis_nn_per_tensor_quant_params " << struct_name << " = {" << multiplier << ", " + << shift << "};\n"; + return struct_name; + } + /*! * \brief Emits cmsis_nn_dims struct */ std::string EmitCMSISNNDims(std::ostream& os, std::string tensor_type, CMSISNNDims dims) { std::string struct_name = tensor_type + "_dims"; @@ -192,6 +248,31 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost { return conv2d_params; } + /*! * \brief extracts CMSIS-NN FC parameters from call_extern */ + FCParams extract_fc_params(const CallNode* op, int base_pos) { + FCParams fc_params; + fc_params.input_offset = ValueFromArg(op, base_pos); + fc_params.filter_offset = ValueFromArg(op, ++base_pos); + fc_params.output_offset = ValueFromArg(op, ++base_pos); + fc_params.clip_min = ValueFromArg(op, ++base_pos); + fc_params.clip_max = ValueFromArg(op, ++base_pos); + fc_params.multiplier = ValueFromArg(op, ++base_pos); + fc_params.shift = ValueFromArg(op, ++base_pos); + return fc_params; + } + + /*! * \brief extracts CMSIS-NN Pooling parameters from call_extern */ + PoolParams extract_pool_params(const CallNode* op, int base_pos) { + PoolParams pool_params; + pool_params.stride_h = ValueFromArg(op, base_pos); + pool_params.stride_w = ValueFromArg(op, ++base_pos); + pool_params.padding_h = ValueFromArg(op, ++base_pos); + pool_params.padding_w = ValueFromArg(op, ++base_pos); + pool_params.clip_min = ValueFromArg(op, ++base_pos); + pool_params.clip_max = ValueFromArg(op, ++base_pos); + return pool_params; + } + /*! * \brief extracts CMSIS-NN buffer dimensions from call_extern */ CMSISNNDims extract_buffer_dims(const CallNode* op, int base_pos) { CMSISNNDims dims; @@ -202,7 +283,7 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost { return dims; } - /*! * \brief Emits CMSIS-NN APIs for every call_extern */ + /*! * \brief Emits CMSIS-NN APIs for every call_extern comprising convolution */ void EmitConv2D(const CallNode* op) { // Position of various arguments relative to buffers in the call_extern enum CallExternArgPos { @@ -273,9 +354,128 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost { stream << "}\n"; } - private: - std::string context_buffer_name_ = "NULL"; - int context_buffer_size_ = 0; + /*! * \brief Emits CMSIS-NN APIs for every call_extern comprising fully connected */ + void EmitFullyConnected(const CallNode* op) { + // Position of various arguments relative to buffers in the call_extern + enum CallExternArgPos { + CONTEXT_BUFFER_POS = 1, + FC_PARAMS_POS = 3, + INPUT_DIM_POS = 10, + FILTER_DIM_POS = 14, + BIAS_DIM_POS = 18, + OUTPUT_DIM_POS = 22, + MAX_NUM_ARGS = 30 + }; + + std::string cmsis_func_name = op->args[0].as()->value; + + // extract buffer names from call_extern + int arg_id = 0; + std::string input_data = VarNameFromArg(op, ++arg_id); + std::string filter_data = VarNameFromArg(op, ++arg_id); + std::string bias_data("NULL"); + if (op->args.size() == CallExternArgPos::MAX_NUM_ARGS) { + bias_data = VarNameFromArg(op, ++arg_id); + } + std::string output_data = VarNameFromArg(op, ++arg_id); + + // extract CMSIS-NN API parameters + int context_buffer_pos = arg_id + CallExternArgPos::CONTEXT_BUFFER_POS; + int fc_params_pos = arg_id + CallExternArgPos::FC_PARAMS_POS; + int input_dim_pos = arg_id + CallExternArgPos::INPUT_DIM_POS; + int filter_dim_pos = arg_id + CallExternArgPos::FILTER_DIM_POS; + int bias_dim_pos = arg_id + CallExternArgPos::BIAS_DIM_POS; + int output_dim_pos = arg_id + CallExternArgPos::OUTPUT_DIM_POS; + + CMSISNNContextBuffer context_buffer = extract_context_buffer_info(op, context_buffer_pos); + FCParams fc_params = extract_fc_params(op, fc_params_pos); + CMSISNNDims input_dims = extract_buffer_dims(op, input_dim_pos); + CMSISNNDims filter_dims = extract_buffer_dims(op, filter_dim_pos); + CMSISNNDims bias_dims = extract_buffer_dims(op, bias_dim_pos); + CMSISNNDims output_dims = extract_buffer_dims(op, output_dim_pos); + + // Emit CMSIS-NN API arguments + std::string context = EmitCMSISNNContext(stream, context_buffer); + std::string cmsisnn_fc_params = EmitCMSISNNFCParams(stream, fc_params); + std::string quant_params = + EmitCMSISNNPerTensorQuantParams(stream, fc_params.multiplier, fc_params.shift); + std::string input_dim = EmitCMSISNNDims(stream, "input", input_dims); + std::string filter_dim = EmitCMSISNNDims(stream, "filter", filter_dims); + std::string bias_dim = EmitCMSISNNDims(stream, "bias", bias_dims); + std::string output_dim = EmitCMSISNNDims(stream, "output", output_dims); + + PrintIndent(); + stream << "arm_status status = "; + stream << cmsis_func_name << "("; + stream << "&" << context << ", "; + stream << "&" << cmsisnn_fc_params << ", "; + stream << "&" << quant_params << ", "; + stream << "&" << input_dim << ", " << input_data << ", "; + stream << "&" << filter_dim << ", " << filter_data << ", "; + stream << "&" << bias_dim << ", " << bias_data << ", "; + stream << "&" << output_dim << ", " << output_data << ");\n"; + PrintIndent(); + stream << "if (status != ARM_MATH_SUCCESS) {\n"; + PrintIndent(); + PrintIndent(); + stream << "return -1;\n"; + PrintIndent(); + stream << "}\n"; + } + + /*! * \brief Emits CMSIS-NN APIs for every call_extern comprising pooling ops */ + void EmitPool2D(const CallNode* op) { + // Position of various arguments relative to buffers in the call_extern + enum CallExternArgPos { + CONTEXT_BUFFER_POS = 1, + POOL_PARAMS_POS = 3, + INPUT_DIM_POS = 9, + FILTER_DIM_POS = 13, + OUTPUT_DIM_POS = 17, + MAX_NUM_ARGS = 23 + }; + std::string cmsis_func_name = op->args[0].as()->value; + + // extract buffer names from call_extern + int arg_id = 0; + std::string input_data = VarNameFromArg(op, ++arg_id); + std::string output_data = VarNameFromArg(op, ++arg_id); + + // extract CMSIS-NN API parameters + int context_buffer_pos = arg_id + CallExternArgPos::CONTEXT_BUFFER_POS; + int pool_params_pos = arg_id + CallExternArgPos::POOL_PARAMS_POS; + int input_dim_pos = arg_id + CallExternArgPos::INPUT_DIM_POS; + int filter_dim_pos = arg_id + CallExternArgPos::FILTER_DIM_POS; + int output_dim_pos = arg_id + CallExternArgPos::OUTPUT_DIM_POS; + + CMSISNNContextBuffer context_buffer = extract_context_buffer_info(op, context_buffer_pos); + PoolParams pool_params = extract_pool_params(op, pool_params_pos); + CMSISNNDims input_dims = extract_buffer_dims(op, input_dim_pos); + CMSISNNDims filter_dims = extract_buffer_dims(op, filter_dim_pos); + CMSISNNDims output_dims = extract_buffer_dims(op, output_dim_pos); + + std::string context = EmitCMSISNNContext(stream, context_buffer); + std::string cmsisnn_pool_params = EmitCMSISNNPoolParams(stream, pool_params); + std::string input_dim = EmitCMSISNNDims(stream, "input", input_dims); + std::string filter_dim = EmitCMSISNNDims(stream, "filter", filter_dims); + std::string output_dim = EmitCMSISNNDims(stream, "output", output_dims); + + PrintIndent(); + stream << "arm_status status = "; + stream << cmsis_func_name << "("; + stream << "&" << context << ", "; + stream << "&" << cmsisnn_pool_params << ", "; + stream << "&" << input_dim << ", " << input_data << ", "; + stream << "&" << filter_dim << ", "; + stream << "&" << output_dim << ", " << output_data << ");\n"; + PrintIndent(); + stream << "if (status != ARM_MATH_SUCCESS) {\n"; + PrintIndent(); + PrintIndent(); + stream << "return -1;\n"; + PrintIndent(); + stream << "}\n"; + } }; runtime::Module TIRToRuntime(IRModule mod, Target target) { diff --git a/tests/python/contrib/test_cmsisnn/test_fully_connected.py b/tests/python/contrib/test_cmsisnn/test_fully_connected.py new file mode 100644 index 000000000000..42b36a77b77f --- /dev/null +++ b/tests/python/contrib/test_cmsisnn/test_fully_connected.py @@ -0,0 +1,259 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""CMSIS-NN integration tests: Fully Connected""" +import itertools +import numpy as np +import pytest +import tvm +from tvm import relay +from tvm.relay.op.contrib import cmsisnn + + +from tests.python.relay.aot.aot_test_utils import ( + AOTTestModel, + AOT_CORSTONE300_RUNNER, + AOT_DEFAULT_RUNNER, + generate_ref_data, + compile_and_run, +) +from utils import ( + skip_if_no_reference_system, + make_module, + count_num_calls, + get_range_for_dtype_str, + get_same_padding, + get_conv2d_qnn_params, + make_qnn_relu, +) + + +def make_model( + in_shape, # [batchsize, in_channels] + kernel_shape, # [out_channels, num_inputs] + input_zero_point, + kernel_zero_point, + input_scale, + kernel_scale, + output_zero_point, + output_scale, + dtype, + kernel_dtype, + out_channels, + enable_bias, + relu_type="NONE", +): + """Return a model and any parameters it may have""" + a = relay.var("input", shape=in_shape, dtype=dtype) + rng = np.random.default_rng(12321) + w = tvm.nd.array( + rng.integers( + np.iinfo(kernel_dtype).min, + high=np.iinfo(kernel_dtype).max, + size=kernel_shape, + dtype=kernel_dtype, + ) + ) + weight_const = relay.const(w, kernel_dtype) + fc = relay.qnn.op.dense( + a, + weight_const, + input_zero_point=relay.const(input_zero_point, "int32"), + kernel_zero_point=relay.const(kernel_zero_point, "int32"), + input_scale=relay.const(input_scale, "float32"), + kernel_scale=relay.const(kernel_scale, "float32"), + units=out_channels, + out_dtype="int32", + ) + + b = tvm.nd.array(rng.integers(0, high=10, size=(out_channels,), dtype="int32")) + bias_const = relay.const(b, "int32") + last_op = relay.nn.bias_add(fc, bias_const) if enable_bias else fc + requant_input_sc = input_scale * kernel_scale + last_op = relay.qnn.op.requantize( + last_op, + relay.const(requant_input_sc, "float32"), + relay.const(0, "int32"), + relay.const(output_scale, "float32"), + relay.const(output_zero_point, "int32"), + out_dtype=dtype, + ) + last_op = make_qnn_relu(last_op, relu_type, output_scale, output_zero_point, dtype) + params = {"w": w, "b": b} + return last_op, params + + +@tvm.testing.requires_cmsisnn +@pytest.mark.parametrize("in_shape", [(2, 28), (1, 64)]) +@pytest.mark.parametrize("out_channels", [12, 128]) +@pytest.mark.parametrize("enable_bias", [False, True]) +@pytest.mark.parametrize("relu_type", ["RELU"]) +@pytest.mark.parametrize( + "input_zero_point, input_scale, kernel_scale", + [(10, 0.0128, 0.11), (-64, 0.0256, 1.37)], +) +def test_op_int8( + in_shape, + enable_bias, + input_zero_point, + input_scale, + kernel_scale, + out_channels, + relu_type, +): + interface_api = "c" + use_unpacked_api = True + test_runner = AOT_CORSTONE300_RUNNER + + dtype = "int8" + kernel_zero_point = 0 + kernel_shape = [out_channels, in_shape[1]] + conv2d_kernel_shape = (1, 1, kernel_shape[0], kernel_shape[1]) + in_min, in_max = get_range_for_dtype_str(dtype) + + output_scale, output_zero_point = get_conv2d_qnn_params( + conv2d_kernel_shape, + input_scale, + input_zero_point, + kernel_scale, + kernel_zero_point, + dtype, + ) + + model, params = make_model( + in_shape, + kernel_shape, + input_zero_point, + kernel_zero_point, + input_scale, + kernel_scale, + output_zero_point, + output_scale, + dtype, + dtype, + out_channels, + enable_bias, + ) + orig_mod = make_module(model) + cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params) + + # validate pattern matching + attrs = [ + cmsisnn_mod[var.name_hint].attrs + for var in cmsisnn_mod.get_global_vars() + if cmsisnn_mod[var.name_hint].attrs + ] + assert any(attrs), "At least one function with external attributes was expected." + + compilers = [ + key == "Compiler" and value == "cmsis-nn" for attr in attrs for key, value in attr.items() + ] + assert any(compilers), "Module does not contain function for cmsisnn target." + + assert count_num_calls(orig_mod) == count_num_calls( + cmsisnn_mod + ), "Number of calls changed during partitioning" + + # validate the output + rng = np.random.default_rng(12345) + inputs = {"input": rng.integers(in_min, high=in_max, size=in_shape, dtype=dtype)} + output_list = generate_ref_data(orig_mod["main"], inputs, params) + compile_and_run( + AOTTestModel( + module=cmsisnn_mod, + inputs=inputs, + outputs=output_list, + params=params, + output_tolerance=1, + ), + test_runner, + interface_api, + use_unpacked_api, + ) + + +def parameterize_for_invalid_model(test): + in_dtype = ["uint8", "int8"] + kernel_dtype = ["uint8", "int8"] + kernel_zero_point = [-33, 10, 0] + all_combinations = itertools.product(in_dtype, kernel_dtype, kernel_zero_point) + all_combinations = filter( + lambda parameters: not ( + parameters[0] == "int8" and parameters[1] == "int8" and parameters[2] == 0 + ), + all_combinations, + ) + return pytest.mark.parametrize( + ["in_dtype", "kernel_dtype", "kernel_zero_point"], + all_combinations, + )(test) + + +@tvm.testing.requires_cmsisnn +@parameterize_for_invalid_model +def test_invalid_parameters( + in_dtype, + kernel_dtype, + kernel_zero_point, +): + in_shape = (2, 28) + out_channels = 2 + input_scale = 1 + input_zero_point = 24 + kernel_scale = [0.11, 0.0237] + in_min, in_max = get_range_for_dtype_str(in_dtype) + + kernel_shape = [out_channels, in_shape[1]] + conv2d_kernel_shape = [1, 1, kernel_shape[0], kernel_shape[1]] + output_scale, output_zero_point = get_conv2d_qnn_params( + conv2d_kernel_shape, + input_scale, + input_zero_point, + kernel_scale, + kernel_zero_point, + in_dtype, + kernel_dtype, + in_dtype, + ) + model, params = make_model( + in_shape=in_shape, + kernel_shape=kernel_shape, + input_zero_point=input_zero_point, + kernel_zero_point=kernel_zero_point, + input_scale=input_scale, + kernel_scale=kernel_scale, + output_zero_point=output_zero_point, + output_scale=output_scale, + dtype=in_dtype, + kernel_dtype=kernel_dtype, + out_channels=out_channels, + enable_bias=True, + ) + orig_mod = make_module(model) + cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params) + + # validate pattern matching + attrs = [ + cmsisnn_mod[var.name_hint].attrs + for var in cmsisnn_mod.get_global_vars() + if cmsisnn_mod[var.name_hint].attrs + ] + assert not any(attrs), "No function should have an external attribute." + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/contrib/test_cmsisnn/test_pooling.py b/tests/python/contrib/test_cmsisnn/test_pooling.py new file mode 100644 index 000000000000..1c440b1e1de4 --- /dev/null +++ b/tests/python/contrib/test_cmsisnn/test_pooling.py @@ -0,0 +1,173 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""CMSIS-NN integration tests: Conv2D""" +import itertools +import numpy as np +import pytest +import tvm +from tvm import relay +from tvm.relay.op.contrib import cmsisnn + + +from tests.python.relay.aot.aot_test_utils import ( + AOTTestModel, + AOT_CORSTONE300_RUNNER, + AOT_DEFAULT_RUNNER, + generate_ref_data, + compile_and_run, +) +from utils import ( + skip_if_no_reference_system, + make_module, + count_num_calls, + get_range_for_dtype_str, + get_same_padding, + get_conv2d_qnn_params, + make_qnn_relu, +) + + +def make_model(pool_op, shape, pool_size, strides, padding, dtype, scale, zero_point, relu_type): + """Return a model and any parameters it may have""" + op = relay.var("input", shape=shape, dtype=dtype) + pad_ = (0, 0, 0, 0) + if padding == "SAME": + dilation = (1, 1) + pad_ = get_same_padding((shape[1], shape[2]), pool_size, dilation, strides) + op = relay.nn.pad( + op, + pad_width=[(0, 0), (pad_[0], pad_[2]), (pad_[1], pad_[3]), (0, 0)], + pad_value=zero_point, + pad_mode="constant", + ) + if pool_op == relay.nn.avg_pool2d: + op = relay.cast(op, "int32") + op = pool_op( + op, pool_size=pool_size, strides=strides, padding=pad_, ceil_mode=True, layout="NHWC" + ) + if pool_op == relay.nn.avg_pool2d: + op = relay.cast(op, dtype) + op = make_qnn_relu(op, relu_type, scale, zero_point, dtype) + return op + + +@tvm.testing.requires_cmsisnn +@pytest.mark.parametrize("in_shape", [(1, 28, 28, 12), (1, 64, 100, 4)]) +@pytest.mark.parametrize( + "pool_size, strides, padding", [((3, 3), (2, 2), "SAME"), ((2, 2), (1, 1), "VALID")] +) +@pytest.mark.parametrize("relu_type", ["RELU"]) +@pytest.mark.parametrize("pool_type", [relay.nn.max_pool2d, relay.nn.avg_pool2d]) +@pytest.mark.parametrize("zero_point, scale", [(-34, 0.0256)]) +def test_op_int8( + in_shape, + pool_size, + strides, + padding, + relu_type, + pool_type, + zero_point, + scale, +): + interface_api = "c" + use_unpacked_api = True + test_runner = AOT_CORSTONE300_RUNNER + + dtype = "int8" + + model = make_model( + pool_type, + in_shape, + pool_size, + strides, + padding, + dtype, + scale, + zero_point, + relu_type, + ) + orig_mod = make_module(model) + + cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) + + # validate pattern matching + attrs = [ + cmsisnn_mod[var.name_hint].attrs + for var in cmsisnn_mod.get_global_vars() + if cmsisnn_mod[var.name_hint].attrs + ] + assert any(attrs), "At least one function with external attributes was expected." + + compilers = [ + key == "Compiler" and value == "cmsis-nn" for attr in attrs for key, value in attr.items() + ] + assert any(compilers), "Module does not contain function for cmsisnn target." + + assert count_num_calls(orig_mod) == count_num_calls( + cmsisnn_mod + ), "Number of calls changed during partitioning" + + # validate the output + in_min, in_max = get_range_for_dtype_str(dtype) + np.random.seed(0) + inputs = { + "input": np.random.randint(in_min, high=in_max, size=in_shape, dtype="int8"), + } + output_list = generate_ref_data(orig_mod["main"], inputs) + compile_and_run( + AOTTestModel( + module=cmsisnn_mod, + inputs=inputs, + outputs=output_list, + params=None, + output_tolerance=1, + ), + test_runner, + interface_api, + use_unpacked_api, + ) + + +@tvm.testing.requires_cmsisnn +def test_invalid_parameters(): + model = make_model( + pool_op=relay.nn.avg_pool2d, + shape=(1, 28, 28, 12), + pool_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype="uint8", + scale=1, + zero_point=-33, + relu_type="RELU", + ) + + orig_mod = make_module(model) + cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) + + # validate pattern matching + attrs = [ + cmsisnn_mod[var.name_hint].attrs + for var in cmsisnn_mod.get_global_vars() + if cmsisnn_mod[var.name_hint].attrs + ] + assert not any(attrs), "No function should have an external attribute." + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:]))