diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 8c079056e21d..bba994dd11b5 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -1362,6 +1362,12 @@ impl ScalarValue { DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(0.0))), DataType::Float32 => ScalarValue::Float32(Some(0.0)), DataType::Float64 => ScalarValue::Float64(Some(0.0)), + DataType::Decimal32(precision, scale) => { + ScalarValue::Decimal32(Some(0), *precision, *scale) + } + DataType::Decimal64(precision, scale) => { + ScalarValue::Decimal64(Some(0), *precision, *scale) + } DataType::Decimal128(precision, scale) => { ScalarValue::Decimal128(Some(0), *precision, *scale) } diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index 1c99f49d26cf..52bb211d9b99 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -327,6 +327,16 @@ impl<'a> BinaryTypeCoercer<'a> { // TODO Move the rest inside of BinaryTypeCoercer +fn is_decimal(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::Decimal32(..) + | DataType::Decimal64(..) + | DataType::Decimal128(..) + | DataType::Decimal256(..) + ) +} + /// Coercion rules for mathematics operators between decimal and non-decimal types. fn math_decimal_coercion( lhs_type: &DataType, @@ -357,6 +367,15 @@ fn math_decimal_coercion( | (Decimal256(_, _), Decimal256(_, _)) => { Some((lhs_type.clone(), rhs_type.clone())) } + // Cross-variant decimal coercion - choose larger variant with appropriate precision/scale + (lhs, rhs) + if is_decimal(lhs) + && is_decimal(rhs) + && std::mem::discriminant(lhs) != std::mem::discriminant(rhs) => + { + let coerced_type = get_wider_decimal_type_cross_variant(lhs_type, rhs_type)?; + Some((coerced_type.clone(), coerced_type)) + } // Unlike with comparison we don't coerce to a decimal in the case of floating point // numbers, instead falling back to floating point arithmetic instead ( @@ -953,21 +972,92 @@ pub fn binary_numeric_coercion( pub fn decimal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { use arrow::datatypes::DataType::*; + // Prefer decimal data type over floating point for comparison operation match (lhs_type, rhs_type) { - // Prefer decimal data type over floating point for comparison operation - (Decimal128(_, _), Decimal128(_, _)) => { + // Same decimal types + (lhs_type, rhs_type) + if is_decimal(lhs_type) + && is_decimal(rhs_type) + && std::mem::discriminant(lhs_type) + == std::mem::discriminant(rhs_type) => + { get_wider_decimal_type(lhs_type, rhs_type) } - (Decimal128(_, _), _) => get_common_decimal_type(lhs_type, rhs_type), - (_, Decimal128(_, _)) => get_common_decimal_type(rhs_type, lhs_type), - (Decimal256(_, _), Decimal256(_, _)) => { - get_wider_decimal_type(lhs_type, rhs_type) + // Mismatched decimal types + (lhs_type, rhs_type) + if is_decimal(lhs_type) + && is_decimal(rhs_type) + && std::mem::discriminant(lhs_type) + != std::mem::discriminant(rhs_type) => + { + get_wider_decimal_type_cross_variant(lhs_type, rhs_type) + } + // Decimal + non-decimal types + (Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _), _) => { + get_common_decimal_type(lhs_type, rhs_type) + } + (_, Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _)) => { + get_common_decimal_type(rhs_type, lhs_type) } - (Decimal256(_, _), _) => get_common_decimal_type(lhs_type, rhs_type), - (_, Decimal256(_, _)) => get_common_decimal_type(rhs_type, lhs_type), (_, _) => None, } } +/// Handle cross-variant decimal widening by choosing the larger variant +fn get_wider_decimal_type_cross_variant( + lhs_type: &DataType, + rhs_type: &DataType, +) -> Option { + use arrow::datatypes::DataType::*; + + let (p1, s1) = match lhs_type { + Decimal32(p, s) => (*p, *s), + Decimal64(p, s) => (*p, *s), + Decimal128(p, s) => (*p, *s), + Decimal256(p, s) => (*p, *s), + _ => return None, + }; + + let (p2, s2) = match rhs_type { + Decimal32(p, s) => (*p, *s), + Decimal64(p, s) => (*p, *s), + Decimal128(p, s) => (*p, *s), + Decimal256(p, s) => (*p, *s), + _ => return None, + }; + + // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2) + let s = s1.max(s2); + let range = (p1 as i8 - s1).max(p2 as i8 - s2); + let required_precision = (range + s) as u8; + + // Choose the larger variant between the two input types, while making sure we don't overflow the precision. + match (lhs_type, rhs_type) { + (Decimal32(_, _), Decimal64(_, _)) | (Decimal64(_, _), Decimal32(_, _)) + if required_precision <= DECIMAL64_MAX_PRECISION => + { + Some(Decimal64(required_precision, s)) + } + (Decimal32(_, _), Decimal128(_, _)) + | (Decimal128(_, _), Decimal32(_, _)) + | (Decimal64(_, _), Decimal128(_, _)) + | (Decimal128(_, _), Decimal64(_, _)) + if required_precision <= DECIMAL128_MAX_PRECISION => + { + Some(Decimal128(required_precision, s)) + } + (Decimal32(_, _), Decimal256(_, _)) + | (Decimal256(_, _), Decimal32(_, _)) + | (Decimal64(_, _), Decimal256(_, _)) + | (Decimal256(_, _), Decimal64(_, _)) + | (Decimal128(_, _), Decimal256(_, _)) + | (Decimal256(_, _), Decimal128(_, _)) + if required_precision <= DECIMAL256_MAX_PRECISION => + { + Some(Decimal256(required_precision, s)) + } + _ => None, + } +} /// Coerce `lhs_type` and `rhs_type` to a common type. fn get_common_decimal_type( @@ -976,7 +1066,15 @@ fn get_common_decimal_type( ) -> Option { use arrow::datatypes::DataType::*; match decimal_type { - Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) => { + Decimal32(_, _) => { + let other_decimal_type = coerce_numeric_type_to_decimal32(other_type)?; + get_wider_decimal_type(decimal_type, &other_decimal_type) + } + Decimal64(_, _) => { + let other_decimal_type = coerce_numeric_type_to_decimal64(other_type)?; + get_wider_decimal_type(decimal_type, &other_decimal_type) + } + Decimal128(_, _) => { let other_decimal_type = coerce_numeric_type_to_decimal128(other_type)?; get_wider_decimal_type(decimal_type, &other_decimal_type) } @@ -988,7 +1086,7 @@ fn get_common_decimal_type( } } -/// Returns a `DataType::Decimal128` that can store any value from either +/// Returns a decimal [`DataType`] variant that can store any value from either /// `lhs_decimal_type` and `rhs_decimal_type` /// /// The result decimal type is `(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2))`. @@ -1209,14 +1307,14 @@ fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option DataType { - DataType::Decimal128( + DataType::Decimal32( DECIMAL32_MAX_PRECISION.min(precision), DECIMAL32_MAX_SCALE.min(scale), ) } fn create_decimal64_type(precision: u8, scale: i8) -> DataType { - DataType::Decimal128( + DataType::Decimal64( DECIMAL64_MAX_PRECISION.min(precision), DECIMAL64_MAX_SCALE.min(scale), ) diff --git a/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs b/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs index e6238ba0078d..bfedcf071387 100644 --- a/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs +++ b/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs @@ -291,3 +291,133 @@ fn test_coercion_arithmetic_decimal() -> Result<()> { Ok(()) } + +#[test] +fn test_coercion_arithmetic_decimal_cross_variant() -> Result<()> { + let test_cases = [ + ( + DataType::Decimal32(5, 2), + DataType::Decimal64(10, 3), + DataType::Decimal64(10, 3), + DataType::Decimal64(10, 3), + ), + ( + DataType::Decimal32(7, 1), + DataType::Decimal128(15, 4), + DataType::Decimal128(15, 4), + DataType::Decimal128(15, 4), + ), + ( + DataType::Decimal32(9, 0), + DataType::Decimal256(20, 5), + DataType::Decimal256(20, 5), + DataType::Decimal256(20, 5), + ), + ( + DataType::Decimal64(12, 3), + DataType::Decimal128(18, 2), + DataType::Decimal128(19, 3), + DataType::Decimal128(19, 3), + ), + ( + DataType::Decimal64(15, 4), + DataType::Decimal256(25, 6), + DataType::Decimal256(25, 6), + DataType::Decimal256(25, 6), + ), + ( + DataType::Decimal128(20, 5), + DataType::Decimal256(30, 8), + DataType::Decimal256(30, 8), + DataType::Decimal256(30, 8), + ), + // Reverse order cases + ( + DataType::Decimal64(10, 3), + DataType::Decimal32(5, 2), + DataType::Decimal64(10, 3), + DataType::Decimal64(10, 3), + ), + ( + DataType::Decimal128(15, 4), + DataType::Decimal32(7, 1), + DataType::Decimal128(15, 4), + DataType::Decimal128(15, 4), + ), + ( + DataType::Decimal256(20, 5), + DataType::Decimal32(9, 0), + DataType::Decimal256(20, 5), + DataType::Decimal256(20, 5), + ), + ( + DataType::Decimal128(18, 2), + DataType::Decimal64(12, 3), + DataType::Decimal128(19, 3), + DataType::Decimal128(19, 3), + ), + ( + DataType::Decimal256(25, 6), + DataType::Decimal64(15, 4), + DataType::Decimal256(25, 6), + DataType::Decimal256(25, 6), + ), + ( + DataType::Decimal256(30, 8), + DataType::Decimal128(20, 5), + DataType::Decimal256(30, 8), + DataType::Decimal256(30, 8), + ), + ]; + + for (lhs_type, rhs_type, expected_lhs_type, expected_rhs_type) in test_cases { + test_math_decimal_coercion_rule( + lhs_type, + rhs_type, + expected_lhs_type, + expected_rhs_type, + ); + } + + Ok(()) +} + +#[test] +fn test_decimal_precision_overflow_cross_variant() -> Result<()> { + // s = max(0, 1) = 1, range = max(76-0, 38-1) = 76, required_precision = 76 + 1 = 77 (overflow) + let result = get_wider_decimal_type_cross_variant( + &DataType::Decimal256(76, 0), + &DataType::Decimal128(38, 1), + ); + assert!(result.is_none()); + + // s = max(0, 10) = 10, range = max(9-0, 18-10) = 9, required_precision = 9 + 10 = 19 (overflow > 18) + let result = get_wider_decimal_type_cross_variant( + &DataType::Decimal32(9, 0), + &DataType::Decimal64(18, 10), + ); + assert!(result.is_none()); + + // s = max(5, 26) = 26, range = max(18-5, 38-26) = 13, required_precision = 13 + 26 = 39 (overflow > 38) + let result = get_wider_decimal_type_cross_variant( + &DataType::Decimal64(18, 5), + &DataType::Decimal128(38, 26), + ); + assert!(result.is_none()); + + // s = max(10, 49) = 49, range = max(38-10, 76-49) = 28, required_precision = 28 + 49 = 77 (overflow > 76) + let result = get_wider_decimal_type_cross_variant( + &DataType::Decimal128(38, 10), + &DataType::Decimal256(76, 49), + ); + assert!(result.is_none()); + + // s = max(2, 3) = 3, range = max(5-2, 10-3) = 7, required_precision = 7 + 3 = 10 (valid <= 18) + let result = get_wider_decimal_type_cross_variant( + &DataType::Decimal32(5, 2), + &DataType::Decimal64(10, 3), + ); + assert!(result.is_some()); + + Ok(()) +} diff --git a/datafusion/expr-common/src/type_coercion/binary/tests/comparison.rs b/datafusion/expr-common/src/type_coercion/binary/tests/comparison.rs index 208edae4ffc2..5401264e43e3 100644 --- a/datafusion/expr-common/src/type_coercion/binary/tests/comparison.rs +++ b/datafusion/expr-common/src/type_coercion/binary/tests/comparison.rs @@ -697,3 +697,91 @@ fn test_map_coercion() -> Result<()> { ); Ok(()) } + +#[test] +fn test_decimal_cross_variant_comparison_coercion() -> Result<()> { + let test_cases = [ + // (lhs, rhs, expected_result) + ( + DataType::Decimal32(5, 2), + DataType::Decimal64(10, 3), + DataType::Decimal64(10, 3), + ), + ( + DataType::Decimal32(7, 1), + DataType::Decimal128(15, 4), + DataType::Decimal128(15, 4), + ), + ( + DataType::Decimal32(9, 0), + DataType::Decimal256(20, 5), + DataType::Decimal256(20, 5), + ), + ( + DataType::Decimal64(12, 3), + DataType::Decimal128(18, 2), + DataType::Decimal128(19, 3), + ), + ( + DataType::Decimal64(15, 4), + DataType::Decimal256(25, 6), + DataType::Decimal256(25, 6), + ), + ( + DataType::Decimal128(20, 5), + DataType::Decimal256(30, 8), + DataType::Decimal256(30, 8), + ), + // Reverse order cases + ( + DataType::Decimal64(10, 3), + DataType::Decimal32(5, 2), + DataType::Decimal64(10, 3), + ), + ( + DataType::Decimal128(15, 4), + DataType::Decimal32(7, 1), + DataType::Decimal128(15, 4), + ), + ( + DataType::Decimal256(20, 5), + DataType::Decimal32(9, 0), + DataType::Decimal256(20, 5), + ), + ( + DataType::Decimal128(18, 2), + DataType::Decimal64(12, 3), + DataType::Decimal128(19, 3), + ), + ( + DataType::Decimal256(25, 6), + DataType::Decimal64(15, 4), + DataType::Decimal256(25, 6), + ), + ( + DataType::Decimal256(30, 8), + DataType::Decimal128(20, 5), + DataType::Decimal256(30, 8), + ), + ]; + + let comparison_op_types = [ + Operator::NotEq, + Operator::Eq, + Operator::Gt, + Operator::GtEq, + Operator::Lt, + Operator::LtEq, + ]; + + for (lhs_type, rhs_type, expected_type) in test_cases { + for op in comparison_op_types { + let (lhs, rhs) = + BinaryTypeCoercer::new(&lhs_type, &op, &rhs_type).get_input_types()?; + assert_eq!(expected_type, lhs, "Coercion of type {lhs_type:?} with {rhs_type:?} resulted in unexpected type: {lhs:?}"); + assert_eq!(expected_type, rhs, "Coercion of type {rhs_type:?} with {lhs_type:?} resulted in unexpected type: {rhs:?}"); + } + } + + Ok(()) +} diff --git a/datafusion/functions/src/math/abs.rs b/datafusion/functions/src/math/abs.rs index 8af8e4c2c849..040f13c01449 100644 --- a/datafusion/functions/src/math/abs.rs +++ b/datafusion/functions/src/math/abs.rs @@ -21,8 +21,8 @@ use std::any::Any; use std::sync::Arc; use arrow::array::{ - ArrayRef, Decimal128Array, Decimal256Array, Float32Array, Float64Array, Int16Array, - Int32Array, Int64Array, Int8Array, + ArrayRef, Decimal128Array, Decimal256Array, Decimal32Array, Decimal64Array, + Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, }; use arrow::datatypes::DataType; use arrow::error::ArrowError; @@ -98,6 +98,8 @@ fn create_abs_function(input_data_type: &DataType) -> Result | DataType::UInt64 => Ok(|input: &ArrayRef| Ok(Arc::clone(input))), // Decimal types + DataType::Decimal32(_, _) => Ok(make_decimal_abs_function!(Decimal32Array)), + DataType::Decimal64(_, _) => Ok(make_decimal_abs_function!(Decimal64Array)), DataType::Decimal128(_, _) => Ok(make_decimal_abs_function!(Decimal128Array)), DataType::Decimal256(_, _) => Ok(make_decimal_abs_function!(Decimal256Array)), @@ -162,6 +164,12 @@ impl ScalarUDFImpl for AbsFunc { DataType::UInt16 => Ok(DataType::UInt16), DataType::UInt32 => Ok(DataType::UInt32), DataType::UInt64 => Ok(DataType::UInt64), + DataType::Decimal32(precision, scale) => { + Ok(DataType::Decimal32(precision, scale)) + } + DataType::Decimal64(precision, scale) => { + Ok(DataType::Decimal64(precision, scale)) + } DataType::Decimal128(precision, scale) => { Ok(DataType::Decimal128(precision, scale)) } diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs b/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs index d28a9bad17ec..f16ef24fd172 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs @@ -1922,6 +1922,8 @@ fn compare_join_arrays( DataType::BinaryView => compare_value!(BinaryViewArray), DataType::FixedSizeBinary(_) => compare_value!(FixedSizeBinaryArray), DataType::LargeBinary => compare_value!(LargeBinaryArray), + DataType::Decimal32(..) => compare_value!(Decimal32Array), + DataType::Decimal64(..) => compare_value!(Decimal64Array), DataType::Decimal128(..) => compare_value!(Decimal128Array), DataType::Timestamp(time_unit, None) => match time_unit { TimeUnit::Second => compare_value!(TimestampSecondArray), @@ -1994,7 +1996,10 @@ fn is_join_arrays_equal( DataType::BinaryView => compare_value!(BinaryViewArray), DataType::FixedSizeBinary(_) => compare_value!(FixedSizeBinaryArray), DataType::LargeBinary => compare_value!(LargeBinaryArray), + DataType::Decimal32(..) => compare_value!(Decimal32Array), + DataType::Decimal64(..) => compare_value!(Decimal64Array), DataType::Decimal128(..) => compare_value!(Decimal128Array), + DataType::Decimal256(..) => compare_value!(Decimal256Array), DataType::Timestamp(time_unit, None) => match time_unit { TimeUnit::Second => compare_value!(TimestampSecondArray), TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray), diff --git a/datafusion/spark/src/function/math/width_bucket.rs b/datafusion/spark/src/function/math/width_bucket.rs index 24f8fe6b2456..45a0d843b7ed 100644 --- a/datafusion/spark/src/function/math/width_bucket.rs +++ b/datafusion/spark/src/function/math/width_bucket.rs @@ -32,6 +32,7 @@ use datafusion_common::cast::{ }; use datafusion_common::{exec_err, Result}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_expr::type_coercion::is_signed_numeric; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_functions::utils::make_scalar_function; @@ -93,16 +94,11 @@ impl ScalarUDFImpl for SparkWidthBucket { let (v, lo, hi, n) = (&types[0], &types[1], &types[2], &types[3]); - let is_num = |t: &DataType| { - matches!( - t, - Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _) - ) - }; - match (v, lo, hi, n) { (a, b, c, &(Int8 | Int16 | Int32 | Int64)) - if is_num(a) && is_num(b) && is_num(c) => + if is_signed_numeric(a) + && is_signed_numeric(b) + && is_signed_numeric(c) => { Ok(vec![Float64, Float64, Float64, Int32]) }