Skip to content

Commit

Permalink
add cast test macro function; refactor other type to decimal type; ad…
Browse files Browse the repository at this point in the history
…d decimal to signed numeric type
  • Loading branch information
liukun4515 committed Dec 21, 2021
1 parent f3e452c commit d17180f
Showing 1 changed file with 239 additions and 53 deletions.
292 changes: 239 additions & 53 deletions arrow/src/compute/kernels/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,11 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
}

match (from_type, to_type) {
// TODO now just support signed numeric to decimal, support decimal to numeric later
(Int8 | Int16 | Int32 | Int64 | Float32 | Float64, Decimal(_, _))
// TODO UTF8/unsigned numeric to decimal
// signed numeric to decimal
(Int8 | Int16 | Int32 | Int64 | Float32 | Float64, Decimal(_, _)) |
// decimal to signed numeric
(Decimal(_, _), Int8 | Int16 | Int32 | Int64 | Float32 | Float64)
| (
Null,
Boolean
Expand Down Expand Up @@ -1906,26 +1909,191 @@ where
mod tests {
use super::*;
use crate::{buffer::Buffer, util::display::array_value_to_string};
use num::traits::Pow;

macro_rules! generate_cast_test_case {
($INPUT_ARRAY: expr, $INPUT_ARRAY_TYPE: expr, $OUTPUT_TYPE_ARRAY: ident, $OUTPUT_TYPE: expr, $OUTPUT_VALUES: expr) => {
// assert cast type
assert!(can_cast_types($INPUT_ARRAY_TYPE, $OUTPUT_TYPE));
let casted_array = cast($INPUT_ARRAY, $OUTPUT_TYPE).unwrap();
let result_array = casted_array
.as_any()
.downcast_ref::<$OUTPUT_TYPE_ARRAY>()
.unwrap();
assert_eq!($OUTPUT_TYPE, result_array.data_type());
assert_eq!(result_array.len(), $OUTPUT_VALUES.len());
for (i, x) in $OUTPUT_VALUES.iter().enumerate() {
match x {
Some(x) => {
assert_eq!(result_array.value(i), *x);
}
None => {
assert!(result_array.is_null(i));
}
}
}
};
}

// TODO remove this function if the decimal array has the creator function
fn create_decimal_array(
array: &Vec<Option<i128>>,
precision: usize,
scale: usize,
) -> Result<DecimalArray> {
let mut decimal_builder = DecimalBuilder::new(array.len(), precision, scale);
for value in array {
match value {
None => {
decimal_builder.append_null()?;
}
Some(v) => {
decimal_builder.append_value(*v)?;
}
}
}
Ok(decimal_builder.finish())
}

#[test]
fn test_cast_decimal_to_numeric() {
let decimal_type = DataType::Decimal(38, 2);
// negative test
assert!(!can_cast_types(&decimal_type, &DataType::UInt8));
let value_array: Vec<Option<i128>> =
vec![Some(125), Some(225), Some(325), None, Some(525)];
let decimal_array = create_decimal_array(&value_array, 38, 2).unwrap();
let array = Arc::new(decimal_array) as ArrayRef;
// i8
generate_cast_test_case!(
&array,
&decimal_type,
Int8Array,
&DataType::Int8,
vec![Some(1_i8), Some(2_i8), Some(3_i8), None, Some(5_i8)]
);
// i16
generate_cast_test_case!(
&array,
&decimal_type,
Int16Array,
&DataType::Int16,
vec![Some(1_i16), Some(2_i16), Some(3_i16), None, Some(5_i16)]
);
// i32
generate_cast_test_case!(
&array,
&decimal_type,
Int32Array,
&DataType::Int32,
vec![Some(1_i32), Some(2_i32), Some(3_i32), None, Some(5_i32)]
);
// i64
generate_cast_test_case!(
&array,
&decimal_type,
Int64Array,
&DataType::Int64,
vec![Some(1_i64), Some(2_i64), Some(3_i64), None, Some(5_i64)]
);
// f32
generate_cast_test_case!(
&array,
&decimal_type,
Int64Array,
&DataType::Int64,
vec![Some(1_i64), Some(2_i64), Some(3_i64), None, Some(5_i64)]
);
// f64
generate_cast_test_case!(
&array,
&decimal_type,
Int64Array,
&DataType::Int64,
vec![Some(1_i64), Some(2_i64), Some(3_i64), None, Some(5_i64)]
);

// overflow test: out of range of max i8
let value_array: Vec<Option<i128>> = vec![Some(24400)];
let decimal_array = create_decimal_array(&value_array, 38, 2).unwrap();
let array = Arc::new(decimal_array) as ArrayRef;
let casted_array = cast(&array, &DataType::Int8);
assert!(casted_array.is_err());
// TODO add out of range for the cast

// loss the precision: convert decimal to f32、f64
// f32
// 112345678_f32 and 112345679_f32 are same, so the 112345679_f32 will lose precision.
let value_array: Vec<Option<i128>> = vec![
Some(125),
Some(225),
Some(325),
None,
Some(525),
Some(112345678),
Some(112345679),
];
let decimal_array = create_decimal_array(&value_array, 38, 2).unwrap();
let array = Arc::new(decimal_array) as ArrayRef;
generate_cast_test_case!(
&array,
&decimal_type,
Float32Array,
&DataType::Float32,
vec![
Some(1.25_f32),
Some(2.25_f32),
Some(3.25_f32),
None,
Some(5.25_f32),
Some(1123456.78_f32),
Some(1123456.78_f32)
]
);

// f64
// 112345678901234568_f64 and 112345678901234560_f64 are same, so the 112345678901234568_f64 will lose precision.
let value_array: Vec<Option<i128>> = vec![
Some(125),
Some(225),
Some(325),
None,
Some(525),
Some(112345678901234568),
Some(112345678901234560),
];
let decimal_array = create_decimal_array(&value_array, 38, 2).unwrap();
let array = Arc::new(decimal_array) as ArrayRef;
generate_cast_test_case!(
&array,
&decimal_type,
Float64Array,
&DataType::Float64,
vec![
Some(1.25_f64),
Some(2.25_f64),
Some(3.25_f64),
None,
Some(5.25_f64),
Some(1123456789012345.60_f64),
Some(1123456789012345.60_f64)
]
);
}

#[test]
fn test_cast_numeric_to_decimal() {
// test cast type
let decimal_type = DataType::Decimal(38, 6);
assert!(!can_cast_types(&DataType::UInt64, &decimal_type));

// test cast data
let data_types = vec![
DataType::Int8,
DataType::Int16,
DataType::Int32,
DataType::Int64,
DataType::Float32,
DataType::Float64,
];
let decimal_type = DataType::Decimal(38, 6);
for data_type in data_types {
assert!(can_cast_types(&data_type, &decimal_type))
}
assert!(!can_cast_types(&DataType::UInt64, &decimal_type));

// test cast data
let input_datas = vec![
Arc::new(Int8Array::from(vec![
Some(1),
Expand Down Expand Up @@ -1956,25 +2124,21 @@ mod tests {
Some(5),
])) as ArrayRef, // i64
];

// i8, i16, i32, i64
for array in input_datas {
let casted_array = cast(&array, &decimal_type).unwrap();
let decimal_array = casted_array
.as_any()
.downcast_ref::<DecimalArray>()
.unwrap();
assert_eq!(&decimal_type, decimal_array.data_type());
for i in 0..array.len() {
if i == 3 {
assert!(decimal_array.is_null(i as usize));
} else {
assert_eq!(
10_i128.pow(6) * (i as i128 + 1),
decimal_array.value(i as usize)
);
}
}
for (i, array) in input_datas.iter().enumerate() {
generate_cast_test_case!(
array,
&data_types[i],
DecimalArray,
&decimal_type,
vec![
Some(1000000_i128),
Some(2000000_i128),
Some(3000000_i128),
None,
Some(5000000_i128)
]
);
}

// test i8 to decimal type with overflow the result type
Expand All @@ -1986,34 +2150,56 @@ mod tests {
assert_eq!("Invalid argument error: The value of 1000 i128 is not compatible with Decimal(3,1)", casted_array.unwrap_err().to_string());

// test f32 to decimal type
let f_data: Vec<f32> = vec![1.1, 2.2, 4.4, 1.123_456_8];
let array = Float32Array::from(f_data.clone());
let array = Float32Array::from(vec![
Some(1.1),
Some(2.2),
Some(4.4),
None,
Some(1.123_456_78),
Some(1.123_456_79),
]);
let array = Arc::new(array) as ArrayRef;
let casted_array = cast(&array, &decimal_type).unwrap();
let decimal_array = casted_array
.as_any()
.downcast_ref::<DecimalArray>()
.unwrap();
assert_eq!(&decimal_type, decimal_array.data_type());
for (i, item) in f_data.iter().enumerate().take(array.len()) {
let left = (*item as f64) * 10_f64.pow(6);
assert_eq!(left as i128, decimal_array.value(i as usize));
}
generate_cast_test_case!(
&array,
&DataType::Float32,
DecimalArray,
&decimal_type,
vec![
Some(1100000_i128),
Some(2200000_i128),
Some(4400000_i128),
None,
Some(1123456_i128),
Some(1123456_i128),
]
);

// test f64 to decimal type
let f_data: Vec<f64> = vec![1.1, 2.2, 4.4, 1.123_456_789_123_4];
let array = Float64Array::from(f_data.clone());
let array = Float64Array::from(vec![
Some(1.1),
Some(2.2),
Some(4.4),
None,
Some(1.123_456_789_123_4),
Some(1.123_456_789_012_345_68),
Some(1.123_456_789_012_345_60),
]);
let array = Arc::new(array) as ArrayRef;
let casted_array = cast(&array, &decimal_type).unwrap();
let decimal_array = casted_array
.as_any()
.downcast_ref::<DecimalArray>()
.unwrap();
assert_eq!(&decimal_type, decimal_array.data_type());
for (i, item) in f_data.iter().enumerate().take(array.len()) {
let left = (*item as f64) * 10_f64.pow(6);
assert_eq!(left as i128, decimal_array.value(i as usize));
}
generate_cast_test_case!(
&array,
&DataType::Float64,
DecimalArray,
&decimal_type,
vec![
Some(1100000_i128),
Some(2200000_i128),
Some(4400000_i128),
None,
Some(1123456_i128),
Some(1123456_i128),
Some(1123456_i128),
]
);
}

#[test]
Expand Down

0 comments on commit d17180f

Please sign in to comment.