Skip to content

Commit

Permalink
[opt](nereids) infer in-predicate from or-predicate (apache#46468)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?
previous verion has follow drawbacks
1. inferred some in-predicates, which cannot be pushed down to storage
layer
2. it is easy to lead dead loop, because other expression rewrite rule
may remove its flag in expression's mutableState

in order to solve above issues, we implemented a new version
first, it is a plan node level rule to avoid to be applied to the same
expression repeatedly
second, we define replace mode and extract mode. if in replace mode, the
original expression should be equivalent to the inferred in-pred, which
is used for all plan node's expressions except filter.

for example, orig = "(a=1 and b=1) or (a=2 and c=2)" is equivalent to "a
in (1, 2) and (a=1 and b=1) or (a=2 and c=2)".
orig is a filter condition, "a in (1, 2)" can be pushed down to storage
layer, and this infer is useful. But if this is orig is an other join
condition, this inferrence is useless. So in extract mode, "a in (1, 2)"
is inferred, but in replace mode, it is not.
  • Loading branch information
englefly authored Jan 9, 2025
1 parent 7d3d36e commit af8a975
Show file tree
Hide file tree
Showing 25 changed files with 229 additions and 119 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
import org.apache.doris.nereids.rules.rewrite.FindHashConditionForJoin;
import org.apache.doris.nereids.rules.rewrite.InferAggNotNull;
import org.apache.doris.nereids.rules.rewrite.InferFilterNotNull;
import org.apache.doris.nereids.rules.rewrite.InferInPredicateFromOr;
import org.apache.doris.nereids.rules.rewrite.InferJoinNotNull;
import org.apache.doris.nereids.rules.rewrite.InferPredicates;
import org.apache.doris.nereids.rules.rewrite.InferSetOperatorDistinct;
Expand Down Expand Up @@ -323,6 +324,9 @@ public class Rewriter extends AbstractBatchJobExecutor {
// after EliminateEmptyRelation, project can be pushed into union
topDown(new PushProjectIntoUnion())
),
topic("infer In-predicate from Or-predicate",
topDown(new InferInPredicateFromOr())
),
// putting the "Column pruning and infer predicate" topic behind the "Set operation optimization"
// is because that pulling up predicates from union needs EliminateEmptyRelation in union child
topic("Column pruning and infer predicate",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ public enum RuleType {
TRANSPOSE_LOGICAL_SEMI_JOIN_AGG_PROJECT(RuleTypeClass.REWRITE),

// expression of plan rewrite
EXTRACT_IN_PREDICATE_FROM_OR(RuleTypeClass.REWRITE),
REWRITE_ONE_ROW_RELATION_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_PROJECT_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_AGG_EXPRESSION(RuleTypeClass.REWRITE),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.apache.doris.nereids.rules.expression.rules.ExtractCommonFactorRule;
import org.apache.doris.nereids.rules.expression.rules.LikeToEqualRewrite;
import org.apache.doris.nereids.rules.expression.rules.NullSafeEqualToEqual;
import org.apache.doris.nereids.rules.expression.rules.OrToIn;
import org.apache.doris.nereids.rules.expression.rules.SimplifyComparisonPredicate;
import org.apache.doris.nereids.rules.expression.rules.SimplifyInPredicate;
import org.apache.doris.nereids.rules.expression.rules.SimplifyRange;
Expand All @@ -47,7 +46,6 @@ public class ExpressionOptimization extends ExpressionRewrite {
SimplifyComparisonPredicate.INSTANCE,
SimplifyInPredicate.INSTANCE,
SimplifyRange.INSTANCE,
OrToIn.INSTANCE,
DateFunctionRewrite.INSTANCE,
ArrayContainToArrayOverlap.INSTANCE,
CaseWhenToIf.INSTANCE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ public enum ExpressionRuleType {
MERGE_DATE_TRUNC,
NORMALIZE_BINARY_PREDICATES,
NULL_SAFE_EQUAL_TO_EQUAL,
OR_TO_IN,
REPLACE_VARIABLE_BY_LITERAL,
SIMPLIFY_ARITHMETIC_COMPARISON,
SIMPLIFY_ARITHMETIC,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,8 @@

import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.expression.ExpressionBottomUpRewriter;
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
import org.apache.doris.nereids.rules.expression.ExpressionRewrite;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.ExpressionRuleType;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
Expand All @@ -33,7 +30,6 @@
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.MutableState;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
Expand All @@ -45,48 +41,94 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

/**
* dependends on SimplifyRange rule
* Do NOT use this rule in ExpressionOptimization
* apply this rule on filter expressions in extract mode,
* on other expressions in replace mode
*
*/
public class OrToIn implements ExpressionPatternRuleFactory {
public class OrToIn {
/**
* case 1: from (a=1 and b=1) or (a=2), "a in (1, 2)" is inferred,
* inferred expr is not equivalent to the original expr
* - replaceMode: output origin expr
* - extractMode: output a in (1, 2) and (a=1 and b=1) or (a=2)
*
* case 2: from (a=1) or (a=2), "a in (1,2)" is inferred, the inferred expr is equivalent to the original expr
* - replaceMode/extractMode: output a in (1, 2)
*
* extractMode only used for filter, the inferred In-predicate could be pushed down.
*/
public enum Mode {
replaceMode,
extractMode
}

public static final OrToIn INSTANCE = new OrToIn();
public static final OrToIn EXTRACT_MODE_INSTANCE = new OrToIn(Mode.extractMode);
public static final OrToIn REPLACE_MODE_INSTANCE = new OrToIn(Mode.replaceMode);

public static final int REWRITE_OR_TO_IN_PREDICATE_THRESHOLD = 2;

@Override
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
return ImmutableList.of(
matchesTopType(Or.class).then(OrToIn.INSTANCE::rewrite)
.toRule(ExpressionRuleType.OR_TO_IN)
);
private final Mode mode;

public OrToIn(Mode mode) {
this.mode = mode;
}

/**
* simplify and then rewrite
*/
public Expression rewriteTree(Expression expr, ExpressionRewriteContext context) {
if (expr instanceof CompoundPredicate) {
expr = SimplifyRange.rewrite((CompoundPredicate) expr, context);
ExpressionBottomUpRewriter simplify = ExpressionRewrite.bottomUp(SimplifyRange.INSTANCE);
expr = simplify.rewrite(expr, context);
return rewriteTree(expr);

}

/**
* rewrite tree
*/
public Expression rewriteTree(Expression expr) {
List<Expression> children = expr.children();
if (children.isEmpty()) {
return expr;
}
List<Expression> newChildren = children.stream()
.map(this::rewriteTree).collect(Collectors.toList());
if (expr instanceof And) {
// filter out duplicated conjunct
// example: OrToInTest.testDeDup()
Set<Expression> dedupSet = new LinkedHashSet<>();
for (Expression newChild : newChildren) {
dedupSet.addAll(ExpressionUtils.extractConjunction(newChild));
}
newChildren = Lists.newArrayList(dedupSet);
}
ExpressionBottomUpRewriter bottomUpRewriter = ExpressionRewrite.bottomUp(this);
return bottomUpRewriter.rewrite(expr, context);
if (expr instanceof CompoundPredicate && newChildren.size() == 1) {
// (a=1) and (a=1)
// after rewrite, newChildren=[(a=1)]
expr = newChildren.get(0);
} else {
expr = expr.withChildren(newChildren);
}
if (expr instanceof Or) {
expr = rewrite((Or) expr);
}
return expr;
}

private Expression rewrite(Or or) {
if (or.getMutableState(MutableState.KEY_OR_TO_IN).isPresent()) {
return or;
}
Pair<Expression, Expression> pair = extractCommonConjunct(or);
Expression result = tryToRewriteIn(pair.second);
if (pair.first != null) {
result = new And(pair.first, result);
}
result.setMutableState(MutableState.KEY_OR_TO_IN, 1);
return result;
}

private Expression tryToRewriteIn(Expression or) {
or.setMutableState(MutableState.KEY_OR_TO_IN, 1);
List<Expression> disjuncts = ExpressionUtils.extractDisjunction(or);
for (Expression disjunct : disjuncts) {
if (!hasInOrEqualChildren(disjunct)) {
Expand Down Expand Up @@ -114,7 +156,11 @@ private Expression tryToRewriteIn(Expression or) {
Expression conjunct = candidatesToFinalResult(candidates);
boolean keep = keepOriginalOrExpression(disjuncts);
if (keep) {
return new And(conjunct, or);
if (mode == Mode.extractMode) {
return new And(conjunct, or);
} else {
return or;
}
} else {
return conjunct;
}
Expand All @@ -132,15 +178,6 @@ private boolean keepOriginalOrExpression(List<Expression> disjuncts) {
return false;
}

private boolean containsAny(Set a, Set b) {
for (Object x : a) {
if (b.contains(x)) {
return true;
}
}
return false;
}

private Map<Expression, Set<Literal>> mergeCandidates(
Map<Expression, Set<Literal>> a,
Map<Expression, Set<Literal>> b) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ public static <K extends Comparable<K>> List<K> prune(List<Slot> partitionSlots,
"partitionPruningExpandThreshold",
10, sessionVariable -> sessionVariable.partitionPruningExpandThreshold);

partitionPredicate = OrToIn.INSTANCE.rewriteTree(
partitionPredicate = OrToIn.EXTRACT_MODE_INSTANCE.rewriteTree(
partitionPredicate, new ExpressionRewriteContext(cascadesContext));
if (BooleanLiteral.TRUE.equals(partitionPredicate)) {
return Utils.fastToImmutableList(idToPartitions.keySet());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.expression.rules.OrToIn;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;

import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;

/**
* infer In-predicate from Or
*
*/
public class InferInPredicateFromOr implements RewriteRuleFactory {

@Override
public List<Rule> buildRules() {
return ImmutableList.of(
logicalFilter().then(this::rewriteFilterExpression).toRule(RuleType.EXTRACT_IN_PREDICATE_FROM_OR),
logicalProject().then(this::rewriteProject).toRule(RuleType.EXTRACT_IN_PREDICATE_FROM_OR),
logicalJoin().whenNot(LogicalJoin::isMarkJoin)
.then(this::rewriteJoin).toRule(RuleType.EXTRACT_IN_PREDICATE_FROM_OR)
);
}

private LogicalFilter<Plan> rewriteFilterExpression(LogicalFilter<Plan> filter) {
Expression rewritten = OrToIn.EXTRACT_MODE_INSTANCE.rewriteTree(filter.getPredicate());
Set<Expression> set = new LinkedHashSet<>(ExpressionUtils.extractConjunction(rewritten));
return filter.withConjuncts(set);
}

private LogicalProject<Plan> rewriteProject(LogicalProject<Plan> project) {
List<NamedExpression> newProjections = Lists.newArrayList();
for (NamedExpression proj : project.getProjects()) {
if (proj instanceof SlotReference) {
newProjections.add(proj);
} else {
Expression rewritten = OrToIn.REPLACE_MODE_INSTANCE.rewriteTree(proj);
newProjections.add((NamedExpression) rewritten);
}
}
return project.withProjects(newProjections);
}

private LogicalJoin<Plan, Plan> rewriteJoin(LogicalJoin<Plan, Plan> join) {
if (!join.isMarkJoin()) {
Expression otherCondition;
if (join.getOtherJoinConjuncts().isEmpty()) {
return join;
} else if (join.getOtherJoinConjuncts().size() == 1) {
otherCondition = join.getOtherJoinConjuncts().get(0);
} else {
otherCondition = new And(join.getOtherJoinConjuncts());
}
Expression rewritten = OrToIn.REPLACE_MODE_INSTANCE.rewriteTree(otherCondition);
join = join.withJoinConjuncts(join.getHashJoinConjuncts(), ExpressionUtils.extractConjunction(rewritten),
join.getJoinReorderContext());
}
return join;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ public interface MutableState {
String KEY_RF_JUMP = "rf-jump";
String KEY_PUSH_TOPN_TO_AGG = "pushTopnToAgg";

String KEY_OR_TO_IN = "or_to_in";

<T> Optional<T> get(String key);

MutableState set(String key, Object value);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import org.apache.doris.nereids.rules.expression.rules.InPredicateDedup;
import org.apache.doris.nereids.rules.expression.rules.InPredicateToEqualToRule;
import org.apache.doris.nereids.rules.expression.rules.NormalizeBinaryPredicatesRule;
import org.apache.doris.nereids.rules.expression.rules.OrToIn;
import org.apache.doris.nereids.rules.expression.rules.SimplifyCastRule;
import org.apache.doris.nereids.rules.expression.rules.SimplifyComparisonPredicate;
import org.apache.doris.nereids.rules.expression.rules.SimplifyNotExprRule;
Expand Down Expand Up @@ -169,20 +168,6 @@ void testExtractCommonFactorRewrite() {

}

@Test
void testTpcdsCase() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
bottomUp(
SimplifyRange.INSTANCE,
OrToIn.INSTANCE,
ExtractCommonFactorRule.INSTANCE
)
));
assertRewrite(
"(((((customer_address.ca_country = 'United States') AND ca_state IN ('DE', 'FL', 'TX')) OR ((customer_address.ca_country = 'United States') AND ca_state IN ('ID', 'IN', 'ND'))) OR ((customer_address.ca_country = 'United States') AND ca_state IN ('IL', 'MT', 'OH'))))",
"((customer_address.ca_country = 'United States') AND ca_state IN ('DE', 'FL', 'TX', 'ID', 'IN', 'ND', 'IL', 'MT', 'OH'))");
}

@Test
void testInPredicateToEqualToRule() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
Expand Down
Loading

0 comments on commit af8a975

Please sign in to comment.