diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index 60522e3da39d9f..4e58cd170121c8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -51,6 +51,7 @@ import org.apache.doris.nereids.rules.rewrite.CollectFilterAboveConsumer; import org.apache.doris.nereids.rules.rewrite.ColumnPruning; import org.apache.doris.nereids.rules.rewrite.ConvertInnerOrCrossJoin; +import org.apache.doris.nereids.rules.rewrite.ConvertOuterJoinToAntiJoin; import org.apache.doris.nereids.rules.rewrite.CountDistinctRewrite; import org.apache.doris.nereids.rules.rewrite.CountLiteralRewrite; import org.apache.doris.nereids.rules.rewrite.CreatePartitionTopNFromWindow; @@ -446,8 +447,7 @@ public class Rewriter extends AbstractBatchJobExecutor { new CollectCteConsumerOutput() ) ), - topic("Collect used column", custom(RuleType.COLLECT_COLUMNS, QueryColumnCollector::new) - ) + topic("Collect used column", custom(RuleType.COLLECT_COLUMNS, QueryColumnCollector::new)) ) ); @@ -455,6 +455,7 @@ public class Rewriter extends AbstractBatchJobExecutor { ImmutableSet.of(LogicalCTEAnchor.class), () -> jobs( // after variant sub path pruning, we need do column pruning again + bottomUp(RuleSet.PUSH_DOWN_FILTERS), custom(RuleType.COLUMN_PRUNING, ColumnPruning::new), bottomUp(ImmutableList.of( new PushDownFilterThroughProject(), @@ -548,6 +549,8 @@ private static List getWholeTreeRewriteJobs( topic("rewrite cte sub-tree before sub path push down", custom(RuleType.REWRITE_CTE_CHILDREN, () -> new RewriteCteChildren(beforePushDownJobs)) ))); + rewriteJobs.addAll(jobs(topic("convert outer join to anti", + custom(RuleType.CONVERT_OUTER_JOIN_TO_ANTI, ConvertOuterJoinToAntiJoin::new)))); if (needOrExpansion) { rewriteJobs.addAll(jobs(topic("or expansion", custom(RuleType.OR_EXPANSION, () -> OrExpansion.INSTANCE)))); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java index bcd12ac17d2579..15943a25a90f86 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java @@ -86,7 +86,6 @@ import org.apache.doris.nereids.rules.implementation.LogicalTopNToPhysicalTopN; import org.apache.doris.nereids.rules.implementation.LogicalUnionToPhysicalUnion; import org.apache.doris.nereids.rules.implementation.LogicalWindowToPhysicalWindow; -import org.apache.doris.nereids.rules.rewrite.ConvertOuterJoinToAntiJoin; import org.apache.doris.nereids.rules.rewrite.CreatePartitionTopNFromWindow; import org.apache.doris.nereids.rules.rewrite.EliminateFilter; import org.apache.doris.nereids.rules.rewrite.EliminateOuterJoin; @@ -148,7 +147,6 @@ public class RuleSet { new PushDownFilterThroughGenerate(), new PushDownProjectThroughLimit(), new EliminateOuterJoin(), - new ConvertOuterJoinToAntiJoin(), new MergeProjects(), new MergeFilters(), new MergeGenerates(), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewrite.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewrite.java index e69f93bf84551e..a5c77e2e47ef8a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewrite.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewrite.java @@ -20,6 +20,7 @@ import org.apache.doris.common.Pair; import org.apache.doris.nereids.pattern.ExpressionPatternRules; import org.apache.doris.nereids.pattern.ExpressionPatternTraverseListeners; +import org.apache.doris.nereids.pattern.MatchingContext; import org.apache.doris.nereids.properties.OrderKey; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; @@ -28,25 +29,39 @@ import org.apache.doris.nereids.trees.expressions.EqualPredicate; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.OrderExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.functions.Function; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalGenerate; import org.apache.doris.nereids.trees.plans.logical.LogicalHaving; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation; +import org.apache.doris.nereids.trees.plans.logical.LogicalPartitionTopN; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat; +import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation; +import org.apache.doris.nereids.trees.plans.logical.LogicalSink; import org.apache.doris.nereids.trees.plans.logical.LogicalSort; +import org.apache.doris.nereids.trees.plans.logical.LogicalTopN; +import org.apache.doris.nereids.trees.plans.logical.LogicalWindow; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.Utils; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableList.Builder; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.ImmutableSet; +import java.util.ArrayList; import java.util.Collection; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Set; @@ -79,7 +94,19 @@ public List buildRules() { new JoinExpressionRewrite().build(), new SortExpressionRewrite().build(), new LogicalRepeatRewrite().build(), - new HavingExpressionRewrite().build()); + new HavingExpressionRewrite().build(), + new LogicalPartitionTopNExpressionRewrite().build(), + new LogicalTopNExpressionRewrite().build(), + new LogicalSetOperationRewrite().build(), + new LogicalWindowRewrite().build(), + new LogicalCteConsumerRewrite().build(), + new LogicalResultSinkRewrite().build(), + new LogicalFileSinkRewrite().build(), + new LogicalHiveTableSinkRewrite().build(), + new LogicalIcebergTableSinkRewrite().build(), + new LogicalJdbcTableSinkRewrite().build(), + new LogicalOlapTableSinkRewrite().build(), + new LogicalDeferMaterializeResultSinkRewrite().build()); } private class GenerateExpressionRewrite extends OneRewriteRuleFactory { @@ -264,7 +291,166 @@ public Rule build() { } } - private class LogicalRepeatRewrite extends OneRewriteRuleFactory { + private class LogicalWindowRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalWindow().thenApply(ctx -> { + LogicalWindow window = ctx.root; + List windowExpressions = window.getWindowExpressions(); + ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); + List result = rewriteAll(windowExpressions, rewriter, context); + return window.withExpressionsAndChild(result, window.child()); + }) + .toRule(RuleType.REWRITE_WINDOW_EXPRESSION); + } + } + + private class LogicalSetOperationRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalSetOperation().thenApply(ctx -> { + LogicalSetOperation setOperation = ctx.root; + List> slotsList = setOperation.getRegularChildrenOutputs(); + List> newSlotsList = new ArrayList<>(); + ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); + for (List slots : slotsList) { + List result = rewriteAll(slots, rewriter, context); + newSlotsList.add(result); + } + return setOperation.withChildrenAndTheirOutputs(setOperation.children(), newSlotsList); + }) + .toRule(RuleType.REWRITE_SET_OPERATION_EXPRESSION); + } + } + + private class LogicalTopNExpressionRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalTopN().thenApply(ctx -> { + LogicalTopN topN = ctx.root; + List orderKeys = topN.getOrderKeys(); + ImmutableList.Builder rewrittenOrderKeys + = ImmutableList.builderWithExpectedSize(orderKeys.size()); + ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); + boolean changed = false; + for (OrderKey k : orderKeys) { + Expression expression = rewriter.rewrite(k.getExpr(), context); + changed |= expression != k.getExpr(); + rewrittenOrderKeys.add(new OrderKey(expression, k.isAsc(), k.isNullFirst())); + } + return changed ? topN.withOrderKeys(rewrittenOrderKeys.build()) : topN; + }).toRule(RuleType.REWRITE_TOPN_EXPRESSION); + } + } + + private class LogicalPartitionTopNExpressionRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalPartitionTopN().thenApply(ctx -> { + LogicalPartitionTopN partitionTopN = ctx.root; + ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); + List newOrderExpressions = new ArrayList<>(); + for (OrderExpression orderExpression : partitionTopN.getOrderKeys()) { + OrderKey orderKey = orderExpression.getOrderKey(); + Expression expr = rewriter.rewrite(orderKey.getExpr(), context); + OrderKey newOrderKey = new OrderKey(expr, orderKey.isAsc(), orderKey.isNullFirst()); + newOrderExpressions.add(new OrderExpression(newOrderKey)); + } + List result = rewriteAll(partitionTopN.getPartitionKeys(), rewriter, context); + return partitionTopN.withPartitionKeysAndOrderKeys(result, newOrderExpressions); + }).toRule(RuleType.REWRITE_PARTITION_TOPN_EXPRESSION); + } + } + + private class LogicalCteConsumerRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalCTEConsumer().thenApply(ctx -> { + LogicalCTEConsumer consumer = ctx.root; + boolean changed = false; + ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); + ImmutableMap.Builder cToPBuilder = ImmutableMap.builder(); + ImmutableMultimap.Builder pToCBuilder = ImmutableMultimap.builder(); + for (Map.Entry entry : consumer.getConsumerToProducerOutputMap().entrySet()) { + Slot key = (Slot) rewriter.rewrite(entry.getKey(), context); + Slot value = (Slot) rewriter.rewrite(entry.getValue(), context); + cToPBuilder.put(key, value); + pToCBuilder.put(value, key); + if (!key.equals(entry.getKey()) || !value.equals(entry.getValue())) { + changed = true; + } + } + return changed ? consumer.withTwoMaps(cToPBuilder.build(), pToCBuilder.build()) : consumer; + }).toRule(RuleType.REWRITE_TOPN_EXPRESSION); + } + } + + private class LogicalResultSinkRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalResultSink().thenApply(ExpressionRewrite.this::applyRewriteToSink) + .toRule(RuleType.REWRITE_SINK_EXPRESSION); + } + } + + private class LogicalFileSinkRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalFileSink().thenApply(ExpressionRewrite.this::applyRewriteToSink) + .toRule(RuleType.REWRITE_SINK_EXPRESSION); + } + } + + private class LogicalHiveTableSinkRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalHiveTableSink().thenApply(ExpressionRewrite.this::applyRewriteToSink) + .toRule(RuleType.REWRITE_SINK_EXPRESSION); + } + } + + private class LogicalIcebergTableSinkRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalIcebergTableSink().thenApply(ExpressionRewrite.this::applyRewriteToSink) + .toRule(RuleType.REWRITE_SINK_EXPRESSION); + } + } + + private class LogicalJdbcTableSinkRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalJdbcTableSink().thenApply(ExpressionRewrite.this::applyRewriteToSink) + .toRule(RuleType.REWRITE_SINK_EXPRESSION); + } + } + + private class LogicalOlapTableSinkRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalOlapTableSink().thenApply(ExpressionRewrite.this::applyRewriteToSink) + .toRule(RuleType.REWRITE_SINK_EXPRESSION); + } + } + + private class LogicalDeferMaterializeResultSinkRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalDeferMaterializeResultSink().thenApply(ExpressionRewrite.this::applyRewriteToSink) + .toRule(RuleType.REWRITE_SINK_EXPRESSION); + } + } + + private LogicalSink applyRewriteToSink(MatchingContext> ctx) { + LogicalSink sink = ctx.root; + ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); + List outputExprs = sink.getOutputExprs(); + List result = rewriteAll(outputExprs, rewriter, context); + return sink.withOutputExprs(result); + } + + /** LogicalRepeatRewrite */ + public class LogicalRepeatRewrite extends OneRewriteRuleFactory { @Override public Rule build() { return logicalRepeat().thenApply(ctx -> { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoin.java index c9185fd1a3cfea..46445573055df2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoin.java @@ -17,9 +17,9 @@ package org.apache.doris.nereids.rules.rewrite; -import org.apache.doris.nereids.rules.Rule; -import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.jobs.JobContext; import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; @@ -28,9 +28,14 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; +import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; import org.apache.doris.nereids.util.TypeUtils; -import java.util.List; +import com.google.common.collect.ImmutableList; + +import java.util.HashMap; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; @@ -42,18 +47,41 @@ * project(A.*) * - LeftAntiJoin(A, B) */ -public class ConvertOuterJoinToAntiJoin extends OneRewriteRuleFactory { +public class ConvertOuterJoinToAntiJoin extends DefaultPlanRewriter> implements CustomRewriter { + private ExprIdRewriter exprIdReplacer; + + @Override + public Plan rewriteRoot(Plan plan, JobContext jobContext) { + if (!plan.containsType(LogicalJoin.class)) { + return plan; + } + Map replaceMap = new HashMap<>(); + ExprIdRewriter.ReplaceRule replaceRule = new ExprIdRewriter.ReplaceRule(replaceMap); + exprIdReplacer = new ExprIdRewriter(replaceRule, jobContext); + return plan.accept(this, replaceMap); + } @Override - public Rule build() { - return logicalFilter(logicalJoin() - .when(join -> join.getJoinType().isOuterJoin())) - .then(this::toAntiJoin) - .toRule(RuleType.CONVERT_OUTER_JOIN_TO_ANTI); + public Plan visit(Plan plan, Map replaceMap) { + plan = visitChildren(this, plan, replaceMap); + plan = exprIdReplacer.rewriteExpr(plan, replaceMap); + return plan; } - private Plan toAntiJoin(LogicalFilter> filter) { + @Override + public Plan visitLogicalFilter(LogicalFilter filter, Map replaceMap) { + filter = (LogicalFilter) visit(filter, replaceMap); + if (!(filter.child() instanceof LogicalJoin)) { + return filter; + } + return toAntiJoin((LogicalFilter>) filter, replaceMap); + } + + private Plan toAntiJoin(LogicalFilter> filter, Map replaceMap) { LogicalJoin join = filter.child(); + if (!join.getJoinType().isLeftOuterJoin() && !join.getJoinType().isRightOuterJoin()) { + return filter; + } Set alwaysNullSlots = filter.getConjuncts().stream() .filter(p -> TypeUtils.isNull(p).isPresent()) @@ -66,33 +94,37 @@ private Plan toAntiJoin(LogicalFilter> filter) { .filter(s -> alwaysNullSlots.contains(s) && !s.nullable()) .collect(Collectors.toSet()); - Plan newJoin = null; + Plan newChild = null; if (join.getJoinType().isLeftOuterJoin() && !rightAlwaysNullSlots.isEmpty()) { - newJoin = join.withJoinTypeAndContext(JoinType.LEFT_ANTI_JOIN, join.getJoinReorderContext()); + newChild = join.withJoinTypeAndContext(JoinType.LEFT_ANTI_JOIN, join.getJoinReorderContext()); } if (join.getJoinType().isRightOuterJoin() && !leftAlwaysNullSlots.isEmpty()) { - newJoin = join.withJoinTypeAndContext(JoinType.RIGHT_ANTI_JOIN, join.getJoinReorderContext()); + newChild = join.withJoinTypeAndContext(JoinType.RIGHT_ANTI_JOIN, join.getJoinReorderContext()); } - if (newJoin == null) { - return null; + if (newChild == null) { + return filter; } - if (!newJoin.getOutputSet().containsAll(filter.getInputSlots())) { + if (!newChild.getOutputSet().containsAll(filter.getInputSlots())) { // if there are slots that don't belong to join output, we use null alias to replace them // such as: // project(A.id, null as B.id) // - (A left anti join B) - Set joinOutput = newJoin.getOutputSet(); - List projects = filter.getOutput().stream() - .map(s -> { - if (joinOutput.contains(s)) { - return s; - } else { - return new Alias(s.getExprId(), new NullLiteral(s.getDataType()), s.getName()); - } - }).collect(Collectors.toList()); - newJoin = new LogicalProject<>(projects, newJoin); + Set joinOutputs = newChild.getOutputSet(); + ImmutableList.Builder projectsBuilder = ImmutableList.builder(); + for (NamedExpression e : filter.getOutput()) { + if (joinOutputs.contains(e)) { + projectsBuilder.add(e); + } else { + Alias newAlias = new Alias(new NullLiteral(e.getDataType()), e.getName(), e.getQualifier()); + replaceMap.put(e.getExprId(), newAlias.getExprId()); + projectsBuilder.add(newAlias); + } + } + newChild = new LogicalProject<>(projectsBuilder.build(), newChild); + return exprIdReplacer.rewriteExpr(filter.withChildren(newChild), replaceMap); + } else { + return filter.withChildren(newChild); } - return filter.withChildren(newJoin); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExprIdRewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExprIdRewriter.java index 60c9da4bc6eec5..5e065fa3724b08 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExprIdRewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExprIdRewriter.java @@ -18,32 +18,20 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.nereids.jobs.JobContext; -import org.apache.doris.nereids.pattern.MatchingContext; import org.apache.doris.nereids.pattern.Pattern; -import org.apache.doris.nereids.properties.OrderKey; import org.apache.doris.nereids.rules.Rule; -import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.rules.expression.ExpressionRewrite; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.NamedExpression; -import org.apache.doris.nereids.trees.expressions.OrderExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.trees.plans.logical.LogicalPartitionTopN; -import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation; -import org.apache.doris.nereids.trees.plans.logical.LogicalSink; -import org.apache.doris.nereids.trees.plans.logical.LogicalTopN; -import org.apache.doris.nereids.trees.plans.logical.LogicalWindow; import com.google.common.collect.ImmutableList; -import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -58,26 +46,6 @@ public ExprIdRewriter(ReplaceRule replaceRule, JobContext jobContext) { this.jobContext = jobContext; } - @Override - public List buildRules() { - ImmutableList.Builder builder = ImmutableList.builder(); - builder.addAll(super.buildRules()); - builder.addAll(ImmutableList.of( - new LogicalPartitionTopNExpressionRewrite().build(), - new LogicalTopNExpressionRewrite().build(), - new LogicalSetOperationRewrite().build(), - new LogicalWindowRewrite().build(), - new LogicalResultSinkRewrite().build(), - new LogicalFileSinkRewrite().build(), - new LogicalHiveTableSinkRewrite().build(), - new LogicalIcebergTableSinkRewrite().build(), - new LogicalJdbcTableSinkRewrite().build(), - new LogicalOlapTableSinkRewrite().build(), - new LogicalDeferMaterializeResultSinkRewrite().build() - )); - return builder.build(); - } - /**rewriteExpr*/ public Plan rewriteExpr(Plan plan, Map replaceMap) { if (replaceMap.isEmpty()) { @@ -129,156 +97,4 @@ public List> buildRules() { ); } } - - private class LogicalResultSinkRewrite extends OneRewriteRuleFactory { - @Override - public Rule build() { - return logicalResultSink().thenApply(ExprIdRewriter.this::applyRewrite) - .toRule(RuleType.REWRITE_SINK_EXPRESSION); - } - } - - private class LogicalFileSinkRewrite extends OneRewriteRuleFactory { - @Override - public Rule build() { - return logicalFileSink().thenApply(ExprIdRewriter.this::applyRewrite) - .toRule(RuleType.REWRITE_SINK_EXPRESSION); - } - } - - private class LogicalHiveTableSinkRewrite extends OneRewriteRuleFactory { - @Override - public Rule build() { - return logicalHiveTableSink().thenApply(ExprIdRewriter.this::applyRewrite) - .toRule(RuleType.REWRITE_SINK_EXPRESSION); - } - } - - private class LogicalIcebergTableSinkRewrite extends OneRewriteRuleFactory { - @Override - public Rule build() { - return logicalIcebergTableSink().thenApply(ExprIdRewriter.this::applyRewrite) - .toRule(RuleType.REWRITE_SINK_EXPRESSION); - } - } - - private class LogicalJdbcTableSinkRewrite extends OneRewriteRuleFactory { - @Override - public Rule build() { - return logicalJdbcTableSink().thenApply(ExprIdRewriter.this::applyRewrite) - .toRule(RuleType.REWRITE_SINK_EXPRESSION); - } - } - - private class LogicalOlapTableSinkRewrite extends OneRewriteRuleFactory { - @Override - public Rule build() { - return logicalOlapTableSink().thenApply(ExprIdRewriter.this::applyRewrite) - .toRule(RuleType.REWRITE_SINK_EXPRESSION); - } - } - - private class LogicalDeferMaterializeResultSinkRewrite extends OneRewriteRuleFactory { - @Override - public Rule build() { - return logicalDeferMaterializeResultSink().thenApply(ExprIdRewriter.this::applyRewrite) - .toRule(RuleType.REWRITE_SINK_EXPRESSION); - } - } - - private class LogicalSetOperationRewrite extends OneRewriteRuleFactory { - @Override - public Rule build() { - return logicalSetOperation().thenApply(ctx -> { - LogicalSetOperation setOperation = ctx.root; - List> slotsList = setOperation.getRegularChildrenOutputs(); - List> newSlotsList = new ArrayList<>(); - ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); - for (List slots : slotsList) { - List newSlots = rewriteAll(slots, rewriter, context); - newSlotsList.add(newSlots); - } - if (newSlotsList.equals(slotsList)) { - return setOperation; - } - return setOperation.withChildrenAndTheirOutputs(setOperation.children(), newSlotsList); - }) - .toRule(RuleType.REWRITE_SET_OPERATION_EXPRESSION); - } - } - - private class LogicalWindowRewrite extends OneRewriteRuleFactory { - @Override - public Rule build() { - return logicalWindow().thenApply(ctx -> { - LogicalWindow window = ctx.root; - List windowExpressions = window.getWindowExpressions(); - ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); - List newWindowExpressions = rewriteAll(windowExpressions, rewriter, context); - if (newWindowExpressions.equals(windowExpressions)) { - return window; - } - return window.withExpressionsAndChild(newWindowExpressions, window.child()); - }) - .toRule(RuleType.REWRITE_WINDOW_EXPRESSION); - } - } - - private class LogicalTopNExpressionRewrite extends OneRewriteRuleFactory { - @Override - public Rule build() { - return logicalTopN().thenApply(ctx -> { - LogicalTopN topN = ctx.root; - List orderKeys = topN.getOrderKeys(); - ImmutableList.Builder rewrittenOrderKeys - = ImmutableList.builderWithExpectedSize(orderKeys.size()); - ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); - boolean changed = false; - for (OrderKey k : orderKeys) { - Expression expression = rewriter.rewrite(k.getExpr(), context); - changed |= expression != k.getExpr(); - rewrittenOrderKeys.add(new OrderKey(expression, k.isAsc(), k.isNullFirst())); - } - return changed ? topN.withOrderKeys(rewrittenOrderKeys.build()) : topN; - }).toRule(RuleType.REWRITE_TOPN_EXPRESSION); - } - } - - private class LogicalPartitionTopNExpressionRewrite extends OneRewriteRuleFactory { - @Override - public Rule build() { - return logicalPartitionTopN().thenApply(ctx -> { - LogicalPartitionTopN partitionTopN = ctx.root; - ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); - List newOrderExpressions = new ArrayList<>(); - boolean changed = false; - for (OrderExpression orderExpression : partitionTopN.getOrderKeys()) { - OrderKey orderKey = orderExpression.getOrderKey(); - Expression expr = rewriter.rewrite(orderKey.getExpr(), context); - changed |= expr != orderKey.getExpr(); - OrderKey newOrderKey = new OrderKey(expr, orderKey.isAsc(), orderKey.isNullFirst()); - newOrderExpressions.add(new OrderExpression(newOrderKey)); - } - List newPartitionKeys = rewriteAll(partitionTopN.getPartitionKeys(), rewriter, context); - if (!newPartitionKeys.equals(partitionTopN.getPartitionKeys())) { - changed = true; - } - if (!changed) { - return partitionTopN; - } - return partitionTopN.withPartitionKeysAndOrderKeys(newPartitionKeys, newOrderExpressions); - }).toRule(RuleType.REWRITE_PARTITION_TOPN_EXPRESSION); - } - } - - private LogicalSink applyRewrite(MatchingContext> ctx) { - LogicalSink sink = ctx.root; - ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); - List outputExprs = sink.getOutputExprs(); - List newOutputExprs = rewriteAll(outputExprs, rewriter, context); - if (outputExprs.equals(newOutputExprs)) { - return sink; - } - return sink.withOutputExprs(newOutputExprs); - } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/StatementScopeIdGenerator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/StatementScopeIdGenerator.java index df7ef2ab69a100..cf0ecc3cb9b956 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/StatementScopeIdGenerator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/StatementScopeIdGenerator.java @@ -81,6 +81,6 @@ public static void clear() throws Exception { if (ConnectContext.get() != null) { ConnectContext.get().setStatementContext(new StatementContext()); } - statementContext = new StatementContext(); + statementContext = new StatementContext(10000); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCTEConsumer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCTEConsumer.java index 415fdddf80b449..6148f62378e65a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCTEConsumer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCTEConsumer.java @@ -198,4 +198,21 @@ public String toString() { "relationId", relationId, "name", name); } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) { + return false; + } + if (!super.equals(o)) { + return false; + } + LogicalCTEConsumer that = (LogicalCTEConsumer) o; + return Objects.equals(consumerToProducerOutputMap, that.consumerToProducerOutputMap); + } + + @Override + public int hashCode() { + return super.hashCode(); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoinTest.java index 1159fc2a7cec6d..b3166c2224052b 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoinTest.java @@ -32,17 +32,21 @@ import org.apache.doris.nereids.util.MemoTestUtils; import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; +import org.apache.doris.qe.ConnectContext; import com.google.common.collect.ImmutableList; import com.google.common.collect.Sets; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; class ConvertOuterJoinToAntiJoinTest implements MemoPatternMatchSupported { - private final LogicalOlapScan scan1; - private final LogicalOlapScan scan2; + private LogicalOlapScan scan1; + private LogicalOlapScan scan2; - public ConvertOuterJoinToAntiJoinTest() throws Exception { + @BeforeEach + void setUp() throws Exception { // clear id so that slot id keep consistent every running + ConnectContext.remove(); StatementScopeIdGenerator.clear(); scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); @@ -58,7 +62,7 @@ void testEliminateLeftWithProject() { PlanChecker.from(MemoTestUtils.createConnectContext(), plan) .applyTopDown(new InferFilterNotNull()) - .applyTopDown(new ConvertOuterJoinToAntiJoin()) + .applyCustom(new ConvertOuterJoinToAntiJoin()) .printlnTree() .matches(logicalJoin().when(join -> join.getJoinType().isLeftAntiJoin())); } @@ -73,7 +77,7 @@ void testEliminateRightWithProject() { PlanChecker.from(MemoTestUtils.createConnectContext(), plan) .applyTopDown(new InferFilterNotNull()) - .applyTopDown(new ConvertOuterJoinToAntiJoin()) + .applyCustom(new ConvertOuterJoinToAntiJoin()) .printlnTree() .matches(logicalJoin().when(join -> join.getJoinType().isRightAntiJoin())); } @@ -91,7 +95,7 @@ void testEliminateLeftWithLeftPredicate() { PlanChecker.from(MemoTestUtils.createConnectContext(), plan) .applyTopDown(new InferFilterNotNull()) - .applyTopDown(new ConvertOuterJoinToAntiJoin()) + .applyCustom(new ConvertOuterJoinToAntiJoin()) .printlnTree() .matches(logicalJoin().when(join -> join.getJoinType().isLeftAntiJoin())); } @@ -109,7 +113,7 @@ void testEliminateLeftWithRightPredicate() { PlanChecker.from(MemoTestUtils.createConnectContext(), plan) .applyTopDown(new InferFilterNotNull()) - .applyTopDown(new ConvertOuterJoinToAntiJoin()) + .applyCustom(new ConvertOuterJoinToAntiJoin()) .printlnTree() .matches(logicalJoin().when(join -> join.getJoinType().isLeftAntiJoin())); } @@ -127,7 +131,7 @@ void testEliminateLeftWithOrPredicate() { PlanChecker.from(MemoTestUtils.createConnectContext(), plan) .applyTopDown(new InferFilterNotNull()) - .applyTopDown(new ConvertOuterJoinToAntiJoin()) + .applyCustom(new ConvertOuterJoinToAntiJoin()) .printlnTree() .matches(logicalJoin().when(join -> join.getJoinType().isLeftOuterJoin())); } @@ -146,7 +150,7 @@ void testEliminateLeftWithAndPredicate() { PlanChecker.from(MemoTestUtils.createConnectContext(), plan) .applyTopDown(new InferFilterNotNull()) - .applyTopDown(new ConvertOuterJoinToAntiJoin()) + .applyCustom(new ConvertOuterJoinToAntiJoin()) .printlnTree() .matches(logicalJoin().when(join -> join.getJoinType().isLeftOuterJoin())); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateOuterJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateOuterJoinTest.java index 255f1e82e0061c..f0034410163649 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateOuterJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateOuterJoinTest.java @@ -21,6 +21,7 @@ import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.GreaterThan; import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator; +import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; @@ -69,7 +70,7 @@ void testEliminateLeft() { void testEliminateRight() { LogicalPlan plan = new LogicalPlanBuilder(scan1) .join(scan2, JoinType.RIGHT_OUTER_JOIN, Pair.of(0, 0)) // t1.id = t2.id - .filter(new GreaterThan(scan1.getOutput().get(0), Literal.of(1))) + .filter(new GreaterThan(scan1.getOutput().get(0), new IntegerLiteral(1))) .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), plan) @@ -81,7 +82,7 @@ void testEliminateRight() { logicalFilter( logicalJoin().when(join -> join.getJoinType().isInnerJoin()) ).when(filter -> filter.getConjuncts().size() == 1) - .when(filter -> Objects.equals(filter.getConjuncts().toString(), "[(id#0 > 1)]")) + .when(filter -> Objects.equals(filter.getConjuncts().iterator().next(), new GreaterThan(scan1.getOutput().get(0), new IntegerLiteral(1)))) ); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java index 71d0f0101b0413..6962572d07a483 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java @@ -32,6 +32,7 @@ import org.apache.doris.nereids.jobs.executor.Optimizer; import org.apache.doris.nereids.jobs.executor.Rewriter; import org.apache.doris.nereids.jobs.joinorder.JoinOrderJob; +import org.apache.doris.nereids.jobs.rewrite.CustomRewriteJob; import org.apache.doris.nereids.jobs.rewrite.PlanTreeRewriteBottomUpJob; import org.apache.doris.nereids.jobs.rewrite.PlanTreeRewriteTopDownJob; import org.apache.doris.nereids.jobs.rewrite.RootPlanTreeRewriteJob; @@ -202,6 +203,14 @@ public PlanChecker applyTopDown(List rule) { return this; } + public PlanChecker applyCustom(CustomRewriter customRewriter) { + CustomRewriteJob customRewriteJob = new CustomRewriteJob(() -> customRewriter, RuleType.TEST_REWRITE); + customRewriteJob.execute(cascadesContext.getCurrentJobContext()); + cascadesContext.toMemo(); + MemoValidator.validate(cascadesContext.getMemo()); + return this; + } + /** * apply a top down rewrite rule if you not care the ruleId * diff --git a/regression-test/suites/nereids_syntax_p0/transform_outer_join_to_anti.groovy b/regression-test/suites/nereids_syntax_p0/transform_outer_join_to_anti.groovy index ccbb8fd64a8f6f..f806f4ce5c7a5e 100644 --- a/regression-test/suites/nereids_syntax_p0/transform_outer_join_to_anti.groovy +++ b/regression-test/suites/nereids_syntax_p0/transform_outer_join_to_anti.groovy @@ -84,4 +84,12 @@ suite("transform_outer_join_to_anti") { sql("select * from eliminate_outer_join_A right outer join eliminate_outer_join_B on eliminate_outer_join_B.b = eliminate_outer_join_A.a where eliminate_outer_join_A.a is null and eliminate_outer_join_B.null_b is null and eliminate_outer_join_A.null_a is null") contains "ANTI JOIN" } + + explain { + sql """with temp as ( + select * from eliminate_outer_join_A left outer join eliminate_outer_join_B on eliminate_outer_join_B.b = eliminate_outer_join_A.a where eliminate_outer_join_B.b is null + ) + select * from temp t1 join temp t2""" + contains "ANTI JOIN" + } }