Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1837,15 +1837,38 @@ public PlanFragment visitPhysicalProject(PhysicalProject<? extends Plan> project
registerRewrittenSlot(project, (OlapScanNode) inputFragment.getPlanRoot());
}

List<Expr> projectionExprs = project.getProjects()
.stream()
.map(e -> ExpressionTranslator.translate(e, context))
.collect(Collectors.toList());
List<Slot> slots = project.getProjects()
.stream()
.map(NamedExpression::toSlot)
.collect(Collectors.toList());

PlanNode inputPlanNode = inputFragment.getPlanRoot();
List<Expr> projectionExprs = null;
List<Expr> allProjectionExprs = Lists.newArrayList();
List<Slot> slots = null;
if (project.hasMultiLayerProjection()) {
int layerCount = project.getMultiLayerProjects().size();
for (int i = 0; i < layerCount; i++) {
List<NamedExpression> layer = project.getMultiLayerProjects().get(i);
projectionExprs = layer.stream()
.map(e -> ExpressionTranslator.translate(e, context))
.collect(Collectors.toList());
slots = layer.stream()
.map(NamedExpression::toSlot)
.collect(Collectors.toList());
if (i < layerCount - 1) {
inputPlanNode.addIntermediateProjectList(projectionExprs);
TupleDescriptor projectionTuple = generateTupleDesc(slots, null, context);
inputPlanNode.addIntermediateOutputTupleDescList(projectionTuple);
}
allProjectionExprs.addAll(projectionExprs);
}
} else {
projectionExprs = project.getProjects()
.stream()
.map(e -> ExpressionTranslator.translate(e, context))
.collect(Collectors.toList());
slots = project.getProjects()
.stream()
.map(NamedExpression::toSlot)
.collect(Collectors.toList());
allProjectionExprs.addAll(projectionExprs);
}
// process multicast sink
if (inputFragment instanceof MultiCastPlanFragment) {
MultiCastDataSink multiCastDataSink = (MultiCastDataSink) inputFragment.getSink();
Expand All @@ -1857,10 +1880,9 @@ public PlanFragment visitPhysicalProject(PhysicalProject<? extends Plan> project
return inputFragment;
}

PlanNode inputPlanNode = inputFragment.getPlanRoot();
List<Expr> conjuncts = inputPlanNode.getConjuncts();
Set<SlotId> requiredSlotIdSet = Sets.newHashSet();
for (Expr expr : projectionExprs) {
for (Expr expr : allProjectionExprs) {
Expr.extractSlots(expr, requiredSlotIdSet);
}
Set<SlotId> requiredByProjectSlotIdSet = Sets.newHashSet(requiredSlotIdSet);
Expand Down Expand Up @@ -1895,8 +1917,10 @@ public PlanFragment visitPhysicalProject(PhysicalProject<? extends Plan> project
requiredSlotIdSet.forEach(e -> requiredExprIds.add(context.findExprId(e)));
for (ExprId exprId : requiredExprIds) {
SlotId slotId = ((HashJoinNode) joinNode).getHashOutputExprSlotIdMap().get(exprId);
Preconditions.checkState(slotId != null);
((HashJoinNode) joinNode).addSlotIdToHashOutputSlotIds(slotId);
// Preconditions.checkState(slotId != null);
if (slotId != null) {
((HashJoinNode) joinNode).addSlotIdToHashOutputSlotIds(slotId);
}
}
}
return inputFragment;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.processor.post;

import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;

import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;

/**
* collect common expr
*/
public class CommonSubExpressionCollector extends ExpressionVisitor<Integer, Void> {
public final Map<Integer, Set<Expression>> commonExprByDepth = new HashMap<>();
private final Map<Integer, Set<Expression>> expressionsByDepth = new HashMap<>();

@Override
public Integer visit(Expression expr, Void context) {
if (expr.children().isEmpty()) {
return 0;
}
return collectCommonExpressionByDepth(expr.children().stream().map(child ->
child.accept(this, context)).reduce(Math::max).map(m -> m + 1).orElse(1), expr);
}

private int collectCommonExpressionByDepth(int depth, Expression expr) {
Set<Expression> expressions = getExpressionsFromDepthMap(depth, expressionsByDepth);
if (expressions.contains(expr)) {
Set<Expression> commonExpression = getExpressionsFromDepthMap(depth, commonExprByDepth);
commonExpression.add(expr);
}
expressions.add(expr);
return depth;
}

public static Set<Expression> getExpressionsFromDepthMap(
int depth, Map<Integer, Set<Expression>> depthMap) {
depthMap.putIfAbsent(depth, new LinkedHashSet<>());
return depthMap.get(depth);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.processor.post;

import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;

import com.google.common.collect.Lists;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
* Select A+B, (A+B+C)*2, (A+B+C)*3, D from T
*
* before optimize
* projection:
* Proj: A+B, (A+B+C)*2, (A+B+C)*3, D
*
* ---
* after optimize:
* Projection: List < List < Expression > >
* A+B, C, D
* A+B, A+B+C, D
* A+B, (A+B+C)*2, (A+B+C)*3, D
*/
public class CommonSubExpressionOpt extends PlanPostProcessor {
@Override
public PhysicalProject visitPhysicalProject(PhysicalProject<? extends Plan> project, CascadesContext ctx) {

List<List<NamedExpression>> multiLayers = computeMultiLayerProjections(
project.getInputSlots(), project.getProjects());
project.setMultiLayerProjects(multiLayers);
return project;
}

private List<List<NamedExpression>> computeMultiLayerProjections(
Set<Slot> inputSlots, List<NamedExpression> projects) {

List<List<NamedExpression>> multiLayers = Lists.newArrayList();
CommonSubExpressionCollector collector = new CommonSubExpressionCollector();
for (Expression expr : projects) {
expr.accept(collector, null);
}
Map<Expression, Alias> commonExprToAliasMap = new HashMap<>();
collector.commonExprByDepth.values().stream().flatMap(expressions -> expressions.stream())
.forEach(expression -> {
if (expression instanceof Alias) {
commonExprToAliasMap.put(expression, (Alias) expression);
} else {
commonExprToAliasMap.put(expression, new Alias(expression));
}
});
Map<Expression, Alias> aliasMap = new HashMap<>();
if (!collector.commonExprByDepth.isEmpty()) {
for (int i = 1; i <= collector.commonExprByDepth.size(); i++) {
List<NamedExpression> layer = Lists.newArrayList();
layer.addAll(inputSlots);
Set<Expression> exprsInDepth = CommonSubExpressionCollector
.getExpressionsFromDepthMap(i, collector.commonExprByDepth);
exprsInDepth.forEach(expr -> {
Expression rewritten = expr.accept(ExpressionReplacer.INSTANCE, aliasMap);
Alias alias = new Alias(rewritten);
aliasMap.put(expr, alias);
});
layer.addAll(aliasMap.values());
multiLayers.add(layer);
}
// final layer
List<NamedExpression> finalLayer = Lists.newArrayList();
projects.forEach(expr -> {
Expression rewritten = expr.accept(ExpressionReplacer.INSTANCE, aliasMap);
if (rewritten instanceof Slot) {
finalLayer.add((NamedExpression) rewritten);
} else if (rewritten instanceof Alias) {
finalLayer.add(new Alias(expr.getExprId(), ((Alias) rewritten).child(), expr.getName()));
}
});
multiLayers.add(finalLayer);
}
return multiLayers;
}

/**
* replace sub expr by aliasMap
*/
public static class ExpressionReplacer
extends DefaultExpressionRewriter<Map<? extends Expression, ? extends Alias>> {
public static final ExpressionReplacer INSTANCE = new ExpressionReplacer();

private ExpressionReplacer() {
}

@Override
public Expression visit(Expression expr, Map<? extends Expression, ? extends Alias> replaceMap) {
if (replaceMap.containsKey(expr)) {
return replaceMap.get(expr).toSlot();
}
return super.visit(expr, replaceMap);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ public List<PlanPostProcessor> getProcessors() {
builder.add(new MergeProjectPostProcessor());
builder.add(new RecomputeLogicalPropertiesProcessor());
builder.add(new AddOffsetIntoDistribute());
builder.add(new CommonSubExpressionOpt());
// DO NOT replace PLAN NODE from here
builder.add(new TopNScanOpt());
// after generate rf, DO NOT replace PLAN NODE
builder.add(new FragmentProcessor());
if (!cascadesContext.getConnectContext().getSessionVariable().getRuntimeFilterMode()
.toUpperCase().equals(TRuntimeFilterMode.OFF.name())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.doris.nereids.processor.post.RuntimeFilterGenerator;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
Expand All @@ -41,6 +42,7 @@

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;

import java.util.List;
import java.util.Objects;
Expand All @@ -52,6 +54,12 @@
public class PhysicalProject<CHILD_TYPE extends Plan> extends PhysicalUnary<CHILD_TYPE> implements Project {

private final List<NamedExpression> projects;
//multiLayerProjects is used to extract common expressions
// projects: (A+B) * 2, (A+B) * 3
// multiLayerProjects:
// L1: A+B as x
// L2: x*2, x*3
private List<List<NamedExpression>> multiLayerProjects = Lists.newArrayList();

public PhysicalProject(List<NamedExpression> projects, LogicalProperties logicalProperties, CHILD_TYPE child) {
this(projects, Optional.empty(), logicalProperties, child);
Expand Down Expand Up @@ -227,7 +235,12 @@ public boolean pushDownRuntimeFilter(CascadesContext context, IdGenerator<Runtim

@Override
public List<Slot> computeOutput() {
return projects.stream()
List<NamedExpression> output = projects;
if (! multiLayerProjects.isEmpty()) {
int layers = multiLayerProjects.size();
output = multiLayerProjects.get(layers - 1);
}
return output.stream()
.map(NamedExpression::toSlot)
.collect(ImmutableList.toImmutableList());
}
Expand All @@ -237,4 +250,70 @@ public PhysicalProject<CHILD_TYPE> resetLogicalProperties() {
return new PhysicalProject<>(projects, groupExpression, null, physicalProperties,
statistics, child());
}

/**
* extract common expr, set multi layer projects
*/
public void computeMultiLayerProjectsForCommonExpress() {
// hard code: select (s_suppkey + s_nationkey), 1+(s_suppkey + s_nationkey), s_name from supplier;
if (projects.size() == 3) {
if (projects.get(2) instanceof SlotReference) {
SlotReference sName = (SlotReference) projects.get(2);
if (sName.getName().equals("s_name")) {
Alias a1 = (Alias) projects.get(0); // (s_suppkey + s_nationkey)
Alias a2 = (Alias) projects.get(1); // 1+(s_suppkey + s_nationkey)
// L1: (s_suppkey + s_nationkey) as x, s_name
multiLayerProjects.add(Lists.newArrayList(projects.get(0), projects.get(2)));
List<NamedExpression> l2 = Lists.newArrayList();
l2.add(a1.toSlot());
Alias a3 = new Alias(a2.getExprId(), new Add(a1.toSlot(), a2.child().child(1)), a2.getName());
l2.add(a3);
l2.add(sName);
// L2: x, (1+x) as y, s_name
multiLayerProjects.add(l2);
}
}
}
// hard code:
// select (s_suppkey + n_regionkey) + 1 as x, (s_suppkey + n_regionkey) + 2 as y
// from supplier join nation on s_nationkey=n_nationkey
// projects: x, y
// multi L1: s_suppkey, n_regionkey, (s_suppkey + n_regionkey) as z
// L2: z +1 as x, z+2 as y
if (projects.size() == 2 && projects.get(0) instanceof Alias && projects.get(1) instanceof Alias
&& ((Alias) projects.get(0)).getName().equals("x")
&& ((Alias) projects.get(1)).getName().equals("y")) {
Alias a0 = (Alias) projects.get(0);
Alias a1 = (Alias) projects.get(1);
Add common = (Add) a0.child().child(0); // s_suppkey + n_regionkey
List<NamedExpression> l1 = Lists.newArrayList();
common.children().stream().forEach(child -> l1.add((SlotReference) child));
Alias aliasOfCommon = new Alias(common);
l1.add(aliasOfCommon);
multiLayerProjects.add(l1);
Add add1 = new Add(common, a0.child().child(0).child(1));
Alias aliasOfAdd1 = new Alias(a0.getExprId(), add1, a0.getName());
Add add2 = new Add(common, a1.child().child(0).child(1));
Alias aliasOfAdd2 = new Alias(a1.getExprId(), add2, a1.getName());
List<NamedExpression> l2 = Lists.newArrayList(aliasOfAdd1, aliasOfAdd2);
multiLayerProjects.add(l2);
}
}

public boolean hasMultiLayerProjection() {
return !multiLayerProjects.isEmpty();
}

public List<List<NamedExpression>> getMultiLayerProjects() {
return multiLayerProjects;
}

public void setMultiLayerProjects(List<List<NamedExpression>> multiLayers) {
this.multiLayerProjects = multiLayers;
}

@Override
public List<Slot> getOutput() {
return computeOutput();
}
}
Loading