Skip to content

Commit

Permalink
cast binary -> numeric
Browse files Browse the repository at this point in the history
  • Loading branch information
nevi-me committed Mar 6, 2019
1 parent c3b8961 commit 4a8906b
Showing 1 changed file with 69 additions and 3 deletions.
72 changes: 69 additions & 3 deletions rust/arrow/src/compute/cast_kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ use crate::builder::*;
use crate::datatypes::*;
use crate::error::{ArrowError, Result};

// TODO:
// * remove downcast unwraps and replace with explicit errors

/// Macro rule to cast between numeric types
macro_rules! cast_numeric_arrays {
($array:expr, $from_ty:ident, $to_ty:ident) => {{
Expand Down Expand Up @@ -55,6 +58,17 @@ macro_rules! cast_numeric_to_string {
}};
}

macro_rules! cast_string_to_numeric {
($array:expr, $to_ty:ident) => {{
match cast_string_to_numeric::<$to_ty>(
$array.as_any().downcast_ref::<BinaryArray>().unwrap(),
) {
Ok(to) => Ok(Arc::new(to) as ArrayRef),
Err(e) => Err(e),
}
}};
}

/// Cast array to provided data type
pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result<ArrayRef> {
use DataType::*;
Expand All @@ -78,7 +92,23 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result<ArrayRef> {
(_, List(_)) => unimplemented!("Casting scalars to lists not yet supported"),
(_, Boolean) => unimplemented!("Boolean casts not yet implemented"),
(Boolean, _) => unimplemented!("Boolean casts not yet implemented"),
(Utf8, _) => Err(ArrowError::ComputeError("Unable to cast".to_string())),
(Utf8, _) => match to_type {
UInt8 => cast_string_to_numeric!(array, UInt8Type),
UInt16 => cast_string_to_numeric!(array, UInt16Type),
UInt32 => cast_string_to_numeric!(array, UInt32Type),
UInt64 => cast_string_to_numeric!(array, UInt64Type),
Int8 => cast_string_to_numeric!(array, Int8Type),
Int16 => cast_string_to_numeric!(array, Int16Type),
Int32 => cast_string_to_numeric!(array, Int32Type),
Int64 => cast_string_to_numeric!(array, Int64Type),
Float32 => cast_string_to_numeric!(array, Float32Type),
Float64 => cast_string_to_numeric!(array, Float64Type),
_ => unimplemented!(
"Casting from {:?} to {:?} not yet implemented",
from_type,
to_type
),
},
(_, Utf8) => match from_type {
UInt8 => cast_numeric_to_string!(array, UInt8Type),
UInt16 => cast_numeric_to_string!(array, UInt16Type),
Expand Down Expand Up @@ -218,8 +248,6 @@ where
}

/// Cast numeric types to Utf8
///
/// The trait
fn cast_numeric_to_string<T>(from: &PrimitiveArray<T>) -> Result<BinaryArray>
where
T: ArrowPrimitiveType + ArrowNumericType,
Expand All @@ -238,6 +266,31 @@ where
Ok(b.finish())
}

/// Cast numeric types to Utf8
fn cast_string_to_numeric<T>(from: &BinaryArray) -> Result<PrimitiveArray<T>>
where
T: ArrowPrimitiveType + ArrowNumericType,
T::Native: ::std::string::ToString,
{
let mut b = PrimitiveBuilder::<T>::new(from.len());

for i in 0..from.len() {
if from.is_null(i) {
b.append_null()?;
} else {
match std::str::from_utf8(from.value(i))
.unwrap_or("")
.parse::<T::Native>()
{
Ok(v) => b.append_value(v)?,
_ => b.append_null()?,
};
}
}

Ok(b.finish())
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -267,4 +320,17 @@ mod tests {
assert_eq!(8, c.value(3));
assert_eq!(9, c.value(4));
}

#[test]
fn test_cast_utf_to_i32() {
let a = BinaryArray::from(vec!["5", "6", "seven", "8", "9.1"]);
let array = Arc::new(a) as ArrayRef;
let b = cast(&array, &DataType::Int32).unwrap();
let c = b.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(5, c.value(0));
assert_eq!(6, c.value(1));
assert_eq!(false, c.is_valid(2));
assert_eq!(8, c.value(3));
assert_eq!(false, c.is_valid(2));
}
}

0 comments on commit 4a8906b

Please sign in to comment.