From 730944ac2e674526a5d92e3695995db4ed4126b1 Mon Sep 17 00:00:00 2001 From: yujun Date: Sat, 22 Nov 2025 16:44:17 +0800 Subject: [PATCH 01/10] extract case when branch --- .../doris/nereids/jobs/executor/Rewriter.java | 2 + .../apache/doris/nereids/rules/RuleType.java | 1 + .../rewrite/JoinExtractOrFromCaseWhen.java | 316 ++++++++++++++++++ .../nereids/rules/rewrite/OrExpansion.java | 18 +- .../doris/nereids/util/ExpressionUtils.java | 41 +++ 5 files changed, 374 insertions(+), 4 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/JoinExtractOrFromCaseWhen.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index b7f948f00e1a1e..41d7843833ba25 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -98,6 +98,7 @@ import org.apache.doris.nereids.rules.rewrite.InferSetOperatorDistinct; import org.apache.doris.nereids.rules.rewrite.InitJoinOrder; import org.apache.doris.nereids.rules.rewrite.InlineLogicalView; +import org.apache.doris.nereids.rules.rewrite.JoinExtractOrFromCaseWhen; import org.apache.doris.nereids.rules.rewrite.LimitAggToTopNAgg; import org.apache.doris.nereids.rules.rewrite.LimitSortToTopN; import org.apache.doris.nereids.rules.rewrite.LogicalResultSinkToShortCircuitPointQuery; @@ -552,6 +553,7 @@ public class Rewriter extends AbstractBatchJobExecutor { new ReorderJoin(), new PushFilterInsideJoin(), new FindHashConditionForJoin(), + new JoinExtractOrFromCaseWhen(), new ConvertInnerOrCrossJoin(), new EliminateNullAwareLeftAntiJoin() ), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index b9e6313d65bdf5..a691009be415d4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -261,6 +261,7 @@ public enum RuleType { REWRITE_PARTITION_TOPN_EXPRESSION(RuleTypeClass.REWRITE), REWRITE_QUALIFY_EXPRESSION(RuleTypeClass.REWRITE), REWRITE_TOPN_EXPRESSION(RuleTypeClass.REWRITE), + JOIN_EXTRACT_OR_FROM_CASE_WHEN(RuleTypeClass.REWRITE), EXTRACT_FILTER_FROM_JOIN(RuleTypeClass.REWRITE), REORDER_JOIN(RuleTypeClass.REWRITE), INIT_JOIN_ORDER(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/JoinExtractOrFromCaseWhen.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/JoinExtractOrFromCaseWhen.java new file mode 100644 index 00000000000000..771d27a566f32d --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/JoinExtractOrFromCaseWhen.java @@ -0,0 +1,316 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.EqualPredicate; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Or; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.JoinUtils; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; + +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.Set; + +/** + * Extract case when branches to OR expressions for join conditions. + * Latter can help to generate more join conditions. + * + * 1. extract conditions for one side only, latter can push down the one side condition: + * + * t1 join t2 on not (case when t1.a = 1 then t2.a else t2.b) + t2.b + t2.c > 10) + * => + * t1 join t2 on not (case when t1.a = 1 then t2.a else t2.b end) + t2.b + t2.c > 10) + * AND (not (t2.a + t2.b + t2.c > 10) or not (t2.b + t2.b + t2.c > 10)) + * + * + * 2. extract or expansion hash conditions for both table sides: + * the or expansion hash condition is OR expression and each disjunction need to be equal predicate, + * and one side contains only left side slots, another side contains only right side slots. + * + * t1 join t2 on (case when t1.a = 1 then t2.a else t2.b end) = t1.a + t1.b + * => + * t1 join t2 on (case when t1.a = 1 then t2.a else t2.b end) = t1.a + t1.b + * AND (t2.a = t1.a + t1.b or t2.b = t1.a + t1.b) + + * Notice we don't extract more than one case when like expressions. + * because it may generate expressions with combinatorial explosion. + * + * (((case c1 then p1 else p2 end) + (case when d1 then q1 else q2 end))) + a > 10 + * => (p1 + q1 + a > 10) + * or (p1 + q2 + a > 10) + * or (p2 + q1 + a > 10) + * or (p2 + q2 + a > 10) + * + * so we only extract at most one case when like expression for each condition. + */ +public class JoinExtractOrFromCaseWhen implements RewriteRuleFactory { + + @Override + public List buildRules() { + return ImmutableList.of(logicalJoin() + .when(this::needRewrite) + .then(this::rewrite) + .toRule(RuleType.JOIN_EXTRACT_OR_FROM_CASE_WHEN)); + } + + private boolean needRewrite(LogicalJoin join) { + Set leftSlots = join.left().getOutputSet(); + Set rightSlots = join.right().getOutputSet(); + for (Expression expr : join.getOtherJoinConjuncts()) { + if (isConditionNeedRewrite(expr, leftSlots, rightSlots)) { + return true; + } + } + return false; + } + + // 1. expr contains slots from both sides; + private boolean isConditionNeedRewrite(Expression expr, Set leftSlots, Set rightSlots) { + return getExtractChildIndexAndOtherChildSlotFromLeft(expr, leftSlots, rightSlots).isPresent(); + } + + private Plan rewrite(LogicalJoin join) { + Set newOtherConditions = Sets.newLinkedHashSetWithExpectedSize(join.getOtherJoinConjuncts().size()); + newOtherConditions.addAll(join.getOtherJoinConjuncts()); + int oldCondSize = newOtherConditions.size(); + boolean extractHashCondition = OrExpansion.INSTANCE.needRewriteJoin(join); + List orExpandConds = Lists.newArrayList(); + for (Expression expr : join.getOtherJoinConjuncts()) { + tryAddOrExpansionHashCondition(orExpandConds, expr, join); + } + for (Expression expr : join.getOtherJoinConjuncts()) { + extractExpression(join, expr, extractHashCondition, newOtherConditions, orExpandConds); + } + if (!orExpandConds.isEmpty()) { + newOtherConditions.addAll(orExpandConds); + } + if (newOtherConditions.size() == oldCondSize) { + return join; + } else { + return join.withJoinConjuncts(join.getHashJoinConjuncts(), ImmutableList.copyOf(newOtherConditions), + join.getJoinReorderContext()); + } + } + + private void extractExpression(LogicalJoin join, Expression expr, + boolean extractHashCondition, Set conditions, List orExpandConds) { + Set leftSlots = join.left().getOutputSet(); + Set rightSlots = join.right().getOutputSet(); + Optional> extractOpt + = getExtractChildIndexAndOtherChildSlotFromLeft(expr, leftSlots, rightSlots); + if (!extractOpt.isPresent()) { + return; + } + + int extractChildIndex = extractOpt.get().first; + Boolean otherChildrenFromLeft = extractOpt.get().second; + if (otherChildrenFromLeft == null) { + doExtractExpression(expr, extractChildIndex, true, leftSlots, rightSlots) + .ifPresent(conditions::add); + doExtractExpression(expr, extractChildIndex, false, leftSlots, rightSlots) + .ifPresent(conditions::add); + } else { + doExtractExpression(expr, extractChildIndex, otherChildrenFromLeft, leftSlots, rightSlots) + .ifPresent(conditions::add); + if (expr instanceof EqualPredicate && extractHashCondition) { + doExtractExpression(expr, extractChildIndex, !otherChildrenFromLeft, leftSlots, rightSlots) + .ifPresent(cond -> tryAddOrExpansionHashCondition(orExpandConds, cond, join)); + } + } + } + + // Or Expansion only use one condition, so we keep the one with least disjunctions. + private void tryAddOrExpansionHashCondition(List orExpandConds, + Expression condition, LogicalJoin join) { + // Or Expansion only works for all the disjunctions are equal predicates + if (!JoinUtils.extractExpressionForHashTable( + join.left().getOutput(), join.right().getOutput(), ExpressionUtils.extractDisjunction(condition) + ).second.isEmpty()) { + return; + } + + if (orExpandConds.isEmpty()) { + orExpandConds.add(condition); + } else { + int childNum = condition instanceof Or ? condition.children().size() : 1; + int otherChildNum = orExpandConds.get(0) instanceof Or ? orExpandConds.get(0).children().size() : 1; + if (childNum < otherChildNum) { + orExpandConds.clear(); + orExpandConds.add(condition); + } + } + } + + // one child contains both side slots, other children contains only one side slots. + private Optional> getExtractChildIndexAndOtherChildSlotFromLeft(Expression expr, + Set leftSlots, Set rightSlots) { + if (expr.containsUniqueFunction()) { + return Optional.empty(); + } + int extractChildIndex = -1; + Boolean otherChildSlotFromLeft = null; + for (int i = 0; i < expr.children().size(); i++) { + Expression child = expr.child(i); + Set childSlots = child.getInputSlots(); + if (childSlots.isEmpty()) { + continue; + } + boolean containsLeft = !Collections.disjoint(childSlots, leftSlots); + boolean containsRight = !Collections.disjoint(childSlots, rightSlots); + if (containsLeft && containsRight) { + if (extractChildIndex != -1 || !ExpressionUtils.containsCaseWhenLikeType(child)) { + // more than one child contains both side slots + return Optional.empty(); + } + extractChildIndex = i; + } else if (containsLeft) { + if (otherChildSlotFromLeft == null) { + otherChildSlotFromLeft = true; + } else if (!otherChildSlotFromLeft) { + // one child from left, another child from right + return Optional.empty(); + } + } else if (containsRight) { + if (otherChildSlotFromLeft == null) { + otherChildSlotFromLeft = false; + } else if (otherChildSlotFromLeft) { + // one child from left, another child from right + return Optional.empty(); + } + } else { + // should not be here + return Optional.empty(); + } + } + + if (extractChildIndex == -1) { + return Optional.empty(); + } + + return Optional.of(Pair.of(extractChildIndex, otherChildSlotFromLeft)); + } + + private Optional doExtractExpression(Expression expr, int extractChildIndex, boolean childSlotFromLeft, + Set leftSlots, Set rightSlots) { + Expression target = expr.child(extractChildIndex); + Optional> expandTargetOpt = tryExtractCaseWhen( + target, childSlotFromLeft, leftSlots, rightSlots); + if (!expandTargetOpt.isPresent()) { + return Optional.empty(); + } + + List expandTargetExpressions = expandTargetOpt.get(); + if (expandTargetExpressions.size() <= 1) { + return Optional.empty(); + } + + List newChildren = Lists.newArrayList(expr.children()); + List disjuncts = Lists.newArrayListWithExpectedSize(expandTargetExpressions.size()); + for (Expression expandTargetExpr : expandTargetExpressions) { + newChildren.set(extractChildIndex, expandTargetExpr); + disjuncts.add(expr.withChildren(newChildren)); + } + + Expression result = ExpressionUtils.or(disjuncts); + if (result.getInputSlots().isEmpty()) { + return Optional.empty(); + } + + return Optional.of(ExpressionUtils.or(result)); + } + + private Optional> tryExtractCaseWhen(Expression expr, boolean childSlotFromLeft, + Set leftSlots, Set rightSlots) { + if (isSlotsEmptyOrFrom(expr, childSlotFromLeft, leftSlots, rightSlots)) { + return Optional.of(ImmutableList.of(expr)); + } + + Optional> caseWhenLikeResults = ExpressionUtils.getCaseWhenLikeBranchResults(expr); + if (caseWhenLikeResults.isPresent()) { + for (Expression branchResult : caseWhenLikeResults.get()) { + if (!isSlotsEmptyOrFrom(branchResult, childSlotFromLeft, leftSlots, rightSlots)) { + return Optional.empty(); + } + } + return caseWhenLikeResults; + } + + if (!ExpressionUtils.containsCaseWhenLikeType(expr)) { + return Optional.empty(); + } + + int expandChildIndex = -1; + List expandChildExpressions = null; + List newChildren = Lists.newArrayListWithExpectedSize(expr.children().size()); + for (int i = 0; i < expr.children().size(); i++) { + Expression child = expr.child(i); + Optional> childExtractedOpt = tryExtractCaseWhen( + child, childSlotFromLeft, leftSlots, rightSlots); + if (!childExtractedOpt.isPresent()) { + return Optional.empty(); + } + List childExtracted = childExtractedOpt.get(); + if (childExtracted.size() == 1) { + Expression newChild = childExtracted.get(0); + newChildren.add(newChild); + } else { + // more than one child to expand + if (expandChildIndex != -1) { + return Optional.empty(); + } + expandChildIndex = i; + expandChildExpressions = childExtracted; + // will replace the child later, add a placeholder first + newChildren.add(child); + } + } + if (expandChildIndex == -1) { + return Optional.empty(); + } + List resultExpressions = Lists.newArrayListWithExpectedSize(expandChildExpressions.size()); + for (Expression expandChildExpr : expandChildExpressions) { + newChildren.set(expandChildIndex, expandChildExpr); + Expression newExpr = expr.withChildren(newChildren); + resultExpressions.add(newExpr); + } + return Optional.of(resultExpressions); + } + + private boolean isSlotsEmptyOrFrom(Expression expr, boolean slotFromLeft, + Set leftSlots, Set rightSlots) { + Set exprSlots = expr.getInputSlots(); + if (slotFromLeft) { + return Collections.disjoint(exprSlots, rightSlots); + } else { + return Collections.disjoint(exprSlots, leftSlots); + } + } + +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/OrExpansion.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/OrExpansion.java index 4cb07fd44fd89b..1b01b6e8583c98 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/OrExpansion.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/OrExpansion.java @@ -121,12 +121,10 @@ public Plan visitLogicalCTEAnchor( @Override public Plan visitLogicalJoin(LogicalJoin join, OrExpandsionContext ctx) { join = (LogicalJoin) this.visit(join, ctx); - if (join.isMarkJoin() || !JoinUtils.shouldNestedLoopJoin(join)) { - return join; - } - if (!supportJoinType.contains(join.getJoinType())) { + if (!needRewriteJoin(join)) { return join; } + Preconditions.checkArgument(join.getHashJoinConjuncts().isEmpty(), "Only Expansion nest loop join without hashCond"); @@ -207,6 +205,18 @@ public Plan visitLogicalJoin(LogicalJoin join, O return null; } + /** + * check whether need to rewrite the join + * @param join + * @return + */ + public boolean needRewriteJoin(LogicalJoin join) { + if (join.isMarkJoin() || !JoinUtils.shouldNestedLoopJoin(join)) { + return false; + } + return supportJoinType.contains(join.getJoinType()); + } + private Map constructReplaceMap(LogicalCTEConsumer leftConsumer, Map leftCloneToLeft, LogicalCTEConsumer rightConsumer, Map rightCloneToRight) { Map replaced = new HashMap<>(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index 0efa98ddb0b770..12c68980224e74 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -37,6 +37,7 @@ import org.apache.doris.nereids.trees.TreeNode; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.And; +import org.apache.doris.nereids.trees.expressions.CaseWhen; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.CompoundPredicate; @@ -51,11 +52,15 @@ import org.apache.doris.nereids.trees.expressions.Or; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.WhenClause; import org.apache.doris.nereids.trees.expressions.WindowExpression; import org.apache.doris.nereids.trees.expressions.functions.agg.Avg; import org.apache.doris.nereids.trees.expressions.functions.agg.Max; import org.apache.doris.nereids.trees.expressions.functions.agg.Min; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.expressions.functions.scalar.If; +import org.apache.doris.nereids.trees.expressions.functions.scalar.NullIf; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Nvl; import org.apache.doris.nereids.trees.expressions.functions.scalar.UniqueFunction; import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; import org.apache.doris.nereids.trees.expressions.literal.ComparableLiteral; @@ -1220,6 +1225,42 @@ public static String slotListShapeInfo(List materializedSlots) { return shapeBuilder.toString(); } + /** + * check whether the expression contains CaseWhen like type + */ + public static boolean containsCaseWhenLikeType(Expression expression) { + return expression.containsType(CaseWhen.class, If.class, NullIf.class, Nvl.class); + } + + /** + * get the results of each branch in CaseWhen like expression + * @param expression + * @return + */ + public static Optional> getCaseWhenLikeBranchResults(Expression expression) { + if (expression instanceof CaseWhen) { + CaseWhen caseWhen = (CaseWhen) expression; + ImmutableList.Builder builder + = ImmutableList.builderWithExpectedSize(caseWhen.getWhenClauses().size() + 1); + for (WhenClause whenClause : caseWhen.getWhenClauses()) { + builder.add(whenClause.getResult()); + } + builder.add(caseWhen.getDefaultValue().orElse(new NullLiteral(caseWhen.getDataType()))); + return Optional.of(builder.build()); + } else if (expression instanceof If) { + If ifExpr = (If) expression; + return Optional.of(ImmutableList.of(ifExpr.getTrueValue(), ifExpr.getFalseValue())); + } else if (expression instanceof NullIf) { + NullIf nullIf = (NullIf) expression; + return Optional.of(ImmutableList.of(new NullLiteral(nullIf.getDataType()), nullIf.left())); + } else if (expression instanceof Nvl) { + Nvl nvl = (Nvl) expression; + return Optional.of(ImmutableList.of(nvl.left(), nvl.right())); + } else { + return Optional.empty(); + } + } + /** * has aggregate function, exclude the window function */ From 9e7291d955770974dc0a96a391e913d27c049284 Mon Sep 17 00:00:00 2001 From: yujun Date: Thu, 27 Nov 2025 08:55:13 +0800 Subject: [PATCH 02/10] add test --- .../join_extract_or_from_case_when.out | 215 ++++++++++++++++++ .../join_extract_or_from_case_when.groovy | 153 +++++++++++++ 2 files changed, 368 insertions(+) create mode 100644 regression-test/data/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.out create mode 100644 regression-test/suites/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.groovy diff --git a/regression-test/data/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.out b/regression-test/data/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.out new file mode 100644 index 00000000000000..ae85313e475050 --- /dev/null +++ b/regression-test/data/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.out @@ -0,0 +1,215 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !case_when_one_side_1 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[INNER_JOIN]( not (CASE WHEN (x > 10) THEN a WHEN (y > 10) THEN (a + b) END = 95)) +------PhysicalProject[(a + b) AS `(a + b)`, t1.a] +--------filter(OR[( not (a = 95)),( not ((a + b) = 95))]) +----------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1(t1)] +------PhysicalProject[(x > 10) AS `(x > 10)`, (y > 10) AS `(y > 10)`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2(t2)] + +-- !case_when_one_side_2 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[INNER_JOIN]( not (CASE WHEN (x > 10) THEN a WHEN (y > 10) THEN (a + b) END = ((a - b) - 5))) +------PhysicalProject[((a - b) - 5) AS `((a - b) - 5)`, (a + b) AS `(a + b)`, t1.a] +--------filter(OR[( not (a = ((a - b) - 5))),( not ((a + b) = ((a - b) - 5)))]) +----------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1(t1)] +------PhysicalProject[(x > 10) AS `(x > 10)`, (y > 10) AS `(y > 10)`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2(t2)] + +-- !case_when_one_side_3 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[INNER_JOIN]( not (CASE WHEN (x > 10) THEN a WHEN (y > 10) THEN (a + b) ELSE b END = 95)) +------PhysicalProject[(a + b) AS `(a + b)`, t1.a, t1.b] +--------filter(OR[( not (a = 95)),( not ((a + b) = 95)),( not (b = 95))]) +----------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1(t1)] +------PhysicalProject[(x > 10) AS `(x > 10)`, (y > 10) AS `(y > 10)`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2(t2)] + +-- !case_when_one_side_4 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[INNER_JOIN]( not (CASE WHEN (x > 10) THEN a WHEN (y > 10) THEN (a + b) ELSE b END = ((a - b) - 5))) +------PhysicalProject[((a - b) - 5) AS `((a - b) - 5)`, (a + b) AS `(a + b)`, t1.a, t1.b] +--------filter(OR[( not (a = ((a - b) - 5))),( not ((a + b) = ((a - b) - 5))),( not (b = ((a - b) - 5)))]) +----------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1(t1)] +------PhysicalProject[(x > 10) AS `(x > 10)`, (y > 10) AS `(y > 10)`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2(t2)] + +-- !case_when_one_side_5 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[INNER_JOIN]((CASE WHEN (x > 10) THEN a WHEN (y > 10) THEN (a + b) ELSE b END + random(1, 100)) = (a - b)) +------PhysicalProject[(a + b) AS `(a + b)`, (a - b) AS `(a - b)`, t1.a, t1.b] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1(t1)] +------PhysicalProject[(x > 10) AS `(x > 10)`, (y > 10) AS `(y > 10)`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2(t2)] + +-- !case_when_one_side_6 -- +PhysicalResultSink +--NestedLoopJoin[CROSS_JOIN] +----PhysicalProject[t1.a] +------filter(( not (CASE WHEN (a > 10) THEN a WHEN (b > 10) THEN (a + b) END = 95))) +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1(t1)] +----PhysicalProject[t2.x] +------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2(t2)] + +-- !case_when_two_side_1 -- +PhysicalCteAnchor ( cteId=CTEId#0 ) +--PhysicalCteProducer ( cteId=CTEId#0 ) +----PhysicalProject[(a + b) AS `(a + b)`, t1.a, t1.b] +------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1(t1)] +--PhysicalCteAnchor ( cteId=CTEId#1 ) +----PhysicalCteProducer ( cteId=CTEId#1 ) +------PhysicalProject[((x + y) + -4) AS `((x + y) + -4)`, (x > 10) AS `(x > 10)`, (y > 10) AS `(y > 10)`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2(t2)] +----PhysicalResultSink +------PhysicalUnion +--------PhysicalProject[.a, .x] +----------hashJoin[INNER_JOIN] hashCondition=((.a = .((x + y) + -4))) otherCondition=((CASE WHEN (x > 10) THEN a WHEN (y > 10) THEN (a + b) ELSE b END = .((x + y) + -4))) +------------PhysicalCteConsumer ( cteId=CTEId#0 ) +------------PhysicalCteConsumer ( cteId=CTEId#1 ) +--------PhysicalProject[.a, .x] +----------hashJoin[INNER_JOIN] hashCondition=((.(a + b) = .((x + y) + -4))) otherCondition=((CASE WHEN (x > 10) THEN a WHEN (y > 10) THEN (a + b) ELSE b END = .((x + y) + -4)) and OR[( not (a = ((x + y) + -4))),(a = ((x + y) + -4)) IS NULL]) +------------PhysicalCteConsumer ( cteId=CTEId#0 ) +------------PhysicalCteConsumer ( cteId=CTEId#1 ) +--------PhysicalProject[.a, .x] +----------hashJoin[INNER_JOIN] hashCondition=((.b = .((x + y) + -4))) otherCondition=((CASE WHEN (x > 10) THEN a WHEN (y > 10) THEN (a + b) ELSE b END = .((x + y) + -4)) and OR[( not ((a + b) = ((x + y) + -4))),((a + b) = ((x + y) + -4)) IS NULL] and OR[( not (a = ((x + y) + -4))),(a = ((x + y) + -4)) IS NULL]) +------------PhysicalCteConsumer ( cteId=CTEId#0 ) +------------PhysicalCteConsumer ( cteId=CTEId#1 ) + +-- !case_when_one_side_2 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[INNER_JOIN](CASE WHEN (x > 10) THEN a WHEN (y > 10) THEN (a + b) END = ((x + y) + -4)) +------PhysicalProject[(a + b) AS `(a + b)`, t1.a] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1(t1)] +------PhysicalProject[((x + y) + -4) AS `((x + y) + -4)`, (x > 10) AS `(x > 10)`, (y > 10) AS `(y > 10)`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2(t2)] + +-- !case_when_two_side_3 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[INNER_JOIN]((CASE WHEN (x > 10) THEN a WHEN (y > 10) THEN (a + b) ELSE b END + random(1, 10)) = ((x + y) + 1)) +------PhysicalProject[(a + b) AS `(a + b)`, t1.a, t1.b] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1(t1)] +------PhysicalProject[((x + y) + 1) AS `((x + y) + 1)`, (x > 10) AS `(x > 10)`, (y > 10) AS `(y > 10)`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2(t2)] + +-- !case_when_two_side_4 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[INNER_JOIN](CASE WHEN (x > 10) THEN a WHEN (y > 10) THEN (a + b) ELSE (x + b) END = ((x + y) + 1)) +------PhysicalProject[(a + b) AS `(a + b)`, t1.a, t1.b] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1(t1)] +------PhysicalProject[((x + y) + 1) AS `((x + y) + 1)`, (x > 10) AS `(x > 10)`, (y > 10) AS `(y > 10)`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2(t2)] + +-- !case_when_two_side_5 -- +PhysicalCteAnchor ( cteId=CTEId#0 ) +--PhysicalCteProducer ( cteId=CTEId#0 ) +----PhysicalProject[(a + b) AS `(a + b)`, t1.a, t1.b] +------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1(t1)] +--PhysicalCteAnchor ( cteId=CTEId#1 ) +----PhysicalCteProducer ( cteId=CTEId#1 ) +------PhysicalProject[((x + y) + -4) AS `((x + y) + -4)`, (x - y) AS `(x - y)`, (x > 10) AS `(x > 10)`, (y > 10) AS `(y > 10)`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2(t2)] +----PhysicalResultSink +------PhysicalUnion +--------PhysicalProject[.a, .x] +----------hashJoin[INNER_JOIN] hashCondition=((.a = .(x - y))) otherCondition=((CASE WHEN (x > 10) THEN a WHEN (y > 10) THEN (a + b) ELSE b END = .((x + y) + -4)) and (if((a > x), a, b) = .(x - y))) +------------PhysicalCteConsumer ( cteId=CTEId#0 ) +------------PhysicalCteConsumer ( cteId=CTEId#1 ) +--------PhysicalProject[.a, .x] +----------hashJoin[INNER_JOIN] hashCondition=((.b = .(x - y))) otherCondition=((CASE WHEN (x > 10) THEN a WHEN (y > 10) THEN (a + b) ELSE b END = .((x + y) + -4)) and (if((a > x), a, b) = .(x - y)) and OR[( not (a = (x - y))),(a = (x - y)) IS NULL]) +------------PhysicalCteConsumer ( cteId=CTEId#0 ) +------------PhysicalCteConsumer ( cteId=CTEId#1 ) + +-- !case_when_two_side_6 -- +PhysicalCteAnchor ( cteId=CTEId#0 ) +--PhysicalCteProducer ( cteId=CTEId#0 ) +----PhysicalProject[a IS NULL AS `a IS NULL`, coalesce(a, b) AS `coalesce(a, b)`, t1.a] +------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1(t1)] +--PhysicalCteAnchor ( cteId=CTEId#1 ) +----PhysicalCteProducer ( cteId=CTEId#1 ) +------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2(t2)] +----PhysicalResultSink +------PhysicalUnion +--------PhysicalProject[.a, .x] +----------hashJoin[INNER_JOIN] hashCondition=((.x = .coalesce(a, b))) otherCondition=((if(a IS NULL, x, y) = .coalesce(a, b))) +------------PhysicalCteConsumer ( cteId=CTEId#0 ) +------------PhysicalCteConsumer ( cteId=CTEId#1 ) +--------PhysicalProject[.a, .x] +----------hashJoin[INNER_JOIN] hashCondition=((.y = .coalesce(a, b))) otherCondition=((if(a IS NULL, x, y) = .coalesce(a, b)) and OR[( not (x = coalesce(a, b))),(x = coalesce(a, b)) IS NULL]) +------------PhysicalCteConsumer ( cteId=CTEId#0 ) +------------PhysicalCteConsumer ( cteId=CTEId#1 ) + +-- !if_one_side_1 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[INNER_JOIN](if((a > x), a, b) = (a + b)) +------PhysicalProject[(a + b) AS `(a + b)`, t1.a, t1.b] +--------filter(OR[(t1.a = (t1.a + t1.b)),(t1.b = (t1.a + t1.b))]) +----------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1(t1)] +------PhysicalProject[t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2(t2)] + +-- !if_two_side_1 -- +PhysicalCteAnchor ( cteId=CTEId#0 ) +--PhysicalCteProducer ( cteId=CTEId#0 ) +----PhysicalOlapScan[tbl_join_extract_or_from_case_when_1(t1)] +--PhysicalCteAnchor ( cteId=CTEId#1 ) +----PhysicalCteProducer ( cteId=CTEId#1 ) +------PhysicalProject[(x + y) AS `(x + y)`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2(t2)] +----PhysicalResultSink +------PhysicalUnion +--------PhysicalProject[.a, .x] +----------hashJoin[INNER_JOIN] hashCondition=((.a = .(x + y))) otherCondition=((if((a > x), a, b) = .(x + y))) +------------PhysicalCteConsumer ( cteId=CTEId#0 ) +------------PhysicalCteConsumer ( cteId=CTEId#1 ) +--------PhysicalProject[.a, .x] +----------hashJoin[INNER_JOIN] hashCondition=((.b = .(x + y))) otherCondition=((if((a > x), a, b) = .(x + y)) and OR[( not (a = (x + y))),(a = (x + y)) IS NULL]) +------------PhysicalCteConsumer ( cteId=CTEId#0 ) +------------PhysicalCteConsumer ( cteId=CTEId#1 ) + +-- !ifnull_one_side_1 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[INNER_JOIN](ifnull(a, x) = (a + b)) +------PhysicalProject[(a + b) AS `(a + b)`, t1.a] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1(t1)] +------PhysicalProject[t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2(t2)] + +-- !ifnull_two_side_1 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[INNER_JOIN](ifnull(a, x) = (x + y)) +------PhysicalProject[t1.a] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1(t1)] +------PhysicalProject[(x + y) AS `(x + y)`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2(t2)] + +-- !nullif_one_side_1 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[INNER_JOIN](nullif(a, x) = (a + b)) +------PhysicalProject[(a + b) AS `(a + b)`, t1.a] +--------filter((t1.a = (t1.a + t1.b))) +----------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1(t1)] +------PhysicalProject[t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2(t2)] + +-- !nullif_two_side_1 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[INNER_JOIN](nullif(a, x) = (x + y)) +------PhysicalProject[t1.a] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1(t1)] +------PhysicalProject[(x + y) AS `(x + y)`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2(t2)] + diff --git a/regression-test/suites/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.groovy b/regression-test/suites/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.groovy new file mode 100644 index 00000000000000..514af9b0152b4d --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.groovy @@ -0,0 +1,153 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS +// OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite('join_extract_or_from_case_when') { + sql """ + SET disable_nereids_rules='PRUNE_EMPTY_PARTITION'; + SET detail_shape_nodes='PhysicalProject,PhysicalHashAggregate'; + SET ignore_shape_nodes='PhysicalDistribute'; + SET runtime_filter_type=2; + SET disable_join_reorder=true; + DROP TABLE IF EXISTS tbl_join_extract_or_from_case_when_1 FORCE; + DROP TABLE IF EXISTS tbl_join_extract_or_from_case_when_2 FORCE; + CREATE TABLE tbl_join_extract_or_from_case_when_1 (a bigint, b bigint) properties('replication_num' = '1'); + CREATE TABLE tbl_join_extract_or_from_case_when_2 (x bigint, y bigint) properties('replication_num' = '1'); + """ + + qt_case_when_one_side_1 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on not((case when t2.x > 10 then t1.a when t2.y > 10 then t1.a + t1.b end) + 5 = 100); + """ + + qt_case_when_one_side_2 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on not((case when t2.x > 10 then t1.a when t2.y > 10 then t1.a + t1.b end) + 5 = t1.a - t1.b); + """ + + qt_case_when_one_side_3 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on not((case when t2.x > 10 then t1.a when t2.y > 10 then t1.a + t1.b else t1.b end) + 5 = 100); + """ + + qt_case_when_one_side_4 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on not((case when t2.x > 10 then t1.a when t2.y > 10 then t1.a + t1.b else t1.b end) + 5 = t1.a - t1.b); + """ + + // random will not extract + qt_case_when_one_side_5 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on case when t2.x > 10 then t1.a when t2.y > 10 then t1.a + t1.b else t1.b end + random(1, 100) = t1.a - t1.b; + """ + + // the origin expression not contains two sides will not extract + qt_case_when_one_side_6 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on not((case when t1.a > 10 then t1.a when t1.b > 10 then t1.a + t1.b end) + 5 = 100); + """ + + qt_case_when_two_side_1 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on (case when t2.x > 10 then t1.a when t2.y > 10 then t1.a + t1.b else t1.b end) + 5 = t2.x + t2.y + 1; + """ + + // any case when branch no contains slot from other side will not extract + qt_case_when_one_side_2 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on (case when t2.x > 10 then t1.a when t2.y > 10 then t1.a + t1.b end) + 5 = t2.x + t2.y + 1; + """ + + // random will not extract + qt_case_when_two_side_3 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on case when t2.x > 10 then t1.a when t2.y > 10 then t1.a + t1.b else t1.b end + random(1, 10) = t2.x + t2.y + 1; + """ + + // any case when branch contains slot from both sides will not extract + qt_case_when_two_side_4 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on case when t2.x > 10 then t1.a when t2.y > 10 then t1.a + t1.b else t2.x + t1.b end = t2.x + t2.y + 1; + """ + + // extract the least OR EXPANSION hash condition + qt_case_when_two_side_5 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on if(t1.a > t2.x, t1.a, t1.b) = t2.x - t2.y and (case when t2.x > 10 then t1.a when t2.y > 10 then t1.a + t1.b else t1.b end) + 5 = t2.x + t2.y + 1; + """ + + qt_case_when_two_side_6 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on case when t1.a is null then t2.x else t2.y end = COALESCE(t1.a, t1.b); + """ + + qt_if_one_side_1 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on if(t1.a > t2.x, t1.a, t1.b) = t1.a + t1.b; + """ + + qt_if_two_side_1 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on if(t1.a > t2.x, t1.a, t1.b) = t2.x + t2.y; + """ + + // in fact, IFNULL will nerver rewrite becase the rule require the + // case when expression contains both side slots and all the branch results contains only one side slots. + qt_ifnull_one_side_1 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on ifnull(t1.a, t2.x) = t1.a + t1.b; + """ + + // in fact, IFNULL will nerver rewrite becase the rule require the + // case when expression contains both side slots and all the branch results contains only one side slots. + qt_ifnull_two_side_1 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on ifnull(t1.a, t2.x) = t2.x + t2.y; + """ + + qt_nullif_one_side_1 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on nullif(t1.a, t2.x) = t1.a + t1.b; + """ + + qt_nullif_two_side_1 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on nullif(t1.a, t2.x) = t2.x + t2.y; + """ + + sql """ + DROP TABLE IF EXISTS tbl_join_extract_or_from_case_when_1 FORCE; + DROP TABLE IF EXISTS tbl_join_extract_or_from_case_when_2 FORCE; + """ +} From 711a66b92fd2bcb7d95eb85157db7ece73d0677d Mon Sep 17 00:00:00 2001 From: yujun Date: Thu, 27 Nov 2025 09:07:34 +0800 Subject: [PATCH 03/10] add annotation --- .../doris/nereids/jobs/executor/Rewriter.java | 7 ++++--- .../rules/rewrite/JoinExtractOrFromCaseWhen.java | 13 ++++++------- .../doris/nereids/rules/rewrite/OrExpansion.java | 4 +--- .../apache/doris/nereids/util/ExpressionUtils.java | 2 -- 4 files changed, 11 insertions(+), 15 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index 41d7843833ba25..693ee3a2799918 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -324,7 +324,8 @@ public class Rewriter extends AbstractBatchJobExecutor { new PushFilterInsideJoin(), new FindHashConditionForJoin(), new ConvertInnerOrCrossJoin(), - new EliminateNullAwareLeftAntiJoin() + new EliminateNullAwareLeftAntiJoin(), + new JoinExtractOrFromCaseWhen() ), // push down SEMI Join bottomUp( @@ -553,9 +554,9 @@ public class Rewriter extends AbstractBatchJobExecutor { new ReorderJoin(), new PushFilterInsideJoin(), new FindHashConditionForJoin(), - new JoinExtractOrFromCaseWhen(), new ConvertInnerOrCrossJoin(), - new EliminateNullAwareLeftAntiJoin() + new EliminateNullAwareLeftAntiJoin(), + new JoinExtractOrFromCaseWhen() ), // push down SEMI Join bottomUp( diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/JoinExtractOrFromCaseWhen.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/JoinExtractOrFromCaseWhen.java index 771d27a566f32d..85beed6051d4e2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/JoinExtractOrFromCaseWhen.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/JoinExtractOrFromCaseWhen.java @@ -39,10 +39,9 @@ import java.util.Set; /** - * Extract case when branches to OR expressions for join conditions. - * Latter can help to generate more join conditions. + * Join extract OR expression from case when / if / nullif expressions. * - * 1. extract conditions for one side only, latter can push down the one side condition: + * 1. extract conditions for one side, latter can push down the one side condition: * * t1 join t2 on not (case when t1.a = 1 then t2.a else t2.b) + t2.b + t2.c > 10) * => @@ -50,9 +49,8 @@ * AND (not (t2.a + t2.b + t2.c > 10) or not (t2.b + t2.b + t2.c > 10)) * * - * 2. extract or expansion hash conditions for both table sides: - * the or expansion hash condition is OR expression and each disjunction need to be equal predicate, - * and one side contains only left side slots, another side contains only right side slots. + * 2. extract conditions for both sides, which use for OR EXPANSION rule: + * the OR EXPANSION is an OR expression which all its disjuncts are hash join conditions. * * t1 join t2 on (case when t1.a = 1 then t2.a else t2.b end) = t1.a + t1.b * => @@ -63,7 +61,8 @@ * because it may generate expressions with combinatorial explosion. * * (((case c1 then p1 else p2 end) + (case when d1 then q1 else q2 end))) + a > 10 - * => (p1 + q1 + a > 10) + * => + * (p1 + q1 + a > 10) * or (p1 + q2 + a > 10) * or (p2 + q1 + a > 10) * or (p2 + q2 + a > 10) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/OrExpansion.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/OrExpansion.java index 1b01b6e8583c98..cf06ee608c14dd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/OrExpansion.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/OrExpansion.java @@ -206,9 +206,7 @@ public Plan visitLogicalJoin(LogicalJoin join, O } /** - * check whether need to rewrite the join - * @param join - * @return + * check whether it need to rewrite the join */ public boolean needRewriteJoin(LogicalJoin join) { if (join.isMarkJoin() || !JoinUtils.shouldNestedLoopJoin(join)) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index 12c68980224e74..8d74099b325f32 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -1234,8 +1234,6 @@ public static boolean containsCaseWhenLikeType(Expression expression) { /** * get the results of each branch in CaseWhen like expression - * @param expression - * @return */ public static Optional> getCaseWhenLikeBranchResults(Expression expression) { if (expression instanceof CaseWhen) { From 965b16ee75b04188a2380e5408c648e9599671bd Mon Sep 17 00:00:00 2001 From: yujun Date: Thu, 27 Nov 2025 10:38:01 +0800 Subject: [PATCH 04/10] add test --- .../join_extract_or_from_case_when.out | 19 +++++++++++++++++++ .../join_extract_or_from_case_when.groovy | 15 +++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/regression-test/data/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.out b/regression-test/data/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.out index ae85313e475050..302d7c3d20b2cd 100644 --- a/regression-test/data/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.out +++ b/regression-test/data/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.out @@ -57,6 +57,25 @@ PhysicalResultSink ----PhysicalProject[t2.x] ------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2(t2)] +-- !case_when_one_side_7 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[INNER_JOIN]((CASE WHEN (x > 0) THEN a WHEN (x < 10) THEN (a + 1) END + CASE WHEN (x > 1) THEN (a + 1) WHEN (x < 10) THEN (a + 10) END) > (a + b)) +------PhysicalProject[(a + 1) AS `(a + 1)`, (a + 10) AS `(a + 10)`, (a + b) AS `(a + b)`, t1.a] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1(t1)] +------PhysicalProject[(x < 10) AS `(x < 10)`, (x > 0) AS `(x > 0)`, (x > 1) AS `(x > 1)`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2(t2)] + +-- !case_when_one_side_8 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[INNER_JOIN]((CASE WHEN (x > 0) THEN a WHEN (x < 10) THEN (a + 1) END + CASE WHEN (a > 1) THEN (a + 1) WHEN (a < 10) THEN (a + 10) END) > (a + b)) +------PhysicalProject[(a + 1) AS `(a + 1)`, (a + b) AS `(a + b)`, CASE WHEN (a > 1) THEN (a + 1) WHEN (a < 10) THEN (a + 10) END AS `CASE WHEN (a > 1) THEN (a + 1) WHEN (a < 10) THEN (a + 10) END`, t1.a] +--------filter(OR[((t1.a + CASE WHEN (a > 1) THEN (a + 1) WHEN (a < 10) THEN (a + 10) END) > (t1.a + t1.b)),((t1.a + CASE WHEN (a > 1) THEN (a + 1) WHEN (a < 10) THEN (a + 10) END) > ((t1.a + t1.b) - 1))]) +----------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1(t1)] +------PhysicalProject[(x < 10) AS `(x < 10)`, (x > 0) AS `(x > 0)`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2(t2)] + -- !case_when_two_side_1 -- PhysicalCteAnchor ( cteId=CTEId#0 ) --PhysicalCteProducer ( cteId=CTEId#0 ) diff --git a/regression-test/suites/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.groovy b/regression-test/suites/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.groovy index 514af9b0152b4d..d233342b010d63 100644 --- a/regression-test/suites/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.groovy +++ b/regression-test/suites/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.groovy @@ -66,6 +66,21 @@ suite('join_extract_or_from_case_when') { on not((case when t1.a > 10 then t1.a when t1.b > 10 then t1.a + t1.b end) + 5 = 100); """ + // two case when branch contains both side slots will not rewrite + qt_case_when_one_side_7 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on case when t2.x > 0 then t1.a when t2.x < 10 then t1.a + 1 end + + case when t2.x > 1 then t1.a + 1 when t2.x < 10 then t1.a + 10 end > t1.a + t1.b + """ + + qt_case_when_one_side_8 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on case when t2.x > 0 then t1.a when t2.x < 10 then t1.a + 1 end + + case when t1.a > 1 then t1.a + 1 when t1.a < 10 then t1.a + 10 end > t1.a + t1.b + """ + qt_case_when_two_side_1 """explain shape plan select t1.a, t2.x from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 From 2874d65c67d8835fa88bd823d97c7c1baad7fff6 Mon Sep 17 00:00:00 2001 From: yujun Date: Thu, 27 Nov 2025 15:55:24 +0800 Subject: [PATCH 05/10] fix test --- .../join_extract_or_from_case_when.groovy | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/regression-test/suites/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.groovy b/regression-test/suites/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.groovy index d233342b010d63..1641f39e1f308a 100644 --- a/regression-test/suites/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.groovy +++ b/regression-test/suites/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.groovy @@ -20,7 +20,7 @@ suite('join_extract_or_from_case_when') { SET disable_nereids_rules='PRUNE_EMPTY_PARTITION'; SET detail_shape_nodes='PhysicalProject,PhysicalHashAggregate'; SET ignore_shape_nodes='PhysicalDistribute'; - SET runtime_filter_type=2; + SET runtime_filter_mode=OFF; SET disable_join_reorder=true; DROP TABLE IF EXISTS tbl_join_extract_or_from_case_when_1 FORCE; DROP TABLE IF EXISTS tbl_join_extract_or_from_case_when_2 FORCE; From 4baab5204949d168830daad85411cffaa70c7741 Mon Sep 17 00:00:00 2001 From: yujun Date: Thu, 27 Nov 2025 18:15:35 +0800 Subject: [PATCH 06/10] add more for expansion slots --- .../rewrite/JoinExtractOrFromCaseWhen.java | 109 +++++++++++------- .../join_extract_or_from_case_when.out | 36 ++++-- .../join_extract_or_from_case_when.groovy | 10 +- 3 files changed, 105 insertions(+), 50 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/JoinExtractOrFromCaseWhen.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/JoinExtractOrFromCaseWhen.java index 85beed6051d4e2..88c31a77399d22 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/JoinExtractOrFromCaseWhen.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/JoinExtractOrFromCaseWhen.java @@ -20,9 +20,10 @@ import org.apache.doris.common.Pair; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.rules.FoldConstantRuleOnFE; import org.apache.doris.nereids.trees.expressions.EqualPredicate; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.Or; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; @@ -37,6 +38,7 @@ import java.util.List; import java.util.Optional; import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; /** * Join extract OR expression from case when / if / nullif expressions. @@ -75,7 +77,7 @@ public class JoinExtractOrFromCaseWhen implements RewriteRuleFactory { public List buildRules() { return ImmutableList.of(logicalJoin() .when(this::needRewrite) - .then(this::rewrite) + .thenApply(ctx -> rewrite(ctx.root, new ExpressionRewriteContext(ctx.root, ctx.cascadesContext))) .toRule(RuleType.JOIN_EXTRACT_OR_FROM_CASE_WHEN)); } @@ -95,20 +97,20 @@ private boolean isConditionNeedRewrite(Expression expr, Set leftSlots, Set return getExtractChildIndexAndOtherChildSlotFromLeft(expr, leftSlots, rightSlots).isPresent(); } - private Plan rewrite(LogicalJoin join) { + private Plan rewrite(LogicalJoin join, ExpressionRewriteContext context) { Set newOtherConditions = Sets.newLinkedHashSetWithExpectedSize(join.getOtherJoinConjuncts().size()); newOtherConditions.addAll(join.getOtherJoinConjuncts()); int oldCondSize = newOtherConditions.size(); - boolean extractHashCondition = OrExpansion.INSTANCE.needRewriteJoin(join); - List orExpandConds = Lists.newArrayList(); + boolean extractOrExpansionCondition = OrExpansion.INSTANCE.needRewriteJoin(join); + AtomicReference> leastOrExpandCondRef = new AtomicReference<>(); for (Expression expr : join.getOtherJoinConjuncts()) { - tryAddOrExpansionHashCondition(orExpandConds, expr, join); + tryAddOrExpansionHashCondition(expr, join, context, leastOrExpandCondRef); } for (Expression expr : join.getOtherJoinConjuncts()) { - extractExpression(join, expr, extractHashCondition, newOtherConditions, orExpandConds); + extractExpression(expr, context, join, extractOrExpansionCondition, newOtherConditions, leastOrExpandCondRef); } - if (!orExpandConds.isEmpty()) { - newOtherConditions.addAll(orExpandConds); + if (leastOrExpandCondRef.get() != null) { + newOtherConditions.add(leastOrExpandCondRef.get().first); } if (newOtherConditions.size() == oldCondSize) { return join; @@ -118,8 +120,9 @@ private Plan rewrite(LogicalJoin join) { } } - private void extractExpression(LogicalJoin join, Expression expr, - boolean extractHashCondition, Set conditions, List orExpandConds) { + private void extractExpression(Expression expr, ExpressionRewriteContext context, + LogicalJoin join, boolean extractOrExpansionCondition, + Set conditions, AtomicReference> leastOrExpandCondRef) { Set leftSlots = join.left().getOutputSet(); Set rightSlots = join.right().getOutputSet(); Optional> extractOpt @@ -131,39 +134,54 @@ private void extractExpression(LogicalJoin join, int extractChildIndex = extractOpt.get().first; Boolean otherChildrenFromLeft = extractOpt.get().second; if (otherChildrenFromLeft == null) { + // extract expression for left side doExtractExpression(expr, extractChildIndex, true, leftSlots, rightSlots) .ifPresent(conditions::add); + // extract expression for right side doExtractExpression(expr, extractChildIndex, false, leftSlots, rightSlots) .ifPresent(conditions::add); } else { + // extract expression for one side, all child slots need from the same side doExtractExpression(expr, extractChildIndex, otherChildrenFromLeft, leftSlots, rightSlots) .ifPresent(conditions::add); - if (expr instanceof EqualPredicate && extractHashCondition) { + if (expr instanceof EqualPredicate && extractOrExpansionCondition) { + // extract expression for hash condition only when the expr is equal predicate doExtractExpression(expr, extractChildIndex, !otherChildrenFromLeft, leftSlots, rightSlots) - .ifPresent(cond -> tryAddOrExpansionHashCondition(orExpandConds, cond, join)); + .ifPresent(cond -> tryAddOrExpansionHashCondition(cond, join, context, leastOrExpandCondRef)); } } } // Or Expansion only use one condition, so we keep the one with least disjunctions. - private void tryAddOrExpansionHashCondition(List orExpandConds, - Expression condition, LogicalJoin join) { + private void tryAddOrExpansionHashCondition(Expression condition, LogicalJoin join, + ExpressionRewriteContext context, AtomicReference> leastOrExpandCondRef) { // Or Expansion only works for all the disjunctions are equal predicates - if (!JoinUtils.extractExpressionForHashTable( - join.left().getOutput(), join.right().getOutput(), ExpressionUtils.extractDisjunction(condition) - ).second.isEmpty()) { + List disjunctions = ExpressionUtils.extractDisjunction(condition); + List remainOtherConditions = JoinUtils.extractExpressionForHashTable( + join.left().getOutput(), join.right().getOutput(), disjunctions).second; + int hashCondLen = disjunctions.size() - remainOtherConditions.size(); + // no hash condition extracted, all are other conditions + if (hashCondLen == 0) { return; } - if (orExpandConds.isEmpty()) { - orExpandConds.add(condition); - } else { - int childNum = condition instanceof Or ? condition.children().size() : 1; - int otherChildNum = orExpandConds.get(0) instanceof Or ? orExpandConds.get(0).children().size() : 1; - if (childNum < otherChildNum) { - orExpandConds.clear(); - orExpandConds.add(condition); + for (Expression expr : remainOtherConditions) { + // for case when t1.a > t2.x then t1.a when t1.b > t2.y then t1. b else null end = t2.x + // then will extract E1 = (t1.a = t2.x or t1.b = t2.x or null = t2.x) + // but E1 can not use as OR Expansion condition, because null = t2.x is not a valid hash join condition. + // but after we fold null = t2.x to null, latter expression simplifier can simplify E1 + // to (t1.a = t2.x or t1.b = t2.x), then it becomes a valid OR Expansion condition. + Expression foldExpr = FoldConstantRuleOnFE.evaluate(expr, context); + if (!foldExpr.isLiteral()) { + return; } + // foldExpr should be NULL / TRUE /FALSE, later expression simplifier can handle them. + } + + Pair leastOrExpandCond = leastOrExpandCondRef.get(); + int oldHashCondLen = leastOrExpandCond == null ? -1 : leastOrExpandCond.second; + if (oldHashCondLen == -1 || hashCondLen < oldHashCondLen) { + leastOrExpandCondRef.set(Pair.of(condition, hashCondLen)); } } @@ -216,16 +234,19 @@ private Optional> getExtractChildIndexAndOtherChildSlotFr return Optional.of(Pair.of(extractChildIndex, otherChildSlotFromLeft)); } + // extract case when expression from `expr`'s child at `extractChildIndex`. + // after extraction, all slots in expr's child at `extractChildIndex` + // are from one side indicated by `childSlotFromLeft`. private Optional doExtractExpression(Expression expr, int extractChildIndex, boolean childSlotFromLeft, Set leftSlots, Set rightSlots) { - Expression target = expr.child(extractChildIndex); - Optional> expandTargetOpt = tryExtractCaseWhen( - target, childSlotFromLeft, leftSlots, rightSlots); - if (!expandTargetOpt.isPresent()) { + Expression expandChild = expr.child(extractChildIndex); + Optional> resultOpt = tryExtractCaseWhen( + expandChild, childSlotFromLeft, leftSlots, rightSlots); + if (!resultOpt.isPresent()) { return Optional.empty(); } - List expandTargetExpressions = expandTargetOpt.get(); + List expandTargetExpressions = resultOpt.get(); if (expandTargetExpressions.size() <= 1) { return Optional.empty(); } @@ -245,33 +266,38 @@ private Optional doExtractExpression(Expression expr, int extractChi return Optional.of(ExpressionUtils.or(result)); } - private Optional> tryExtractCaseWhen(Expression expr, boolean childSlotFromLeft, + // try to extract case when like expressions from expr. + // after extraction, all slots in expr are from one side indicated by `slotFromLeft`. + // if `expr`'s all slots are already from `slotFromLeft`, return expr itself, no need handle with its children. + // otherwise will recurse its children to extract case when like expressions. + private Optional> tryExtractCaseWhen(Expression expr, boolean slotFromLeft, Set leftSlots, Set rightSlots) { - if (isSlotsEmptyOrFrom(expr, childSlotFromLeft, leftSlots, rightSlots)) { + if (isAllSlotsFromLeftSide(expr, slotFromLeft, leftSlots, rightSlots)) { return Optional.of(ImmutableList.of(expr)); } + if (!ExpressionUtils.containsCaseWhenLikeType(expr)) { + return Optional.empty(); + } + + // process case when like expression. Optional> caseWhenLikeResults = ExpressionUtils.getCaseWhenLikeBranchResults(expr); if (caseWhenLikeResults.isPresent()) { for (Expression branchResult : caseWhenLikeResults.get()) { - if (!isSlotsEmptyOrFrom(branchResult, childSlotFromLeft, leftSlots, rightSlots)) { + if (!isAllSlotsFromLeftSide(branchResult, slotFromLeft, leftSlots, rightSlots)) { return Optional.empty(); } } return caseWhenLikeResults; } - if (!ExpressionUtils.containsCaseWhenLikeType(expr)) { - return Optional.empty(); - } - int expandChildIndex = -1; List expandChildExpressions = null; List newChildren = Lists.newArrayListWithExpectedSize(expr.children().size()); for (int i = 0; i < expr.children().size(); i++) { Expression child = expr.child(i); Optional> childExtractedOpt = tryExtractCaseWhen( - child, childSlotFromLeft, leftSlots, rightSlots); + child, slotFromLeft, leftSlots, rightSlots); if (!childExtractedOpt.isPresent()) { return Optional.empty(); } @@ -302,12 +328,15 @@ private Optional> tryExtractCaseWhen(Expression expr, boolean c return Optional.of(resultExpressions); } - private boolean isSlotsEmptyOrFrom(Expression expr, boolean slotFromLeft, + // check whether all slots in expr are from one side, allow empty slots. + private boolean isAllSlotsFromLeftSide(Expression expr, boolean slotFromLeft, Set leftSlots, Set rightSlots) { Set exprSlots = expr.getInputSlots(); if (slotFromLeft) { + // no slots from right return Collections.disjoint(exprSlots, rightSlots); } else { + // no slots from left return Collections.disjoint(exprSlots, leftSlots); } } diff --git a/regression-test/data/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.out b/regression-test/data/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.out index 302d7c3d20b2cd..21e76576a1fa00 100644 --- a/regression-test/data/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.out +++ b/regression-test/data/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.out @@ -100,14 +100,25 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ------------PhysicalCteConsumer ( cteId=CTEId#0 ) ------------PhysicalCteConsumer ( cteId=CTEId#1 ) --- !case_when_one_side_2 -- -PhysicalResultSink ---PhysicalProject[t1.a, t2.x] -----NestedLoopJoin[INNER_JOIN](CASE WHEN (x > 10) THEN a WHEN (y > 10) THEN (a + b) END = ((x + y) + -4)) -------PhysicalProject[(a + b) AS `(a + b)`, t1.a] ---------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1(t1)] +-- !case_when_two_side_2 -- +PhysicalCteAnchor ( cteId=CTEId#0 ) +--PhysicalCteProducer ( cteId=CTEId#0 ) +----PhysicalProject[(a + b) AS `(a + b)`, t1.a] +------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1(t1)] +--PhysicalCteAnchor ( cteId=CTEId#1 ) +----PhysicalCteProducer ( cteId=CTEId#1 ) ------PhysicalProject[((x + y) + -4) AS `((x + y) + -4)`, (x > 10) AS `(x > 10)`, (y > 10) AS `(y > 10)`, t2.x] --------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2(t2)] +----PhysicalResultSink +------PhysicalUnion +--------PhysicalProject[.a, .x] +----------hashJoin[INNER_JOIN] hashCondition=((.a = .((x + y) + -4))) otherCondition=((CASE WHEN (x > 10) THEN a WHEN (y > 10) THEN (a + b) END = .((x + y) + -4))) +------------PhysicalCteConsumer ( cteId=CTEId#0 ) +------------PhysicalCteConsumer ( cteId=CTEId#1 ) +--------PhysicalProject[.a, .x] +----------hashJoin[INNER_JOIN] hashCondition=((.(a + b) = .((x + y) + -4))) otherCondition=((CASE WHEN (x > 10) THEN a WHEN (y > 10) THEN (a + b) END = .((x + y) + -4)) and OR[( not (a = ((x + y) + -4))),(a = ((x + y) + -4)) IS NULL]) +------------PhysicalCteConsumer ( cteId=CTEId#0 ) +------------PhysicalCteConsumer ( cteId=CTEId#1 ) -- !case_when_two_side_3 -- PhysicalResultSink @@ -166,6 +177,15 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ------------PhysicalCteConsumer ( cteId=CTEId#0 ) ------------PhysicalCteConsumer ( cteId=CTEId#1 ) +-- !case_when_two_side_7 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[INNER_JOIN](CASE WHEN (x > 10) THEN a WHEN (y > 10) THEN (a + b) ELSE 100 END = ((x + y) + -4)) +------PhysicalProject[(a + b) AS `(a + b)`, t1.a] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1(t1)] +------PhysicalProject[((x + y) + -4) AS `((x + y) + -4)`, (x > 10) AS `(x > 10)`, (y > 10) AS `(y > 10)`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2(t2)] + -- !if_one_side_1 -- PhysicalResultSink --PhysicalProject[t1.a, t2.x] @@ -226,9 +246,9 @@ PhysicalResultSink -- !nullif_two_side_1 -- PhysicalResultSink --PhysicalProject[t1.a, t2.x] -----NestedLoopJoin[INNER_JOIN](nullif(a, x) = (x + y)) +----hashJoin[INNER_JOIN] hashCondition=((t1.a = expr_(x + y))) otherCondition=((nullif(a, x) = (t2.x + t2.y))) ------PhysicalProject[t1.a] --------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1(t1)] -------PhysicalProject[(x + y) AS `(x + y)`, t2.x] +------PhysicalProject[(x + y) AS `expr_(x + y)`, tbl_join_extract_or_from_case_when_2.x, tbl_join_extract_or_from_case_when_2.y] --------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2(t2)] diff --git a/regression-test/suites/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.groovy b/regression-test/suites/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.groovy index 1641f39e1f308a..72db9487ac6361 100644 --- a/regression-test/suites/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.groovy +++ b/regression-test/suites/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.groovy @@ -87,8 +87,7 @@ suite('join_extract_or_from_case_when') { on (case when t2.x > 10 then t1.a when t2.y > 10 then t1.a + t1.b else t1.b end) + 5 = t2.x + t2.y + 1; """ - // any case when branch no contains slot from other side will not extract - qt_case_when_one_side_2 """explain shape plan + qt_case_when_two_side_2 """explain shape plan select t1.a, t2.x from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 on (case when t2.x > 10 then t1.a when t2.y > 10 then t1.a + t1.b end) + 5 = t2.x + t2.y + 1; @@ -121,6 +120,13 @@ suite('join_extract_or_from_case_when') { on case when t1.a is null then t2.x else t2.y end = COALESCE(t1.a, t1.b); """ + // any case when branch no contains slot from other side will not extract + qt_case_when_two_side_7 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on (case when t2.x > 10 then t1.a when t2.y > 10 then t1.a + t1.b else 100 end) + 5 = t2.x + t2.y + 1; + """ + qt_if_one_side_1 """explain shape plan select t1.a, t2.x from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 From 7e6fb335730def57680a2518c37565782fec5ae6 Mon Sep 17 00:00:00 2001 From: yujun Date: Thu, 27 Nov 2025 22:28:49 +0800 Subject: [PATCH 07/10] update --- .../rewrite/JoinExtractOrFromCaseWhen.java | 176 ++++++++++++------ .../join_extract_or_from_case_when.out | 10 + .../join_extract_or_from_case_when.groovy | 7 + 3 files changed, 131 insertions(+), 62 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/JoinExtractOrFromCaseWhen.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/JoinExtractOrFromCaseWhen.java index 88c31a77399d22..6ad6e699adec5f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/JoinExtractOrFromCaseWhen.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/JoinExtractOrFromCaseWhen.java @@ -73,6 +73,13 @@ */ public class JoinExtractOrFromCaseWhen implements RewriteRuleFactory { + private enum SlotFrom { + FROM_LEFT_SIDE_ONLY, + FROM_RIGHT_SIDE_ONLY, + FROM_BOTH_SIDE, + FROM_NONE, + } + @Override public List buildRules() { return ImmutableList.of(logicalJoin() @@ -94,7 +101,17 @@ private boolean needRewrite(LogicalJoin join) { // 1. expr contains slots from both sides; private boolean isConditionNeedRewrite(Expression expr, Set leftSlots, Set rightSlots) { - return getExtractChildIndexAndOtherChildSlotFromLeft(expr, leftSlots, rightSlots).isPresent(); + if (expr.containsUniqueFunction()) { + return false; + } + Set exprSlots = expr.getInputSlots(); + // all slots are from one side, no need to process it. + if (Collections.disjoint(exprSlots, leftSlots) || Collections.disjoint(exprSlots, rightSlots) + || !ExpressionUtils.containsCaseWhenLikeType(expr)) { + return false; + } + // can not rewrite none case when child + return getNoneCaseWhenChildSlotFrom(expr, leftSlots, rightSlots) != SlotFrom.FROM_BOTH_SIDE; } private Plan rewrite(LogicalJoin join, ExpressionRewriteContext context) { @@ -125,29 +142,49 @@ private void extractExpression(Expression expr, ExpressionRewriteContext context Set conditions, AtomicReference> leastOrExpandCondRef) { Set leftSlots = join.left().getOutputSet(); Set rightSlots = join.right().getOutputSet(); - Optional> extractOpt - = getExtractChildIndexAndOtherChildSlotFromLeft(expr, leftSlots, rightSlots); - if (!extractOpt.isPresent()) { + if (!isConditionNeedRewrite(expr, leftSlots, rightSlots)) { return; } + Boolean otherChildrenFromLeft = null; + switch (getNoneCaseWhenChildSlotFrom(expr, leftSlots, rightSlots)) { + case FROM_BOTH_SIDE: { + return; + } + case FROM_LEFT_SIDE_ONLY: { + otherChildrenFromLeft = true; + break; + } + case FROM_RIGHT_SIDE_ONLY: { + otherChildrenFromLeft = false; + break; + } + } - int extractChildIndex = extractOpt.get().first; - Boolean otherChildrenFromLeft = extractOpt.get().second; - if (otherChildrenFromLeft == null) { - // extract expression for left side - doExtractExpression(expr, extractChildIndex, true, leftSlots, rightSlots) - .ifPresent(conditions::add); - // extract expression for right side - doExtractExpression(expr, extractChildIndex, false, leftSlots, rightSlots) - .ifPresent(conditions::add); - } else { - // extract expression for one side, all child slots need from the same side - doExtractExpression(expr, extractChildIndex, otherChildrenFromLeft, leftSlots, rightSlots) - .ifPresent(conditions::add); - if (expr instanceof EqualPredicate && extractOrExpansionCondition) { - // extract expression for hash condition only when the expr is equal predicate - doExtractExpression(expr, extractChildIndex, !otherChildrenFromLeft, leftSlots, rightSlots) - .ifPresent(cond -> tryAddOrExpansionHashCondition(cond, join, context, leastOrExpandCondRef)); + List childrenSlotFrom = Lists.newArrayListWithExpectedSize(expr.children().size()); + for (Expression child : expr.children()) { + childrenSlotFrom.add(getExpressionSlotFrom(child, leftSlots, rightSlots)); + } + for (int i = 0; i < expr.children().size(); i++) { + if (!ExpressionUtils.containsCaseWhenLikeType(expr.child(i))) { + continue; + } + if (otherChildrenFromLeft == null) { + // extract expression for left side + doExtractExpression(expr, i, true, true, childrenSlotFrom, leftSlots, rightSlots) + .ifPresent(conditions::add); + // extract expression for right side + doExtractExpression(expr, i, false, false, childrenSlotFrom, leftSlots, rightSlots) + .ifPresent(conditions::add); + } else { + // extract expression for one side, all child slots need from the same side + doExtractExpression(expr, i, otherChildrenFromLeft, otherChildrenFromLeft, childrenSlotFrom, leftSlots, rightSlots) + .ifPresent(conditions::add); + if (expr instanceof EqualPredicate && extractOrExpansionCondition) { + // extract expression for hash condition only when the expr is equal predicate + doExtractExpression(expr, i, !otherChildrenFromLeft, otherChildrenFromLeft, childrenSlotFrom, + leftSlots, rightSlots).ifPresent( + cond -> tryAddOrExpansionHashCondition(cond, join, context, leastOrExpandCondRef)); + } } } } @@ -185,69 +222,69 @@ private void tryAddOrExpansionHashCondition(Expression condition, LogicalJoin> getExtractChildIndexAndOtherChildSlotFromLeft(Expression expr, - Set leftSlots, Set rightSlots) { - if (expr.containsUniqueFunction()) { - return Optional.empty(); - } - int extractChildIndex = -1; - Boolean otherChildSlotFromLeft = null; + private SlotFrom getNoneCaseWhenChildSlotFrom(Expression expr, Set leftSlots, Set rightSlots) { + SlotFrom mergeSlotFrom = SlotFrom.FROM_NONE; for (int i = 0; i < expr.children().size(); i++) { Expression child = expr.child(i); - Set childSlots = child.getInputSlots(); - if (childSlots.isEmpty()) { + if (ExpressionUtils.containsCaseWhenLikeType(child)) { continue; } - boolean containsLeft = !Collections.disjoint(childSlots, leftSlots); - boolean containsRight = !Collections.disjoint(childSlots, rightSlots); - if (containsLeft && containsRight) { - if (extractChildIndex != -1 || !ExpressionUtils.containsCaseWhenLikeType(child)) { - // more than one child contains both side slots - return Optional.empty(); + SlotFrom childSlotFrom = getExpressionSlotFrom(child, leftSlots, rightSlots); + switch (childSlotFrom) { + case FROM_LEFT_SIDE_ONLY: { + if (mergeSlotFrom == SlotFrom.FROM_RIGHT_SIDE_ONLY) { + return SlotFrom.FROM_BOTH_SIDE; + } + mergeSlotFrom = childSlotFrom; + break; } - extractChildIndex = i; - } else if (containsLeft) { - if (otherChildSlotFromLeft == null) { - otherChildSlotFromLeft = true; - } else if (!otherChildSlotFromLeft) { - // one child from left, another child from right - return Optional.empty(); + case FROM_RIGHT_SIDE_ONLY: { + if (mergeSlotFrom == SlotFrom.FROM_LEFT_SIDE_ONLY) { + return SlotFrom.FROM_BOTH_SIDE; + } + mergeSlotFrom = childSlotFrom; + break; } - } else if (containsRight) { - if (otherChildSlotFromLeft == null) { - otherChildSlotFromLeft = false; - } else if (otherChildSlotFromLeft) { - // one child from left, another child from right - return Optional.empty(); + case FROM_BOTH_SIDE: { + return childSlotFrom; } - } else { - // should not be here - return Optional.empty(); } } - if (extractChildIndex == -1) { - return Optional.empty(); - } - - return Optional.of(Pair.of(extractChildIndex, otherChildSlotFromLeft)); + return mergeSlotFrom; } // extract case when expression from `expr`'s child at `extractChildIndex`. - // after extraction, all slots in expr's child at `extractChildIndex` - // are from one side indicated by `childSlotFromLeft`. - private Optional doExtractExpression(Expression expr, int extractChildIndex, boolean childSlotFromLeft, + // after extraction, all slots of this child at `extractChildIndex` + // are from one side indicated by `extractChildSlotFromLeft`. + // for expr's other children, their slot from need met the otherChildSlotFromLeft + private Optional doExtractExpression(Expression expr, int extractChildIndex, + boolean extractChildSlotFromLeft, boolean otherChildSlotFromLeft, List childrenSlotFrom, Set leftSlots, Set rightSlots) { + for (int i = 0; i < expr.children().size(); i++) { + // we only rewrite extractChildIndex, + // so for other child, it must need the requirement of `otherChildSlotFromLeft` + if (i != extractChildIndex) { + // use childrenSlotFrom to avoid call Collection.disjoint too many, but maybe we can delete it. + SlotFrom slotFrom = childrenSlotFrom.get(i); + if (slotFrom == SlotFrom.FROM_BOTH_SIDE + || (otherChildSlotFromLeft && slotFrom == SlotFrom.FROM_RIGHT_SIDE_ONLY) + || (!otherChildSlotFromLeft && slotFrom == SlotFrom.FROM_LEFT_SIDE_ONLY)) { + return Optional.empty(); + } + } + } + Expression expandChild = expr.child(extractChildIndex); Optional> resultOpt = tryExtractCaseWhen( - expandChild, childSlotFromLeft, leftSlots, rightSlots); + expandChild, extractChildSlotFromLeft, leftSlots, rightSlots); if (!resultOpt.isPresent()) { return Optional.empty(); } List expandTargetExpressions = resultOpt.get(); if (expandTargetExpressions.size() <= 1) { + // if size = 1, then it don't expand, should be just the expr itself. return Optional.empty(); } @@ -341,4 +378,19 @@ private boolean isAllSlotsFromLeftSide(Expression expr, boolean slotFromLeft, } } + private SlotFrom getExpressionSlotFrom(Expression expr, Set leftSlots, Set rightSlots) { + Set exprSlots = expr.getInputSlots(); + boolean containsLeft = !Collections.disjoint(exprSlots, leftSlots); + boolean containsRight = !Collections.disjoint(exprSlots, rightSlots); + if (containsLeft && containsRight) { + return SlotFrom.FROM_BOTH_SIDE; + } else if (containsLeft) { + return SlotFrom.FROM_LEFT_SIDE_ONLY; + } else if (containsRight) { + return SlotFrom.FROM_RIGHT_SIDE_ONLY; + } else { + return SlotFrom.FROM_NONE; + } + } + } diff --git a/regression-test/data/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.out b/regression-test/data/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.out index 21e76576a1fa00..d20ad125aee2f6 100644 --- a/regression-test/data/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.out +++ b/regression-test/data/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.out @@ -76,6 +76,16 @@ PhysicalResultSink ------PhysicalProject[(x < 10) AS `(x < 10)`, (x > 0) AS `(x > 0)`, t2.x] --------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2(t2)] +-- !case_when_one_side_9 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[INNER_JOIN](CASE WHEN (x > 0) THEN 105 WHEN (x < 10) THEN 10005 ELSE NULL END > (a + b)) +------PhysicalProject[(a + b) AS `(a + b)`, t1.a] +--------filter(((t1.a + t1.b) < 10005)) +----------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1(t1)] +------PhysicalProject[CASE WHEN (x > 0) THEN 105 WHEN (x < 10) THEN 10005 ELSE NULL END AS `CASE WHEN (x > 0) THEN 105 WHEN (x < 10) THEN 10005 ELSE NULL END`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2(t2)] + -- !case_when_two_side_1 -- PhysicalCteAnchor ( cteId=CTEId#0 ) --PhysicalCteProducer ( cteId=CTEId#0 ) diff --git a/regression-test/suites/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.groovy b/regression-test/suites/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.groovy index 72db9487ac6361..fbec8fa6aa9fba 100644 --- a/regression-test/suites/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.groovy +++ b/regression-test/suites/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.groovy @@ -81,6 +81,13 @@ suite('join_extract_or_from_case_when') { + case when t1.a > 1 then t1.a + 1 when t1.a < 10 then t1.a + 10 end > t1.a + t1.b """ + // case when's result are all literal + qt_case_when_one_side_9 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on case when t2.x > 0 then 100 when t2.x < 10 then 10000 end + 5 > t1.a + t1.b + """ + qt_case_when_two_side_1 """explain shape plan select t1.a, t2.x from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 From 2f7564212b5798bea6cb7e9d7698269bd5a0b2a7 Mon Sep 17 00:00:00 2001 From: yujun Date: Thu, 27 Nov 2025 22:58:53 +0800 Subject: [PATCH 08/10] update --- .../rewrite/JoinExtractOrFromCaseWhen.java | 27 +++++++++++++------ .../rewrite/PushDownJoinOtherCondition.java | 24 ++++++++++++++--- .../join_extract_or_from_case_when.out | 18 +++++++++++++ .../join_extract_or_from_case_when.groovy | 14 ++++++++++ 4 files changed, 71 insertions(+), 12 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/JoinExtractOrFromCaseWhen.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/JoinExtractOrFromCaseWhen.java index 6ad6e699adec5f..fb9287fe8c70cc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/JoinExtractOrFromCaseWhen.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/JoinExtractOrFromCaseWhen.java @@ -88,7 +88,10 @@ public List buildRules() { .toRule(RuleType.JOIN_EXTRACT_OR_FROM_CASE_WHEN)); } - private boolean needRewrite(LogicalJoin join) { + private boolean needRewrite(LogicalJoin join) { + if (!PushDownJoinOtherCondition.needRewrite(join) && !OrExpansion.INSTANCE.needRewriteJoin(join)) { + return false; + } Set leftSlots = join.left().getOutputSet(); Set rightSlots = join.right().getOutputSet(); for (Expression expr : join.getOtherJoinConjuncts()) { @@ -160,6 +163,8 @@ private void extractExpression(Expression expr, ExpressionRewriteContext context } } + boolean canPushLeft = PushDownJoinOtherCondition.PUSH_DOWN_LEFT_VALID_TYPE.contains(join.getJoinType()); + boolean canPushRight = PushDownJoinOtherCondition.PUSH_DOWN_RIGHT_VALID_TYPE.contains(join.getJoinType()); List childrenSlotFrom = Lists.newArrayListWithExpectedSize(expr.children().size()); for (Expression child : expr.children()) { childrenSlotFrom.add(getExpressionSlotFrom(child, leftSlots, rightSlots)); @@ -170,15 +175,21 @@ private void extractExpression(Expression expr, ExpressionRewriteContext context } if (otherChildrenFromLeft == null) { // extract expression for left side - doExtractExpression(expr, i, true, true, childrenSlotFrom, leftSlots, rightSlots) - .ifPresent(conditions::add); - // extract expression for right side - doExtractExpression(expr, i, false, false, childrenSlotFrom, leftSlots, rightSlots) - .ifPresent(conditions::add); + if (canPushLeft) { + doExtractExpression(expr, i, true, true, childrenSlotFrom, leftSlots, rightSlots) + .ifPresent(conditions::add); + } + if (canPushRight) { + // extract expression for right side + doExtractExpression(expr, i, false, false, childrenSlotFrom, leftSlots, rightSlots) + .ifPresent(conditions::add); + } } else { // extract expression for one side, all child slots need from the same side - doExtractExpression(expr, i, otherChildrenFromLeft, otherChildrenFromLeft, childrenSlotFrom, leftSlots, rightSlots) - .ifPresent(conditions::add); + if ((otherChildrenFromLeft && canPushLeft) || (!otherChildrenFromLeft && canPushRight)) { + doExtractExpression(expr, i, otherChildrenFromLeft, otherChildrenFromLeft, + childrenSlotFrom, leftSlots, rightSlots).ifPresent(conditions::add); + } if (expr instanceof EqualPredicate && extractOrExpansionCondition) { // extract expression for hash condition only when the expr is equal predicate doExtractExpression(expr, i, !otherChildrenFromLeft, otherChildrenFromLeft, childrenSlotFrom, diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOtherCondition.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOtherCondition.java index b622c51d27dad5..013b47f9abbfab 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOtherCondition.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOtherCondition.java @@ -37,7 +37,10 @@ * Push the other join conditions in LogicalJoin to children. */ public class PushDownJoinOtherCondition extends OneRewriteRuleFactory { - private static final ImmutableList PUSH_DOWN_LEFT_VALID_TYPE = ImmutableList.of( + /** + * left push support type + */ + public static final ImmutableList PUSH_DOWN_LEFT_VALID_TYPE = ImmutableList.of( JoinType.INNER_JOIN, JoinType.LEFT_SEMI_JOIN, JoinType.RIGHT_OUTER_JOIN, @@ -46,7 +49,10 @@ public class PushDownJoinOtherCondition extends OneRewriteRuleFactory { JoinType.CROSS_JOIN ); - private static final ImmutableList PUSH_DOWN_RIGHT_VALID_TYPE = ImmutableList.of( + /** + * right push support type + */ + public static final ImmutableList PUSH_DOWN_RIGHT_VALID_TYPE = ImmutableList.of( JoinType.INNER_JOIN, JoinType.LEFT_OUTER_JOIN, JoinType.LEFT_ANTI_JOIN, @@ -60,7 +66,7 @@ public class PushDownJoinOtherCondition extends OneRewriteRuleFactory { public Rule build() { return logicalJoin() // TODO: we may need another rule to handle on true or on false condition - .when(join -> !join.getOtherJoinConjuncts().isEmpty() && !join.isMarkJoin()) + .when(PushDownJoinOtherCondition::needRewrite) .then(join -> { List otherJoinConjuncts = join.getOtherJoinConjuncts(); List remainingOther = Lists.newArrayList(); @@ -93,7 +99,17 @@ && allCoveredBy(otherConjunct, join.right().getOutputSet())) { }).toRule(RuleType.PUSH_DOWN_JOIN_OTHER_CONDITION); } - private boolean allCoveredBy(Expression predicate, Set inputSlotSet) { + /** + * check need rewrite + */ + public static boolean needRewrite(LogicalJoin join) { + return !join.getOtherJoinConjuncts().isEmpty() + && !join.isMarkJoin() + && (PUSH_DOWN_LEFT_VALID_TYPE.contains(join.getJoinType()) + || PUSH_DOWN_RIGHT_VALID_TYPE.contains(join.getJoinType())); + } + + private static boolean allCoveredBy(Expression predicate, Set inputSlotSet) { return inputSlotSet.containsAll(predicate.getInputSlots()); } } diff --git a/regression-test/data/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.out b/regression-test/data/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.out index d20ad125aee2f6..c022f2fb688e29 100644 --- a/regression-test/data/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.out +++ b/regression-test/data/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.out @@ -86,6 +86,24 @@ PhysicalResultSink ------PhysicalProject[CASE WHEN (x > 0) THEN 105 WHEN (x < 10) THEN 10005 ELSE NULL END AS `CASE WHEN (x > 0) THEN 105 WHEN (x < 10) THEN 10005 ELSE NULL END`, t2.x] --------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2(t2)] +-- !case_when_one_side_10 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[LEFT_OUTER_JOIN](CASE WHEN (x > 0) THEN 105 WHEN (x < 10) THEN 10005 ELSE NULL END > (a + b)) +------PhysicalProject[(a + b) AS `(a + b)`, t1.a] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1(t1)] +------PhysicalProject[CASE WHEN (x > 0) THEN 105 WHEN (x < 10) THEN 10005 ELSE NULL END AS `CASE WHEN (x > 0) THEN 105 WHEN (x < 10) THEN 10005 ELSE NULL END`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2(t2)] + +-- !case_when_one_side_11 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[RIGHT_OUTER_JOIN](CASE WHEN (a > 0) THEN 105 WHEN (b < 10) THEN 10005 ELSE NULL END > (x + y)) +------PhysicalProject[CASE WHEN (a > 0) THEN 105 WHEN (b < 10) THEN 10005 ELSE NULL END AS `CASE WHEN (a > 0) THEN 105 WHEN (b < 10) THEN 10005 ELSE NULL END`, t1.a] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1(t1)] +------PhysicalProject[(x + y) AS `(x + y)`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2(t2)] + -- !case_when_two_side_1 -- PhysicalCteAnchor ( cteId=CTEId#0 ) --PhysicalCteProducer ( cteId=CTEId#0 ) diff --git a/regression-test/suites/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.groovy b/regression-test/suites/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.groovy index fbec8fa6aa9fba..ab74f5cf28967d 100644 --- a/regression-test/suites/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.groovy +++ b/regression-test/suites/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.groovy @@ -88,6 +88,20 @@ suite('join_extract_or_from_case_when') { on case when t2.x > 0 then 100 when t2.x < 10 then 10000 end + 5 > t1.a + t1.b """ + // not push down because not meet push down other requirement + qt_case_when_one_side_10 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 left join tbl_join_extract_or_from_case_when_2 t2 + on case when t2.x > 0 then 100 when t2.x < 10 then 10000 end + 5 > t1.a + t1.b + """ + + // not push down because not meet push down other requirement + qt_case_when_one_side_11 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 right join tbl_join_extract_or_from_case_when_2 t2 + on case when t1.a > 0 then 100 when t1.b < 10 then 10000 end + 5 > t2.x + t2.y + """ + qt_case_when_two_side_1 """explain shape plan select t1.a, t2.x from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 From 35c83c817430229493f6f5cd00478fee8d4bde6d Mon Sep 17 00:00:00 2001 From: yujun Date: Fri, 28 Nov 2025 09:34:41 +0800 Subject: [PATCH 09/10] update --- .../rewrite/JoinExtractOrFromCaseWhen.java | 179 +++++------------- 1 file changed, 52 insertions(+), 127 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/JoinExtractOrFromCaseWhen.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/JoinExtractOrFromCaseWhen.java index fb9287fe8c70cc..da51907e8bc944 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/JoinExtractOrFromCaseWhen.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/JoinExtractOrFromCaseWhen.java @@ -73,13 +73,6 @@ */ public class JoinExtractOrFromCaseWhen implements RewriteRuleFactory { - private enum SlotFrom { - FROM_LEFT_SIDE_ONLY, - FROM_RIGHT_SIDE_ONLY, - FROM_BOTH_SIDE, - FROM_NONE, - } - @Override public List buildRules() { return ImmutableList.of(logicalJoin() @@ -113,21 +106,19 @@ private boolean isConditionNeedRewrite(Expression expr, Set leftSlots, Set || !ExpressionUtils.containsCaseWhenLikeType(expr)) { return false; } - // can not rewrite none case when child - return getNoneCaseWhenChildSlotFrom(expr, leftSlots, rightSlots) != SlotFrom.FROM_BOTH_SIDE; + return true; } private Plan rewrite(LogicalJoin join, ExpressionRewriteContext context) { Set newOtherConditions = Sets.newLinkedHashSetWithExpectedSize(join.getOtherJoinConjuncts().size()); newOtherConditions.addAll(join.getOtherJoinConjuncts()); int oldCondSize = newOtherConditions.size(); - boolean extractOrExpansionCondition = OrExpansion.INSTANCE.needRewriteJoin(join); AtomicReference> leastOrExpandCondRef = new AtomicReference<>(); for (Expression expr : join.getOtherJoinConjuncts()) { tryAddOrExpansionHashCondition(expr, join, context, leastOrExpandCondRef); } for (Expression expr : join.getOtherJoinConjuncts()) { - extractExpression(expr, context, join, extractOrExpansionCondition, newOtherConditions, leastOrExpandCondRef); + extractExpression(expr, context, join, newOtherConditions, leastOrExpandCondRef); } if (leastOrExpandCondRef.get() != null) { newOtherConditions.add(leastOrExpandCondRef.get().first); @@ -141,62 +132,60 @@ private Plan rewrite(LogicalJoin join, Expressio } private void extractExpression(Expression expr, ExpressionRewriteContext context, - LogicalJoin join, boolean extractOrExpansionCondition, - Set conditions, AtomicReference> leastOrExpandCondRef) { + LogicalJoin join, Set newOtherConditions, + AtomicReference> leastOrExpandCondRef) { Set leftSlots = join.left().getOutputSet(); Set rightSlots = join.right().getOutputSet(); if (!isConditionNeedRewrite(expr, leftSlots, rightSlots)) { return; } - Boolean otherChildrenFromLeft = null; - switch (getNoneCaseWhenChildSlotFrom(expr, leftSlots, rightSlots)) { - case FROM_BOTH_SIDE: { - return; - } - case FROM_LEFT_SIDE_ONLY: { - otherChildrenFromLeft = true; - break; + List containsLeftSlotChildIndexes = Lists.newArrayList(); + List containsRightSlotChildIndexes = Lists.newArrayList(); + for (int i = 0; i < expr.children().size(); i++) { + Set childSlots = expr.child(i).getInputSlots(); + if (!Collections.disjoint(childSlots, leftSlots)) { + containsLeftSlotChildIndexes.add(i); } - case FROM_RIGHT_SIDE_ONLY: { - otherChildrenFromLeft = false; - break; + if (!Collections.disjoint(childSlots, rightSlots)) { + containsRightSlotChildIndexes.add(i); } } - - boolean canPushLeft = PushDownJoinOtherCondition.PUSH_DOWN_LEFT_VALID_TYPE.contains(join.getJoinType()); - boolean canPushRight = PushDownJoinOtherCondition.PUSH_DOWN_RIGHT_VALID_TYPE.contains(join.getJoinType()); - List childrenSlotFrom = Lists.newArrayListWithExpectedSize(expr.children().size()); - for (Expression child : expr.children()) { - childrenSlotFrom.add(getExpressionSlotFrom(child, leftSlots, rightSlots)); + // all slots are from one side, no need handle + if (containsLeftSlotChildIndexes.isEmpty() || containsRightSlotChildIndexes.isEmpty()) { + return; } - for (int i = 0; i < expr.children().size(); i++) { - if (!ExpressionUtils.containsCaseWhenLikeType(expr.child(i))) { - continue; - } - if (otherChildrenFromLeft == null) { - // extract expression for left side - if (canPushLeft) { - doExtractExpression(expr, i, true, true, childrenSlotFrom, leftSlots, rightSlots) - .ifPresent(conditions::add); - } - if (canPushRight) { - // extract expression for right side - doExtractExpression(expr, i, false, false, childrenSlotFrom, leftSlots, rightSlots) - .ifPresent(conditions::add); - } - } else { - // extract expression for one side, all child slots need from the same side - if ((otherChildrenFromLeft && canPushLeft) || (!otherChildrenFromLeft && canPushRight)) { - doExtractExpression(expr, i, otherChildrenFromLeft, otherChildrenFromLeft, - childrenSlotFrom, leftSlots, rightSlots).ifPresent(conditions::add); - } - if (expr instanceof EqualPredicate && extractOrExpansionCondition) { - // extract expression for hash condition only when the expr is equal predicate - doExtractExpression(expr, i, !otherChildrenFromLeft, otherChildrenFromLeft, childrenSlotFrom, - leftSlots, rightSlots).ifPresent( - cond -> tryAddOrExpansionHashCondition(cond, join, context, leastOrExpandCondRef)); - } + boolean extractedLeftSideCond = PushDownJoinOtherCondition.PUSH_DOWN_LEFT_VALID_TYPE + .contains(join.getJoinType()); + boolean extractedRightSideCond = PushDownJoinOtherCondition.PUSH_DOWN_RIGHT_VALID_TYPE + .contains(join.getJoinType()); + boolean extractOrExpansionCond = OrExpansion.INSTANCE.needRewriteJoin(join); + // eliminate all the right slots of all children, but we rewrite at most 1 case when expression, + // so require contains right slot child num not exceeds 1. + if (extractedLeftSideCond && containsRightSlotChildIndexes.size() == 1) { + doExtractExpression(expr, containsRightSlotChildIndexes.get(0), true, leftSlots, rightSlots) + .ifPresent(newOtherConditions::add); + } + // eliminate all the left slots of all children, but we rewrite at most 1 case when expression, + // so require contains left slot child num not exceeds 1. + if (extractedRightSideCond && containsLeftSlotChildIndexes.size() == 1) { + doExtractExpression(expr, containsLeftSlotChildIndexes.get(0), false, leftSlots, rightSlots) + .ifPresent(newOtherConditions::add); + } + if (extractOrExpansionCond && expr instanceof EqualPredicate) { + Optional orExpansionExpr = Optional.empty(); + if (containsLeftSlotChildIndexes.size() == 1 && containsRightSlotChildIndexes.size() == 2) { + // equal's two children all contain right slots, while only one child contains left slot, + // then we eliminate the right slot from the child which it contains left slots. + orExpansionExpr = doExtractExpression( + expr, containsLeftSlotChildIndexes.get(0), true, leftSlots, rightSlots); + } else if (containsLeftSlotChildIndexes.size() == 2 && containsRightSlotChildIndexes.size() == 1) { + // equal's two children all contain left slots, while one child contains left slot, + // then we eliminate the left slot from the child which it contains right slots. + orExpansionExpr = doExtractExpression( + expr, containsRightSlotChildIndexes.get(0), false, leftSlots, rightSlots); } + orExpansionExpr.ifPresent( + cond -> tryAddOrExpansionHashCondition(cond, join, context, leastOrExpandCondRef)); } } @@ -233,59 +222,11 @@ private void tryAddOrExpansionHashCondition(Expression condition, LogicalJoin leftSlots, Set rightSlots) { - SlotFrom mergeSlotFrom = SlotFrom.FROM_NONE; - for (int i = 0; i < expr.children().size(); i++) { - Expression child = expr.child(i); - if (ExpressionUtils.containsCaseWhenLikeType(child)) { - continue; - } - SlotFrom childSlotFrom = getExpressionSlotFrom(child, leftSlots, rightSlots); - switch (childSlotFrom) { - case FROM_LEFT_SIDE_ONLY: { - if (mergeSlotFrom == SlotFrom.FROM_RIGHT_SIDE_ONLY) { - return SlotFrom.FROM_BOTH_SIDE; - } - mergeSlotFrom = childSlotFrom; - break; - } - case FROM_RIGHT_SIDE_ONLY: { - if (mergeSlotFrom == SlotFrom.FROM_LEFT_SIDE_ONLY) { - return SlotFrom.FROM_BOTH_SIDE; - } - mergeSlotFrom = childSlotFrom; - break; - } - case FROM_BOTH_SIDE: { - return childSlotFrom; - } - } - } - - return mergeSlotFrom; - } - - // extract case when expression from `expr`'s child at `extractChildIndex`. - // after extraction, all slots of this child at `extractChildIndex` - // are from one side indicated by `extractChildSlotFromLeft`. - // for expr's other children, their slot from need met the otherChildSlotFromLeft + // extract case when expression from `expr`'s child C at `extractChildIndex`. + // after extraction, new C will contain only slots from one side indicated by `extractedChildSlotFromLeft`, + // but new C can allow no contains any slots. private Optional doExtractExpression(Expression expr, int extractChildIndex, - boolean extractChildSlotFromLeft, boolean otherChildSlotFromLeft, List childrenSlotFrom, - Set leftSlots, Set rightSlots) { - for (int i = 0; i < expr.children().size(); i++) { - // we only rewrite extractChildIndex, - // so for other child, it must need the requirement of `otherChildSlotFromLeft` - if (i != extractChildIndex) { - // use childrenSlotFrom to avoid call Collection.disjoint too many, but maybe we can delete it. - SlotFrom slotFrom = childrenSlotFrom.get(i); - if (slotFrom == SlotFrom.FROM_BOTH_SIDE - || (otherChildSlotFromLeft && slotFrom == SlotFrom.FROM_RIGHT_SIDE_ONLY) - || (!otherChildSlotFromLeft && slotFrom == SlotFrom.FROM_LEFT_SIDE_ONLY)) { - return Optional.empty(); - } - } - } - + boolean extractChildSlotFromLeft, Set leftSlots, Set rightSlots) { Expression expandChild = expr.child(extractChildIndex); Optional> resultOpt = tryExtractCaseWhen( expandChild, extractChildSlotFromLeft, leftSlots, rightSlots); @@ -295,7 +236,7 @@ private Optional doExtractExpression(Expression expr, int extractChi List expandTargetExpressions = resultOpt.get(); if (expandTargetExpressions.size() <= 1) { - // if size = 1, then it don't expand, should be just the expr itself. + // if size = 1, then C don't expand, should be just the expr itself. return Optional.empty(); } @@ -376,7 +317,7 @@ private Optional> tryExtractCaseWhen(Expression expr, boolean s return Optional.of(resultExpressions); } - // check whether all slots in expr are from one side, allow empty slots. + // check whether all slots in expr are from one side, allow contains no slots. private boolean isAllSlotsFromLeftSide(Expression expr, boolean slotFromLeft, Set leftSlots, Set rightSlots) { Set exprSlots = expr.getInputSlots(); @@ -388,20 +329,4 @@ private boolean isAllSlotsFromLeftSide(Expression expr, boolean slotFromLeft, return Collections.disjoint(exprSlots, leftSlots); } } - - private SlotFrom getExpressionSlotFrom(Expression expr, Set leftSlots, Set rightSlots) { - Set exprSlots = expr.getInputSlots(); - boolean containsLeft = !Collections.disjoint(exprSlots, leftSlots); - boolean containsRight = !Collections.disjoint(exprSlots, rightSlots); - if (containsLeft && containsRight) { - return SlotFrom.FROM_BOTH_SIDE; - } else if (containsLeft) { - return SlotFrom.FROM_LEFT_SIDE_ONLY; - } else if (containsRight) { - return SlotFrom.FROM_RIGHT_SIDE_ONLY; - } else { - return SlotFrom.FROM_NONE; - } - } - } From 88c442b80571f19408097d418c6c0fad81030bc5 Mon Sep 17 00:00:00 2001 From: yujun Date: Mon, 1 Dec 2025 13:31:07 +0800 Subject: [PATCH 10/10] extract hash condition --- .../rewrite/JoinExtractOrFromCaseWhen.java | 51 ++++++++++++------- .../rewrite/PushDownJoinOtherCondition.java | 6 +-- .../join_extract_or_from_case_when.out | 11 ++++ .../join_extract_or_from_case_when.groovy | 6 +++ 4 files changed, 52 insertions(+), 22 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/JoinExtractOrFromCaseWhen.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/JoinExtractOrFromCaseWhen.java index da51907e8bc944..929ecbc272391f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/JoinExtractOrFromCaseWhen.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/JoinExtractOrFromCaseWhen.java @@ -82,14 +82,22 @@ public List buildRules() { } private boolean needRewrite(LogicalJoin join) { - if (!PushDownJoinOtherCondition.needRewrite(join) && !OrExpansion.INSTANCE.needRewriteJoin(join)) { - return false; + if (PushDownJoinOtherCondition.needRewrite(join) || OrExpansion.INSTANCE.needRewriteJoin(join)) { + Set leftSlots = join.left().getOutputSet(); + Set rightSlots = join.right().getOutputSet(); + for (Expression expr : join.getOtherJoinConjuncts()) { + if (isConditionNeedRewrite(expr, leftSlots, rightSlots)) { + return true; + } + } } - Set leftSlots = join.left().getOutputSet(); - Set rightSlots = join.right().getOutputSet(); - for (Expression expr : join.getOtherJoinConjuncts()) { - if (isConditionNeedRewrite(expr, leftSlots, rightSlots)) { - return true; + if (!join.isMarkJoin()) { + Set leftSlots = join.left().getOutputSet(); + Set rightSlots = join.right().getOutputSet(); + for (Expression expr : join.getHashJoinConjuncts()) { + if (isConditionNeedRewrite(expr, leftSlots, rightSlots)) { + return true; + } } } return false; @@ -109,7 +117,7 @@ private boolean isConditionNeedRewrite(Expression expr, Set leftSlots, Set return true; } - private Plan rewrite(LogicalJoin join, ExpressionRewriteContext context) { + private Plan rewrite(LogicalJoin join, ExpressionRewriteContext context) { Set newOtherConditions = Sets.newLinkedHashSetWithExpectedSize(join.getOtherJoinConjuncts().size()); newOtherConditions.addAll(join.getOtherJoinConjuncts()); int oldCondSize = newOtherConditions.size(); @@ -120,6 +128,17 @@ private Plan rewrite(LogicalJoin join, Expressio for (Expression expr : join.getOtherJoinConjuncts()) { extractExpression(expr, context, join, newOtherConditions, leastOrExpandCondRef); } + if (!join.isMarkJoin()) { + // Notice: if join's hash conditions is not empty, then OrExpansion.needRewriteJoin will return fail + // so it will not extract OrExpansion from the hash condition, it only extract one side condition + // for hash condition: if(t1.a > 10, 1, 100) = if(t2.x > 10, 2, 200), + // we can still extract two one-side condition: + // 1) if(t1.a > 10, 1, 100) = 2 or if(t1.a > 10, 1, 100) = 200 + // 2) if(t2.x > 10, 2, 200) = 1 or if(t2.x > 10, 2, 200) = 100 + for (Expression expr : join.getHashJoinConjuncts()) { + extractExpression(expr, context, join, newOtherConditions, leastOrExpandCondRef); + } + } if (leastOrExpandCondRef.get() != null) { newOtherConditions.add(leastOrExpandCondRef.get().first); } @@ -132,7 +151,7 @@ private Plan rewrite(LogicalJoin join, Expressio } private void extractExpression(Expression expr, ExpressionRewriteContext context, - LogicalJoin join, Set newOtherConditions, + LogicalJoin join, Set newOtherConditions, AtomicReference> leastOrExpandCondRef) { Set leftSlots = join.left().getOutputSet(); Set rightSlots = join.right().getOutputSet(); @@ -154,9 +173,10 @@ private void extractExpression(Expression expr, ExpressionRewriteContext context if (containsLeftSlotChildIndexes.isEmpty() || containsRightSlotChildIndexes.isEmpty()) { return; } - boolean extractedLeftSideCond = PushDownJoinOtherCondition.PUSH_DOWN_LEFT_VALID_TYPE + boolean canPushDownOther = PushDownJoinOtherCondition.needRewrite(join); + boolean extractedLeftSideCond = canPushDownOther && PushDownJoinOtherCondition.PUSH_DOWN_LEFT_VALID_TYPE .contains(join.getJoinType()); - boolean extractedRightSideCond = PushDownJoinOtherCondition.PUSH_DOWN_RIGHT_VALID_TYPE + boolean extractedRightSideCond = canPushDownOther && PushDownJoinOtherCondition.PUSH_DOWN_RIGHT_VALID_TYPE .contains(join.getJoinType()); boolean extractOrExpansionCond = OrExpansion.INSTANCE.needRewriteJoin(join); // eliminate all the right slots of all children, but we rewrite at most 1 case when expression, @@ -190,7 +210,7 @@ private void extractExpression(Expression expr, ExpressionRewriteContext context } // Or Expansion only use one condition, so we keep the one with least disjunctions. - private void tryAddOrExpansionHashCondition(Expression condition, LogicalJoin join, + private void tryAddOrExpansionHashCondition(Expression condition, LogicalJoin join, ExpressionRewriteContext context, AtomicReference> leastOrExpandCondRef) { // Or Expansion only works for all the disjunctions are equal predicates List disjunctions = ExpressionUtils.extractDisjunction(condition); @@ -247,12 +267,7 @@ private Optional doExtractExpression(Expression expr, int extractChi disjuncts.add(expr.withChildren(newChildren)); } - Expression result = ExpressionUtils.or(disjuncts); - if (result.getInputSlots().isEmpty()) { - return Optional.empty(); - } - - return Optional.of(ExpressionUtils.or(result)); + return Optional.of(ExpressionUtils.or(disjuncts)); } // try to extract case when like expressions from expr. diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOtherCondition.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOtherCondition.java index 013b47f9abbfab..75c5e3085d310c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOtherCondition.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOtherCondition.java @@ -66,6 +66,7 @@ public class PushDownJoinOtherCondition extends OneRewriteRuleFactory { public Rule build() { return logicalJoin() // TODO: we may need another rule to handle on true or on false condition + .when(join -> !join.getOtherJoinConjuncts().isEmpty()) .when(PushDownJoinOtherCondition::needRewrite) .then(join -> { List otherJoinConjuncts = join.getOtherJoinConjuncts(); @@ -103,10 +104,7 @@ && allCoveredBy(otherConjunct, join.right().getOutputSet())) { * check need rewrite */ public static boolean needRewrite(LogicalJoin join) { - return !join.getOtherJoinConjuncts().isEmpty() - && !join.isMarkJoin() - && (PUSH_DOWN_LEFT_VALID_TYPE.contains(join.getJoinType()) - || PUSH_DOWN_RIGHT_VALID_TYPE.contains(join.getJoinType())); + return !join.isMarkJoin(); } private static boolean allCoveredBy(Expression predicate, Set inputSlotSet) { diff --git a/regression-test/data/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.out b/regression-test/data/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.out index c022f2fb688e29..84590c64610e4d 100644 --- a/regression-test/data/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.out +++ b/regression-test/data/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.out @@ -104,6 +104,17 @@ PhysicalResultSink ------PhysicalProject[(x + y) AS `(x + y)`, t2.x] --------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2(t2)] +-- !hash_cond_for_one_side_1 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----hashJoin[INNER_JOIN] hashCondition=((expr_if((a > 1), 1, 100) = expr_if((x > 2), 1, 200))) otherCondition=() +------PhysicalProject[if((a > 1), 1, 100) AS `expr_if((a > 1), 1, 100)`, t1.a] +--------filter((t1.a > 1)) +----------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1(t1)] +------PhysicalProject[if((x > 2), 1, 200) AS `expr_if((x > 2), 1, 200)`, t2.x] +--------filter((t2.x > 2)) +----------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2(t2)] + -- !case_when_two_side_1 -- PhysicalCteAnchor ( cteId=CTEId#0 ) --PhysicalCteProducer ( cteId=CTEId#0 ) diff --git a/regression-test/suites/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.groovy b/regression-test/suites/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.groovy index ab74f5cf28967d..708c7fa6eea8c7 100644 --- a/regression-test/suites/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.groovy +++ b/regression-test/suites/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.groovy @@ -102,6 +102,12 @@ suite('join_extract_or_from_case_when') { on case when t1.a > 0 then 100 when t1.b < 10 then 10000 end + 5 > t2.x + t2.y """ + qt_hash_cond_for_one_side_1 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on if(t1.a > 1, 1, 100) = if(t2.x > 2, 1, 200) + """ + qt_case_when_two_side_1 """explain shape plan select t1.a, t2.x from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2