From b256d3d5f868ab593efbdcade2edfdf63b3f97e9 Mon Sep 17 00:00:00 2001 From: sundyli <543950155@qq.com> Date: Sat, 22 Jun 2024 21:40:01 +0800 Subject: [PATCH] fix(query): decimal div overflow (#15856) * fix(query): decimal div overflow * fix(query): decimal div overflow * fix(query): decimal div overflow --- src/query/expression/src/types/decimal.rs | 20 +++++++++++++++++++ .../src/scalars/decimal/arithmetic.rs | 14 +++++++++---- .../11_0006_data_type_decimal.test | 15 ++++++++++++++ 3 files changed, 45 insertions(+), 4 deletions(-) diff --git a/src/query/expression/src/types/decimal.rs b/src/query/expression/src/types/decimal.rs index 71334284f129..ccf96c7bd040 100644 --- a/src/query/expression/src/types/decimal.rs +++ b/src/query/expression/src/types/decimal.rs @@ -348,6 +348,8 @@ pub trait Decimal: fn checked_mul(self, rhs: Self) -> Option; fn checked_rem(self, rhs: Self) -> Option; + fn do_round_div(self, rhs: Self, mul: Self) -> Option; + fn min_for_precision(precision: u8) -> Self; fn max_for_precision(precision: u8) -> Self; @@ -443,6 +445,16 @@ impl Decimal for i128 { self.checked_rem(rhs) } + fn do_round_div(self, rhs: Self, mul: Self) -> Option { + if self.is_negative() == rhs.is_negative() { + let res = (i256::from(self) * i256::from(mul) + i256::from(rhs) / 2) / i256::from(rhs); + Some(*res.low()) + } else { + let res = (i256::from(self) * i256::from(mul) - i256::from(rhs) / 2) / i256::from(rhs); + Some(*res.low()) + } + } + fn min_for_precision(to_precision: u8) -> Self { MIN_DECIMAL_FOR_EACH_PRECISION[to_precision as usize - 1] } @@ -647,6 +659,14 @@ impl Decimal for i256 { self.checked_rem(rhs) } + fn do_round_div(self, rhs: Self, mul: Self) -> Option { + if self.is_negative() == rhs.is_negative() { + self.checked_mul(mul).map(|x| (x + rhs / 2) / rhs) + } else { + self.checked_mul(mul).map(|x| (x - rhs / 2) / rhs) + } + } + fn min_for_precision(to_precision: u8) -> Self { MIN_DECIMAL256_BYTES_FOR_EACH_PRECISION[to_precision as usize - 1] } diff --git a/src/query/functions/src/scalars/decimal/arithmetic.rs b/src/query/functions/src/scalars/decimal/arithmetic.rs index 7bc9b63b20b9..088595e88e95 100644 --- a/src/query/functions/src/scalars/decimal/arithmetic.rs +++ b/src/query/functions/src/scalars/decimal/arithmetic.rs @@ -88,7 +88,6 @@ macro_rules! binary_decimal { let scale_a = $left.scale(); let scale_b = $right.scale(); - // Note: the result scale is always larger than the left scale let scale_mul = scale_b + $size.scale - scale_a; let multiplier = T::e(scale_mul as u32); @@ -102,10 +101,17 @@ macro_rules! binary_decimal { if std::intrinsics::unlikely(b == zero) { ctx.set_error(result.len(), "divided by zero"); result.push(one); - } else if a.is_negative() == b.is_negative() { - result.push((a * multiplier + b / 2).div(b)); } else { - result.push((a * multiplier - b / 2).div(b)); + match a.do_round_div(b, multiplier) { + Some(t) => result.push(t), + None => { + ctx.set_error( + result.len(), + concat!("Decimal overflow at line : ", line!()), + ); + result.push(one); + } + } } }; diff --git a/tests/sqllogictests/suites/base/11_data_type/11_0006_data_type_decimal.test b/tests/sqllogictests/suites/base/11_data_type/11_0006_data_type_decimal.test index 43ac9c2f2a82..ddf1a4198930 100644 --- a/tests/sqllogictests/suites/base/11_data_type/11_0006_data_type_decimal.test +++ b/tests/sqllogictests/suites/base/11_data_type/11_0006_data_type_decimal.test @@ -266,6 +266,10 @@ SELECT CAST(987654321.34 AS DECIMAL(76, 2)) / CAST(1.23 AS DECIMAL(76, 2)) AS re ---- 802970992.95934959 +query I +SELECT 404.754480000000000000000001 / 563.653044520000000000000001, 404.754480000000000000000000 / 563.653044520000000000000000; +---- +0.718091535094401799683905 0.718091535094401799683905 ## negative @@ -1114,5 +1118,16 @@ select cast(b as int), cast(c as int), cast(d as int) from t -1 -1 -1 -2 -2 -2 +statement ok +create table decimal_test2(a decimal(28,8), b decimal(24,16)); + +statement ok +insert into decimal_test2 values(300.07878791,5325.0000000000000000),(2.00000000,10491.0000000000000000); + +query I +select sum(a * b) / sum(a * b), sum(a + b) / sum(a + b), sum(a - b) / sum(a - b) from decimal_test2; +---- +1.000000000000000000000000 1.0000000000000000 1.0000000000000000 + statement ok drop database decimal_t;