Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import org.apache.doris.nereids.types.SmallIntType;
import org.apache.doris.nereids.types.TinyIntType;
import org.apache.doris.nereids.types.coercion.DateLikeType;
import org.apache.doris.nereids.types.coercion.IntegralType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.TypeCoercionUtils;

Expand All @@ -80,6 +81,11 @@
public class SimplifyComparisonPredicate implements ExpressionPatternRuleFactory {
public static SimplifyComparisonPredicate INSTANCE = new SimplifyComparisonPredicate();

public static final int MAX_INT_TO_FLOAT_NO_LOSS = 1 << 24;
public static final int MIN_INT_TO_FLOAT_NO_LOSS = -MAX_INT_TO_FLOAT_NO_LOSS;
public static final long MAX_LONG_TO_DOUBLE_NO_LOSS = 1L << 53;
public static final long MIN_LONG_TO_DOUBLE_NO_LOSS = -MAX_LONG_TO_DOUBLE_NO_LOSS;

@Override
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
return ImmutableList.of(
Expand Down Expand Up @@ -400,19 +406,25 @@ private static Expression processDecimalV3TypeCoercion(ComparisonPredicate compa
if (left instanceof Cast && right instanceof DecimalV3Literal) {
Cast cast = (Cast) left;
left = cast.child();
DecimalV3Type castDataType = (DecimalV3Type) cast.getDataType();
DecimalV3Literal literal = (DecimalV3Literal) right;
if (left.getDataType().isDecimalV3Type()) {
DecimalV3Type leftType = (DecimalV3Type) left.getDataType();
if (castDataType.getRange() < leftType.getRange()
|| (castDataType.getRange() == leftType.getRange()
&& castDataType.getScale() < leftType.getScale())) {
// for cast(col as decimal(m2, n2)) cmp literal,
// if cast-to can not hold col's integer part, the cast result maybe null, don't process it.
return comparisonPredicate;
}
Optional<Expression> toSmallerDecimalDataTypeExpr = convertDecimalToSmallerDecimalV3Type(
comparisonPredicate, cast, literal);
if (toSmallerDecimalDataTypeExpr.isPresent()) {
return toSmallerDecimalDataTypeExpr.get();
}

DecimalV3Type leftType = (DecimalV3Type) left.getDataType();
DecimalV3Type literalType = (DecimalV3Type) literal.getDataType();
if (cast.getDataType().isDecimalV3Type()
&& ((DecimalV3Type) cast.getDataType()).getScale() >= leftType.getScale()
&& leftType.getScale() < literalType.getScale()) {
if (castDataType.getScale() >= leftType.getScale() && leftType.getScale() < literalType.getScale()) {
int toScale = ((DecimalV3Type) left.getDataType()).getScale();
if (comparisonPredicate instanceof EqualTo) {
try {
Expand Down Expand Up @@ -457,9 +469,17 @@ private static Expression processDecimalV3TypeCoercion(ComparisonPredicate compa
private static Expression processIntegerDecimalLiteralComparison(
ComparisonPredicate comparisonPredicate, Expression left, BigDecimal literal) {
// we only process isIntegerLikeType, which are tinyint, smallint, int, bigint
// for `cast(c_int as decimal(m, n)) cmp literal`,
// if c_int's range is wider than decimal's range, cast result maybe null, don't process it.
DataType castDataType = comparisonPredicate.left().getDataType();
if (castDataType.isDecimalV3Type()
&& ((DecimalV3Type) castDataType).getRange() < ((IntegralType) left.getDataType()).range()) {
return comparisonPredicate;
}
if (literal.compareTo(new BigDecimal(Long.MIN_VALUE)) >= 0
&& literal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0) {
literal = literal.stripTrailingZeros();
Optional<BigDecimal> roundLiteralOpt = Optional.empty();
if (literal.scale() > 0) {
if (comparisonPredicate instanceof EqualTo) {
// TODO: the ideal way is to return an If expr like:
Expand All @@ -473,21 +493,22 @@ private static Expression processIntegerDecimalLiteralComparison(
return BooleanLiteral.of(false);
} else if (comparisonPredicate instanceof GreaterThan
|| comparisonPredicate instanceof LessThanEqual) {
return TypeCoercionUtils
.processComparisonPredicate((ComparisonPredicate) comparisonPredicate
.withChildren(left, convertDecimalToIntegerLikeLiteral(
literal.setScale(0, RoundingMode.FLOOR))));
roundLiteralOpt = Optional.of(literal.setScale(0, RoundingMode.FLOOR));
} else if (comparisonPredicate instanceof LessThan
|| comparisonPredicate instanceof GreaterThanEqual) {
roundLiteralOpt = Optional.of(literal.setScale(0, RoundingMode.CEILING));
}
} else {
roundLiteralOpt = Optional.of(literal);
}
if (roundLiteralOpt.isPresent()) {
Optional<IntegerLikeLiteral> integerLikeLiteralOpt
= convertDecimalToIntegerLikeLiteral(roundLiteralOpt.get(), castDataType);
if (integerLikeLiteralOpt.isPresent()) {
return TypeCoercionUtils
.processComparisonPredicate((ComparisonPredicate) comparisonPredicate
.withChildren(left, convertDecimalToIntegerLikeLiteral(
literal.setScale(0, RoundingMode.CEILING))));
.withChildren(left, integerLikeLiteralOpt.get()));
}
} else {
return TypeCoercionUtils
.processComparisonPredicate((ComparisonPredicate) comparisonPredicate
.withChildren(left, convertDecimalToIntegerLikeLiteral(literal)));
}
}
return comparisonPredicate;
Expand Down Expand Up @@ -615,20 +636,39 @@ private static Optional<Expression> convertDecimalToSmallerDecimalV3Type(Compari
return Optional.empty();
}

private static IntegerLikeLiteral convertDecimalToIntegerLikeLiteral(BigDecimal decimal) {
private static Optional<IntegerLikeLiteral> convertDecimalToIntegerLikeLiteral(BigDecimal decimal,
DataType castDataType) {
Preconditions.checkArgument(decimal.scale() <= 0
&& decimal.compareTo(new BigDecimal(Long.MIN_VALUE)) >= 0
&& decimal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0,
"decimal literal must have 0 scale and in range [Long.MIN_VALUE, Long.MAX_VALUE]");
long val = decimal.longValue();
// for integer like convert to float, only [-2^24, 2^24] can convert to float without loss of precision,
// but here need to exclude the boundary value, because
// cast(2^24 as float) = cast(2^24 + 1 as float) = 2^24 = MAX_INT_TO_FLOAT_NO_LOSS,
// so for cast(c_int as float) = 2^24, we can't simplify it to c_int = 2^24,
// c_int can be 2^24 + 1. The same for -2^24
if (castDataType.isFloatType()
&& (val <= MIN_INT_TO_FLOAT_NO_LOSS || val >= MAX_INT_TO_FLOAT_NO_LOSS)) {
return Optional.empty();
}
// for long convert to double, only [-2^53, 2^53] can convert to double without loss of precision,
// but here need to exclude the boundary value, because
// cast(2^53 as double) = cast(2^53 + 1 as double) = 2^53 = MAX_LONG_TO_DOUBLE_NO_LOSS,
// so for cast(c_bigint as double) = 2^53, we can't simplify it to c_bigint = 2^53,
// c_bigint can be 2^53 + 1. The same for -2^53
if (castDataType.isDoubleType()
&& (val <= MIN_LONG_TO_DOUBLE_NO_LOSS || val >= MAX_LONG_TO_DOUBLE_NO_LOSS)) {
return Optional.empty();
}
if (val >= Byte.MIN_VALUE && val <= Byte.MAX_VALUE) {
return new TinyIntLiteral((byte) val);
return Optional.of(new TinyIntLiteral((byte) val));
} else if (val >= Short.MIN_VALUE && val <= Short.MAX_VALUE) {
return new SmallIntLiteral((short) val);
return Optional.of(new SmallIntLiteral((short) val));
} else if (val >= Integer.MIN_VALUE && val <= Integer.MAX_VALUE) {
return new IntegerLiteral((int) val);
return Optional.of(new IntegerLiteral((int) val));
} else {
return new BigIntLiteral(val);
return Optional.of(new BigIntLiteral(val));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,12 @@ public class Cast extends Expression implements UnaryExpression, Monotonic {

private final DataType targetType;

public Cast(Expression child, DataType targetType, boolean isExplicitType) {
super(ImmutableList.of(child));
this.targetType = Objects.requireNonNull(targetType, "targetType can not be null");
this.isExplicitType = isExplicitType;
public Cast(Expression child, DataType targetType) {
this(child, targetType, false);
}

public Cast(Expression child, DataType targetType) {
super(ImmutableList.of(child));
this.targetType = Objects.requireNonNull(targetType, "targetType can not be null");
this.isExplicitType = false;
public Cast(Expression child, DataType targetType, boolean isExplicitType) {
this(ImmutableList.of(child), targetType, isExplicitType);
}

private Cast(List<Expression> child, DataType targetType, boolean isExplicitType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,9 @@ public DataType defaultConcreteType() {
public int width() {
return WIDTH;
}

@Override
public int range() {
return RANGE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,9 @@ public DataType defaultConcreteType() {
public int width() {
return WIDTH;
}

@Override
public int range() {
return RANGE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,9 @@ public DataType defaultConcreteType() {
public int width() {
return WIDTH;
}

@Override
public int range() {
return RANGE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,9 @@ public DataType defaultConcreteType() {
public int width() {
return WIDTH;
}

@Override
public int range() {
return RANGE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,9 @@ public DataType defaultConcreteType() {
public int width() {
return WIDTH;
}

@Override
public int range() {
return RANGE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DataType;

import org.apache.commons.lang3.NotImplementedException;

/**
* Abstract class for all integral data type in Nereids.
*/
Expand All @@ -45,4 +47,9 @@ public String simpleString() {
public boolean widerThan(IntegralType other) {
return this.width() > other.width();
}

// The maximum number of digits that Integer can represent.
public int range() {
throw new NotImplementedException("should be implemented by derived class");
}
}
Loading
Loading