Skip to content

Commit

Permalink
Revert "[Fix](Nereids) fix infer predicate lost cast of source expres…
Browse files Browse the repository at this point in the history
…sion (#23692)"

This reverts commit 03f029f.
  • Loading branch information
xiaokang committed Sep 13, 2023
1 parent 13652d6 commit 6870802
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ public Set<Expression> infer(Set<Expression> predicates) {
}

/**
* Use the left or right child of `equalExpr` to replace the left or right child of `expression`
* Use the left or right child of `leftSlotEqualToRightSlot` to replace the left or right child of `expression`
* Now only support infer `ComparisonPredicate`.
* TODO: We should determine whether `expression` satisfies the condition for replacement
* eg: Satisfy `expression` is non-deterministic
*/
private Expression doInfer(Expression equalExpr, Expression expression) {
private Expression doInfer(Expression leftSlotEqualToRightSlot, Expression expression) {
return expression.accept(new DefaultExpressionRewriter<Void>() {

@Override
Expand All @@ -76,43 +76,36 @@ public Expression visit(Expression expr, Void context) {
public Expression visitComparisonPredicate(ComparisonPredicate cp, Void context) {
// we need to get expression covered by cast, because we want to infer different datatype
if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.left()) && (cp.right().isConstant())) {
return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.left()), equalExpr);
return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.left()));
} else if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.right()) && cp.left().isConstant()) {
return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.right()), equalExpr);
return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.right()));
}
return super.visit(cp, context);
}

private boolean isDataTypeValid(DataType originDataType, Expression expr) {
if ((expr.child(0).getDataType() instanceof IntegralType)
&& (expr.child(1).getDataType() instanceof IntegralType)
if ((leftSlotEqualToRightSlot.child(0).getDataType() instanceof IntegralType)
&& (leftSlotEqualToRightSlot.child(1).getDataType() instanceof IntegralType)
&& (originDataType instanceof IntegralType)) {
// infer filter can not be lower than original datatype, or dataset would be wrong
if (!((IntegralType) originDataType).widerThan(
(IntegralType) expr.child(0).getDataType())
(IntegralType) leftSlotEqualToRightSlot.child(0).getDataType())
&& !((IntegralType) originDataType).widerThan(
(IntegralType) expr.child(1).getDataType())) {
(IntegralType) leftSlotEqualToRightSlot.child(1).getDataType())) {
return true;
}
} else if (expr.child(0).getDataType().equals(expr.child(1).getDataType())) {
return true;
}
return false;
}

private Expression replaceSlot(Expression sourcePredicate, DataType originDataType, Expression equal) {
if (!isDataTypeValid(originDataType, equal)) {
return sourcePredicate;
}
return sourcePredicate.rewriteUp(e -> {
// we can not replace Cast expression to slot because when rewrite up, we have replace child of cast
if (e instanceof Cast) {
return e;
}
if (ExpressionUtils.isTwoExpressionEqualWithCast(e, equal.child(0))) {
return equal.child(1);
} else if (ExpressionUtils.isTwoExpressionEqualWithCast(e, equal.child(1))) {
return equal.child(0);
private Expression replaceSlot(Expression expr, DataType originDataType) {
return expr.rewriteUp(e -> {
if (isDataTypeValid(originDataType, leftSlotEqualToRightSlot)) {
if (ExpressionUtils.isTwoExpressionEqualWithCast(e, leftSlotEqualToRightSlot.child(0))) {
return leftSlotEqualToRightSlot.child(1);
} else if (ExpressionUtils.isTwoExpressionEqualWithCast(e, leftSlotEqualToRightSlot.child(1))) {
return leftSlotEqualToRightSlot.child(0);
}
}
return e;
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,15 @@

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.utframe.TestWithFeService;

import com.google.common.collect.Sets;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

import java.util.Optional;
import java.util.Set;

public class InferPredicatesTest extends TestWithFeService implements MemoPatternMatchSupported {

private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);

private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);

private final PredicatePropagation propagation = new PredicatePropagation();

@Override
protected void runBeforeAll() throws Exception {
createDatabase("test");
Expand Down Expand Up @@ -646,16 +628,4 @@ public void innerJoinShouldNotInferUnderLeftJoinOnClausePredicates() {
).when(join -> join.getJoinType() == JoinType.LEFT_OUTER_JOIN)
);
}

@Test
void testInfer() {
EqualTo equalTo = new EqualTo(new Cast(scan1.getOutput().get(0), BigIntType.INSTANCE), Literal.of(1));
EqualTo equalTo2 = new EqualTo(scan2.getOutput().get(0), scan1.getOutput().get(0));
Set<Expression> predicates = Sets.newHashSet();
predicates.add(equalTo2);
predicates.add(equalTo);
Set<Expression> newPredicates = propagation.infer(predicates);
Optional<Expression> newPredicate = newPredicates.stream().findFirst();
Assertions.assertTrue(newPredicate.get().equals(new EqualTo(new Cast(scan2.getOutput().get(0), BigIntType.INSTANCE), Literal.of(1))));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,13 @@ suite("test_infer_predicate") {
sql 'drop table if exists infer_tb1;'
sql 'drop table if exists infer_tb2;'
sql 'drop table if exists infer_tb3;'
sql 'drop table if exists infer_tb4;'
sql 'drop table if exists infer_tb5;'

sql '''create table infer_tb1 (k1 int, k2 int) distributed by hash(k1) buckets 3 properties('replication_num' = '1');'''

sql '''create table infer_tb2 (k1 tinyint, k2 smallint, k3 int, k4 bigint, k5 largeint, k6 date, k7 datetime, k8 float, k9 double) distributed by hash(k1) buckets 3 properties('replication_num' = '1');'''

sql '''create table infer_tb3 (k1 varchar(100), k2 int) distributed by hash(k1) buckets 3 properties('replication_num' = '1');'''

sql '''create table infer_tb4 (k1 varchar(100), k2 date) distributed by hash(k1) buckets 3 properties('replication_num' = '1');'''

sql '''create table infer_tb5 (k1 varchar(100), k3 date) distributed by hash(k1) buckets 3 properties('replication_num' = '1');'''

explain {
sql "select * from infer_tb1 inner join infer_tb2 where infer_tb2.k1 = infer_tb1.k2 and infer_tb2.k1 = 1;"
contains "PREDICATES: k2"
Expand All @@ -61,16 +55,4 @@ suite("test_infer_predicate") {
contains "PREDICATES: k3"
contains "PREDICATES: k2"
}

explain {
sql "select * from infer_tb4 left join infer_tb5 on infer_tb4.k2 = infer_tb5.k3 where infer_tb4.k2 = '20230901';"
contains "PREDICATES: k3"
contains "PREDICATES: k2"
}

sql 'drop table if exists infer_tb1;'
sql 'drop table if exists infer_tb2;'
sql 'drop table if exists infer_tb3;'
sql 'drop table if exists infer_tb4;'
sql 'drop table if exists infer_tb5;'
}

0 comments on commit 6870802

Please sign in to comment.