From 8f41a07a970613b1768272d85e630a4eb16d3b70 Mon Sep 17 00:00:00 2001 From: Kun Liu Date: Thu, 23 Dec 2021 00:43:44 +0800 Subject: [PATCH] support cast decimal to signed numeric (#1073) * add cast test macro function; refactor other type to decimal type; add decimal to signed numeric type support decimal to unsigned numeric * address the comments and fix the clippy --- arrow/src/compute/kernels/cast.rs | 370 +++++++++++++++++++++++++----- 1 file changed, 312 insertions(+), 58 deletions(-) diff --git a/arrow/src/compute/kernels/cast.rs b/arrow/src/compute/kernels/cast.rs index 44fbfae52b07..2275aac7fe68 100644 --- a/arrow/src/compute/kernels/cast.rs +++ b/arrow/src/compute/kernels/cast.rs @@ -68,8 +68,12 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { } match (from_type, to_type) { - // TODO now just support signed numeric to decimal, support decimal to numeric later - (Int8 | Int16 | Int32 | Int64 | Float32 | Float64, Decimal(_, _)) + // TODO UTF8/unsigned numeric to decimal + // TODO decimal to decimal type + // signed numeric to decimal + (Int8 | Int16 | Int32 | Int64 | Float32 | Float64, Decimal(_, _)) | + // decimal to signed numeric + (Decimal(_, _), Int8 | Int16 | Int32 | Int64 | Float32 | Float64) | ( Null, Boolean @@ -108,6 +112,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { | Dictionary(_, _), Null, ) => true, + (Decimal(_, _), _) => false, + (_, Decimal(_, _)) => false, (Struct(_), _) => false, (_, Struct(_)) => false, (LargeList(list_from), LargeList(list_to)) => { @@ -281,6 +287,56 @@ macro_rules! cast_floating_point_to_decimal { }}; } +// cast the decimal array to integer array +macro_rules! cast_decimal_to_integer { + ($ARRAY:expr, $SCALE : ident, $VALUE_BUILDER: ident, $NATIVE_TYPE : ident, $DATA_TYPE : expr) => {{ + let array = $ARRAY.as_any().downcast_ref::().unwrap(); + let mut value_builder = $VALUE_BUILDER::new(array.len()); + let div: i128 = 10_i128.pow(*$SCALE as u32); + let min_bound = ($NATIVE_TYPE::MIN) as i128; + let max_bound = ($NATIVE_TYPE::MAX) as i128; + for i in 0..array.len() { + if array.is_null(i) { + value_builder.append_null()?; + } else { + let v = array.value(i) / div; + // check the overflow + // For example: Decimal(128,10,0) as i8 + // 128 is out of range i8 + if v <= max_bound && v >= min_bound { + value_builder.append_value(v as $NATIVE_TYPE)?; + } else { + return Err(ArrowError::CastError(format!( + "value of {} is out of range {}", + v, $DATA_TYPE + ))); + } + } + } + Ok(Arc::new(value_builder.finish())) + }}; +} + +// cast the decimal array to floating-point array +macro_rules! cast_decimal_to_float { + ($ARRAY:expr, $SCALE : ident, $VALUE_BUILDER: ident, $NATIVE_TYPE : ty) => {{ + let array = $ARRAY.as_any().downcast_ref::().unwrap(); + let div = 10_f64.powi(*$SCALE as i32); + let mut value_builder = $VALUE_BUILDER::new(array.len()); + for i in 0..array.len() { + if array.is_null(i) { + value_builder.append_null()?; + } else { + // The range of f32 or f64 is larger than i128, we don't need to check overflow. + // cast the i128 to f64 will lose precision, for example the `112345678901234568` will be as `112345678901234560`. + let v = (array.value(i) as f64 / div) as $NATIVE_TYPE; + value_builder.append_value(v)?; + } + } + Ok(Arc::new(value_builder.finish())) + }}; +} + /// Cast `array` to the provided data type and return a new Array with /// type `to_type`, if possible. It accepts `CastOptions` to allow consumers /// to configure cast behavior. @@ -315,6 +371,33 @@ pub fn cast_with_options( return Ok(array.clone()); } match (from_type, to_type) { + (Decimal(_, scale), _) => { + // cast decimal to other type + match to_type { + Int8 => { + cast_decimal_to_integer!(array, scale, Int8Builder, i8, Int8) + } + Int16 => { + cast_decimal_to_integer!(array, scale, Int16Builder, i16, Int16) + } + Int32 => { + cast_decimal_to_integer!(array, scale, Int32Builder, i32, Int32) + } + Int64 => { + cast_decimal_to_integer!(array, scale, Int64Builder, i64, Int64) + } + Float32 => { + cast_decimal_to_float!(array, scale, Float32Builder, f32) + } + Float64 => { + cast_decimal_to_float!(array, scale, Float64Builder, f64) + } + _ => Err(ArrowError::CastError(format!( + "Casting from {:?} to {:?} not supported", + from_type, to_type + ))), + } + } (_, Decimal(precision, scale)) => { // cast data to decimal match from_type { @@ -1906,26 +1989,179 @@ where mod tests { use super::*; use crate::{buffer::Buffer, util::display::array_value_to_string}; - use num::traits::Pow; + + macro_rules! generate_cast_test_case { + ($INPUT_ARRAY: expr, $OUTPUT_TYPE_ARRAY: ident, $OUTPUT_TYPE: expr, $OUTPUT_VALUES: expr) => { + // assert cast type + let input_array_type = $INPUT_ARRAY.data_type(); + assert!(can_cast_types(input_array_type, $OUTPUT_TYPE)); + let casted_array = cast($INPUT_ARRAY, $OUTPUT_TYPE).unwrap(); + let result_array = casted_array + .as_any() + .downcast_ref::<$OUTPUT_TYPE_ARRAY>() + .unwrap(); + assert_eq!($OUTPUT_TYPE, result_array.data_type()); + assert_eq!(result_array.len(), $OUTPUT_VALUES.len()); + for (i, x) in $OUTPUT_VALUES.iter().enumerate() { + match x { + Some(x) => { + assert_eq!(result_array.value(i), *x); + } + None => { + assert!(result_array.is_null(i)); + } + } + } + }; + } + + // TODO remove this function if the decimal array has the creator function + fn create_decimal_array( + array: &[Option], + precision: usize, + scale: usize, + ) -> Result { + let mut decimal_builder = DecimalBuilder::new(array.len(), precision, scale); + for value in array { + match value { + None => { + decimal_builder.append_null()?; + } + Some(v) => { + decimal_builder.append_value(*v)?; + } + } + } + Ok(decimal_builder.finish()) + } #[test] - fn test_cast_numeric_to_decimal() { - // test cast type - let data_types = vec![ - DataType::Int8, - DataType::Int16, - DataType::Int32, - DataType::Int64, - DataType::Float32, - DataType::Float64, + fn test_cast_decimal_to_numeric() { + let decimal_type = DataType::Decimal(38, 2); + // negative test + assert!(!can_cast_types(&decimal_type, &DataType::UInt8)); + let value_array: Vec> = + vec![Some(125), Some(225), Some(325), None, Some(525)]; + let decimal_array = create_decimal_array(&value_array, 38, 2).unwrap(); + let array = Arc::new(decimal_array) as ArrayRef; + // i8 + generate_cast_test_case!( + &array, + Int8Array, + &DataType::Int8, + vec![Some(1_i8), Some(2_i8), Some(3_i8), None, Some(5_i8)] + ); + // i16 + generate_cast_test_case!( + &array, + Int16Array, + &DataType::Int16, + vec![Some(1_i16), Some(2_i16), Some(3_i16), None, Some(5_i16)] + ); + // i32 + generate_cast_test_case!( + &array, + Int32Array, + &DataType::Int32, + vec![Some(1_i32), Some(2_i32), Some(3_i32), None, Some(5_i32)] + ); + // i64 + generate_cast_test_case!( + &array, + Int64Array, + &DataType::Int64, + vec![Some(1_i64), Some(2_i64), Some(3_i64), None, Some(5_i64)] + ); + // f32 + generate_cast_test_case!( + &array, + Int64Array, + &DataType::Int64, + vec![Some(1_i64), Some(2_i64), Some(3_i64), None, Some(5_i64)] + ); + // f64 + generate_cast_test_case!( + &array, + Int64Array, + &DataType::Int64, + vec![Some(1_i64), Some(2_i64), Some(3_i64), None, Some(5_i64)] + ); + + // overflow test: out of range of max i8 + let value_array: Vec> = vec![Some(24400)]; + let decimal_array = create_decimal_array(&value_array, 38, 2).unwrap(); + let array = Arc::new(decimal_array) as ArrayRef; + let casted_array = cast(&array, &DataType::Int8); + assert_eq!( + "Cast error: value of 244 is out of range Int8".to_string(), + casted_array.unwrap_err().to_string() + ); + + // loss the precision: convert decimal to f32、f64 + // f32 + // 112345678_f32 and 112345679_f32 are same, so the 112345679_f32 will lose precision. + let value_array: Vec> = vec![ + Some(125), + Some(225), + Some(325), + None, + Some(525), + Some(112345678), + Some(112345679), + ]; + let decimal_array = create_decimal_array(&value_array, 38, 2).unwrap(); + let array = Arc::new(decimal_array) as ArrayRef; + generate_cast_test_case!( + &array, + Float32Array, + &DataType::Float32, + vec![ + Some(1.25_f32), + Some(2.25_f32), + Some(3.25_f32), + None, + Some(5.25_f32), + Some(1_123_456.7_f32), + Some(1_123_456.7_f32) + ] + ); + + // f64 + // 112345678901234568_f64 and 112345678901234560_f64 are same, so the 112345678901234568_f64 will lose precision. + let value_array: Vec> = vec![ + Some(125), + Some(225), + Some(325), + None, + Some(525), + Some(112345678901234568), + Some(112345678901234560), ]; + let decimal_array = create_decimal_array(&value_array, 38, 2).unwrap(); + let array = Arc::new(decimal_array) as ArrayRef; + generate_cast_test_case!( + &array, + Float64Array, + &DataType::Float64, + vec![ + Some(1.25_f64), + Some(2.25_f64), + Some(3.25_f64), + None, + Some(5.25_f64), + Some(1_123_456_789_012_345.6_f64), + Some(1_123_456_789_012_345.6_f64), + ] + ); + } + + #[test] + fn test_cast_numeric_to_decimal() { + // test negative cast type let decimal_type = DataType::Decimal(38, 6); - for data_type in data_types { - assert!(can_cast_types(&data_type, &decimal_type)) - } assert!(!can_cast_types(&DataType::UInt64, &decimal_type)); - // test cast data + // i8, i16, i32, i64 let input_datas = vec![ Arc::new(Int8Array::from(vec![ Some(1), @@ -1956,25 +2192,19 @@ mod tests { Some(5), ])) as ArrayRef, // i64 ]; - - // i8, i16, i32, i64 for array in input_datas { - let casted_array = cast(&array, &decimal_type).unwrap(); - let decimal_array = casted_array - .as_any() - .downcast_ref::() - .unwrap(); - assert_eq!(&decimal_type, decimal_array.data_type()); - for i in 0..array.len() { - if i == 3 { - assert!(decimal_array.is_null(i as usize)); - } else { - assert_eq!( - 10_i128.pow(6) * (i as i128 + 1), - decimal_array.value(i as usize) - ); - } - } + generate_cast_test_case!( + &array, + DecimalArray, + &decimal_type, + vec![ + Some(1000000_i128), + Some(2000000_i128), + Some(3000000_i128), + None, + Some(5000000_i128) + ] + ); } // test i8 to decimal type with overflow the result type @@ -1986,34 +2216,54 @@ mod tests { assert_eq!("Invalid argument error: The value of 1000 i128 is not compatible with Decimal(3,1)", casted_array.unwrap_err().to_string()); // test f32 to decimal type - let f_data: Vec = vec![1.1, 2.2, 4.4, 1.123_456_8]; - let array = Float32Array::from(f_data.clone()); + let array = Float32Array::from(vec![ + Some(1.1), + Some(2.2), + Some(4.4), + None, + Some(1.123_456_7), + Some(1.123_456_7), + ]); let array = Arc::new(array) as ArrayRef; - let casted_array = cast(&array, &decimal_type).unwrap(); - let decimal_array = casted_array - .as_any() - .downcast_ref::() - .unwrap(); - assert_eq!(&decimal_type, decimal_array.data_type()); - for (i, item) in f_data.iter().enumerate().take(array.len()) { - let left = (*item as f64) * 10_f64.pow(6); - assert_eq!(left as i128, decimal_array.value(i as usize)); - } + generate_cast_test_case!( + &array, + DecimalArray, + &decimal_type, + vec![ + Some(1100000_i128), + Some(2200000_i128), + Some(4400000_i128), + None, + Some(1123456_i128), + Some(1123456_i128), + ] + ); // test f64 to decimal type - let f_data: Vec = vec![1.1, 2.2, 4.4, 1.123_456_789_123_4]; - let array = Float64Array::from(f_data.clone()); + let array = Float64Array::from(vec![ + Some(1.1), + Some(2.2), + Some(4.4), + None, + Some(1.123_456_789_123_4), + Some(1.123_456_789_012_345_6), + Some(1.123_456_789_012_345_6), + ]); let array = Arc::new(array) as ArrayRef; - let casted_array = cast(&array, &decimal_type).unwrap(); - let decimal_array = casted_array - .as_any() - .downcast_ref::() - .unwrap(); - assert_eq!(&decimal_type, decimal_array.data_type()); - for (i, item) in f_data.iter().enumerate().take(array.len()) { - let left = (*item as f64) * 10_f64.pow(6); - assert_eq!(left as i128, decimal_array.value(i as usize)); - } + generate_cast_test_case!( + &array, + DecimalArray, + &decimal_type, + vec![ + Some(1100000_i128), + Some(2200000_i128), + Some(4400000_i128), + None, + Some(1123456_i128), + Some(1123456_i128), + Some(1123456_i128), + ] + ); } #[test] @@ -3968,6 +4218,9 @@ mod tests { Arc::new(DurationMillisecondArray::from(vec![1000, 2000])), Arc::new(DurationMicrosecondArray::from(vec![1000, 2000])), Arc::new(DurationNanosecondArray::from(vec![1000, 2000])), + Arc::new( + create_decimal_array(&[Some(1), Some(2), Some(3), None], 38, 0).unwrap(), + ), ] } @@ -4142,6 +4395,7 @@ mod tests { Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)), Dictionary(Box::new(DataType::Int16), Box::new(DataType::Utf8)), Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)), + Decimal(38, 0), ] }