diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java index 5e6968b9142333..da1f245ab1ab68 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java @@ -52,6 +52,7 @@ import org.apache.doris.nereids.types.DateV2Type; import org.apache.doris.nereids.types.DecimalV3Type; import org.apache.doris.nereids.types.coercion.DateLikeType; +import org.apache.doris.nereids.types.coercion.IntegralType; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.TypeCoercionUtils; @@ -61,6 +62,7 @@ import java.math.BigDecimal; import java.math.RoundingMode; import java.util.List; +import java.util.Optional; /** * simplify comparison @@ -70,6 +72,11 @@ public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule implements ExpressionPatternRuleFactory { public static SimplifyComparisonPredicate INSTANCE = new SimplifyComparisonPredicate(); + public static final int MAX_INT_TO_FLOAT_NO_LOSS = 1 << 24; + public static final int MIN_INT_TO_FLOAT_NO_LOSS = -MAX_INT_TO_FLOAT_NO_LOSS; + public static final long MAX_LONG_TO_DOUBLE_NO_LOSS = 1L << 53; + public static final long MIN_LONG_TO_DOUBLE_NO_LOSS = -MAX_LONG_TO_DOUBLE_NO_LOSS; + @Override public List> buildRules() { return ImmutableList.of( @@ -236,9 +243,17 @@ private static Expression processDecimalV3TypeCoercion(ComparisonPredicate compa if (left instanceof Cast && right instanceof DecimalV3Literal) { Cast cast = (Cast) left; left = cast.child(); + DecimalV3Type castDataType = (DecimalV3Type) cast.getDataType(); DecimalV3Literal literal = (DecimalV3Literal) right; if (left.getDataType().isDecimalV3Type()) { DecimalV3Type leftType = (DecimalV3Type) left.getDataType(); + if (castDataType.getRange() < leftType.getRange() + || (castDataType.getRange() == leftType.getRange() + && castDataType.getScale() < leftType.getScale())) { + // for cast(col as decimal(m2, n2)) cmp literal, + // if cast-to can not hold col's integer part, the cast result maybe null, don't process it. + return comparisonPredicate; + } DecimalV3Type literalType = (DecimalV3Type) literal.getDataType(); if (leftType.getScale() < literalType.getScale()) { int toScale = ((DecimalV3Type) left.getDataType()).getScale(); @@ -285,9 +300,17 @@ private static Expression processDecimalV3TypeCoercion(ComparisonPredicate compa private static Expression processIntegerDecimalLiteralComparison( ComparisonPredicate comparisonPredicate, Expression left, BigDecimal literal) { // we only process isIntegerLikeType, which are tinyint, smallint, int, bigint + // for `cast(c_int as decimal(m, n)) cmp literal`, + // if c_int's range is wider than decimal's range, cast result maybe null, don't process it. + DataType castDataType = comparisonPredicate.left().getDataType(); + if (castDataType.isDecimalV3Type() + && ((DecimalV3Type) castDataType).getRange() < ((IntegralType) left.getDataType()).range()) { + return comparisonPredicate; + } if (literal.compareTo(new BigDecimal(Long.MIN_VALUE)) >= 0 && literal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0) { literal = literal.stripTrailingZeros(); + Optional roundLiteralOpt = Optional.empty(); if (literal.scale() > 0) { if (comparisonPredicate instanceof EqualTo) { // TODO: the ideal way is to return an If expr like: @@ -301,40 +324,60 @@ private static Expression processIntegerDecimalLiteralComparison( return BooleanLiteral.of(false); } else if (comparisonPredicate instanceof GreaterThan || comparisonPredicate instanceof LessThanEqual) { - return TypeCoercionUtils - .processComparisonPredicate((ComparisonPredicate) comparisonPredicate - .withChildren(left, convertDecimalToIntegerLikeLiteral( - literal.setScale(0, RoundingMode.FLOOR)))); + roundLiteralOpt = Optional.of(literal.setScale(0, RoundingMode.FLOOR)); } else if (comparisonPredicate instanceof LessThan || comparisonPredicate instanceof GreaterThanEqual) { + roundLiteralOpt = Optional.of(literal.setScale(0, RoundingMode.CEILING)); + } + } else { + roundLiteralOpt = Optional.of(literal); + } + if (roundLiteralOpt.isPresent()) { + Optional integerLikeLiteralOpt + = convertDecimalToIntegerLikeLiteral(roundLiteralOpt.get(), castDataType); + if (integerLikeLiteralOpt.isPresent()) { return TypeCoercionUtils .processComparisonPredicate((ComparisonPredicate) comparisonPredicate - .withChildren(left, convertDecimalToIntegerLikeLiteral( - literal.setScale(0, RoundingMode.CEILING)))); + .withChildren(left, integerLikeLiteralOpt.get())); } - } else { - return TypeCoercionUtils - .processComparisonPredicate((ComparisonPredicate) comparisonPredicate - .withChildren(left, convertDecimalToIntegerLikeLiteral(literal))); } } return comparisonPredicate; } - private static IntegerLikeLiteral convertDecimalToIntegerLikeLiteral(BigDecimal decimal) { + private static Optional convertDecimalToIntegerLikeLiteral(BigDecimal decimal, + DataType castDataType) { Preconditions.checkArgument(decimal.scale() <= 0 && decimal.compareTo(new BigDecimal(Long.MIN_VALUE)) >= 0 && decimal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0, "decimal literal must have 0 scale and in range [Long.MIN_VALUE, Long.MAX_VALUE]"); long val = decimal.longValue(); + // for integer like convert to float, only [-2^24, 2^24] can convert to float without loss of precision, + // but here need to exclude the boundary value, because + // cast(2^24 as float) = cast(2^24 + 1 as float) = 2^24 = MAX_INT_TO_FLOAT_NO_LOSS, + // so for cast(c_int as float) = 2^24, we can't simplify it to c_int = 2^24, + // c_int can be 2^24 + 1. The same for -2^24 + if (castDataType.isFloatType() + && (val <= MIN_INT_TO_FLOAT_NO_LOSS || val >= MAX_INT_TO_FLOAT_NO_LOSS)) { + return Optional.empty(); + } + // for long convert to double, only [-2^53, 2^53] can convert to double without loss of precision, + // but here need to exclude the boundary value, because + // cast(2^53 as double) = cast(2^53 + 1 as double) = 2^53 = MAX_LONG_TO_DOUBLE_NO_LOSS, + // so for cast(c_bigint as double) = 2^53, we can't simplify it to c_bigint = 2^53, + // c_bigint can be 2^53 + 1. The same for -2^53 + if (castDataType.isDoubleType() + && (val <= MIN_LONG_TO_DOUBLE_NO_LOSS || val >= MAX_LONG_TO_DOUBLE_NO_LOSS)) { + return Optional.empty(); + } if (val >= Byte.MIN_VALUE && val <= Byte.MAX_VALUE) { - return new TinyIntLiteral((byte) val); + return Optional.of(new TinyIntLiteral((byte) val)); } else if (val >= Short.MIN_VALUE && val <= Short.MAX_VALUE) { - return new SmallIntLiteral((short) val); + return Optional.of(new SmallIntLiteral((short) val)); } else if (val >= Integer.MIN_VALUE && val <= Integer.MAX_VALUE) { - return new IntegerLiteral((int) val); + return Optional.of(new IntegerLiteral((int) val)); } else { - return new BigIntLiteral(val); + return Optional.of(new BigIntLiteral(val)); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3Comparison.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3Comparison.java index f6a91f1dc0c522..23c83f81e1310d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3Comparison.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3Comparison.java @@ -46,21 +46,16 @@ public class SimplifyDecimalV3Comparison implements ExpressionPatternRuleFactory @Override public List> buildRules() { return ImmutableList.of( - matchesType(ComparisonPredicate.class).then(SimplifyDecimalV3Comparison::simplify) + matchesType(ComparisonPredicate.class).then(this::simplify) ); } /** simplify */ - public static Expression simplify(ComparisonPredicate cp) { + public Expression simplify(ComparisonPredicate cp) { Expression left = cp.left(); Expression right = cp.right(); - if (left.getDataType() instanceof DecimalV3Type - && left instanceof Cast - && ((Cast) left).child().getDataType() instanceof DecimalV3Type - && ((DecimalV3Type) left.getDataType()).getScale() - >= ((DecimalV3Type) ((Cast) left).child().getDataType()).getScale() - && right instanceof DecimalV3Literal) { + if (canProcess(left, right)) { try { return doProcess(cp, (Cast) left, (DecimalV3Literal) right); } catch (ArithmeticException e) { @@ -71,7 +66,21 @@ public static Expression simplify(ComparisonPredicate cp) { return cp; } - private static Expression doProcess(ComparisonPredicate cp, Cast left, DecimalV3Literal right) { + private boolean canProcess(Expression left, Expression right) { + if (!(left.getDataType() instanceof DecimalV3Type + && left instanceof Cast + && left.child(0).getDataType() instanceof DecimalV3Type + && right instanceof DecimalV3Literal)) { + return false; + } + + DecimalV3Type castType = (DecimalV3Type) left.getDataType(); + DecimalV3Type castChildType = (DecimalV3Type) left.child(0).getDataType(); + + return castType.getRange() >= castChildType.getRange() && castType.getScale() >= castChildType.getScale(); + } + + private Expression doProcess(ComparisonPredicate cp, Cast left, DecimalV3Literal right) { BigDecimal trailingZerosValue = right.getValue().stripTrailingZeros(); int scale = org.apache.doris.analysis.DecimalLiteral.getBigDecimalScale(trailingZerosValue); int precision = org.apache.doris.analysis.DecimalLiteral.getBigDecimalPrecision(trailingZerosValue); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Cast.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Cast.java index c6104777bb5d0e..677be83ccce523 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Cast.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Cast.java @@ -38,16 +38,12 @@ public class Cast extends Expression implements UnaryExpression { private final DataType targetType; - public Cast(Expression child, DataType targetType, boolean isExplicitType) { - super(ImmutableList.of(child)); - this.targetType = Objects.requireNonNull(targetType, "targetType can not be null"); - this.isExplicitType = isExplicitType; + public Cast(Expression child, DataType targetType) { + this(child, targetType, false); } - public Cast(Expression child, DataType targetType) { - super(ImmutableList.of(child)); - this.targetType = Objects.requireNonNull(targetType, "targetType can not be null"); - this.isExplicitType = false; + public Cast(Expression child, DataType targetType, boolean isExplicitType) { + this(ImmutableList.of(child), targetType, isExplicitType); } private Cast(List child, DataType targetType, boolean isExplicitType) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/BigIntType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/BigIntType.java index 762cc78e09fb7c..f654b9f271ba81 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/BigIntType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/BigIntType.java @@ -29,6 +29,7 @@ public class BigIntType extends IntegralType implements Int64OrLessType { public static final BigIntType INSTANCE = new BigIntType("bigint"); public static final BigIntType SIGNED = new BigIntType("signed"); + public static final int RANGE = 19; // The maximum number of digits that BigInt can represent. private static final int WIDTH = 8; private final String simpleName; @@ -61,4 +62,9 @@ public DataType defaultConcreteType() { public int width() { return WIDTH; } + + @Override + public int range() { + return RANGE; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/IntegerType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/IntegerType.java index 58838abbc8c2a1..7b47ff90b72270 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/IntegerType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/IntegerType.java @@ -27,6 +27,7 @@ public class IntegerType extends IntegralType implements Int32OrLessType { public static final IntegerType INSTANCE = new IntegerType(); + public static final int RANGE = 10; // The maximum number of digits that Integer can represent. private static final int WIDTH = 4; private IntegerType() { @@ -61,4 +62,9 @@ public DataType defaultConcreteType() { public int width() { return WIDTH; } + + @Override + public int range() { + return RANGE; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/LargeIntType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/LargeIntType.java index 8f619fd9e7cbef..ee6c4c9c0560f1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/LargeIntType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/LargeIntType.java @@ -34,6 +34,7 @@ public class LargeIntType extends IntegralType { public static final BigInteger MIN_VALUE = new BigInteger("-170141183460469231731687303715884105728"); + public static final int RANGE = 39; // The maximum number of digits that LargeInteger can represent. private static final int WIDTH = 16; private final String simpleName; @@ -71,4 +72,9 @@ public DataType defaultConcreteType() { public int width() { return WIDTH; } + + @Override + public int range() { + return RANGE; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/SmallIntType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/SmallIntType.java index 176bf90ddcabc8..ad2743a4a7e19a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/SmallIntType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/SmallIntType.java @@ -27,6 +27,7 @@ public class SmallIntType extends IntegralType implements Int16OrLessType { public static final SmallIntType INSTANCE = new SmallIntType(); + public static final int RANGE = 5; // The maximum number of digits that SmallInteger can represent. private static final int WIDTH = 2; private SmallIntType() { @@ -61,4 +62,9 @@ public DataType defaultConcreteType() { public int width() { return WIDTH; } + + @Override + public int range() { + return RANGE; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TinyIntType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TinyIntType.java index 3c5a351bf02ee2..075cb945db2479 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TinyIntType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TinyIntType.java @@ -27,6 +27,7 @@ public class TinyIntType extends IntegralType implements Int16OrLessType { public static final TinyIntType INSTANCE = new TinyIntType(); + public static final int RANGE = 3; // The maximum number of digits that TinyIntType can represent. private static final int WIDTH = 1; private TinyIntType() { @@ -61,4 +62,9 @@ public DataType defaultConcreteType() { public int width() { return WIDTH; } + + @Override + public int range() { + return RANGE; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java index 656c6f660f2090..b1e588053881eb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java @@ -20,6 +20,8 @@ import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.DataType; +import org.apache.commons.lang3.NotImplementedException; + /** * Abstract class for all integral data type in Nereids. */ @@ -45,4 +47,9 @@ public String simpleString() { public boolean widerThan(IntegralType other) { return this.width() > other.width(); } + + // The maximum number of digits that Integer can represent. + public int range() { + throw new NotImplementedException("should be implemented by derived class"); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java index e0b7302d89fcef..53270848703fb0 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java @@ -425,6 +425,49 @@ void testDoubleLiteral() { new LessThan(bigIntSlot, new BigIntLiteral(13L))); assertRewrite(new LessThanEqual(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)), new LessThanEqual(bigIntSlot, new BigIntLiteral(12L))); + + // int and float literal near no loss bound + float noLossBoundF = 16777216.0f; // 2^24 + assertRewrite(new EqualTo(new Cast(intSlot, FloatType.INSTANCE), new FloatLiteral(-noLossBoundF)), + new EqualTo(new Cast(intSlot, FloatType.INSTANCE), new FloatLiteral(-noLossBoundF))); + assertRewrite(new EqualTo(new Cast(intSlot, FloatType.INSTANCE), new FloatLiteral(-16777215.0f)), + new EqualTo(intSlot, new IntegerLiteral(-16777215))); + assertRewrite(new EqualTo(new Cast(intSlot, FloatType.INSTANCE), new FloatLiteral(16777215.0f)), + new EqualTo(intSlot, new IntegerLiteral(16777215))); + assertRewrite(new EqualTo(new Cast(intSlot, FloatType.INSTANCE), new FloatLiteral(noLossBoundF)), + new EqualTo(new Cast(intSlot, FloatType.INSTANCE), new FloatLiteral(noLossBoundF))); + + // big int and double literal near no loss bound + double noLossBoundD = 9007199254740992.0; + assertRewrite(new EqualTo(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(-noLossBoundD)), + new EqualTo(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(-noLossBoundD))); + assertRewrite(new EqualTo(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(-9007199254740991.0)), + new EqualTo(bigIntSlot, new BigIntLiteral(-9007199254740991L))); + assertRewrite(new EqualTo(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(9007199254740991.0)), + new EqualTo(bigIntSlot, new BigIntLiteral(9007199254740991L))); + assertRewrite(new EqualTo(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(noLossBoundD)), + new EqualTo(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(noLossBoundD))); + } + + @Test + void testFloatNoLossBound() { + checkIntConvertFloatLikeLossBound(SimplifyComparisonPredicate.MIN_INT_TO_FLOAT_NO_LOSS, true); + checkIntConvertFloatLikeLossBound(SimplifyComparisonPredicate.MAX_INT_TO_FLOAT_NO_LOSS, true); + checkIntConvertFloatLikeLossBound(SimplifyComparisonPredicate.MIN_LONG_TO_DOUBLE_NO_LOSS, false); + checkIntConvertFloatLikeLossBound(SimplifyComparisonPredicate.MAX_LONG_TO_DOUBLE_NO_LOSS, false); + } + + private void checkIntConvertFloatLikeLossBound(long bound, boolean isFloat) { + for (int i = 0; i < 100000; i++) { + long v = (Math.abs(bound) - i) * Long.signum(bound); + long vCast = isFloat ? (long) ((float) v) : (long) ((double) v); + Assertions.assertEquals(v, vCast); + } + + long firstOutBound = (Math.abs(bound) + 1) * Long.signum(bound); + long firstOutBoundCast = isFloat ? (long) ((float) firstOutBound) : (long) ((double) firstOutBound); + Assertions.assertNotEquals(firstOutBound, firstOutBoundCast); + Assertions.assertEquals(bound, firstOutBoundCast); } @Test @@ -439,77 +482,83 @@ void testIntCmpDecimalV3Literal() { Expression bigIntSlot = new SlotReference("a", BigIntType.INSTANCE); // tiny int, literal not exceeds data type limit - assertRewrite(new EqualTo(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.0"))), + DecimalV3Type tinyDecimalType = DecimalV3Type.createDecimalV3Type(4, 1); + assertRewrite(new EqualTo(new Cast(tinyIntSlot, tinyDecimalType), new DecimalV3Literal(tinyDecimalType, new BigDecimal("12.0"))), new EqualTo(tinyIntSlot, new TinyIntLiteral((byte) 12))); - assertRewrite(new EqualTo(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new EqualTo(new Cast(tinyIntSlot, tinyDecimalType), new DecimalV3Literal(tinyDecimalType, new BigDecimal("12.3"))), ExpressionUtils.falseOrNull(tinyIntSlot)); - assertRewrite(new NullSafeEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new NullSafeEqual(new Cast(tinyIntSlot, tinyDecimalType), new DecimalV3Literal(tinyDecimalType, new BigDecimal("12.3"))), BooleanLiteral.FALSE); - assertRewrite(new GreaterThan(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new GreaterThan(new Cast(tinyIntSlot, tinyDecimalType), new DecimalV3Literal(tinyDecimalType, new BigDecimal("12.3"))), new GreaterThan(tinyIntSlot, new TinyIntLiteral((byte) 12))); - assertRewrite(new GreaterThanEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new GreaterThanEqual(new Cast(tinyIntSlot, tinyDecimalType), new DecimalV3Literal(tinyDecimalType, new BigDecimal("12.3"))), new GreaterThanEqual(tinyIntSlot, new TinyIntLiteral((byte) 13))); - assertRewrite(new LessThan(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new LessThan(new Cast(tinyIntSlot, tinyDecimalType), new DecimalV3Literal(tinyDecimalType, new BigDecimal("12.3"))), new LessThan(tinyIntSlot, new TinyIntLiteral((byte) 13))); - assertRewrite(new LessThanEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new LessThanEqual(new Cast(tinyIntSlot, tinyDecimalType), new DecimalV3Literal(tinyDecimalType, new BigDecimal("12.3"))), new LessThanEqual(tinyIntSlot, new TinyIntLiteral((byte) 12))); // small int - assertRewrite(new EqualTo(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.0"))), + DecimalV3Type smallDecimalType = DecimalV3Type.createDecimalV3Type(6, 1); + assertRewrite(new EqualTo(new Cast(smallIntSlot, smallDecimalType), new DecimalV3Literal(smallDecimalType, new BigDecimal("12.0"))), new EqualTo(smallIntSlot, new SmallIntLiteral((short) 12))); - assertRewrite(new EqualTo(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new EqualTo(new Cast(smallIntSlot, smallDecimalType), new DecimalV3Literal(smallDecimalType, new BigDecimal("12.3"))), ExpressionUtils.falseOrNull(smallIntSlot)); - assertRewrite(new NullSafeEqual(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new NullSafeEqual(new Cast(smallIntSlot, smallDecimalType), new DecimalV3Literal(smallDecimalType, new BigDecimal("12.3"))), BooleanLiteral.FALSE); - assertRewrite(new GreaterThan(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new GreaterThan(new Cast(smallIntSlot, smallDecimalType), new DecimalV3Literal(smallDecimalType, new BigDecimal("12.3"))), new GreaterThan(smallIntSlot, new SmallIntLiteral((short) 12))); - assertRewrite(new GreaterThanEqual(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new GreaterThanEqual(new Cast(smallIntSlot, smallDecimalType), new DecimalV3Literal(smallDecimalType, new BigDecimal("12.3"))), new GreaterThanEqual(smallIntSlot, new SmallIntLiteral((short) 13))); - assertRewrite(new LessThan(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new LessThan(new Cast(smallIntSlot, smallDecimalType), new DecimalV3Literal(smallDecimalType, new BigDecimal("12.3"))), new LessThan(smallIntSlot, new SmallIntLiteral((short) 13))); - assertRewrite(new LessThanEqual(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new LessThanEqual(new Cast(smallIntSlot, smallDecimalType), new DecimalV3Literal(smallDecimalType, new BigDecimal("12.3"))), new LessThanEqual(smallIntSlot, new SmallIntLiteral((short) 12))); // int - assertRewrite(new EqualTo(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.0"))), + DecimalV3Type intDecimalType = DecimalV3Type.createDecimalV3Type(11, 1); + assertRewrite(new EqualTo(new Cast(intSlot, intDecimalType), new DecimalV3Literal(intDecimalType, new BigDecimal("12.0"))), new EqualTo(intSlot, new IntegerLiteral(12))); - assertRewrite(new EqualTo(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new EqualTo(new Cast(intSlot, intDecimalType), new DecimalV3Literal(intDecimalType, new BigDecimal("12.3"))), ExpressionUtils.falseOrNull(intSlot)); - assertRewrite(new NullSafeEqual(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new NullSafeEqual(new Cast(intSlot, intDecimalType), new DecimalV3Literal(intDecimalType, new BigDecimal("12.3"))), BooleanLiteral.FALSE); - assertRewrite(new GreaterThan(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new GreaterThan(new Cast(intSlot, intDecimalType), new DecimalV3Literal(intDecimalType, new BigDecimal("12.3"))), new GreaterThan(intSlot, new IntegerLiteral(12))); - assertRewrite(new GreaterThanEqual(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new GreaterThanEqual(new Cast(intSlot, intDecimalType), new DecimalV3Literal(intDecimalType, new BigDecimal("12.3"))), new GreaterThanEqual(intSlot, new IntegerLiteral(13))); - assertRewrite(new LessThan(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new LessThan(new Cast(intSlot, intDecimalType), new DecimalV3Literal(intDecimalType, new BigDecimal("12.3"))), new LessThan(intSlot, new IntegerLiteral(13))); - assertRewrite(new LessThanEqual(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new LessThanEqual(new Cast(intSlot, intDecimalType), new DecimalV3Literal(intDecimalType, new BigDecimal("12.3"))), new LessThanEqual(intSlot, new IntegerLiteral(12))); // big int - assertRewrite(new EqualTo(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.0"))), + DecimalV3Type bigDecimalType = DecimalV3Type.createDecimalV3Type(20, 1); + assertRewrite(new EqualTo(new Cast(bigIntSlot, bigDecimalType), new DecimalV3Literal(bigDecimalType, new BigDecimal("12.0"))), new EqualTo(bigIntSlot, new BigIntLiteral(12L))); - assertRewrite(new EqualTo(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new EqualTo(new Cast(bigIntSlot, bigDecimalType), new DecimalV3Literal(bigDecimalType, new BigDecimal("12.3"))), ExpressionUtils.falseOrNull(bigIntSlot)); - assertRewrite(new NullSafeEqual(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new NullSafeEqual(new Cast(bigIntSlot, bigDecimalType), new DecimalV3Literal(bigDecimalType, new BigDecimal("12.3"))), BooleanLiteral.FALSE); - assertRewrite(new GreaterThan(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new GreaterThan(new Cast(bigIntSlot, bigDecimalType), new DecimalV3Literal(bigDecimalType, new BigDecimal("12.3"))), new GreaterThan(bigIntSlot, new BigIntLiteral(12L))); - assertRewrite(new GreaterThanEqual(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new GreaterThanEqual(new Cast(bigIntSlot, bigDecimalType), new DecimalV3Literal(bigDecimalType, new BigDecimal("12.3"))), new GreaterThanEqual(bigIntSlot, new BigIntLiteral(13L))); - assertRewrite(new LessThan(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new LessThan(new Cast(bigIntSlot, bigDecimalType), new DecimalV3Literal(bigDecimalType, new BigDecimal("12.3"))), new LessThan(bigIntSlot, new BigIntLiteral(13L))); - assertRewrite(new LessThanEqual(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new LessThanEqual(new Cast(bigIntSlot, bigDecimalType), new DecimalV3Literal(bigDecimalType, new BigDecimal("12.3"))), new LessThanEqual(bigIntSlot, new BigIntLiteral(12L))); - assertRewrite(new LessThan(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(20, 1)), new DecimalV3Literal(new BigDecimal("-9223372036854775808.1"))), - new LessThan(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(20, 1)), new DecimalV3Literal(new BigDecimal("-9223372036854775808.1")))); - assertRewrite(new LessThan(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(20, 1)), new DecimalV3Literal(new BigDecimal("-9223372036854775807.1"))), - new LessThan(bigIntSlot, new BigIntLiteral(-9223372036854775807L))); - assertRewrite(new GreaterThanEqual(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(20, 1)), new DecimalV3Literal(new BigDecimal("9223372036854775807.1"))), - new GreaterThanEqual(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(20, 1)), new DecimalV3Literal(new BigDecimal("9223372036854775807.1")))); - assertRewrite(new LessThan(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(20, 1)), new DecimalV3Literal(new BigDecimal("9223372036854775806.1"))), - new LessThan(bigIntSlot, new BigIntLiteral(9223372036854775807L))); + // do not convert cast(int as decimal) if decimal's range < int's range + DecimalV3Type noConvertDecimalType = DecimalV3Type.createDecimalV3Type(3, 1); + assertRewrite(new EqualTo(new Cast(tinyIntSlot, noConvertDecimalType), new DecimalV3Literal(new BigDecimal("12.0"))), + new EqualTo(new Cast(tinyIntSlot, noConvertDecimalType), new DecimalV3Literal(new BigDecimal("12.0")))); + assertRewrite(new EqualTo(new Cast(smallIntSlot, noConvertDecimalType), new DecimalV3Literal(new BigDecimal("12.0"))), + new EqualTo(new Cast(smallIntSlot, noConvertDecimalType), new DecimalV3Literal(new BigDecimal("12.0")))); + assertRewrite(new EqualTo(new Cast(intSlot, noConvertDecimalType), new DecimalV3Literal(new BigDecimal("12.0"))), + new EqualTo(new Cast(intSlot, noConvertDecimalType), new DecimalV3Literal(new BigDecimal("12.0")))); + assertRewrite(new EqualTo(new Cast(bigIntSlot, noConvertDecimalType), new DecimalV3Literal(new BigDecimal("12.0"))), + new EqualTo(new Cast(bigIntSlot, noConvertDecimalType), new DecimalV3Literal(new BigDecimal("12.0")))); } @Test @@ -640,5 +689,13 @@ void testDecimalCmpDecimalV3Literal() { rewrittenExpression.child(0).getDataType()); Assertions.assertInstanceOf(DecimalV3Literal.class, rewrittenExpression.child(1)); Assertions.assertEquals(new BigDecimal("12345.12"), ((DecimalV3Literal) rewrittenExpression.child(1)).getValue()); + + // don't convert for cast(b as decimal) if b's range > decimal's range + // or b's range = decimal's range and b's scale < decimal's scale + SlotReference decimalSlot = new SlotReference("slot", DecimalV3Type.createDecimalV3Type(5, 2), true); + assertRewrite(new EqualTo(new Cast(decimalSlot, DecimalV3Type.createDecimalV3Type(7, 5)), new DecimalV3Literal(new BigDecimal("12.34567"))), + new EqualTo(new Cast(decimalSlot, DecimalV3Type.createDecimalV3Type(7, 5)), new DecimalV3Literal(new BigDecimal("12.34567")))); + assertRewrite(new EqualTo(new Cast(decimalSlot, DecimalV3Type.createDecimalV3Type(4, 1)), new DecimalV3Literal(new BigDecimal("123.4"))), + new EqualTo(new Cast(decimalSlot, DecimalV3Type.createDecimalV3Type(4, 1)), new DecimalV3Literal(new BigDecimal("123.4")))); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3ComparisonTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3ComparisonTest.java index bc6b94836ea89a..9bf0c779b83d5b 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3ComparisonTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3ComparisonTest.java @@ -22,11 +22,11 @@ import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal; import org.apache.doris.nereids.types.DecimalV3Type; import com.google.common.collect.ImmutableList; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import java.math.BigDecimal; @@ -39,14 +39,24 @@ void testChildScaleLargerThanCast() { bottomUp(SimplifyDecimalV3Comparison.INSTANCE) )); - Expression leftChild = new DecimalV3Literal(new BigDecimal("1.23456")); - Expression left = new Cast(leftChild, DecimalV3Type.createDecimalV3Type(3, 2)); - Expression right = new DecimalV3Literal(new BigDecimal("1.20")); - Expression expression = new EqualTo(left, right); - Expression rewrittenExpression = executor.rewrite(expression, context); - Assertions.assertInstanceOf(Cast.class, rewrittenExpression.child(0)); - Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(3, 2), - rewrittenExpression.child(0).getDataType()); + DecimalV3Type castType = DecimalV3Type.createDecimalV3Type(10, 2); + Expression a = new SlotReference("a", DecimalV3Type.createDecimalV3Type(6, 5)); + Expression left = new Cast(a, castType); + Expression right = new DecimalV3Literal(castType, new BigDecimal("1.20")); + assertRewrite(new EqualTo(left, right), new EqualTo(left, right)); + } + + @Test + void testChildRangeLargerThanCast() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(SimplifyDecimalV3Comparison.INSTANCE) + )); + + DecimalV3Type castType = DecimalV3Type.createDecimalV3Type(7, 5); + Expression a = new SlotReference("a", DecimalV3Type.createDecimalV3Type(6, 3)); + Expression left = new Cast(a, castType); + Expression right = new DecimalV3Literal(castType, new BigDecimal("1.20")); + assertRewrite(new EqualTo(left, right), new EqualTo(left, right)); } @Test @@ -55,16 +65,11 @@ void testChildScaleSmallerThanCast() { bottomUp(SimplifyDecimalV3Comparison.INSTANCE) )); - Expression leftChild = new DecimalV3Literal(new BigDecimal("1.23456")); - Expression left = new Cast(leftChild, DecimalV3Type.createDecimalV3Type(10, 9)); + DecimalV3Type castType = DecimalV3Type.createDecimalV3Type(10, 9); + Expression a = new SlotReference("a", DecimalV3Type.createDecimalV3Type(6, 5)); + Expression left = new Cast(a, castType); Expression right = new DecimalV3Literal(new BigDecimal("1.200000000")); - Expression expression = new EqualTo(left, right); - Expression rewrittenExpression = executor.rewrite(expression, context); - Assertions.assertInstanceOf(DecimalV3Literal.class, rewrittenExpression.child(0)); - Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(6, 5), - rewrittenExpression.child(0).getDataType()); - Assertions.assertInstanceOf(DecimalV3Literal.class, rewrittenExpression.child(1)); - Assertions.assertEquals(new BigDecimal("1.20000"), - ((DecimalV3Literal) rewrittenExpression.child(1)).getValue()); + assertRewrite(new EqualTo(left, right), + new EqualTo(a, new DecimalV3Literal(new BigDecimal("1.20000")))); } } diff --git a/regression-test/data/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.out b/regression-test/data/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.out new file mode 100644 index 00000000000000..091cda0772c1cb --- /dev/null +++ b/regression-test/data/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.out @@ -0,0 +1,138 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !cast_bigint_as_double_shape -- +PhysicalResultSink +--PhysicalProject +----hashJoin[LEFT_OUTER_JOIN] hashCondition=((expr_cast(c_s as DOUBLE) = expr_cast(c_bigint as DOUBLE))) otherCondition=() +------PhysicalProject +--------filter((t1.c_s = '870479087484055553')) +----------PhysicalOlapScan[tbl1_test_simplify_comparison_predicate_int_vs_double] +------PhysicalProject +--------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !cast_bigint_as_double_result -- +100 870479087484055553 200 870479087484055553 999.999 + +-- !float_neg16777216_neg10 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as FLOAT) = -1.6777226E7)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !float_neg16777216_0 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as FLOAT) = -1.6777216E7)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !float_neg16777216_10 -- +PhysicalResultSink +--PhysicalProject +----filter((tbl2_test_simplify_comparison_predicate_int_vs_double.c_bigint = -16777206)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !float_16777216_neg10 -- +PhysicalResultSink +--PhysicalProject +----filter((tbl2_test_simplify_comparison_predicate_int_vs_double.c_bigint = 16777206)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !float_16777216_0 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as FLOAT) = 1.6777216E7)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !float_16777216_10 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as FLOAT) = 1.6777226E7)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !double_neg9007199254740992_neg10 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as DOUBLE) = -9.007199254741002E15)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !float_neg9007199254740992_neg10 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as FLOAT) = -9.0071993E15)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !double_neg9007199254740992_0 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as DOUBLE) = -9.007199254740992E15)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !float_neg9007199254740992_0 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as FLOAT) = -9.0071993E15)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !double_neg9007199254740992_10 -- +PhysicalResultSink +--PhysicalProject +----filter((tbl2_test_simplify_comparison_predicate_int_vs_double.c_bigint = -9007199254740982)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !float_neg9007199254740992_10 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as FLOAT) = -9.0071993E15)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !double_9007199254740992_neg10 -- +PhysicalResultSink +--PhysicalProject +----filter((tbl2_test_simplify_comparison_predicate_int_vs_double.c_bigint = 9007199254740982)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !float_9007199254740992_neg10 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as FLOAT) = 9.0071993E15)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !double_9007199254740992_0 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as DOUBLE) = 9.007199254740992E15)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !float_9007199254740992_0 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as FLOAT) = 9.0071993E15)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !double_9007199254740992_10 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as DOUBLE) = 9.007199254741002E15)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !float_9007199254740992_10 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as FLOAT) = 9.0071993E15)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !int_vs_double_3_shape -- +PhysicalResultSink +--filter((tbl2_test_simplify_comparison_predicate_int_vs_double.c_bigint > 123)) +----PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !int_vs_double_3_result -- +200 870479087484055553 999.999 + +-- !decimal_vs_decimal_4_shape -- +PhysicalResultSink +--filter((tbl2_test_simplify_comparison_predicate_int_vs_double.c_decimal > 123.456)) +----PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !decimal_vs_decimal_4_result -- +200 870479087484055553 999.999 + diff --git a/regression-test/suites/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.groovy b/regression-test/suites/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.groovy new file mode 100644 index 00000000000000..84ec6a3490abde --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.groovy @@ -0,0 +1,127 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite('test_simplify_comparison_predicate_int_vs_double') { + sql """ + set runtime_filter_mode='OFF'; + set disable_join_reorder=false; + set ignore_shape_nodes='PhysicalDistribute'; + + drop table if exists tbl1_test_simplify_comparison_predicate_int_vs_double force; + drop table if exists tbl2_test_simplify_comparison_predicate_int_vs_double force; + + create table tbl1_test_simplify_comparison_predicate_int_vs_double + (k1 int, c_s varchar(100)) properties('replication_num' = '1'); + create table tbl2_test_simplify_comparison_predicate_int_vs_double + (k2 int, c_bigint bigint, c_decimal decimal(6, 3)) properties('replication_num' = '1'); + + insert into tbl1_test_simplify_comparison_predicate_int_vs_double values + (100, "870479087484055553"), + (101,"870479087484055554"); + insert into tbl2_test_simplify_comparison_predicate_int_vs_double values + (200, 870479087484055553, 999.999); + """ + + + explainAndOrderResult 'cast_bigint_as_double', """ + select * + from tbl1_test_simplify_comparison_predicate_int_vs_double t1 + left join tbl2_test_simplify_comparison_predicate_int_vs_double t2 + on t1.c_s = t2.c_bigint + where t1.c_s = '870479087484055553' + """ + + for (def delimit : [-(1L<<24), 1L<<24]) { + for (def diff : [-10, 0, 10]) { + def tag = "float_${delimit}_${diff}".replace('-', 'neg') + "qt_${tag}" """ + explain shape plan + select c_bigint + from tbl2_test_simplify_comparison_predicate_int_vs_double + where cast(c_bigint as float) = cast('${delimit + diff}' as float) + """ + } + } + + for (def delimit : [-(1L<<53), 1L<<53]) { + for (def diff : [-10, 0, 10]) { + def tag = "double_${delimit}_${diff}".replace('-', 'neg') + "qt_${tag}" """ + explain shape plan + select c_bigint + from tbl2_test_simplify_comparison_predicate_int_vs_double + where cast(c_bigint as double) = cast('${delimit + diff}' as double) + """ + + tag = "float_${delimit}_${diff}".replace('-', 'neg') + "qt_${tag}" """ + explain shape plan + select c_bigint + from tbl2_test_simplify_comparison_predicate_int_vs_double + where cast(c_bigint as float) = cast('${delimit + diff}' as float) + """ + } + } + + explainAndOrderResult 'int_vs_double_3', """ + select * + from tbl2_test_simplify_comparison_predicate_int_vs_double + where c_bigint > 123.456 + """ + + explainAndOrderResult 'decimal_vs_decimal_4', """ + select * + from tbl2_test_simplify_comparison_predicate_int_vs_double + where c_decimal > 123.4567 + """ + + test { + sql """ + select * + from tbl2_test_simplify_comparison_predicate_int_vs_double + where cast(c_bigint as decimal(7, 3)) > 123.456 + """ + + exception 'Arithmetic overflow when converting value' + } + + test { + sql """ + select * + from tbl2_test_simplify_comparison_predicate_int_vs_double + where cast(c_decimal as decimal(3,1)) > cast(12.3 as decimal(5, 1)) + """ + + exception 'Arithmetic overflow when converting value' + } + + test { + sql """ + select * + from tbl2_test_simplify_comparison_predicate_int_vs_double + where cast(c_decimal as decimal(5,2)) > cast(123.45 as decimal(5, 2)) + """ + + exception 'Arithmetic overflow when converting value' + } + + sql """ + drop table if exists tbl1_test_simplify_comparison_predicate_int_vs_double force; + drop table if exists tbl2_test_simplify_comparison_predicate_int_vs_double force; + """ +}