Skip to content

Commit

Permalink
[DNM](nerieds) avoid exploration unexpected number of group expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
morrySnow committed Jul 25, 2024
1 parent 501270d commit 99517db
Show file tree
Hide file tree
Showing 10 changed files with 65 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ private PhysicalPlan chooseBestPlan(Group rootGroup, PhysicalProperties physical
GroupExpression groupExpression = rootGroup.getLowestCostPlan(physicalProperties).orElseThrow(
() -> new AnalysisException("lowestCostPlans with physicalProperties("
+ physicalProperties + ") doesn't exist in root group")).second;
if (rootGroup.getEnforcers().contains(groupExpression)) {
if (rootGroup.getEnforcers().containsKey(groupExpression)) {
rootGroup.addChosenEnforcerId(groupExpression.getId().asInt());
rootGroup.addChosenEnforcerProperties(physicalProperties);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ public final void execute() throws AnalysisException {
GroupExpressionMatching groupExpressionMatching
= new GroupExpressionMatching(rule.getPattern(), groupExpression);
for (Plan plan : groupExpressionMatching) {
if (rule.isExploration()
&& context.getCascadesContext().getMemo().getGroupExpressionsSize() > context.getCascadesContext()
.getConnectContext().getSessionVariable().memoMaxGroupExpressionSize) {
break;
}
List<Plan> newPlans = rule.transform(plan, context.getCascadesContext());
for (Plan newPlan : newPlans) {
if (newPlan == plan) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public class Group {

private final List<GroupExpression> logicalExpressions = Lists.newArrayList();
private final List<GroupExpression> physicalExpressions = Lists.newArrayList();
private final List<GroupExpression> enforcers = Lists.newArrayList();
private final Map<GroupExpression, GroupExpression> enforcers = Maps.newHashMap();
private boolean isStatsReliable = true;
private LogicalProperties logicalProperties;

Expand Down Expand Up @@ -239,10 +239,10 @@ public GroupExpression getBestPlan(PhysicalProperties properties) {

public void addEnforcer(GroupExpression enforcer) {
enforcer.setOwnerGroup(this);
enforcers.add(enforcer);
enforcers.put(enforcer, enforcer);
}

public List<GroupExpression> getEnforcers() {
public Map<GroupExpression, GroupExpression> getEnforcers() {
return enforcers;
}

Expand Down Expand Up @@ -346,9 +346,9 @@ public void mergeTo(Group target) {
parentExpressions.keySet().forEach(parent -> target.addParentExpression(parent));

// move enforcers Ownership
enforcers.forEach(ge -> ge.children().set(0, target));
enforcers.forEach((k, v) -> k.children().set(0, target));
// TODO: dedup?
enforcers.forEach(enforcer -> target.addEnforcer(enforcer));
enforcers.forEach((k, v) -> target.addEnforcer(k));
enforcers.clear();

// move LogicalExpression PhysicalExpression Ownership
Expand Down Expand Up @@ -458,7 +458,7 @@ public String toString() {
str.append(" ").append(physicalExpression).append("\n");
}
str.append(" enforcers:\n");
for (GroupExpression enforcer : enforcers) {
for (GroupExpression enforcer : enforcers.keySet()) {
str.append(" ").append(enforcer).append("\n");
}
if (!chosenEnforcerIdList.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -552,8 +552,7 @@ public void mergeGroup(Group source, Group destination, HashMap<Long, Group> pla
return;
}
Group parentOwnerGroup = srcParent.getOwnerGroup();
HashSet<GroupExpression> enforcers = new HashSet<>(parentOwnerGroup.getEnforcers());
if (enforcers.contains(srcParent)) {
if (parentOwnerGroup.getEnforcers().containsKey(srcParent)) {
continue;
}
needReplaceChild.add(srcParent);
Expand Down Expand Up @@ -946,7 +945,7 @@ private List<GroupExpression> extractGroupExpressionSatisfyProp(Group group, Phy
List<GroupExpression> exprs = Lists.newArrayList(bestExpr);
Set<GroupExpression> hasVisited = new HashSet<>();
hasVisited.add(bestExpr);
Stream.concat(group.getPhysicalExpressions().stream(), group.getEnforcers().stream())
Stream.concat(group.getPhysicalExpressions().stream(), group.getEnforcers().keySet().stream())
.forEach(groupExpression -> {
if (!groupExpression.getInputPropertiesListOrEmpty(prop).isEmpty()
&& !groupExpression.equals(bestExpr) && !hasVisited.contains(groupExpression)) {
Expand All @@ -969,7 +968,7 @@ private List<List<PhysicalProperties>> extractInputProperties(GroupExpression gr
res.add(groupExpression.getInputPropertiesList(prop));

// return optimized input for enforcer
if (groupExpression.getOwnerGroup().getEnforcers().contains(groupExpression)) {
if (groupExpression.getOwnerGroup().getEnforcers().containsKey(groupExpression)) {
return res;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ public boolean isRewrite() {
return ruleType.getRuleTypeClass() == RuleTypeClass.REWRITE;
}

public boolean isExploration() {
return ruleType.getRuleTypeClass() == RuleTypeClass.EXPLORATION;
}

@Override
public String toString() {
return getRuleType().toString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
import org.apache.doris.nereids.rules.exploration.join.LogicalJoinSemiJoinTransposeProject;
import org.apache.doris.nereids.rules.exploration.join.OuterJoinAssocProject;
import org.apache.doris.nereids.rules.exploration.join.OuterJoinLAsscomProject;
import org.apache.doris.nereids.rules.exploration.join.PushDownProjectThroughInnerOuterJoin;
import org.apache.doris.nereids.rules.exploration.join.PushDownProjectThroughSemiJoin;
import org.apache.doris.nereids.rules.exploration.join.SemiJoinSemiJoinTransposeProject;
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewAggregateOnNoneAggregateRule;
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewAggregateRule;
Expand Down Expand Up @@ -126,8 +124,8 @@ public class RuleSet {
.add(OuterJoinLAsscomProject.INSTANCE)
.add(SemiJoinSemiJoinTransposeProject.INSTANCE)
.add(LogicalJoinSemiJoinTransposeProject.INSTANCE)
.add(PushDownProjectThroughInnerOuterJoin.INSTANCE)
.add(PushDownProjectThroughSemiJoin.INSTANCE)
// .add(PushDownProjectThroughInnerOuterJoin.INSTANCE)
// .add(PushDownProjectThroughSemiJoin.INSTANCE)
.add(TransposeAggSemiJoinProject.INSTANCE)
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -384,29 +384,29 @@ public enum RuleType {
EAGER_SPLIT(RuleTypeClass.EXPLORATION),

EXPLORATION_SENTINEL(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_PROJECT_JOIN(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_FILTER_JOIN(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_PROJECT_FILTER_JOIN(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_FILTER_PROJECT_JOIN(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_ONLY_JOIN(RuleTypeClass.EXPLORATION),

MATERIALIZED_VIEW_PROJECT_AGGREGATE(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_FILTER_AGGREGATE(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_PROJECT_FILTER_AGGREGATE(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_FILTER_PROJECT_AGGREGATE(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_ONLY_AGGREGATE(RuleTypeClass.EXPLORATION),

MATERIALIZED_VIEW_PROJECT_AGGREGATE_ON_NONE_AGGREGATE(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_FILTER_AGGREGATE_ON_NONE_AGGREGATE(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_PROJECT_FILTER_AGGREGATE_ON_NONE_AGGREGATE(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_FILTER_PROJECT_AGGREGATE_ON_NONE_AGGREGATE(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_ONLY_AGGREGATE_ON_NONE_AGGREGATE(RuleTypeClass.EXPLORATION),

MATERIALIZED_VIEW_FILTER_SCAN(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_PROJECT_SCAN(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_FILTER_PROJECT_SCAN(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_PROJECT_FILTER_SCAN(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_ONLY_SCAN(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_PROJECT_JOIN(RuleTypeClass.MATERIALIZE_VIEW),
MATERIALIZED_VIEW_FILTER_JOIN(RuleTypeClass.MATERIALIZE_VIEW),
MATERIALIZED_VIEW_PROJECT_FILTER_JOIN(RuleTypeClass.MATERIALIZE_VIEW),
MATERIALIZED_VIEW_FILTER_PROJECT_JOIN(RuleTypeClass.MATERIALIZE_VIEW),
MATERIALIZED_VIEW_ONLY_JOIN(RuleTypeClass.MATERIALIZE_VIEW),

MATERIALIZED_VIEW_PROJECT_AGGREGATE(RuleTypeClass.MATERIALIZE_VIEW),
MATERIALIZED_VIEW_FILTER_AGGREGATE(RuleTypeClass.MATERIALIZE_VIEW),
MATERIALIZED_VIEW_PROJECT_FILTER_AGGREGATE(RuleTypeClass.MATERIALIZE_VIEW),
MATERIALIZED_VIEW_FILTER_PROJECT_AGGREGATE(RuleTypeClass.MATERIALIZE_VIEW),
MATERIALIZED_VIEW_ONLY_AGGREGATE(RuleTypeClass.MATERIALIZE_VIEW),

MATERIALIZED_VIEW_PROJECT_AGGREGATE_ON_NONE_AGGREGATE(RuleTypeClass.MATERIALIZE_VIEW),
MATERIALIZED_VIEW_FILTER_AGGREGATE_ON_NONE_AGGREGATE(RuleTypeClass.MATERIALIZE_VIEW),
MATERIALIZED_VIEW_PROJECT_FILTER_AGGREGATE_ON_NONE_AGGREGATE(RuleTypeClass.MATERIALIZE_VIEW),
MATERIALIZED_VIEW_FILTER_PROJECT_AGGREGATE_ON_NONE_AGGREGATE(RuleTypeClass.MATERIALIZE_VIEW),
MATERIALIZED_VIEW_ONLY_AGGREGATE_ON_NONE_AGGREGATE(RuleTypeClass.MATERIALIZE_VIEW),

MATERIALIZED_VIEW_FILTER_SCAN(RuleTypeClass.MATERIALIZE_VIEW),
MATERIALIZED_VIEW_PROJECT_SCAN(RuleTypeClass.MATERIALIZE_VIEW),
MATERIALIZED_VIEW_FILTER_PROJECT_SCAN(RuleTypeClass.MATERIALIZE_VIEW),
MATERIALIZED_VIEW_PROJECT_FILTER_SCAN(RuleTypeClass.MATERIALIZE_VIEW),
MATERIALIZED_VIEW_ONLY_SCAN(RuleTypeClass.MATERIALIZE_VIEW),

// implementation rules
LOGICAL_ONE_ROW_RELATION_TO_PHYSICAL_ONE_ROW_RELATION(RuleTypeClass.IMPLEMENTATION),
Expand Down Expand Up @@ -494,6 +494,7 @@ public <INPUT_TYPE extends Plan, OUTPUT_TYPE extends Plan> Rule build(
enum RuleTypeClass {
REWRITE,
EXPLORATION,
MATERIALIZE_VIEW,
// This type is used for unit test only.
CHECK,
IMPLEMENTATION,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
import org.apache.doris.nereids.util.Utils;

import com.google.common.base.Preconditions;
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableSet;
Expand All @@ -46,6 +48,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

/**
* Logical project plan.
Expand All @@ -54,6 +57,7 @@ public class LogicalProject<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_
implements Project, OutputPrunable {

private final List<NamedExpression> projects;
private final Supplier<Set<NamedExpression>> projectsSet;
private final List<NamedExpression> excepts;
private final boolean isDistinct;

Expand Down Expand Up @@ -83,6 +87,7 @@ private LogicalProject(List<NamedExpression> projects, List<NamedExpression> exc
this.projects = projects.isEmpty()
? ImmutableList.of(ExpressionUtils.selectMinimumColumn(child.get(0).getOutput()))
: projects;
this.projectsSet = Suppliers.memoize(() -> ImmutableSet.copyOf(projects));
this.excepts = Utils.fastToImmutableList(excepts);
this.isDistinct = isDistinct;
}
Expand Down Expand Up @@ -138,7 +143,7 @@ public boolean equals(Object o) {
return false;
}
LogicalProject<?> that = (LogicalProject<?>) o;
boolean equal = projects.equals(that.projects)
boolean equal = projectsSet.get().equals(that.projectsSet.get())
&& excepts.equals(that.excepts)
&& isDistinct == that.isDistinct;
// TODO: should add exprId for UnBoundStar and BoundStar for equality comparison
Expand All @@ -150,7 +155,7 @@ public boolean equals(Object o) {

@Override
public int hashCode() {
return Objects.hash(projects);
return Objects.hash(projectsSet.get());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,24 @@
import org.apache.doris.statistics.Statistics;

import com.google.common.base.Preconditions;
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;

import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

/**
* Physical project plan.
*/
public class PhysicalProject<CHILD_TYPE extends Plan> extends PhysicalUnary<CHILD_TYPE> implements Project {

private final List<NamedExpression> projects;
private final Supplier<Set<NamedExpression>> projectsSet;
//multiLayerProjects is used to extract common expressions
// projects: (A+B) * 2, (A+B) * 3
// multiLayerProjects:
Expand All @@ -62,6 +67,7 @@ public PhysicalProject(List<NamedExpression> projects, Optional<GroupExpression>
LogicalProperties logicalProperties, CHILD_TYPE child) {
super(PlanType.PHYSICAL_PROJECT, groupExpression, logicalProperties, child);
this.projects = ImmutableList.copyOf(Objects.requireNonNull(projects, "projects can not be null"));
this.projectsSet = Suppliers.memoize(() -> ImmutableSet.copyOf(projects));
}

public PhysicalProject(List<NamedExpression> projects, Optional<GroupExpression> groupExpression,
Expand All @@ -70,6 +76,7 @@ public PhysicalProject(List<NamedExpression> projects, Optional<GroupExpression>
super(PlanType.PHYSICAL_PROJECT, groupExpression, logicalProperties, physicalProperties, statistics,
child);
this.projects = ImmutableList.copyOf(Objects.requireNonNull(projects, "projects can not be null"));
this.projectsSet = Suppliers.memoize(() -> ImmutableSet.copyOf(projects));
}

public List<NamedExpression> getProjects() {
Expand All @@ -96,13 +103,13 @@ public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) {
return false;
}
PhysicalProject that = (PhysicalProject) o;
return projects.equals(that.projects);
PhysicalProject<?> that = (PhysicalProject<?>) o;
return projectsSet.get().equals(that.projectsSet.get());
}

@Override
public int hashCode() {
return Objects.hash(projects);
return Objects.hash(projectsSet.get());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public <R, C> R accept(PlanVisitor<R, C> visitor, C context) {

@Override
public String toString() {
return Utils.toSqlString("PhysicalUnion" + getGroupIdWithPrefix(),
return Utils.toSqlString("PhysicalUnion" + "[" + id.asInt() + "]" + getGroupIdWithPrefix(),
"qualifier", qualifier,
"outputs", outputs,
"regularChildrenOutputs", regularChildrenOutputs,
Expand All @@ -98,7 +98,7 @@ public String toString() {

@Override
public PhysicalUnion withChildren(List<Plan> children) {
return new PhysicalUnion(qualifier, outputs, regularChildrenOutputs, constantExprsList,
return new PhysicalUnion(qualifier, outputs, regularChildrenOutputs, constantExprsList, groupExpression,
getLogicalProperties(), children);
}

Expand All @@ -119,7 +119,7 @@ public Plan withGroupExprLogicalPropChildren(Optional<GroupExpression> groupExpr
public PhysicalUnion withPhysicalPropertiesAndStats(
PhysicalProperties physicalProperties, Statistics statistics) {
return new PhysicalUnion(qualifier, outputs, regularChildrenOutputs, constantExprsList,
Optional.empty(), getLogicalProperties(), physicalProperties, statistics, children);
groupExpression, getLogicalProperties(), physicalProperties, statistics, children);
}

@Override
Expand Down

0 comments on commit 99517db

Please sign in to comment.