diff --git a/arrow-buffer/src/bigint.rs b/arrow-buffer/src/bigint.rs index 23400b4a3f6e..8cf79c917346 100644 --- a/arrow-buffer/src/bigint.rs +++ b/arrow-buffer/src/bigint.rs @@ -16,7 +16,7 @@ // under the License. use num::cast::AsPrimitive; -use num::{BigInt, FromPrimitive, ToPrimitive}; +use num::{BigInt, FromPrimitive, Num, ToPrimitive}; use std::cmp::Ordering; /// A signed 256-bit integer @@ -102,6 +102,19 @@ impl i256 { Self::from_parts(v as u128, v >> 127) } + /// Create an integer value from its representation as string. + #[inline] + pub fn from_string(value_str: &str) -> Option { + let numbers = BigInt::from_str_radix(value_str, 10).ok()?; + let (integer, overflow) = Self::from_bigint_with_overflow(numbers); + + if overflow { + None + } else { + Some(integer) + } + } + /// Create an optional i256 from the provided `f64`. Returning `None` /// if overflow occurred pub fn from_f64(v: f64) -> Option { diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs index 7bb3aeb9603f..a8f5738d0c1a 100644 --- a/arrow-cast/src/cast.rs +++ b/arrow-cast/src/cast.rs @@ -125,7 +125,6 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (List(_), _) => false, (_, List(list_to)) => can_cast_types(from_type, list_to.data_type()), (_, LargeList(list_to)) => can_cast_types(from_type, list_to.data_type()), - // TODO UTF8 to decimal // cast one decimal type to another decimal type (Decimal128(_, _), Decimal128(_, _)) => true, (Decimal256(_, _), Decimal256(_, _)) => true, @@ -142,6 +141,9 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { // decimal to signed numeric (Decimal128(_, _), Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64) | (Decimal256(_, _), Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64) => true, + // Utf8 to decimal + (Utf8 | LargeUtf8, Decimal128(_, _)) => true, + (Utf8 | LargeUtf8, Decimal256(_, _)) => true, (Decimal128(_, _), _) => false, (_, Decimal128(_, _)) => false, (Decimal256(_, _), _) => false, @@ -944,6 +946,18 @@ pub fn cast_with_options( *scale, cast_options, ), + Utf8 => cast_string_to_decimal::( + array, + *precision, + *scale, + cast_options, + ), + LargeUtf8 => cast_string_to_decimal::( + array, + *precision, + *scale, + cast_options, + ), Null => Ok(new_null_array(to_type, array.len())), _ => Err(ArrowError::CastError(format!( "Casting from {:?} to {:?} not supported", @@ -995,6 +1009,18 @@ pub fn cast_with_options( *scale, cast_options, ), + Utf8 => cast_string_to_decimal::( + array, + *precision, + *scale, + cast_options, + ), + LargeUtf8 => cast_string_to_decimal::( + array, + *precision, + *scale, + cast_options, + ), Null => Ok(new_null_array(to_type, array.len())), _ => Err(ArrowError::CastError(format!( "Casting from {:?} to {:?} not supported", @@ -2801,6 +2827,176 @@ fn cast_utf8_to_boolean( Ok(Arc::new(output_array)) } +/// Parses given string to specified decimal native (i128/i256) based on given +/// scale. Returns an `Err` if it cannot parse given string. +fn parse_string_to_decimal_native( + value_str: &str, + scale: usize, +) -> Result +where + T::Native: DecimalCast + ArrowNativeTypeOp, +{ + let value_str = value_str.trim(); + let parts: Vec<&str> = value_str.split('.').collect(); + if parts.len() > 2 { + return Err(ArrowError::InvalidArgumentError(format!( + "Invalid decimal format: {:?}", + value_str + ))); + } + + let integers = parts[0].trim_start_matches('0'); + let decimals = if parts.len() == 2 { parts[1] } else { "" }; + + // Adjust decimal based on scale + let number_decimals = if decimals.len() > scale { + let decimal_number = i256::from_string(decimals).ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "Cannot parse decimal format: {}", + value_str + )) + })?; + + let div = + i256::from_i128(10_i128).pow_checked((decimals.len() - scale) as u32)?; + + let half = div.div_wrapping(i256::from_i128(2)); + let half_neg = half.neg_wrapping(); + + let d = decimal_number.div_wrapping(div); + let r = decimal_number.mod_wrapping(div); + + // Round result + let adjusted = match decimal_number >= i256::ZERO { + true if r >= half => d.add_wrapping(i256::ONE), + false if r <= half_neg => d.sub_wrapping(i256::ONE), + _ => d, + }; + + let integers = if !integers.is_empty() { + i256::from_string(integers) + .ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "Cannot parse decimal format: {}", + value_str + )) + }) + .map(|v| { + v.mul_wrapping(i256::from_i128(10_i128).pow_wrapping(scale as u32)) + })? + } else { + i256::ZERO + }; + + format!("{}", integers.add_wrapping(adjusted)) + } else { + let padding = if scale > decimals.len() { scale } else { 0 }; + + let decimals = format!("{:0( + from: &GenericStringArray, + precision: u8, + scale: i8, + cast_options: &CastOptions, +) -> Result, ArrowError> +where + T: DecimalType, + T::Native: DecimalCast + ArrowNativeTypeOp, +{ + if cast_options.safe { + let iter = from.iter().map(|v| { + v.and_then(|v| parse_string_to_decimal_native::(v, scale as usize).ok()) + }); + // Benefit: + // 20% performance improvement + // Soundness: + // The iterator is trustedLen because it comes from an `StringArray`. + Ok(unsafe { + PrimitiveArray::::from_trusted_len_iter(iter) + .with_precision_and_scale(precision, scale)? + }) + } else { + let vec = from + .iter() + .map(|v| { + v.map(|v| { + parse_string_to_decimal_native::(v, scale as usize).map_err(|_| { + ArrowError::CastError(format!( + "Cannot cast string '{}' to value of {:?} type", + v, + T::DATA_TYPE, + )) + }) + }) + .transpose() + }) + .collect::, _>>()?; + // Benefit: + // 20% performance improvement + // Soundness: + // The iterator is trustedLen because it comes from an `StringArray`. + Ok(unsafe { + PrimitiveArray::::from_trusted_len_iter(vec.iter()) + .with_precision_and_scale(precision, scale)? + }) + } +} + +/// Cast Utf8 to decimal +fn cast_string_to_decimal( + from: &ArrayRef, + precision: u8, + scale: i8, + cast_options: &CastOptions, +) -> Result +where + T: DecimalType, + T::Native: DecimalCast + ArrowNativeTypeOp, +{ + if scale < 0 { + return Err(ArrowError::InvalidArgumentError(format!( + "Cannot cast string to decimal with negative scale {}", + scale + ))); + } + + if scale > T::MAX_SCALE { + return Err(ArrowError::InvalidArgumentError(format!( + "Cannot cast string to decimal greater than maximum scale {}", + T::MAX_SCALE + ))); + } + + Ok(Arc::new(string_to_decimal_cast::( + from.as_any() + .downcast_ref::>() + .unwrap(), + precision, + scale, + cast_options, + )?)) +} + /// Cast numeric types to Boolean /// /// Any zero value returns `false` while non-zero returns `true` @@ -7104,4 +7300,326 @@ mod tests { ] ); } + + #[test] + fn test_parse_string_to_decimal() { + assert_eq!( + Decimal128Type::format_decimal( + parse_string_to_decimal_native::("123.45", 2).unwrap(), + 38, + 2, + ), + "123.45" + ); + assert_eq!( + Decimal128Type::format_decimal( + parse_string_to_decimal_native::("12345", 2).unwrap(), + 38, + 2 + ), + "12345.00" + ); + assert_eq!( + Decimal128Type::format_decimal( + parse_string_to_decimal_native::("0.12345", 2).unwrap(), + 38, + 2 + ), + "0.12" + ); + assert_eq!( + Decimal128Type::format_decimal( + parse_string_to_decimal_native::(".12345", 2).unwrap(), + 38, + 2 + ), + "0.12" + ); + assert_eq!( + Decimal128Type::format_decimal( + parse_string_to_decimal_native::(".1265", 2).unwrap(), + 38, + 2 + ), + "0.13" + ); + assert_eq!( + Decimal128Type::format_decimal( + parse_string_to_decimal_native::(".1265", 2).unwrap(), + 38, + 2 + ), + "0.13" + ); + + assert_eq!( + Decimal256Type::format_decimal( + parse_string_to_decimal_native::("123.45", 3).unwrap(), + 38, + 3 + ), + "123.450" + ); + assert_eq!( + Decimal256Type::format_decimal( + parse_string_to_decimal_native::("12345", 3).unwrap(), + 38, + 3 + ), + "12345.000" + ); + assert_eq!( + Decimal256Type::format_decimal( + parse_string_to_decimal_native::("0.12345", 3).unwrap(), + 38, + 3 + ), + "0.123" + ); + assert_eq!( + Decimal256Type::format_decimal( + parse_string_to_decimal_native::(".12345", 3).unwrap(), + 38, + 3 + ), + "0.123" + ); + assert_eq!( + Decimal256Type::format_decimal( + parse_string_to_decimal_native::(".1265", 3).unwrap(), + 38, + 3 + ), + "0.127" + ); + } + + fn test_cast_string_to_decimal(array: ArrayRef) { + // Decimal128 + let output_type = DataType::Decimal128(38, 2); + assert!(can_cast_types(array.data_type(), &output_type)); + + let casted_array = cast(&array, &output_type).unwrap(); + let decimal_arr = as_primitive_array::(&casted_array); + + assert_eq!("123.45", decimal_arr.value_as_string(0)); + assert_eq!("1.23", decimal_arr.value_as_string(1)); + assert_eq!("0.12", decimal_arr.value_as_string(2)); + assert_eq!("0.13", decimal_arr.value_as_string(3)); + assert_eq!("1.26", decimal_arr.value_as_string(4)); + assert_eq!("12345.00", decimal_arr.value_as_string(5)); + assert_eq!("12345.00", decimal_arr.value_as_string(6)); + assert_eq!("0.12", decimal_arr.value_as_string(7)); + assert_eq!("12.23", decimal_arr.value_as_string(8)); + assert!(decimal_arr.is_null(9)); + assert_eq!("0.00", decimal_arr.value_as_string(10)); + assert_eq!("0.00", decimal_arr.value_as_string(11)); + assert!(decimal_arr.is_null(12)); + + // Decimal256 + let output_type = DataType::Decimal256(76, 3); + assert!(can_cast_types(array.data_type(), &output_type)); + + let casted_array = cast(&array, &output_type).unwrap(); + let decimal_arr = as_primitive_array::(&casted_array); + + assert_eq!("123.450", decimal_arr.value_as_string(0)); + assert_eq!("1.235", decimal_arr.value_as_string(1)); + assert_eq!("0.123", decimal_arr.value_as_string(2)); + assert_eq!("0.127", decimal_arr.value_as_string(3)); + assert_eq!("1.263", decimal_arr.value_as_string(4)); + assert_eq!("12345.000", decimal_arr.value_as_string(5)); + assert_eq!("12345.000", decimal_arr.value_as_string(6)); + assert_eq!("0.123", decimal_arr.value_as_string(7)); + assert_eq!("12.234", decimal_arr.value_as_string(8)); + assert!(decimal_arr.is_null(9)); + assert_eq!("0.000", decimal_arr.value_as_string(10)); + assert_eq!("0.000", decimal_arr.value_as_string(11)); + assert!(decimal_arr.is_null(12)); + } + + #[test] + fn test_cast_utf8_to_decimal() { + let str_array = StringArray::from(vec![ + Some("123.45"), + Some("1.2345"), + Some("0.12345"), + Some("0.1267"), + Some("1.263"), + Some("12345.0"), + Some("12345"), + Some("000.123"), + Some("12.234000"), + None, + Some(""), + Some(" "), + None, + ]); + let array = Arc::new(str_array) as ArrayRef; + + test_cast_string_to_decimal(array); + } + + #[test] + fn test_cast_large_utf8_to_decimal() { + let str_array = LargeStringArray::from(vec![ + Some("123.45"), + Some("1.2345"), + Some("0.12345"), + Some("0.1267"), + Some("1.263"), + Some("12345.0"), + Some("12345"), + Some("000.123"), + Some("12.234000"), + None, + Some(""), + Some(" "), + None, + ]); + let array = Arc::new(str_array) as ArrayRef; + + test_cast_string_to_decimal(array); + } + + #[test] + fn test_cast_invalid_utf8_to_decimal() { + let str_array = StringArray::from(vec!["4.4.5", ". 0.123"]); + let array = Arc::new(str_array) as ArrayRef; + + // Safe cast + let output_type = DataType::Decimal128(38, 2); + let casted_array = cast(&array, &output_type).unwrap(); + assert!(casted_array.is_null(0)); + assert!(casted_array.is_null(1)); + + let output_type = DataType::Decimal256(76, 2); + let casted_array = cast(&array, &output_type).unwrap(); + assert!(casted_array.is_null(0)); + assert!(casted_array.is_null(1)); + + // Non-safe cast + let output_type = DataType::Decimal128(38, 2); + let str_array = StringArray::from(vec!["4.4.5"]); + let array = Arc::new(str_array) as ArrayRef; + let option = CastOptions { safe: false }; + let casted_err = cast_with_options(&array, &output_type, &option).unwrap_err(); + assert!(casted_err + .to_string() + .contains("Cannot cast string '4.4.5' to value of Decimal128(38, 10) type")); + + let str_array = StringArray::from(vec![". 0.123"]); + let array = Arc::new(str_array) as ArrayRef; + let casted_err = cast_with_options(&array, &output_type, &option).unwrap_err(); + assert!(casted_err.to_string().contains( + "Cannot cast string '. 0.123' to value of Decimal128(38, 10) type" + )); + } + + fn test_cast_string_to_decimal128_overflow(overflow_array: ArrayRef) { + let output_type = DataType::Decimal128(38, 2); + let casted_array = cast(&overflow_array, &output_type).unwrap(); + let decimal_arr = as_primitive_array::(&casted_array); + + assert!(decimal_arr.is_null(0)); + assert!(decimal_arr.is_null(1)); + assert!(decimal_arr.is_null(2)); + assert_eq!( + "999999999999999999999999999999999999.99", + decimal_arr.value_as_string(3) + ); + assert_eq!( + "100000000000000000000000000000000000.00", + decimal_arr.value_as_string(4) + ); + } + + #[test] + fn test_cast_utf8_to_decimal128_overflow() { + let overflow_str_array = StringArray::from(vec![ + i128::MAX.to_string(), + i128::MIN.to_string(), + "99999999999999999999999999999999999999".to_string(), + "999999999999999999999999999999999999.99".to_string(), + "99999999999999999999999999999999999.999".to_string(), + ]); + let overflow_array = Arc::new(overflow_str_array) as ArrayRef; + + test_cast_string_to_decimal128_overflow(overflow_array); + } + + #[test] + fn test_cast_large_utf8_to_decimal128_overflow() { + let overflow_str_array = LargeStringArray::from(vec![ + i128::MAX.to_string(), + i128::MIN.to_string(), + "99999999999999999999999999999999999999".to_string(), + "999999999999999999999999999999999999.99".to_string(), + "99999999999999999999999999999999999.999".to_string(), + ]); + let overflow_array = Arc::new(overflow_str_array) as ArrayRef; + + test_cast_string_to_decimal128_overflow(overflow_array); + } + + fn test_cast_string_to_decimal256_overflow(overflow_array: ArrayRef) { + let output_type = DataType::Decimal256(76, 2); + let casted_array = cast(&overflow_array, &output_type).unwrap(); + let decimal_arr = as_primitive_array::(&casted_array); + + assert_eq!( + "170141183460469231731687303715884105727.00", + decimal_arr.value_as_string(0) + ); + assert_eq!( + "-170141183460469231731687303715884105728.00", + decimal_arr.value_as_string(1) + ); + assert_eq!( + "99999999999999999999999999999999999999.00", + decimal_arr.value_as_string(2) + ); + assert_eq!( + "999999999999999999999999999999999999.99", + decimal_arr.value_as_string(3) + ); + assert_eq!( + "100000000000000000000000000000000000.00", + decimal_arr.value_as_string(4) + ); + assert!(decimal_arr.is_null(5)); + assert!(decimal_arr.is_null(6)); + } + + #[test] + fn test_cast_utf8_to_decimal256_overflow() { + let overflow_str_array = StringArray::from(vec![ + i128::MAX.to_string(), + i128::MIN.to_string(), + "99999999999999999999999999999999999999".to_string(), + "999999999999999999999999999999999999.99".to_string(), + "99999999999999999999999999999999999.999".to_string(), + i256::MAX.to_string(), + i256::MIN.to_string(), + ]); + let overflow_array = Arc::new(overflow_str_array) as ArrayRef; + + test_cast_string_to_decimal256_overflow(overflow_array); + } + + #[test] + fn test_cast_large_utf8_to_decimal256_overflow() { + let overflow_str_array = LargeStringArray::from(vec![ + i128::MAX.to_string(), + i128::MIN.to_string(), + "99999999999999999999999999999999999999".to_string(), + "999999999999999999999999999999999999.99".to_string(), + "99999999999999999999999999999999999.999".to_string(), + i256::MAX.to_string(), + i256::MIN.to_string(), + ]); + let overflow_array = Arc::new(overflow_str_array) as ArrayRef; + + test_cast_string_to_decimal256_overflow(overflow_array); + } }