Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ public Expression visitOr(Or or, ExpressionRewriteContext context) {
newChild = child;
}
if (newChild.getDataType().isNullType()) {
newChild = new NullLiteral(BooleanType.INSTANCE);
newChild = NullLiteral.BOOLEAN_INSTANCE;
} else {
newChild = TypeCoercionUtils.castIfNotSameType(newChild, BooleanType.INSTANCE);
}
Expand Down Expand Up @@ -618,7 +618,7 @@ public Expression visitAnd(And and, ExpressionRewriteContext context) {
newChild = child;
}
if (newChild.getDataType().isNullType()) {
newChild = new NullLiteral(BooleanType.INSTANCE);
newChild = NullLiteral.BOOLEAN_INSTANCE;
} else {
newChild = TypeCoercionUtils.castIfNotSameType(newChild, BooleanType.INSTANCE);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.doris.nereids.rules.expression.rules.AddMinMax;
import org.apache.doris.nereids.rules.expression.rules.ArrayContainToArrayOverlap;
import org.apache.doris.nereids.rules.expression.rules.BetweenToEqual;
import org.apache.doris.nereids.rules.expression.rules.CaseWhenToCompoundPredicate;
import org.apache.doris.nereids.rules.expression.rules.CaseWhenToIf;
import org.apache.doris.nereids.rules.expression.rules.DateFunctionRewrite;
import org.apache.doris.nereids.rules.expression.rules.DistinctPredicatesRule;
Expand Down Expand Up @@ -62,6 +63,7 @@ public class ExpressionOptimization extends ExpressionRewrite {
ReplaceNullWithFalseForCond.INSTANCE,
NestedCaseWhenCondToLiteral.INSTANCE,
CaseWhenToIf.INSTANCE,
CaseWhenToCompoundPredicate.INSTANCE,
TopnToMax.INSTANCE,
NullSafeEqualToEqual.INSTANCE,
LikeToEqualRewrite.INSTANCE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ public enum ExpressionRuleType {
ADD_MIN_MAX,
ARRAY_CONTAIN_TO_ARRAY_OVERLAP,
BETWEEN_TO_EQUAL,
CASE_WHEN_TO_COMPOUND_PREDICATE,
CASE_WHEN_TO_IF,
CHECK_CAST,
CONVERT_AGG_STATE_CAST,
Expand All @@ -36,6 +37,7 @@ public enum ExpressionRuleType {
FOLD_CONSTANT_ON_BE,
FOLD_CONSTANT_ON_FE,
LOG_TO_LN,
IF_TO_COMPOUND_PREDICATE,
IN_PREDICATE_DEDUP,
IN_PREDICATE_EXTRACT_NON_CONSTANT,
IN_PREDICATE_TO_EQUAL_TO,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.expression.rules;

import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
import org.apache.doris.nereids.rules.expression.ExpressionRuleType;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.WhenClause;
import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;

import com.google.common.collect.ImmutableList;

import java.util.List;
import java.util.Optional;

/**
* if case when all branch value are true/false literal, and the ELSE default value can be any expression,
* then can eliminate this case when.
*
* for example:
* 1. case when c1 then true when c2 then false end => (c1 <=> true or (not (c2 <=> true) and null))
* 2. if (c1, true, false) => c1 <=> true or false
*/
public class CaseWhenToCompoundPredicate implements ExpressionPatternRuleFactory {
public static CaseWhenToCompoundPredicate INSTANCE = new CaseWhenToCompoundPredicate();

@Override
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
return ImmutableList.of(
matchesType(CaseWhen.class)
.when(this::checkBooleanType)
.then(this::rewriteCaseWhen)
.toRule(ExpressionRuleType.CASE_WHEN_TO_COMPOUND_PREDICATE),
matchesType(If.class)
.when(this::checkBooleanType)
.then(this::rewriteIf)
.toRule(ExpressionRuleType.IF_TO_COMPOUND_PREDICATE)
);
}

private boolean checkBooleanType(Expression expression) {
return expression.getDataType().isBooleanType();
}

private Expression rewriteCaseWhen(CaseWhen caseWhen) {
Expression defaultValue = caseWhen.getDefaultValue().orElse(NullLiteral.BOOLEAN_INSTANCE);
return rewrite(caseWhen.getWhenClauses(), defaultValue).orElse(caseWhen);
}

private Expression rewriteIf(If ifExpr) {
List<WhenClause> whenClauses = ImmutableList.of(new WhenClause(ifExpr.getCondition(), ifExpr.getTrueValue()));
Expression defaultValue = ifExpr.getFalseValue();
return rewrite(whenClauses, defaultValue).orElse(ifExpr);
}

// for a branch, suppose the branches later it can rewrite to X, then given the branch:
// 1. when c then true ..., will rewrite to (c <=> true OR X),
// 2. when c then false ..., will rewrite to (not(c <=> true) AND X),
// for the ELSE branch, it can rewrite to `when true then defaultValue`,
// process the branches from back to front, the default value process first, while the first when clause will
// process last.
private Optional<Expression> rewrite(List<WhenClause> whenClauses, Expression defaultValue) {
for (WhenClause whenClause : whenClauses) {
Expression result = whenClause.getResult();
if (!(result instanceof BooleanLiteral)) {
return Optional.empty();
}
}
Expression result = defaultValue;
try {
for (int i = whenClauses.size() - 1; i >= 0; i--) {
WhenClause whenClause = whenClauses.get(i);
// operand <=> true
Expression condition = new NullSafeEqual(whenClause.getOperand(), BooleanLiteral.TRUE);
if (whenClause.getResult().equals(BooleanLiteral.TRUE)) {
result = new Or(condition, result);
} else {
result = new And(new Not(condition), result);
}
}
} catch (Exception e) {
// expression may exceed expression limit
return Optional.empty();
}
return Optional.of(result);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@
import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.TypeCoercionUtils;
Expand Down Expand Up @@ -445,7 +444,7 @@ public Expression visitAnd(And and, ExpressionRewriteContext context) {
}
} else {
// null and null and null and ...
return new NullLiteral(BooleanType.INSTANCE);
return NullLiteral.BOOLEAN_INSTANCE;
}
}

Expand Down Expand Up @@ -491,7 +490,7 @@ public Expression visitOr(Or or, ExpressionRewriteContext context) {
return or.withChildren(nonFalseLiteral);
} else {
// null or null
return new NullLiteral(BooleanType.INSTANCE);
return NullLiteral.BOOLEAN_INSTANCE;
}
}

Expand Down Expand Up @@ -649,7 +648,7 @@ public Expression visitInPredicate(InPredicate inPredicate, ExpressionRewriteCon
// now the inPredicate contains literal only.
Expression value = inPredicate.child(0);
if (value.isNullLiteral()) {
return new NullLiteral(BooleanType.INSTANCE);
return NullLiteral.BOOLEAN_INSTANCE;
}
boolean isOptionContainsNull = false;
for (Expression item : inPredicate.getOptions()) {
Expand All @@ -660,7 +659,7 @@ public Expression visitInPredicate(InPredicate inPredicate, ExpressionRewriteCon
}
}
return isOptionContainsNull
? new NullLiteral(BooleanType.INSTANCE)
? NullLiteral.BOOLEAN_INSTANCE
: BooleanLiteral.FALSE;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.types.BooleanType;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
Expand Down Expand Up @@ -115,7 +114,7 @@ public Expression visitInPredicate(InPredicate inPredicate, Map<Slot, PartitionS

Expression newCompareExpr = inPredicate.getCompareExpr().accept(this, context);
if (newCompareExpr.isNullLiteral()) {
return new NullLiteral(BooleanType.INSTANCE);
return NullLiteral.BOOLEAN_INSTANCE;
}

try {
Expand All @@ -125,7 +124,7 @@ public Expression visitInPredicate(InPredicate inPredicate, Map<Slot, PartitionS
return BooleanLiteral.TRUE;
}
if (inPredicate.optionsContainsNullLiteral()) {
return new NullLiteral(BooleanType.INSTANCE);
return NullLiteral.BOOLEAN_INSTANCE;
}
return BooleanLiteral.FALSE;
} catch (Throwable t) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import org.apache.doris.analysis.LiteralExpr;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.NullType;

Expand All @@ -30,9 +31,10 @@
public class NullLiteral extends Literal implements ComparableLiteral {

public static final NullLiteral INSTANCE = new NullLiteral();
public static final NullLiteral BOOLEAN_INSTANCE = new NullLiteral(BooleanType.INSTANCE);

public NullLiteral() {
super(NullType.INSTANCE);
this(NullType.INSTANCE);
}

public NullLiteral(DataType dataType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@
import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.trees.plans.visitor.ExpressionLineageReplacer;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.coercion.NumericType;
import org.apache.doris.qe.ConnectContext;

Expand Down Expand Up @@ -218,7 +217,7 @@ public static Expression and(Collection<Expression> expressions) {
}
}

List<Expression> exprList = Lists.newArrayList(distinctExpressions);
List<Expression> exprList = ImmutableList.copyOf(distinctExpressions);
if (exprList.isEmpty()) {
return BooleanLiteral.TRUE;
} else if (exprList.size() == 1) {
Expand Down Expand Up @@ -266,7 +265,7 @@ public static Expression or(Collection<Expression> expressions) {
}
}

List<Expression> exprList = Lists.newArrayList(distinctExpressions);
List<Expression> exprList = ImmutableList.copyOf(distinctExpressions);
if (exprList.isEmpty()) {
return BooleanLiteral.FALSE;
} else if (exprList.size() == 1) {
Expand All @@ -278,15 +277,15 @@ public static Expression or(Collection<Expression> expressions) {

public static Expression falseOrNull(Expression expression) {
if (expression.nullable()) {
return new And(new IsNull(expression), new NullLiteral(BooleanType.INSTANCE));
return new And(new IsNull(expression), NullLiteral.BOOLEAN_INSTANCE);
} else {
return BooleanLiteral.FALSE;
}
}

public static Expression trueOrNull(Expression expression) {
if (expression.nullable()) {
return new Or(new Not(new IsNull(expression)), new NullLiteral(BooleanType.INSTANCE));
return new Or(new Not(new IsNull(expression)), NullLiteral.BOOLEAN_INSTANCE);
} else {
return BooleanLiteral.TRUE;
}
Expand Down Expand Up @@ -668,7 +667,7 @@ public static boolean canInferNotNullForMarkSlot(Expression predicate, Expressio
* and in semi join, we can safely change the mark conjunct to hash conjunct
*/
ImmutableList<Literal> literals =
ImmutableList.of(new NullLiteral(BooleanType.INSTANCE), BooleanLiteral.FALSE);
ImmutableList.of(NullLiteral.BOOLEAN_INSTANCE, BooleanLiteral.FALSE);
List<MarkJoinSlotReference> markJoinSlotReferenceList =
new ArrayList<>((predicate.collect(MarkJoinSlotReference.class::isInstance)));
int markSlotSize = markJoinSlotReferenceList.size();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.expression.rules;

import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper;
import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor;

import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Test;

class CaseWhenToCompoundPredicateTest extends ExpressionRewriteTestHelper {

@Test
void testCaseWhen() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
bottomUp(
CaseWhenToCompoundPredicate.INSTANCE
)
));
assertRewriteAfterTypeCoercion("case when a = 1 then true end", "(a = 1 <=> TRUE) or null");
assertRewriteAfterTypeCoercion("case when a = 1 then true else null end", "(a = 1 <=> TRUE) or null");
assertRewriteAfterTypeCoercion("case when a = 1 then true else false end", "(a = 1 <=> TRUE) or false");
assertRewriteAfterTypeCoercion("case when a = 1 then true else true end", "(a = 1 <=> TRUE) or true");
assertRewriteAfterTypeCoercion("case when a = 1 then true else b = 1 end", "(a = 1 <=> TRUE) or b = 1");
assertRewriteAfterTypeCoercion("case when a = 1 then true when b = 1 then true when c = 1 then true end",
"(a = 1 <=> TRUE) or (b = 1 <=> TRUE) or (c = 1 <=> TRUE) or null");
assertRewriteAfterTypeCoercion("case when a = 1 then false when b = 1 then false when c = 1 then false end",
"not(a = 1 <=> TRUE) and not (b = 1 <=> TRUE) and not(c = 1 <=> TRUE) and null");
assertRewriteAfterTypeCoercion("case when a = 1 then true when b = 1 then false when c = 1 then true end",
"(a = 1 <=> TRUE) or (not (b = 1 <=> TRUE) and ((c = 1 <=> TRUE) or null))");
}

@Test
void testIf() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
bottomUp(
CaseWhenToCompoundPredicate.INSTANCE
)
));

assertRewriteAfterTypeCoercion("if(a = 1, true, a > b)", "(a = 1 <=> TRUE) or a > b");
assertRewriteAfterTypeCoercion("if(a = 1, false, a > b)", "not (a = 1 <=> TRUE) and a > b");
}
}
Loading
Loading