diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 6104566450c3..f254274edde6 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -18,6 +18,7 @@ //! Array expressions use std::any::type_name; +use std::cmp::Ordering; use std::collections::HashSet; use std::sync::Arc; @@ -377,111 +378,107 @@ fn return_empty(return_null: bool, data_type: DataType) -> Arc { } } -macro_rules! list_slice { - ($ARRAY:expr, $I:expr, $J:expr, $RETURN_ELEMENT:expr, $ARRAY_TYPE:ident) => {{ - let array = $ARRAY.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); - if $I == 0 && $J == 0 || $ARRAY.is_empty() { - return return_empty($RETURN_ELEMENT, $ARRAY.data_type().clone()); - } +fn list_slice( + array: &dyn Array, + i: i64, + j: i64, + return_element: bool, +) -> ArrayRef { + let array = array.as_any().downcast_ref::().unwrap(); - let i = if $I < 0 { - if $I.abs() as usize > array.len() { - return return_empty(true, $ARRAY.data_type().clone()); - } + let array_type = array.data_type().clone(); - (array.len() as i64 + $I + 1) as usize - } else { - if $I == 0 { - 1 - } else { - $I as usize - } - }; - let j = if $J < 0 { - if $J.abs() as usize > array.len() { - return return_empty(true, $ARRAY.data_type().clone()); + if i == 0 && j == 0 || array.is_empty() { + return return_empty(return_element, array_type); + } + + let i = match i.cmp(&0) { + Ordering::Less => { + if i.unsigned_abs() > array.len() as u64 { + return return_empty(true, array_type); } - if $RETURN_ELEMENT { - (array.len() as i64 + $J + 1) as usize - } else { - (array.len() as i64 + $J) as usize + (array.len() as i64 + i + 1) as usize + } + Ordering::Equal => 1, + Ordering::Greater => i as usize, + }; + + let j = match j.cmp(&0) { + Ordering::Less => { + if j.unsigned_abs() as usize > array.len() { + return return_empty(true, array_type); } - } else { - if $J == 0 { - 1 + if return_element { + (array.len() as i64 + j + 1) as usize } else { - if $J as usize > array.len() { - array.len() - } else { - $J as usize - } + (array.len() as i64 + j) as usize } - }; - - if i > j || i as usize > $ARRAY.len() { - return_empty($RETURN_ELEMENT, $ARRAY.data_type().clone()) - } else { - Arc::new(array.slice((i - 1), (j + 1 - i))) } - }}; + Ordering::Equal => 1, + Ordering::Greater => j.min(array.len() as i64) as usize, + }; + + if i > j || i > array.len() { + return_empty(return_element, array_type) + } else { + Arc::new(array.slice(i - 1, j + 1 - i)) + } } -macro_rules! slice { - ($ARRAY:expr, $KEY:expr, $EXTRA_KEY:expr, $RETURN_ELEMENT:expr, $ARRAY_TYPE:ident) => {{ - let sliced_array: Vec> = $ARRAY +fn slice( + array: &ListArray, + key: &Int64Array, + extra_key: &Int64Array, + return_element: bool, +) -> Result> { + let sliced_array: Vec> = array + .iter() + .zip(key.iter()) + .zip(extra_key.iter()) + .map(|((arr, i), j)| match (arr, i, j) { + (Some(arr), Some(i), Some(j)) => list_slice::(&arr, i, j, return_element), + (Some(arr), None, Some(j)) => list_slice::(&arr, 1i64, j, return_element), + (Some(arr), Some(i), None) => { + list_slice::(&arr, i, arr.len() as i64, return_element) + } + (Some(arr), None, None) if !return_element => arr.clone(), + _ => return_empty(return_element, array.value_type()), + }) + .collect(); + + // concat requires input of at least one array + if sliced_array.is_empty() { + Ok(return_empty(return_element, array.value_type())) + } else { + let vec = sliced_array .iter() - .zip($KEY.iter()) - .zip($EXTRA_KEY.iter()) - .map(|((arr, i), j)| match (arr, i, j) { - (Some(arr), Some(i), Some(j)) => { - list_slice!(arr, i, j, $RETURN_ELEMENT, $ARRAY_TYPE) - } - (Some(arr), None, Some(j)) => { - list_slice!(arr, 1i64, j, $RETURN_ELEMENT, $ARRAY_TYPE) - } - (Some(arr), Some(i), None) => { - list_slice!(arr, i, arr.len() as i64, $RETURN_ELEMENT, $ARRAY_TYPE) - } - (Some(arr), None, None) if !$RETURN_ELEMENT => arr, - _ => return_empty($RETURN_ELEMENT, $ARRAY.value_type().clone()), - }) - .collect(); + .map(|a| a.as_ref()) + .collect::>(); + let mut i: i32 = 0; + let mut offsets = vec![i]; + offsets.extend( + vec.iter() + .map(|a| { + i += a.len() as i32; + i + }) + .collect::>(), + ); + let values = compute::concat(vec.as_slice()).unwrap(); - // concat requires input of at least one array - if sliced_array.is_empty() { - Ok(return_empty($RETURN_ELEMENT, $ARRAY.value_type())) + if return_element { + Ok(values) } else { - let vec = sliced_array - .iter() - .map(|a| a.as_ref()) - .collect::>(); - let mut i: i32 = 0; - let mut offsets = vec![i]; - offsets.extend( - vec.iter() - .map(|a| { - i += a.len() as i32; - i - }) - .collect::>(), - ); - let values = compute::concat(vec.as_slice()).unwrap(); - - if $RETURN_ELEMENT { - Ok(values) - } else { - let field = - Arc::new(Field::new("item", $ARRAY.value_type().clone(), true)); - Ok(Arc::new(ListArray::try_new( - field, - OffsetBuffer::new(offsets.into()), - values, - None, - )?)) - } + let field = Arc::new(Field::new("item", array.value_type(), true)); + Ok(Arc::new(ListArray::try_new( + field, + OffsetBuffer::new(offsets.into()), + values, + None, + )?)) } - }}; + } } fn define_array_slice( @@ -492,7 +489,7 @@ fn define_array_slice( ) -> Result { macro_rules! array_function { ($ARRAY_TYPE:ident) => { - slice!(list_array, key, extra_key, return_element, $ARRAY_TYPE) + slice::<$ARRAY_TYPE>(list_array, key, extra_key, return_element) }; } call_array_function!(list_array.value_type(), true)