diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicateByReplace.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicateByReplace.java index 7b052b8e870c00..c2ca99f0b61b9c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicateByReplace.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicateByReplace.java @@ -29,6 +29,7 @@ import org.apache.doris.nereids.trees.expressions.InPredicate; import org.apache.doris.nereids.trees.expressions.Like; import org.apache.doris.nereids.trees.expressions.Not; +import org.apache.doris.nereids.trees.expressions.Or; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait; import org.apache.doris.nereids.trees.expressions.literal.Literal; @@ -126,6 +127,14 @@ public Void visitLike(Like like, Map> context) { return null; } + @Override + public Void visitOr(Or or, Map> context) { + for (Expression expr : getAllSubExpressions(or)) { + context.computeIfAbsent(expr, k -> new LinkedHashSet<>()).add(or); + } + return null; + } + private boolean validComparisonPredicate(ComparisonPredicate comparisonPredicate) { return comparisonPredicate.right() instanceof Literal; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java index d4c5af6e578629..f0d4deac1476a9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java @@ -18,9 +18,12 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.mysql.MysqlCommand; +import org.apache.doris.nereids.CascadesContext; import org.apache.doris.nereids.jobs.JobContext; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.IsNull; import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Or; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator; import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; @@ -37,15 +40,20 @@ import org.apache.doris.nereids.util.PredicateInferUtils; import org.apache.doris.qe.ConnectContext; +import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Lists; import com.google.common.collect.Sets; import java.util.HashMap; +import java.util.HashSet; import java.util.LinkedHashSet; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.function.Supplier; /** * infer additional predicates for `LogicalFilter` and `LogicalJoin`. @@ -91,20 +99,23 @@ public Plan visitLogicalJoin(LogicalJoin join, J Plan right = join.right(); Set expressions = getAllExpressions(left, right, join.getOnClauseCondition()); switch (join.getJoinType()) { - case INNER_JOIN: case CROSS_JOIN: - case LEFT_SEMI_JOIN: - case RIGHT_SEMI_JOIN: left = inferNewPredicate(left, expressions); right = inferNewPredicate(right, expressions); break; + case INNER_JOIN: + case LEFT_SEMI_JOIN: + case RIGHT_SEMI_JOIN: + left = inferNewPredicateRemoveUselessIsNull(left, expressions, join, context.getCascadesContext()); + right = inferNewPredicateRemoveUselessIsNull(right, expressions, join, context.getCascadesContext()); + break; case LEFT_OUTER_JOIN: case LEFT_ANTI_JOIN: - right = inferNewPredicate(right, expressions); + right = inferNewPredicateRemoveUselessIsNull(right, expressions, join, context.getCascadesContext()); break; case RIGHT_OUTER_JOIN: case RIGHT_ANTI_JOIN: - left = inferNewPredicate(left, expressions); + left = inferNewPredicateRemoveUselessIsNull(left, expressions, join, context.getCascadesContext()); break; default: break; @@ -122,12 +133,16 @@ public Plan visitLogicalFilter(LogicalFilter filter, JobContext return new LogicalEmptyRelation(StatementScopeIdGenerator.newRelationId(), filter.getOutput()); } filter = visitChildren(this, filter, context); - Set filterPredicates = pullUpPredicates(filter); - filterPredicates.removeAll(pullUpAllPredicates(filter.child())); - if (filterPredicates.isEmpty()) { + Set inferredPredicates = pullUpPredicates(filter); + inferredPredicates.removeAll(pullUpAllPredicates(filter.child())); + if (inferredPredicates.isEmpty()) { return filter.child(); } - return new LogicalFilter<>(ImmutableSet.copyOf(filterPredicates), filter.child()); + if (inferredPredicates.equals(filter.getConjuncts())) { + return filter; + } else { + return new LogicalFilter<>(ImmutableSet.copyOf(inferredPredicates), filter.child()); + } } @Override @@ -139,15 +154,18 @@ public Plan visitLogicalExcept(LogicalExcept except, JobContext context) { } ImmutableList.Builder builder = ImmutableList.builder(); builder.add(except.child(0)); + boolean changed = false; for (int i = 1; i < except.arity(); ++i) { Map replaceMap = new HashMap<>(); for (int j = 0; j < except.getOutput().size(); ++j) { NamedExpression output = except.getOutput().get(j); replaceMap.put(output, except.getRegularChildOutput(i).get(j)); } - builder.add(inferNewPredicate(except.child(i), ExpressionUtils.replace(baseExpressions, replaceMap))); + Plan newChild = inferNewPredicate(except.child(i), ExpressionUtils.replace(baseExpressions, replaceMap)); + changed = changed || newChild != except.child(i); + builder.add(newChild); } - return except.withChildren(builder.build()); + return changed ? except.withChildren(builder.build()) : except; } @Override @@ -158,15 +176,18 @@ public Plan visitLogicalIntersect(LogicalIntersect intersect, JobContext context return intersect; } ImmutableList.Builder builder = ImmutableList.builder(); + boolean changed = false; for (int i = 0; i < intersect.arity(); ++i) { Map replaceMap = new HashMap<>(); for (int j = 0; j < intersect.getOutput().size(); ++j) { NamedExpression output = intersect.getOutput().get(j); replaceMap.put(output, intersect.getRegularChildOutput(i).get(j)); } - builder.add(inferNewPredicate(intersect.child(i), ExpressionUtils.replace(baseExpressions, replaceMap))); + Plan newChild = inferNewPredicate(intersect.child(i), ExpressionUtils.replace(baseExpressions, replaceMap)); + changed = changed || newChild != intersect.child(i); + builder.add(newChild); } - return intersect.withChildren(builder.build()); + return changed ? intersect.withChildren(builder.build()) : intersect; } private Set getAllExpressions(Plan left, Plan right, Optional condition) { @@ -196,4 +217,60 @@ private Plan inferNewPredicate(Plan plan, Set expressions) { predicates.removeAll(plan.accept(pullUpAllPredicates, null)); return PlanUtils.filterOrSelf(predicates, plan); } + + // Remove redundant "or is null" from expressions. + // For example, when we have a t2 left join t3 condition t2.a=t3.a, we can infer that t3.a is not null. + // If we find a predicate like "t3.a = 1 or t3.a is null" in expressions, we change it to "t3.a=1". + private Plan inferNewPredicateRemoveUselessIsNull(Plan plan, Set expressions, + LogicalJoin join, CascadesContext cascadesContext) { + Supplier> supplier = Suppliers.memoize(() -> { + Set all = new HashSet<>(); + all.addAll(join.getHashJoinConjuncts()); + all.addAll(join.getOtherJoinConjuncts()); + return ExpressionUtils.inferNotNullSlots(all, cascadesContext); + }); + + Set predicates = new LinkedHashSet<>(); + Set planOutputs = plan.getOutputSet(); + for (Expression expr : expressions) { + Set slots = expr.getInputSlots(); + if (slots.isEmpty() || !planOutputs.containsAll(slots)) { + continue; + } + if (expr instanceof Or && expr.isInferred()) { + List orChildren = ExpressionUtils.extractDisjunction(expr); + List newOrChildren = Lists.newArrayList(); + boolean changed = false; + for (Expression orChild : orChildren) { + if (orChild instanceof IsNull && orChild.child(0) instanceof Slot + && supplier.get().contains(orChild.child(0))) { + changed = true; + continue; + } + newOrChildren.add(orChild); + } + if (changed) { + if (newOrChildren.size() == 1) { + predicates.add(withInferredIfSupported(newOrChildren.get(0), expr)); + } else if (newOrChildren.size() > 1) { + predicates.add(ExpressionUtils.or(newOrChildren).withInferred(true)); + } + } else { + predicates.add(expr); + } + } else { + predicates.add(expr); + } + } + predicates.removeAll(plan.accept(pullUpAllPredicates, null)); + return PlanUtils.filterOrSelf(predicates, plan); + } + + private Expression withInferredIfSupported(Expression expression, Expression originExpr) { + try { + return expression.withInferred(true); + } catch (RuntimeException e) { + return originExpr; + } + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java index 8575ab9e5a21b5..8da895d34ab7fb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java @@ -260,6 +260,10 @@ public ImmutableSet visitLogicalJoin(LogicalJoin visitLogicalJoin(LogicalJoin generateNullTolerantPredicates(Set predicates, Set nullableSlots) { + if (predicates.isEmpty() || nullableSlots.isEmpty()) { + return predicates; + } + Set tolerant = Sets.newLinkedHashSetWithExpectedSize(predicates.size()); + for (Expression predicate : predicates) { + Set predicateSlots = predicate.getInputSlots(); + List orChildren = new ArrayList<>(); + if (predicateSlots.size() == 1) { + Slot slot = predicateSlots.iterator().next(); + if (nullableSlots.contains(slot)) { + orChildren.add(new IsNull(slot)); + } + } + if (!orChildren.isEmpty()) { + List expandedOr = new ArrayList<>(2); + expandedOr.add(predicate); + expandedOr.addAll(orChildren); + tolerant.add(ExpressionUtils.or(expandedOr)); + } + } + return tolerant; + } + private ImmutableSet getFiltersFromUnionChild(LogicalUnion union, Void context) { Set filters = new LinkedHashSet<>(); for (int i = 0; i < union.getArity(); ++i) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CompoundPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CompoundPredicate.java index 13f99839ffb119..f2d6a350b163f1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CompoundPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CompoundPredicate.java @@ -37,7 +37,11 @@ public abstract class CompoundPredicate extends Expression implements ExpectsInp private String symbol; public CompoundPredicate(List children, String symbol) { - super(children); + this(children, symbol, false); + } + + public CompoundPredicate(List children, String symbol, boolean inferred) { + super(children, inferred); this.symbol = symbol; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Or.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Or.java index 235c1bc2f0a70b..b62c0e76b40715 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Or.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Or.java @@ -36,20 +36,28 @@ public class Or extends CompoundPredicate { * @param right right child of comparison predicate */ public Or(Expression left, Expression right) { + this(left, right, false); + } + + public Or(Expression left, Expression right, boolean inferred) { this(ExpressionUtils.mergeList( ExpressionUtils.extractDisjunction(left), - ExpressionUtils.extractDisjunction(right))); + ExpressionUtils.extractDisjunction(right)), inferred); } public Or(List children) { - super(children, "OR"); + this(children, false); + } + + public Or(List children, boolean inferred) { + super(children, "OR", inferred); Preconditions.checkArgument(children.size() >= 2); } @Override public Expression withChildren(List children) { Preconditions.checkArgument(children.size() >= 2); - return new Or(children); + return new Or(children, this.isInferred()); } @Override @@ -90,4 +98,9 @@ public List children() { } return flattenChildren; } + + @Override + public Expression withInferred(boolean inferred) { + return new Or(children, inferred); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java index 07a2d1b7d97cda..9499dc601caec1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java @@ -354,7 +354,7 @@ public LogicalAggregate withChildGroupByAndOutput(List groupBy public LogicalAggregate withChildGroupByAndOutputAndSourceRepeat(List groupByExprList, List outputExpressionList, Plan newChild, - Optional> sourceRepeat) { + Optional> sourceRepeat) { return new LogicalAggregate<>(groupByExprList, outputExpressionList, normalized, ordinalIsResolved, generated, hasPushed, withInProjection, sourceRepeat, Optional.empty(), Optional.empty(), newChild); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicateByReplaceTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicateByReplaceTest.java index 98fbbfbec13f2e..ad35028d7b9467 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicateByReplaceTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicateByReplaceTest.java @@ -152,7 +152,7 @@ public void testInferWithOrPredicate() { inputs.add(equalTo); Set result = InferPredicateByReplace.infer(inputs); - Assertions.assertEquals(2, result.size()); + Assertions.assertEquals(3, result.size()); } @Test diff --git a/regression-test/data/nereids_rules_p0/infer_predicate/extend_infer_equal_predicate.out b/regression-test/data/nereids_rules_p0/infer_predicate/extend_infer_equal_predicate.out index 2cda3936670925..c646bf2e485b20 100644 --- a/regression-test/data/nereids_rules_p0/infer_predicate/extend_infer_equal_predicate.out +++ b/regression-test/data/nereids_rules_p0/infer_predicate/extend_infer_equal_predicate.out @@ -114,14 +114,16 @@ PhysicalResultSink --hashJoin[INNER_JOIN] hashCondition=((t1.a = t2.a)) otherCondition=() ----filter(OR[(t1.a < 2),(t1.a > 10)]) ------PhysicalOlapScan[extend_infer_t3(t1)] -----PhysicalOlapScan[extend_infer_t4(t2)] +----filter(OR[(t2.a < 2),(t2.a > 10)]) +------PhysicalOlapScan[extend_infer_t4(t2)] -- !test_or2 -- PhysicalResultSink --hashJoin[INNER_JOIN] hashCondition=((t1.a = t2.a)) otherCondition=() ----filter(OR[(t1.a < 2),(t1.a > 10)]) ------PhysicalOlapScan[extend_infer_t3(t1)] -----PhysicalOlapScan[extend_infer_t4(t2)] +----filter(OR[(t2.a < 2),(t2.a > 10)]) +------PhysicalOlapScan[extend_infer_t4(t2)] -- !test_sign_predicate -- PhysicalResultSink @@ -772,3 +774,116 @@ PhysicalResultSink -- !pull_up_from_agg -- 0 +-- !qt_leftjoin_right_pull_up_shape_shape -- +PhysicalResultSink +--hashJoin[LEFT_OUTER_JOIN] hashCondition=((t2.a = t3.a)) otherCondition=() +----hashJoin[LEFT_OUTER_JOIN] hashCondition=((t1.a = t2.a)) otherCondition=() +------filter((t1.a = 1)) +--------PhysicalOlapScan[extend_infer_t3(t1)] +------filter((t2.a = 1)) +--------PhysicalOlapScan[extend_infer_t4(t2)] +----filter((t3.a = 1)) +------PhysicalOlapScan[extend_infer_t5(t3)] + +-- !qt_leftjoin_right_pull_up_shape_result -- + +-- !qt_multi_leftjoin_right_pull_up_shape_shape -- +PhysicalResultSink +--hashJoin[LEFT_OUTER_JOIN] hashCondition=((t2.a = t5.a)) otherCondition=() +----hashJoin[LEFT_OUTER_JOIN] hashCondition=((t4.a = t3.a)) otherCondition=() +------hashJoin[LEFT_OUTER_JOIN] hashCondition=((t2.a = t3.a)) otherCondition=() +--------hashJoin[LEFT_OUTER_JOIN] hashCondition=((t1.a = t2.a)) otherCondition=() +----------filter((t1.a = 1)) +------------PhysicalOlapScan[extend_infer_t3(t1)] +----------filter((t2.a = 1)) +------------PhysicalOlapScan[extend_infer_t4(t2)] +--------filter((t3.a = 1)) +----------PhysicalOlapScan[extend_infer_t5(t3)] +------filter((t4.a = 1)) +--------PhysicalOlapScan[extend_infer_t5(t4)] +----filter((t5.a = 1)) +------PhysicalOlapScan[extend_infer_t5(t5)] + +-- !qt_multi_leftjoin_right_pull_up_shape_result -- + +-- !qt_leftjoin_right_pull_up_in_shape_shape -- +PhysicalResultSink +--hashJoin[LEFT_OUTER_JOIN] hashCondition=((t2.a = t3.a)) otherCondition=() +----hashJoin[LEFT_OUTER_JOIN] hashCondition=((t1.a = t2.a)) otherCondition=() +------filter(a IN (1, 2)) +--------PhysicalOlapScan[extend_infer_t3(t1)] +------filter(a IN (1, 2)) +--------PhysicalOlapScan[extend_infer_t4(t2)] +----filter(a IN (1, 2)) +------PhysicalOlapScan[extend_infer_t5(t3)] + +-- !qt_leftjoin_right_pull_up_in_shape_result -- + +-- !qt_leftjoin_right_pull_up_is_null_shape_shape -- +PhysicalResultSink +--hashJoin[LEFT_OUTER_JOIN] hashCondition=((t2.a = t3.a)) otherCondition=() +----hashJoin[LEFT_OUTER_JOIN] hashCondition=((t1.a = t2.a)) otherCondition=() +------filter(a IS NULL) +--------PhysicalOlapScan[extend_infer_t3(t1)] +------PhysicalOlapScan[extend_infer_t4(t2)] +----PhysicalOlapScan[extend_infer_t5(t3)] + +-- !qt_leftjoin_right_pull_up_is_null_shape_result -- +\N \N 9 3 \N \N \N \N \N \N \N \N +\N d2 3 55 \N \N \N \N \N \N \N \N + +-- !qt_leftjoin_right_pull_up_is_not_null_shape_shape -- +PhysicalResultSink +--hashJoin[LEFT_OUTER_JOIN] hashCondition=((t2.a = t3.a)) otherCondition=() +----hashJoin[LEFT_OUTER_JOIN] hashCondition=((t1.a = t2.a)) otherCondition=() +------filter(( not a IS NULL)) +--------PhysicalOlapScan[extend_infer_t3(t1)] +------PhysicalOlapScan[extend_infer_t4(t2)] +----PhysicalOlapScan[extend_infer_t5(t3)] + +-- !qt_leftjoin_right_pull_up_is_not_null_shape_result -- +0 d2 3 5 0 d2 2 2 \N \N \N \N +100 d2 3 5 100 d2 3 \N \N \N \N \N +12 \N 9 3 \N \N \N \N \N \N \N \N +33 d2 2 5 33 d2 23 5 \N \N \N \N +78 \N 9 3 78 d2 23 5 \N \N \N \N + +-- !qt_left_join_inner_shape -- +PhysicalResultSink +--NestedLoopJoin[INNER_JOIN] +----NestedLoopJoin[INNER_JOIN] +------filter((t1.a = 1)) +--------PhysicalOlapScan[extend_infer_t3(t1)] +------filter((t2.a = 1)) +--------PhysicalOlapScan[extend_infer_t4(t2)] +----filter((t3.a = 1)) +------PhysicalOlapScan[extend_infer_t5(t3)] + +-- !qt_left_join_inner_result -- + +-- !qt_left_join_semi_shape -- +PhysicalResultSink +--NestedLoopJoin[LEFT_SEMI_JOIN] +----NestedLoopJoin[INNER_JOIN] +------filter((t1.a = 1)) +--------PhysicalOlapScan[extend_infer_t3(t1)] +------filter((t2.a = 1)) +--------PhysicalOlapScan[extend_infer_t4(t2)] +----filter((t3.a = 1)) +------PhysicalOlapScan[extend_infer_t5(t3)] + +-- !qt_left_join_semi_result -- + +-- !qt_left_join_anti_shape -- +PhysicalResultSink +--hashJoin[LEFT_ANTI_JOIN] hashCondition=((t2.a = t3.a)) otherCondition=() +----hashJoin[LEFT_OUTER_JOIN] hashCondition=((t1.a = t2.a)) otherCondition=() +------filter((t1.a = 1)) +--------PhysicalOlapScan[extend_infer_t3(t1)] +------filter((t2.a = 1)) +--------PhysicalOlapScan[extend_infer_t4(t2)] +----filter((t3.a = 1)) +------PhysicalOlapScan[extend_infer_t5(t3)] + +-- !qt_left_join_anti_result -- + diff --git a/regression-test/data/nereids_rules_p0/infer_predicate/pull_up_predicate_agg.out b/regression-test/data/nereids_rules_p0/infer_predicate/pull_up_predicate_agg.out index 330b2e93f2dbd5..81529b85b48b3b 100644 --- a/regression-test/data/nereids_rules_p0/infer_predicate/pull_up_predicate_agg.out +++ b/regression-test/data/nereids_rules_p0/infer_predicate/pull_up_predicate_agg.out @@ -46,11 +46,13 @@ PhysicalResultSink --PhysicalQuickSort[MERGE_SORT] ----PhysicalQuickSort[LOCAL_SORT] ------hashJoin[INNER_JOIN] hashCondition=((t.col1 = t2.a) and (t.col2 = t2.c)) otherCondition=() ---------hashAgg[GLOBAL] -----------hashAgg[LOCAL] -------------filter((test_pull_up_agg_t1.a <= 20) and (test_pull_up_agg_t1.c < 200)) ---------------PhysicalOlapScan[test_pull_up_agg_t1] ---------PhysicalOlapScan[test_pull_up_agg_t2(t2)] +--------filter((t.col1 <= 20) and (t.col2 < 200)) +----------hashAgg[GLOBAL] +------------hashAgg[LOCAL] +--------------filter((test_pull_up_agg_t1.a <= 20) and (test_pull_up_agg_t1.c < 200)) +----------------PhysicalOlapScan[test_pull_up_agg_t1] +--------filter((t2.a <= 20) and (t2.c < 200)) +----------PhysicalOlapScan[test_pull_up_agg_t2(t2)] -- !pull_up_from_agg_to_filter_with_same_cond_shape -- PhysicalResultSink diff --git a/regression-test/suites/nereids_rules_p0/infer_predicate/extend_infer_equal_predicate.groovy b/regression-test/suites/nereids_rules_p0/infer_predicate/extend_infer_equal_predicate.groovy index b7a6090e901ed3..de730c5b574a0e 100644 --- a/regression-test/suites/nereids_rules_p0/infer_predicate/extend_infer_equal_predicate.groovy +++ b/regression-test/suites/nereids_rules_p0/infer_predicate/extend_infer_equal_predicate.groovy @@ -377,4 +377,39 @@ suite("extend_infer_equal_predicate") { qt_pull_up_from_intersect """select a from(select a from (select t1.a from extend_infer_t3 t1 where t1.a<10 intersect select t2.a from extend_infer_t4 t2 where t2.a<10 ) tt limit 10) t where a<10 order by 1 ;""" qt_pull_up_from_agg """select a from (select a from extend_infer_t3 t1 where a<10 group by a limit 10) t where a<10 order by 1""" + + def explain_and_result = { tag, sql -> + "qt_${tag}_shape" "explain shape plan ${sql}" + "order_qt_${tag}_result" "${sql}" + } + + // test left join right table predicate pull up + explain_and_result 'qt_leftjoin_right_pull_up_shape', ''' + select * from extend_infer_t3 t1 left join extend_infer_t4 t2 on t1.a=t2.a left join extend_infer_t5 t3 on t2.a= t3.a where t1.a=1; + ''' + // test multi left join right table predicate pull up + explain_and_result "qt_multi_leftjoin_right_pull_up_shape", """ + select * from extend_infer_t3 t1 left join extend_infer_t4 t2 on t1.a=t2.a left join extend_infer_t5 t3 on t2.a= t3.a left join extend_infer_t5 t4 on t4.a=t3.a left join extend_infer_t5 t5 on t2.a=t5.a where t1.a=1; + """ + explain_and_result "qt_leftjoin_right_pull_up_in_shape", """ + select * from extend_infer_t3 t1 left join extend_infer_t4 t2 on t1.a=t2.a left join extend_infer_t5 t3 on t2.a= t3.a where t1.a in (1,2); + """ + // is null may be can be inferred but we do not infer it now + explain_and_result "qt_leftjoin_right_pull_up_is_null_shape", """ + select * from extend_infer_t3 t1 left join extend_infer_t4 t2 on t1.a=t2.a left join extend_infer_t5 t3 on t2.a= t3.a where t1.a is null; + """ + // is not null may be need not be innfered + explain_and_result "qt_leftjoin_right_pull_up_is_not_null_shape", """ + select * from extend_infer_t3 t1 left join extend_infer_t4 t2 on t1.a=t2.a left join extend_infer_t5 t3 on t2.a= t3.a where t1.a is not null; + """ + + explain_and_result 'qt_left_join_inner', ''' + select * from extend_infer_t3 t1 left join extend_infer_t4 t2 on t1.a=t2.a inner join extend_infer_t5 t3 on t2.a= t3.a where t1.a=1; + ''' + explain_and_result 'qt_left_join_semi', ''' + select * from extend_infer_t3 t1 left join extend_infer_t4 t2 on t1.a=t2.a left semi join extend_infer_t5 t3 on t2.a= t3.a where t1.a=1; + ''' + explain_and_result 'qt_left_join_anti', ''' + select * from extend_infer_t3 t1 left join extend_infer_t4 t2 on t1.a=t2.a left anti join extend_infer_t5 t3 on t2.a= t3.a where t1.a=1; + ''' } \ No newline at end of file