Skip to content

Commit

Permalink
boolean casts, documentation, error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
nevi-me committed Mar 6, 2019
1 parent 4a8906b commit 04c7307
Showing 1 changed file with 224 additions and 18 deletions.
242 changes: 224 additions & 18 deletions rust/arrow/src/compute/cast_kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,37 @@
//! Defines cast kernels for `ArrayRef`.
//!
//! Allows casting arrays between supported datatypes.
//!
//! ## Behavior
//!
//! * Boolean to Utf8: `true` => '1', `false` => `0`
//! * Utf8 to numeric: strings that can't be parsed to numbers return null, float strings
//! in integer casts return null
//! * Numeric to boolean: 0 returns `false`, any other value returns `true`
//!
//! ## Unsupported Casts
//!
//! * To or from `StructArray`
//! * To or from `ListArray`
//! * Boolean to float
//! * Utf8 to boolean
//!
//! Example:
//!
//! ```
//! use arrow::array::*;
//! use arrow::compute::cast;
//! use arrow::datatypes::DataType;
//! use std::sync::Arc;
//!
//! let a = Int32Array::from(vec![5, 6, 7]);
//! let array = Arc::new(a) as ArrayRef;
//! let b = cast(&array, &DataType::Float64).unwrap();
//! let c = b.as_any().downcast_ref::<Float64Array>().unwrap();
//! assert_eq!(5.0, c.value(0));
//! assert_eq!(6.0, c.value(1));
//! assert_eq!(7.0, c.value(2));
//! ```

use std::sync::Arc;

Expand All @@ -26,9 +57,6 @@ use crate::builder::*;
use crate::datatypes::*;
use crate::error::{ArrowError, Result};

// TODO:
// * remove downcast unwraps and replace with explicit errors

/// Macro rule to cast between numeric types
macro_rules! cast_numeric_arrays {
($array:expr, $from_ty:ident, $to_ty:ident) => {{
Expand Down Expand Up @@ -69,6 +97,31 @@ macro_rules! cast_string_to_numeric {
}};
}

macro_rules! cast_numeric_to_bool {
($array:expr, $from_ty:ident) => {{
match cast_numeric_to_bool::<$from_ty>(
$array
.as_any()
.downcast_ref::<PrimitiveArray<$from_ty>>()
.unwrap(),
) {
Ok(to) => Ok(Arc::new(to) as ArrayRef),
Err(e) => Err(e),
}
}};
}

macro_rules! cast_bool_to_numeric {
($array:expr, $to_ty:ident) => {{
match cast_bool_to_numeric::<$to_ty>(
$array.as_any().downcast_ref::<BooleanArray>().unwrap(),
) {
Ok(to) => Ok(Arc::new(to) as ArrayRef),
Err(e) => Err(e),
}
}};
}

/// Cast array to provided data type
pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result<ArrayRef> {
use DataType::*;
Expand All @@ -85,13 +138,69 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result<ArrayRef> {
(_, Struct(_)) => Err(ArrowError::ComputeError(
"Cannot cast to struct from other types".to_string(),
)),
(List(_), List(_)) => unimplemented!("Casting between lists not yet supported"),
(List(_), List(_)) => Err(ArrowError::ComputeError(
"Casting between lists not yet supported".to_string(),
)),
(List(_), _) => Err(ArrowError::ComputeError(
"Cannot cast list to non-list data types".to_string(),
)),
(_, List(_)) => unimplemented!("Casting scalars to lists not yet supported"),
(_, Boolean) => unimplemented!("Boolean casts not yet implemented"),
(Boolean, _) => unimplemented!("Boolean casts not yet implemented"),
(_, List(_)) => Err(ArrowError::ComputeError(
"Cannot cast primitive types to lists".to_string(),
)),
(_, Boolean) => match from_type {
UInt8 => cast_numeric_to_bool!(array, UInt8Type),
UInt16 => cast_numeric_to_bool!(array, UInt16Type),
UInt32 => cast_numeric_to_bool!(array, UInt32Type),
UInt64 => cast_numeric_to_bool!(array, UInt64Type),
Int8 => cast_numeric_to_bool!(array, Int8Type),
Int16 => cast_numeric_to_bool!(array, Int16Type),
Int32 => cast_numeric_to_bool!(array, Int32Type),
Int64 => cast_numeric_to_bool!(array, Int64Type),
Float32 => cast_numeric_to_bool!(array, Float32Type),
Float64 => cast_numeric_to_bool!(array, Float64Type),
Utf8 => Err(ArrowError::ComputeError(format!(
"Casting from {:?} to {:?} not supported",
from_type, to_type,
))),
_ => Err(ArrowError::ComputeError(format!(
"Casting from {:?} to {:?} not supported",
from_type, to_type,
))),
},
(Boolean, _) => match to_type {
UInt8 => cast_bool_to_numeric!(array, UInt8Type),
UInt16 => cast_bool_to_numeric!(array, UInt16Type),
UInt32 => cast_bool_to_numeric!(array, UInt32Type),
UInt64 => cast_bool_to_numeric!(array, UInt64Type),
Int8 => cast_bool_to_numeric!(array, Int8Type),
Int16 => cast_bool_to_numeric!(array, Int16Type),
Int32 => cast_bool_to_numeric!(array, Int32Type),
Int64 => cast_bool_to_numeric!(array, Int64Type),
Float32 | Float64 => Err(ArrowError::ComputeError(format!(
"Casting from {:?} to {:?} not supported",
from_type, to_type,
))),
Utf8 => {
let from = array.as_any().downcast_ref::<BooleanArray>().unwrap();
let mut b = BinaryBuilder::new(array.len());
for i in 0..array.len() {
if array.is_null(i) {
b.append(false)?;
} else {
b.append_string(match from.value(i) {
true => "1",
false => "0",
})?;
}
}

Ok(Arc::new(b.finish()) as ArrayRef)
}
_ => Err(ArrowError::ComputeError(format!(
"Casting from {:?} to {:?} not supported",
from_type, to_type,
))),
},
(Utf8, _) => match to_type {
UInt8 => cast_string_to_numeric!(array, UInt8Type),
UInt16 => cast_string_to_numeric!(array, UInt16Type),
Expand All @@ -103,11 +212,10 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result<ArrayRef> {
Int64 => cast_string_to_numeric!(array, Int64Type),
Float32 => cast_string_to_numeric!(array, Float32Type),
Float64 => cast_string_to_numeric!(array, Float64Type),
_ => unimplemented!(
"Casting from {:?} to {:?} not yet implemented",
from_type,
to_type
),
_ => Err(ArrowError::ComputeError(format!(
"Casting from {:?} to {:?} not supported",
from_type, to_type,
))),
},
(_, Utf8) => match from_type {
UInt8 => cast_numeric_to_string!(array, UInt8Type),
Expand All @@ -120,11 +228,10 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result<ArrayRef> {
Int64 => cast_numeric_to_string!(array, Int64Type),
Float32 => cast_numeric_to_string!(array, Float32Type),
Float64 => cast_numeric_to_string!(array, Float64Type),
_ => unimplemented!(
"Casting from {:?} to {:?} not yet implemented",
from_type,
to_type
),
_ => Err(ArrowError::ComputeError(format!(
"Casting from {:?} to {:?} not supported",
from_type, to_type,
))),
},

// start numeric casts
Expand Down Expand Up @@ -218,7 +325,10 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result<ArrayRef> {
(Float64, Int64) => cast_numeric_arrays!(array, Float64Type, Int64Type),
(Float64, Float32) => cast_numeric_arrays!(array, Float64Type, Float32Type),
// end numeric casts
(_, _) => unimplemented!("Unable to cast from {:?} to {:?}", from_type, to_type),
(_, _) => Err(ArrowError::ComputeError(format!(
"Casting from {:?} to {:?} not supported",
from_type, to_type,
))),
}
}

Expand Down Expand Up @@ -291,6 +401,59 @@ where
Ok(b.finish())
}

/// Cast numeric types to Boolean
///
/// Any zero value returns `false` while non-zero returns `true`
fn cast_numeric_to_bool<T>(from: &PrimitiveArray<T>) -> Result<BooleanArray>
where
T: ArrowPrimitiveType + ArrowNumericType,
{
let mut b = BooleanBuilder::new(from.len());

for i in 0..from.len() {
if from.is_null(i) {
b.append_null()?;
} else {
if from.value(i) != T::default_value() {
b.append_value(true)?;
} else {
b.append_value(false)?;
}
}
}

Ok(b.finish())
}

/// Cast Boolean types to numeric
///
/// Any zero value returns `false` while non-zero returns `true`
fn cast_bool_to_numeric<T>(from: &BooleanArray) -> Result<PrimitiveArray<T>>
where
T: ArrowPrimitiveType + ArrowNumericType,
T::Native: num::NumCast,
{
let mut b = PrimitiveBuilder::<T>::new(from.len());

for i in 0..from.len() {
if from.is_null(i) {
b.append_null()?;
} else {
if from.value(i) {
// a workaround to cast a primitive to T::Native, infallible
match num::cast::cast(1) {
Some(v) => b.append_value(v)?,
None => b.append_null()?,
};
} else {
b.append_value(T::default_value())?;
}
}
}

Ok(b.finish())
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -308,6 +471,20 @@ mod tests {
assert_eq!(9.0, c.value(4));
}

#[test]
fn test_cast_i32_to_u8() {
let a = Int32Array::from(vec![-5, 6, -7, 8, 100000000]);
let array = Arc::new(a) as ArrayRef;
let b = cast(&array, &DataType::UInt8).unwrap();
let c = b.as_any().downcast_ref::<UInt8Array>().unwrap();
assert_eq!(false, c.is_valid(0));
assert_eq!(6, c.value(1));
assert_eq!(false, c.is_valid(2));
assert_eq!(8, c.value(3));
// overflows return None
assert_eq!(false, c.is_valid(4));
}

#[test]
fn test_cast_i32_to_i32() {
let a = Int32Array::from(vec![5, 6, 7, 8, 9]);
Expand All @@ -333,4 +510,33 @@ mod tests {
assert_eq!(8, c.value(3));
assert_eq!(false, c.is_valid(2));
}

#[test]
fn test_cast_bool_to_i32() {
let a = BooleanArray::from(vec![Some(true), Some(false), None]);
let array = Arc::new(a) as ArrayRef;
let b = cast(&array, &DataType::Int32).unwrap();
let c = b.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(1, c.value(0));
assert_eq!(0, c.value(1));
assert_eq!(false, c.is_valid(2));
}

#[test]
#[should_panic(expected = "Casting from Boolean to Float64 not supported")]
fn test_cast_bool_to_f64() {
let a = BooleanArray::from(vec![Some(true), Some(false), None]);
let array = Arc::new(a) as ArrayRef;
cast(&array, &DataType::Float64).unwrap();
}

#[test]
#[should_panic(
expected = "Casting from Int32 to Timestamp(Microsecond) not supported"
)]
fn test_cast_int32_to_timestamp() {
let a = Int32Array::from(vec![Some(2), Some(10), None]);
let array = Arc::new(a) as ArrayRef;
cast(&array, &DataType::Timestamp(TimeUnit::Microsecond)).unwrap();
}
}

0 comments on commit 04c7307

Please sign in to comment.