Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import org.apache.doris.nereids.types.DateV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.coercion.DateLikeType;
import org.apache.doris.nereids.types.coercion.IntegralType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.TypeCoercionUtils;

Expand All @@ -61,6 +62,7 @@
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.List;
import java.util.Optional;

/**
* simplify comparison
Expand All @@ -70,6 +72,11 @@
public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule implements ExpressionPatternRuleFactory {
public static SimplifyComparisonPredicate INSTANCE = new SimplifyComparisonPredicate();

public static final int MAX_INT_TO_FLOAT_NO_LOSS = 1 << 24;
public static final int MIN_INT_TO_FLOAT_NO_LOSS = -MAX_INT_TO_FLOAT_NO_LOSS;
public static final long MAX_LONG_TO_DOUBLE_NO_LOSS = 1L << 53;
public static final long MIN_LONG_TO_DOUBLE_NO_LOSS = -MAX_LONG_TO_DOUBLE_NO_LOSS;

@Override
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
return ImmutableList.of(
Expand Down Expand Up @@ -236,9 +243,17 @@ private static Expression processDecimalV3TypeCoercion(ComparisonPredicate compa
if (left instanceof Cast && right instanceof DecimalV3Literal) {
Cast cast = (Cast) left;
left = cast.child();
DecimalV3Type castDataType = (DecimalV3Type) cast.getDataType();
DecimalV3Literal literal = (DecimalV3Literal) right;
if (left.getDataType().isDecimalV3Type()) {
DecimalV3Type leftType = (DecimalV3Type) left.getDataType();
if (castDataType.getRange() < leftType.getRange()
|| (castDataType.getRange() == leftType.getRange()
&& castDataType.getScale() < leftType.getScale())) {
// for cast(col as decimal(m2, n2)) cmp literal,
// if cast-to can not hold col's integer part, the cast result maybe null, don't process it.
return comparisonPredicate;
}
DecimalV3Type literalType = (DecimalV3Type) literal.getDataType();
if (leftType.getScale() < literalType.getScale()) {
int toScale = ((DecimalV3Type) left.getDataType()).getScale();
Expand Down Expand Up @@ -285,9 +300,17 @@ private static Expression processDecimalV3TypeCoercion(ComparisonPredicate compa
private static Expression processIntegerDecimalLiteralComparison(
ComparisonPredicate comparisonPredicate, Expression left, BigDecimal literal) {
// we only process isIntegerLikeType, which are tinyint, smallint, int, bigint
// for `cast(c_int as decimal(m, n)) cmp literal`,
// if c_int's range is wider than decimal's range, cast result maybe null, don't process it.
DataType castDataType = comparisonPredicate.left().getDataType();
if (castDataType.isDecimalV3Type()
&& ((DecimalV3Type) castDataType).getRange() < ((IntegralType) left.getDataType()).range()) {
return comparisonPredicate;
}
if (literal.compareTo(new BigDecimal(Long.MIN_VALUE)) >= 0
&& literal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0) {
literal = literal.stripTrailingZeros();
Optional<BigDecimal> roundLiteralOpt = Optional.empty();
if (literal.scale() > 0) {
if (comparisonPredicate instanceof EqualTo) {
// TODO: the ideal way is to return an If expr like:
Expand All @@ -301,40 +324,60 @@ private static Expression processIntegerDecimalLiteralComparison(
return BooleanLiteral.of(false);
} else if (comparisonPredicate instanceof GreaterThan
|| comparisonPredicate instanceof LessThanEqual) {
return TypeCoercionUtils
.processComparisonPredicate((ComparisonPredicate) comparisonPredicate
.withChildren(left, convertDecimalToIntegerLikeLiteral(
literal.setScale(0, RoundingMode.FLOOR))));
roundLiteralOpt = Optional.of(literal.setScale(0, RoundingMode.FLOOR));
} else if (comparisonPredicate instanceof LessThan
|| comparisonPredicate instanceof GreaterThanEqual) {
roundLiteralOpt = Optional.of(literal.setScale(0, RoundingMode.CEILING));
}
} else {
roundLiteralOpt = Optional.of(literal);
}
if (roundLiteralOpt.isPresent()) {
Optional<IntegerLikeLiteral> integerLikeLiteralOpt
= convertDecimalToIntegerLikeLiteral(roundLiteralOpt.get(), castDataType);
if (integerLikeLiteralOpt.isPresent()) {
return TypeCoercionUtils
.processComparisonPredicate((ComparisonPredicate) comparisonPredicate
.withChildren(left, convertDecimalToIntegerLikeLiteral(
literal.setScale(0, RoundingMode.CEILING))));
.withChildren(left, integerLikeLiteralOpt.get()));
}
} else {
return TypeCoercionUtils
.processComparisonPredicate((ComparisonPredicate) comparisonPredicate
.withChildren(left, convertDecimalToIntegerLikeLiteral(literal)));
}
}
return comparisonPredicate;
}

private static IntegerLikeLiteral convertDecimalToIntegerLikeLiteral(BigDecimal decimal) {
private static Optional<IntegerLikeLiteral> convertDecimalToIntegerLikeLiteral(BigDecimal decimal,
DataType castDataType) {
Preconditions.checkArgument(decimal.scale() <= 0
&& decimal.compareTo(new BigDecimal(Long.MIN_VALUE)) >= 0
&& decimal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0,
"decimal literal must have 0 scale and in range [Long.MIN_VALUE, Long.MAX_VALUE]");
long val = decimal.longValue();
// for integer like convert to float, only [-2^24, 2^24] can convert to float without loss of precision,
// but here need to exclude the boundary value, because
// cast(2^24 as float) = cast(2^24 + 1 as float) = 2^24 = MAX_INT_TO_FLOAT_NO_LOSS,
// so for cast(c_int as float) = 2^24, we can't simplify it to c_int = 2^24,
// c_int can be 2^24 + 1. The same for -2^24
if (castDataType.isFloatType()
&& (val <= MIN_INT_TO_FLOAT_NO_LOSS || val >= MAX_INT_TO_FLOAT_NO_LOSS)) {
return Optional.empty();
}
// for long convert to double, only [-2^53, 2^53] can convert to double without loss of precision,
// but here need to exclude the boundary value, because
// cast(2^53 as double) = cast(2^53 + 1 as double) = 2^53 = MAX_LONG_TO_DOUBLE_NO_LOSS,
// so for cast(c_bigint as double) = 2^53, we can't simplify it to c_bigint = 2^53,
// c_bigint can be 2^53 + 1. The same for -2^53
if (castDataType.isDoubleType()
&& (val <= MIN_LONG_TO_DOUBLE_NO_LOSS || val >= MAX_LONG_TO_DOUBLE_NO_LOSS)) {
return Optional.empty();
}
if (val >= Byte.MIN_VALUE && val <= Byte.MAX_VALUE) {
return new TinyIntLiteral((byte) val);
return Optional.of(new TinyIntLiteral((byte) val));
} else if (val >= Short.MIN_VALUE && val <= Short.MAX_VALUE) {
return new SmallIntLiteral((short) val);
return Optional.of(new SmallIntLiteral((short) val));
} else if (val >= Integer.MIN_VALUE && val <= Integer.MAX_VALUE) {
return new IntegerLiteral((int) val);
return Optional.of(new IntegerLiteral((int) val));
} else {
return new BigIntLiteral(val);
return Optional.of(new BigIntLiteral(val));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,16 @@ public class SimplifyDecimalV3Comparison implements ExpressionPatternRuleFactory
@Override
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
return ImmutableList.of(
matchesType(ComparisonPredicate.class).then(SimplifyDecimalV3Comparison::simplify)
matchesType(ComparisonPredicate.class).then(this::simplify)
);
}

/** simplify */
public static Expression simplify(ComparisonPredicate cp) {
public Expression simplify(ComparisonPredicate cp) {
Expression left = cp.left();
Expression right = cp.right();

if (left.getDataType() instanceof DecimalV3Type
&& left instanceof Cast
&& ((Cast) left).child().getDataType() instanceof DecimalV3Type
&& ((DecimalV3Type) left.getDataType()).getScale()
>= ((DecimalV3Type) ((Cast) left).child().getDataType()).getScale()
&& right instanceof DecimalV3Literal) {
if (canProcess(left, right)) {
try {
return doProcess(cp, (Cast) left, (DecimalV3Literal) right);
} catch (ArithmeticException e) {
Expand All @@ -71,7 +66,21 @@ public static Expression simplify(ComparisonPredicate cp) {
return cp;
}

private static Expression doProcess(ComparisonPredicate cp, Cast left, DecimalV3Literal right) {
private boolean canProcess(Expression left, Expression right) {
if (!(left.getDataType() instanceof DecimalV3Type
&& left instanceof Cast
&& left.child(0).getDataType() instanceof DecimalV3Type
&& right instanceof DecimalV3Literal)) {
return false;
}

DecimalV3Type castType = (DecimalV3Type) left.getDataType();
DecimalV3Type castChildType = (DecimalV3Type) left.child(0).getDataType();

return castType.getRange() >= castChildType.getRange() && castType.getScale() >= castChildType.getScale();
}

private Expression doProcess(ComparisonPredicate cp, Cast left, DecimalV3Literal right) {
BigDecimal trailingZerosValue = right.getValue().stripTrailingZeros();
int scale = org.apache.doris.analysis.DecimalLiteral.getBigDecimalScale(trailingZerosValue);
int precision = org.apache.doris.analysis.DecimalLiteral.getBigDecimalPrecision(trailingZerosValue);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,12 @@ public class Cast extends Expression implements UnaryExpression {

private final DataType targetType;

public Cast(Expression child, DataType targetType, boolean isExplicitType) {
super(ImmutableList.of(child));
this.targetType = Objects.requireNonNull(targetType, "targetType can not be null");
this.isExplicitType = isExplicitType;
public Cast(Expression child, DataType targetType) {
this(child, targetType, false);
}

public Cast(Expression child, DataType targetType) {
super(ImmutableList.of(child));
this.targetType = Objects.requireNonNull(targetType, "targetType can not be null");
this.isExplicitType = false;
public Cast(Expression child, DataType targetType, boolean isExplicitType) {
this(ImmutableList.of(child), targetType, isExplicitType);
}

private Cast(List<Expression> child, DataType targetType, boolean isExplicitType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ public class BigIntType extends IntegralType implements Int64OrLessType {
public static final BigIntType INSTANCE = new BigIntType("bigint");
public static final BigIntType SIGNED = new BigIntType("signed");

public static final int RANGE = 19; // The maximum number of digits that BigInt can represent.
private static final int WIDTH = 8;

private final String simpleName;
Expand Down Expand Up @@ -61,4 +62,9 @@ public DataType defaultConcreteType() {
public int width() {
return WIDTH;
}

@Override
public int range() {
return RANGE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
public class IntegerType extends IntegralType implements Int32OrLessType {
public static final IntegerType INSTANCE = new IntegerType();

public static final int RANGE = 10; // The maximum number of digits that Integer can represent.
private static final int WIDTH = 4;

private IntegerType() {
Expand Down Expand Up @@ -61,4 +62,9 @@ public DataType defaultConcreteType() {
public int width() {
return WIDTH;
}

@Override
public int range() {
return RANGE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public class LargeIntType extends IntegralType {

public static final BigInteger MIN_VALUE = new BigInteger("-170141183460469231731687303715884105728");

public static final int RANGE = 39; // The maximum number of digits that LargeInteger can represent.
private static final int WIDTH = 16;

private final String simpleName;
Expand Down Expand Up @@ -71,4 +72,9 @@ public DataType defaultConcreteType() {
public int width() {
return WIDTH;
}

@Override
public int range() {
return RANGE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
public class SmallIntType extends IntegralType implements Int16OrLessType {
public static final SmallIntType INSTANCE = new SmallIntType();

public static final int RANGE = 5; // The maximum number of digits that SmallInteger can represent.
private static final int WIDTH = 2;

private SmallIntType() {
Expand Down Expand Up @@ -61,4 +62,9 @@ public DataType defaultConcreteType() {
public int width() {
return WIDTH;
}

@Override
public int range() {
return RANGE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
public class TinyIntType extends IntegralType implements Int16OrLessType {
public static final TinyIntType INSTANCE = new TinyIntType();

public static final int RANGE = 3; // The maximum number of digits that TinyIntType can represent.
private static final int WIDTH = 1;

private TinyIntType() {
Expand Down Expand Up @@ -61,4 +62,9 @@ public DataType defaultConcreteType() {
public int width() {
return WIDTH;
}

@Override
public int range() {
return RANGE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DataType;

import org.apache.commons.lang3.NotImplementedException;

/**
* Abstract class for all integral data type in Nereids.
*/
Expand All @@ -45,4 +47,9 @@ public String simpleString() {
public boolean widerThan(IntegralType other) {
return this.width() > other.width();
}

// The maximum number of digits that Integer can represent.
public int range() {
throw new NotImplementedException("should be implemented by derived class");
}
}
Loading