From f4a732e85cb1ba99de3f9b2270861677f65ff60f Mon Sep 17 00:00:00 2001 From: XuPengfei Date: Tue, 18 Jun 2024 14:40:21 +0800 Subject: [PATCH] [refactor](Nereids) New expression extractor for partitions pruning (#36326) An exception throw in TryEliminateUninterestedPredicates, for this case CREATE TABLE `tbltest` ( `id` INT NULL, `col2` VARCHAR(255) NULL, `col3` VARCHAR(255) NULL, `dt` DATE NULL ) ENGINE=OLAP DUPLICATE KEY(`id`, `col2`) PARTITION BY RANGE(`dt`) (PARTITION p20240617 VALUES [('2024-06-17'), ('2024-06-18'))) DISTRIBUTED BY HASH(`id`) BUCKETS 10 PROPERTIES ( "replication_allocation" = "tag.location.default: 1" ); select * from tbltest where case when col2 = 'xxx' and col3='yyy' then false -- note this is not about partition column when col2 in ('xxx') then false when col2 like 'xxx%' then false else true end The CaseWhen require children should be WhenClause, TryEliminateUninterestedPredicates maybe rewrite the WhenClause to true/false predicate, and cause this exception: ERROR 1105 (HY000): errCode = 2, detailMessage = The children format needs to be [WhenClause+, DefaultValue?] Original extractor(TryEliminateUninterestedPredicates.java) caused some errors while try to derive the expressions which can be used for pruning partitions. I tried to write a new extractor(and with unit tests) for pruning partitions, it is more simple and reliable (I think). The theory of extractor is pretty simple: A:Sort the expression in two kinds: 1. evaluable-expression (let's mark it as E). Expressions that can be evaluated in the partition pruning stage. In the other word: not contains non-partition slots or deterministic expression. 2. un-evaluable-expression (let's mark it as UE). Expressions that can NOT be evaluated in the partition pruning stage. In the other word: contains non-partition slots or deterministic expression. B: Travel the predicate, only point on AND and OR operator, following the rule: (E and UE) -> (E and TRUE) -> E (UE and UE) -> TRUE (E and E) -> (E and E) (E or UE) -> TRUE (UE or UE) -> TRUE (E or E) -> (E or E) --- .../nereids/analyzer/UnboundFunction.java | 18 +- .../PartitionPruneExpressionExtractor.java | 178 ++++++++++++ .../expression/rules/PartitionPruner.java | 2 +- .../TryEliminateUninterestedPredicates.java | 152 ---------- .../expressions/functions/BoundFunction.java | 19 +- .../trees/expressions/functions/Function.java | 13 +- ...PartitionPruneExpressionExtractorTest.java | 273 ++++++++++++++++++ 7 files changed, 475 insertions(+), 180 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PartitionPruneExpressionExtractor.java delete mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/TryEliminateUninterestedPredicates.java create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/PartitionPruneExpressionExtractorTest.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/UnboundFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/UnboundFunction.java index 4934d8ddc4e00e..4f9a69146b344a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/UnboundFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/UnboundFunction.java @@ -36,7 +36,6 @@ */ public class UnboundFunction extends Function implements Unbound, PropagateNullable { private final String dbName; - private final String name; private final boolean isDistinct; public UnboundFunction(String name, List arguments) { @@ -52,16 +51,11 @@ public UnboundFunction(String name, boolean isDistinct, List argumen } public UnboundFunction(String dbName, String name, boolean isDistinct, List arguments) { - super(arguments); + super(name, arguments); this.dbName = dbName; - this.name = Objects.requireNonNull(name, "name cannot be null"); this.isDistinct = isDistinct; } - public String getName() { - return name; - } - @Override public String getExpressionName() { if (!this.exprName.isPresent()) { @@ -87,13 +81,13 @@ public String toSql() throws UnboundException { String params = children.stream() .map(Expression::toSql) .collect(Collectors.joining(", ")); - return name + "(" + (isDistinct ? "distinct " : "") + params + ")"; + return getName() + "(" + (isDistinct ? "distinct " : "") + params + ")"; } @Override public String toString() { String params = Joiner.on(", ").join(children); - return "'" + name + "(" + (isDistinct ? "distinct " : "") + params + ")"; + return "'" + getName() + "(" + (isDistinct ? "distinct " : "") + params + ")"; } @Override @@ -103,7 +97,7 @@ public R accept(ExpressionVisitor visitor, C context) { @Override public UnboundFunction withChildren(List children) { - return new UnboundFunction(dbName, name, isDistinct, children); + return new UnboundFunction(dbName, getName(), isDistinct, children); } @Override @@ -118,11 +112,11 @@ public boolean equals(Object o) { return false; } UnboundFunction that = (UnboundFunction) o; - return isDistinct == that.isDistinct && name.equals(that.name); + return isDistinct == that.isDistinct && getName().equals(that.getName()); } @Override public int hashCode() { - return Objects.hash(name, isDistinct); + return Objects.hash(getName(), isDistinct); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PartitionPruneExpressionExtractor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PartitionPruneExpressionExtractor.java new file mode 100644 index 00000000000000..322016fd45c4a9 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PartitionPruneExpressionExtractor.java @@ -0,0 +1,178 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression.rules; + +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.rules.PartitionPruneExpressionExtractor.Context; +import org.apache.doris.nereids.trees.expressions.And; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Or; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SubqueryExpr; +import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; +import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; + +import com.google.common.annotations.VisibleForTesting; + +import java.util.Objects; +import java.util.Set; + +/** + * PartitionPruneExpressionExtractor + * + * This rewriter only used to extract the expression that can be used in partition pruning from + * the whole predicate expression. + * The theory of extractor is pretty simple: + * A:Sort the expression in two kinds: + * 1. evaluable-expression (let's mark it as E). + * Expressions that can be evaluated in the partition pruning stage. + * In the other word: not contains non-partition slots or deterministic expression. + * 2. un-evaluable-expression (let's mark it as UE). + * Expressions that can NOT be evaluated in the partition pruning stage. + * In the other word: contains non-partition slots or deterministic expression. + * + * B: Travel the predicate, only point on AND and OR operator, following the rule: + * (E and UE) -> (E and TRUE) -> E + * (UE and UE) -> TRUE + * (E and E) -> (E and E) + * (E or UE) -> TRUE + * (UE or UE) -> TRUE + * (E or E) -> (E or E) + * + * e.g. + * (part = 1 and non_part = 'a') or (part = 2) + * -> (part = 1 and true) or (part = 2) + * -> (part = 1) or (part = 2) + * + * It's better that do some expression optimize(like fold, eliminate etc.) on predicate before this step. + */ +public class PartitionPruneExpressionExtractor extends DefaultExpressionRewriter { + private final ExpressionEvaluableDetector expressionEvaluableDetector; + + private PartitionPruneExpressionExtractor(Set interestedSlots) { + this.expressionEvaluableDetector = new ExpressionEvaluableDetector(interestedSlots); + } + + /** + * Extract partition prune expression from predicate + */ + public static Expression extract(Expression predicate, + Set partitionSlots, + CascadesContext cascadesContext) { + predicate = predicate.accept(FoldConstantRuleOnFE.VISITOR_INSTANCE, + new ExpressionRewriteContext(cascadesContext)); + PartitionPruneExpressionExtractor rewriter = new PartitionPruneExpressionExtractor(partitionSlots); + Context context = new Context(); + Expression partitionPruneExpression = predicate.accept(rewriter, context); + if (context.containsUnEvaluableExpression) { + return BooleanLiteral.TRUE; + } + return partitionPruneExpression; + } + + @Override + public Expression visit(Expression originExpr, Context parentContext) { + if (originExpr instanceof And) { + return this.visitAnd((And) originExpr, parentContext); + } + if (originExpr instanceof Or) { + return this.visitOr((Or) originExpr, parentContext); + } + + parentContext.containsUnEvaluableExpression = !expressionEvaluableDetector.detect(originExpr); + return originExpr; + } + + @Override + public Expression visitAnd(And node, Context parentContext) { + // handle left node + Context leftContext = new Context(); + Expression newLeft = node.left().accept(this, leftContext); + // handle right node + Context rightContext = new Context(); + Expression newRight = node.right().accept(this, rightContext); + + // if anyone of them is FALSE, the whole expression should be FALSE. + if (newLeft == BooleanLiteral.FALSE || newRight == BooleanLiteral.FALSE) { + return BooleanLiteral.FALSE; + } + + // If left node contains non-partition slot or is TURE, just discard it. + if (newLeft == BooleanLiteral.TRUE || leftContext.containsUnEvaluableExpression) { + return rightContext.containsUnEvaluableExpression ? BooleanLiteral.TRUE : newRight; + } + + // If right node contains non-partition slot or is TURE, just discard it. + if (newRight == BooleanLiteral.TRUE || rightContext.containsUnEvaluableExpression) { + return newLeft; + } + + // both does not contains non-partition slot. + return new And(newLeft, newRight); + } + + @Override + public Expression visitOr(Or node, Context parentContext) { + // handle left node + Context leftContext = new Context(); + Expression newLeft = node.left().accept(this, leftContext); + // handle right node + Context rightContext = new Context(); + Expression newRight = node.right().accept(this, rightContext); + + // if anyone of them is TRUE or contains non-partition slot, just return TRUE. + if (newLeft == BooleanLiteral.TRUE || newRight == BooleanLiteral.TRUE + || leftContext.containsUnEvaluableExpression || rightContext.containsUnEvaluableExpression) { + return BooleanLiteral.TRUE; + } + + return new Or(newLeft, newRight); + } + + /** + * Context + */ + @VisibleForTesting + public static class Context { + private boolean containsUnEvaluableExpression; + } + + /** + * The detector only indicate that whether a predicate contains interested slots or not, + * and do not change the predicate. + */ + @VisibleForTesting + public static class ExpressionEvaluableDetector extends DefaultExpressionRewriter { + private final Set partitionSlots; + + public ExpressionEvaluableDetector(Set partitionSlots) { + this.partitionSlots = Objects.requireNonNull(partitionSlots, "partitionSlots can not be null"); + } + + /** + * Return true if expression does NOT contains un-evaluable expression. + */ + @VisibleForTesting + public boolean detect(Expression expression) { + boolean containsUnEvaluableExpression = expression.anyMatch( + expr -> expr instanceof SubqueryExpr || (expr instanceof Slot && !partitionSlots.contains(expr))); + return !containsUnEvaluableExpression; + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PartitionPruner.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PartitionPruner.java index 4a825d7956b839..b65c0d2ec55990 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PartitionPruner.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PartitionPruner.java @@ -117,7 +117,7 @@ public List prune() { public static List prune(List partitionSlots, Expression partitionPredicate, Map idToPartitions, CascadesContext cascadesContext, PartitionTableType partitionTableType) { - partitionPredicate = TryEliminateUninterestedPredicates.rewrite( + partitionPredicate = PartitionPruneExpressionExtractor.extract( partitionPredicate, ImmutableSet.copyOf(partitionSlots), cascadesContext); partitionPredicate = PredicateRewriteForPartitionPrune.rewrite(partitionPredicate, cascadesContext); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/TryEliminateUninterestedPredicates.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/TryEliminateUninterestedPredicates.java deleted file mode 100644 index ce23219bcc93e2..00000000000000 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/TryEliminateUninterestedPredicates.java +++ /dev/null @@ -1,152 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.rules.expression.rules; - -import org.apache.doris.nereids.CascadesContext; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; -import org.apache.doris.nereids.rules.expression.rules.TryEliminateUninterestedPredicates.Context; -import org.apache.doris.nereids.trees.expressions.And; -import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.Not; -import org.apache.doris.nereids.trees.expressions.Slot; -import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; -import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; - -import java.util.Set; - -/** - * TryEliminateUninterestedPredicates - * - * this rewriter usually used to extract the partition columns related predicates, - * and try to eliminate partition columns related predicate. - * - * e.g. - * (part = 1 and non_part = 'a') or (part = 2) - * -> (part = 1 and true) or (part = 2) - * -> (part = 1) or (part = 2) - * - * maybe eliminate failed in some special cases, e.g. (non_part + part) = 2. - * the key point is: if a predicate(return boolean type) only contains the uninterested slots, we can eliminate it. - */ -public class TryEliminateUninterestedPredicates extends DefaultExpressionRewriter { - private final Set interestedSlots; - private final ExpressionRewriteContext expressionRewriteContext; - - private TryEliminateUninterestedPredicates(Set interestedSlots, CascadesContext cascadesContext) { - this.interestedSlots = interestedSlots; - this.expressionRewriteContext = new ExpressionRewriteContext(cascadesContext); - } - - /** rewrite */ - public static Expression rewrite(Expression expression, Set interestedSlots, - CascadesContext cascadesContext) { - // before eliminate uninterested predicate, we must push down `Not` under CompoundPredicate - expression = expression.rewriteUp(expr -> { - if (expr instanceof Not) { - return SimplifyNotExprRule.simplify((Not) expr); - } else { - return expr; - } - }); - TryEliminateUninterestedPredicates rewriter = new TryEliminateUninterestedPredicates( - interestedSlots, cascadesContext); - return expression.accept(rewriter, new Context()); - } - - @Override - public Expression visit(Expression originExpr, Context parentContext) { - Context currentContext = new Context(); - // postorder traversal - Expression expr = super.visit(originExpr, currentContext); - - // process predicate - if (expr.getDataType().isBooleanType()) { - // if a predicate contains not only interested slots but also non-interested slots, - // we can not eliminate non-interested slots: - // e.g. - // not(uninterested slot b + interested slot a > 1) - // -> not(uninterested slot b + interested slot a > 1) - if (!currentContext.childrenContainsInterestedSlots && currentContext.childrenContainsNonInterestedSlots) { - // propagate true value up to eliminate uninterested slots, - // because we don't know the runtime value of the slots - // e.g. - // not(uninterested slot b > 1) - // -> not(true) - // -> true - expr = BooleanLiteral.TRUE; - } else { - // simplify the predicate expression, the interested slots may be eliminated too - // e.g. - // ((interested slot a) and not(uninterested slot b > 1)) or true - // -> ((interested slot a) and not(true)) or true - // -> ((interested slot a) and true) or true - // -> (interested slot a) or true - // -> true - expr = FoldConstantRuleOnFE.evaluate(expr, expressionRewriteContext); - } - } else { - // ((uninterested slot b > 0) + 1) > 1 - // -> (true + 1) > 1 - // -> ((uninterested slot b > 0) + 1) > 1 (recover to origin expr because `true + 1` is not predicate) - // -> true (not contains interested slot but contains uninterested slot) - expr = originExpr; - } - - parentContext.childrenContainsInterestedSlots |= currentContext.childrenContainsInterestedSlots; - parentContext.childrenContainsNonInterestedSlots |= currentContext.childrenContainsNonInterestedSlots; - - return expr; - } - - @Override - public Expression visitAnd(And and, Context parentContext) { - Expression left = and.left(); - Context leftContext = new Context(); - // Expression newLeft = this.visit(left, leftContext); - Expression newLeft = left.accept(this, leftContext); - - if (leftContext.childrenContainsNonInterestedSlots) { - newLeft = BooleanLiteral.TRUE; - } - - Expression right = and.right(); - Context rightContext = new Context(); - Expression newRight = this.visit(right, rightContext); - if (rightContext.childrenContainsNonInterestedSlots) { - newRight = BooleanLiteral.TRUE; - } - Expression expr = FoldConstantRuleOnFE.evaluate(new And(newLeft, newRight), expressionRewriteContext); - parentContext.childrenContainsInterestedSlots = - rightContext.childrenContainsInterestedSlots || leftContext.childrenContainsInterestedSlots; - return expr; - } - - @Override - public Expression visitSlot(Slot slot, Context context) { - boolean isInterestedSlot = interestedSlots.contains(slot); - context.childrenContainsInterestedSlots |= isInterestedSlot; - context.childrenContainsNonInterestedSlots |= !isInterestedSlot; - return slot; - } - - /** Context */ - public static class Context { - private boolean childrenContainsInterestedSlots; - private boolean childrenContainsNonInterestedSlots; - } -} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BoundFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BoundFunction.java index 33b587ce74d8e9..06acb4c5782b2c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BoundFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BoundFunction.java @@ -33,7 +33,6 @@ /** BoundFunction. */ public abstract class BoundFunction extends Function implements ComputeSignature { - private final String name; private final Supplier signatureCache = Suppliers.memoize(() -> { // first step: find the candidate signature in the signature list @@ -43,17 +42,11 @@ public abstract class BoundFunction extends Function implements ComputeSignature }); public BoundFunction(String name, Expression... arguments) { - super(arguments); - this.name = Objects.requireNonNull(name, "name can not be null"); + super(name, arguments); } public BoundFunction(String name, List children) { - super(children); - this.name = Objects.requireNonNull(name, "name can not be null"); - } - - public String getName() { - return name; + super(name, children); } @Override @@ -75,17 +68,17 @@ public R accept(ExpressionVisitor visitor, C context) { @Override protected boolean extraEquals(Expression that) { - return Objects.equals(name, ((BoundFunction) that).name); + return Objects.equals(getName(), ((BoundFunction) that).getName()); } @Override public int hashCode() { - return Objects.hash(name, children); + return Objects.hash(getName(), children); } @Override public String toSql() throws UnboundException { - StringBuilder sql = new StringBuilder(name).append("("); + StringBuilder sql = new StringBuilder(getName()).append("("); int arity = arity(); for (int i = 0; i < arity; i++) { Expression arg = child(i); @@ -103,6 +96,6 @@ public String toString() { .stream() .map(Expression::toString) .collect(Collectors.joining(", ")); - return name + "(" + args + ")"; + return getName() + "(" + args + ")"; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Function.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Function.java index d1d23c192bb631..9e4c19365d837f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Function.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Function.java @@ -21,20 +21,29 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda; import java.util.List; +import java.util.Objects; /** * function in nereids. */ public abstract class Function extends Expression { - public Function(Expression... children) { + private final String name; + + public Function(String name, Expression... children) { super(children); + this.name = Objects.requireNonNull(name, "name can not be null"); } - public Function(List children) { + public Function(String name, List children) { super(children); + this.name = Objects.requireNonNull(name, "name can not be null"); } public boolean isHighOrder() { return !children.isEmpty() && children.get(0) instanceof Lambda; } + + public final String getName() { + return name; + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/PartitionPruneExpressionExtractorTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/PartitionPruneExpressionExtractorTest.java new file mode 100644 index 00000000000000..d9f49d88f307fd --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/PartitionPruneExpressionExtractorTest.java @@ -0,0 +1,273 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression; + +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.analyzer.UnboundRelation; +import org.apache.doris.nereids.analyzer.UnboundSlot; +import org.apache.doris.nereids.parser.NereidsParser; +import org.apache.doris.nereids.rules.expression.rules.PartitionPruneExpressionExtractor; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.plans.RelationId; +import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.BooleanType; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.types.StringType; +import org.apache.doris.nereids.types.TinyIntType; +import org.apache.doris.nereids.util.MemoTestUtils; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * This unit is used to check whether {@link PartitionPruneExpressionExtractor} is correct or not. + * Slot P1 ~ P5 are partition slots. + */ +public class PartitionPruneExpressionExtractorTest { + private static final NereidsParser PARSER = new NereidsParser(); + private final CascadesContext cascadesContext = MemoTestUtils.createCascadesContext( + new UnboundRelation(new RelationId(1), ImmutableList.of("tbl"))); + private final Map slotMemo = Maps.newHashMap(); + private final Set partitionSlots; + private final PartitionPruneExpressionExtractor.ExpressionEvaluableDetector evaluableDetector; + + public PartitionPruneExpressionExtractorTest() { + Map partitions = createPartitionSlots(); + slotMemo.putAll(partitions); + partitionSlots = ImmutableSet.copyOf(partitions.values()); + evaluableDetector = new PartitionPruneExpressionExtractor.ExpressionEvaluableDetector(partitionSlots); + } + + /** + * Expect: All expressions which contains non-partition slot are not evaluable. + */ + @Test + public void testExpressionEvaluableDetector() { + // expression does not contains any non-partition slot. + assertDeterminateEvaluable("P1 = '20240614'", true); + assertDeterminateEvaluable("P1 = '20240614' and P2 = '20240614'", true); + assertDeterminateEvaluable("P1 = '20240614' or P2 = '20240614'", true); + assertDeterminateEvaluable("P1 = '20240614' and P2 = '20240614' and true", true); + assertDeterminateEvaluable("P1 = '20240614' and P2 = '20240614' and false", true); + assertDeterminateEvaluable("P1 = '20240614' and P2 = '20240614' and 5 > 10", true); + assertDeterminateEvaluable("P1 = '20240614' and P2 = '20240614' or 'a' = 'b'", true); + assertDeterminateEvaluable("P1 = '20240614' and not(P2 = '20240614') or 'a' = 'b'", true); + assertDeterminateEvaluable("P1 = '20240614' and P2 = '20240614' or not('a' = 'b')", true); + assertDeterminateEvaluable("P1 = '20240614' and " + + "case when(P2 = '20240614' or P2 = 'abc') then P3 = 'abc' else false end", true); + assertDeterminateEvaluable("P1 = '20240614' and " + + "case when(P2 = '20240614' and P1 = 'abc') then P3 = 'abc' else false end", true); + assertDeterminateEvaluable("P1 = '20240614' and " + + "if(P2 = '20240614' and P1 = 'abc', P3 = 'abc', false)", true); + assertDeterminateEvaluable("P1 = '20240614' and " + + "if(P2 = '20240614' and '123' = 'abc', P1 = 'abc', false)", true); + assertDeterminateEvaluable("P1 = '20240614' and " + + "to_date('20240614', '%Y%m%d') = P2", true); + + // expression contains non-partition slot. + assertDeterminateEvaluable("P1 = '20240614' and P2 = '20240614' and I1 = 'abc'", false); + assertDeterminateEvaluable("P1 = '20240614' and P2 = '20240614' or I1 = 'abc'", false); + assertDeterminateEvaluable("P1 = '20240614' and (P2 = '20240614' and I1 = 'abc')", false); + assertDeterminateEvaluable("P1 = '20240614' and (P2 = '20240614' or I1 = 'abc')", false); + assertDeterminateEvaluable("P1 = '20240614' and (P2 = '20240614' or I1 = 'abc')", false); + assertDeterminateEvaluable("P1 = '20240614' and not(P2 = '20240614') or 'S1' = 'b'", true); + assertDeterminateEvaluable("P1 = '20240614' and P2 = '20240614' or not('S2' = 'b')", true); + assertDeterminateEvaluable("P1 = '20240614' and " + + "case when(P2 = '20240614' or I1 = 'abc') then I2 = 'abc' else false end", false); + assertDeterminateEvaluable("P1 = '20240614' and " + + "case when(P2 = '20240614' and I1 = 'abc') then I2 = 'abc' else false end", false); + assertDeterminateEvaluable("P1 = '20240614' and " + + "if(P2 = '20240614' and I1 = 'abc', I2 = 'abc', false)", false); + assertDeterminateEvaluable("P1 = '20240614' and " + + "if(P2 = '20240614' and I1 = 'abc', I2 = 'abc', false)", false); + assertDeterminateEvaluable("P1 = '20240614' and " + + "to_date('20240614', '%Y%m%d') = S1", false); + assertDeterminateEvaluable("P1 = '20240614' and " + + "(select 'a' from t limit 1) = S1", false); + } + + @Test + public void testExpressionExtract() { + assertExtract("P1 = '20240614'", "P1 = '20240614'"); + assertExtract("P1 = '20240614' and P2 = '20240614'", "P1 = '20240614' and P2 = '20240614'"); + assertExtract("P1 = '20240614' or P2 = '20240614'", "P1 = '20240614' or P2 = '20240614'"); + + assertExtract("P1 = '20240614' and P2 = '20240614' and true", + "P1 = '20240614' and P2 = '20240614'"); + assertExtract("P1 = '20240614' and P2 = '20240614' and false", "false"); + assertExtract("P1 = '20240614' and P2 = '20240614' and 5 > 10", "false"); + assertExtract("P1 = '20240614' and P2 = '20240614' and I1 = 'abc'", + "P1 = '20240614' and P2 = '20240614'"); + assertExtract("P1 = '20240614' and P2 = '20240614' and (5 > 10 and I2 = 123)", "false"); + assertExtract("P1 = '20240614' and P2 = '20240614' or I1 = 'abc'", "true"); + + assertExtract("P1 = '20240614' and P2 = '20240614' or false", + "P1 = '20240614' and P2 = '20240614'"); + assertExtract("P1 = '20240614' and P2 = '20240614' or 5 < 10", "true"); + assertExtract("P1 = '20240614' and P2 = '20240614' or (5 < 10 or I2 = 123)", "true"); + assertExtract("P1 = '20240614' and P2 = '20240614' or (5 < 10 and I2 = 123)", "true"); + assertExtract("P1 = '20240614' and (P2 = '20240614' and I1 = 'abc')", + "P1 = '20240614' and P2 = '20240614'"); + assertExtract("P1 = '20240614' and (P2 = '20240614' or I1 = 'abc')", "P1 = '20240614'"); + assertExtract("P1 = '20240614' and (P2 = '20240614' or I1 = 'abc')", "P1 = '20240614'"); + assertExtract("P1 = '20240614' and P2 = '20240614' or I2 = 123", "true"); + assertExtract("P1 = '20240614' and P2 = '20240614' or not(I2 = 123)", "true"); + assertExtract("P1 = '20240614' and P2 = '20240614' or not(P3 = '20240614' and I2 = 123)", "true"); + + assertExtract("P1 = '20240614' and P2 = '20240614' or (5 > 10 and I2 = 123)", + "P1 = '20240614' and P2 = '20240614'"); + assertExtract("P1 = '20240614' and P2 = '20240614' and not(5 > 10 and I2 = 123)", + "P1 = '20240614' and P2 = '20240614'"); + + assertExtract("P1 = '20240614' and P2 = '20240614' and (P3 = '20240614' and (P4 = '20240614' or I1 = 123))", + "P1 = '20240614' and P2 = '20240614' and P3 = '20240614'"); + assertExtract("P1 = '20240614' and P2 = '20240614' and " + + "(P3 = '20240614' or (P4 = '20240614' and P5 = '20240614' or I1 = 123))", + "P1 = '20240614' and P2 = '20240614'"); + assertExtract("P1 = '20240614' and P2 = '20240614' and " + + "(P3 = '20240614' or (P4 = '20240614' or I1 = 123 and P5 = '20240614'))", + "P1 = '20240614' and P2 = '20240614' and (P3 = '20240614' or (P4 = '20240614' or P5 = '20240614'))"); + assertExtract("P1 = '20240614' and P2 = '20240614' and " + + "(P3 = '20240614' or ((P4 = '20240614' or I1 = 123) and P5 = '20240614'))", + "P1 = '20240614' and P2 = '20240614' and (P3 = '20240614' or P5 = '20240614')"); + assertExtract("I2 = 345 or (P1 = '20240614' and P2 = '20240614') and " + + "(P3 = '20240614' or (P4 = '20240614' and P5 = '20240614' and I1 = 123))", + "true"); + assertExtract("(I2 = 345 or P1 = '20240614') and P2 = '20240614' and " + + "(P3 = '20240614' or (P4 = '20240614' and P5 = '20240614' and I1 = 123))", + "P2 = '20240614' and (P3 = '20240614' or (P4 = '20240614' and P5 = '20240614'))"); + assertExtract("(I2 = 345 or P1 = '20240614') and P2 = '20240614' and " + + "(P3 = '20240614' or (P4 = '20240614' and P5 = '20240614' or I1 = 123))", + "P2 = '20240614'"); + assertExtract("P1 = '20240614' and P2 = '20240614' or " + + "(P3 = '20240614' or (P4 = '20240614' and P5 = '20240614' or I1 = 123))", + "true"); + assertExtract("P1 = '20240614' and case when(P2 = '20240614' or P2 = 'abc') then P3 = 'abc' else false end", + "P1 = '20240614' and case when(P2 = '20240614' or P2 = 'abc') then P3 = 'abc' else false end"); + assertExtract("P1 = '20240614' and case when(P2 = '20240614' and P1 = 'abc') then P3 = 'abc' else false end", + "P1 = '20240614' and case when(P2 = '20240614' and P1 = 'abc') then P3 = 'abc' else false end"); + assertExtract("P1 = '20240614' and case when(P2 = '20240614' or I1 = 'abc') then I2 = 'abc' else false end", + "P1 = '20240614'"); + assertExtract("P1 = '20240614' and case when(P2 = '20240614' and I1 = 'abc') then I2 = 'abc' else false end", + "P1 = '20240614'"); + assertExtract("P1 = '20240614' or if(P2 = '20240614' and P1 = 'abc', P3 = 'abc', false)", + "P1 = '20240614' or if(P2 = '20240614' and P1 = 'abc', P3 = 'abc', false)"); + assertExtract("P1 = '20240614' or if(P2 = '20240614' and '123' = 'abc', P1 = 'abc', false)", + "P1 = '20240614' or if(false, P1 = 'abc', false)"); + assertExtract("P1 = '20240614' or if(P2 = '20240614' and I1 = 'abc', I2 = 'abc', false)", "true"); + assertExtract("P1 = '20240614' or if(P2 = '20240614' and I1 = 'abc', I2 = 'abc', false)", "true"); + assertExtract("P1 = '20240614' and to_date('20240614', '%Y%m%d') = P2", + "P1 = '20240614' and to_date('20240614', '%Y%m%d') = P2"); + assertExtract("P1 = '20240614' and to_date('20240614', '%Y%m%d') = S1", + "P1 = '20240614'"); + assertExtract("P1 = '20240614' or (select 'a' from t limit 1) = S1", "true"); + } + + private void assertDeterminateEvaluable(String expressionString, boolean evaluable) { + Expression expression = replaceUnboundSlot(PARSER.parseExpression(expressionString), slotMemo); + Assertions.assertEquals(evaluableDetector.detect(expression), evaluable); + } + + private void assertExtract(String expression, String expected) { + Expression needRewriteExpression = replaceUnboundSlot(PARSER.parseExpression(expression), slotMemo); + Expression expectedExpression = replaceUnboundSlot(PARSER.parseExpression(expected), slotMemo); + Expression rewrittenExpression = + PartitionPruneExpressionExtractor.extract(needRewriteExpression, partitionSlots, cascadesContext); + Assertions.assertEquals(expectedExpression, rewrittenExpression); + } + + private Expression replaceUnboundSlot(Expression expression, Map mem) { + List children = Lists.newArrayList(); + boolean hasNewChildren = false; + for (Expression child : expression.children()) { + Expression newChild = replaceUnboundSlot(child, mem); + if (newChild != child) { + hasNewChildren = true; + } + children.add(newChild); + } + if (expression instanceof UnboundSlot) { + String name = ((UnboundSlot) expression).getName(); + mem.putIfAbsent(name, new SlotReference(name, getType(name.charAt(0)))); + return mem.get(name); + } + return hasNewChildren ? expression.withChildren(children) : expression; + } + + private Expression replaceNotNullUnboundSlot(Expression expression, Map mem) { + List children = Lists.newArrayList(); + boolean hasNewChildren = false; + for (Expression child : expression.children()) { + Expression newChild = replaceNotNullUnboundSlot(child, mem); + if (newChild != child) { + hasNewChildren = true; + } + children.add(newChild); + } + if (expression instanceof UnboundSlot) { + String name = ((UnboundSlot) expression).getName(); + mem.putIfAbsent(name, new SlotReference(name, getType(name.charAt(0)), false)); + return mem.get(name); + } + return hasNewChildren ? expression.withChildren(children) : expression; + } + + private Map createPartitionSlots() { + SlotReference slotReference1 = new SlotReference("P1", StringType.INSTANCE); + SlotReference slotReference2 = new SlotReference("P2", IntegerType.INSTANCE); + SlotReference slotReference3 = new SlotReference("P3", StringType.INSTANCE); + SlotReference slotReference4 = new SlotReference("P4", IntegerType.INSTANCE); + SlotReference slotReference5 = new SlotReference("P5", StringType.INSTANCE); + return ImmutableMap.of( + "P1", slotReference1, + "P2", slotReference2, + "P3", slotReference3, + "P4", slotReference4, + "P5", slotReference5); + } + + private DataType getType(char t) { + switch (t) { + case 'T': + return TinyIntType.INSTANCE; + case 'I': + return IntegerType.INSTANCE; + case 'D': + return DoubleType.INSTANCE; + case 'S': + return StringType.INSTANCE; + case 'B': + return BooleanType.INSTANCE; + default: + return BigIntType.INSTANCE; + } + } +}