diff --git a/rust/arrow/src/compute/cast_kernels.rs b/rust/arrow/src/compute/cast_kernels.rs index c2c5dde365fe0..62940b5729b73 100644 --- a/rust/arrow/src/compute/cast_kernels.rs +++ b/rust/arrow/src/compute/cast_kernels.rs @@ -18,6 +18,37 @@ //! Defines cast kernels for `ArrayRef`. //! //! Allows casting arrays between supported datatypes. +//! +//! ## Behavior +//! +//! * Boolean to Utf8: `true` => '1', `false` => `0` +//! * Utf8 to numeric: strings that can't be parsed to numbers return null, float strings +//! in integer casts return null +//! * Numeric to boolean: 0 returns `false`, any other value returns `true` +//! +//! ## Unsupported Casts +//! +//! * To or from `StructArray` +//! * To or from `ListArray` +//! * Boolean to float +//! * Utf8 to boolean +//! +//! Example: +//! +//! ``` +//! use arrow::array::*; +//! use arrow::compute::cast; +//! use arrow::datatypes::DataType; +//! use std::sync::Arc; +//! +//! let a = Int32Array::from(vec![5, 6, 7]); +//! let array = Arc::new(a) as ArrayRef; +//! let b = cast(&array, &DataType::Float64).unwrap(); +//! let c = b.as_any().downcast_ref::().unwrap(); +//! assert_eq!(5.0, c.value(0)); +//! assert_eq!(6.0, c.value(1)); +//! assert_eq!(7.0, c.value(2)); +//! ``` use std::sync::Arc; @@ -26,9 +57,6 @@ use crate::builder::*; use crate::datatypes::*; use crate::error::{ArrowError, Result}; -// TODO: -// * remove downcast unwraps and replace with explicit errors - /// Macro rule to cast between numeric types macro_rules! cast_numeric_arrays { ($array:expr, $from_ty:ident, $to_ty:ident) => {{ @@ -69,6 +97,31 @@ macro_rules! cast_string_to_numeric { }}; } +macro_rules! cast_numeric_to_bool { + ($array:expr, $from_ty:ident) => {{ + match cast_numeric_to_bool::<$from_ty>( + $array + .as_any() + .downcast_ref::>() + .unwrap(), + ) { + Ok(to) => Ok(Arc::new(to) as ArrayRef), + Err(e) => Err(e), + } + }}; +} + +macro_rules! cast_bool_to_numeric { + ($array:expr, $to_ty:ident) => {{ + match cast_bool_to_numeric::<$to_ty>( + $array.as_any().downcast_ref::().unwrap(), + ) { + Ok(to) => Ok(Arc::new(to) as ArrayRef), + Err(e) => Err(e), + } + }}; +} + /// Cast array to provided data type pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result { use DataType::*; @@ -85,13 +138,69 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result { (_, Struct(_)) => Err(ArrowError::ComputeError( "Cannot cast to struct from other types".to_string(), )), - (List(_), List(_)) => unimplemented!("Casting between lists not yet supported"), + (List(_), List(_)) => Err(ArrowError::ComputeError( + "Casting between lists not yet supported".to_string(), + )), (List(_), _) => Err(ArrowError::ComputeError( "Cannot cast list to non-list data types".to_string(), )), - (_, List(_)) => unimplemented!("Casting scalars to lists not yet supported"), - (_, Boolean) => unimplemented!("Boolean casts not yet implemented"), - (Boolean, _) => unimplemented!("Boolean casts not yet implemented"), + (_, List(_)) => Err(ArrowError::ComputeError( + "Cannot cast primitive types to lists".to_string(), + )), + (_, Boolean) => match from_type { + UInt8 => cast_numeric_to_bool!(array, UInt8Type), + UInt16 => cast_numeric_to_bool!(array, UInt16Type), + UInt32 => cast_numeric_to_bool!(array, UInt32Type), + UInt64 => cast_numeric_to_bool!(array, UInt64Type), + Int8 => cast_numeric_to_bool!(array, Int8Type), + Int16 => cast_numeric_to_bool!(array, Int16Type), + Int32 => cast_numeric_to_bool!(array, Int32Type), + Int64 => cast_numeric_to_bool!(array, Int64Type), + Float32 => cast_numeric_to_bool!(array, Float32Type), + Float64 => cast_numeric_to_bool!(array, Float64Type), + Utf8 => Err(ArrowError::ComputeError(format!( + "Casting from {:?} to {:?} not supported", + from_type, to_type, + ))), + _ => Err(ArrowError::ComputeError(format!( + "Casting from {:?} to {:?} not supported", + from_type, to_type, + ))), + }, + (Boolean, _) => match to_type { + UInt8 => cast_bool_to_numeric!(array, UInt8Type), + UInt16 => cast_bool_to_numeric!(array, UInt16Type), + UInt32 => cast_bool_to_numeric!(array, UInt32Type), + UInt64 => cast_bool_to_numeric!(array, UInt64Type), + Int8 => cast_bool_to_numeric!(array, Int8Type), + Int16 => cast_bool_to_numeric!(array, Int16Type), + Int32 => cast_bool_to_numeric!(array, Int32Type), + Int64 => cast_bool_to_numeric!(array, Int64Type), + Float32 | Float64 => Err(ArrowError::ComputeError(format!( + "Casting from {:?} to {:?} not supported", + from_type, to_type, + ))), + Utf8 => { + let from = array.as_any().downcast_ref::().unwrap(); + let mut b = BinaryBuilder::new(array.len()); + for i in 0..array.len() { + if array.is_null(i) { + b.append(false)?; + } else { + b.append_string(match from.value(i) { + true => "1", + false => "0", + })?; + } + } + + Ok(Arc::new(b.finish()) as ArrayRef) + } + _ => Err(ArrowError::ComputeError(format!( + "Casting from {:?} to {:?} not supported", + from_type, to_type, + ))), + }, (Utf8, _) => match to_type { UInt8 => cast_string_to_numeric!(array, UInt8Type), UInt16 => cast_string_to_numeric!(array, UInt16Type), @@ -103,11 +212,10 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result { Int64 => cast_string_to_numeric!(array, Int64Type), Float32 => cast_string_to_numeric!(array, Float32Type), Float64 => cast_string_to_numeric!(array, Float64Type), - _ => unimplemented!( - "Casting from {:?} to {:?} not yet implemented", - from_type, - to_type - ), + _ => Err(ArrowError::ComputeError(format!( + "Casting from {:?} to {:?} not supported", + from_type, to_type, + ))), }, (_, Utf8) => match from_type { UInt8 => cast_numeric_to_string!(array, UInt8Type), @@ -120,11 +228,10 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result { Int64 => cast_numeric_to_string!(array, Int64Type), Float32 => cast_numeric_to_string!(array, Float32Type), Float64 => cast_numeric_to_string!(array, Float64Type), - _ => unimplemented!( - "Casting from {:?} to {:?} not yet implemented", - from_type, - to_type - ), + _ => Err(ArrowError::ComputeError(format!( + "Casting from {:?} to {:?} not supported", + from_type, to_type, + ))), }, // start numeric casts @@ -218,7 +325,10 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result { (Float64, Int64) => cast_numeric_arrays!(array, Float64Type, Int64Type), (Float64, Float32) => cast_numeric_arrays!(array, Float64Type, Float32Type), // end numeric casts - (_, _) => unimplemented!("Unable to cast from {:?} to {:?}", from_type, to_type), + (_, _) => Err(ArrowError::ComputeError(format!( + "Casting from {:?} to {:?} not supported", + from_type, to_type, + ))), } } @@ -291,6 +401,59 @@ where Ok(b.finish()) } +/// Cast numeric types to Boolean +/// +/// Any zero value returns `false` while non-zero returns `true` +fn cast_numeric_to_bool(from: &PrimitiveArray) -> Result +where + T: ArrowPrimitiveType + ArrowNumericType, +{ + let mut b = BooleanBuilder::new(from.len()); + + for i in 0..from.len() { + if from.is_null(i) { + b.append_null()?; + } else { + if from.value(i) != T::default_value() { + b.append_value(true)?; + } else { + b.append_value(false)?; + } + } + } + + Ok(b.finish()) +} + +/// Cast Boolean types to numeric +/// +/// Any zero value returns `false` while non-zero returns `true` +fn cast_bool_to_numeric(from: &BooleanArray) -> Result> +where + T: ArrowPrimitiveType + ArrowNumericType, + T::Native: num::NumCast, +{ + let mut b = PrimitiveBuilder::::new(from.len()); + + for i in 0..from.len() { + if from.is_null(i) { + b.append_null()?; + } else { + if from.value(i) { + // a workaround to cast a primitive to T::Native, infallible + match num::cast::cast(1) { + Some(v) => b.append_value(v)?, + None => b.append_null()?, + }; + } else { + b.append_value(T::default_value())?; + } + } + } + + Ok(b.finish()) +} + #[cfg(test)] mod tests { use super::*; @@ -308,6 +471,20 @@ mod tests { assert_eq!(9.0, c.value(4)); } + #[test] + fn test_cast_i32_to_u8() { + let a = Int32Array::from(vec![-5, 6, -7, 8, 100000000]); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::UInt8).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!(false, c.is_valid(0)); + assert_eq!(6, c.value(1)); + assert_eq!(false, c.is_valid(2)); + assert_eq!(8, c.value(3)); + // overflows return None + assert_eq!(false, c.is_valid(4)); + } + #[test] fn test_cast_i32_to_i32() { let a = Int32Array::from(vec![5, 6, 7, 8, 9]); @@ -333,4 +510,33 @@ mod tests { assert_eq!(8, c.value(3)); assert_eq!(false, c.is_valid(2)); } + + #[test] + fn test_cast_bool_to_i32() { + let a = BooleanArray::from(vec![Some(true), Some(false), None]); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Int32).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!(1, c.value(0)); + assert_eq!(0, c.value(1)); + assert_eq!(false, c.is_valid(2)); + } + + #[test] + #[should_panic(expected = "Casting from Boolean to Float64 not supported")] + fn test_cast_bool_to_f64() { + let a = BooleanArray::from(vec![Some(true), Some(false), None]); + let array = Arc::new(a) as ArrayRef; + cast(&array, &DataType::Float64).unwrap(); + } + + #[test] + #[should_panic( + expected = "Casting from Int32 to Timestamp(Microsecond) not supported" + )] + fn test_cast_int32_to_timestamp() { + let a = Int32Array::from(vec![Some(2), Some(10), None]); + let array = Arc::new(a) as ArrayRef; + cast(&array, &DataType::Timestamp(TimeUnit::Microsecond)).unwrap(); + } }