diff --git a/arrow-cast/src/cast/decimal.rs b/arrow-cast/src/cast/decimal.rs index 600f868a3e01..6c020ac69015 100644 --- a/arrow-cast/src/cast/decimal.rs +++ b/arrow-cast/src/cast/decimal.rs @@ -250,7 +250,7 @@ where } }; - let integers = first_part.trim_start_matches('0'); + let integers = first_part; let decimals = if parts.len() == 2 { parts[1] } else { "" }; if !integers.is_empty() && !integers.as_bytes()[0].is_ascii_digit() { @@ -571,3 +571,48 @@ where let array = array.unary::<_, T>(op); Ok(Arc::new(array)) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_string_to_decimal_native() -> Result<(), ArrowError> { + assert_eq!( + parse_string_to_decimal_native::("0", 0)?, + 0_i128 + ); + assert_eq!( + parse_string_to_decimal_native::("0", 5)?, + 0_i128 + ); + + assert_eq!( + parse_string_to_decimal_native::("123", 0)?, + 123_i128 + ); + assert_eq!( + parse_string_to_decimal_native::("123", 5)?, + 12300000_i128 + ); + + assert_eq!( + parse_string_to_decimal_native::("123.45", 0)?, + 123_i128 + ); + assert_eq!( + parse_string_to_decimal_native::("123.45", 5)?, + 12345000_i128 + ); + + assert_eq!( + parse_string_to_decimal_native::("123.4567891", 0)?, + 123_i128 + ); + assert_eq!( + parse_string_to_decimal_native::("123.4567891", 5)?, + 12345679_i128 + ); + Ok(()) + } +} diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs index 1770157bcfd9..7641334f793f 100644 --- a/arrow-cast/src/cast/mod.rs +++ b/arrow-cast/src/cast/mod.rs @@ -8209,6 +8209,10 @@ mod tests { assert!(decimal_arr.is_null(25)); assert!(decimal_arr.is_null(26)); assert!(decimal_arr.is_null(27)); + assert_eq!("0.00", decimal_arr.value_as_string(28)); + assert_eq!("0.00", decimal_arr.value_as_string(29)); + assert_eq!("12345.00", decimal_arr.value_as_string(30)); + assert_eq!(decimal_arr.len(), 31); // Decimal256 let output_type = DataType::Decimal256(76, 3); @@ -8245,6 +8249,10 @@ mod tests { assert!(decimal_arr.is_null(25)); assert!(decimal_arr.is_null(26)); assert!(decimal_arr.is_null(27)); + assert_eq!("0.000", decimal_arr.value_as_string(28)); + assert_eq!("0.000", decimal_arr.value_as_string(29)); + assert_eq!("12345.000", decimal_arr.value_as_string(30)); + assert_eq!(decimal_arr.len(), 31); } #[test] @@ -8278,10 +8286,30 @@ mod tests { Some("1.-23499999"), Some("-1.-23499999"), Some("--1.23499999"), + Some("0"), + Some("000.000"), + Some("0000000000000000012345.000"), ]); let array = Arc::new(str_array) as ArrayRef; test_cast_string_to_decimal(array); + + let test_cases = [ + (None, None), + // (Some(""), None), + // (Some(" "), None), + (Some("0"), Some("0")), + (Some("000.000"), Some("0")), + (Some("12345"), Some("12345")), + (Some("000000000000000000000000000012345"), Some("12345")), + (Some("-123"), Some("-123")), + (Some("+123"), Some("123")), + ]; + let inputs = test_cases.iter().map(|entry| entry.0).collect::>(); + let expected = test_cases.iter().map(|entry| entry.1).collect::>(); + + let array = Arc::new(StringArray::from(inputs)) as ArrayRef; + test_cast_string_to_decimal_scale_zero(array, &expected); } #[test] @@ -8315,10 +8343,67 @@ mod tests { Some("1.-23499999"), Some("-1.-23499999"), Some("--1.23499999"), + Some("0"), + Some("000.000"), + Some("0000000000000000012345.000"), ]); let array = Arc::new(str_array) as ArrayRef; test_cast_string_to_decimal(array); + + let test_cases = [ + (None, None), + (Some(""), None), + (Some(" "), None), + (Some("0"), Some("0")), + (Some("000.000"), Some("0")), + (Some("12345"), Some("12345")), + (Some("000000000000000000000000000012345"), Some("12345")), + (Some("-123"), Some("-123")), + (Some("+123"), Some("123")), + ]; + let inputs = test_cases.iter().map(|entry| entry.0).collect::>(); + let expected = test_cases.iter().map(|entry| entry.1).collect::>(); + + let array = Arc::new(LargeStringArray::from(inputs)) as ArrayRef; + test_cast_string_to_decimal_scale_zero(array, &expected); + } + + fn test_cast_string_to_decimal_scale_zero( + array: ArrayRef, + expected_as_string: &[Option<&str>], + ) { + // Decimal128 + let output_type = DataType::Decimal128(38, 0); + assert!(can_cast_types(array.data_type(), &output_type)); + let casted_array = cast(&array, &output_type).unwrap(); + let decimal_arr = casted_array.as_primitive::(); + assert_decimal_array_contents(decimal_arr, expected_as_string); + + // Decimal256 + let output_type = DataType::Decimal256(76, 0); + assert!(can_cast_types(array.data_type(), &output_type)); + let casted_array = cast(&array, &output_type).unwrap(); + let decimal_arr = casted_array.as_primitive::(); + assert_decimal_array_contents(decimal_arr, expected_as_string); + } + + fn assert_decimal_array_contents( + array: &PrimitiveArray, + expected_as_string: &[Option<&str>], + ) where + T: DecimalType + ArrowPrimitiveType, + { + assert_eq!(array.len(), expected_as_string.len()); + for (i, expected) in expected_as_string.iter().enumerate() { + let actual = if array.is_null(i) { + None + } else { + Some(array.value_as_string(i)) + }; + let actual = actual.as_ref().map(|s| s.as_ref()); + assert_eq!(*expected, actual, "Expected at position {}", i); + } } #[test]