diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java index 2b5b9f62fe913d..4b6e9b6e09b1ef 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java @@ -49,6 +49,7 @@ import org.apache.doris.nereids.trees.expressions.CaseWhen; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; +import org.apache.doris.nereids.trees.expressions.CompoundPredicate; import org.apache.doris.nereids.trees.expressions.Divide; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.ExprId; @@ -120,13 +121,19 @@ /** ExpressionAnalyzer */ public class ExpressionAnalyzer extends SubExprAnalyzer { + // This rule only used in unit test @VisibleForTesting public static final AbstractExpressionRewriteRule FUNCTION_ANALYZER_RULE = new AbstractExpressionRewriteRule() { @Override public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) { - return new ExpressionAnalyzer( - null, new Scope(ImmutableList.of()), null, false, false - ).analyze(expr, ctx); + return new ExpressionAnalyzer(null, new Scope(ImmutableList.of()), null, false, false) { + @Override + protected Expression processCompoundNewChildren(CompoundPredicate cp, List newChildren) { + // ExpressionUtils.and/ExpressionUtils.or will remove duplicate children, and simplify FALSE / TRUE. + // But we don't want to simplify them in unit test. + return cp.withChildren(newChildren); + } + }.analyze(expr, ctx); } }; @@ -594,7 +601,7 @@ public Expression visitOr(Or or, ExpressionRewriteContext context) { newChildren.add(newChild); } if (hasNewChild) { - return ExpressionUtils.or(newChildren); + return processCompoundNewChildren(or, newChildren); } else { return or; } @@ -616,18 +623,26 @@ public Expression visitAnd(And and, ExpressionRewriteContext context) { newChild = TypeCoercionUtils.castIfNotSameType(newChild, BooleanType.INSTANCE); } - if (! child.equals(newChild)) { + if (!child.equals(newChild)) { hasNewChild = true; } newChildren.add(newChild); } if (hasNewChild) { - return ExpressionUtils.and(newChildren); + return processCompoundNewChildren(and, newChildren); } else { return and; } } + protected Expression processCompoundNewChildren(CompoundPredicate cp, List newChildren) { + if (cp instanceof And) { + return ExpressionUtils.and(newChildren); + } else { + return ExpressionUtils.or(newChildren); + } + } + @Override public Expression visitNot(Not not, ExpressionRewriteContext context) { // maybe is `not subquery`, we should bind it first diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java index a7e6f8055167c9..7c751700e98acf 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java @@ -25,6 +25,7 @@ import org.apache.doris.nereids.rules.expression.rules.DistinctPredicatesRule; import org.apache.doris.nereids.rules.expression.rules.ExtractCommonFactorRule; import org.apache.doris.nereids.rules.expression.rules.LikeToEqualRewrite; +import org.apache.doris.nereids.rules.expression.rules.NestedCaseWhenCondToLiteral; import org.apache.doris.nereids.rules.expression.rules.NullSafeEqualToEqual; import org.apache.doris.nereids.rules.expression.rules.ReplaceNullWithFalseForCond; import org.apache.doris.nereids.rules.expression.rules.SimplifyComparisonPredicate; @@ -59,6 +60,7 @@ public class ExpressionOptimization extends ExpressionRewrite { DateFunctionRewrite.INSTANCE, ArrayContainToArrayOverlap.INSTANCE, ReplaceNullWithFalseForCond.INSTANCE, + NestedCaseWhenCondToLiteral.INSTANCE, CaseWhenToIf.INSTANCE, TopnToMax.INSTANCE, NullSafeEqualToEqual.INSTANCE, diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java index 1cb43a3113dbc4..823dbd49b93bd5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java @@ -42,6 +42,7 @@ public enum ExpressionRuleType { LIKE_TO_EQUAL, MERGE_DATE_TRUNC, MEDIAN_CONVERT, + NESTED_CASE_WHEN_COND_TO_LITERAL, NORMALIZE_BINARY_PREDICATES, NULL_SAFE_EQUAL_TO_EQUAL, REPLACE_NULL_WITH_FALSE_FOR_COND, diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java index bfc983915b37de..60f05f52f6270a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java @@ -106,12 +106,13 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; import com.google.common.collect.Lists; +import com.google.common.collect.Sets; import org.apache.commons.codec.digest.DigestUtils; -import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.Set; import java.util.function.BiFunction; import java.util.function.Predicate; @@ -576,57 +577,64 @@ public Expression visitBinaryArithmetic(BinaryArithmetic binaryArithmetic, Expre public Expression visitCaseWhen(CaseWhen caseWhen, ExpressionRewriteContext context) { CaseWhen originCaseWhen = caseWhen; caseWhen = rewriteChildren(caseWhen, context); - Expression newDefault = null; - boolean foundNewDefault = false; - - List whenClauses = new ArrayList<>(); + final Expression oldDefault = caseWhen.getDefaultValue().orElse(null); + Expression newDefault = oldDefault; + ImmutableList.Builder whenClausesBuilder + = ImmutableList.builderWithExpectedSize(caseWhen.getWhenClauses().size()); + Set uniqueOperands = Sets.newHashSet(); for (WhenClause whenClause : caseWhen.getWhenClauses()) { Expression whenOperand = whenClause.getOperand(); - - if (!(whenOperand.isLiteral())) { - whenClauses.add(new WhenClause(whenOperand, whenClause.getResult())); + if (!whenOperand.isLiteral() && uniqueOperands.add(whenOperand)) { + whenClausesBuilder.add(new WhenClause(whenOperand, whenClause.getResult())); } else if (BooleanLiteral.TRUE.equals(whenOperand)) { - foundNewDefault = true; newDefault = whenClause.getResult(); break; } } - - Expression defaultResult = null; - if (caseWhen.getDefaultValue().isPresent()) { - defaultResult = caseWhen.getDefaultValue().get(); - } - if (foundNewDefault) { - defaultResult = newDefault; + List newWhenClauses = whenClausesBuilder.build(); + Expression realTypeCoercionDefault = newDefault != null ? newDefault : new NullLiteral(caseWhen.getDataType()); + boolean allThenEqualsDefault = true; + for (WhenClause whenClause : newWhenClauses) { + if (!whenClause.getResult().equals(realTypeCoercionDefault)) { + allThenEqualsDefault = false; + break; + } } - if (whenClauses.isEmpty()) { - return TypeCoercionUtils.ensureSameResultType( - originCaseWhen, defaultResult == null ? new NullLiteral(caseWhen.getDataType()) : defaultResult, - context - ); + if (allThenEqualsDefault) { + return realTypeCoercionDefault; } - if (defaultResult == null) { - if (caseWhen.getDataType().isNullType()) { - // if caseWhen's type is NULL_TYPE, means all possible return values are nulls - // it's safe to return null literal here - return new NullLiteral(); - } else { - return TypeCoercionUtils.ensureSameResultType(originCaseWhen, new CaseWhen(whenClauses), context); + boolean hasNewChildren = newWhenClauses.size() != caseWhen.getWhenClauses().size() + || newDefault != oldDefault; + if (newWhenClauses.size() == caseWhen.getWhenClauses().size()) { + for (int i = 0; i < newWhenClauses.size(); i++) { + if (newWhenClauses.get(i) != caseWhen.getWhenClauses().get(i)) { + hasNewChildren = true; + break; + } } } - return TypeCoercionUtils.ensureSameResultType( - originCaseWhen, new CaseWhen(whenClauses, defaultResult), context - ); + if (hasNewChildren) { + caseWhen = newDefault == null + ? new CaseWhen(newWhenClauses) : new CaseWhen(newWhenClauses, newDefault); + } + return TypeCoercionUtils.ensureSameResultType(originCaseWhen, caseWhen, context); } @Override public Expression visitIf(If ifExpr, ExpressionRewriteContext context) { If originIf = ifExpr; ifExpr = rewriteChildren(ifExpr, context); - if (ifExpr.child(0) instanceof NullLiteral || ifExpr.child(0).equals(BooleanLiteral.FALSE)) { - return TypeCoercionUtils.ensureSameResultType(originIf, ifExpr.child(2), context); - } else if (ifExpr.child(0).equals(BooleanLiteral.TRUE)) { - return TypeCoercionUtils.ensureSameResultType(originIf, ifExpr.child(1), context); + Expression condition = ifExpr.getCondition(); + Expression typeCoercionTrueValue + = TypeCoercionUtils.ensureSameResultType(originIf, ifExpr.getTrueValue(), context); + Expression typeCoercionFalseValue + = TypeCoercionUtils.ensureSameResultType(originIf, ifExpr.getFalseValue(), context); + if (condition.equals(BooleanLiteral.TRUE)) { + return typeCoercionTrueValue; + } else if (condition.equals(BooleanLiteral.FALSE) || condition.isNullLiteral()) { + return typeCoercionFalseValue; + } else if (typeCoercionTrueValue.equals(typeCoercionFalseValue)) { + return typeCoercionTrueValue; } return TypeCoercionUtils.ensureSameResultType(originIf, ifExpr, context); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NestedCaseWhenCondToLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NestedCaseWhenCondToLiteral.java new file mode 100644 index 00000000000000..29ddf97e786db0 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NestedCaseWhenCondToLiteral.java @@ -0,0 +1,226 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression.rules; + +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionRuleType; +import org.apache.doris.nereids.trees.expressions.CaseWhen; +import org.apache.doris.nereids.trees.expressions.CompoundPredicate; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.WhenClause; +import org.apache.doris.nereids.trees.expressions.functions.scalar.If; +import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; +import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; + +import java.util.List; +import java.util.Map; + +/** + * For nested CaseWhen/IF expression, replace the inner CaseWhen/IF condition with TRUE/FALSE literal + * when the condition also exists in the outer CaseWhen/IF conditions. + * + * on the nested CASE/IF path, a condition may exist in multiple CASE/IF branches, + * for any inner case when or if condition, its boolean value is determined by the outermost CASE/IF branch, + * that is the first occurrence of the condition on the nested CASE/IF path. + * + *
+ * 1. if it exists in outer case's current branch condition, replace it with TRUE + * e.g. + * case when A then + * (case when A then 1 else 2 end) + * ... + * end + * then inner case condition A will replace with TRUE: + * case when A then + * (case when TRUE then 1 else 2 end) + * ... + * end + *
+ * 2. if it exists in outer case's previous branch condition, replace it with FALSE + * e.g. + * case when A then ... + * when B then + * (case when A then 1 else 2 end) + * ... + * end + * then inner case condition A will replace with FALSE: + * case when A then ... + * when B then + * (case when FALSE then 1 else 2 end) + * ... + * end + *
+ */ +public class NestedCaseWhenCondToLiteral implements ExpressionPatternRuleFactory { + + public static final NestedCaseWhenCondToLiteral INSTANCE = new NestedCaseWhenCondToLiteral(); + + @Override + public List> buildRules() { + return ImmutableList.of( + root(Expression.class) + .when(this::needRewrite) + .thenApply(ctx -> rewrite(ctx.expr, ctx.rewriteContext)) + .toRule(ExpressionRuleType.NESTED_CASE_WHEN_COND_TO_LITERAL) + ); + } + + private boolean needRewrite(Expression expression) { + return expression.containsType(CaseWhen.class, If.class); + } + + private Expression rewrite(Expression expression, ExpressionRewriteContext context) { + return expression.accept(new NestedCondReplacer(), null); + } + + /** NestedCondReplacer */ + @VisibleForTesting + public static class NestedCondReplacer extends DefaultExpressionRewriter { + + // condition literals is used to record the boolean literal for a condition expression, + // 1. if a condition, if it exists in outer case/if conditions, it will be replaced with the literal. + // 2. otherwise it's the first time occur, then: + // a) when enter a case/if branch, set this condition to TRUE literal + // b) when leave a case/if branch, set this condition to FALSE literal + // c) when leave the whole case/if statement, remove this condition literal + protected final Map conditionLiterals = Maps.newHashMap(); + + @Override + public Expression visit(Expression expr, Void context) { + if (INSTANCE.needRewrite(expr)) { + return super.visit(expr, context); + } else { + return expr; + } + } + + @Override + public Expression visitCaseWhen(CaseWhen caseWhen, Void context) { + ImmutableList.Builder newWhenClausesBuilder + = ImmutableList.builderWithExpectedSize(caseWhen.arity()); + List firstOccurConds = Lists.newArrayListWithExpectedSize(caseWhen.arity()); + for (WhenClause whenClause : caseWhen.getWhenClauses()) { + Expression oldCondition = whenClause.getOperand(); + Pair replaceResult = replaceCondition(oldCondition, context); + Expression newCondition = replaceResult.first; + boolean condFirstOccur = replaceResult.second; + if (condFirstOccur) { + firstOccurConds.add(oldCondition); + conditionLiterals.put(oldCondition, BooleanLiteral.TRUE); + } + Expression newResult = whenClause.getResult().accept(this, context); + if (condFirstOccur) { + conditionLiterals.put(oldCondition, BooleanLiteral.FALSE); + } + if (whenClause.getOperand() != newCondition || whenClause.getResult() != newResult) { + newWhenClausesBuilder.add(new WhenClause(newCondition, newResult)); + } else { + newWhenClausesBuilder.add(whenClause); + } + } + Expression oldDefaultValue = caseWhen.getDefaultValue().orElse(null); + Expression newDefaultValue = oldDefaultValue; + if (newDefaultValue != null) { + newDefaultValue = newDefaultValue.accept(this, context); + } + for (Expression cond : firstOccurConds) { + conditionLiterals.remove(cond); + } + List newWhenClauses = newWhenClausesBuilder.build(); + boolean hasNewChildren = false; + if (newWhenClauses.size() != caseWhen.getWhenClauses().size()) { + hasNewChildren = true; + } else { + for (int i = 0; i < newWhenClauses.size(); i++) { + if (newWhenClauses.get(i) != caseWhen.getWhenClauses().get(i)) { + hasNewChildren = true; + break; + } + } + } + if (newDefaultValue != oldDefaultValue) { + hasNewChildren = true; + } + if (hasNewChildren) { + return newDefaultValue != null + ? new CaseWhen(newWhenClauses, newDefaultValue) + : new CaseWhen(newWhenClauses); + } else { + return caseWhen; + } + } + + @Override + public Expression visitIf(If ifExpr, Void context) { + Expression oldCondition = ifExpr.getCondition(); + Pair replaceResult = replaceCondition(oldCondition, context); + Expression newCondition = replaceResult.first; + boolean condFirstOccur = replaceResult.second; + if (condFirstOccur) { + conditionLiterals.put(oldCondition, BooleanLiteral.TRUE); + } + Expression newTrueValue = ifExpr.getTrueValue().accept(this, context); + if (condFirstOccur) { + conditionLiterals.put(oldCondition, BooleanLiteral.FALSE); + } + Expression newFalseValue = ifExpr.getFalseValue().accept(this, context); + if (condFirstOccur) { + conditionLiterals.remove(oldCondition); + } + if (newCondition != oldCondition + || newTrueValue != ifExpr.getTrueValue() + || newFalseValue != ifExpr.getFalseValue()) { + return new If(newCondition, newTrueValue, newFalseValue); + } else { + return ifExpr; + } + } + + // return newCondition + condition first occur flag + private Pair replaceCondition(Expression condition, Void context) { + if (condition.isLiteral()) { + // literal condition do not need to replace, and do not record it + return Pair.of(condition, false); + } else if (conditionLiterals.containsKey(condition)) { + return Pair.of(conditionLiterals.get(condition), false); + } else if (condition instanceof CompoundPredicate) { + ImmutableList.Builder newChildrenBuilder + = ImmutableList.builderWithExpectedSize(condition.arity()); + boolean hasNewChildren = false; + for (Expression child : condition.children()) { + Expression newChild = replaceCondition(child, context).first; + hasNewChildren = hasNewChildren || newChild != child; + newChildrenBuilder.add(newChild); + } + Expression newCondition = hasNewChildren + ? condition.withChildren(newChildrenBuilder.build()) : condition; + return Pair.of(newCondition, true); + } else { + return Pair.of(condition.accept(this, context), true); + } + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/And.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/And.java index 683960cecaa3ff..0ab5ee108c1713 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/And.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/And.java @@ -35,9 +35,9 @@ public class And extends CompoundPredicate { * @param right right child of comparison predicate */ public And(Expression left, Expression right) { - super(ExpressionUtils.mergeList( + this(ExpressionUtils.mergeList( ExpressionUtils.extractConjunction(left), - ExpressionUtils.extractConjunction(right)), "AND"); + ExpressionUtils.extractConjunction(right))); } public And(List children) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Or.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Or.java index 20d2605f561127..235c1bc2f0a70b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Or.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Or.java @@ -36,9 +36,9 @@ public class Or extends CompoundPredicate { * @param right right child of comparison predicate */ public Or(Expression left, Expression right) { - super(ExpressionUtils.mergeList( + this(ExpressionUtils.mergeList( ExpressionUtils.extractDisjunction(left), - ExpressionUtils.extractDisjunction(right)), "OR"); + ExpressionUtils.extractDisjunction(right))); } public Or(List children) { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/FoldConstantTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/FoldConstantTest.java index 0ba1f18c0549f0..2a27f253f6d5d9 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/FoldConstantTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/FoldConstantTest.java @@ -155,6 +155,18 @@ void testCaseWhenFold() { assertRewriteAfterTypeCoercion("case when null = 2 then 1 else 4 end", "4"); assertRewriteAfterTypeCoercion("case when null = 2 then 1 end", "null"); assertRewriteAfterTypeCoercion("case when TA = TB then 1 when TC is null then 2 end", "CASE WHEN (TA = TB) THEN 1 WHEN TC IS NULL THEN 2 END"); + assertRewriteAfterTypeCoercion("case when a > 1 then a + 1 when a > 1 then a + 10 when a > 2 then a + 2 else a + 100 end", + "case when a > 1 then a + 1 when a > 2 then a + 2 else a + 100 end"); + assertRewriteAfterTypeCoercion("case when a > 1 then a + 1 when a > 2 then a + 1 when a > 3 then a + 1 else a + 1 end", + "a + 1"); + assertRewriteAfterTypeCoercion("case when a > 1 then a + 1 when a > 2 then a + 1 when a > 3 then a + 1 end", + "case when a > 1 then a + 1 when a > 2 then a + 1 when a > 3 then a + 1 end"); + assertRewriteAfterTypeCoercion("case when null then 1 when false then 2 when a > 3 then 3 when a > 4 then 4 end", + "case when a > 3 then 3 when a > 4 then 4 end"); + assertRewriteAfterTypeCoercion("case when null then 1 when false then 2 when a > 3 then 3 when true then 0 when a > 4 then 4 end", + "case when a > 3 then 3 else 0 end"); + assertRewriteAfterTypeCoercion("case when true then 100 when a > 1 then a + 1 when a > 1 then a + 10 when a > 2 then a + 2 else a + 100 end", + "100"); // make sure the case when return datetime(6) Expression analyzedCaseWhen = ExpressionAnalyzer.analyzeFunction(null, null, PARSER.parseExpression( @@ -166,6 +178,16 @@ void testCaseWhenFold() { Assertions.assertEquals(new DateTimeV2Literal(DateTimeV2Type.of(6), "2025-04-17"), foldCaseWhen); } + @Test + void testIfFold() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(FoldConstantRuleOnFE.VISITOR_INSTANCE) + )); + assertRewriteAfterTypeCoercion("if(true, a + 1, a + 2)", "a + 1"); + assertRewriteAfterTypeCoercion("if(false, a + 1, a + 2)", "a + 2"); + assertRewriteAfterTypeCoercion("if(b > 0, a + 100, a + 100)", "a + 100"); + } + @Test void testInFold() { executor = new ExpressionRuleExecutor(ImmutableList.of( diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java index e6e856852d0c13..5fff6e51aceb61 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java @@ -230,7 +230,7 @@ public void testSimplify() { assertRewrite("TA + TC = 1 and TA + TC = 3", "(TA + TC) is null and null"); assertRewriteNotNull("TA + TC in (1) and TA + TC in (3)", "FALSE"); assertRewrite("TA + TC in (1) and TA + TC in (3)", "(TA + TC) is null and null"); - assertRewrite("TA + TC in (1) and TA + TC in (1)", "TA + TC in (1)"); + assertRewrite("TA + TC in (1) and TA + TC in (1)", "TA + TC = 1"); assertRewriteNotNull("(TA + TC > 3 and TA + TC < 1) and TB < 5", "FALSE"); assertRewrite("(TA + TC > 3 and TA + TC < 1) and TB < 5", "(TA + TC) is null and null and TB < 5"); assertRewrite("(TA + TC > 3 and TA + TC < 1) or TB < 5", "((TA + TC) is null and null) OR TB < 5"); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/NestedCaseWhenCondToLiteralTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/NestedCaseWhenCondToLiteralTest.java new file mode 100644 index 00000000000000..04af10abaa9e23 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/NestedCaseWhenCondToLiteralTest.java @@ -0,0 +1,334 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression.rules; + +import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper; +import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor; +import org.apache.doris.nereids.trees.expressions.CaseWhen; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.GreaterThan; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.WhenClause; +import org.apache.doris.nereids.trees.expressions.functions.scalar.If; +import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; +import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; +import org.apache.doris.nereids.types.IntegerType; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Maps; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; + +class NestedCaseWhenCondToLiteralTest extends ExpressionRewriteTestHelper { + + @Test + void testNestedCaseWhen() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(NestedCaseWhenCondToLiteral.INSTANCE) + )); + + assertRewriteAfterTypeCoercion( + "case when a > 1 then 1" + + " when a > 2 then" + + " (case when a > 1 then 2" + + " when a > 2 then 3" + + " when a > 1 and a > 1 and a > 2 and a > 2 and a > 3 then 100" + + " when a > 3 then (case when a > 1 then 4" + + " when a > 2 then 5" + + " when a > 3 then 6" + + " end)" + + " when a > 1 and a > 1 and a > 2 and a > 2 and a > 3 then 101" + + " end)" + + " when (case when a > 1 then a > 1" + + " when a > 2 then a > 2" + + " when a > 3 then a > 3" + + " when a > 1 then a > 1" + + " end) then 100" + + " when a > 3 then 7" + + " when a > 1 then 8" + + " else (case when a > 1 then 9" + + " when a > 2 then 10" + + " when a > 3 then 11" + + " when a > 4 then 12" + + " else (case when a > 1 then 13" + + " when a > 2 then 14" + + " when a > 3 then 15" + + " when a > 4 then 16" + + " when a > 5 then (case when a > 1 then 17 when a > 5 then 18 end)" + + " end)" + + " end)" + + " end", + "case when a > 1 then 1" + + " when a > 2 then" + + " (case when false then 2" + + " when true then 3" + + " when false and false and true and true and a > 3 then 100" + + " when a > 3 then (case when false then 4" + + " when true then 5" + + " when true then 6" + + " end)" + + " when false then 101" + + " end)" + + " when (case when false then a > 1" + + " when false then a > 2" + + " when a > 3 then a > 3" + + " when false then a > 1" + + " end) then 100" + + " when a > 3 then 7" + + " when false then 8" + + " else (case when false then 9" + + " when false then 10" + + " when false then 11" + + " when a > 4 then 12" + + " else (case when false then 13" + + " when false then 14" + + " when false then 15" + + " when false then 16" + + " when a > 5 then (case when false then 17 when true then 18 end)" + + " end)" + + " end)" + + " end" + ); + } + + @Test + void testNestedIf() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(NestedCaseWhenCondToLiteral.INSTANCE) + )); + assertRewriteAfterTypeCoercion( + "if(" + + " a > 1," + + " if(" + + " a > 1," + + " if(" + + " a > 2," + + " if(a > 2,a + 2,a + 3)," + + " if(" + + " a > 1," + + " if(a > 2,a + 3,a + 4)," + + " if(a > 2,a + 5,a + 6)" + + " )" + + " )," + + " if(a > 1,a + 1,a + 2)" + + " )," + + " if(" + + " a > 1," + + " a + 5," + + " if(a > 2,a + 6,a + 7)" + + " )" + + ")", + "if(" + + " a > 1," + + " if(" + + " true," + + " if(" + + " a > 2," + + " if(true,a + 2,a + 3)," + + " if(" + + " true," + + " if(false,a + 3,a + 4)," + + " if(false,a + 5,a + 6)" + + " )" + + " )," + + " if(true,a + 1,a + 2)" + + " )," + + " if(" + + " false," + + " a + 5," + + " if(a > 2,a + 6,a + 7)" + + " )" + + ")" + ); + } + + @Test + void testNestedCaseWhenReplacer() { + // case when a > 1 then 101 + // when a > 2 then (case when a > 1 then 102 + // when a > 2 then 103 + // when a > 3 then 104 + // when a > 4 then 105 + // else 106 + // end) + // when a > 3 then 107 + // when a > 4 then 108 + // when a > 4 then 109 + // else 110 + // end + SlotReference a = new SlotReference("a", IntegerType.INSTANCE); + Expression c1 = new GreaterThan(a, IntegerLiteral.of(1)); + Expression c2 = new GreaterThan(a, IntegerLiteral.of(2)); + Expression c3 = new GreaterThan(a, IntegerLiteral.of(3)); + Expression c4 = new GreaterThan(a, IntegerLiteral.of(4)); + Expression i101 = IntegerLiteral.of(101); + Expression i102 = IntegerLiteral.of(102); + Expression i103 = IntegerLiteral.of(103); + Expression i104 = IntegerLiteral.of(104); + Expression i105 = IntegerLiteral.of(105); + Expression i106 = IntegerLiteral.of(106); + Expression i107 = IntegerLiteral.of(107); + Expression i108 = IntegerLiteral.of(108); + Expression i109 = IntegerLiteral.of(109); + Expression i110 = IntegerLiteral.of(110); + Expression innerCaseWhen = new CaseWhen( + ImmutableList.of( + new WhenClause(c1, i102), + new WhenClause(c2, i103), + new WhenClause(c3, i104), + new WhenClause(c4, i105)), + i106); + Expression outerCaseWhen = new CaseWhen( + ImmutableList.of( + new WhenClause(c1, i101), + new WhenClause(c2, innerCaseWhen), + new WhenClause(c3, i107), + new WhenClause(c4, i108), + new WhenClause(c4, i109)), + i110); + TestNestedCondReplacer replacer = new TestNestedCondReplacer(); + outerCaseWhen.accept(replacer, null); + replacer.checkExpressionReplaceLiterals(outerCaseWhen, + ImmutableList.of(), + ImmutableList.of()); + replacer.checkExpressionReplaceLiterals(i101, + ImmutableList.of(c1), + ImmutableList.of()); + replacer.checkExpressionReplaceLiterals(i102, + ImmutableList.of(c2), + ImmutableList.of(c1)); + replacer.checkExpressionReplaceLiterals(i103, + ImmutableList.of(c2), + ImmutableList.of(c1)); + replacer.checkExpressionReplaceLiterals(i104, + ImmutableList.of(c2, c3), + ImmutableList.of(c1)); + replacer.checkExpressionReplaceLiterals(i105, + ImmutableList.of(c2, c4), + ImmutableList.of(c1, c3)); + replacer.checkExpressionReplaceLiterals(i106, + ImmutableList.of(c2), + ImmutableList.of(c1, c3, c4)); + replacer.checkExpressionReplaceLiterals(i107, + ImmutableList.of(c3), + ImmutableList.of(c1, c2)); + replacer.checkExpressionReplaceLiterals(i108, + ImmutableList.of(c4), + ImmutableList.of(c1, c2, c3)); + replacer.checkExpressionReplaceLiterals(i109, + ImmutableList.of(), + ImmutableList.of(c1, c2, c3, c4)); + replacer.checkExpressionReplaceLiterals(i110, + ImmutableList.of(), + ImmutableList.of(c1, c2, c3, c4)); + + // after rewrite, the condition literals should clear + Assertions.assertEquals(Maps.newHashMap(), replacer.conditionLiterals); + } + + @Test + void testNestedIfReplacer() { + // if(a > 1, + // if(a > 2, + // if(a > 3, 301, 302), + // if(a > 4, 303, 304) + // ), + // if(a > 5, 305, 306) + // ) + SlotReference a = new SlotReference("a", IntegerType.INSTANCE); + Expression c1 = new GreaterThan(a, IntegerLiteral.of(1)); + Expression c2 = new GreaterThan(a, IntegerLiteral.of(2)); + Expression c3 = new GreaterThan(a, IntegerLiteral.of(3)); + Expression c4 = new GreaterThan(a, IntegerLiteral.of(4)); + Expression c5 = new GreaterThan(a, IntegerLiteral.of(5)); + Expression i301 = IntegerLiteral.of(301); + Expression i302 = IntegerLiteral.of(302); + Expression i303 = IntegerLiteral.of(303); + Expression i304 = IntegerLiteral.of(304); + Expression i305 = IntegerLiteral.of(305); + Expression i306 = IntegerLiteral.of(306); + Expression innerIf1 = new If(c3, i301, i302); + Expression innerIf2 = new If(c4, i303, i304); + Expression innerIf = new If(c2, innerIf1, innerIf2); + Expression outerIf = new If(c1, innerIf, new If(c5, i305, i306)); + TestNestedCondReplacer replacer = new TestNestedCondReplacer(); + outerIf.accept(replacer, null); + replacer.checkExpressionReplaceLiterals(outerIf, + ImmutableList.of(), + ImmutableList.of()); + replacer.checkExpressionReplaceLiterals(innerIf, + ImmutableList.of(c1), + ImmutableList.of()); + replacer.checkExpressionReplaceLiterals(innerIf1, + ImmutableList.of(c1, c2), + ImmutableList.of()); + replacer.checkExpressionReplaceLiterals(i301, + ImmutableList.of(c1, c2, c3), + ImmutableList.of()); + replacer.checkExpressionReplaceLiterals(i302, + ImmutableList.of(c1, c2), + ImmutableList.of(c3)); + replacer.checkExpressionReplaceLiterals(innerIf2, + ImmutableList.of(c1), + ImmutableList.of(c2)); + + // after rewrite, the condition literals should clear + Assertions.assertEquals(Maps.newHashMap(), replacer.conditionLiterals); + } + + private static class TestNestedCondReplacer extends NestedCaseWhenCondToLiteral.NestedCondReplacer { + private final Map> expressionReplaceMap = Maps.newHashMap(); + + @Override + public Expression visit(Expression expr, Void context) { + recordReplaceLiteral(expr); + return super.visit(expr, context); + } + + @Override + public Expression visitCaseWhen(CaseWhen caseWhen, Void context) { + recordReplaceLiteral(caseWhen); + return super.visitCaseWhen(caseWhen, context); + } + + @Override + public Expression visitIf(If ifExpr, Void context) { + recordReplaceLiteral(ifExpr); + return super.visitIf(ifExpr, context); + } + + private void recordReplaceLiteral(Expression expr) { + expressionReplaceMap.put(expr, Maps.newHashMap(conditionLiterals)); + } + + private void checkExpressionReplaceLiterals(Expression expression, + List trueConditions, List falseConditions) { + Map expectedReplaceMap = Maps.newHashMap(); + for (Expression trueCondition : trueConditions) { + expectedReplaceMap.put(trueCondition, BooleanLiteral.TRUE); + } + for (Expression falseCondition : falseConditions) { + expectedReplaceMap.put(falseCondition, BooleanLiteral.FALSE); + } + Assertions.assertEquals(expectedReplaceMap, expressionReplaceMap.get(expression)); + } + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/ReplaceNullWithFalseForCondTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/ReplaceNullWithFalseForCondTest.java index 14953b21cfa05d..56be05c87816b8 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/ReplaceNullWithFalseForCondTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/ReplaceNullWithFalseForCondTest.java @@ -17,21 +17,13 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.analyzer.Scope; -import org.apache.doris.nereids.rules.analysis.ExpressionAnalyzer; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper; import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor; -import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.Or; -import org.apache.doris.nereids.util.ExpressionUtils; import com.google.common.collect.ImmutableList; import org.junit.jupiter.api.Test; -import java.util.function.Function; - class ReplaceNullWithFalseForCondTest extends ExpressionRewriteTestHelper { private final ReplaceNullWithFalseForCond replaceCaseThenInstance = new ReplaceNullWithFalseForCond() { @@ -65,7 +57,7 @@ void testCaseWhen() { + " then (case when false then null else null end) " + " else null end"; - assertRewrite(sql, expectedSql); + assertRewriteAfterTypeCoercion(sql, expectedSql); executor = new ExpressionRuleExecutor(ImmutableList.of( bottomUp(replaceCaseThenInstance) @@ -80,7 +72,7 @@ void testCaseWhen() { + " then (case when false then false else false end) " + " else false end"; - assertRewrite(sql, expectedSql); + assertRewriteAfterTypeCoercion(sql, expectedSql); } @Test @@ -90,57 +82,29 @@ void testIf() { )); String sql = "if(" - + " null and not(null) and if(null and not(null), null, null)," + + " null and not(null) and if(null and not(null), null and true, null)," + " null and not(null)," + " if(a = 1 and null, null, null)" + ")"; String expectedSql = "if(" - + " false and not(null) and if(false and not(null), false, false)," + + " false and not(null) and if(false and not(null), false and true, false)," + " null and not(null)," + " if(a = 1 and false, null, null)" + ")"; - assertRewrite(sql, expectedSql); + assertRewriteAfterTypeCoercion(sql, expectedSql); executor = new ExpressionRuleExecutor(ImmutableList.of( bottomUp(replaceCaseThenInstance, SimplifyCastRule.INSTANCE) )); expectedSql = "if(" - + " false and not(null) and if(false and not(null), false, false)," + + " false and not(null) and if(false and not(null), false and true, false)," + " false and not(null)," + " if(a = 1 and false, false, false)" + ")"; - assertRewrite(sql, expectedSql); - } - - @Override - protected void assertRewrite(String sql, String expectedSql) { - Function converter = expr -> new ExpressionAnalyzer( - null, new Scope(ImmutableList.of()), null, false, false - ) { - // ExpressionAnalyzer will rewrite 'false and xxx' to 'false', but we want to keep the structure of the expression, - @Override - public Expression visitAnd(And and, ExpressionRewriteContext context) { - return new And( - ExpressionUtils.extractConjunction(and) - .stream() - .map(e -> e.accept(this, context)) - .collect(ImmutableList.toImmutableList())); - } - - @Override - public Expression visitOr(Or or, ExpressionRewriteContext context) { - return new Or( - ExpressionUtils.extractDisjunction(or) - .stream() - .map(e -> e.accept(this, context)) - .collect(ImmutableList.toImmutableList())); - } - }.analyze(expr, null); - - assertRewriteAfterConvert(sql, expectedSql, converter); + assertRewriteAfterTypeCoercion(sql, expectedSql); } } diff --git a/regression-test/suites/nereids_rules_p0/partition_prune/one_col_list_partition.groovy b/regression-test/suites/nereids_rules_p0/partition_prune/one_col_list_partition.groovy index 4de512376d1378..a174d275705759 100644 --- a/regression-test/suites/nereids_rules_p0/partition_prune/one_col_list_partition.groovy +++ b/regression-test/suites/nereids_rules_p0/partition_prune/one_col_list_partition.groovy @@ -224,7 +224,7 @@ suite("one_col_list_partition") { contains("VEMPTYSET") } explain { - sql "SELECT * FROM one_col_list_partition_date WHERE if(a>1, dt<'2001-1-01 00:00:00', dt<'2001-1-01 00:00:00')" + sql "SELECT * FROM one_col_list_partition_date WHERE if(a>1, dt<'2001-1-01 00:00:00', dt<'2001-1-03 00:00:00')" contains("partitions=8/9 (p1,p2,p3,p4,p5,p6,p7,p8)") } explain { @@ -246,4 +246,4 @@ suite("one_col_list_partition") { else '2023-01-01 00:00:00' end <'2021-01-06 00:00:00' ;""" contains("partitions=3/9 (p1,p2,p3)") } -} \ No newline at end of file +}