Skip to content

Commit 74ea1d5

Browse files
AdamGSJefffrey
andcommitted
More decimal 32/64 support - type coercsion and misc gaps (#17808)
* More small decimal support * CR comments Co-authored-by: Jeffrey Vo <jeffrey.vo.australia@gmail.com> * Add tests and cleanup some code --------- Co-authored-by: Jeffrey Vo <jeffrey.vo.australia@gmail.com>
1 parent d8b844d commit 74ea1d5

File tree

7 files changed

+353
-22
lines changed

7 files changed

+353
-22
lines changed

datafusion/common/src/scalar/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,6 +1366,12 @@ impl ScalarValue {
13661366
DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(0.0))),
13671367
DataType::Float32 => ScalarValue::Float32(Some(0.0)),
13681368
DataType::Float64 => ScalarValue::Float64(Some(0.0)),
1369+
DataType::Decimal32(precision, scale) => {
1370+
ScalarValue::Decimal32(Some(0), *precision, *scale)
1371+
}
1372+
DataType::Decimal64(precision, scale) => {
1373+
ScalarValue::Decimal64(Some(0), *precision, *scale)
1374+
}
13691375
DataType::Decimal128(precision, scale) => {
13701376
ScalarValue::Decimal128(Some(0), *precision, *scale)
13711377
}

datafusion/expr-common/src/type_coercion/binary.rs

Lines changed: 110 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,16 @@ impl<'a> BinaryTypeCoercer<'a> {
320320

321321
// TODO Move the rest inside of BinaryTypeCoercer
322322

323+
fn is_decimal(data_type: &DataType) -> bool {
324+
matches!(
325+
data_type,
326+
DataType::Decimal32(..)
327+
| DataType::Decimal64(..)
328+
| DataType::Decimal128(..)
329+
| DataType::Decimal256(..)
330+
)
331+
}
332+
323333
/// Coercion rules for mathematics operators between decimal and non-decimal types.
324334
fn math_decimal_coercion(
325335
lhs_type: &DataType,
@@ -350,6 +360,15 @@ fn math_decimal_coercion(
350360
| (Decimal256(_, _), Decimal256(_, _)) => {
351361
Some((lhs_type.clone(), rhs_type.clone()))
352362
}
363+
// Cross-variant decimal coercion - choose larger variant with appropriate precision/scale
364+
(lhs, rhs)
365+
if is_decimal(lhs)
366+
&& is_decimal(rhs)
367+
&& std::mem::discriminant(lhs) != std::mem::discriminant(rhs) =>
368+
{
369+
let coerced_type = get_wider_decimal_type_cross_variant(lhs_type, rhs_type)?;
370+
Some((coerced_type.clone(), coerced_type))
371+
}
353372
// Unlike with comparison we don't coerce to a decimal in the case of floating point
354373
// numbers, instead falling back to floating point arithmetic instead
355374
(
@@ -946,21 +965,92 @@ pub fn binary_numeric_coercion(
946965
pub fn decimal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
947966
use arrow::datatypes::DataType::*;
948967

968+
// Prefer decimal data type over floating point for comparison operation
949969
match (lhs_type, rhs_type) {
950-
// Prefer decimal data type over floating point for comparison operation
951-
(Decimal128(_, _), Decimal128(_, _)) => {
970+
// Same decimal types
971+
(lhs_type, rhs_type)
972+
if is_decimal(lhs_type)
973+
&& is_decimal(rhs_type)
974+
&& std::mem::discriminant(lhs_type)
975+
== std::mem::discriminant(rhs_type) =>
976+
{
952977
get_wider_decimal_type(lhs_type, rhs_type)
953978
}
954-
(Decimal128(_, _), _) => get_common_decimal_type(lhs_type, rhs_type),
955-
(_, Decimal128(_, _)) => get_common_decimal_type(rhs_type, lhs_type),
956-
(Decimal256(_, _), Decimal256(_, _)) => {
957-
get_wider_decimal_type(lhs_type, rhs_type)
979+
// Mismatched decimal types
980+
(lhs_type, rhs_type)
981+
if is_decimal(lhs_type)
982+
&& is_decimal(rhs_type)
983+
&& std::mem::discriminant(lhs_type)
984+
!= std::mem::discriminant(rhs_type) =>
985+
{
986+
get_wider_decimal_type_cross_variant(lhs_type, rhs_type)
987+
}
988+
// Decimal + non-decimal types
989+
(Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _), _) => {
990+
get_common_decimal_type(lhs_type, rhs_type)
991+
}
992+
(_, Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _)) => {
993+
get_common_decimal_type(rhs_type, lhs_type)
958994
}
959-
(Decimal256(_, _), _) => get_common_decimal_type(lhs_type, rhs_type),
960-
(_, Decimal256(_, _)) => get_common_decimal_type(rhs_type, lhs_type),
961995
(_, _) => None,
962996
}
963997
}
998+
/// Handle cross-variant decimal widening by choosing the larger variant
999+
fn get_wider_decimal_type_cross_variant(
1000+
lhs_type: &DataType,
1001+
rhs_type: &DataType,
1002+
) -> Option<DataType> {
1003+
use arrow::datatypes::DataType::*;
1004+
1005+
let (p1, s1) = match lhs_type {
1006+
Decimal32(p, s) => (*p, *s),
1007+
Decimal64(p, s) => (*p, *s),
1008+
Decimal128(p, s) => (*p, *s),
1009+
Decimal256(p, s) => (*p, *s),
1010+
_ => return None,
1011+
};
1012+
1013+
let (p2, s2) = match rhs_type {
1014+
Decimal32(p, s) => (*p, *s),
1015+
Decimal64(p, s) => (*p, *s),
1016+
Decimal128(p, s) => (*p, *s),
1017+
Decimal256(p, s) => (*p, *s),
1018+
_ => return None,
1019+
};
1020+
1021+
// max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)
1022+
let s = s1.max(s2);
1023+
let range = (p1 as i8 - s1).max(p2 as i8 - s2);
1024+
let required_precision = (range + s) as u8;
1025+
1026+
// Choose the larger variant between the two input types, while making sure we don't overflow the precision.
1027+
match (lhs_type, rhs_type) {
1028+
(Decimal32(_, _), Decimal64(_, _)) | (Decimal64(_, _), Decimal32(_, _))
1029+
if required_precision <= DECIMAL64_MAX_PRECISION =>
1030+
{
1031+
Some(Decimal64(required_precision, s))
1032+
}
1033+
(Decimal32(_, _), Decimal128(_, _))
1034+
| (Decimal128(_, _), Decimal32(_, _))
1035+
| (Decimal64(_, _), Decimal128(_, _))
1036+
| (Decimal128(_, _), Decimal64(_, _))
1037+
if required_precision <= DECIMAL128_MAX_PRECISION =>
1038+
{
1039+
Some(Decimal128(required_precision, s))
1040+
}
1041+
(Decimal32(_, _), Decimal256(_, _))
1042+
| (Decimal256(_, _), Decimal32(_, _))
1043+
| (Decimal64(_, _), Decimal256(_, _))
1044+
| (Decimal256(_, _), Decimal64(_, _))
1045+
| (Decimal128(_, _), Decimal256(_, _))
1046+
| (Decimal256(_, _), Decimal128(_, _))
1047+
if required_precision <= DECIMAL256_MAX_PRECISION =>
1048+
{
1049+
Some(Decimal256(required_precision, s))
1050+
}
1051+
_ => None,
1052+
}
1053+
}
9641054

9651055
/// Coerce `lhs_type` and `rhs_type` to a common type.
9661056
fn get_common_decimal_type(
@@ -969,7 +1059,15 @@ fn get_common_decimal_type(
9691059
) -> Option<DataType> {
9701060
use arrow::datatypes::DataType::*;
9711061
match decimal_type {
972-
Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) => {
1062+
Decimal32(_, _) => {
1063+
let other_decimal_type = coerce_numeric_type_to_decimal32(other_type)?;
1064+
get_wider_decimal_type(decimal_type, &other_decimal_type)
1065+
}
1066+
Decimal64(_, _) => {
1067+
let other_decimal_type = coerce_numeric_type_to_decimal64(other_type)?;
1068+
get_wider_decimal_type(decimal_type, &other_decimal_type)
1069+
}
1070+
Decimal128(_, _) => {
9731071
let other_decimal_type = coerce_numeric_type_to_decimal128(other_type)?;
9741072
get_wider_decimal_type(decimal_type, &other_decimal_type)
9751073
}
@@ -981,7 +1079,7 @@ fn get_common_decimal_type(
9811079
}
9821080
}
9831081

984-
/// Returns a `DataType::Decimal128` that can store any value from either
1082+
/// Returns a decimal [`DataType`] variant that can store any value from either
9851083
/// `lhs_decimal_type` and `rhs_decimal_type`
9861084
///
9871085
/// The result decimal type is `(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2))`.
@@ -1202,14 +1300,14 @@ fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataTy
12021300
}
12031301

12041302
fn create_decimal32_type(precision: u8, scale: i8) -> DataType {
1205-
DataType::Decimal128(
1303+
DataType::Decimal32(
12061304
DECIMAL32_MAX_PRECISION.min(precision),
12071305
DECIMAL32_MAX_SCALE.min(scale),
12081306
)
12091307
}
12101308

12111309
fn create_decimal64_type(precision: u8, scale: i8) -> DataType {
1212-
DataType::Decimal128(
1310+
DataType::Decimal64(
12131311
DECIMAL64_MAX_PRECISION.min(precision),
12141312
DECIMAL64_MAX_SCALE.min(scale),
12151313
)

datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,3 +291,133 @@ fn test_coercion_arithmetic_decimal() -> Result<()> {
291291

292292
Ok(())
293293
}
294+
295+
#[test]
296+
fn test_coercion_arithmetic_decimal_cross_variant() -> Result<()> {
297+
let test_cases = [
298+
(
299+
DataType::Decimal32(5, 2),
300+
DataType::Decimal64(10, 3),
301+
DataType::Decimal64(10, 3),
302+
DataType::Decimal64(10, 3),
303+
),
304+
(
305+
DataType::Decimal32(7, 1),
306+
DataType::Decimal128(15, 4),
307+
DataType::Decimal128(15, 4),
308+
DataType::Decimal128(15, 4),
309+
),
310+
(
311+
DataType::Decimal32(9, 0),
312+
DataType::Decimal256(20, 5),
313+
DataType::Decimal256(20, 5),
314+
DataType::Decimal256(20, 5),
315+
),
316+
(
317+
DataType::Decimal64(12, 3),
318+
DataType::Decimal128(18, 2),
319+
DataType::Decimal128(19, 3),
320+
DataType::Decimal128(19, 3),
321+
),
322+
(
323+
DataType::Decimal64(15, 4),
324+
DataType::Decimal256(25, 6),
325+
DataType::Decimal256(25, 6),
326+
DataType::Decimal256(25, 6),
327+
),
328+
(
329+
DataType::Decimal128(20, 5),
330+
DataType::Decimal256(30, 8),
331+
DataType::Decimal256(30, 8),
332+
DataType::Decimal256(30, 8),
333+
),
334+
// Reverse order cases
335+
(
336+
DataType::Decimal64(10, 3),
337+
DataType::Decimal32(5, 2),
338+
DataType::Decimal64(10, 3),
339+
DataType::Decimal64(10, 3),
340+
),
341+
(
342+
DataType::Decimal128(15, 4),
343+
DataType::Decimal32(7, 1),
344+
DataType::Decimal128(15, 4),
345+
DataType::Decimal128(15, 4),
346+
),
347+
(
348+
DataType::Decimal256(20, 5),
349+
DataType::Decimal32(9, 0),
350+
DataType::Decimal256(20, 5),
351+
DataType::Decimal256(20, 5),
352+
),
353+
(
354+
DataType::Decimal128(18, 2),
355+
DataType::Decimal64(12, 3),
356+
DataType::Decimal128(19, 3),
357+
DataType::Decimal128(19, 3),
358+
),
359+
(
360+
DataType::Decimal256(25, 6),
361+
DataType::Decimal64(15, 4),
362+
DataType::Decimal256(25, 6),
363+
DataType::Decimal256(25, 6),
364+
),
365+
(
366+
DataType::Decimal256(30, 8),
367+
DataType::Decimal128(20, 5),
368+
DataType::Decimal256(30, 8),
369+
DataType::Decimal256(30, 8),
370+
),
371+
];
372+
373+
for (lhs_type, rhs_type, expected_lhs_type, expected_rhs_type) in test_cases {
374+
test_math_decimal_coercion_rule(
375+
lhs_type,
376+
rhs_type,
377+
expected_lhs_type,
378+
expected_rhs_type,
379+
);
380+
}
381+
382+
Ok(())
383+
}
384+
385+
#[test]
386+
fn test_decimal_precision_overflow_cross_variant() -> Result<()> {
387+
// s = max(0, 1) = 1, range = max(76-0, 38-1) = 76, required_precision = 76 + 1 = 77 (overflow)
388+
let result = get_wider_decimal_type_cross_variant(
389+
&DataType::Decimal256(76, 0),
390+
&DataType::Decimal128(38, 1),
391+
);
392+
assert!(result.is_none());
393+
394+
// s = max(0, 10) = 10, range = max(9-0, 18-10) = 9, required_precision = 9 + 10 = 19 (overflow > 18)
395+
let result = get_wider_decimal_type_cross_variant(
396+
&DataType::Decimal32(9, 0),
397+
&DataType::Decimal64(18, 10),
398+
);
399+
assert!(result.is_none());
400+
401+
// s = max(5, 26) = 26, range = max(18-5, 38-26) = 13, required_precision = 13 + 26 = 39 (overflow > 38)
402+
let result = get_wider_decimal_type_cross_variant(
403+
&DataType::Decimal64(18, 5),
404+
&DataType::Decimal128(38, 26),
405+
);
406+
assert!(result.is_none());
407+
408+
// s = max(10, 49) = 49, range = max(38-10, 76-49) = 28, required_precision = 28 + 49 = 77 (overflow > 76)
409+
let result = get_wider_decimal_type_cross_variant(
410+
&DataType::Decimal128(38, 10),
411+
&DataType::Decimal256(76, 49),
412+
);
413+
assert!(result.is_none());
414+
415+
// s = max(2, 3) = 3, range = max(5-2, 10-3) = 7, required_precision = 7 + 3 = 10 (valid <= 18)
416+
let result = get_wider_decimal_type_cross_variant(
417+
&DataType::Decimal32(5, 2),
418+
&DataType::Decimal64(10, 3),
419+
);
420+
assert!(result.is_some());
421+
422+
Ok(())
423+
}

0 commit comments

Comments
 (0)