diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownProject.java index 9fbc9413b29c9c..832f9c25e776f2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownProject.java @@ -36,10 +36,12 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalUnion; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.ImmutableCollection; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Lists; import com.google.common.collect.Multimap; import java.util.ArrayList; @@ -198,101 +200,105 @@ private Plan pushThroughUnion(MatchingContext> ctx) if (!ctx.connectContext.getSessionVariable().enablePruneNestedColumns) { return ctx.root; } - LogicalProject project = ctx.root; + return pushThroughUnion(ctx.root, ctx.statementContext); + } + + @VisibleForTesting + static Plan pushThroughUnion(LogicalProject project, StatementContext statementContext) { LogicalUnion union = project.child(); PushdownProjectHelper pushdownProjectHelper - = new PushdownProjectHelper(ctx.statementContext, project); - + = new PushdownProjectHelper(statementContext, project); Pair> pushProjects = pushdownProjectHelper.pushDownExpressions(project.getProjects()); - if (pushProjects.first) { - List unionOutputs = union.getOutputs(); - Map slotToColumnIndex = new LinkedHashMap<>(); - for (int i = 0; i < unionOutputs.size(); i++) { - NamedExpression output = unionOutputs.get(i); - slotToColumnIndex.put(output.toSlot(), i); - } + if (!pushProjects.first) { + return project; + } + List unionOutputs = union.getOutputs(); + Map slotToColumnIndex = new LinkedHashMap<>(); + for (int i = 0; i < unionOutputs.size(); i++) { + NamedExpression output = unionOutputs.get(i); + slotToColumnIndex.put(output.toSlot(), i); + } - Collection pushDownProjections - = pushdownProjectHelper.childToPushDownProjects.values(); - List newChildren = new ArrayList<>(); - List> newChildrenOutputs = new ArrayList<>(); - for (Plan child : union.children()) { - List pushedOutput = replaceSlot( - ctx.statementContext, - pushDownProjections, - slot -> { - Integer sourceColumnIndex = slotToColumnIndex.get(slot); - if (sourceColumnIndex != null) { - return child.getOutput().get(sourceColumnIndex).toSlot(); - } - return slot; + List pushDownProjections + = Lists.newArrayList(pushdownProjectHelper.childToPushDownProjects.values()); + List newChildren = new ArrayList<>(); + List> newChildrenOutputs = new ArrayList<>(); + for (int i = 0; i < union.arity(); i++) { + List regulatorOutput = union.getRegularChildOutput(i); + List pushedOutput = replaceSlot( + statementContext, + pushDownProjections, + slot -> { + Integer sourceColumnIndex = slotToColumnIndex.get(slot); + if (sourceColumnIndex != null) { + return regulatorOutput.get(sourceColumnIndex).toSlot(); } - ); - - LogicalProject newChild = new LogicalProject<>( - ImmutableList.builder() - .addAll(child.getOutput()) - .addAll(pushedOutput) - .build(), - child - ); - - newChildrenOutputs.add((List) newChild.getOutput()); - newChildren.add(newChild); - } + return slot; + } + ); - for (List originConstantExprs : union.getConstantExprsList()) { - List pushedOutput = replaceSlot( - ctx.statementContext, - pushDownProjections, - slot -> { - Integer sourceColumnIndex = slotToColumnIndex.get(slot); - if (sourceColumnIndex != null) { - return originConstantExprs.get(sourceColumnIndex).toSlot(); - } - return slot; + LogicalProject newChild = new LogicalProject<>( + ImmutableList.builder() + .addAll(regulatorOutput) + .addAll(pushedOutput) + .build(), + union.child(i) + ); + + newChildrenOutputs.add((List) newChild.getOutput()); + newChildren.add(newChild); + } + + for (List originConstantExprs : union.getConstantExprsList()) { + List pushedOutput = replaceSlot( + statementContext, + pushDownProjections, + slot -> { + Integer sourceColumnIndex = slotToColumnIndex.get(slot); + if (sourceColumnIndex != null) { + return originConstantExprs.get(sourceColumnIndex).toSlot(); } - ); - - LogicalOneRowRelation originOneRowRelation = new LogicalOneRowRelation( - ctx.statementContext.getNextRelationId(), - originConstantExprs - ); - - LogicalProject newChild = new LogicalProject<>( - ImmutableList.builder() - .addAll(originOneRowRelation.getOutput()) - .addAll(pushedOutput) - .build(), - originOneRowRelation - ); - - newChildrenOutputs.add((List) newChild.getOutput()); - newChildren.add(newChild); - } + return slot; + } + ); - List newUnionOutputs = new ArrayList<>(union.getOutputs()); - for (NamedExpression projection : pushDownProjections) { - newUnionOutputs.add(projection.toSlot()); - } + LogicalOneRowRelation originOneRowRelation = new LogicalOneRowRelation( + statementContext.getNextRelationId(), + originConstantExprs + ); - return new LogicalProject<>( - pushProjects.second, - new LogicalUnion( - union.getQualifier(), - newUnionOutputs, - newChildrenOutputs, - ImmutableList.of(), - union.hasPushedFilter(), - newChildren - ) + LogicalProject newChild = new LogicalProject<>( + ImmutableList.builder() + .addAll(originOneRowRelation.getOutput()) + .addAll(pushedOutput) + .build(), + originOneRowRelation ); + + newChildrenOutputs.add((List) newChild.getOutput()); + newChildren.add(newChild); } - return project; + + List newUnionOutputs = new ArrayList<>(union.getOutputs()); + for (NamedExpression projection : pushDownProjections) { + newUnionOutputs.add(projection.toSlot()); + } + + return new LogicalProject<>( + pushProjects.second, + new LogicalUnion( + union.getQualifier(), + newUnionOutputs, + newChildrenOutputs, + ImmutableList.of(), + union.hasPushedFilter(), + newChildren + ) + ); } - private List replaceSlot( + private static List replaceSlot( StatementContext statementContext, Collection pushDownProjections, Function slotReplace) { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownProjectTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownProjectTest.java new file mode 100644 index 00000000000000..47398e3ef9a458 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownProjectTest.java @@ -0,0 +1,192 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.StatementContext; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.ExprId; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt; +import org.apache.doris.nereids.trees.expressions.literal.StringLiteral; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.RelationId; +import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier; +import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.logical.LogicalUnion; +import org.apache.doris.nereids.types.TinyIntType; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PlanChecker; + +import com.google.common.collect.Lists; +import org.junit.jupiter.api.Test; + +import java.util.List; + +public class PushDownProjectTest implements MemoPatternMatchSupported { + + private final List rel1Output = Lists.newArrayList( + new SlotReference(new ExprId(1), "c1", TinyIntType.INSTANCE, true, Lists.newArrayList()), + new SlotReference(new ExprId(2), "c2", TinyIntType.INSTANCE, true, Lists.newArrayList()) + ); + private final List regulatorRel1Output = Lists.newArrayList( + new SlotReference(new ExprId(2), "c2", TinyIntType.INSTANCE, true, Lists.newArrayList()), + new SlotReference(new ExprId(1), "c1", TinyIntType.INSTANCE, true, Lists.newArrayList()), + new SlotReference(new ExprId(2), "c2", TinyIntType.INSTANCE, true, Lists.newArrayList()) + ); + private final List rel2Output = Lists.newArrayList( + new SlotReference(new ExprId(3), "c3", TinyIntType.INSTANCE, true, Lists.newArrayList()), + new SlotReference(new ExprId(4), "c4", TinyIntType.INSTANCE, true, Lists.newArrayList()), + new SlotReference(new ExprId(5), "c5", TinyIntType.INSTANCE, true, Lists.newArrayList()) + ); + private final List regulatorRel2Output = Lists.newArrayList( + new SlotReference(new ExprId(3), "c3", TinyIntType.INSTANCE, true, Lists.newArrayList()), + new SlotReference(new ExprId(5), "c5", TinyIntType.INSTANCE, true, Lists.newArrayList()), + new SlotReference(new ExprId(4), "c4", TinyIntType.INSTANCE, true, Lists.newArrayList()) + ); + private final List unionOutput = Lists.newArrayList( + new SlotReference(new ExprId(10), "c10", TinyIntType.INSTANCE, true, Lists.newArrayList()), + new SlotReference(new ExprId(11), "c11", TinyIntType.INSTANCE, true, Lists.newArrayList()), + new SlotReference(new ExprId(12), "c12", TinyIntType.INSTANCE, true, Lists.newArrayList()) + ); + private final List pushDownProjections = Lists.newArrayList( + new Alias(new ExprId(100), new ElementAt( + new SlotReference(new ExprId(10), "c10", TinyIntType.INSTANCE, true, Lists.newArrayList()), + new StringLiteral("a"))), + new Alias(new ExprId(101), new ElementAt( + new SlotReference(new ExprId(10), "c10", TinyIntType.INSTANCE, true, Lists.newArrayList()), + new StringLiteral("b"))), + new Alias(new ExprId(102), new ElementAt( + new SlotReference(new ExprId(12), "c10", TinyIntType.INSTANCE, true, Lists.newArrayList()), + new StringLiteral("a"))), + new SlotReference(new ExprId(11), "c11", TinyIntType.INSTANCE, true, Lists.newArrayList()) + ); + + private final LogicalOneRowRelation rel1 = new LogicalOneRowRelation(new RelationId(1), rel1Output); + private final LogicalOneRowRelation rel2 = new LogicalOneRowRelation(new RelationId(2), rel2Output); + private final List children = Lists.newArrayList(rel1, rel2); + + @Test + public void testPushDownProjectThroughUnionOnlyHasChildren() { + List> regulatorOutputs = Lists.newArrayList(regulatorRel1Output, regulatorRel2Output); + LogicalUnion union = new LogicalUnion(Qualifier.ALL, unionOutput, + regulatorOutputs, Lists.newArrayList(), true, children); + LogicalProject project = new LogicalProject<>(pushDownProjections, union); + StatementContext context = new StatementContext(); + LogicalProject resProject + = (LogicalProject) PushDownProject.pushThroughUnion(project, context); + PlanChecker.from(MemoTestUtils.createConnectContext(), resProject) + .matchesFromRoot( + logicalProject( + logicalUnion( + logicalProject( + logicalOneRowRelation() + .when(r -> r.getRelationId().asInt() == 1) + .when(r -> r.getOutputs().size() == 2) + ).when(p -> p.getOutputs().size() == 6) + .when(p -> p.getProjects().get(0).getExprId().asInt() == 2) + .when(p -> p.getProjects().get(1).getExprId().asInt() == 1) + .when(p -> p.getProjects().get(2).getExprId().asInt() == 2) + .when(p -> p.getProjects().get(3).child(0) instanceof ElementAt) + .when(p -> p.getProjects().get(4).child(0) instanceof ElementAt) + .when(p -> p.getProjects().get(5).child(0) instanceof ElementAt) + .when(p -> ((SlotReference) (p.getProjects().get(3).child(0).child(0))).getExprId().asInt() == 2) + .when(p -> ((StringLiteral) (p.getProjects().get(3).child(0).child(1))).getValue().equals("a")) + .when(p -> ((SlotReference) (p.getProjects().get(4).child(0).child(0))).getExprId().asInt() == 2) + .when(p -> ((StringLiteral) (p.getProjects().get(4).child(0).child(1))).getValue().equals("b")) + .when(p -> ((SlotReference) (p.getProjects().get(5).child(0).child(0))).getExprId().asInt() == 2) + .when(p -> ((StringLiteral) (p.getProjects().get(5).child(0).child(1))).getValue().equals("a")), + logicalProject( + logicalOneRowRelation() + .when(r -> r.getRelationId().asInt() == 2) + .when(r -> r.getOutputs().size() == 3) + ).when(p -> p.getOutputs().size() == 6) + .when(p -> p.getProjects().get(0).getExprId().asInt() == 3) + .when(p -> p.getProjects().get(1).getExprId().asInt() == 5) + .when(p -> p.getProjects().get(2).getExprId().asInt() == 4) + .when(p -> p.getProjects().get(3).child(0) instanceof ElementAt) + .when(p -> p.getProjects().get(4).child(0) instanceof ElementAt) + .when(p -> p.getProjects().get(5).child(0) instanceof ElementAt) + .when(p -> ((SlotReference) (p.getProjects().get(3).child(0).child(0))).getExprId().asInt() == 3) + .when(p -> ((StringLiteral) (p.getProjects().get(3).child(0).child(1))).getValue().equals("a")) + .when(p -> ((SlotReference) (p.getProjects().get(4).child(0).child(0))).getExprId().asInt() == 3) + .when(p -> ((StringLiteral) (p.getProjects().get(4).child(0).child(1))).getValue().equals("b")) + .when(p -> ((SlotReference) (p.getProjects().get(5).child(0).child(0))).getExprId().asInt() == 4) + .when(p -> ((StringLiteral) (p.getProjects().get(5).child(0).child(1))).getValue().equals("a")) + ).when(u -> u.getOutput().size() == 6) + .when(u -> u.getOutput().get(0).getExprId().asInt() == 10) + .when(u -> u.getOutput().get(1).getExprId().asInt() == 11) + .when(u -> u.getOutput().get(2).getExprId().asInt() == 12) + ).when(p -> p.getProjects().stream().noneMatch(ne -> ne.containsType(ElementAt.class))) + ); + } + + @Test + public void testPushDownProjectThroughUnionHasNoChildren() { + LogicalUnion union = new LogicalUnion(Qualifier.ALL, unionOutput, Lists.newArrayList(), + Lists.newArrayList(regulatorRel1Output, regulatorRel2Output), true, Lists.newArrayList()); + LogicalProject project = new LogicalProject<>(pushDownProjections, union); + StatementContext context = new StatementContext(); + LogicalProject resProject + = (LogicalProject) PushDownProject.pushThroughUnion(project, context); + PlanChecker.from(MemoTestUtils.createConnectContext(), resProject) + .matchesFromRoot( + logicalProject( + logicalUnion( + logicalProject( + logicalOneRowRelation() + .when(r -> r.getOutputs().size() == 3) + ).when(p -> p.getOutputs().size() == 6) + .when(p -> p.getProjects().get(0).getExprId().asInt() == 2) + .when(p -> p.getProjects().get(1).getExprId().asInt() == 1) + .when(p -> p.getProjects().get(2).getExprId().asInt() == 2) + .when(p -> p.getProjects().get(3).child(0) instanceof ElementAt) + .when(p -> p.getProjects().get(4).child(0) instanceof ElementAt) + .when(p -> p.getProjects().get(5).child(0) instanceof ElementAt) + .when(p -> ((SlotReference) (p.getProjects().get(3).child(0).child(0))).getExprId().asInt() == 2) + .when(p -> ((StringLiteral) (p.getProjects().get(3).child(0).child(1))).getValue().equals("a")) + .when(p -> ((SlotReference) (p.getProjects().get(4).child(0).child(0))).getExprId().asInt() == 2) + .when(p -> ((StringLiteral) (p.getProjects().get(4).child(0).child(1))).getValue().equals("b")) + .when(p -> ((SlotReference) (p.getProjects().get(5).child(0).child(0))).getExprId().asInt() == 2) + .when(p -> ((StringLiteral) (p.getProjects().get(5).child(0).child(1))).getValue().equals("a")), + logicalProject( + logicalOneRowRelation() + .when(r -> r.getOutputs().size() == 3) + ).when(p -> p.getOutputs().size() == 6) + .when(p -> p.getProjects().get(0).getExprId().asInt() == 3) + .when(p -> p.getProjects().get(1).getExprId().asInt() == 5) + .when(p -> p.getProjects().get(2).getExprId().asInt() == 4) + .when(p -> p.getProjects().get(3).child(0) instanceof ElementAt) + .when(p -> p.getProjects().get(4).child(0) instanceof ElementAt) + .when(p -> p.getProjects().get(5).child(0) instanceof ElementAt) + .when(p -> ((SlotReference) (p.getProjects().get(3).child(0).child(0))).getExprId().asInt() == 3) + .when(p -> ((StringLiteral) (p.getProjects().get(3).child(0).child(1))).getValue().equals("a")) + .when(p -> ((SlotReference) (p.getProjects().get(4).child(0).child(0))).getExprId().asInt() == 3) + .when(p -> ((StringLiteral) (p.getProjects().get(4).child(0).child(1))).getValue().equals("b")) + .when(p -> ((SlotReference) (p.getProjects().get(5).child(0).child(0))).getExprId().asInt() == 4) + .when(p -> ((StringLiteral) (p.getProjects().get(5).child(0).child(1))).getValue().equals("a")) + ).when(u -> u.getOutput().size() == 6) + .when(u -> u.getOutput().get(0).getExprId().asInt() == 10) + .when(u -> u.getOutput().get(1).getExprId().asInt() == 11) + .when(u -> u.getOutput().get(2).getExprId().asInt() == 12) + ).when(p -> p.getProjects().stream().noneMatch(ne -> ne.containsType(ElementAt.class))) + ); + } +}