diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java index c8328a5c913e08..7c3c4d6c4c615e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java @@ -60,6 +60,7 @@ import org.apache.doris.nereids.types.SmallIntType; import org.apache.doris.nereids.types.TinyIntType; import org.apache.doris.nereids.types.coercion.DateLikeType; +import org.apache.doris.nereids.types.coercion.IntegralType; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.TypeCoercionUtils; @@ -80,6 +81,11 @@ public class SimplifyComparisonPredicate implements ExpressionPatternRuleFactory { public static SimplifyComparisonPredicate INSTANCE = new SimplifyComparisonPredicate(); + public static final int MAX_INT_TO_FLOAT_NO_LOSS = 1 << 24; + public static final int MIN_INT_TO_FLOAT_NO_LOSS = -MAX_INT_TO_FLOAT_NO_LOSS; + public static final long MAX_LONG_TO_DOUBLE_NO_LOSS = 1L << 53; + public static final long MIN_LONG_TO_DOUBLE_NO_LOSS = -MAX_LONG_TO_DOUBLE_NO_LOSS; + @Override public List> buildRules() { return ImmutableList.of( @@ -400,19 +406,25 @@ private static Expression processDecimalV3TypeCoercion(ComparisonPredicate compa if (left instanceof Cast && right instanceof DecimalV3Literal) { Cast cast = (Cast) left; left = cast.child(); + DecimalV3Type castDataType = (DecimalV3Type) cast.getDataType(); DecimalV3Literal literal = (DecimalV3Literal) right; if (left.getDataType().isDecimalV3Type()) { + DecimalV3Type leftType = (DecimalV3Type) left.getDataType(); + if (castDataType.getRange() < leftType.getRange() + || (castDataType.getRange() == leftType.getRange() + && castDataType.getScale() < leftType.getScale())) { + // for cast(col as decimal(m2, n2)) cmp literal, + // if cast-to can not hold col's integer part, the cast result maybe null, don't process it. + return comparisonPredicate; + } Optional toSmallerDecimalDataTypeExpr = convertDecimalToSmallerDecimalV3Type( comparisonPredicate, cast, literal); if (toSmallerDecimalDataTypeExpr.isPresent()) { return toSmallerDecimalDataTypeExpr.get(); } - DecimalV3Type leftType = (DecimalV3Type) left.getDataType(); DecimalV3Type literalType = (DecimalV3Type) literal.getDataType(); - if (cast.getDataType().isDecimalV3Type() - && ((DecimalV3Type) cast.getDataType()).getScale() >= leftType.getScale() - && leftType.getScale() < literalType.getScale()) { + if (castDataType.getScale() >= leftType.getScale() && leftType.getScale() < literalType.getScale()) { int toScale = ((DecimalV3Type) left.getDataType()).getScale(); if (comparisonPredicate instanceof EqualTo) { try { @@ -457,9 +469,17 @@ private static Expression processDecimalV3TypeCoercion(ComparisonPredicate compa private static Expression processIntegerDecimalLiteralComparison( ComparisonPredicate comparisonPredicate, Expression left, BigDecimal literal) { // we only process isIntegerLikeType, which are tinyint, smallint, int, bigint + // for `cast(c_int as decimal(m, n)) cmp literal`, + // if c_int's range is wider than decimal's range, cast result maybe null, don't process it. + DataType castDataType = comparisonPredicate.left().getDataType(); + if (castDataType.isDecimalV3Type() + && ((DecimalV3Type) castDataType).getRange() < ((IntegralType) left.getDataType()).range()) { + return comparisonPredicate; + } if (literal.compareTo(new BigDecimal(Long.MIN_VALUE)) >= 0 && literal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0) { literal = literal.stripTrailingZeros(); + Optional roundLiteralOpt = Optional.empty(); if (literal.scale() > 0) { if (comparisonPredicate instanceof EqualTo) { // TODO: the ideal way is to return an If expr like: @@ -473,21 +493,22 @@ private static Expression processIntegerDecimalLiteralComparison( return BooleanLiteral.of(false); } else if (comparisonPredicate instanceof GreaterThan || comparisonPredicate instanceof LessThanEqual) { - return TypeCoercionUtils - .processComparisonPredicate((ComparisonPredicate) comparisonPredicate - .withChildren(left, convertDecimalToIntegerLikeLiteral( - literal.setScale(0, RoundingMode.FLOOR)))); + roundLiteralOpt = Optional.of(literal.setScale(0, RoundingMode.FLOOR)); } else if (comparisonPredicate instanceof LessThan || comparisonPredicate instanceof GreaterThanEqual) { + roundLiteralOpt = Optional.of(literal.setScale(0, RoundingMode.CEILING)); + } + } else { + roundLiteralOpt = Optional.of(literal); + } + if (roundLiteralOpt.isPresent()) { + Optional integerLikeLiteralOpt + = convertDecimalToIntegerLikeLiteral(roundLiteralOpt.get(), castDataType); + if (integerLikeLiteralOpt.isPresent()) { return TypeCoercionUtils .processComparisonPredicate((ComparisonPredicate) comparisonPredicate - .withChildren(left, convertDecimalToIntegerLikeLiteral( - literal.setScale(0, RoundingMode.CEILING)))); + .withChildren(left, integerLikeLiteralOpt.get())); } - } else { - return TypeCoercionUtils - .processComparisonPredicate((ComparisonPredicate) comparisonPredicate - .withChildren(left, convertDecimalToIntegerLikeLiteral(literal))); } } return comparisonPredicate; @@ -615,20 +636,39 @@ private static Optional convertDecimalToSmallerDecimalV3Type(Compari return Optional.empty(); } - private static IntegerLikeLiteral convertDecimalToIntegerLikeLiteral(BigDecimal decimal) { + private static Optional convertDecimalToIntegerLikeLiteral(BigDecimal decimal, + DataType castDataType) { Preconditions.checkArgument(decimal.scale() <= 0 && decimal.compareTo(new BigDecimal(Long.MIN_VALUE)) >= 0 && decimal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0, "decimal literal must have 0 scale and in range [Long.MIN_VALUE, Long.MAX_VALUE]"); long val = decimal.longValue(); + // for integer like convert to float, only [-2^24, 2^24] can convert to float without loss of precision, + // but here need to exclude the boundary value, because + // cast(2^24 as float) = cast(2^24 + 1 as float) = 2^24 = MAX_INT_TO_FLOAT_NO_LOSS, + // so for cast(c_int as float) = 2^24, we can't simplify it to c_int = 2^24, + // c_int can be 2^24 + 1. The same for -2^24 + if (castDataType.isFloatType() + && (val <= MIN_INT_TO_FLOAT_NO_LOSS || val >= MAX_INT_TO_FLOAT_NO_LOSS)) { + return Optional.empty(); + } + // for long convert to double, only [-2^53, 2^53] can convert to double without loss of precision, + // but here need to exclude the boundary value, because + // cast(2^53 as double) = cast(2^53 + 1 as double) = 2^53 = MAX_LONG_TO_DOUBLE_NO_LOSS, + // so for cast(c_bigint as double) = 2^53, we can't simplify it to c_bigint = 2^53, + // c_bigint can be 2^53 + 1. The same for -2^53 + if (castDataType.isDoubleType() + && (val <= MIN_LONG_TO_DOUBLE_NO_LOSS || val >= MAX_LONG_TO_DOUBLE_NO_LOSS)) { + return Optional.empty(); + } if (val >= Byte.MIN_VALUE && val <= Byte.MAX_VALUE) { - return new TinyIntLiteral((byte) val); + return Optional.of(new TinyIntLiteral((byte) val)); } else if (val >= Short.MIN_VALUE && val <= Short.MAX_VALUE) { - return new SmallIntLiteral((short) val); + return Optional.of(new SmallIntLiteral((short) val)); } else if (val >= Integer.MIN_VALUE && val <= Integer.MAX_VALUE) { - return new IntegerLiteral((int) val); + return Optional.of(new IntegerLiteral((int) val)); } else { - return new BigIntLiteral(val); + return Optional.of(new BigIntLiteral(val)); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Cast.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Cast.java index 04ab287a8f9c6e..d225954dec8d47 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Cast.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Cast.java @@ -49,16 +49,12 @@ public class Cast extends Expression implements UnaryExpression, Monotonic { private final DataType targetType; - public Cast(Expression child, DataType targetType, boolean isExplicitType) { - super(ImmutableList.of(child)); - this.targetType = Objects.requireNonNull(targetType, "targetType can not be null"); - this.isExplicitType = isExplicitType; + public Cast(Expression child, DataType targetType) { + this(child, targetType, false); } - public Cast(Expression child, DataType targetType) { - super(ImmutableList.of(child)); - this.targetType = Objects.requireNonNull(targetType, "targetType can not be null"); - this.isExplicitType = false; + public Cast(Expression child, DataType targetType, boolean isExplicitType) { + this(ImmutableList.of(child), targetType, isExplicitType); } private Cast(List child, DataType targetType, boolean isExplicitType) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/BigIntType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/BigIntType.java index 767db0e89a9d32..f654b9f271ba81 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/BigIntType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/BigIntType.java @@ -62,4 +62,9 @@ public DataType defaultConcreteType() { public int width() { return WIDTH; } + + @Override + public int range() { + return RANGE; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/IntegerType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/IntegerType.java index 2b22c167d464f0..7b47ff90b72270 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/IntegerType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/IntegerType.java @@ -62,4 +62,9 @@ public DataType defaultConcreteType() { public int width() { return WIDTH; } + + @Override + public int range() { + return RANGE; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/LargeIntType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/LargeIntType.java index 78a9369bd26830..22e88206c6abd3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/LargeIntType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/LargeIntType.java @@ -71,4 +71,9 @@ public DataType defaultConcreteType() { public int width() { return WIDTH; } + + @Override + public int range() { + return RANGE; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/SmallIntType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/SmallIntType.java index 4272052cd1a938..ad2743a4a7e19a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/SmallIntType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/SmallIntType.java @@ -62,4 +62,9 @@ public DataType defaultConcreteType() { public int width() { return WIDTH; } + + @Override + public int range() { + return RANGE; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TinyIntType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TinyIntType.java index 30259582127e88..075cb945db2479 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TinyIntType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TinyIntType.java @@ -62,4 +62,9 @@ public DataType defaultConcreteType() { public int width() { return WIDTH; } + + @Override + public int range() { + return RANGE; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java index 656c6f660f2090..b1e588053881eb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java @@ -20,6 +20,8 @@ import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.DataType; +import org.apache.commons.lang3.NotImplementedException; + /** * Abstract class for all integral data type in Nereids. */ @@ -45,4 +47,9 @@ public String simpleString() { public boolean widerThan(IntegralType other) { return this.width() > other.width(); } + + // The maximum number of digits that Integer can represent. + public int range() { + throw new NotImplementedException("should be implemented by derived class"); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java index 74f569f52c3b5c..b5ad2dda3d443b 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java @@ -596,6 +596,49 @@ void testDoubleLiteral() { new LessThan(bigIntSlot, new BigIntLiteral(13L))); assertRewrite(new LessThanEqual(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)), new LessThanEqual(bigIntSlot, new BigIntLiteral(12L))); + + // int and float literal near no loss bound + float noLossBoundF = 16777216.0f; // 2^24 + assertRewrite(new EqualTo(new Cast(intSlot, FloatType.INSTANCE), new FloatLiteral(-noLossBoundF)), + new EqualTo(new Cast(intSlot, FloatType.INSTANCE), new FloatLiteral(-noLossBoundF))); + assertRewrite(new EqualTo(new Cast(intSlot, FloatType.INSTANCE), new FloatLiteral(-16777215.0f)), + new EqualTo(intSlot, new IntegerLiteral(-16777215))); + assertRewrite(new EqualTo(new Cast(intSlot, FloatType.INSTANCE), new FloatLiteral(16777215.0f)), + new EqualTo(intSlot, new IntegerLiteral(16777215))); + assertRewrite(new EqualTo(new Cast(intSlot, FloatType.INSTANCE), new FloatLiteral(noLossBoundF)), + new EqualTo(new Cast(intSlot, FloatType.INSTANCE), new FloatLiteral(noLossBoundF))); + + // big int and double literal near no loss bound + double noLossBoundD = 9007199254740992.0; + assertRewrite(new EqualTo(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(-noLossBoundD)), + new EqualTo(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(-noLossBoundD))); + assertRewrite(new EqualTo(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(-9007199254740991.0)), + new EqualTo(bigIntSlot, new BigIntLiteral(-9007199254740991L))); + assertRewrite(new EqualTo(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(9007199254740991.0)), + new EqualTo(bigIntSlot, new BigIntLiteral(9007199254740991L))); + assertRewrite(new EqualTo(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(noLossBoundD)), + new EqualTo(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(noLossBoundD))); + } + + @Test + void testFloatNoLossBound() { + checkIntConvertFloatLikeLossBound(SimplifyComparisonPredicate.MIN_INT_TO_FLOAT_NO_LOSS, true); + checkIntConvertFloatLikeLossBound(SimplifyComparisonPredicate.MAX_INT_TO_FLOAT_NO_LOSS, true); + checkIntConvertFloatLikeLossBound(SimplifyComparisonPredicate.MIN_LONG_TO_DOUBLE_NO_LOSS, false); + checkIntConvertFloatLikeLossBound(SimplifyComparisonPredicate.MAX_LONG_TO_DOUBLE_NO_LOSS, false); + } + + private void checkIntConvertFloatLikeLossBound(long bound, boolean isFloat) { + for (int i = 0; i < 100000; i++) { + long v = (Math.abs(bound) - i) * Long.signum(bound); + long vCast = isFloat ? (long) ((float) v) : (long) ((double) v); + Assertions.assertEquals(v, vCast); + } + + long firstOutBound = (Math.abs(bound) + 1) * Long.signum(bound); + long firstOutBoundCast = isFloat ? (long) ((float) firstOutBound) : (long) ((double) firstOutBound); + Assertions.assertNotEquals(firstOutBound, firstOutBoundCast); + Assertions.assertEquals(bound, firstOutBoundCast); } @Test @@ -610,95 +653,110 @@ void testIntCmpDecimalV3Literal() { Expression bigIntSlot = new SlotReference("a", BigIntType.INSTANCE); // tiny int, literal not exceeds data type limit - assertRewrite(new EqualTo(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.0"))), + DecimalV3Type tinyDecimalType = DecimalV3Type.createDecimalV3Type(4, 1); + assertRewrite(new EqualTo(new Cast(tinyIntSlot, tinyDecimalType), new DecimalV3Literal(tinyDecimalType, new BigDecimal("12.0"))), new EqualTo(tinyIntSlot, new TinyIntLiteral((byte) 12))); - assertRewrite(new EqualTo(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new EqualTo(new Cast(tinyIntSlot, tinyDecimalType), new DecimalV3Literal(tinyDecimalType, new BigDecimal("12.3"))), ExpressionUtils.falseOrNull(tinyIntSlot)); - assertRewrite(new NullSafeEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new NullSafeEqual(new Cast(tinyIntSlot, tinyDecimalType), new DecimalV3Literal(tinyDecimalType, new BigDecimal("12.3"))), BooleanLiteral.FALSE); - assertRewrite(new GreaterThan(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new GreaterThan(new Cast(tinyIntSlot, tinyDecimalType), new DecimalV3Literal(tinyDecimalType, new BigDecimal("12.3"))), new GreaterThan(tinyIntSlot, new TinyIntLiteral((byte) 12))); - assertRewrite(new GreaterThanEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new GreaterThanEqual(new Cast(tinyIntSlot, tinyDecimalType), new DecimalV3Literal(tinyDecimalType, new BigDecimal("12.3"))), new GreaterThanEqual(tinyIntSlot, new TinyIntLiteral((byte) 13))); - assertRewrite(new LessThan(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new LessThan(new Cast(tinyIntSlot, tinyDecimalType), new DecimalV3Literal(tinyDecimalType, new BigDecimal("12.3"))), new LessThan(tinyIntSlot, new TinyIntLiteral((byte) 13))); - assertRewrite(new LessThanEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new LessThanEqual(new Cast(tinyIntSlot, tinyDecimalType), new DecimalV3Literal(tinyDecimalType, new BigDecimal("12.3"))), new LessThanEqual(tinyIntSlot, new TinyIntLiteral((byte) 12))); // tiny int, literal exceeds data type limit - assertRewrite(new EqualTo(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4, 1)), new DecimalV3Literal(new BigDecimal("200.0"))), + assertRewrite(new EqualTo(new Cast(tinyIntSlot, tinyDecimalType), new DecimalV3Literal(new BigDecimal("200.0"))), ExpressionUtils.falseOrNull(tinyIntSlot)); - assertRewrite(new EqualTo(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4, 1)), new DecimalV3Literal(new BigDecimal("200.3"))), + assertRewrite(new EqualTo(new Cast(tinyIntSlot, tinyDecimalType), new DecimalV3Literal(new BigDecimal("200.3"))), ExpressionUtils.falseOrNull(tinyIntSlot)); - assertRewrite(new NullSafeEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4, 1)), new DecimalV3Literal(new BigDecimal("200.3"))), + assertRewrite(new NullSafeEqual(new Cast(tinyIntSlot, tinyDecimalType), new DecimalV3Literal(new BigDecimal("200.3"))), BooleanLiteral.FALSE); - assertRewrite(new GreaterThan(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4, 1)), new DecimalV3Literal(new BigDecimal("200.3"))), + assertRewrite(new GreaterThan(new Cast(tinyIntSlot, tinyDecimalType), new DecimalV3Literal(new BigDecimal("200.3"))), ExpressionUtils.falseOrNull(tinyIntSlot)); - assertRewrite(new GreaterThanEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4, 1)), new DecimalV3Literal(new BigDecimal("200.3"))), + assertRewrite(new GreaterThanEqual(new Cast(tinyIntSlot, tinyDecimalType), new DecimalV3Literal(new BigDecimal("200.3"))), ExpressionUtils.falseOrNull(tinyIntSlot)); - assertRewrite(new LessThan(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4, 1)), new DecimalV3Literal(new BigDecimal("200.3"))), + assertRewrite(new LessThan(new Cast(tinyIntSlot, tinyDecimalType), new DecimalV3Literal(new BigDecimal("200.3"))), ExpressionUtils.trueOrNull(tinyIntSlot)); - assertRewrite(new LessThanEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4, 1)), new DecimalV3Literal(new BigDecimal("200.3"))), + assertRewrite(new LessThanEqual(new Cast(tinyIntSlot, tinyDecimalType), new DecimalV3Literal(new BigDecimal("200.3"))), ExpressionUtils.trueOrNull(tinyIntSlot)); // small int - assertRewrite(new EqualTo(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.0"))), + DecimalV3Type smallDecimalType = DecimalV3Type.createDecimalV3Type(6, 1); + assertRewrite(new EqualTo(new Cast(smallIntSlot, smallDecimalType), new DecimalV3Literal(smallDecimalType, new BigDecimal("12.0"))), new EqualTo(smallIntSlot, new SmallIntLiteral((short) 12))); - assertRewrite(new EqualTo(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new EqualTo(new Cast(smallIntSlot, smallDecimalType), new DecimalV3Literal(smallDecimalType, new BigDecimal("12.3"))), ExpressionUtils.falseOrNull(smallIntSlot)); - assertRewrite(new NullSafeEqual(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new NullSafeEqual(new Cast(smallIntSlot, smallDecimalType), new DecimalV3Literal(smallDecimalType, new BigDecimal("12.3"))), BooleanLiteral.FALSE); - assertRewrite(new GreaterThan(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new GreaterThan(new Cast(smallIntSlot, smallDecimalType), new DecimalV3Literal(smallDecimalType, new BigDecimal("12.3"))), new GreaterThan(smallIntSlot, new SmallIntLiteral((short) 12))); - assertRewrite(new GreaterThanEqual(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new GreaterThanEqual(new Cast(smallIntSlot, smallDecimalType), new DecimalV3Literal(smallDecimalType, new BigDecimal("12.3"))), new GreaterThanEqual(smallIntSlot, new SmallIntLiteral((short) 13))); - assertRewrite(new LessThan(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new LessThan(new Cast(smallIntSlot, smallDecimalType), new DecimalV3Literal(smallDecimalType, new BigDecimal("12.3"))), new LessThan(smallIntSlot, new SmallIntLiteral((short) 13))); - assertRewrite(new LessThanEqual(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new LessThanEqual(new Cast(smallIntSlot, smallDecimalType), new DecimalV3Literal(smallDecimalType, new BigDecimal("12.3"))), new LessThanEqual(smallIntSlot, new SmallIntLiteral((short) 12))); // int - assertRewrite(new EqualTo(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.0"))), + DecimalV3Type intDecimalType = DecimalV3Type.createDecimalV3Type(11, 1); + assertRewrite(new EqualTo(new Cast(intSlot, intDecimalType), new DecimalV3Literal(intDecimalType, new BigDecimal("12.0"))), new EqualTo(intSlot, new IntegerLiteral(12))); - assertRewrite(new EqualTo(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new EqualTo(new Cast(intSlot, intDecimalType), new DecimalV3Literal(intDecimalType, new BigDecimal("12.3"))), ExpressionUtils.falseOrNull(intSlot)); - assertRewrite(new NullSafeEqual(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new NullSafeEqual(new Cast(intSlot, intDecimalType), new DecimalV3Literal(intDecimalType, new BigDecimal("12.3"))), BooleanLiteral.FALSE); - assertRewrite(new GreaterThan(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new GreaterThan(new Cast(intSlot, intDecimalType), new DecimalV3Literal(intDecimalType, new BigDecimal("12.3"))), new GreaterThan(intSlot, new IntegerLiteral(12))); - assertRewrite(new GreaterThanEqual(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new GreaterThanEqual(new Cast(intSlot, intDecimalType), new DecimalV3Literal(intDecimalType, new BigDecimal("12.3"))), new GreaterThanEqual(intSlot, new IntegerLiteral(13))); - assertRewrite(new LessThan(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new LessThan(new Cast(intSlot, intDecimalType), new DecimalV3Literal(intDecimalType, new BigDecimal("12.3"))), new LessThan(intSlot, new IntegerLiteral(13))); - assertRewrite(new LessThanEqual(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new LessThanEqual(new Cast(intSlot, intDecimalType), new DecimalV3Literal(intDecimalType, new BigDecimal("12.3"))), new LessThanEqual(intSlot, new IntegerLiteral(12))); // big int - assertRewrite(new EqualTo(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.0"))), + DecimalV3Type bigDecimalType = DecimalV3Type.createDecimalV3Type(20, 1); + assertRewrite(new EqualTo(new Cast(bigIntSlot, bigDecimalType), new DecimalV3Literal(bigDecimalType, new BigDecimal("12.0"))), new EqualTo(bigIntSlot, new BigIntLiteral(12L))); - assertRewrite(new EqualTo(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new EqualTo(new Cast(bigIntSlot, bigDecimalType), new DecimalV3Literal(bigDecimalType, new BigDecimal("12.3"))), ExpressionUtils.falseOrNull(bigIntSlot)); - assertRewrite(new NullSafeEqual(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new NullSafeEqual(new Cast(bigIntSlot, bigDecimalType), new DecimalV3Literal(bigDecimalType, new BigDecimal("12.3"))), BooleanLiteral.FALSE); - assertRewrite(new GreaterThan(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new GreaterThan(new Cast(bigIntSlot, bigDecimalType), new DecimalV3Literal(bigDecimalType, new BigDecimal("12.3"))), new GreaterThan(bigIntSlot, new BigIntLiteral(12L))); - assertRewrite(new GreaterThanEqual(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new GreaterThanEqual(new Cast(bigIntSlot, bigDecimalType), new DecimalV3Literal(bigDecimalType, new BigDecimal("12.3"))), new GreaterThanEqual(bigIntSlot, new BigIntLiteral(13L))); - assertRewrite(new LessThan(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new LessThan(new Cast(bigIntSlot, bigDecimalType), new DecimalV3Literal(bigDecimalType, new BigDecimal("12.3"))), new LessThan(bigIntSlot, new BigIntLiteral(13L))); - assertRewrite(new LessThanEqual(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new LessThanEqual(new Cast(bigIntSlot, bigDecimalType), new DecimalV3Literal(bigDecimalType, new BigDecimal("12.3"))), new LessThanEqual(bigIntSlot, new BigIntLiteral(12L))); - assertRewrite(new LessThan(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(20, 1)), new DecimalV3Literal(new BigDecimal("-9223372036854775808.1"))), + assertRewrite(new LessThan(new Cast(bigIntSlot, bigDecimalType), new DecimalV3Literal(new BigDecimal("-9223372036854775808.1"))), ExpressionUtils.falseOrNull(bigIntSlot)); - assertRewrite(new GreaterThanEqual(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(20, 1)), new DecimalV3Literal(new BigDecimal("-9223372036854775808.1"))), + assertRewrite(new GreaterThanEqual(new Cast(bigIntSlot, bigDecimalType), new DecimalV3Literal(new BigDecimal("-9223372036854775808.1"))), ExpressionUtils.trueOrNull(bigIntSlot)); - assertRewrite(new LessThan(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(20, 1)), new DecimalV3Literal(new BigDecimal("-9223372036854775807.1"))), + assertRewrite(new LessThan(new Cast(bigIntSlot, bigDecimalType), new DecimalV3Literal(new BigDecimal("-9223372036854775807.1"))), new LessThan(bigIntSlot, new BigIntLiteral(-9223372036854775807L))); - assertRewrite(new GreaterThanEqual(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(20, 1)), new DecimalV3Literal(new BigDecimal("9223372036854775807.1"))), + assertRewrite(new GreaterThanEqual(new Cast(bigIntSlot, bigDecimalType), new DecimalV3Literal(new BigDecimal("9223372036854775807.1"))), ExpressionUtils.falseOrNull(bigIntSlot)); - assertRewrite(new LessThan(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(20, 1)), new DecimalV3Literal(new BigDecimal("9223372036854775806.1"))), + assertRewrite(new LessThan(new Cast(bigIntSlot, bigDecimalType), new DecimalV3Literal(new BigDecimal("9223372036854775806.1"))), new LessThan(bigIntSlot, new BigIntLiteral(9223372036854775807L))); + + // do not convert cast(int as decimal) if decimal's range < int's range + DecimalV3Type noConvertDecimalType = DecimalV3Type.createDecimalV3Type(3, 1); + assertRewrite(new EqualTo(new Cast(tinyIntSlot, noConvertDecimalType), new DecimalV3Literal(new BigDecimal("12.0"))), + new EqualTo(new Cast(tinyIntSlot, noConvertDecimalType), new DecimalV3Literal(new BigDecimal("12.0")))); + assertRewrite(new EqualTo(new Cast(smallIntSlot, noConvertDecimalType), new DecimalV3Literal(new BigDecimal("12.0"))), + new EqualTo(new Cast(smallIntSlot, noConvertDecimalType), new DecimalV3Literal(new BigDecimal("12.0")))); + assertRewrite(new EqualTo(new Cast(intSlot, noConvertDecimalType), new DecimalV3Literal(new BigDecimal("12.0"))), + new EqualTo(new Cast(intSlot, noConvertDecimalType), new DecimalV3Literal(new BigDecimal("12.0")))); + assertRewrite(new EqualTo(new Cast(bigIntSlot, noConvertDecimalType), new DecimalV3Literal(new BigDecimal("12.0"))), + new EqualTo(new Cast(bigIntSlot, noConvertDecimalType), new DecimalV3Literal(new BigDecimal("12.0")))); } @Test @@ -861,6 +919,14 @@ void testDecimalCmpDecimalV3Literal() { Assertions.assertInstanceOf(DecimalV3Literal.class, rewrittenExpression.child(1)); Assertions.assertEquals(new BigDecimal("1.20000"), ((DecimalV3Literal) rewrittenExpression.child(1)).getValue()); + + // don't convert for cast(b as decimal) if b's range > decimal's range + // or b's range = decimal's range and b's scale < decimal's scale + SlotReference decimalSlot = new SlotReference("slot", DecimalV3Type.createDecimalV3Type(5, 2), true); + assertRewrite(new EqualTo(new Cast(decimalSlot, DecimalV3Type.createDecimalV3Type(7, 5)), new DecimalV3Literal(new BigDecimal("12.34567"))), + new EqualTo(new Cast(decimalSlot, DecimalV3Type.createDecimalV3Type(7, 5)), new DecimalV3Literal(new BigDecimal("12.34567")))); + assertRewrite(new EqualTo(new Cast(decimalSlot, DecimalV3Type.createDecimalV3Type(4, 1)), new DecimalV3Literal(new BigDecimal("123.4"))), + new EqualTo(new Cast(decimalSlot, DecimalV3Type.createDecimalV3Type(4, 1)), new DecimalV3Literal(new BigDecimal("123.4")))); } private enum RangeLimitResult { diff --git a/regression-test/data/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.out b/regression-test/data/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.out new file mode 100644 index 00000000000000..59ae6df0d908b5 --- /dev/null +++ b/regression-test/data/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.out @@ -0,0 +1,173 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !cast_bigint_as_double_shape -- +PhysicalResultSink +--NestedLoopJoin[LEFT_OUTER_JOIN] +----filter((t1.c_s = '870479087484055553')) +------PhysicalOlapScan[tbl1_test_simplify_comparison_predicate_int_vs_double] +----filter((cast(c_bigint as DOUBLE) = 8.7047908748405555E17)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !cast_bigint_as_double_result -- +100 870479087484055553 200 870479087484055553 999.999 + +-- !float_neg16777216_neg10 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as FLOAT) = -1.6777226E7)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !float_neg16777216_0 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as FLOAT) = -1.6777216E7)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !float_neg16777216_10 -- +PhysicalResultSink +--PhysicalProject +----filter((tbl2_test_simplify_comparison_predicate_int_vs_double.c_bigint = -16777206)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !float_16777216_neg10 -- +PhysicalResultSink +--PhysicalProject +----filter((tbl2_test_simplify_comparison_predicate_int_vs_double.c_bigint = 16777206)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !float_16777216_0 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as FLOAT) = 1.6777216E7)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !float_16777216_10 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as FLOAT) = 1.6777226E7)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !double_neg9007199254740992_neg10 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as DOUBLE) = -9.007199254741002E15)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !float_neg9007199254740992_neg10 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as FLOAT) = -9.0071993E15)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !double_neg9007199254740992_0 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as DOUBLE) = -9.007199254740992E15)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !float_neg9007199254740992_0 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as FLOAT) = -9.0071993E15)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !double_neg9007199254740992_10 -- +PhysicalResultSink +--PhysicalProject +----filter((tbl2_test_simplify_comparison_predicate_int_vs_double.c_bigint = -9007199254740982)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !float_neg9007199254740992_10 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as FLOAT) = -9.0071993E15)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !double_9007199254740992_neg10 -- +PhysicalResultSink +--PhysicalProject +----filter((tbl2_test_simplify_comparison_predicate_int_vs_double.c_bigint = 9007199254740982)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !float_9007199254740992_neg10 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as FLOAT) = 9.0071993E15)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !double_9007199254740992_0 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as DOUBLE) = 9.007199254740992E15)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !float_9007199254740992_0 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as FLOAT) = 9.0071993E15)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !double_9007199254740992_10 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as DOUBLE) = 9.007199254741002E15)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !float_9007199254740992_10 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as FLOAT) = 9.0071993E15)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !int_vs_double_1_shape -- +PhysicalResultSink +--filter((tbl2_test_simplify_comparison_predicate_int_vs_double.c_bigint > 123)) +----PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !int_vs_double_1_result -- +200 870479087484055553 999.999 + +-- !int_vs_double_2_shape -- +PhysicalResultSink +--filter((cast(c_bigint as DECIMALV3(7, 3)) > 123.456)) +----PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !int_vs_double_2_result -- + +-- !decimal_vs_decimal_1_shape -- +PhysicalResultSink +--filter((tbl2_test_simplify_comparison_predicate_int_vs_double.c_decimal > 123.456)) +----PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !decimal_vs_decimal_1_result -- +200 870479087484055553 999.999 + +-- !decimal_vs_decimal_2_shape -- +PhysicalResultSink +--filter((cast(c_decimal as DECIMALV3(3, 1)) > 12.3)) +----PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !decimal_vs_decimal_2_result -- + +-- !decimal_vs_decimal_3_shape -- +PhysicalResultSink +--filter((cast(c_decimal as DECIMALV3(5, 2)) > 123.45)) +----PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !decimal_vs_decimal_3_result -- + +-- !int_vs_double_3_shape -- +PhysicalResultSink +--filter((tbl2_test_simplify_comparison_predicate_int_vs_double.c_bigint > 123)) +----PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !int_vs_double_3_result -- +200 870479087484055553 999.999 + +-- !decimal_vs_decimal_4_shape -- +PhysicalResultSink +--filter((tbl2_test_simplify_comparison_predicate_int_vs_double.c_decimal > 123.456)) +----PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !decimal_vs_decimal_4_result -- +200 870479087484055553 999.999 + diff --git a/regression-test/suites/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.groovy b/regression-test/suites/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.groovy new file mode 100644 index 00000000000000..cfe0b435ef24af --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.groovy @@ -0,0 +1,161 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite('test_simplify_comparison_predicate_int_vs_double') { + sql """ + set runtime_filter_mode='OFF'; + set disable_join_reorder=false; + set ignore_shape_nodes='PhysicalDistribute'; + + drop table if exists tbl1_test_simplify_comparison_predicate_int_vs_double force; + drop table if exists tbl2_test_simplify_comparison_predicate_int_vs_double force; + + create table tbl1_test_simplify_comparison_predicate_int_vs_double + (k1 int, c_s varchar(100)) properties('replication_num' = '1'); + create table tbl2_test_simplify_comparison_predicate_int_vs_double + (k2 int, c_bigint bigint, c_decimal decimal(6, 3)) properties('replication_num' = '1'); + + insert into tbl1_test_simplify_comparison_predicate_int_vs_double values + (100, "870479087484055553"), + (101,"870479087484055554"); + insert into tbl2_test_simplify_comparison_predicate_int_vs_double values + (200, 870479087484055553, 999.999); + """ + + + explainAndOrderResult 'cast_bigint_as_double', """ + select * + from tbl1_test_simplify_comparison_predicate_int_vs_double t1 + left join tbl2_test_simplify_comparison_predicate_int_vs_double t2 + on t1.c_s = t2.c_bigint + where t1.c_s = '870479087484055553' + """ + + for (def delimit : [-(1L<<24), 1L<<24]) { + for (def diff : [-10, 0, 10]) { + def tag = "float_${delimit}_${diff}".replace('-', 'neg') + "qt_${tag}" """ + explain shape plan + select c_bigint + from tbl2_test_simplify_comparison_predicate_int_vs_double + where cast(c_bigint as float) = cast('${delimit + diff}' as float) + """ + } + } + + for (def delimit : [-(1L<<53), 1L<<53]) { + for (def diff : [-10, 0, 10]) { + def tag = "double_${delimit}_${diff}".replace('-', 'neg') + "qt_${tag}" """ + explain shape plan + select c_bigint + from tbl2_test_simplify_comparison_predicate_int_vs_double + where cast(c_bigint as double) = cast('${delimit + diff}' as double) + """ + + tag = "float_${delimit}_${diff}".replace('-', 'neg') + "qt_${tag}" """ + explain shape plan + select c_bigint + from tbl2_test_simplify_comparison_predicate_int_vs_double + where cast(c_bigint as float) = cast('${delimit + diff}' as float) + """ + } + } + + sql "set enable_strict_cast=false" + + explainAndOrderResult 'int_vs_double_1', """ + select * + from tbl2_test_simplify_comparison_predicate_int_vs_double + where c_bigint > 123.456 + """ + + explainAndOrderResult 'int_vs_double_2', """ + select * + from tbl2_test_simplify_comparison_predicate_int_vs_double + where cast(c_bigint as decimal(7, 3)) > 123.456 + """ + + explainAndOrderResult 'decimal_vs_decimal_1', """ + select * + from tbl2_test_simplify_comparison_predicate_int_vs_double + where c_decimal > 123.4567 + """ + + explainAndOrderResult 'decimal_vs_decimal_2', """ + select * + from tbl2_test_simplify_comparison_predicate_int_vs_double + where cast(c_decimal as decimal(3,1)) > cast(12.3 as decimal(5, 1)) + """ + + explainAndOrderResult 'decimal_vs_decimal_3', """ + select * + from tbl2_test_simplify_comparison_predicate_int_vs_double + where cast(c_decimal as decimal(5,2)) > cast(123.45 as decimal(5, 2)) + """ + + sql "set enable_strict_cast=true" + + explainAndOrderResult 'int_vs_double_3', """ + select * + from tbl2_test_simplify_comparison_predicate_int_vs_double + where c_bigint > 123.456 + """ + + explainAndOrderResult 'decimal_vs_decimal_4', """ + select * + from tbl2_test_simplify_comparison_predicate_int_vs_double + where c_decimal > 123.4567 + """ + + test { + sql """ + select * + from tbl2_test_simplify_comparison_predicate_int_vs_double + where cast(c_bigint as decimal(7, 3)) > 123.456 + """ + + exception 'Arithmetic overflow when converting value' + } + + test { + sql """ + select * + from tbl2_test_simplify_comparison_predicate_int_vs_double + where cast(c_decimal as decimal(3,1)) > cast(12.3 as decimal(5, 1)) + """ + + exception 'Arithmetic overflow when converting value' + } + + test { + sql """ + select * + from tbl2_test_simplify_comparison_predicate_int_vs_double + where cast(c_decimal as decimal(5,2)) > cast(123.45 as decimal(5, 2)) + """ + + exception 'Arithmetic overflow when converting value' + } + + sql """ + drop table if exists tbl1_test_simplify_comparison_predicate_int_vs_double force; + drop table if exists tbl2_test_simplify_comparison_predicate_int_vs_double force; + """ +}