diff --git a/Cargo.toml b/Cargo.toml index 1a2f4a84af38..6558642d4a5e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,6 +53,7 @@ arrow = { version = "48.0.0", features = ["prettyprint"] } arrow-array = { version = "48.0.0", default-features = false, features = ["chrono-tz"] } arrow-buffer = { version = "48.0.0", default-features = false } arrow-flight = { version = "48.0.0", features = ["flight-sql-experimental"] } +arrow-ord = { version = "48.0.0", default-features = false } arrow-schema = { version = "48.0.0", default-features = false } async-trait = "0.1.73" bigdecimal = "0.4.1" diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index dc828f018fd5..a5eafa68cf44 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1241,6 +1241,7 @@ dependencies = [ "arrow", "arrow-array", "arrow-buffer", + "arrow-ord", "arrow-schema", "base64", "blake2", diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 4be625e384b9..4496e7215204 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -44,6 +44,7 @@ ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] arrow = { workspace = true } arrow-array = { workspace = true } arrow-buffer = { workspace = true } +arrow-ord = { workspace = true } arrow-schema = { workspace = true } base64 = { version = "0.21", optional = true } blake2 = { version = "^0.10.2", optional = true } diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 687502e79fed..e296e9c96fad 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -31,7 +31,8 @@ use datafusion_common::cast::{ }; use datafusion_common::utils::array_into_list_array; use datafusion_common::{ - exec_err, internal_err, not_impl_err, plan_err, DataFusionError, Result, + exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, + DataFusionError, Result, }; use itertools::Itertools; @@ -1221,217 +1222,121 @@ array_removement_function!( "Array_remove_all SQL function" ); -macro_rules! general_replace { - ($ARRAY:expr, $FROM:expr, $TO:expr, $MAX:expr, $ARRAY_TYPE:ident) => {{ - let mut offsets: Vec = vec![0]; - let mut values = - downcast_arg!(new_empty_array($FROM.data_type()), $ARRAY_TYPE).clone(); - - let from_array = downcast_arg!($FROM, $ARRAY_TYPE); - let to_array = downcast_arg!($TO, $ARRAY_TYPE); - for (((arr, from), to), max) in $ARRAY - .iter() - .zip(from_array.iter()) - .zip(to_array.iter()) - .zip($MAX.iter()) - { - let last_offset: i32 = offsets.last().copied().ok_or_else(|| { - DataFusionError::Internal(format!("offsets should not be empty")) - })?; - match arr { - Some(arr) => { - let child_array = downcast_arg!(arr, $ARRAY_TYPE); - let mut counter = 0; - let max = if max < Some(1) { 1 } else { max.unwrap() }; - - let replaced_array = child_array - .iter() - .map(|el| { - if counter != max && el == from { - counter += 1; - to +fn general_replace(args: &[ArrayRef], arr_n: Vec) -> Result { + let list_array = as_list_array(&args[0])?; + let from_array = &args[1]; + let to_array = &args[2]; + + let mut offsets: Vec = vec![0]; + let data_type = list_array.value_type(); + let mut values = new_empty_array(&data_type); + + for (row_index, (arr, n)) in list_array.iter().zip(arr_n.iter()).enumerate() { + let last_offset: i32 = offsets + .last() + .copied() + .ok_or_else(|| internal_datafusion_err!("offsets should not be empty"))?; + match arr { + Some(arr) => { + let indices = UInt32Array::from(vec![row_index as u32]); + let from_arr = arrow::compute::take(from_array, &indices, None)?; + + let eq_array = match from_arr.data_type() { + // arrow_ord::cmp_eq does not support ListArray, so we need to compare it by loop + DataType::List(_) => { + let from_a = as_list_array(&from_arr)?.value(0); + let list_arr = as_list_array(&arr)?; + + let mut bool_values = vec![]; + for arr in list_arr.iter() { + if let Some(a) = arr { + bool_values.push(Some(a.eq(&from_a))); } else { - el + return internal_err!( + "Null value is not supported in array_replace" + ); } - }) - .collect::<$ARRAY_TYPE>(); - - values = downcast_arg!( - compute::concat(&[&values, &replaced_array])?.clone(), - $ARRAY_TYPE - ) - .clone(); - offsets.push(last_offset + replaced_array.len() as i32); - } - None => { - offsets.push(last_offset); - } - } - } - - let field = Arc::new(Field::new("item", $FROM.data_type().clone(), true)); - - Arc::new(ListArray::try_new( - field, - OffsetBuffer::new(offsets.into()), - Arc::new(values), - None, - )?) - }}; -} - -macro_rules! general_replace_list { - ($ARRAY:expr, $FROM:expr, $TO:expr, $MAX:expr, $ARRAY_TYPE:ident) => {{ - let mut offsets: Vec = vec![0]; - let mut values = - downcast_arg!(new_empty_array($FROM.data_type()), ListArray).clone(); - - let from_array = downcast_arg!($FROM, ListArray); - let to_array = downcast_arg!($TO, ListArray); - for (((arr, from), to), max) in $ARRAY - .iter() - .zip(from_array.iter()) - .zip(to_array.iter()) - .zip($MAX.iter()) - { - let last_offset: i32 = offsets.last().copied().ok_or_else(|| { - DataFusionError::Internal(format!("offsets should not be empty")) - })?; - match arr { - Some(arr) => { - let child_array = downcast_arg!(arr, ListArray); - let mut counter = 0; - let max = if max < Some(1) { 1 } else { max.unwrap() }; + } + BooleanArray::from(bool_values) + } + _ => { + let from_arr = Scalar::new(from_arr); + arrow_ord::cmp::eq(&arr, &from_arr)? + } + }; - let replaced_vec = child_array - .iter() - .map(|el| { - if counter != max && el == from { - counter += 1; - to.clone().unwrap() - } else { - el.clone().unwrap() + // Use MutableArrayData to build the replaced array + // First array is the original array, second array is the element to replace with. + let arrays = vec![arr, to_array.clone()]; + let arrays_data = arrays + .iter() + .map(|a| a.to_data()) + .collect::>(); + let arrays_data = arrays_data.iter().collect::>(); + + let arrays = arrays + .iter() + .map(|arr| arr.as_ref()) + .collect::>(); + let capacity = Capacities::Array(arrays.iter().map(|a| a.len()).sum()); + + let mut mutable = + MutableArrayData::with_capacities(arrays_data, false, capacity); + + let mut counter = 0; + for (i, to_replace) in eq_array.iter().enumerate() { + if let Some(to_replace) = to_replace { + if to_replace { + mutable.extend(1, row_index, row_index + 1); + counter += 1; + if counter == *n { + // extend the rest of the array + mutable.extend(0, i + 1, eq_array.len()); + break; } - }) - .collect::>(); - - let mut i: i32 = 0; - let mut replaced_offsets = vec![i]; - replaced_offsets.extend( - replaced_vec - .clone() - .into_iter() - .map(|a| { - i += a.len() as i32; - i - }) - .collect::>(), - ); - - let mut replaced_values = downcast_arg!( - new_empty_array(&from_array.value_type()), - $ARRAY_TYPE - ) - .clone(); - for replaced_list in replaced_vec { - replaced_values = downcast_arg!( - compute::concat(&[&replaced_values, &replaced_list])?, - $ARRAY_TYPE - ) - .clone(); + } else { + mutable.extend(0, i, i + 1); + } + } else { + return internal_err!("eq_array should not contain None"); } + } - let field = Arc::new(Field::new( - "item", - from_array.value_type().clone(), - true, - )); - let replaced_array = ListArray::try_new( - field, - OffsetBuffer::new(replaced_offsets.clone().into()), - Arc::new(replaced_values), - None, - )?; + let data = mutable.freeze(); + let replaced_array = arrow_array::make_array(data); - values = downcast_arg!( - compute::concat(&[&values, &replaced_array,])?.clone(), - ListArray - ) - .clone(); - offsets.push(last_offset + replaced_array.len() as i32); - } - None => { - offsets.push(last_offset); - } + let v = arrow::compute::concat(&[&values, &replaced_array])?; + values = v; + offsets.push(last_offset + replaced_array.len() as i32); + } + None => { + offsets.push(last_offset); } } + } - let field = Arc::new(Field::new("item", $FROM.data_type().clone(), true)); - - Arc::new(ListArray::try_new( - field, - OffsetBuffer::new(offsets.into()), - Arc::new(values), - None, - )?) - }}; -} - -macro_rules! array_replacement_function { - ($FUNC:ident, $MAX_FUNC:expr, $DOC:expr) => { - #[doc = $DOC] - pub fn $FUNC(args: &[ArrayRef]) -> Result { - let arr = as_list_array(&args[0])?; - let from = &args[1]; - let to = &args[2]; - let max = $MAX_FUNC(args)?; - - check_datatypes(stringify!($FUNC), &[arr.values(), from, to])?; - let res = match arr.value_type() { - DataType::List(field) => { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - general_replace_list!(arr, from, to, max, $ARRAY_TYPE) - }; - } - call_array_function!(field.data_type(), true) - } - data_type => { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - general_replace!(arr, from, to, max, $ARRAY_TYPE) - }; - } - call_array_function!(data_type, false) - } - }; - - Ok(res) - } - }; + Ok(Arc::new(ListArray::try_new( + Arc::new(Field::new("item", data_type, true)), + OffsetBuffer::new(offsets.into()), + values, + None, + )?)) } -fn replace_one(args: &[ArrayRef]) -> Result { - Ok(Int64Array::from_value(1, args[0].len())) +pub fn array_replace(args: &[ArrayRef]) -> Result { + general_replace(args, vec![1; args[0].len()]) } -fn replace_n(args: &[ArrayRef]) -> Result { - as_int64_array(&args[3]).cloned() +pub fn array_replace_n(args: &[ArrayRef]) -> Result { + let arr = as_int64_array(&args[3])?; + let arr_n = arr.values().to_vec(); + general_replace(args, arr_n) } -fn replace_all(args: &[ArrayRef]) -> Result { - Ok(Int64Array::from_value(i64::MAX, args[0].len())) +pub fn array_replace_all(args: &[ArrayRef]) -> Result { + general_replace(args, vec![i64::MAX; args[0].len()]) } -// array replacement functions -array_replacement_function!(array_replace, replace_one, "Array_replace SQL function"); -array_replacement_function!(array_replace_n, replace_n, "Array_replace_n SQL function"); -array_replacement_function!( - array_replace_all, - replace_all, - "Array_replace_all SQL function" -); - macro_rules! to_string { ($ARG:expr, $ARRAY:expr, $DELIMITER:expr, $NULL_STRING:expr, $WITH_NULL_STRING:expr, $ARRAY_TYPE:ident) => {{ let arr = downcast_arg!($ARRAY, $ARRAY_TYPE);