diff --git a/datafusion/src/physical_plan/expressions/case.rs b/datafusion/src/physical_plan/expressions/case.rs index e8c500e5ed62..723438df60f7 100644 --- a/datafusion/src/physical_plan/expressions/case.rs +++ b/datafusion/src/physical_plan/expressions/case.rs @@ -17,13 +17,13 @@ use std::{any::Any, sync::Arc}; +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::{ColumnarValue, PhysicalExpr}; use arrow::array::{self, *}; +use arrow::compute::{eq, eq_utf8}; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; -use crate::error::{DataFusionError, Result}; -use crate::physical_plan::{ColumnarValue, PhysicalExpr}; - /// The CASE expression is similar to a series of nested if/else and there are two forms that /// can be used. The first form consists of a series of boolean "when" expressions with /// corresponding "then" expressions, and an optional "else" expression. @@ -265,7 +265,7 @@ fn build_null_array(data_type: &DataType, num_rows: usize) -> Result { } macro_rules! array_equals { - ($TY:ty, $L:expr, $R:expr) => {{ + ($TY:ty, $L:expr, $R:expr, $eq_fn:expr) => {{ let when_value = $L .as_ref() .as_any() @@ -278,15 +278,7 @@ macro_rules! array_equals { .downcast_ref::<$TY>() .expect("array_equals downcast failed"); - let mut builder = BooleanBuilder::new(when_value.len()); - for row in 0..when_value.len() { - if when_value.is_valid(row) && base_value.is_valid(row) { - builder.append_value(when_value.value(row) == base_value.value(row))?; - } else { - builder.append_null()?; - } - } - Ok(builder.finish()) + $eq_fn(when_value, base_value).map_err(DataFusionError::from) }}; } @@ -296,17 +288,39 @@ fn array_equals( base_value: ArrayRef, ) -> Result { match data_type { - DataType::UInt8 => array_equals!(array::UInt8Array, when_value, base_value), - DataType::UInt16 => array_equals!(array::UInt16Array, when_value, base_value), - DataType::UInt32 => array_equals!(array::UInt32Array, when_value, base_value), - DataType::UInt64 => array_equals!(array::UInt64Array, when_value, base_value), - DataType::Int8 => array_equals!(array::Int8Array, when_value, base_value), - DataType::Int16 => array_equals!(array::Int16Array, when_value, base_value), - DataType::Int32 => array_equals!(array::Int32Array, when_value, base_value), - DataType::Int64 => array_equals!(array::Int64Array, when_value, base_value), - DataType::Float32 => array_equals!(array::Float32Array, when_value, base_value), - DataType::Float64 => array_equals!(array::Float64Array, when_value, base_value), - DataType::Utf8 => array_equals!(array::StringArray, when_value, base_value), + DataType::UInt8 => { + array_equals!(array::UInt8Array, when_value, base_value, eq) + } + DataType::UInt16 => { + array_equals!(array::UInt16Array, when_value, base_value, eq) + } + DataType::UInt32 => { + array_equals!(array::UInt32Array, when_value, base_value, eq) + } + DataType::UInt64 => { + array_equals!(array::UInt64Array, when_value, base_value, eq) + } + DataType::Int8 => { + array_equals!(array::Int8Array, when_value, base_value, eq) + } + DataType::Int16 => { + array_equals!(array::Int16Array, when_value, base_value, eq) + } + DataType::Int32 => { + array_equals!(array::Int32Array, when_value, base_value, eq) + } + DataType::Int64 => { + array_equals!(array::Int64Array, when_value, base_value, eq) + } + DataType::Float32 => { + array_equals!(array::Float32Array, when_value, base_value, eq) + } + DataType::Float64 => { + array_equals!(array::Float64Array, when_value, base_value, eq) + } + DataType::Utf8 => { + array_equals!(array::StringArray, when_value, base_value, eq_utf8) + } other => Err(DataFusionError::Execution(format!( "CASE does not support '{:?}'", other