diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc index 813bba8a5952b..26d4db845705f 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -17,6 +18,50 @@ namespace onnxruntime { namespace qnn { +namespace { + +// 1. Clip can be removed iff the codoamin of QuantizeLinear remain unchanged. +// 2. To remain unchanged, y=QuantizeLinear(Clip(x)) must span the full range of values that can be represented by the +// integer type of y. We can use this precondition to eval QuantizeLinear backward to to get the domain. +// 3. Indicates the domain of QuantizeLinear is strict subset of the codomain of Clip. We can use this to test if a +// removal is valid or not. +// 4. Due to rounding effect, we can be get the upperbound and lowerbound of min or max +// - Which one to use? +// - upperbound of min and lowerbound of max +// - Why? +// - We want the codomain to be unchanged. + + +// The quantization formula is y = saturate((x / y_scale) + y_zero_point) +// For (x / y_scale), it rounds to the nearest even. So the allowed quantize limits before saturate need to be taken +// care of. +// +// The following struct provides a wrapper to compute the domain from codomain (Q output dtype). +template +struct quantize_domain { + static float min_upper(float scale, int zp) { + int64_t codomain_min = std::numeric_limits::lowest(); + int64_t biased = codomain_min - zp; + float before_round = float(biased) + 0.5; // move to upperbound + if (biased % 2 == 1) { // cannot be exact ?.5 because of rounding to even + before_round = std::nextafterf(before_round, float(biased)); + } + return before_round * scale; + } + + static float max_lower(float scale, int zp) { + int64_t codomain_max = std::numeric_limits::max(); + int64_t biased = codomain_max - zp; + float before_round = float(biased) - 0.5; // move to lowerbound + if (biased % 2 == 1) { // cannot be exact ?.5 because of rounding to even + before_round = std::nextafterf(before_round, float(biased)); + } + return before_round * scale; + } +}; + +} + // Gets the scale, zero-point, and zero-point type for a QuantizeLinear node that uses per-tensor quantization. static bool GetQScalarScaleZeroPoint(const QnnModelWrapper& qnn_model_wrapper, const NodeUnit& q_node_unit, @@ -53,7 +98,7 @@ static bool GetQScalarScaleZeroPoint(const QnnModelWrapper& qnn_model_wrapper, return true; } -// Computes the floating point range (rmin, rmax) from a QuantizeLinear node's scale/zero-point. +// Computes the floating point range (aka, domain) (rmin, rmax) from a QuantizeLinear node's scale/zero-point. static bool GetQRminRmax(const QnnModelWrapper& qnn_model_wrapper, const NodeUnit& q_node_unit, /*out*/ float& rmin, @@ -68,23 +113,23 @@ static bool GetQRminRmax(const QnnModelWrapper& qnn_model_wrapper, switch (zp_data_type) { case ONNX_NAMESPACE::TensorProto_DataType_INT8: { - rmin = scale * (std::numeric_limits::lowest() - zero_point); - rmax = scale * (std::numeric_limits::max() - zero_point); + rmin = quantize_domain::min_upper(scale, zero_point); + rmax = quantize_domain::max_lower(scale, zero_point); break; } case ONNX_NAMESPACE::TensorProto_DataType_UINT8: { - rmin = scale * (std::numeric_limits::lowest() - zero_point); - rmax = scale * (std::numeric_limits::max() - zero_point); + rmin = quantize_domain::min_upper(scale, zero_point); + rmax = quantize_domain::max_lower(scale, zero_point); break; } case ONNX_NAMESPACE::TensorProto_DataType_INT16: { - rmin = scale * (std::numeric_limits::lowest() - zero_point); - rmax = scale * (std::numeric_limits::max() - zero_point); + rmin = quantize_domain::min_upper(scale, zero_point); + rmax = quantize_domain::max_lower(scale, zero_point); break; } case ONNX_NAMESPACE::TensorProto_DataType_UINT16: { - rmin = scale * (std::numeric_limits::lowest() - zero_point); - rmax = scale * (std::numeric_limits::max() - zero_point); + rmin = quantize_domain::min_upper(scale, zero_point); + rmax = quantize_domain::max_lower(scale, zero_point); break; } default: @@ -115,15 +160,14 @@ static bool CanClipBeRemoved(const QnnModelWrapper& qnn_model_wrapper, return false; } - // The clip range must entirely overlap the quantization range (quantization can be smaller). - // Clip range: [------------------] - // Quant range: [-------------] - constexpr float epsilon = std::numeric_limits::epsilon(); - if ((epsilon < clip_min - rmin) || (epsilon < rmax - clip_max)) { - return false; + // The Clip codomain must entirely overlap the QuantizeLinear domain (quantization can be smaller). + // Clip codomain: [------------------] + // QuantizeLinear domain: [-------------] + if (clip_min <= rmin && rmax <= clip_max) { + return true; } - return true; + return false; } // Returns true if the Relu in the sequence (Relu -> Q) can be removed because it is made redundant by the Q. @@ -137,16 +181,15 @@ static bool CanQRelaceRelu(const QnnModelWrapper& qnn_model_wrapper, const NodeU return false; } - // Relu is redundant if the zero-point is set to the smallest quantized value. switch (zp_data_type) { case ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_INT8: - return zero_point == static_cast(std::numeric_limits::lowest()); + return quantize_domain::min_upper(scale, zero_point) >= 0; case ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_UINT8: - return zero_point == static_cast(std::numeric_limits::lowest()); + return quantize_domain::min_upper(scale, zero_point) >= 0; case ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_INT16: - return zero_point == static_cast(std::numeric_limits::lowest()); + return quantize_domain::min_upper(scale, zero_point) >= 0; case ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_UINT16: - return zero_point == static_cast(std::numeric_limits::lowest()); + return quantize_domain::min_upper(scale, zero_point) >= 0; default: return false; }