Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.rules.rewrite.AccessPathExpressionCollector.CollectAccessPathResult;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
Expand All @@ -28,6 +30,7 @@
import org.apache.doris.nereids.trees.plans.logical.LogicalFileScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalTVFRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor;
Expand All @@ -53,6 +56,28 @@ public Map<Slot, List<CollectAccessPathResult>> collect(Plan root, StatementCont
return scanSlotToAccessPaths;
}

@Override
public Void visitLogicalProject(LogicalProject<? extends Plan> project, StatementContext context) {
AccessPathExpressionCollector exprCollector
= new AccessPathExpressionCollector(context, allSlotToAccessPaths, false);
for (NamedExpression output : project.getProjects()) {
// e.g. select struct_element(s, 'city') from (select s from tbl)a;
// we will not treat the inner `s` access all path
if (output instanceof Slot && allSlotToAccessPaths.containsKey(output.getExprId().asInt())) {
continue;
} else if (output instanceof Alias && output.child(0) instanceof Slot
&& allSlotToAccessPaths.containsKey(output.getExprId().asInt())) {
Slot innerSlot = (Slot) output.child(0);
Collection<CollectAccessPathResult> outerSlotAccessPaths = allSlotToAccessPaths.get(
output.getExprId().asInt());
allSlotToAccessPaths.putAll(innerSlot.getExprId().asInt(), outerSlotAccessPaths);
} else {
exprCollector.collect(output);
}
}
return project.child().accept(this, context);
}

@Override
public Void visitLogicalFilter(LogicalFilter<? extends Plan> filter, StatementContext context) {
boolean bottomFilter = filter.child().arity() == 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;

/** push down project if the expression instance of PreferPushDownProject */
Expand Down Expand Up @@ -320,13 +319,13 @@ private List<NamedExpression> replaceSlot(
private static class PushdownProjectHelper {
private final Plan plan;
private final StatementContext statementContext;
private final Map<Expression, Pair<Slot, Plan>> exprToChildAndSlot;
private final Map<Expression, Expression> oldExprToNewExpr;
private final Multimap<Plan, NamedExpression> childToPushDownProjects;

public PushdownProjectHelper(StatementContext statementContext, Plan plan) {
this.statementContext = statementContext;
this.plan = plan;
this.exprToChildAndSlot = new LinkedHashMap<>();
this.oldExprToNewExpr = new LinkedHashMap<>();
this.childToPushDownProjects = ArrayListMultimap.create();
}

Expand Down Expand Up @@ -357,32 +356,36 @@ public <C extends Collection<E>, E extends Expression> Pair<Boolean, C> pushDown
}

public <E extends Expression> Optional<E> pushDownExpression(E expression) {
if (!(expression instanceof PreferPushDownProject
|| (expression instanceof Alias && expression.child(0) instanceof PreferPushDownProject))) {
if (!expression.containsType(PreferPushDownProject.class)) {
return Optional.empty();
}
Pair<Slot, Plan> existPushdown = exprToChildAndSlot.get(expression);
Expression existPushdown = oldExprToNewExpr.get(expression);
if (existPushdown != null) {
return Optional.of((E) existPushdown.first);
return Optional.of((E) existPushdown);
}

Alias pushDownAlias = null;
if (expression instanceof Alias) {
pushDownAlias = (Alias) expression;
} else {
pushDownAlias = new Alias(statementContext.getNextExprId(), expression);
}

Set<Slot> inputSlots = expression.getInputSlots();
for (Plan child : plan.children()) {
if (child.getOutputSet().containsAll(inputSlots)) {
Slot remaimSlot = pushDownAlias.toSlot();
exprToChildAndSlot.put(expression, Pair.of(remaimSlot, child));
childToPushDownProjects.put(child, pushDownAlias);
return Optional.of((E) remaimSlot);
Expression newExpression = expression.rewriteDownShortCircuit(e -> {
if (e instanceof PreferPushDownProject) {
List<Plan> children = plan.children();
for (int i = 0; i < children.size(); i++) {
Plan child = children.get(i);
if (child.getOutputSet().containsAll(e.getInputSlots())) {
Alias alias = new Alias(statementContext.getNextExprId(), e);
Slot slot = alias.toSlot();
childToPushDownProjects.put(child, alias);
return slot;
}
}
}
return e;
});

if (newExpression != expression) {
oldExprToNewExpr.put(expression, newExpression);
return Optional.of((E) newExpression);
} else {
return Optional.empty();
}
return Optional.empty();
}

public List<Plan> buildNewChildren() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Coalesce;
import org.apache.doris.nereids.trees.expressions.functions.scalar.StructElement;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.plans.physical.PhysicalCTEConsumer;
Expand Down Expand Up @@ -67,6 +68,7 @@ public void createTable() throws Exception {

createTable("create table tbl(\n"
+ " id int,\n"
+ " value int,\n"
+ " s struct<\n"
+ " city: string,\n"
+ " data: array<map<\n"
Expand All @@ -78,6 +80,7 @@ public void createTable() throws Exception {

createTable("create table tbl2(\n"
+ " id2 int,\n"
+ " value int,\n"
+ " s2 struct<\n"
+ " city2: string,\n"
+ " data2: array<map<\n"
Expand Down Expand Up @@ -371,13 +374,13 @@ public void testCte() throws Throwable {

@Test
public void testUnion() throws Throwable {
assertColumn("select struct_element(s, 'city') from (select s from tbl union all select null)a",
assertColumn("select coalesce(struct_element(s, 'city'), 'abc') from (select s from tbl union all select null)a",
"struct<city:text>",
ImmutableList.of(path("s", "city")),
ImmutableList.of()
);

assertColumn("select * from (select struct_element(s, 'city') from tbl union all select null)a",
assertColumn("select * from (select coalesce(struct_element(s, 'city'), 'abc') from tbl union all select null)a",
"struct<city:text>",
ImmutableList.of(path("s", "city")),
ImmutableList.of()
Expand All @@ -402,7 +405,7 @@ public void testCteAndUnion() throws Throwable {
@Test
public void testPushDownThroughJoin() {
PlanChecker.from(connectContext)
.analyze("select struct_element(s, 'city') from (select * from tbl)a join (select 100 id, 'f1' name)b on a.id=b.id")
.analyze("select coalesce(struct_element(s, 'city'), 'abc') from (select * from tbl)a join (select 100 id, 'f1' name)b on a.id=b.id")
.rewrite()
.matches(
logicalResultSink(
Expand All @@ -421,7 +424,9 @@ public void testPushDownThroughJoin() {
logicalOneRowRelation()
)
).when(p -> {
Assertions.assertTrue(p.getProjects().size() == 1 && p.getProjects().get(0) instanceof SlotReference);
Assertions.assertTrue(p.getProjects().size() == 1 && p.getProjects().get(0) instanceof Alias
&& p.getProjects().get(0).child(0) instanceof Coalesce
&& p.getProjects().get(0).child(0).child(0) instanceof Slot);
return true;
})
)
Expand Down Expand Up @@ -474,7 +479,9 @@ public void testPushDownThroughWindow() {
})
)
).when(p -> {
Assertions.assertTrue(p.getProjects().size() == 2 && p.getProjects().get(0) instanceof SlotReference);
Assertions.assertTrue(p.getProjects().size() == 2
&& (p.getProjects().get(0) instanceof SlotReference
|| (p.getProjects().get(0) instanceof Alias && p.getProjects().get(0).child(0) instanceof SlotReference)));
return true;
})
)
Expand Down Expand Up @@ -504,7 +511,9 @@ public void testPushDownThroughPartitionTopN() {
)
)
).when(p -> {
Assertions.assertTrue(p.getProjects().size() == 2 && p.getProjects().get(0) instanceof SlotReference);
Assertions.assertTrue(p.getProjects().size() == 2
&& (p.getProjects().get(0) instanceof SlotReference
|| p.getProjects().get(0) instanceof Alias && p.getProjects().get(0).child(0) instanceof SlotReference));
return true;
})
)
Expand Down
Loading