From adcc858adca89463a8d5c3af05cbc98ae8aa4f70 Mon Sep 17 00:00:00 2001 From: James Date: Mon, 8 Sep 2025 19:24:42 +0800 Subject: [PATCH] [fix](constant fold)Do not do BE constant fold when float/double is NaN (#55425) Do not do BE constant fold when float/double is NaN --- .../rules/FoldConstantRuleOnBE.java | 49 +++++++++++++------ .../fold_constant/fold_constant_by_be.groovy | 10 ++++ 2 files changed, 43 insertions(+), 16 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnBE.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnBE.java index 6d95fb8667a4e0..2da29bcfe065ce 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnBE.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnBE.java @@ -408,25 +408,23 @@ public static List getResultExpression(DataType type, PValues resultCon int num = resultContent.getFloatValueCount(); for (int i = 0; i < num; ++i) { float value = resultContent.getFloatValue(i); - Literal literal = null; if (Float.isNaN(value)) { - literal = new NullLiteral(type); + return Collections.emptyList(); } else { - literal = new FloatLiteral(value); + Literal literal = new FloatLiteral(value); + res.add(literal); } - res.add(literal); } } else if (type.isDoubleType()) { int num = resultContent.getDoubleValueCount(); for (int i = 0; i < num; ++i) { double value = resultContent.getDoubleValue(i); - Literal literal = null; if (Double.isNaN(value)) { - literal = new NullLiteral(type); + return Collections.emptyList(); } else { - literal = new DoubleLiteral(value); + Literal literal = new DoubleLiteral(value); + res.add(literal); } - res.add(literal); } } else if (type.isDecimalV2Type()) { int num = resultContent.getBytesValueCount(); @@ -505,8 +503,13 @@ public static List getResultExpression(DataType type, PValues resultCon int childCount = resultContent.getChildElementCount(); List allLiterals = new ArrayList<>(); for (int i = 0; i < childCount; ++i) { - allLiterals.addAll(getResultExpression(arrayType.getItemType(), - resultContent.getChildElement(i))); + List resultExpression = getResultExpression(arrayType.getItemType(), + resultContent.getChildElement(i)); + // If any child element couldn't fold, stop folding this Array. + if (resultExpression.isEmpty()) { + return Collections.emptyList(); + } + allLiterals.addAll(resultExpression); } int offsetCount = resultContent.getChildOffsetCount(); if (offsetCount == 1) { @@ -530,10 +533,19 @@ public static List getResultExpression(DataType type, PValues resultCon List allKeys = new ArrayList<>(); List allValues = new ArrayList<>(); for (int i = 0; i < childCount; i = i + 2) { - allKeys.addAll(getResultExpression(mapType.getKeyType(), - resultContent.getChildElement(i))); - allValues.addAll(getResultExpression(mapType.getValueType(), - resultContent.getChildElement(i + 1))); + // If any key or value couldn't fold, stop folding this Map. + List key = getResultExpression(mapType.getKeyType(), + resultContent.getChildElement(i)); + if (key.isEmpty()) { + return Collections.emptyList(); + } + allKeys.addAll(key); + List value = getResultExpression(mapType.getValueType(), + resultContent.getChildElement(i + 1)); + if (value.isEmpty()) { + return Collections.emptyList(); + } + allValues.addAll(value); } int offsetCount = resultContent.getChildOffsetCount(); if (offsetCount == 1) { @@ -558,8 +570,13 @@ public static List getResultExpression(DataType type, PValues resultCon int childCount = resultContent.getChildElementCount(); List> allFields = new ArrayList<>(); for (int i = 0; i < childCount; ++i) { - allFields.add(getResultExpression(structType.getFields().get(i).getDataType(), - resultContent.getChildElement(i))); + List resultExpression = getResultExpression(structType.getFields().get(i).getDataType(), + resultContent.getChildElement(i)); + // If any field couldn't fold, stop folding this Struct. + if (resultExpression.isEmpty()) { + return Collections.emptyList(); + } + allFields.add(resultExpression); } for (int i = 0; i < allFields.get(0).size(); ++i) { List fields = new ArrayList<>(); diff --git a/regression-test/suites/nereids_p0/expression/fold_constant/fold_constant_by_be.groovy b/regression-test/suites/nereids_p0/expression/fold_constant/fold_constant_by_be.groovy index 9b1a902b5ec127..6c9a6d14e521af 100644 --- a/regression-test/suites/nereids_p0/expression/fold_constant/fold_constant_by_be.groovy +++ b/regression-test/suites/nereids_p0/expression/fold_constant/fold_constant_by_be.groovy @@ -96,4 +96,14 @@ suite("fold_constant_by_be") { sql "select IS_IPV4_MAPPED(NULLABLE(ipv6_string_to_num_or_default('192.168.1.1')));" contains "192.168.1.1" } + explain { + sql "select cosine_distance([0], [0]);" + contains "cosine_distance" + notContains("NULL") + } + + explain { + sql "select array(cosine_distance([1], [1]), cast(\"NaN\" as float));" + contains "array(cosine_distance" + } }