From 34bd2c2be764ffde03f06bd245bfcf6c843a64e6 Mon Sep 17 00:00:00 2001 From: yujun Date: Wed, 10 Sep 2025 20:35:39 +0800 Subject: [PATCH 1/8] fix simplify integer like compare with double --- .../rules/SimplifyComparisonPredicate.java | 46 +++++++++++++------ .../SimplifyComparisonPredicateTest.java | 11 +++++ 2 files changed, 42 insertions(+), 15 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java index c8328a5c913e08..2b38ae6b4b1461 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java @@ -80,6 +80,9 @@ public class SimplifyComparisonPredicate implements ExpressionPatternRuleFactory { public static SimplifyComparisonPredicate INSTANCE = new SimplifyComparisonPredicate(); + private static final long MAX_LONG_CAST_DOUBLE_NO_LOSS = 1L << 53; + private static final long MIN_LONG_CAST_DOUBLE_NO_LOSS = -MAX_LONG_CAST_DOUBLE_NO_LOSS; + @Override public List> buildRules() { return ImmutableList.of( @@ -460,6 +463,7 @@ private static Expression processIntegerDecimalLiteralComparison( if (literal.compareTo(new BigDecimal(Long.MIN_VALUE)) >= 0 && literal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0) { literal = literal.stripTrailingZeros(); + Optional roundLiteralOpt = Optional.empty(); if (literal.scale() > 0) { if (comparisonPredicate instanceof EqualTo) { // TODO: the ideal way is to return an If expr like: @@ -473,21 +477,23 @@ private static Expression processIntegerDecimalLiteralComparison( return BooleanLiteral.of(false); } else if (comparisonPredicate instanceof GreaterThan || comparisonPredicate instanceof LessThanEqual) { - return TypeCoercionUtils - .processComparisonPredicate((ComparisonPredicate) comparisonPredicate - .withChildren(left, convertDecimalToIntegerLikeLiteral( - literal.setScale(0, RoundingMode.FLOOR)))); + roundLiteralOpt = Optional.of(literal.setScale(0, RoundingMode.FLOOR)); } else if (comparisonPredicate instanceof LessThan || comparisonPredicate instanceof GreaterThanEqual) { + roundLiteralOpt = Optional.of(literal.setScale(0, RoundingMode.CEILING)); + } + } else { + roundLiteralOpt = Optional.of(literal); + } + if (roundLiteralOpt.isPresent()) { + boolean isCastFloatLike = comparisonPredicate.right().getDataType().isFloatLikeType(); + Optional integerLikeLiteralOpt + = convertDecimalToIntegerLikeLiteral(roundLiteralOpt.get(), isCastFloatLike); + if (integerLikeLiteralOpt.isPresent()) { return TypeCoercionUtils .processComparisonPredicate((ComparisonPredicate) comparisonPredicate - .withChildren(left, convertDecimalToIntegerLikeLiteral( - literal.setScale(0, RoundingMode.CEILING)))); + .withChildren(left, integerLikeLiteralOpt.get())); } - } else { - return TypeCoercionUtils - .processComparisonPredicate((ComparisonPredicate) comparisonPredicate - .withChildren(left, convertDecimalToIntegerLikeLiteral(literal))); } } return comparisonPredicate; @@ -615,20 +621,30 @@ private static Optional convertDecimalToSmallerDecimalV3Type(Compari return Optional.empty(); } - private static IntegerLikeLiteral convertDecimalToIntegerLikeLiteral(BigDecimal decimal) { + private static Optional convertDecimalToIntegerLikeLiteral(BigDecimal decimal, + boolean isCastFloatLike) { Preconditions.checkArgument(decimal.scale() <= 0 && decimal.compareTo(new BigDecimal(Long.MIN_VALUE)) >= 0 && decimal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0, "decimal literal must have 0 scale and in range [Long.MIN_VALUE, Long.MAX_VALUE]"); long val = decimal.longValue(); if (val >= Byte.MIN_VALUE && val <= Byte.MAX_VALUE) { - return new TinyIntLiteral((byte) val); + return Optional.of(new TinyIntLiteral((byte) val)); } else if (val >= Short.MIN_VALUE && val <= Short.MAX_VALUE) { - return new SmallIntLiteral((short) val); + return Optional.of(new SmallIntLiteral((short) val)); } else if (val >= Integer.MIN_VALUE && val <= Integer.MAX_VALUE) { - return new IntegerLiteral((int) val); + return Optional.of(new IntegerLiteral((int) val)); + } else if (!isCastFloatLike + || (val > MIN_LONG_CAST_DOUBLE_NO_LOSS && val < MAX_LONG_CAST_DOUBLE_NO_LOSS)) { + // for decimal, all long value can convert to decimal without loss of precision + // for float/double, only [-2^53, 2^53] can convert to long without loss of precision, + // but here need to exclude the boundary value, because + // cast(2^53 as double) = cast(2^53 + 1 as double) = 2^53 = MAX_LONG_CAST_DOUBLE_NO_LOSS, + // so for cast(c_bigint as double) = 2^53, we can't simplify it to c_bigint = 2^53, + // c_bigint can be 2^53 + 1. The same for -2^53 + return Optional.of(new BigIntLiteral(val)); } else { - return new BigIntLiteral(val); + return Optional.empty(); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java index 74f569f52c3b5c..3d6f7c1bb5c510 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java @@ -596,6 +596,17 @@ void testDoubleLiteral() { new LessThan(bigIntSlot, new BigIntLiteral(13L))); assertRewrite(new LessThanEqual(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)), new LessThanEqual(bigIntSlot, new BigIntLiteral(12L))); + + // big int and literal near no loss bound + double noLossBound = 9007199254740992.0; + assertRewrite(new EqualTo(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(-noLossBound)), + new EqualTo(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(-noLossBound))); + assertRewrite(new EqualTo(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(-9007199254740991.0)), + new EqualTo(bigIntSlot, new BigIntLiteral(-9007199254740991L))); + assertRewrite(new EqualTo(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(9007199254740991.0)), + new EqualTo(bigIntSlot, new BigIntLiteral(9007199254740991L))); + assertRewrite(new EqualTo(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(noLossBound)), + new EqualTo(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(noLossBound))); } @Test From 590f5b432d053aa5dd5d5885772bf736bab07695 Mon Sep 17 00:00:00 2001 From: yujun Date: Thu, 11 Sep 2025 10:41:47 +0800 Subject: [PATCH 2/8] add test --- .../rules/SimplifyComparisonPredicate.java | 48 ++++++++++++------ .../doris/nereids/trees/expressions/Cast.java | 12 ++--- .../SimplifyComparisonPredicateTest.java | 25 +++++++--- ...ify_comparison_predicate_int_vs_double.out | 12 +++++ ..._comparison_predicate_int_vs_double.groovy | 49 +++++++++++++++++++ 5 files changed, 116 insertions(+), 30 deletions(-) create mode 100644 regression-test/data/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.out create mode 100644 regression-test/suites/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.groovy diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java index 2b38ae6b4b1461..b42ea7d6f9ec1f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java @@ -80,8 +80,10 @@ public class SimplifyComparisonPredicate implements ExpressionPatternRuleFactory { public static SimplifyComparisonPredicate INSTANCE = new SimplifyComparisonPredicate(); - private static final long MAX_LONG_CAST_DOUBLE_NO_LOSS = 1L << 53; - private static final long MIN_LONG_CAST_DOUBLE_NO_LOSS = -MAX_LONG_CAST_DOUBLE_NO_LOSS; + private static final int MAX_CONTINUE_INT_TO_FLOAT_NO_LOSS = 1 << 24; + private static final int MIN_CONTINUE_INT_TO_FLOAT_NO_LOSS = -MAX_CONTINUE_INT_TO_FLOAT_NO_LOSS; + private static final long MAX_CONTINUE_LONG_TO_DOUBLE_NO_LOSS = 1L << 53; + private static final long MIN_CONTINUE_LONG_TO_DOUBLE_NO_LOSS = -MAX_CONTINUE_LONG_TO_DOUBLE_NO_LOSS; @Override public List> buildRules() { @@ -486,9 +488,9 @@ private static Expression processIntegerDecimalLiteralComparison( roundLiteralOpt = Optional.of(literal); } if (roundLiteralOpt.isPresent()) { - boolean isCastFloatLike = comparisonPredicate.right().getDataType().isFloatLikeType(); + DataType castDataType = comparisonPredicate.left().getDataType(); Optional integerLikeLiteralOpt - = convertDecimalToIntegerLikeLiteral(roundLiteralOpt.get(), isCastFloatLike); + = convertDecimalToIntegerLikeLiteral(roundLiteralOpt.get(), castDataType); if (integerLikeLiteralOpt.isPresent()) { return TypeCoercionUtils .processComparisonPredicate((ComparisonPredicate) comparisonPredicate @@ -622,7 +624,7 @@ private static Optional convertDecimalToSmallerDecimalV3Type(Compari } private static Optional convertDecimalToIntegerLikeLiteral(BigDecimal decimal, - boolean isCastFloatLike) { + DataType castDataType) { Preconditions.checkArgument(decimal.scale() <= 0 && decimal.compareTo(new BigDecimal(Long.MIN_VALUE)) >= 0 && decimal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0, @@ -633,18 +635,32 @@ private static Optional convertDecimalToIntegerLikeLiteral(B } else if (val >= Short.MIN_VALUE && val <= Short.MAX_VALUE) { return Optional.of(new SmallIntLiteral((short) val)); } else if (val >= Integer.MIN_VALUE && val <= Integer.MAX_VALUE) { - return Optional.of(new IntegerLiteral((int) val)); - } else if (!isCastFloatLike - || (val > MIN_LONG_CAST_DOUBLE_NO_LOSS && val < MAX_LONG_CAST_DOUBLE_NO_LOSS)) { - // for decimal, all long value can convert to decimal without loss of precision - // for float/double, only [-2^53, 2^53] can convert to long without loss of precision, - // but here need to exclude the boundary value, because - // cast(2^53 as double) = cast(2^53 + 1 as double) = 2^53 = MAX_LONG_CAST_DOUBLE_NO_LOSS, - // so for cast(c_bigint as double) = 2^53, we can't simplify it to c_bigint = 2^53, - // c_bigint can be 2^53 + 1. The same for -2^53 - return Optional.of(new BigIntLiteral(val)); + // in fact, type convert shouldn't have `cast(integer like as float) cmp float literal`, + // it should convert to `cast(integer like as double) cmp double literal`. + // but we still process it to be more robust + if (castDataType.isFloatType() + && (val <= MIN_CONTINUE_INT_TO_FLOAT_NO_LOSS || val >= MAX_CONTINUE_INT_TO_FLOAT_NO_LOSS)) { + // for float, only [-2^24, 2^24] can convert to int without loss of precision, + // but here need to exclude the boundary value, because + // cast(2^24 as float) = cast(2^24 + 1 as float) = 2^24 = MAX_CONTINUE_INT_TO_FLOAT_NO_LOSS, + // so for cast(c_int as float) = 2^24, we can't simplify it to c_int = 2^24, + // c_int can be 2^24 + 1. The same for -2^24 + return Optional.empty(); + } else { + return Optional.of(new IntegerLiteral((int) val)); + } } else { - return Optional.empty(); + if (castDataType.isDoubleType() + && (val <= MIN_CONTINUE_LONG_TO_DOUBLE_NO_LOSS || val >= MAX_CONTINUE_LONG_TO_DOUBLE_NO_LOSS)) { + // for double, only [-2^53, 2^53] can convert to long without loss of precision, + // but here need to exclude the boundary value, because + // cast(2^53 as double) = cast(2^53 + 1 as double) = 2^53 = MAX_CONTINUE_LONG_TO_DOUBLE_NO_LOSS, + // so for cast(c_bigint as double) = 2^53, we can't simplify it to c_bigint = 2^53, + // c_bigint can be 2^53 + 1. The same for -2^53 + return Optional.empty(); + } else { + return Optional.of(new BigIntLiteral(val)); + } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Cast.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Cast.java index 04ab287a8f9c6e..d225954dec8d47 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Cast.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Cast.java @@ -49,16 +49,12 @@ public class Cast extends Expression implements UnaryExpression, Monotonic { private final DataType targetType; - public Cast(Expression child, DataType targetType, boolean isExplicitType) { - super(ImmutableList.of(child)); - this.targetType = Objects.requireNonNull(targetType, "targetType can not be null"); - this.isExplicitType = isExplicitType; + public Cast(Expression child, DataType targetType) { + this(child, targetType, false); } - public Cast(Expression child, DataType targetType) { - super(ImmutableList.of(child)); - this.targetType = Objects.requireNonNull(targetType, "targetType can not be null"); - this.isExplicitType = false; + public Cast(Expression child, DataType targetType, boolean isExplicitType) { + this(ImmutableList.of(child), targetType, isExplicitType); } private Cast(List child, DataType targetType, boolean isExplicitType) { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java index 3d6f7c1bb5c510..98a1f299d329d1 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java @@ -597,16 +597,29 @@ void testDoubleLiteral() { assertRewrite(new LessThanEqual(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)), new LessThanEqual(bigIntSlot, new BigIntLiteral(12L))); - // big int and literal near no loss bound - double noLossBound = 9007199254740992.0; - assertRewrite(new EqualTo(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(-noLossBound)), - new EqualTo(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(-noLossBound))); + // int and float literal near no loss bound + // in fact, shouldn't have cast(c_int as float) cmp float literal, it will convert to cast(c_int as double) cmp double literal + // but we still test 'cast(c_int as float) cmp float literal' here for more robustness + float noLossBoundF = 16777216.0f; // 2^24 + assertRewrite(new EqualTo(new Cast(intSlot, FloatType.INSTANCE), new FloatLiteral(-noLossBoundF)), + new EqualTo(new Cast(intSlot, FloatType.INSTANCE), new FloatLiteral(-noLossBoundF))); + assertRewrite(new EqualTo(new Cast(intSlot, FloatType.INSTANCE), new FloatLiteral(-16777215.0f)), + new EqualTo(intSlot, new IntegerLiteral(-16777215))); + assertRewrite(new EqualTo(new Cast(intSlot, FloatType.INSTANCE), new FloatLiteral(16777215.0f)), + new EqualTo(intSlot, new IntegerLiteral(16777215))); + assertRewrite(new EqualTo(new Cast(intSlot, FloatType.INSTANCE), new FloatLiteral(noLossBoundF)), + new EqualTo(new Cast(intSlot, FloatType.INSTANCE), new FloatLiteral(noLossBoundF))); + + // big int and double literal near no loss bound + double noLossBoundD = 9007199254740992.0; + assertRewrite(new EqualTo(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(-noLossBoundD)), + new EqualTo(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(-noLossBoundD))); assertRewrite(new EqualTo(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(-9007199254740991.0)), new EqualTo(bigIntSlot, new BigIntLiteral(-9007199254740991L))); assertRewrite(new EqualTo(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(9007199254740991.0)), new EqualTo(bigIntSlot, new BigIntLiteral(9007199254740991L))); - assertRewrite(new EqualTo(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(noLossBound)), - new EqualTo(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(noLossBound))); + assertRewrite(new EqualTo(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(noLossBoundD)), + new EqualTo(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(noLossBoundD))); } @Test diff --git a/regression-test/data/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.out b/regression-test/data/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.out new file mode 100644 index 00000000000000..9cac5b001cb732 --- /dev/null +++ b/regression-test/data/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.out @@ -0,0 +1,12 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !cast_bigint_as_double_shape -- +PhysicalResultSink +--NestedLoopJoin[LEFT_OUTER_JOIN] +----filter((t1.c_s = '870479087484055553')) +------PhysicalOlapScan[tbl1_test_simplify_comparison_predicate_int_vs_double] +----filter((cast(c_bigint as DOUBLE) = 8.7047908748405555E17)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !cast_bigint_as_double_result -- +100 870479087484055553 200 870479087484055553 + diff --git a/regression-test/suites/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.groovy b/regression-test/suites/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.groovy new file mode 100644 index 00000000000000..b044ffbbd0c067 --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.groovy @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite('test_simplify_comparison_predicate_int_vs_double') { + sql """ + set runtime_filter_mode='OFF'; + set disable_join_reorder=false; + set ignore_shape_nodes='PhysicalDistribute'; + + drop table if exists tbl1_test_simplify_comparison_predicate_int_vs_double force; + drop table if exists tbl2_test_simplify_comparison_predicate_int_vs_double force; + + create table tbl1_test_simplify_comparison_predicate_int_vs_double + (k1 int, c_s varchar(100)) properties('replication_num' = '1'); + create table tbl2_test_simplify_comparison_predicate_int_vs_double + (k2 int, c_bigint bigint) properties('replication_num' = '1'); + + insert into tbl1_test_simplify_comparison_predicate_int_vs_double values(100, "870479087484055553"), (101, "870479087484055554"); + insert into tbl2_test_simplify_comparison_predicate_int_vs_double values(200, 870479087484055553); + """ + + explainAndOrderResult 'cast_bigint_as_double', """ + select * + from tbl1_test_simplify_comparison_predicate_int_vs_double t1 + left join tbl2_test_simplify_comparison_predicate_int_vs_double t2 + on t1.c_s = t2.c_bigint + where t1.c_s = '870479087484055553' + """ + + sql """ + drop table if exists tbl1_test_simplify_comparison_predicate_int_vs_double force; + drop table if exists tbl2_test_simplify_comparison_predicate_int_vs_double force; + """ +} From 95eb5830363f9ccd9f0fe7d465b1408ee17a1e3e Mon Sep 17 00:00:00 2001 From: yujun Date: Thu, 11 Sep 2025 11:02:02 +0800 Subject: [PATCH 3/8] udpate --- .../expression/rules/SimplifyComparisonPredicate.java | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java index b42ea7d6f9ec1f..30147601ccc2e9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java @@ -635,7 +635,10 @@ private static Optional convertDecimalToIntegerLikeLiteral(B } else if (val >= Short.MIN_VALUE && val <= Short.MAX_VALUE) { return Optional.of(new SmallIntLiteral((short) val)); } else if (val >= Integer.MIN_VALUE && val <= Integer.MAX_VALUE) { - // in fact, type convert shouldn't have `cast(integer like as float) cmp float literal`, + // double/decimal can represent all int value without loss of precision, + // but float can't represent all int value without loss of precision. + // need to handle the float case specially. + // but notice that, in fact, type convert shouldn't have `cast(integer like as float) cmp float literal`, // it should convert to `cast(integer like as double) cmp double literal`. // but we still process it to be more robust if (castDataType.isFloatType() @@ -650,7 +653,9 @@ private static Optional convertDecimalToIntegerLikeLiteral(B return Optional.of(new IntegerLiteral((int) val)); } } else { - if (castDataType.isDoubleType() + // decimal can represent all long value without loss of precision, + // but float/double can't represent all long value without loss of precision. + if (castDataType.isFloatLikeType() && (val <= MIN_CONTINUE_LONG_TO_DOUBLE_NO_LOSS || val >= MAX_CONTINUE_LONG_TO_DOUBLE_NO_LOSS)) { // for double, only [-2^53, 2^53] can convert to long without loss of precision, // but here need to exclude the boundary value, because From c4209d4eb37eccaa4a678b778465c66b105c4888 Mon Sep 17 00:00:00 2001 From: yujun Date: Thu, 11 Sep 2025 11:37:36 +0800 Subject: [PATCH 4/8] rename variable --- .../rules/SimplifyComparisonPredicate.java | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java index 30147601ccc2e9..eb8fc60cfa9c79 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java @@ -80,10 +80,10 @@ public class SimplifyComparisonPredicate implements ExpressionPatternRuleFactory { public static SimplifyComparisonPredicate INSTANCE = new SimplifyComparisonPredicate(); - private static final int MAX_CONTINUE_INT_TO_FLOAT_NO_LOSS = 1 << 24; - private static final int MIN_CONTINUE_INT_TO_FLOAT_NO_LOSS = -MAX_CONTINUE_INT_TO_FLOAT_NO_LOSS; - private static final long MAX_CONTINUE_LONG_TO_DOUBLE_NO_LOSS = 1L << 53; - private static final long MIN_CONTINUE_LONG_TO_DOUBLE_NO_LOSS = -MAX_CONTINUE_LONG_TO_DOUBLE_NO_LOSS; + private static final int MAX_INT_TO_FLOAT_NO_LOSS = 1 << 24; + private static final int MIN_INT_TO_FLOAT_NO_LOSS = -MAX_INT_TO_FLOAT_NO_LOSS; + private static final long MAX_LONG_TO_DOUBLE_NO_LOSS = 1L << 53; + private static final long MIN_LONG_TO_DOUBLE_NO_LOSS = -MAX_LONG_TO_DOUBLE_NO_LOSS; @Override public List> buildRules() { @@ -642,10 +642,10 @@ private static Optional convertDecimalToIntegerLikeLiteral(B // it should convert to `cast(integer like as double) cmp double literal`. // but we still process it to be more robust if (castDataType.isFloatType() - && (val <= MIN_CONTINUE_INT_TO_FLOAT_NO_LOSS || val >= MAX_CONTINUE_INT_TO_FLOAT_NO_LOSS)) { + && (val <= MIN_INT_TO_FLOAT_NO_LOSS || val >= MAX_INT_TO_FLOAT_NO_LOSS)) { // for float, only [-2^24, 2^24] can convert to int without loss of precision, // but here need to exclude the boundary value, because - // cast(2^24 as float) = cast(2^24 + 1 as float) = 2^24 = MAX_CONTINUE_INT_TO_FLOAT_NO_LOSS, + // cast(2^24 as float) = cast(2^24 + 1 as float) = 2^24 = MAX_INT_TO_FLOAT_NO_LOSS, // so for cast(c_int as float) = 2^24, we can't simplify it to c_int = 2^24, // c_int can be 2^24 + 1. The same for -2^24 return Optional.empty(); @@ -656,10 +656,10 @@ private static Optional convertDecimalToIntegerLikeLiteral(B // decimal can represent all long value without loss of precision, // but float/double can't represent all long value without loss of precision. if (castDataType.isFloatLikeType() - && (val <= MIN_CONTINUE_LONG_TO_DOUBLE_NO_LOSS || val >= MAX_CONTINUE_LONG_TO_DOUBLE_NO_LOSS)) { + && (val <= MIN_LONG_TO_DOUBLE_NO_LOSS || val >= MAX_LONG_TO_DOUBLE_NO_LOSS)) { // for double, only [-2^53, 2^53] can convert to long without loss of precision, // but here need to exclude the boundary value, because - // cast(2^53 as double) = cast(2^53 + 1 as double) = 2^53 = MAX_CONTINUE_LONG_TO_DOUBLE_NO_LOSS, + // cast(2^53 as double) = cast(2^53 + 1 as double) = 2^53 = MAX_LONG_TO_DOUBLE_NO_LOSS, // so for cast(c_bigint as double) = 2^53, we can't simplify it to c_bigint = 2^53, // c_bigint can be 2^53 + 1. The same for -2^53 return Optional.empty(); From 7fcaa33f9098207efdcfd02d1df752bc4eaebeea Mon Sep 17 00:00:00 2001 From: yujun Date: Thu, 11 Sep 2025 16:08:24 +0800 Subject: [PATCH 5/8] add test --- .../rules/SimplifyComparisonPredicate.java | 8 +- .../SimplifyComparisonPredicateTest.java | 2 - ...ify_comparison_predicate_int_vs_double.out | 108 ++++++++++++++++++ ..._comparison_predicate_int_vs_double.groovy | 46 +++++++- 4 files changed, 155 insertions(+), 9 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java index eb8fc60cfa9c79..9e12dcc5b16bbd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java @@ -638,9 +638,6 @@ private static Optional convertDecimalToIntegerLikeLiteral(B // double/decimal can represent all int value without loss of precision, // but float can't represent all int value without loss of precision. // need to handle the float case specially. - // but notice that, in fact, type convert shouldn't have `cast(integer like as float) cmp float literal`, - // it should convert to `cast(integer like as double) cmp double literal`. - // but we still process it to be more robust if (castDataType.isFloatType() && (val <= MIN_INT_TO_FLOAT_NO_LOSS || val >= MAX_INT_TO_FLOAT_NO_LOSS)) { // for float, only [-2^24, 2^24] can convert to int without loss of precision, @@ -655,8 +652,9 @@ private static Optional convertDecimalToIntegerLikeLiteral(B } else { // decimal can represent all long value without loss of precision, // but float/double can't represent all long value without loss of precision. - if (castDataType.isFloatLikeType() - && (val <= MIN_LONG_TO_DOUBLE_NO_LOSS || val >= MAX_LONG_TO_DOUBLE_NO_LOSS)) { + if (castDataType.isFloatType() + || (castDataType.isDoubleType() + && (val <= MIN_LONG_TO_DOUBLE_NO_LOSS || val >= MAX_LONG_TO_DOUBLE_NO_LOSS))) { // for double, only [-2^53, 2^53] can convert to long without loss of precision, // but here need to exclude the boundary value, because // cast(2^53 as double) = cast(2^53 + 1 as double) = 2^53 = MAX_LONG_TO_DOUBLE_NO_LOSS, diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java index 98a1f299d329d1..10cb088208aca0 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java @@ -598,8 +598,6 @@ void testDoubleLiteral() { new LessThanEqual(bigIntSlot, new BigIntLiteral(12L))); // int and float literal near no loss bound - // in fact, shouldn't have cast(c_int as float) cmp float literal, it will convert to cast(c_int as double) cmp double literal - // but we still test 'cast(c_int as float) cmp float literal' here for more robustness float noLossBoundF = 16777216.0f; // 2^24 assertRewrite(new EqualTo(new Cast(intSlot, FloatType.INSTANCE), new FloatLiteral(-noLossBoundF)), new EqualTo(new Cast(intSlot, FloatType.INSTANCE), new FloatLiteral(-noLossBoundF))); diff --git a/regression-test/data/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.out b/regression-test/data/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.out index 9cac5b001cb732..dd4dccdf839cae 100644 --- a/regression-test/data/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.out +++ b/regression-test/data/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.out @@ -10,3 +10,111 @@ PhysicalResultSink -- !cast_bigint_as_double_result -- 100 870479087484055553 200 870479087484055553 +-- !tbl3_float_neg16777216_neg10 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as FLOAT) = -1.6777226E7)) +------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] + +-- !tbl3_float_neg16777216_0 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as FLOAT) = -1.6777216E7)) +------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] + +-- !tbl3_float_neg16777216_10 -- +PhysicalResultSink +--PhysicalProject +----filter((tbl3_test_simplify_comparison_predicate_int_vs_double.c_bigint = -16777206)) +------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] + +-- !tbl3_float_16777216_neg10 -- +PhysicalResultSink +--PhysicalProject +----filter((tbl3_test_simplify_comparison_predicate_int_vs_double.c_bigint = 16777206)) +------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] + +-- !tbl3_float_16777216_0 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as FLOAT) = 1.6777216E7)) +------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] + +-- !tbl3_float_16777216_10 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as FLOAT) = 1.6777226E7)) +------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] + +-- !tbl3_double_neg9007199254740992_neg10 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as DOUBLE) = -9.007199254741002E15)) +------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] + +-- !tbl3_float_neg9007199254740992_neg10 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as FLOAT) = -9.0071993E15)) +------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] + +-- !tbl3_double_neg9007199254740992_0 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as DOUBLE) = -9.007199254740992E15)) +------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] + +-- !tbl3_float_neg9007199254740992_0 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as FLOAT) = -9.0071993E15)) +------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] + +-- !tbl3_double_neg9007199254740992_10 -- +PhysicalResultSink +--PhysicalProject +----filter((tbl3_test_simplify_comparison_predicate_int_vs_double.c_bigint = -9007199254740982)) +------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] + +-- !tbl3_float_neg9007199254740992_10 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as FLOAT) = -9.0071993E15)) +------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] + +-- !tbl3_double_9007199254740992_neg10 -- +PhysicalResultSink +--PhysicalProject +----filter((tbl3_test_simplify_comparison_predicate_int_vs_double.c_bigint = 9007199254740982)) +------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] + +-- !tbl3_float_9007199254740992_neg10 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as FLOAT) = 9.0071993E15)) +------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] + +-- !tbl3_double_9007199254740992_0 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as DOUBLE) = 9.007199254740992E15)) +------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] + +-- !tbl3_float_9007199254740992_0 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as FLOAT) = 9.0071993E15)) +------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] + +-- !tbl3_double_9007199254740992_10 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as DOUBLE) = 9.007199254741002E15)) +------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] + +-- !tbl3_float_9007199254740992_10 -- +PhysicalResultSink +--PhysicalProject +----filter((cast(c_bigint as FLOAT) = 9.0071993E15)) +------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] + diff --git a/regression-test/suites/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.groovy b/regression-test/suites/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.groovy index b044ffbbd0c067..26bfd61fbc72d3 100644 --- a/regression-test/suites/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.groovy +++ b/regression-test/suites/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.groovy @@ -24,16 +24,25 @@ suite('test_simplify_comparison_predicate_int_vs_double') { drop table if exists tbl1_test_simplify_comparison_predicate_int_vs_double force; drop table if exists tbl2_test_simplify_comparison_predicate_int_vs_double force; + drop table if exists tbl3_test_simplify_comparison_predicate_int_vs_double force; create table tbl1_test_simplify_comparison_predicate_int_vs_double (k1 int, c_s varchar(100)) properties('replication_num' = '1'); create table tbl2_test_simplify_comparison_predicate_int_vs_double (k2 int, c_bigint bigint) properties('replication_num' = '1'); + create table tbl3_test_simplify_comparison_predicate_int_vs_double + (k3 int, c_bigint bigint) properties('replication_num' = '1'); - insert into tbl1_test_simplify_comparison_predicate_int_vs_double values(100, "870479087484055553"), (101, "870479087484055554"); - insert into tbl2_test_simplify_comparison_predicate_int_vs_double values(200, 870479087484055553); + insert into tbl1_test_simplify_comparison_predicate_int_vs_double values + (100, "870479087484055553"), + (101,"870479087484055554"); + insert into tbl2_test_simplify_comparison_predicate_int_vs_double values + (200, 870479087484055553); + insert into tbl3_test_simplify_comparison_predicate_int_vs_double values + (300, 870479087484055553); """ + explainAndOrderResult 'cast_bigint_as_double', """ select * from tbl1_test_simplify_comparison_predicate_int_vs_double t1 @@ -42,8 +51,41 @@ suite('test_simplify_comparison_predicate_int_vs_double') { where t1.c_s = '870479087484055553' """ + for (def delimit : [-(1L<<24), 1L<<24]) { + for (def diff : [-10, 0, 10]) { + def tag = "tbl3_float_${delimit}_${diff}".replace('-', 'neg') + "qt_${tag}" """ + explain shape plan + select c_bigint + from tbl3_test_simplify_comparison_predicate_int_vs_double + where cast(c_bigint as float) = cast('${delimit + diff}' as float) + """ + } + } + + for (def delimit : [-(1L<<53), 1L<<53]) { + for (def diff : [-10, 0, 10]) { + def tag = "tbl3_double_${delimit}_${diff}".replace('-', 'neg') + "qt_${tag}" """ + explain shape plan + select c_bigint + from tbl3_test_simplify_comparison_predicate_int_vs_double + where cast(c_bigint as double) = cast('${delimit + diff}' as double) + """ + + tag = "tbl3_float_${delimit}_${diff}".replace('-', 'neg') + "qt_${tag}" """ + explain shape plan + select c_bigint + from tbl3_test_simplify_comparison_predicate_int_vs_double + where cast(c_bigint as float) = cast('${delimit + diff}' as float) + """ + } + } + sql """ drop table if exists tbl1_test_simplify_comparison_predicate_int_vs_double force; drop table if exists tbl2_test_simplify_comparison_predicate_int_vs_double force; + drop table if exists tbl3_test_simplify_comparison_predicate_int_vs_double force; """ } From 0a4fd98ec25107750eb692891a0e9661cb14aef6 Mon Sep 17 00:00:00 2001 From: yujun Date: Thu, 11 Sep 2025 16:15:15 +0800 Subject: [PATCH 6/8] update test --- .../rules/SimplifyComparisonPredicate.java | 48 ++++++++----------- 1 file changed, 20 insertions(+), 28 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java index 9e12dcc5b16bbd..c4343c01c0797d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java @@ -630,40 +630,32 @@ private static Optional convertDecimalToIntegerLikeLiteral(B && decimal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0, "decimal literal must have 0 scale and in range [Long.MIN_VALUE, Long.MAX_VALUE]"); long val = decimal.longValue(); + // for integer like convert to float, only [-2^24, 2^24] can convert to float without loss of precision, + // but here need to exclude the boundary value, because + // cast(2^24 as float) = cast(2^24 + 1 as float) = 2^24 = MAX_INT_TO_FLOAT_NO_LOSS, + // so for cast(c_int as float) = 2^24, we can't simplify it to c_int = 2^24, + // c_int can be 2^24 + 1. The same for -2^24 + if (castDataType.isFloatType() + && (val <= MIN_INT_TO_FLOAT_NO_LOSS || val >= MAX_INT_TO_FLOAT_NO_LOSS)) { + return Optional.empty(); + } + // for long convert to double, only [-2^53, 2^53] can convert to double without loss of precision, + // but here need to exclude the boundary value, because + // cast(2^53 as double) = cast(2^53 + 1 as double) = 2^53 = MAX_LONG_TO_DOUBLE_NO_LOSS, + // so for cast(c_bigint as double) = 2^53, we can't simplify it to c_bigint = 2^53, + // c_bigint can be 2^53 + 1. The same for -2^53 + if (castDataType.isDoubleType() + && (val <= MIN_LONG_TO_DOUBLE_NO_LOSS || val >= MAX_LONG_TO_DOUBLE_NO_LOSS)) { + return Optional.empty(); + } if (val >= Byte.MIN_VALUE && val <= Byte.MAX_VALUE) { return Optional.of(new TinyIntLiteral((byte) val)); } else if (val >= Short.MIN_VALUE && val <= Short.MAX_VALUE) { return Optional.of(new SmallIntLiteral((short) val)); } else if (val >= Integer.MIN_VALUE && val <= Integer.MAX_VALUE) { - // double/decimal can represent all int value without loss of precision, - // but float can't represent all int value without loss of precision. - // need to handle the float case specially. - if (castDataType.isFloatType() - && (val <= MIN_INT_TO_FLOAT_NO_LOSS || val >= MAX_INT_TO_FLOAT_NO_LOSS)) { - // for float, only [-2^24, 2^24] can convert to int without loss of precision, - // but here need to exclude the boundary value, because - // cast(2^24 as float) = cast(2^24 + 1 as float) = 2^24 = MAX_INT_TO_FLOAT_NO_LOSS, - // so for cast(c_int as float) = 2^24, we can't simplify it to c_int = 2^24, - // c_int can be 2^24 + 1. The same for -2^24 - return Optional.empty(); - } else { - return Optional.of(new IntegerLiteral((int) val)); - } + return Optional.of(new IntegerLiteral((int) val)); } else { - // decimal can represent all long value without loss of precision, - // but float/double can't represent all long value without loss of precision. - if (castDataType.isFloatType() - || (castDataType.isDoubleType() - && (val <= MIN_LONG_TO_DOUBLE_NO_LOSS || val >= MAX_LONG_TO_DOUBLE_NO_LOSS))) { - // for double, only [-2^53, 2^53] can convert to long without loss of precision, - // but here need to exclude the boundary value, because - // cast(2^53 as double) = cast(2^53 + 1 as double) = 2^53 = MAX_LONG_TO_DOUBLE_NO_LOSS, - // so for cast(c_bigint as double) = 2^53, we can't simplify it to c_bigint = 2^53, - // c_bigint can be 2^53 + 1. The same for -2^53 - return Optional.empty(); - } else { - return Optional.of(new BigIntLiteral(val)); - } + return Optional.of(new BigIntLiteral(val)); } } From 30e986d45051428652c4ebc3f57ca681602d12e4 Mon Sep 17 00:00:00 2001 From: yujun Date: Fri, 12 Sep 2025 11:47:40 +0800 Subject: [PATCH 7/8] add test --- .../rules/SimplifyComparisonPredicate.java | 8 +++---- .../SimplifyComparisonPredicateTest.java | 21 +++++++++++++++++++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java index c4343c01c0797d..9c8b54bb580f39 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java @@ -80,10 +80,10 @@ public class SimplifyComparisonPredicate implements ExpressionPatternRuleFactory { public static SimplifyComparisonPredicate INSTANCE = new SimplifyComparisonPredicate(); - private static final int MAX_INT_TO_FLOAT_NO_LOSS = 1 << 24; - private static final int MIN_INT_TO_FLOAT_NO_LOSS = -MAX_INT_TO_FLOAT_NO_LOSS; - private static final long MAX_LONG_TO_DOUBLE_NO_LOSS = 1L << 53; - private static final long MIN_LONG_TO_DOUBLE_NO_LOSS = -MAX_LONG_TO_DOUBLE_NO_LOSS; + public static final int MAX_INT_TO_FLOAT_NO_LOSS = 1 << 24; + public static final int MIN_INT_TO_FLOAT_NO_LOSS = -MAX_INT_TO_FLOAT_NO_LOSS; + public static final long MAX_LONG_TO_DOUBLE_NO_LOSS = 1L << 53; + public static final long MIN_LONG_TO_DOUBLE_NO_LOSS = -MAX_LONG_TO_DOUBLE_NO_LOSS; @Override public List> buildRules() { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java index 10cb088208aca0..d78f594767d816 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java @@ -620,6 +620,27 @@ void testDoubleLiteral() { new EqualTo(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(noLossBoundD))); } + @Test + void testFloatNoLossBound() { + checkIntConvertFloatLikeLossBound(SimplifyComparisonPredicate.MIN_INT_TO_FLOAT_NO_LOSS, true); + checkIntConvertFloatLikeLossBound(SimplifyComparisonPredicate.MAX_INT_TO_FLOAT_NO_LOSS, true); + checkIntConvertFloatLikeLossBound(SimplifyComparisonPredicate.MIN_LONG_TO_DOUBLE_NO_LOSS, false); + checkIntConvertFloatLikeLossBound(SimplifyComparisonPredicate.MAX_LONG_TO_DOUBLE_NO_LOSS, false); + } + + private void checkIntConvertFloatLikeLossBound(long bound, boolean isFloat) { + for (int i = 0; i < 100000; i++) { + long v = (Math.abs(bound) - i) * Long.signum(bound); + long vCast = isFloat ? (long) ((float) v) : (long) ((double) v); + Assertions.assertEquals(v, vCast); + } + + long firstOutBound = (Math.abs(bound) + 1) * Long.signum(bound); + long firstOutBoundCast = isFloat ? (long) ((float) firstOutBound) : (long) ((double) firstOutBound); + Assertions.assertNotEquals(firstOutBound, firstOutBoundCast); + Assertions.assertEquals(bound, firstOutBoundCast); + } + @Test void testIntCmpDecimalV3Literal() { executor = new ExpressionRuleExecutor(ImmutableList.of( From 6ec3c785d9aea6fb304afa26cce6c631f7018561 Mon Sep 17 00:00:00 2001 From: yujun Date: Fri, 12 Sep 2025 19:34:05 +0800 Subject: [PATCH 8/8] handle with compare with decimal --- .../rules/SimplifyComparisonPredicate.java | 23 ++- .../doris/nereids/types/BigIntType.java | 5 + .../doris/nereids/types/IntegerType.java | 5 + .../doris/nereids/types/LargeIntType.java | 5 + .../doris/nereids/types/SmallIntType.java | 5 + .../doris/nereids/types/TinyIntType.java | 5 + .../nereids/types/coercion/IntegralType.java | 7 + .../SimplifyComparisonPredicateTest.java | 103 +++++++------ ...ify_comparison_predicate_int_vs_double.out | 135 ++++++++++++------ ..._comparison_predicate_int_vs_double.groovy | 98 +++++++++++-- 10 files changed, 291 insertions(+), 100 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java index 9c8b54bb580f39..7c3c4d6c4c615e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java @@ -60,6 +60,7 @@ import org.apache.doris.nereids.types.SmallIntType; import org.apache.doris.nereids.types.TinyIntType; import org.apache.doris.nereids.types.coercion.DateLikeType; +import org.apache.doris.nereids.types.coercion.IntegralType; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.TypeCoercionUtils; @@ -405,19 +406,25 @@ private static Expression processDecimalV3TypeCoercion(ComparisonPredicate compa if (left instanceof Cast && right instanceof DecimalV3Literal) { Cast cast = (Cast) left; left = cast.child(); + DecimalV3Type castDataType = (DecimalV3Type) cast.getDataType(); DecimalV3Literal literal = (DecimalV3Literal) right; if (left.getDataType().isDecimalV3Type()) { + DecimalV3Type leftType = (DecimalV3Type) left.getDataType(); + if (castDataType.getRange() < leftType.getRange() + || (castDataType.getRange() == leftType.getRange() + && castDataType.getScale() < leftType.getScale())) { + // for cast(col as decimal(m2, n2)) cmp literal, + // if cast-to can not hold col's integer part, the cast result maybe null, don't process it. + return comparisonPredicate; + } Optional toSmallerDecimalDataTypeExpr = convertDecimalToSmallerDecimalV3Type( comparisonPredicate, cast, literal); if (toSmallerDecimalDataTypeExpr.isPresent()) { return toSmallerDecimalDataTypeExpr.get(); } - DecimalV3Type leftType = (DecimalV3Type) left.getDataType(); DecimalV3Type literalType = (DecimalV3Type) literal.getDataType(); - if (cast.getDataType().isDecimalV3Type() - && ((DecimalV3Type) cast.getDataType()).getScale() >= leftType.getScale() - && leftType.getScale() < literalType.getScale()) { + if (castDataType.getScale() >= leftType.getScale() && leftType.getScale() < literalType.getScale()) { int toScale = ((DecimalV3Type) left.getDataType()).getScale(); if (comparisonPredicate instanceof EqualTo) { try { @@ -462,6 +469,13 @@ private static Expression processDecimalV3TypeCoercion(ComparisonPredicate compa private static Expression processIntegerDecimalLiteralComparison( ComparisonPredicate comparisonPredicate, Expression left, BigDecimal literal) { // we only process isIntegerLikeType, which are tinyint, smallint, int, bigint + // for `cast(c_int as decimal(m, n)) cmp literal`, + // if c_int's range is wider than decimal's range, cast result maybe null, don't process it. + DataType castDataType = comparisonPredicate.left().getDataType(); + if (castDataType.isDecimalV3Type() + && ((DecimalV3Type) castDataType).getRange() < ((IntegralType) left.getDataType()).range()) { + return comparisonPredicate; + } if (literal.compareTo(new BigDecimal(Long.MIN_VALUE)) >= 0 && literal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0) { literal = literal.stripTrailingZeros(); @@ -488,7 +502,6 @@ private static Expression processIntegerDecimalLiteralComparison( roundLiteralOpt = Optional.of(literal); } if (roundLiteralOpt.isPresent()) { - DataType castDataType = comparisonPredicate.left().getDataType(); Optional integerLikeLiteralOpt = convertDecimalToIntegerLikeLiteral(roundLiteralOpt.get(), castDataType); if (integerLikeLiteralOpt.isPresent()) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/BigIntType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/BigIntType.java index 767db0e89a9d32..f654b9f271ba81 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/BigIntType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/BigIntType.java @@ -62,4 +62,9 @@ public DataType defaultConcreteType() { public int width() { return WIDTH; } + + @Override + public int range() { + return RANGE; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/IntegerType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/IntegerType.java index 2b22c167d464f0..7b47ff90b72270 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/IntegerType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/IntegerType.java @@ -62,4 +62,9 @@ public DataType defaultConcreteType() { public int width() { return WIDTH; } + + @Override + public int range() { + return RANGE; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/LargeIntType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/LargeIntType.java index 78a9369bd26830..22e88206c6abd3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/LargeIntType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/LargeIntType.java @@ -71,4 +71,9 @@ public DataType defaultConcreteType() { public int width() { return WIDTH; } + + @Override + public int range() { + return RANGE; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/SmallIntType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/SmallIntType.java index 4272052cd1a938..ad2743a4a7e19a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/SmallIntType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/SmallIntType.java @@ -62,4 +62,9 @@ public DataType defaultConcreteType() { public int width() { return WIDTH; } + + @Override + public int range() { + return RANGE; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TinyIntType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TinyIntType.java index 30259582127e88..075cb945db2479 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TinyIntType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TinyIntType.java @@ -62,4 +62,9 @@ public DataType defaultConcreteType() { public int width() { return WIDTH; } + + @Override + public int range() { + return RANGE; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java index 656c6f660f2090..b1e588053881eb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java @@ -20,6 +20,8 @@ import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.DataType; +import org.apache.commons.lang3.NotImplementedException; + /** * Abstract class for all integral data type in Nereids. */ @@ -45,4 +47,9 @@ public String simpleString() { public boolean widerThan(IntegralType other) { return this.width() > other.width(); } + + // The maximum number of digits that Integer can represent. + public int range() { + throw new NotImplementedException("should be implemented by derived class"); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java index d78f594767d816..b5ad2dda3d443b 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java @@ -653,95 +653,110 @@ void testIntCmpDecimalV3Literal() { Expression bigIntSlot = new SlotReference("a", BigIntType.INSTANCE); // tiny int, literal not exceeds data type limit - assertRewrite(new EqualTo(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.0"))), + DecimalV3Type tinyDecimalType = DecimalV3Type.createDecimalV3Type(4, 1); + assertRewrite(new EqualTo(new Cast(tinyIntSlot, tinyDecimalType), new DecimalV3Literal(tinyDecimalType, new BigDecimal("12.0"))), new EqualTo(tinyIntSlot, new TinyIntLiteral((byte) 12))); - assertRewrite(new EqualTo(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new EqualTo(new Cast(tinyIntSlot, tinyDecimalType), new DecimalV3Literal(tinyDecimalType, new BigDecimal("12.3"))), ExpressionUtils.falseOrNull(tinyIntSlot)); - assertRewrite(new NullSafeEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new NullSafeEqual(new Cast(tinyIntSlot, tinyDecimalType), new DecimalV3Literal(tinyDecimalType, new BigDecimal("12.3"))), BooleanLiteral.FALSE); - assertRewrite(new GreaterThan(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new GreaterThan(new Cast(tinyIntSlot, tinyDecimalType), new DecimalV3Literal(tinyDecimalType, new BigDecimal("12.3"))), new GreaterThan(tinyIntSlot, new TinyIntLiteral((byte) 12))); - assertRewrite(new GreaterThanEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new GreaterThanEqual(new Cast(tinyIntSlot, tinyDecimalType), new DecimalV3Literal(tinyDecimalType, new BigDecimal("12.3"))), new GreaterThanEqual(tinyIntSlot, new TinyIntLiteral((byte) 13))); - assertRewrite(new LessThan(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new LessThan(new Cast(tinyIntSlot, tinyDecimalType), new DecimalV3Literal(tinyDecimalType, new BigDecimal("12.3"))), new LessThan(tinyIntSlot, new TinyIntLiteral((byte) 13))); - assertRewrite(new LessThanEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new LessThanEqual(new Cast(tinyIntSlot, tinyDecimalType), new DecimalV3Literal(tinyDecimalType, new BigDecimal("12.3"))), new LessThanEqual(tinyIntSlot, new TinyIntLiteral((byte) 12))); // tiny int, literal exceeds data type limit - assertRewrite(new EqualTo(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4, 1)), new DecimalV3Literal(new BigDecimal("200.0"))), + assertRewrite(new EqualTo(new Cast(tinyIntSlot, tinyDecimalType), new DecimalV3Literal(new BigDecimal("200.0"))), ExpressionUtils.falseOrNull(tinyIntSlot)); - assertRewrite(new EqualTo(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4, 1)), new DecimalV3Literal(new BigDecimal("200.3"))), + assertRewrite(new EqualTo(new Cast(tinyIntSlot, tinyDecimalType), new DecimalV3Literal(new BigDecimal("200.3"))), ExpressionUtils.falseOrNull(tinyIntSlot)); - assertRewrite(new NullSafeEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4, 1)), new DecimalV3Literal(new BigDecimal("200.3"))), + assertRewrite(new NullSafeEqual(new Cast(tinyIntSlot, tinyDecimalType), new DecimalV3Literal(new BigDecimal("200.3"))), BooleanLiteral.FALSE); - assertRewrite(new GreaterThan(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4, 1)), new DecimalV3Literal(new BigDecimal("200.3"))), + assertRewrite(new GreaterThan(new Cast(tinyIntSlot, tinyDecimalType), new DecimalV3Literal(new BigDecimal("200.3"))), ExpressionUtils.falseOrNull(tinyIntSlot)); - assertRewrite(new GreaterThanEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4, 1)), new DecimalV3Literal(new BigDecimal("200.3"))), + assertRewrite(new GreaterThanEqual(new Cast(tinyIntSlot, tinyDecimalType), new DecimalV3Literal(new BigDecimal("200.3"))), ExpressionUtils.falseOrNull(tinyIntSlot)); - assertRewrite(new LessThan(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4, 1)), new DecimalV3Literal(new BigDecimal("200.3"))), + assertRewrite(new LessThan(new Cast(tinyIntSlot, tinyDecimalType), new DecimalV3Literal(new BigDecimal("200.3"))), ExpressionUtils.trueOrNull(tinyIntSlot)); - assertRewrite(new LessThanEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4, 1)), new DecimalV3Literal(new BigDecimal("200.3"))), + assertRewrite(new LessThanEqual(new Cast(tinyIntSlot, tinyDecimalType), new DecimalV3Literal(new BigDecimal("200.3"))), ExpressionUtils.trueOrNull(tinyIntSlot)); // small int - assertRewrite(new EqualTo(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.0"))), + DecimalV3Type smallDecimalType = DecimalV3Type.createDecimalV3Type(6, 1); + assertRewrite(new EqualTo(new Cast(smallIntSlot, smallDecimalType), new DecimalV3Literal(smallDecimalType, new BigDecimal("12.0"))), new EqualTo(smallIntSlot, new SmallIntLiteral((short) 12))); - assertRewrite(new EqualTo(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new EqualTo(new Cast(smallIntSlot, smallDecimalType), new DecimalV3Literal(smallDecimalType, new BigDecimal("12.3"))), ExpressionUtils.falseOrNull(smallIntSlot)); - assertRewrite(new NullSafeEqual(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new NullSafeEqual(new Cast(smallIntSlot, smallDecimalType), new DecimalV3Literal(smallDecimalType, new BigDecimal("12.3"))), BooleanLiteral.FALSE); - assertRewrite(new GreaterThan(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new GreaterThan(new Cast(smallIntSlot, smallDecimalType), new DecimalV3Literal(smallDecimalType, new BigDecimal("12.3"))), new GreaterThan(smallIntSlot, new SmallIntLiteral((short) 12))); - assertRewrite(new GreaterThanEqual(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new GreaterThanEqual(new Cast(smallIntSlot, smallDecimalType), new DecimalV3Literal(smallDecimalType, new BigDecimal("12.3"))), new GreaterThanEqual(smallIntSlot, new SmallIntLiteral((short) 13))); - assertRewrite(new LessThan(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new LessThan(new Cast(smallIntSlot, smallDecimalType), new DecimalV3Literal(smallDecimalType, new BigDecimal("12.3"))), new LessThan(smallIntSlot, new SmallIntLiteral((short) 13))); - assertRewrite(new LessThanEqual(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new LessThanEqual(new Cast(smallIntSlot, smallDecimalType), new DecimalV3Literal(smallDecimalType, new BigDecimal("12.3"))), new LessThanEqual(smallIntSlot, new SmallIntLiteral((short) 12))); // int - assertRewrite(new EqualTo(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.0"))), + DecimalV3Type intDecimalType = DecimalV3Type.createDecimalV3Type(11, 1); + assertRewrite(new EqualTo(new Cast(intSlot, intDecimalType), new DecimalV3Literal(intDecimalType, new BigDecimal("12.0"))), new EqualTo(intSlot, new IntegerLiteral(12))); - assertRewrite(new EqualTo(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new EqualTo(new Cast(intSlot, intDecimalType), new DecimalV3Literal(intDecimalType, new BigDecimal("12.3"))), ExpressionUtils.falseOrNull(intSlot)); - assertRewrite(new NullSafeEqual(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new NullSafeEqual(new Cast(intSlot, intDecimalType), new DecimalV3Literal(intDecimalType, new BigDecimal("12.3"))), BooleanLiteral.FALSE); - assertRewrite(new GreaterThan(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new GreaterThan(new Cast(intSlot, intDecimalType), new DecimalV3Literal(intDecimalType, new BigDecimal("12.3"))), new GreaterThan(intSlot, new IntegerLiteral(12))); - assertRewrite(new GreaterThanEqual(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new GreaterThanEqual(new Cast(intSlot, intDecimalType), new DecimalV3Literal(intDecimalType, new BigDecimal("12.3"))), new GreaterThanEqual(intSlot, new IntegerLiteral(13))); - assertRewrite(new LessThan(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new LessThan(new Cast(intSlot, intDecimalType), new DecimalV3Literal(intDecimalType, new BigDecimal("12.3"))), new LessThan(intSlot, new IntegerLiteral(13))); - assertRewrite(new LessThanEqual(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new LessThanEqual(new Cast(intSlot, intDecimalType), new DecimalV3Literal(intDecimalType, new BigDecimal("12.3"))), new LessThanEqual(intSlot, new IntegerLiteral(12))); // big int - assertRewrite(new EqualTo(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.0"))), + DecimalV3Type bigDecimalType = DecimalV3Type.createDecimalV3Type(20, 1); + assertRewrite(new EqualTo(new Cast(bigIntSlot, bigDecimalType), new DecimalV3Literal(bigDecimalType, new BigDecimal("12.0"))), new EqualTo(bigIntSlot, new BigIntLiteral(12L))); - assertRewrite(new EqualTo(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new EqualTo(new Cast(bigIntSlot, bigDecimalType), new DecimalV3Literal(bigDecimalType, new BigDecimal("12.3"))), ExpressionUtils.falseOrNull(bigIntSlot)); - assertRewrite(new NullSafeEqual(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new NullSafeEqual(new Cast(bigIntSlot, bigDecimalType), new DecimalV3Literal(bigDecimalType, new BigDecimal("12.3"))), BooleanLiteral.FALSE); - assertRewrite(new GreaterThan(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new GreaterThan(new Cast(bigIntSlot, bigDecimalType), new DecimalV3Literal(bigDecimalType, new BigDecimal("12.3"))), new GreaterThan(bigIntSlot, new BigIntLiteral(12L))); - assertRewrite(new GreaterThanEqual(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new GreaterThanEqual(new Cast(bigIntSlot, bigDecimalType), new DecimalV3Literal(bigDecimalType, new BigDecimal("12.3"))), new GreaterThanEqual(bigIntSlot, new BigIntLiteral(13L))); - assertRewrite(new LessThan(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new LessThan(new Cast(bigIntSlot, bigDecimalType), new DecimalV3Literal(bigDecimalType, new BigDecimal("12.3"))), new LessThan(bigIntSlot, new BigIntLiteral(13L))); - assertRewrite(new LessThanEqual(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new LessThanEqual(new Cast(bigIntSlot, bigDecimalType), new DecimalV3Literal(bigDecimalType, new BigDecimal("12.3"))), new LessThanEqual(bigIntSlot, new BigIntLiteral(12L))); - assertRewrite(new LessThan(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(20, 1)), new DecimalV3Literal(new BigDecimal("-9223372036854775808.1"))), + assertRewrite(new LessThan(new Cast(bigIntSlot, bigDecimalType), new DecimalV3Literal(new BigDecimal("-9223372036854775808.1"))), ExpressionUtils.falseOrNull(bigIntSlot)); - assertRewrite(new GreaterThanEqual(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(20, 1)), new DecimalV3Literal(new BigDecimal("-9223372036854775808.1"))), + assertRewrite(new GreaterThanEqual(new Cast(bigIntSlot, bigDecimalType), new DecimalV3Literal(new BigDecimal("-9223372036854775808.1"))), ExpressionUtils.trueOrNull(bigIntSlot)); - assertRewrite(new LessThan(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(20, 1)), new DecimalV3Literal(new BigDecimal("-9223372036854775807.1"))), + assertRewrite(new LessThan(new Cast(bigIntSlot, bigDecimalType), new DecimalV3Literal(new BigDecimal("-9223372036854775807.1"))), new LessThan(bigIntSlot, new BigIntLiteral(-9223372036854775807L))); - assertRewrite(new GreaterThanEqual(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(20, 1)), new DecimalV3Literal(new BigDecimal("9223372036854775807.1"))), + assertRewrite(new GreaterThanEqual(new Cast(bigIntSlot, bigDecimalType), new DecimalV3Literal(new BigDecimal("9223372036854775807.1"))), ExpressionUtils.falseOrNull(bigIntSlot)); - assertRewrite(new LessThan(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(20, 1)), new DecimalV3Literal(new BigDecimal("9223372036854775806.1"))), + assertRewrite(new LessThan(new Cast(bigIntSlot, bigDecimalType), new DecimalV3Literal(new BigDecimal("9223372036854775806.1"))), new LessThan(bigIntSlot, new BigIntLiteral(9223372036854775807L))); + + // do not convert cast(int as decimal) if decimal's range < int's range + DecimalV3Type noConvertDecimalType = DecimalV3Type.createDecimalV3Type(3, 1); + assertRewrite(new EqualTo(new Cast(tinyIntSlot, noConvertDecimalType), new DecimalV3Literal(new BigDecimal("12.0"))), + new EqualTo(new Cast(tinyIntSlot, noConvertDecimalType), new DecimalV3Literal(new BigDecimal("12.0")))); + assertRewrite(new EqualTo(new Cast(smallIntSlot, noConvertDecimalType), new DecimalV3Literal(new BigDecimal("12.0"))), + new EqualTo(new Cast(smallIntSlot, noConvertDecimalType), new DecimalV3Literal(new BigDecimal("12.0")))); + assertRewrite(new EqualTo(new Cast(intSlot, noConvertDecimalType), new DecimalV3Literal(new BigDecimal("12.0"))), + new EqualTo(new Cast(intSlot, noConvertDecimalType), new DecimalV3Literal(new BigDecimal("12.0")))); + assertRewrite(new EqualTo(new Cast(bigIntSlot, noConvertDecimalType), new DecimalV3Literal(new BigDecimal("12.0"))), + new EqualTo(new Cast(bigIntSlot, noConvertDecimalType), new DecimalV3Literal(new BigDecimal("12.0")))); } @Test @@ -904,6 +919,14 @@ void testDecimalCmpDecimalV3Literal() { Assertions.assertInstanceOf(DecimalV3Literal.class, rewrittenExpression.child(1)); Assertions.assertEquals(new BigDecimal("1.20000"), ((DecimalV3Literal) rewrittenExpression.child(1)).getValue()); + + // don't convert for cast(b as decimal) if b's range > decimal's range + // or b's range = decimal's range and b's scale < decimal's scale + SlotReference decimalSlot = new SlotReference("slot", DecimalV3Type.createDecimalV3Type(5, 2), true); + assertRewrite(new EqualTo(new Cast(decimalSlot, DecimalV3Type.createDecimalV3Type(7, 5)), new DecimalV3Literal(new BigDecimal("12.34567"))), + new EqualTo(new Cast(decimalSlot, DecimalV3Type.createDecimalV3Type(7, 5)), new DecimalV3Literal(new BigDecimal("12.34567")))); + assertRewrite(new EqualTo(new Cast(decimalSlot, DecimalV3Type.createDecimalV3Type(4, 1)), new DecimalV3Literal(new BigDecimal("123.4"))), + new EqualTo(new Cast(decimalSlot, DecimalV3Type.createDecimalV3Type(4, 1)), new DecimalV3Literal(new BigDecimal("123.4")))); } private enum RangeLimitResult { diff --git a/regression-test/data/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.out b/regression-test/data/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.out index dd4dccdf839cae..59ae6df0d908b5 100644 --- a/regression-test/data/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.out +++ b/regression-test/data/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.out @@ -8,113 +8,166 @@ PhysicalResultSink ------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] -- !cast_bigint_as_double_result -- -100 870479087484055553 200 870479087484055553 +100 870479087484055553 200 870479087484055553 999.999 --- !tbl3_float_neg16777216_neg10 -- +-- !float_neg16777216_neg10 -- PhysicalResultSink --PhysicalProject ----filter((cast(c_bigint as FLOAT) = -1.6777226E7)) -------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] --- !tbl3_float_neg16777216_0 -- +-- !float_neg16777216_0 -- PhysicalResultSink --PhysicalProject ----filter((cast(c_bigint as FLOAT) = -1.6777216E7)) -------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] --- !tbl3_float_neg16777216_10 -- +-- !float_neg16777216_10 -- PhysicalResultSink --PhysicalProject -----filter((tbl3_test_simplify_comparison_predicate_int_vs_double.c_bigint = -16777206)) -------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] +----filter((tbl2_test_simplify_comparison_predicate_int_vs_double.c_bigint = -16777206)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] --- !tbl3_float_16777216_neg10 -- +-- !float_16777216_neg10 -- PhysicalResultSink --PhysicalProject -----filter((tbl3_test_simplify_comparison_predicate_int_vs_double.c_bigint = 16777206)) -------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] +----filter((tbl2_test_simplify_comparison_predicate_int_vs_double.c_bigint = 16777206)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] --- !tbl3_float_16777216_0 -- +-- !float_16777216_0 -- PhysicalResultSink --PhysicalProject ----filter((cast(c_bigint as FLOAT) = 1.6777216E7)) -------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] --- !tbl3_float_16777216_10 -- +-- !float_16777216_10 -- PhysicalResultSink --PhysicalProject ----filter((cast(c_bigint as FLOAT) = 1.6777226E7)) -------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] --- !tbl3_double_neg9007199254740992_neg10 -- +-- !double_neg9007199254740992_neg10 -- PhysicalResultSink --PhysicalProject ----filter((cast(c_bigint as DOUBLE) = -9.007199254741002E15)) -------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] --- !tbl3_float_neg9007199254740992_neg10 -- +-- !float_neg9007199254740992_neg10 -- PhysicalResultSink --PhysicalProject ----filter((cast(c_bigint as FLOAT) = -9.0071993E15)) -------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] --- !tbl3_double_neg9007199254740992_0 -- +-- !double_neg9007199254740992_0 -- PhysicalResultSink --PhysicalProject ----filter((cast(c_bigint as DOUBLE) = -9.007199254740992E15)) -------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] --- !tbl3_float_neg9007199254740992_0 -- +-- !float_neg9007199254740992_0 -- PhysicalResultSink --PhysicalProject ----filter((cast(c_bigint as FLOAT) = -9.0071993E15)) -------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] --- !tbl3_double_neg9007199254740992_10 -- +-- !double_neg9007199254740992_10 -- PhysicalResultSink --PhysicalProject -----filter((tbl3_test_simplify_comparison_predicate_int_vs_double.c_bigint = -9007199254740982)) -------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] +----filter((tbl2_test_simplify_comparison_predicate_int_vs_double.c_bigint = -9007199254740982)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] --- !tbl3_float_neg9007199254740992_10 -- +-- !float_neg9007199254740992_10 -- PhysicalResultSink --PhysicalProject ----filter((cast(c_bigint as FLOAT) = -9.0071993E15)) -------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] --- !tbl3_double_9007199254740992_neg10 -- +-- !double_9007199254740992_neg10 -- PhysicalResultSink --PhysicalProject -----filter((tbl3_test_simplify_comparison_predicate_int_vs_double.c_bigint = 9007199254740982)) -------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] +----filter((tbl2_test_simplify_comparison_predicate_int_vs_double.c_bigint = 9007199254740982)) +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] --- !tbl3_float_9007199254740992_neg10 -- +-- !float_9007199254740992_neg10 -- PhysicalResultSink --PhysicalProject ----filter((cast(c_bigint as FLOAT) = 9.0071993E15)) -------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] --- !tbl3_double_9007199254740992_0 -- +-- !double_9007199254740992_0 -- PhysicalResultSink --PhysicalProject ----filter((cast(c_bigint as DOUBLE) = 9.007199254740992E15)) -------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] --- !tbl3_float_9007199254740992_0 -- +-- !float_9007199254740992_0 -- PhysicalResultSink --PhysicalProject ----filter((cast(c_bigint as FLOAT) = 9.0071993E15)) -------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] --- !tbl3_double_9007199254740992_10 -- +-- !double_9007199254740992_10 -- PhysicalResultSink --PhysicalProject ----filter((cast(c_bigint as DOUBLE) = 9.007199254741002E15)) -------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] --- !tbl3_float_9007199254740992_10 -- +-- !float_9007199254740992_10 -- PhysicalResultSink --PhysicalProject ----filter((cast(c_bigint as FLOAT) = 9.0071993E15)) -------PhysicalOlapScan[tbl3_test_simplify_comparison_predicate_int_vs_double] +------PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !int_vs_double_1_shape -- +PhysicalResultSink +--filter((tbl2_test_simplify_comparison_predicate_int_vs_double.c_bigint > 123)) +----PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !int_vs_double_1_result -- +200 870479087484055553 999.999 + +-- !int_vs_double_2_shape -- +PhysicalResultSink +--filter((cast(c_bigint as DECIMALV3(7, 3)) > 123.456)) +----PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !int_vs_double_2_result -- + +-- !decimal_vs_decimal_1_shape -- +PhysicalResultSink +--filter((tbl2_test_simplify_comparison_predicate_int_vs_double.c_decimal > 123.456)) +----PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !decimal_vs_decimal_1_result -- +200 870479087484055553 999.999 + +-- !decimal_vs_decimal_2_shape -- +PhysicalResultSink +--filter((cast(c_decimal as DECIMALV3(3, 1)) > 12.3)) +----PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !decimal_vs_decimal_2_result -- + +-- !decimal_vs_decimal_3_shape -- +PhysicalResultSink +--filter((cast(c_decimal as DECIMALV3(5, 2)) > 123.45)) +----PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !decimal_vs_decimal_3_result -- + +-- !int_vs_double_3_shape -- +PhysicalResultSink +--filter((tbl2_test_simplify_comparison_predicate_int_vs_double.c_bigint > 123)) +----PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !int_vs_double_3_result -- +200 870479087484055553 999.999 + +-- !decimal_vs_decimal_4_shape -- +PhysicalResultSink +--filter((tbl2_test_simplify_comparison_predicate_int_vs_double.c_decimal > 123.456)) +----PhysicalOlapScan[tbl2_test_simplify_comparison_predicate_int_vs_double] + +-- !decimal_vs_decimal_4_result -- +200 870479087484055553 999.999 diff --git a/regression-test/suites/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.groovy b/regression-test/suites/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.groovy index 26bfd61fbc72d3..cfe0b435ef24af 100644 --- a/regression-test/suites/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.groovy +++ b/regression-test/suites/nereids_rules_p0/expression/test_simplify_comparison_predicate_int_vs_double.groovy @@ -24,22 +24,17 @@ suite('test_simplify_comparison_predicate_int_vs_double') { drop table if exists tbl1_test_simplify_comparison_predicate_int_vs_double force; drop table if exists tbl2_test_simplify_comparison_predicate_int_vs_double force; - drop table if exists tbl3_test_simplify_comparison_predicate_int_vs_double force; create table tbl1_test_simplify_comparison_predicate_int_vs_double (k1 int, c_s varchar(100)) properties('replication_num' = '1'); create table tbl2_test_simplify_comparison_predicate_int_vs_double - (k2 int, c_bigint bigint) properties('replication_num' = '1'); - create table tbl3_test_simplify_comparison_predicate_int_vs_double - (k3 int, c_bigint bigint) properties('replication_num' = '1'); + (k2 int, c_bigint bigint, c_decimal decimal(6, 3)) properties('replication_num' = '1'); insert into tbl1_test_simplify_comparison_predicate_int_vs_double values (100, "870479087484055553"), (101,"870479087484055554"); insert into tbl2_test_simplify_comparison_predicate_int_vs_double values - (200, 870479087484055553); - insert into tbl3_test_simplify_comparison_predicate_int_vs_double values - (300, 870479087484055553); + (200, 870479087484055553, 999.999); """ @@ -53,11 +48,11 @@ suite('test_simplify_comparison_predicate_int_vs_double') { for (def delimit : [-(1L<<24), 1L<<24]) { for (def diff : [-10, 0, 10]) { - def tag = "tbl3_float_${delimit}_${diff}".replace('-', 'neg') + def tag = "float_${delimit}_${diff}".replace('-', 'neg') "qt_${tag}" """ explain shape plan select c_bigint - from tbl3_test_simplify_comparison_predicate_int_vs_double + from tbl2_test_simplify_comparison_predicate_int_vs_double where cast(c_bigint as float) = cast('${delimit + diff}' as float) """ } @@ -65,27 +60,102 @@ suite('test_simplify_comparison_predicate_int_vs_double') { for (def delimit : [-(1L<<53), 1L<<53]) { for (def diff : [-10, 0, 10]) { - def tag = "tbl3_double_${delimit}_${diff}".replace('-', 'neg') + def tag = "double_${delimit}_${diff}".replace('-', 'neg') "qt_${tag}" """ explain shape plan select c_bigint - from tbl3_test_simplify_comparison_predicate_int_vs_double + from tbl2_test_simplify_comparison_predicate_int_vs_double where cast(c_bigint as double) = cast('${delimit + diff}' as double) """ - tag = "tbl3_float_${delimit}_${diff}".replace('-', 'neg') + tag = "float_${delimit}_${diff}".replace('-', 'neg') "qt_${tag}" """ explain shape plan select c_bigint - from tbl3_test_simplify_comparison_predicate_int_vs_double + from tbl2_test_simplify_comparison_predicate_int_vs_double where cast(c_bigint as float) = cast('${delimit + diff}' as float) """ } } + sql "set enable_strict_cast=false" + + explainAndOrderResult 'int_vs_double_1', """ + select * + from tbl2_test_simplify_comparison_predicate_int_vs_double + where c_bigint > 123.456 + """ + + explainAndOrderResult 'int_vs_double_2', """ + select * + from tbl2_test_simplify_comparison_predicate_int_vs_double + where cast(c_bigint as decimal(7, 3)) > 123.456 + """ + + explainAndOrderResult 'decimal_vs_decimal_1', """ + select * + from tbl2_test_simplify_comparison_predicate_int_vs_double + where c_decimal > 123.4567 + """ + + explainAndOrderResult 'decimal_vs_decimal_2', """ + select * + from tbl2_test_simplify_comparison_predicate_int_vs_double + where cast(c_decimal as decimal(3,1)) > cast(12.3 as decimal(5, 1)) + """ + + explainAndOrderResult 'decimal_vs_decimal_3', """ + select * + from tbl2_test_simplify_comparison_predicate_int_vs_double + where cast(c_decimal as decimal(5,2)) > cast(123.45 as decimal(5, 2)) + """ + + sql "set enable_strict_cast=true" + + explainAndOrderResult 'int_vs_double_3', """ + select * + from tbl2_test_simplify_comparison_predicate_int_vs_double + where c_bigint > 123.456 + """ + + explainAndOrderResult 'decimal_vs_decimal_4', """ + select * + from tbl2_test_simplify_comparison_predicate_int_vs_double + where c_decimal > 123.4567 + """ + + test { + sql """ + select * + from tbl2_test_simplify_comparison_predicate_int_vs_double + where cast(c_bigint as decimal(7, 3)) > 123.456 + """ + + exception 'Arithmetic overflow when converting value' + } + + test { + sql """ + select * + from tbl2_test_simplify_comparison_predicate_int_vs_double + where cast(c_decimal as decimal(3,1)) > cast(12.3 as decimal(5, 1)) + """ + + exception 'Arithmetic overflow when converting value' + } + + test { + sql """ + select * + from tbl2_test_simplify_comparison_predicate_int_vs_double + where cast(c_decimal as decimal(5,2)) > cast(123.45 as decimal(5, 2)) + """ + + exception 'Arithmetic overflow when converting value' + } + sql """ drop table if exists tbl1_test_simplify_comparison_predicate_int_vs_double force; drop table if exists tbl2_test_simplify_comparison_predicate_int_vs_double force; - drop table if exists tbl3_test_simplify_comparison_predicate_int_vs_double force; """ }