Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,17 @@

package org.apache.doris.nereids.rules.expression.rules;

import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
Expand All @@ -44,16 +43,16 @@
import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.literal.NumericLiteral;
import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.DateTimeV2Type;
import org.apache.doris.nereids.types.DateType;
import org.apache.doris.nereids.types.DateV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.coercion.DateLikeType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.TypeCoercionUtils;

import com.google.common.base.Preconditions;
Expand All @@ -62,9 +61,10 @@
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.List;
import java.util.Optional;

/**
* simplify comparison
* simplify comparison, not support large int.
* such as: cast(c1 as DateV2) >= DateV2Literal --> c1 >= DateLiteral
* cast(c1 AS double) > 2.0 --> c1 >= 2 (c1 is integer like type)
*/
Expand Down Expand Up @@ -98,22 +98,25 @@ public static Expression simplify(ComparisonPredicate cp) {
Expression left = cp.left();
Expression right = cp.right();

// float like type: float, double
if (left.getDataType().isFloatLikeType() && right.getDataType().isFloatLikeType()) {
return processFloatLikeTypeCoercion(cp, left, right);
}
Expression result;

// decimalv3 type
if (left.getDataType() instanceof DecimalV3Type && right.getDataType() instanceof DecimalV3Type) {
return processDecimalV3TypeCoercion(cp, left, right);
// process type coercion
if (left.getDataType().isFloatLikeType() && right.getDataType().isFloatLikeType()) {
result = processFloatLikeTypeCoercion(cp, left, right);
} else if (left.getDataType() instanceof DecimalV3Type && right.getDataType() instanceof DecimalV3Type) {
result = processDecimalV3TypeCoercion(cp, left, right);
} else if (left.getDataType() instanceof DateLikeType && right.getDataType() instanceof DateLikeType) {
result = processDateLikeTypeCoercion(cp, left, right);
} else {
result = cp;
}

// date like type
if (left.getDataType() instanceof DateLikeType && right.getDataType() instanceof DateLikeType) {
return processDateLikeTypeCoercion(cp, left, right);
if (result instanceof ComparisonPredicate && ((ComparisonPredicate) result).right() instanceof NumericLiteral) {
ComparisonPredicate cmp = (ComparisonPredicate) result;
result = processTypeRangeLimitComparison(cmp, cmp.left(), (NumericLiteral) cmp.right());
}

return cp;
return result;
}

private static Expression processComparisonPredicateDateTimeV2Literal(
Expand All @@ -128,17 +131,13 @@ private static Expression processComparisonPredicateDateTimeV2Literal(
if (right.getMicroSecond() == originValue) {
return comparisonPredicate.withChildren(left, right);
} else {
if (left.nullable()) {
// TODO: the ideal way is to return an If expr like:
// return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE),
// BooleanLiteral.of(false));
// but current fold constant rule can't handle such complex expr with null literal
// before supporting complex conjuncts with null literal folding rules,
// we use a trick way like this:
return new And(new IsNull(left), new NullLiteral(BooleanType.INSTANCE));
} else {
return BooleanLiteral.of(false);
}
// TODO: the ideal way is to return an If expr like:
// return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE),
// BooleanLiteral.of(false));
// but current fold constant rule can't handle such complex expr with null literal
// before supporting complex conjuncts with null literal folding rules,
// we use a trick way like this:
return ExpressionUtils.falseOrNull(left);
}
} else if (comparisonPredicate instanceof NullSafeEqual) {
long originValue = right.getMicroSecond();
Expand Down Expand Up @@ -239,18 +238,13 @@ private static Expression processDecimalV3TypeCoercion(ComparisonPredicate compa
comparisonPredicate.withChildren(left, new DecimalV3Literal(
literal.getValue().setScale(toScale, RoundingMode.UNNECESSARY))));
} catch (ArithmeticException e) {
if (left.nullable()) {
// TODO: the ideal way is to return an If expr like:
// return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE),
// BooleanLiteral.of(false));
// but current fold constant rule can't handle such complex expr with null literal
// before supporting complex conjuncts with null literal folding rules,
// we use a trick way like this:
return new And(new IsNull(left),
new NullLiteral(BooleanType.INSTANCE));
} else {
return BooleanLiteral.of(false);
}
// TODO: the ideal way is to return an If expr like:
// return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE),
// BooleanLiteral.of(false));
// but current fold constant rule can't handle such complex expr with null literal
// before supporting complex conjuncts with null literal folding rules,
// we use a trick way like this:
return ExpressionUtils.falseOrNull(left);
}
} else if (comparisonPredicate instanceof NullSafeEqual) {
try {
Expand Down Expand Up @@ -281,21 +275,18 @@ private static Expression processDecimalV3TypeCoercion(ComparisonPredicate compa
private static Expression processIntegerDecimalLiteralComparison(
ComparisonPredicate comparisonPredicate, Expression left, BigDecimal literal) {
// we only process isIntegerLikeType, which are tinyint, smallint, int, bigint
if (literal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0) {
if (literal.compareTo(new BigDecimal(Long.MIN_VALUE)) >= 0
&& literal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0) {
literal = literal.stripTrailingZeros();
if (literal.scale() > 0) {
if (comparisonPredicate instanceof EqualTo) {
if (left.nullable()) {
// TODO: the ideal way is to return an If expr like:
// return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE),
// BooleanLiteral.of(false));
// but current fold constant rule can't handle such complex expr with null literal
// before supporting complex conjuncts with null literal folding rules,
// we use a trick way like this:
return new And(new IsNull(left), new NullLiteral(BooleanType.INSTANCE));
} else {
return BooleanLiteral.of(false);
}
// TODO: the ideal way is to return an If expr like:
// return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE),
// BooleanLiteral.of(false));
// but current fold constant rule can't handle such complex expr with null literal
// before supporting complex conjuncts with null literal folding rules,
// we use a trick way like this:
return ExpressionUtils.falseOrNull(left);
} else if (comparisonPredicate instanceof NullSafeEqual) {
return BooleanLiteral.of(false);
} else if (comparisonPredicate instanceof GreaterThan
Expand All @@ -320,10 +311,95 @@ private static Expression processIntegerDecimalLiteralComparison(
return comparisonPredicate;
}

private static Expression processTypeRangeLimitComparison(ComparisonPredicate cp, Expression left,
NumericLiteral right) {
BigDecimal typeMinValue = null;
BigDecimal typeMaxValue = null;
// cmp float like have lost precision, for example float.max_value + 0.01 still eval to float.max_value
if (left.getDataType().isIntegerLikeType() || left.getDataType().isDecimalV3Type()) {
Optional<Pair<BigDecimal, BigDecimal>> minMaxOpt =
TypeCoercionUtils.getDataTypeMinMaxValue(left.getDataType());
if (minMaxOpt.isPresent()) {
typeMinValue = minMaxOpt.get().first;
typeMaxValue = minMaxOpt.get().second;
}
}

// cast(child as dataType2) range should be:
// [ max(childDataType.min_value, dataType2.min_value), min(childDataType.max_value, dataType2.max_value)]
if (left instanceof Cast) {
left = ((Cast) left).child();
if (left.getDataType().isIntegerLikeType() || left.getDataType().isDecimalV3Type()) {
Optional<Pair<BigDecimal, BigDecimal>> minMaxOpt =
TypeCoercionUtils.getDataTypeMinMaxValue(left.getDataType());
if (minMaxOpt.isPresent()) {
if (typeMinValue == null || typeMinValue.compareTo(minMaxOpt.get().first) < 0) {
typeMinValue = minMaxOpt.get().first;
}
if (typeMaxValue == null || typeMaxValue.compareTo(minMaxOpt.get().second) > 0) {
typeMaxValue = minMaxOpt.get().second;
}
}
}
}

if (typeMinValue == null || typeMaxValue == null) {
return cp;
}
BigDecimal literal = new BigDecimal(right.getStringValue());
int cmpMin = literal.compareTo(typeMinValue);
int cmpMax = literal.compareTo(typeMaxValue);
if (cp instanceof EqualTo) {
if (cmpMin < 0 || cmpMax > 0) {
return ExpressionUtils.falseOrNull(left);
}
} else if (cp instanceof NullSafeEqual) {
if (cmpMin < 0 || cmpMax > 0) {
return BooleanLiteral.of(false);
}
} else if (cp instanceof GreaterThan) {
if (cmpMin < 0) {
return ExpressionUtils.trueOrNull(left);
}
if (cmpMax >= 0) {
return ExpressionUtils.falseOrNull(left);
}
} else if (cp instanceof GreaterThanEqual) {
if (cmpMin <= 0) {
return ExpressionUtils.trueOrNull(left);
}
if (cmpMax == 0) {
return new EqualTo(cp.left(), cp.right());
}
if (cmpMax > 0) {
return ExpressionUtils.falseOrNull(left);
}
} else if (cp instanceof LessThan) {
if (cmpMin <= 0) {
return ExpressionUtils.falseOrNull(left);
}
if (cmpMax > 0) {
return ExpressionUtils.trueOrNull(left);
}
} else if (cp instanceof LessThanEqual) {
if (cmpMin < 0) {
return ExpressionUtils.falseOrNull(left);
}
if (cmpMin == 0) {
return new EqualTo(cp.left(), cp.right());
}
if (cmpMax >= 0) {
return ExpressionUtils.trueOrNull(left);
}
}
return cp;
}

private static IntegerLikeLiteral convertDecimalToIntegerLikeLiteral(BigDecimal decimal) {
Preconditions.checkArgument(
decimal.scale() <= 0 && decimal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0,
"decimal literal must have 0 scale and smaller than Long.MAX_VALUE");
Preconditions.checkArgument(decimal.scale() <= 0
&& decimal.compareTo(new BigDecimal(Long.MIN_VALUE)) >= 0
&& decimal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0,
"decimal literal must have 0 scale and in range [Long.MIN_VALUE, Long.MAX_VALUE]");
long val = decimal.longValue();
if (val >= Byte.MIN_VALUE && val <= Byte.MAX_VALUE) {
return new TinyIntLiteral((byte) val);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,22 @@ public static Expression or(Collection<Expression> expressions) {
}
}

public static Expression falseOrNull(Expression expression) {
if (expression.nullable()) {
return new And(new IsNull(expression), new NullLiteral(BooleanType.INSTANCE));
} else {
return BooleanLiteral.FALSE;
}
}

public static Expression trueOrNull(Expression expression) {
if (expression.nullable()) {
return new Or(new Not(new IsNull(expression)), new NullLiteral(BooleanType.INSTANCE));
} else {
return BooleanLiteral.TRUE;
}
}

/**
* Use AND/OR to combine expressions together.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.doris.catalog.ScalarType;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.Config;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Add;
Expand Down Expand Up @@ -116,6 +117,7 @@
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

Expand Down Expand Up @@ -1773,6 +1775,48 @@ private static Expression processDecimalV3BinaryArithmetic(BinaryArithmetic bina
castIfNotSameType(right, dt2));
}

/**
* get min and max value of a data type
*
* @param dataType specific data type
* @return min and max values pair
*/
public static Optional<Pair<BigDecimal, BigDecimal>> getDataTypeMinMaxValue(DataType dataType) {
if (dataType.isTinyIntType()) {
return Optional.of(Pair.of(new BigDecimal(Byte.MIN_VALUE), new BigDecimal(Byte.MAX_VALUE)));
} else if (dataType.isSmallIntType()) {
return Optional.of(Pair.of(new BigDecimal(Short.MIN_VALUE), new BigDecimal(Short.MAX_VALUE)));
} else if (dataType.isIntegerType()) {
return Optional.of(Pair.of(new BigDecimal(Integer.MIN_VALUE), new BigDecimal(Integer.MAX_VALUE)));
} else if (dataType.isBigIntType()) {
return Optional.of(Pair.of(new BigDecimal(Long.MIN_VALUE), new BigDecimal(Long.MAX_VALUE)));
} else if (dataType.isLargeIntType()) {
return Optional.of(Pair.of(new BigDecimal(LargeIntType.MIN_VALUE), new BigDecimal(LargeIntType.MAX_VALUE)));
} else if (dataType.isFloatType()) {
return Optional.of(Pair.of(BigDecimal.valueOf(-Float.MAX_VALUE), new BigDecimal(Float.MAX_VALUE)));
} else if (dataType.isDoubleType()) {
return Optional.of(Pair.of(BigDecimal.valueOf(-Double.MAX_VALUE), new BigDecimal(Double.MAX_VALUE)));
} else if (dataType.isDecimalV3Type()) {
DecimalV3Type type = (DecimalV3Type) dataType;
int precision = type.getPrecision();
int scale = type.getScale();
if (scale >= 0) {
StringBuilder sb = new StringBuilder();
sb.append(StringUtils.repeat('9', precision - scale));
if (sb.length() == 0) {
sb.append('0');
}
if (scale > 0) {
sb.append('.');
sb.append(StringUtils.repeat('9', scale));
}
return Optional.of(Pair.of(new BigDecimal("-" + sb.toString()), new BigDecimal(sb.toString())));
}
}

return Optional.empty();
}

private static boolean supportCompare(DataType dataType) {
if (dataType.isArrayType()) {
return true;
Expand Down
Loading
Loading