From 0a2f68339e55b8ab691d1417c5ca03fa40e8ba2c Mon Sep 17 00:00:00 2001 From: yujun Date: Fri, 17 Oct 2025 11:35:18 +0800 Subject: [PATCH 1/3] fold nullif --- .../rules/FoldConstantRuleOnFE.java | 48 +++++++++++++------ .../expressions/ExpressionEvaluator.java | 3 +- .../functions/agg/AggregateFunction.java | 5 ++ .../generator/TableGeneratingFunction.java | 5 ++ .../rules/expression/FoldConstantTest.java | 30 +++++++++++- 5 files changed, 74 insertions(+), 17 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java index d53976d8ae23cf..6c2e3c44a955f5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java @@ -74,6 +74,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.EncryptKeyRef; import org.apache.doris.nereids.trees.expressions.functions.scalar.If; import org.apache.doris.nereids.trees.expressions.functions.scalar.LastQueryId; +import org.apache.doris.nereids.trees.expressions.functions.scalar.NullIf; import org.apache.doris.nereids.trees.expressions.functions.scalar.Nvl; import org.apache.doris.nereids.trees.expressions.functions.scalar.Password; import org.apache.doris.nereids.trees.expressions.functions.scalar.SessionUser; @@ -186,6 +187,7 @@ public List> buildRules() { matches(SessionUser.class, this::visitSessionUser), matches(LastQueryId.class, this::visitLastQueryId), matches(Nvl.class, this::visitNvl), + matches(NullIf.class, this::visitNullIf), matches(Match.class, this::visitMatch) ); } @@ -551,9 +553,6 @@ public Expression visitTryCast(TryCast cast, ExpressionRewriteContext context) { @Override public Expression visitBoundFunction(BoundFunction boundFunction, ExpressionRewriteContext context) { - if (!boundFunction.foldable()) { - return boundFunction; - } boundFunction = rewriteChildren(boundFunction, context); Optional checkedExpr = preProcess(boundFunction); if (checkedExpr.isPresent()) { @@ -735,18 +734,39 @@ public Expression visitVersion(Version version, ExpressionRewriteContext context public Expression visitNvl(Nvl nvl, ExpressionRewriteContext context) { Nvl originNvl = nvl; nvl = rewriteChildren(nvl, context); - - for (Expression expr : nvl.children()) { - if (expr.isLiteral()) { - if (!expr.isNullLiteral()) { - return TypeCoercionUtils.ensureSameResultType(originNvl, expr, context); - } - } else { - return TypeCoercionUtils.ensureSameResultType(originNvl, nvl, context); + Expression first = nvl.left(); + Expression second = nvl.right(); + Expression result = nvl; + if (first.equals(second) || second.isNullLiteral() || (first.isLiteral() && !first.isNullLiteral())) { + result = first; + } else if (first.isNullLiteral()) { + result = second; + } + return TypeCoercionUtils.ensureSameResultType(originNvl, result, context); + } + + @Override + public Expression visitNullIf(NullIf nullIf, ExpressionRewriteContext context) { + NullIf originNullIf = nullIf; + nullIf = rewriteChildren(nullIf, context); + Expression first = nullIf.left(); + Expression second = nullIf.right(); + Expression result = nullIf; + // if first is null, then first = second will be null + if (first.isNullLiteral() || second.isNullLiteral()) { + result = first; + } else if (first.equals(second)) { + // even if first is null, then first = second will be null, then result is first, so the result is also null + result = new NullLiteral(originNullIf.getDataType()); + } else if (first.isLiteral() && second.isLiteral()) { + Expression isEqual = visitEqualTo(new EqualTo(first, second), context); + if (isEqual.equals(BooleanLiteral.TRUE)) { + result = new NullLiteral(originNullIf.getDataType()); + } else if (isEqual.equals(BooleanLiteral.FALSE) || isEqual.isNullLiteral()) { + result = first; } } - // all nulls - return TypeCoercionUtils.ensureSameResultType(originNvl, nvl.child(0), context); + return TypeCoercionUtils.ensureSameResultType(originNullIf, result, context); } private E rewriteChildren(E expr, ExpressionRewriteContext context) { @@ -787,7 +807,7 @@ private E rewriteChildren(E expr, ExpressionRewriteContex } private Optional preProcess(Expression expression) { - if (expression instanceof AggregateFunction || expression instanceof TableGeneratingFunction) { + if (!expression.foldable()) { return Optional.of(expression); } if (ExpressionUtils.hasNullLiteral(expression.getArguments()) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java index 0b77118eb1835f..96bdb3ba257314 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java @@ -19,7 +19,6 @@ import org.apache.doris.nereids.exceptions.NotSupportedException; import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; -import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.executable.DateTimeAcquire; import org.apache.doris.nereids.trees.expressions.functions.executable.DateTimeArithmetic; import org.apache.doris.nereids.trees.expressions.functions.executable.DateTimeExtractAndTransform; @@ -56,7 +55,7 @@ public enum ExpressionEvaluator { * Evaluate the value of the expression. */ public Expression eval(Expression expression) { - if (!(expression.isConstant() || expression.foldable()) || expression instanceof AggregateFunction) { + if (!(expression.isConstant() || expression.foldable())) { return expression; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java index 41ed2bbb2c68f6..e356811fdefc47 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java @@ -122,6 +122,11 @@ public int computeHashCode() { return Objects.hash(distinct, getName(), children); } + @Override + public boolean foldable() { + return false; + } + @Override public R accept(ExpressionVisitor visitor, C context) { return visitor.visitAggregateFunction(this, context); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/generator/TableGeneratingFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/generator/TableGeneratingFunction.java index b67f7c1df623c5..b5b03002bd1be6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/generator/TableGeneratingFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/generator/TableGeneratingFunction.java @@ -52,4 +52,9 @@ public R accept(ExpressionVisitor visitor, C context) { protected GeneratorFunctionParams getFunctionParams(List arguments) { return new GeneratorFunctionParams(this, getName(), arguments, isInferred()); } + + @Override + public boolean foldable() { + return false; + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/FoldConstantTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/FoldConstantTest.java index f9d4782c685008..5665ff0435501f 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/FoldConstantTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/FoldConstantTest.java @@ -1482,7 +1482,8 @@ void testFoldNvl() { assertRewriteExpression("nvl(NULL, 1)", "1"); assertRewriteExpression("nvl(NULL, NULL)", "NULL"); - assertRewriteAfterTypeCoercion("nvl(IA, NULL)", "ifnull(IA, NULL)"); + assertRewriteAfterTypeCoercion("nvl(IA, NULL)", "IA"); + assertRewriteAfterTypeCoercion("nvl(IA, IA)", "IA"); assertRewriteAfterTypeCoercion("nvl(IA, 1)", "ifnull(IA, 1)"); Expression foldNvl = executor.rewrite( @@ -1492,6 +1493,33 @@ void testFoldNvl() { Assertions.assertEquals(new DateTimeV2Literal(DateTimeV2Type.of(6), "2025-04-17"), foldNvl); } + @Test + void testFoldNullIf() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp( + FoldConstantRule.INSTANCE + ) + )); + assertRewriteAfterTypeCoercion("nullif(a, b)", "nullif(a, b)"); + assertRewriteAfterTypeCoercion("nullif(a, a)", "null"); + assertRewriteAfterTypeCoercion("nullif(a, null)", "a"); + assertRewriteAfterTypeCoercion("nullif(null, a)", "null"); + assertRewriteAfterTypeCoercion("nullif(1, 1)", "null"); + assertRewriteAfterTypeCoercion("nullif(1, 2)", "1"); + } + + @Test + void testNonFoldable() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp( + FoldConstantRule.INSTANCE + ) + )); + assertRewriteAfterTypeCoercion("random(0, 1)", "random(0, 1)"); + assertRewriteAfterTypeCoercion("sum(1 + 2)", "sum(3)"); + assertRewriteAfterTypeCoercion("explode([1, 2, 3])", "explode([1, 2, 3])"); + } + private void assertRewriteExpression(String actualExpression, String expectedExpression) { ExpressionRewriteContext context = new ExpressionRewriteContext( MemoTestUtils.createCascadesContext(new UnboundRelation(new RelationId(1), ImmutableList.of("test_table")))); From 03a2ff6df4adcb41f34a647f8e317b09fd1d6d3a Mon Sep 17 00:00:00 2001 From: yujun Date: Tue, 14 Oct 2025 12:29:21 +0800 Subject: [PATCH 2/3] fix check style --- .../nereids/rules/expression/rules/FoldConstantRuleOnFE.java | 1 - 1 file changed, 1 deletion(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java index 6c2e3c44a955f5..1560ab55370e64 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java @@ -64,7 +64,6 @@ import org.apache.doris.nereids.trees.expressions.functions.PropagateNullLiteral; import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; -import org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction; import org.apache.doris.nereids.trees.expressions.functions.scalar.Array; import org.apache.doris.nereids.trees.expressions.functions.scalar.ConnectionId; import org.apache.doris.nereids.trees.expressions.functions.scalar.CurrentCatalog; From b7e19fcdfa34adf6e1c6ff4830ae3846e7b8fac4 Mon Sep 17 00:00:00 2001 From: yujun Date: Fri, 17 Oct 2025 12:50:08 +0800 Subject: [PATCH 3/3] fix test --- .../jdbc/test_mysql_jdbc_catalog.groovy | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/regression-test/suites/external_table_p0/jdbc/test_mysql_jdbc_catalog.groovy b/regression-test/suites/external_table_p0/jdbc/test_mysql_jdbc_catalog.groovy index 75915308fba36e..0b4a8dfb8efc09 100644 --- a/regression-test/suites/external_table_p0/jdbc/test_mysql_jdbc_catalog.groovy +++ b/regression-test/suites/external_table_p0/jdbc/test_mysql_jdbc_catalog.groovy @@ -394,18 +394,18 @@ suite("test_mysql_jdbc_catalog", "p0,external,mysql,external_docker,external_doc contains "QUERY: SELECT `timestamp0` FROM `doris_test`.`dt` WHERE (`timestamp0` > '2022-01-01 00:00:00')" } explain { - sql ("select k6, k8 from test1 where nvl(k6, null) = 1;") + sql ("select k6, k8 from test1 where nvl(k6, 1) = 1;") - contains "QUERY: SELECT `k6`, `k8` FROM `doris_test`.`test1` WHERE ((ifnull(`k6`, NULL) = 1))" + contains "QUERY: SELECT `k6`, `k8` FROM `doris_test`.`test1` WHERE ((ifnull(`k6`, 1) = 1))" } explain { - sql ("select k6, k8 from test1 where nvl(nvl(k6, null),null) = 1;") + sql ("select k6, k8 from test1 where nvl(nvl(k6, 1), 1) = 1;") - contains "QUERY: SELECT `k6`, `k8` FROM `doris_test`.`test1` WHERE ((ifnull(ifnull(`k6`, NULL), NULL) = 1))" + contains "QUERY: SELECT `k6`, `k8` FROM `doris_test`.`test1` WHERE ((ifnull(`k6`, 1) = 1))" } sql """ set enable_ext_func_pred_pushdown = "false"; """ explain { - sql ("select k6, k8 from test1 where nvl(k6, null) = 1 and k8 = 1;") + sql ("select k6, k8 from test1 where nvl(k6, 1) = 1 and k8 = 1;") contains "QUERY: SELECT `k6`, `k8` FROM `doris_test`.`test1` WHERE ((`k8` = 1))" }