Skip to content

Commit

Permalink
[Enhancement](Optimizer) Nereids pattern matching base framework (apa…
Browse files Browse the repository at this point in the history
…che#9474)

This pr provide a new pattern matching framework for Nereids optimizer.

The new pattern matching framework contains this concepts:

1. `Pattern`/`PatternDescriptor`: the tree node's multiple hierarchy shape, e.g. `logicalJoin(logicalJoin(), any()` pattern describe a plan that root is a `LogicalJoin` and the left child is `LogicalJoin` too.
2. `MatchedAction`: a callback function when the pattern matched, usually you can create new plan to replace the origin matched plan.
3. `MatchingContext`: the param pass through MatchedAction, contains the matched plan root and the PlannerContext.
4. `PatternMatcher`: contains PatternDescriptor and MatchedAction
5. `Rule`: a rewrite rule contains RuleType, PatternPromise, Pattern and transform function(equals to MatchedAction)
6. `RuleFactory`: the factory can help us build Rules easily. RuleFactory extends Patterns interface, and have some predefined pattern descriptors.

for example, Join commutative:
```java
public class JoinCommutative extends OneExplorationRuleFactory {
    @OverRide
    public Rule<Plan> build() {
        return innerLogicalJoin().thenApply(ctx -> {
            return new LogicalJoin(
                JoinType.INNER_JOIN,
                ctx.root.getOnClause(),
                ctx.root.right(),
                ctx.root.left()
            );
        }).toRule(RuleType.LOGICAL_JOIN_COMMUTATIVE);
    }
}
```

the code above show the three step to create a Rule
1. 'innerLogicalJoin()' declare pattern  is an inner logical join. 'innerLogicalJoin' is a predefined pattern.
2. invoke 'thenApply()' function to combine a MatchedAction, return a new LogicalJoin with exchange children.
3. invoke 'toRule()' function to convert to Rule

You can think the Rule contains three parts: 
1. Pattern
2. transform function / MatchedAction
3. RuleType and RulePromise

So
1. `innerLogicalJoin()` create a `PatternDescriptor`, which contains a `Pattern`
2. `PatternDescriptor.then()` convert `PatternDescriptor` to `PatternMatcher,` witch contains Pattern and MatchedAction
3. `PatternMatcher.toRule()` convert `PatternMatcher` to a Rule

This three step inspired by the currying in function programing.

It should be noted, apache#9446 provide a generic type for TreeNode's children, so we can infer multiple hierarchy type in this pattern matching framework, so you can get the really tree node type without unsafely cast. like this:
```java
logicalJoin(logicalJoin(), any()).then(j -> {
     // j can be inferred type to LogicalJoin<LogicalJoin<Plan, Plan>, Plan>
     // so j.left() can be inferred type to LogicalJoin<Plan, Plan>,
     // so you don't need to cast j.left() from 'Plan' to 'LogicalJoin'
     var node = j.left().left();
})
```
  • Loading branch information
924060929 authored May 10, 2022
1 parent e61d296 commit 99b8e08
Show file tree
Hide file tree
Showing 56 changed files with 1,235 additions and 191 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
/**
* Expression for unbound alias.
*/
public class UnboundAlias<CHILD_TYPE extends Expression<CHILD_TYPE>>
public class UnboundAlias<CHILD_TYPE extends Expression>
extends UnaryExpression<UnboundAlias<CHILD_TYPE>, CHILD_TYPE>
implements NamedExpression<UnboundAlias<CHILD_TYPE>> {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.doris.nereids.memo.PlanReference;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleSet;
import org.apache.doris.nereids.trees.plans.Plan;

import java.util.List;

Expand All @@ -45,7 +46,7 @@ public RuleSet getRuleSet() {
return context.getOptimizerContext().getRuleSet();
}

public void prunedInvalidRules(PlanReference planReference, List<Rule> candidateRules) {
public void prunedInvalidRules(PlanReference planReference, List<Rule<Plan>> candidateRules) {

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
*/
public class ApplyRuleJob extends Job {
private final PlanReference planReference;
private final Rule rule;
private final Rule<Plan> rule;
private final boolean exploredOnly;

/**
Expand All @@ -44,7 +44,7 @@ public class ApplyRuleJob extends Job {
* @param rule rule to be applied
* @param context context of optimization
*/
public ApplyRuleJob(PlanReference planReference, Rule rule, PlannerContext context) {
public ApplyRuleJob(PlanReference planReference, Rule<Plan> rule, PlannerContext context) {
super(JobType.APPLY_RULE, context);
this.planReference = planReference;
this.rule = rule;
Expand All @@ -63,8 +63,8 @@ public void execute() throws AnalysisException {
if (!rule.check(plan, context)) {
continue;
}
List<Plan<?>> newPlanList = rule.transform(plan, context);
for (Plan<?> newPlan : newPlanList) {
List<Plan> newPlanList = rule.transform(plan, context);
for (Plan newPlan : newPlanList) {
PlanReference newReference = context.getOptimizerContext().getMemo()
.newPlanReference(newPlan, planReference.getParent());
// TODO need to check return is a new Reference, other wise will be into a dead loop
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.doris.nereids.memo.PlanReference;
import org.apache.doris.nereids.pattern.Pattern;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.trees.plans.Plan;

import java.util.Comparator;
import java.util.List;
Expand All @@ -47,15 +48,15 @@ public ExplorePlanJob(PlanReference planReference, PlannerContext context) {

@Override
public void execute() {
List<Rule> explorationRules = getRuleSet().getExplorationRules();
List<Rule<Plan>> explorationRules = getRuleSet().getExplorationRules();
prunedInvalidRules(planReference, explorationRules);
explorationRules.sort(Comparator.comparingInt(o -> o.getRulePromise().promise()));

for (Rule rule : explorationRules) {
pushTask(new ApplyRuleJob(planReference, rule, context));
for (int i = 0; i < rule.getPattern().children().size(); ++i) {
Pattern childPattern = rule.getPattern().child(i);
if (!childPattern.children().isEmpty()) {
if (childPattern.arity() > 0) {
Group childSet = planReference.getChildren().get(i);
pushTask(new ExploreGroupJob(childSet, context));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
import org.apache.doris.nereids.memo.PlanReference;
import org.apache.doris.nereids.pattern.Pattern;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.trees.plans.Plan;

import com.clearspring.analytics.util.Lists;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;

Expand All @@ -43,9 +43,9 @@ public OptimizePlanJob(PlanReference planReference, PlannerContext context) {

@Override
public void execute() {
List<Rule> validRules = Lists.newArrayList();
List<Rule> explorationRules = getRuleSet().getExplorationRules();
List<Rule> implementationRules = getRuleSet().getImplementationRules();
List<Rule<Plan>> validRules = new ArrayList<>();
List<Rule<Plan>> explorationRules = getRuleSet().getExplorationRules();
List<Rule<Plan>> implementationRules = getRuleSet().getImplementationRules();
prunedInvalidRules(planReference, explorationRules);
prunedInvalidRules(planReference, implementationRules);
validRules.addAll(explorationRules);
Expand All @@ -59,7 +59,7 @@ public void execute() {
// child before applying the rule. (assumes task pool is effectively a stack)
for (int i = 0; i < rule.getPattern().children().size(); ++i) {
Pattern childPattern = rule.getPattern().child(i);
if (!childPattern.children().isEmpty()) {
if (childPattern.arity() > 0) {
Group childSet = planReference.getChildren().get(i);
pushTask(new ExploreGroupJob(childSet, context));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.pattern;

import org.apache.doris.nereids.trees.TreeNode;

/**
* Define an callback action when match a pattern, usually implement as a rule body.
* e.g. exchange join children for JoinCommutative Rule
*/
public interface MatchedAction<INPUT_TYPE extends TreeNode, OUTPUT_TYPE extends TreeNode> {
OUTPUT_TYPE apply(MatchingContext<INPUT_TYPE> ctx);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.pattern;

import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.trees.TreeNode;

/**
* Define a context when match a pattern pass through a MatchedAction.
*/
public class MatchingContext<T extends TreeNode> {
public final T root;
public final Pattern<T> pattern;
public final PlannerContext plannerContext;

/**
* the MatchingContext is the param pass through the MatchedAction.
*
* @param root the matched tree node root
* @param pattern the defined pattern
* @param plannerContext the planner context
*/
public MatchingContext(T root, Pattern<T> pattern, PlannerContext plannerContext) {
this.root = root;
this.pattern = pattern;
this.plannerContext = plannerContext;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,23 @@

import org.apache.doris.nereids.trees.AbstractTreeNode;
import org.apache.doris.nereids.trees.NodeType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.TreeNode;

import com.google.common.collect.ImmutableList;

import java.util.List;
import java.util.Objects;
import java.util.function.Predicate;


/**
* Pattern node used in pattern matching.
*/
public class Pattern extends AbstractTreeNode<Pattern> {
public static final Pattern PATTERN_MULTI_LEAF_INSTANCE = new Pattern(NodeType.PATTERN_MULTI_LEAF);
public static final Pattern PATTERN_LEAF_INSTANCE = new Pattern(NodeType.PATTERN_LEAF);
public class Pattern<T extends TreeNode> extends AbstractTreeNode<Pattern<T>> {
public static final Pattern ANY = new Pattern(NodeType.ANY);
public static final Pattern MULTI = new Pattern(NodeType.MULTI);

public final List<Predicate<T>> predicates;
private final NodeType nodeType;

/**
Expand All @@ -41,6 +47,20 @@ public class Pattern extends AbstractTreeNode<Pattern> {
public Pattern(NodeType nodeType, Pattern... children) {
super(NodeType.PATTERN, children);
this.nodeType = nodeType;
this.predicates = ImmutableList.of();
}

/**
* Constructor for Pattern.
*
* @param nodeType node type to matching
* @param predicates custom matching predicate
* @param children sub pattern
*/
public Pattern(NodeType nodeType, List<Predicate<T>> predicates, Pattern... children) {
super(NodeType.PATTERN, children);
this.nodeType = nodeType;
this.predicates = ImmutableList.copyOf(predicates);
}

/**
Expand All @@ -55,23 +75,49 @@ public NodeType getNodeType() {
/**
* Return ture if current Pattern match Plan in params.
*
* @param plan wait to match
* @param root wait to match
* @return ture if current Pattern match Plan in params
*/
public boolean matchRoot(Plan<?> plan) {
if (plan == null) {
public boolean matchRoot(T root) {
if (root == null) {
return false;
}

if (plan.children().size() < this.children().size() && children.contains(PATTERN_MULTI_LEAF_INSTANCE)) {
if (root.children().size() < this.children().size() && !children.contains(MULTI)) {
return false;
}

if (nodeType == NodeType.PATTERN_MULTI_LEAF || nodeType == NodeType.PATTERN_LEAF) {
if (nodeType == NodeType.MULTI || nodeType == NodeType.ANY) {
return true;
}

return getNodeType().equals(plan.getType());
return getNodeType().equals(root.getType())
&& predicates.stream().allMatch(predicate -> predicate.test(root));
}

/**
* Return ture if children patterns match Plan in params.
*
* @param root wait to match
* @return ture if children Patterns match root's children in params
*/
public boolean matchChildren(T root) {
for (int i = 0; i < arity(); i++) {
if (!child(i).match(root.child(i))) {
return false;
}
}
return true;
}

/**
* Return ture if children patterns match Plan in params.
*
* @param root wait to match
* @return ture if current pattern and children patterns match root in params
*/
public boolean match(T root) {
return matchRoot(root) && matchChildren(root);
}

@Override
Expand All @@ -90,4 +136,14 @@ public boolean equals(Object o) {
public int hashCode() {
return Objects.hash(nodeType);
}

@Override
public List<Pattern> children() {
return (List) children;
}

@Override
public Pattern child(int index) {
return (Pattern) children.get(index);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.pattern;

import org.apache.doris.nereids.rules.RulePromise;
import org.apache.doris.nereids.trees.TreeNode;

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.function.Function;
import java.util.function.Predicate;

/**
* Define a descriptor to wrap a pattern tree to define a pattern shape.
* It can support pattern generic type to MatchedAction.
*/
public class PatternDescriptor<INPUT_TYPE extends RULE_TYPE, RULE_TYPE extends TreeNode> {
public final Pattern<INPUT_TYPE> pattern;
public final RulePromise defaultPromise;
public final List<Predicate<INPUT_TYPE>> predicates = new ArrayList<>();

public PatternDescriptor(Pattern<INPUT_TYPE> pattern, RulePromise defaultPromise) {
this.pattern = Objects.requireNonNull(pattern, "pattern can not be null");
this.defaultPromise = Objects.requireNonNull(defaultPromise, "defaultPromise can not be null");
}

public PatternDescriptor<INPUT_TYPE, RULE_TYPE> when(Predicate<INPUT_TYPE> predicate) {
predicates.add(predicate);
return this;
}

public <OUTPUT_TYPE extends RULE_TYPE> PatternMatcher<INPUT_TYPE, OUTPUT_TYPE, RULE_TYPE> then(
Function<INPUT_TYPE, OUTPUT_TYPE> matchedAction) {
return new PatternMatcher<>(patternWithPredicates(), defaultPromise, ctx -> matchedAction.apply(ctx.root));
}

public <OUTPUT_TYPE extends RULE_TYPE> PatternMatcher<INPUT_TYPE, OUTPUT_TYPE, RULE_TYPE> thenApply(
MatchedAction<INPUT_TYPE, OUTPUT_TYPE> matchedAction) {
return new PatternMatcher<>(patternWithPredicates(), defaultPromise, matchedAction);
}

public Pattern<INPUT_TYPE> patternWithPredicates() {
Pattern[] children = pattern.children().toArray(new Pattern[0]);
return new Pattern<>(pattern.getNodeType(), predicates, children);
}
}
Loading

0 comments on commit 99b8e08

Please sign in to comment.