From 87ce88cb0a819ab7dcfa0b9e0308f520803c9ea5 Mon Sep 17 00:00:00 2001 From: yujun Date: Tue, 14 Oct 2025 16:42:07 +0800 Subject: [PATCH] rewrite case when to compound predicate --- .../rules/analysis/ExpressionAnalyzer.java | 4 +- .../expression/ExpressionOptimization.java | 2 + .../rules/expression/ExpressionRuleType.java | 2 + .../rules/CaseWhenToCompoundPredicate.java | 110 ++++++++++++++++++ .../rules/FoldConstantRuleOnFE.java | 9 +- .../rules/OneListPartitionEvaluator.java | 5 +- .../expressions/literal/NullLiteral.java | 4 +- .../doris/nereids/util/ExpressionUtils.java | 11 +- .../CaseWhenToCompoundPredicateTest.java | 59 ++++++++++ .../extend_infer_equal_predicate.out | 18 +-- 10 files changed, 199 insertions(+), 25 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/CaseWhenToCompoundPredicate.java create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/CaseWhenToCompoundPredicateTest.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java index 4b6e9b6e09b1ef..4b691e1dda55df 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java @@ -590,7 +590,7 @@ public Expression visitOr(Or or, ExpressionRewriteContext context) { newChild = child; } if (newChild.getDataType().isNullType()) { - newChild = new NullLiteral(BooleanType.INSTANCE); + newChild = NullLiteral.BOOLEAN_INSTANCE; } else { newChild = TypeCoercionUtils.castIfNotSameType(newChild, BooleanType.INSTANCE); } @@ -618,7 +618,7 @@ public Expression visitAnd(And and, ExpressionRewriteContext context) { newChild = child; } if (newChild.getDataType().isNullType()) { - newChild = new NullLiteral(BooleanType.INSTANCE); + newChild = NullLiteral.BOOLEAN_INSTANCE; } else { newChild = TypeCoercionUtils.castIfNotSameType(newChild, BooleanType.INSTANCE); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java index 7c751700e98acf..9be1e868365bed 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java @@ -20,6 +20,7 @@ import org.apache.doris.nereids.rules.expression.rules.AddMinMax; import org.apache.doris.nereids.rules.expression.rules.ArrayContainToArrayOverlap; import org.apache.doris.nereids.rules.expression.rules.BetweenToEqual; +import org.apache.doris.nereids.rules.expression.rules.CaseWhenToCompoundPredicate; import org.apache.doris.nereids.rules.expression.rules.CaseWhenToIf; import org.apache.doris.nereids.rules.expression.rules.DateFunctionRewrite; import org.apache.doris.nereids.rules.expression.rules.DistinctPredicatesRule; @@ -62,6 +63,7 @@ public class ExpressionOptimization extends ExpressionRewrite { ReplaceNullWithFalseForCond.INSTANCE, NestedCaseWhenCondToLiteral.INSTANCE, CaseWhenToIf.INSTANCE, + CaseWhenToCompoundPredicate.INSTANCE, TopnToMax.INSTANCE, NullSafeEqualToEqual.INSTANCE, LikeToEqualRewrite.INSTANCE, diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java index 823dbd49b93bd5..25ff76a3bc6f04 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java @@ -24,6 +24,7 @@ public enum ExpressionRuleType { ADD_MIN_MAX, ARRAY_CONTAIN_TO_ARRAY_OVERLAP, BETWEEN_TO_EQUAL, + CASE_WHEN_TO_COMPOUND_PREDICATE, CASE_WHEN_TO_IF, CHECK_CAST, CONVERT_AGG_STATE_CAST, @@ -36,6 +37,7 @@ public enum ExpressionRuleType { FOLD_CONSTANT_ON_BE, FOLD_CONSTANT_ON_FE, LOG_TO_LN, + IF_TO_COMPOUND_PREDICATE, IN_PREDICATE_DEDUP, IN_PREDICATE_EXTRACT_NON_CONSTANT, IN_PREDICATE_TO_EQUAL_TO, diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/CaseWhenToCompoundPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/CaseWhenToCompoundPredicate.java new file mode 100644 index 00000000000000..c013aa0ecf30ba --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/CaseWhenToCompoundPredicate.java @@ -0,0 +1,110 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression.rules; + +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; +import org.apache.doris.nereids.rules.expression.ExpressionRuleType; +import org.apache.doris.nereids.trees.expressions.And; +import org.apache.doris.nereids.trees.expressions.CaseWhen; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Not; +import org.apache.doris.nereids.trees.expressions.NullSafeEqual; +import org.apache.doris.nereids.trees.expressions.Or; +import org.apache.doris.nereids.trees.expressions.WhenClause; +import org.apache.doris.nereids.trees.expressions.functions.scalar.If; +import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Optional; + +/** + * if case when all branch value are true/false literal, and the ELSE default value can be any expression, + * then can eliminate this case when. + * + * for example: + * 1. case when c1 then true when c2 then false end => (c1 <=> true or (not (c2 <=> true) and null)) + * 2. if (c1, true, false) => c1 <=> true or false + */ +public class CaseWhenToCompoundPredicate implements ExpressionPatternRuleFactory { + public static CaseWhenToCompoundPredicate INSTANCE = new CaseWhenToCompoundPredicate(); + + @Override + public List> buildRules() { + return ImmutableList.of( + matchesType(CaseWhen.class) + .when(this::checkBooleanType) + .then(this::rewriteCaseWhen) + .toRule(ExpressionRuleType.CASE_WHEN_TO_COMPOUND_PREDICATE), + matchesType(If.class) + .when(this::checkBooleanType) + .then(this::rewriteIf) + .toRule(ExpressionRuleType.IF_TO_COMPOUND_PREDICATE) + ); + } + + private boolean checkBooleanType(Expression expression) { + return expression.getDataType().isBooleanType(); + } + + private Expression rewriteCaseWhen(CaseWhen caseWhen) { + Expression defaultValue = caseWhen.getDefaultValue().orElse(NullLiteral.BOOLEAN_INSTANCE); + return rewrite(caseWhen.getWhenClauses(), defaultValue).orElse(caseWhen); + } + + private Expression rewriteIf(If ifExpr) { + List whenClauses = ImmutableList.of(new WhenClause(ifExpr.getCondition(), ifExpr.getTrueValue())); + Expression defaultValue = ifExpr.getFalseValue(); + return rewrite(whenClauses, defaultValue).orElse(ifExpr); + } + + // for a branch, suppose the branches later it can rewrite to X, then given the branch: + // 1. when c then true ..., will rewrite to (c <=> true OR X), + // 2. when c then false ..., will rewrite to (not(c <=> true) AND X), + // for the ELSE branch, it can rewrite to `when true then defaultValue`, + // process the branches from back to front, the default value process first, while the first when clause will + // process last. + private Optional rewrite(List whenClauses, Expression defaultValue) { + for (WhenClause whenClause : whenClauses) { + Expression result = whenClause.getResult(); + if (!(result instanceof BooleanLiteral)) { + return Optional.empty(); + } + } + Expression result = defaultValue; + try { + for (int i = whenClauses.size() - 1; i >= 0; i--) { + WhenClause whenClause = whenClauses.get(i); + // operand <=> true + Expression condition = new NullSafeEqual(whenClause.getOperand(), BooleanLiteral.TRUE); + if (whenClause.getResult().equals(BooleanLiteral.TRUE)) { + result = new Or(condition, result); + } else { + result = new And(new Not(condition), result); + } + } + } catch (Exception e) { + // expression may exceed expression limit + return Optional.empty(); + } + return Optional.of(result); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java index 60f05f52f6270a..d53976d8ae23cf 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java @@ -92,7 +92,6 @@ import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral; import org.apache.doris.nereids.trees.expressions.literal.StringLiteral; import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral; -import org.apache.doris.nereids.types.BooleanType; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.TypeCoercionUtils; @@ -445,7 +444,7 @@ public Expression visitAnd(And and, ExpressionRewriteContext context) { } } else { // null and null and null and ... - return new NullLiteral(BooleanType.INSTANCE); + return NullLiteral.BOOLEAN_INSTANCE; } } @@ -491,7 +490,7 @@ public Expression visitOr(Or or, ExpressionRewriteContext context) { return or.withChildren(nonFalseLiteral); } else { // null or null - return new NullLiteral(BooleanType.INSTANCE); + return NullLiteral.BOOLEAN_INSTANCE; } } @@ -649,7 +648,7 @@ public Expression visitInPredicate(InPredicate inPredicate, ExpressionRewriteCon // now the inPredicate contains literal only. Expression value = inPredicate.child(0); if (value.isNullLiteral()) { - return new NullLiteral(BooleanType.INSTANCE); + return NullLiteral.BOOLEAN_INSTANCE; } boolean isOptionContainsNull = false; for (Expression item : inPredicate.getOptions()) { @@ -660,7 +659,7 @@ public Expression visitInPredicate(InPredicate inPredicate, ExpressionRewriteCon } } return isOptionContainsNull - ? new NullLiteral(BooleanType.INSTANCE) + ? NullLiteral.BOOLEAN_INSTANCE : BooleanLiteral.FALSE; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OneListPartitionEvaluator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OneListPartitionEvaluator.java index e0d2df9c0f25cd..1257e9840743a6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OneListPartitionEvaluator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OneListPartitionEvaluator.java @@ -30,7 +30,6 @@ import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; -import org.apache.doris.nereids.types.BooleanType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -115,7 +114,7 @@ public Expression visitInPredicate(InPredicate inPredicate, Map expressions) { } } - List exprList = Lists.newArrayList(distinctExpressions); + List exprList = ImmutableList.copyOf(distinctExpressions); if (exprList.isEmpty()) { return BooleanLiteral.TRUE; } else if (exprList.size() == 1) { @@ -266,7 +265,7 @@ public static Expression or(Collection expressions) { } } - List exprList = Lists.newArrayList(distinctExpressions); + List exprList = ImmutableList.copyOf(distinctExpressions); if (exprList.isEmpty()) { return BooleanLiteral.FALSE; } else if (exprList.size() == 1) { @@ -278,7 +277,7 @@ public static Expression or(Collection expressions) { public static Expression falseOrNull(Expression expression) { if (expression.nullable()) { - return new And(new IsNull(expression), new NullLiteral(BooleanType.INSTANCE)); + return new And(new IsNull(expression), NullLiteral.BOOLEAN_INSTANCE); } else { return BooleanLiteral.FALSE; } @@ -286,7 +285,7 @@ public static Expression falseOrNull(Expression expression) { public static Expression trueOrNull(Expression expression) { if (expression.nullable()) { - return new Or(new Not(new IsNull(expression)), new NullLiteral(BooleanType.INSTANCE)); + return new Or(new Not(new IsNull(expression)), NullLiteral.BOOLEAN_INSTANCE); } else { return BooleanLiteral.TRUE; } @@ -668,7 +667,7 @@ public static boolean canInferNotNullForMarkSlot(Expression predicate, Expressio * and in semi join, we can safely change the mark conjunct to hash conjunct */ ImmutableList literals = - ImmutableList.of(new NullLiteral(BooleanType.INSTANCE), BooleanLiteral.FALSE); + ImmutableList.of(NullLiteral.BOOLEAN_INSTANCE, BooleanLiteral.FALSE); List markJoinSlotReferenceList = new ArrayList<>((predicate.collect(MarkJoinSlotReference.class::isInstance))); int markSlotSize = markJoinSlotReferenceList.size(); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/CaseWhenToCompoundPredicateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/CaseWhenToCompoundPredicateTest.java new file mode 100644 index 00000000000000..7046c1cb2be67c --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/CaseWhenToCompoundPredicateTest.java @@ -0,0 +1,59 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression.rules; + +import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper; +import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Test; + +class CaseWhenToCompoundPredicateTest extends ExpressionRewriteTestHelper { + + @Test + void testCaseWhen() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp( + CaseWhenToCompoundPredicate.INSTANCE + ) + )); + assertRewriteAfterTypeCoercion("case when a = 1 then true end", "(a = 1 <=> TRUE) or null"); + assertRewriteAfterTypeCoercion("case when a = 1 then true else null end", "(a = 1 <=> TRUE) or null"); + assertRewriteAfterTypeCoercion("case when a = 1 then true else false end", "(a = 1 <=> TRUE) or false"); + assertRewriteAfterTypeCoercion("case when a = 1 then true else true end", "(a = 1 <=> TRUE) or true"); + assertRewriteAfterTypeCoercion("case when a = 1 then true else b = 1 end", "(a = 1 <=> TRUE) or b = 1"); + assertRewriteAfterTypeCoercion("case when a = 1 then true when b = 1 then true when c = 1 then true end", + "(a = 1 <=> TRUE) or (b = 1 <=> TRUE) or (c = 1 <=> TRUE) or null"); + assertRewriteAfterTypeCoercion("case when a = 1 then false when b = 1 then false when c = 1 then false end", + "not(a = 1 <=> TRUE) and not (b = 1 <=> TRUE) and not(c = 1 <=> TRUE) and null"); + assertRewriteAfterTypeCoercion("case when a = 1 then true when b = 1 then false when c = 1 then true end", + "(a = 1 <=> TRUE) or (not (b = 1 <=> TRUE) and ((c = 1 <=> TRUE) or null))"); + } + + @Test + void testIf() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp( + CaseWhenToCompoundPredicate.INSTANCE + ) + )); + + assertRewriteAfterTypeCoercion("if(a = 1, true, a > b)", "(a = 1 <=> TRUE) or a > b"); + assertRewriteAfterTypeCoercion("if(a = 1, false, a > b)", "not (a = 1 <=> TRUE) and a > b"); + } +} diff --git a/regression-test/data/nereids_rules_p0/infer_predicate/extend_infer_equal_predicate.out b/regression-test/data/nereids_rules_p0/infer_predicate/extend_infer_equal_predicate.out index 8ad01db3edc791..aa5646fbc94b08 100644 --- a/regression-test/data/nereids_rules_p0/infer_predicate/extend_infer_equal_predicate.out +++ b/regression-test/data/nereids_rules_p0/infer_predicate/extend_infer_equal_predicate.out @@ -134,24 +134,25 @@ PhysicalResultSink -- !test_if_predicate -- PhysicalResultSink --hashJoin[INNER_JOIN] hashCondition=((t1.d_int = t2.d_int)) otherCondition=() -----PhysicalOlapScan[extend_infer_t1(t1)] -----filter(if(( not d_int IN (10, 20)), TRUE, FALSE)) +----filter((( not d_int IN (10, 20)) <=> TRUE)) +------PhysicalOlapScan[extend_infer_t1(t1)] +----filter((( not d_int IN (10, 20)) <=> TRUE)) ------PhysicalOlapScan[extend_infer_t1(t2)] -- !test_if_and_in_predicate -- PhysicalResultSink --hashJoin[INNER_JOIN] hashCondition=((t1.d_int = t2.d_int)) otherCondition=() -----filter(( not (if((d_int = 5), TRUE, FALSE) = FALSE))) +----filter(( not (((d_int = 5) <=> TRUE) = FALSE))) ------PhysicalOlapScan[extend_infer_t1(t1)] -----filter(( not (if((d_int = 5), TRUE, FALSE) = FALSE))) +----filter(( not (((d_int = 5) <=> TRUE) = FALSE))) ------PhysicalOlapScan[extend_infer_t1(t2)] -- !test_if_and_in_predicate_not -- PhysicalResultSink --hashJoin[INNER_JOIN] hashCondition=((t1.d_int = t2.d_int)) otherCondition=() -----filter(( not (if((d_int = 5), TRUE, FALSE) = FALSE))) +----filter(( not (((d_int = 5) <=> TRUE) = FALSE))) ------PhysicalOlapScan[extend_infer_t1(t1)] -----filter(( not (if((d_int = 5), TRUE, FALSE) = FALSE))) +----filter(( not (((d_int = 5) <=> TRUE) = FALSE))) ------PhysicalOlapScan[extend_infer_t1(t2)] -- !test_multi_slot_in_predicate1 -- @@ -172,8 +173,9 @@ PhysicalResultSink -- !test_case_when_predicate -- PhysicalResultSink --hashJoin[INNER_JOIN] hashCondition=((t1.d_int = t2.d_int)) otherCondition=() -----PhysicalOlapScan[extend_infer_t1(t1)] -----filter(CASE WHEN (d_int = 1) THEN TRUE WHEN (d_int = 2) THEN FALSE ELSE FALSE END) +----filter(((t1.d_int = 1) <=> TRUE)) +------PhysicalOlapScan[extend_infer_t1(t1)] +----filter(((t2.d_int = 1) <=> TRUE)) ------PhysicalOlapScan[extend_infer_t1(t2)] -- !test_datetimev2_predicate --