Skip to content

Commit

Permalink
use arrow cast in datafusion
Browse files Browse the repository at this point in the history
  • Loading branch information
nevi-me committed Mar 4, 2019
1 parent d3687d9 commit 4b897a4
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 93 deletions.
19 changes: 18 additions & 1 deletion rust/arrow/src/compute/cast_kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,12 @@ macro_rules! cast_numeric_to_string {
/// Cast array to provided data type
pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result<ArrayRef> {
use DataType::*;
// to and from have to be compatible
let from_type = array.data_type();

// clone array if types are the same
if from_type == to_type {
return Ok(array.clone());
}
match (from_type, to_type) {
(Struct(_), _) => Err(ArrowError::ComputeError(
"Cannot cast from struct to other types".to_string(),
Expand Down Expand Up @@ -250,4 +254,17 @@ mod tests {
assert_eq!(8.0, c.value(3));
assert_eq!(9.0, c.value(4));
}

#[test]
fn test_cast_i32_to_i32() {
let a = Int32Array::from(vec![5, 6, 7, 8, 9]);
let array = Arc::new(a) as ArrayRef;
let b = cast(&array, &DataType::Int32).unwrap();
let c = b.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(5, c.value(0));
assert_eq!(6, c.value(1));
assert_eq!(7, c.value(2));
assert_eq!(8, c.value(3));
assert_eq!(9, c.value(4));
}
}
95 changes: 3 additions & 92 deletions rust/datafusion/src/execution/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,54 +245,6 @@ macro_rules! literal_array {
}};
}

// /// Casts a column to an array with a different data type
// macro_rules! cast_column {
// ($INDEX:expr, $TO_TYPE:expr) => {{
// Rc::new(move |batch: &RecordBatch| {
// // get data and cast to known type
// // match batch.column($INDEX).as_any().downcast_ref::<$FROM_TYPE>() {
// // Some(array) => {
// // // create builder for desired type
// // let mut builder = $TO_TYPE::builder(batch.num_rows());
// // for i in 0..batch.num_rows() {
// // if array.is_null(i) {
// // builder.append_null()?;
// // } else {
// // builder.append_value(array.value(i) as $DT)?;
// // }
// // }
// // Ok(Arc::new(builder.finish()) as ArrayRef)
// // }
// // None => Err(ExecutionError::InternalError(format!(
// // "Column at index {} is not of expected type",
// // $INDEX
// // ))),
// // }
// compute::cast(batch.column($INDEX), $TO_TYPE)
// .map_err(|e| ExecutionError::ArrowError(e))
// })
// }};
// }

// macro_rules! cast_column_outer {
// ($INDEX:expr, $FROM_TYPE:ty, $TO_TYPE:expr) => {{
// // match $TO_TYPE {
// // DataType::UInt8 => cast_column!($INDEX, $FROM_TYPE, UInt8Array, u8),
// // DataType::UInt16 => cast_column!($INDEX, $FROM_TYPE, UInt16Array, u16),
// // DataType::UInt32 => cast_column!($INDEX, $FROM_TYPE, UInt32Array, u32),
// // DataType::UInt64 => cast_column!($INDEX, $FROM_TYPE, UInt64Array, u64),
// // DataType::Int8 => cast_column!($INDEX, $FROM_TYPE, Int8Array, i8),
// // DataType::Int16 => cast_column!($INDEX, $FROM_TYPE, Int16Array, i16),
// // DataType::Int32 => cast_column!($INDEX, $FROM_TYPE, Int32Array, i32),
// // DataType::Int64 => cast_column!($INDEX, $FROM_TYPE, Int64Array, i64),
// // DataType::Float32 => cast_column!($INDEX, $FROM_TYPE, Float32Array,
// f32), // DataType::Float64 => cast_column!($INDEX, $FROM_TYPE,
// Float64Array, f64), // _ => unimplemented!(),
// // }
// cast_column!($INDEX, $TO_TYPE)
// }};
// }

/// Compiles a scalar expression into a closure
pub fn compile_scalar_expr(
ctx: &ExecutionContext,
Expand Down Expand Up @@ -334,55 +286,14 @@ pub fn compile_scalar_expr(
} => match expr.as_ref() {
&Expr::Column(index) => {
let col = input_schema.field(index);
let dt = data_type.clone();
Ok(RuntimeExpr::Compiled {
name: col.name().clone(),
t: col.data_type().clone(),
f: Rc::new(|batch: &RecordBatch| {
// match compute::cast(batch.column(index), data_type) {
// Ok(array) => Ok(array),
// Err(e) => Err(e.into())
// }
compute::cast(batch.column(index).clone(), data_type.clone())
f: Rc::new(move |batch: &RecordBatch| {
compute::cast(batch.column(index), &dt)
.map_err(|e| ExecutionError::ArrowError(e))
}),
// f: match col.data_type() {
// DataType::Int8 => {
// cast_column_outer!(index, Int8Array, &data_type)
// }
// DataType::Int16 => {
// cast_column_outer!(index, Int16Array, &data_type)
// }
// DataType::Int32 => {
// cast_column_outer!(index, Int32Array, &data_type)
// }
// DataType::Int64 => {
// cast_column_outer!(index, Int64Array, &data_type)
// }
// DataType::UInt8 => {
// cast_column_outer!(index, UInt8Array, &data_type)
// }
// DataType::UInt16 => {
// cast_column_outer!(index, UInt16Array, &data_type)
// }
// DataType::UInt32 => {
// cast_column_outer!(index, UInt32Array, &data_type)
// }
// DataType::UInt64 => {
// cast_column_outer!(index, UInt64Array, &data_type)
// }
// DataType::Float32 => {
// cast_column_outer!(index, Float32Array, &data_type)
// }
// DataType::Float64 => {
// cast_column_outer!(index, Float64Array, &data_type)
// }
// _ => panic!("unsupported CAST operation"), /*TODO */
// /*Err(ExecutionError::NotImplemented(format!(
// "CAST column from {:?} to {:?}",
// col.data_type(),
// data_type
// )))*/
// },
})
}
&Expr::Literal(ref value) => {
Expand Down

0 comments on commit 4b897a4

Please sign in to comment.