Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullLiteral;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Array;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ConnectionId;
import org.apache.doris.nereids.trees.expressions.functions.scalar.CurrentCatalog;
Expand All @@ -74,6 +73,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncryptKeyRef;
import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
import org.apache.doris.nereids.trees.expressions.functions.scalar.LastQueryId;
import org.apache.doris.nereids.trees.expressions.functions.scalar.NullIf;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Nvl;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Password;
import org.apache.doris.nereids.trees.expressions.functions.scalar.SessionUser;
Expand Down Expand Up @@ -186,6 +186,7 @@ public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
matches(SessionUser.class, this::visitSessionUser),
matches(LastQueryId.class, this::visitLastQueryId),
matches(Nvl.class, this::visitNvl),
matches(NullIf.class, this::visitNullIf),
matches(Match.class, this::visitMatch)
);
}
Expand Down Expand Up @@ -551,9 +552,6 @@ public Expression visitTryCast(TryCast cast, ExpressionRewriteContext context) {

@Override
public Expression visitBoundFunction(BoundFunction boundFunction, ExpressionRewriteContext context) {
if (!boundFunction.foldable()) {
return boundFunction;
}
boundFunction = rewriteChildren(boundFunction, context);
Optional<Expression> checkedExpr = preProcess(boundFunction);
if (checkedExpr.isPresent()) {
Expand Down Expand Up @@ -735,18 +733,39 @@ public Expression visitVersion(Version version, ExpressionRewriteContext context
public Expression visitNvl(Nvl nvl, ExpressionRewriteContext context) {
Nvl originNvl = nvl;
nvl = rewriteChildren(nvl, context);

for (Expression expr : nvl.children()) {
if (expr.isLiteral()) {
if (!expr.isNullLiteral()) {
return TypeCoercionUtils.ensureSameResultType(originNvl, expr, context);
}
} else {
return TypeCoercionUtils.ensureSameResultType(originNvl, nvl, context);
Expression first = nvl.left();
Expression second = nvl.right();
Expression result = nvl;
if (first.equals(second) || second.isNullLiteral() || (first.isLiteral() && !first.isNullLiteral())) {
result = first;
} else if (first.isNullLiteral()) {
result = second;
}
return TypeCoercionUtils.ensureSameResultType(originNvl, result, context);
}

@Override
public Expression visitNullIf(NullIf nullIf, ExpressionRewriteContext context) {
NullIf originNullIf = nullIf;
nullIf = rewriteChildren(nullIf, context);
Expression first = nullIf.left();
Expression second = nullIf.right();
Expression result = nullIf;
// if first is null, then first = second will be null
if (first.isNullLiteral() || second.isNullLiteral()) {
result = first;
} else if (first.equals(second)) {
// even if first is null, then first = second will be null, then result is first, so the result is also null
result = new NullLiteral(originNullIf.getDataType());
} else if (first.isLiteral() && second.isLiteral()) {
Expression isEqual = visitEqualTo(new EqualTo(first, second), context);
if (isEqual.equals(BooleanLiteral.TRUE)) {
result = new NullLiteral(originNullIf.getDataType());
} else if (isEqual.equals(BooleanLiteral.FALSE) || isEqual.isNullLiteral()) {
result = first;
}
}
// all nulls
return TypeCoercionUtils.ensureSameResultType(originNvl, nvl.child(0), context);
return TypeCoercionUtils.ensureSameResultType(originNullIf, result, context);
}

private <E extends Expression> E rewriteChildren(E expr, ExpressionRewriteContext context) {
Expand Down Expand Up @@ -787,7 +806,7 @@ private <E extends Expression> E rewriteChildren(E expr, ExpressionRewriteContex
}

private Optional<Expression> preProcess(Expression expression) {
if (expression instanceof AggregateFunction || expression instanceof TableGeneratingFunction) {
if (!expression.foldable()) {
return Optional.of(expression);
}
if (ExpressionUtils.hasNullLiteral(expression.getArguments())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import org.apache.doris.nereids.exceptions.NotSupportedException;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.executable.DateTimeAcquire;
import org.apache.doris.nereids.trees.expressions.functions.executable.DateTimeArithmetic;
import org.apache.doris.nereids.trees.expressions.functions.executable.DateTimeExtractAndTransform;
Expand Down Expand Up @@ -56,7 +55,7 @@ public enum ExpressionEvaluator {
* Evaluate the value of the expression.
*/
public Expression eval(Expression expression) {
if (!(expression.isConstant() || expression.foldable()) || expression instanceof AggregateFunction) {
if (!(expression.isConstant() || expression.foldable())) {
return expression;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ public int computeHashCode() {
return Objects.hash(distinct, getName(), children);
}

@Override
public boolean foldable() {
return false;
}

@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitAggregateFunction(this, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,9 @@ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
protected GeneratorFunctionParams getFunctionParams(List<Expression> arguments) {
return new GeneratorFunctionParams(this, getName(), arguments, isInferred());
}

@Override
public boolean foldable() {
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1482,7 +1482,8 @@ void testFoldNvl() {

assertRewriteExpression("nvl(NULL, 1)", "1");
assertRewriteExpression("nvl(NULL, NULL)", "NULL");
assertRewriteAfterTypeCoercion("nvl(IA, NULL)", "ifnull(IA, NULL)");
assertRewriteAfterTypeCoercion("nvl(IA, NULL)", "IA");
assertRewriteAfterTypeCoercion("nvl(IA, IA)", "IA");
assertRewriteAfterTypeCoercion("nvl(IA, 1)", "ifnull(IA, 1)");

Expression foldNvl = executor.rewrite(
Expand All @@ -1492,6 +1493,33 @@ void testFoldNvl() {
Assertions.assertEquals(new DateTimeV2Literal(DateTimeV2Type.of(6), "2025-04-17"), foldNvl);
}

@Test
void testFoldNullIf() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
bottomUp(
FoldConstantRule.INSTANCE
)
));
assertRewriteAfterTypeCoercion("nullif(a, b)", "nullif(a, b)");
assertRewriteAfterTypeCoercion("nullif(a, a)", "null");
assertRewriteAfterTypeCoercion("nullif(a, null)", "a");
assertRewriteAfterTypeCoercion("nullif(null, a)", "null");
assertRewriteAfterTypeCoercion("nullif(1, 1)", "null");
assertRewriteAfterTypeCoercion("nullif(1, 2)", "1");
}

@Test
void testNonFoldable() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
bottomUp(
FoldConstantRule.INSTANCE
)
));
assertRewriteAfterTypeCoercion("random(0, 1)", "random(0, 1)");
assertRewriteAfterTypeCoercion("sum(1 + 2)", "sum(3)");
assertRewriteAfterTypeCoercion("explode([1, 2, 3])", "explode([1, 2, 3])");
}

private void assertRewriteExpression(String actualExpression, String expectedExpression) {
ExpressionRewriteContext context = new ExpressionRewriteContext(
MemoTestUtils.createCascadesContext(new UnboundRelation(new RelationId(1), ImmutableList.of("test_table"))));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,18 +394,18 @@ suite("test_mysql_jdbc_catalog", "p0,external,mysql,external_docker,external_doc
contains "QUERY: SELECT `timestamp0` FROM `doris_test`.`dt` WHERE (`timestamp0` > '2022-01-01 00:00:00')"
}
explain {
sql ("select k6, k8 from test1 where nvl(k6, null) = 1;")
sql ("select k6, k8 from test1 where nvl(k6, 1) = 1;")

contains "QUERY: SELECT `k6`, `k8` FROM `doris_test`.`test1` WHERE ((ifnull(`k6`, NULL) = 1))"
contains "QUERY: SELECT `k6`, `k8` FROM `doris_test`.`test1` WHERE ((ifnull(`k6`, 1) = 1))"
}
explain {
sql ("select k6, k8 from test1 where nvl(nvl(k6, null),null) = 1;")
sql ("select k6, k8 from test1 where nvl(nvl(k6, 1), 1) = 1;")

contains "QUERY: SELECT `k6`, `k8` FROM `doris_test`.`test1` WHERE ((ifnull(ifnull(`k6`, NULL), NULL) = 1))"
contains "QUERY: SELECT `k6`, `k8` FROM `doris_test`.`test1` WHERE ((ifnull(`k6`, 1) = 1))"
}
sql """ set enable_ext_func_pred_pushdown = "false"; """
explain {
sql ("select k6, k8 from test1 where nvl(k6, null) = 1 and k8 = 1;")
sql ("select k6, k8 from test1 where nvl(k6, 1) = 1 and k8 = 1;")

contains "QUERY: SELECT `k6`, `k8` FROM `doris_test`.`test1` WHERE ((`k8` = 1))"
}
Expand Down