diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc index a2dca669c24f..c667aeeaa61f 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc @@ -1,9 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include "core/providers/common.h" #include "core/providers/shared/utils/utils.h" #include "core/framework/tensorprotoutils.h" +#include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/common/safeint.h" @@ -24,6 +26,11 @@ class LayerNormOpBuilder : public BaseOpBuilder { const logging::Logger& logger) const override final ORT_MUST_USE_RESULT; protected: + Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const override ORT_MUST_USE_RESULT; Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, std::vector&& input_names, @@ -55,6 +62,91 @@ Status LayerNormOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, return AddToModelBuilder(qnn_model_wrapper, node_unit, logger, true); } +Status LayerNormOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const { + ORT_UNUSED_PARAMETER(do_op_validation); + + const auto& inputs = node_unit.Inputs(); + const auto input_count = inputs.size(); + constexpr size_t X_IDX = 0; + constexpr size_t SCALE_IDX = 1; + constexpr size_t BIAS_IDX = 2; + + // Input[0] (X, required) + ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[X_IDX], logger, input_names)); + + // Input[1] (scale, required) + ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[SCALE_IDX], logger, input_names)); + + // Input[2] (bias, optional) + const bool has_bias_input = input_count > BIAS_IDX && inputs[BIAS_IDX].node_arg.Exists(); + if (has_bias_input) { + ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[BIAS_IDX], logger, input_names)); + } + +#if QNN_API_VERSION_MAJOR == 2 && QNN_API_VERSION_MINOR == 17 + if (!has_bias_input && IsNpuBackend(qnn_model_wrapper.GetQnnBackendType())) { + // Bias is implicit. QNN SDK 2.24 (QNN API version 2.17) has a validation bug for implicit bias inputs, so provide + // an explicit bias of all 0 (quantized int32). + TensorInfo x_input_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[X_IDX], x_input_info)); + + TensorInfo scale_input_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[SCALE_IDX], scale_input_info)); + + if (x_input_info.quant_param.IsPerTensor(/*include_bw*/ true) && scale_input_info.quant_param.IsQuantized()) { + const std::string bias_name = qnn::utils::GetNodeName(node_unit) + "_implicit_bias_ort_qnn_ep"; + + // Make dummy bias input have the same shape as the scale input. + std::vector bias_shape = scale_input_info.shape; + size_t num_bias_elems = 1; + for (size_t i = 0; i < bias_shape.size(); i++) { + num_bias_elems *= static_cast(bias_shape[i]); + } + + // Bias static input should be all zeros. + std::vector bias_bytes(num_bias_elems * sizeof(int32_t), 0); + + // Bias's quantization scale should be the product of the other inputs' quantization scales. + std::vector input0_quant_scales; + std::vector input1_quant_scales; + ORT_RETURN_IF_ERROR(x_input_info.quant_param.GetScales(input0_quant_scales)); + ORT_RETURN_IF_ERROR(scale_input_info.quant_param.GetScales(input1_quant_scales)); + + const size_t num_bias_scales_offsets = input1_quant_scales.size(); + assert(input0_quant_scales.size() == 1); // Expected for per-tensor. + ORT_RETURN_IF_NOT(num_bias_scales_offsets >= input0_quant_scales.size(), + "Input[1] should have >= 1 quantization scale values"); + + std::vector bias_scales(num_bias_scales_offsets); + for (size_t i = 0; i < num_bias_scales_offsets; i++) { + bias_scales[i] = input0_quant_scales[0] * input1_quant_scales[i]; + } + + std::vector bias_offsets(num_bias_scales_offsets, 0); // Bias's zero-points should be all zeros. + QnnQuantParamsWrapper bias_qparams; + + if (scale_input_info.quant_param.IsPerChannel()) { + bias_qparams = QnnQuantParamsWrapper(bias_scales, bias_offsets, /*axis*/ 0, /*is_int4*/ false); + } else { + bias_qparams = QnnQuantParamsWrapper(bias_scales[0], bias_offsets[0]); + } + + auto tensor_wrapper = QnnTensorWrapper(bias_name, QNN_TENSOR_TYPE_STATIC, QNN_DATATYPE_SFIXED_POINT_32, + std::move(bias_qparams), std::move(bias_shape), std::move(bias_bytes)); + + qnn_model_wrapper.AddTensorWrapper(std::move(tensor_wrapper)); + input_names.push_back(bias_name); + } + } +#endif + + return Status::OK(); +} + Status LayerNormOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, std::vector&& input_names, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc index f85cdc401a15..c8537307ef3b 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc @@ -442,6 +442,16 @@ Status QnnModelWrapper::IsPerChannelQuantized(const onnxruntime::NodeUnitIODef& if (is_per_channel) { axis = io_def.quant_param->axis.value_or(1); // 1 is default axis for Q/DQ ops. + if (axis < 0) { + // Normalize negative axis by adding rank. + const auto* tensor_shape_proto = io_def.node_arg.Shape(); + ORT_RETURN_IF_NOT(tensor_shape_proto != nullptr, "NULL tensor shape proto"); + + const int rank = tensor_shape_proto->dim_size(); + ORT_RETURN_IF_NOT(rank > 0, "Per-channel quantized tensor should be of rank > 0"); + + axis += rank; + } } return Status::OK(); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.cc index 2d22c3c1b822..da2d517f6569 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.cc @@ -30,6 +30,7 @@ QnnQuantParamsWrapper& QnnQuantParamsWrapper::operator=(const QnnQuantParamsWrap return *this; } +// Construct per-tensor quantization params. QnnQuantParamsWrapper::QnnQuantParamsWrapper(float scale, int32_t offset) { params_.encodingDefinition = QNN_DEFINITION_DEFINED; params_.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET; @@ -37,6 +38,110 @@ QnnQuantParamsWrapper::QnnQuantParamsWrapper(float scale, int32_t offset) { params_.scaleOffsetEncoding.offset = offset; } +// Construct a per-channel quantization param. +QnnQuantParamsWrapper::QnnQuantParamsWrapper(gsl::span scales, gsl::span offsets, + int32_t axis, bool is_int4) { + assert(scales.size() == offsets.size()); // Logic error if sizes don't match. + const uint32_t num_elems = static_cast(scales.size()); + params_.encodingDefinition = QNN_DEFINITION_DEFINED; + + if (is_int4) { + params_.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET; + params_.bwAxisScaleOffsetEncoding.numElements = num_elems; + params_.bwAxisScaleOffsetEncoding.axis = axis; + params_.bwAxisScaleOffsetEncoding.bitwidth = 4; + + // Deep copy to the scales[] and offsets[] arrays + if (num_elems > 0) { + const size_t num_scale_bytes = num_elems * sizeof(float); + const size_t num_zp_bytes = num_elems * sizeof(int32_t); + const size_t num_bytes = num_scale_bytes + num_zp_bytes; + constexpr std::uintptr_t align = alignof(float); + static_assert(alignof(float) == alignof(int32_t)); + + per_channel_data_ = std::make_unique(num_bytes + align); + char* scales_begin = ALIGN_PTR_UP(per_channel_data_.get(), align, char*); + char* zps_begin = scales_begin + num_scale_bytes; + + std::memcpy(scales_begin, scales.data(), num_scale_bytes); + std::memcpy(zps_begin, offsets.data(), num_zp_bytes); + params_.bwAxisScaleOffsetEncoding.scales = reinterpret_cast(scales_begin); + params_.bwAxisScaleOffsetEncoding.offsets = reinterpret_cast(zps_begin); + } else { + params_.bwAxisScaleOffsetEncoding.scales = nullptr; + params_.bwAxisScaleOffsetEncoding.offsets = nullptr; + } + } else { + params_.quantizationEncoding = QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET; + params_.axisScaleOffsetEncoding.numScaleOffsets = num_elems; + params_.axisScaleOffsetEncoding.axis = axis; + + // Deep copy to the scaleOffset data. + if (num_elems > 0) { + const size_t num_bytes = num_elems * sizeof(Qnn_ScaleOffset_t); + constexpr std::uintptr_t align = alignof(Qnn_ScaleOffset_t); + per_channel_data_ = std::make_unique(num_bytes + align); + Qnn_ScaleOffset_t* aligned_dst = ALIGN_PTR_UP(per_channel_data_.get(), align, Qnn_ScaleOffset_t*); + + for (size_t i = 0; i < static_cast(num_elems); i++) { + aligned_dst[i].offset = offsets[i]; + aligned_dst[i].scale = scales[i]; + } + + params_.axisScaleOffsetEncoding.scaleOffset = aligned_dst; + } else { + params_.axisScaleOffsetEncoding.scaleOffset = nullptr; + } + } +} + +// Get a copy of scales. Works for both per-tensor and per-channel. +Status QnnQuantParamsWrapper::GetScales(/*out*/ std::vector& scales) const { + ORT_RETURN_IF_NOT(params_.encodingDefinition == QNN_DEFINITION_DEFINED, "Unquantized qparams does not have scales"); + + switch (params_.quantizationEncoding) { + case QNN_QUANTIZATION_ENCODING_SCALE_OFFSET: + scales.resize(1); + scales[0] = params_.scaleOffsetEncoding.scale; + break; + case QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET: + scales.resize(1); + scales[0] = params_.bwScaleOffsetEncoding.scale; + break; + case QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET: { + const uint32_t num_elems = params_.axisScaleOffsetEncoding.numScaleOffsets; + scales.resize(num_elems); + + if (num_elems > 0) { + gsl::span scale_offsets(params_.axisScaleOffsetEncoding.scaleOffset, num_elems); + + for (size_t i = 0; i < num_elems; i++) { + scales[i] = scale_offsets[i].scale; + } + } + break; + } + case QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET: { + const uint32_t num_elems = params_.bwAxisScaleOffsetEncoding.numElements; + scales.resize(num_elems); + + // Deep copy the scales[] and offsets[] arrays + if (num_elems > 0) { + gsl::span src_scales(params_.bwAxisScaleOffsetEncoding.scales, num_elems); + for (size_t i = 0; i < num_elems; i++) { + scales[i] = src_scales[i]; + } + } + break; + } + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported QNN quantization encoding: ", + params_.quantizationEncoding); + } + + return Status::OK(); +} + QnnQuantParamsWrapper QnnQuantParamsWrapper::Copy() const { return QnnQuantParamsWrapper(*this); } @@ -199,7 +304,7 @@ Status QnnQuantParamsWrapper::Init(const QnnModelWrapper& qnn_model_wrapper, con params_.encodingDefinition = QNN_DEFINITION_DEFINED; params_.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET; - params_.bwAxisScaleOffsetEncoding.axis = static_cast(*(ort_quant_params->axis)); + params_.bwAxisScaleOffsetEncoding.axis = static_cast(axis); params_.bwAxisScaleOffsetEncoding.bitwidth = 4; params_.bwAxisScaleOffsetEncoding.numElements = static_cast(num_elems); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.h b/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.h index d1f93e5a692b..23330f5616d7 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.h @@ -3,6 +3,7 @@ #pragma once #include +#include #include "QnnTypes.h" #include "core/common/common.h" #include @@ -26,6 +27,9 @@ class QnnQuantParamsWrapper { // Construct a per-tensor quantization param (SCALE_OFFSET) QnnQuantParamsWrapper(float scale, int32_t offset); + // Construct a per-channel quantization param. + QnnQuantParamsWrapper(gsl::span scales, gsl::span offsets, int32_t axis, bool is_int4); + Qnn_QuantizeParams_t& Get() { return params_; } const Qnn_QuantizeParams_t& Get() const { return params_; } @@ -54,6 +58,9 @@ class QnnQuantParamsWrapper { (params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET)); } + // Get a copy of scales. Works for both per-tensor and per-channel. + Status GetScales(/*out*/ std::vector& scales) const; + // Handle transposing of a per-channel quantized tensor. The quantization parameter's axis // must be transposed using the inverse permutation of the Transpose. template diff --git a/onnxruntime/test/providers/qnn/conv_test.cc b/onnxruntime/test/providers/qnn/conv_test.cc index b07951d2a2e6..99636976b9c0 100644 --- a/onnxruntime/test/providers/qnn/conv_test.cc +++ b/onnxruntime/test/providers/qnn/conv_test.cc @@ -178,10 +178,14 @@ static GetTestQDQModelFn BuildQDQPerChannelConvTestCase(const s ORT_ENFORCE(weights_def.IsInitializer() && weights_def.IsRawData()); std::vector weight_scales; std::vector weight_zero_points; + TensorShape weights_shape = weights_def.GetTensorShape(); + int64_t pos_weight_quant_axis = weight_quant_axis; + if (pos_weight_quant_axis < 0) { + pos_weight_quant_axis += static_cast(weights_shape.NumDimensions()); + } GetTestInputQuantParamsPerChannel(weights_def, weight_scales, weight_zero_points, - static_cast(weight_quant_axis), true); + static_cast(pos_weight_quant_axis), true); - TensorShape weights_shape = weights_def.GetTensorShape(); std::vector quantized_weights; size_t num_weight_storage_elems = weights_shape.Size(); if constexpr (std::is_same_v || std::is_same_v) { @@ -189,7 +193,7 @@ static GetTestQDQModelFn BuildQDQPerChannelConvTestCase(const s } quantized_weights.resize(num_weight_storage_elems); QuantizeValues(weights_def.GetRawData(), quantized_weights, weights_shape, - weight_scales, weight_zero_points, weight_quant_axis); + weight_scales, weight_zero_points, pos_weight_quant_axis); NodeArg* weights_initializer = builder.MakeInitializer(weights_def.GetShape(), quantized_weights); NodeArg* weights_dq = builder.MakeIntermediate(); @@ -760,6 +764,34 @@ TEST_F(QnnHTPBackendTests, ConvU16S4S32_PerChannel) { 21); // opset } +// Test per-channel QDQ Conv with INT4 weights and a negative weight quantization axis that still points to dimension 0. +TEST_F(QnnHTPBackendTests, ConvU16S4S32_PerChannel_NegativeWeightQuantAxis) { + std::vector input_shape = {1, 2, 4, 4}; + std::vector weight_shape = {3, 2, 2, 2}; + std::vector bias_shape = {3}; + + TestInputDef input_def(input_shape, false, + GetFloatDataInRange(0.0f, 1.0f, TensorShape(input_shape).Size())); + TestInputDef weight_def(weight_shape, true, + GetFloatDataInRange(-1.0f, 5.0f, TensorShape(weight_shape).Size())); + TestInputDef bias_def(bias_shape, true, + GetFloatDataInRange(-1.0f, 1.0f, TensorShape(bias_shape).Size())); + + RunHTPConvOpPerChannelTest("Conv", + input_def, + weight_def, + bias_def, + -4, // negative weight quant axis (same as 0) + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + false, // use_qdq_contrib_ops + 21); // opset +} + // Test per-channel QDQ Conv with INT4 weights. in0: u16, in1 (weight): s4, in2 (bias): s32, out: u8 // TODO(adrianlizarraga): Investigate inaccuracy for QNN EP. // diff --git a/onnxruntime/test/providers/qnn/layer_norm_test.cc b/onnxruntime/test/providers/qnn/layer_norm_test.cc index 7d129dceca58..2af49a5e500d 100644 --- a/onnxruntime/test/providers/qnn/layer_norm_test.cc +++ b/onnxruntime/test/providers/qnn/layer_norm_test.cc @@ -79,25 +79,53 @@ TEST_F(QnnCPUBackendTests, LayerNorm3D) { template GetTestQDQModelFn BuildQDQLayerNormTestCase(const TestInputDef& input_def, const TestInputDef& scale_def, + const TestInputDef& bias_def, const std::vector& attrs, bool use_contrib_qdq_ops) { - return [input_def, scale_def, attrs, use_contrib_qdq_ops](ModelTestBuilder& builder, - std::vector>& output_qparams) { + return [input_def, scale_def, bias_def, attrs, + use_contrib_qdq_ops](ModelTestBuilder& builder, + std::vector>& output_qparams) { + std::vector layer_norm_inputs; + // input -> Q -> DQ -> NodeArg* input = MakeTestInput(builder, input_def); QuantParams input_qparams = GetTestInputQuantParams(input_def); NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point, use_contrib_qdq_ops); + layer_norm_inputs.push_back(input_qdq); - // scale input -> Q -> DQ -> - NodeArg* scale = MakeTestInput(builder, scale_def); + NodeArg* scale_qdq = nullptr; QuantParams scale_qparams = GetTestInputQuantParams(scale_def); - NodeArg* scale_qdq = AddQDQNodePair(builder, scale, scale_qparams.scale, scale_qparams.zero_point, - use_contrib_qdq_ops); + + if (scale_def.IsInitializer() && scale_def.IsRawData()) { + // Quantized(scale weights) -> DQ -> + std::vector scale_scales = {scale_qparams.scale}; + std::vector scale_zps = {scale_qparams.zero_point}; + TensorShape scale_shape = scale_def.GetTensorShape(); + std::vector quantized_scales(scale_shape.Size()); + QuantizeValues(scale_def.GetRawData(), quantized_scales, scale_shape, + scale_scales, scale_zps, std::nullopt); + + NodeArg* scale_initzer = builder.MakeInitializer(scale_def.GetShape(), quantized_scales); + scale_qdq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(scale_initzer, scale_scales, scale_zps, scale_qdq, + nullptr, use_contrib_qdq_ops); + } else { + // scale input -> Q -> DQ -> + NodeArg* scale = MakeTestInput(builder, scale_def); + scale_qdq = AddQDQNodePair(builder, scale, scale_qparams.scale, scale_qparams.zero_point, + use_contrib_qdq_ops); + } + layer_norm_inputs.push_back(scale_qdq); + + if (!bias_def.GetShape().empty()) { + const float bias_scale = input_qparams.scale * scale_qparams.scale; + layer_norm_inputs.push_back(MakeTestQDQBiasInput(builder, bias_def, bias_scale, use_contrib_qdq_ops)); + } // LayerNormalization NodeArg* layer_norm_output = builder.MakeIntermediate(); - Node& layer_norm_node = builder.AddNode("LayerNormalization", {input_qdq, scale_qdq}, {layer_norm_output}); + Node& layer_norm_node = builder.AddNode("LayerNormalization", layer_norm_inputs, {layer_norm_output}); for (const auto& attr : attrs) { layer_norm_node.AddAttributeProto(attr); @@ -114,6 +142,7 @@ GetTestQDQModelFn BuildQDQLayerNormTestCase(const TestInputDef static void RunLayerNormQDQTest(const TestInputDef& input_def, const TestInputDef& scale_def, + const TestInputDef& bias_def, const std::vector& attrs, ExpectedEPNodeAssignment expected_ep_assignment, bool use_contrib_qdq_ops = false) { @@ -125,7 +154,7 @@ static void RunLayerNormQDQTest(const TestInputDef& input_def, #endif TestQDQModelAccuracy(BuildOpTestCase("LayerNormalization", {input_def, scale_def}, {}, attrs), - BuildQDQLayerNormTestCase(input_def, scale_def, attrs, + BuildQDQLayerNormTestCase(input_def, scale_def, bias_def, attrs, use_contrib_qdq_ops), provider_options, 17, // opset @@ -136,6 +165,7 @@ static void RunLayerNormQDQTest(const TestInputDef& input_def, TEST_F(QnnHTPBackendTests, LayerNorm1D_Axis0_Unsupported) { RunLayerNormQDQTest(TestInputDef({1, 2, 3}, false, 0.0f, 10.0f), TestInputDef({1, 2, 3}, true, 0.0f, 10.0f), + TestInputDef(), {utils::MakeAttribute("axis", static_cast(0))}, // Unsupported axis ExpectedEPNodeAssignment::None); } @@ -143,16 +173,40 @@ TEST_F(QnnHTPBackendTests, LayerNorm1D_Axis0_Unsupported) { // Test accuracy of 8-bit QDQ LayerNorm with a static scale input. TEST_F(QnnHTPBackendTests, LayerNorm1D_LastAxis_StaticScale_AU8_WU8) { RunLayerNormQDQTest(TestInputDef({1, 2, 3}, false, GetFloatDataInRange(0.0f, 10.0f, 6)), - TestInputDef({3}, true, GetFloatDataInRange(0.0f, 1.0f, 3)), // Static - {utils::MakeAttribute("axis", static_cast(-1))}, // Last axis + TestInputDef({3}, true, GetFloatDataInRange(0.0f, 1.0f, 3)), + TestInputDef(), // Implicit bias input + {utils::MakeAttribute("axis", static_cast(-1))}, ExpectedEPNodeAssignment::All); } +// Test accuracy of 8-bit QDQ LayerNorm with a static scale input and an explicit bias input (static). +TEST_F(QnnHTPBackendTests, LayerNorm1D_LastAxis_StaticScale_StaticBias_AU8_WU8_BU8) { + RunLayerNormQDQTest(TestInputDef({1, 2, 3}, false, GetFloatDataInRange(0.0f, 10.0f, 6)), + TestInputDef({3}, true, GetFloatDataInRange(0.0f, 1.0f, 3)), + TestInputDef({3}, true, GetFloatDataInRange(0.0f, 1.0f, 3)), + {utils::MakeAttribute("axis", static_cast(-1))}, + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LayerNorm1D_QNN2_24_ImplicitBias_ValidationBug) { + // QNN 2.24 LayerNorm fails validation (intermittent) if the bias input is not provided. QNN EP will provide an + // explicit bias of all zeros to get around this bug. + for (size_t i = 0; i < 15; i++) { // Run it multiple times since this is an intermittent bug. + RunLayerNormQDQTest(TestInputDef({1, 2, 3}, false, GetFloatDataInRange(0.0f, 1.0f, 6)), + TestInputDef({3}, true, GetFloatDataInRange(0.0f, 1.0f, 3)), + TestInputDef(), // Implicit bias input + {utils::MakeAttribute("axis", static_cast(-1))}, + ExpectedEPNodeAssignment::All, + true); + } +} + // Test accuracy of 16-bit QDQ LayerNorm with a static scale input. TEST_F(QnnHTPBackendTests, LayerNorm1D_LastAxis_StaticScale_AU16_WU8) { RunLayerNormQDQTest(TestInputDef({1, 2, 3}, false, GetFloatDataInRange(0.0f, 10.0f, 6)), TestInputDef({3}, true, GetFloatDataInRange(0.0f, 1.0f, 3)), // Static - {utils::MakeAttribute("axis", static_cast(-1))}, // Last axis + TestInputDef(), + {utils::MakeAttribute("axis", static_cast(-1))}, // Last axis ExpectedEPNodeAssignment::All, true); // Use 'com.microsoft' Q/DQ ops } @@ -174,7 +228,8 @@ TEST_F(QnnHTPBackendTests, LayerNorm1D_LastAxis_StaticScale_AU16_WU8) { TEST_F(QnnHTPBackendTests, DISABLED_LayerNorm1D_LastAxis_DynamicScale) { RunLayerNormQDQTest(TestInputDef({1, 2, 3}, false, GetFloatDataInRange(0.0f, 10.0f, 6)), TestInputDef({3}, false, GetFloatDataInRange(0.0f, 1.0f, 3)), // Dynamic - {utils::MakeAttribute("axis", static_cast(-1))}, // Last axis + TestInputDef(), + {utils::MakeAttribute("axis", static_cast(-1))}, // Last axis ExpectedEPNodeAssignment::All); } diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index ad54e644af3f..eb03270dc846 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -517,6 +517,9 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe ASSERT_STATUS_OK(f32_model.MainGraph().Resolve()); f32_model.ToProto().SerializeToString(&f32_model_data); + // Uncomment to save f32 model to disk for debugging. + // ASSERT_STATUS_OK(onnxruntime::Model::Save(f32_model, ToPathString("cmp_accuracy.f32.onnx"))); + // Run f32 model on CPU EP and collect outputs. std::vector cpu_f32_outputs; InferenceModel(f32_model_data, "f32_model_logger", {}, ExpectedEPNodeAssignment::All, @@ -556,6 +559,9 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe ASSERT_STATUS_OK(qdq_model.MainGraph().Resolve()); qdq_model.ToProto().SerializeToString(&qdq_model_data); + // Uncomment to save QDQ model to disk for debugging. + // ASSERT_STATUS_OK(onnxruntime::Model::Save(qdq_model, ToPathString("cmp_accuracy.qdq.onnx"))); + bool is_qnn_ep = true; TryEnableQNNSaver(qnn_options); std::vector qnn_qdq_outputs; diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index f7dc5779ec5d..2ebc2c6251b4 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -304,7 +304,7 @@ TEST_F(QnnHTPBackendTests, DISABLE_UnaryOp_Elu_U16) { // Expected val: 0 // QNN QDQ val: -10 (err 10) // CPU QDQ val: 0 (err 0) -TEST_F(QnnHTPBackendTests, DISABLED_UnaryOp_Relu) { +TEST_F(QnnHTPBackendTests, UnaryOp_Relu) { RunQDQOpTest("Relu", {TestInputDef({1, 2, 3}, false, GetFloatDataInRange(-10.0f, 10.0f, 6))}, {}, diff --git a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml index a4a3d0e6b334..6649206c0d79 100644 --- a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml @@ -31,7 +31,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.23.0.240531 + default: 2.24.0.240626 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index 700326fe9173..2eb7046d80e7 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -62,7 +62,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.23.0.240531 + default: 2.24.0.240626 resources: repositories: diff --git a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml index 29ebf67dd3f9..0d67b0947be5 100644 --- a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.23.0.240531 + default: 2.24.0.240626 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml index 8d1b6b7854e5..cd3966633d74 100644 --- a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml @@ -59,7 +59,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.23.0.240531 + default: 2.24.0.240626 trigger: none diff --git a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml index a8b12637b70f..7229bc5dbd11 100644 --- a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml @@ -2,7 +2,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.23.0.240531 + default: 2.24.0.240626 - name: build_config displayName: Build Configuration diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml index ada3603ae847..734ad43e0066 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.23.0.240531' + default: '2.24.0.240626' steps: - script: | diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml index 3a68803896ab..900adc969025 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.23.0.240531' + default: '2.24.0.240626' steps: - powershell: | diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml index 17e64a207be2..447e35244eb6 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml @@ -63,7 +63,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.23.0.240531 + default: 2.24.0.240626 stages: - ${{ if eq(parameters.enable_windows_cpu, true) }}: diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml index 70221976d978..40e8583141df 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.23.0.240531 + default: 2.24.0.240626 - name: PYTHON_VERSION type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml index 1bf5db5ae6d9..33335bb2be2d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.23.0.240531 + default: 2.24.0.240626 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml index b4c4f36c5dcc..944745b69ca6 100644 --- a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml @@ -1,5 +1,5 @@ parameters: - QnnSdk: '2.23.0.240531' + QnnSdk: '2.24.0.240626' build_config: 'RelWithDebInfo' IsReleaseBuild: false DoEsrp: false @@ -103,7 +103,7 @@ stages: - task: MSBuild@1 displayName: 'Restore NuGet Packages and create project.assets.json' inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.DesktopOnly.CSharp.sln' platform: 'Any CPU' configuration: ${{ parameters.build_config }} msbuildArguments: '-t:restore -p:OrtPackageId=$(OrtPackageId)' @@ -112,7 +112,7 @@ stages: - task: MSBuild@1 displayName: 'Build C# bindings' inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.DesktopOnly.CSharp.sln' platform: 'Any CPU' configuration: ${{ parameters.build_config }} msbuildArguments: '-p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }}' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index 97745fd09fbf..e1b8b718e992 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.23.0.240531 + default: 2.24.0.240626 jobs: - job: 'build' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml index 2ab81e16cd57..97c4ab15095c 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.23.0.240531 + default: 2.24.0.240626 jobs: - job: 'build'