Skip to content

Commit

Permalink
Issue-9565 - Port ArrayRepeat to function-arrays subcrate (#9568)
Browse files Browse the repository at this point in the history
  • Loading branch information
erenavsarogullari authored Mar 13, 2024
1 parent a9b0db4 commit 78bb64e
Show file tree
Hide file tree
Showing 13 changed files with 215 additions and 189 deletions.
10 changes: 0 additions & 10 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,6 @@ pub enum BuiltinScalarFunction {
ArrayRemoveN,
/// array_remove_all
ArrayRemoveAll,
/// array_repeat
ArrayRepeat,
/// array_replace
ArrayReplace,
/// array_replace_n
Expand Down Expand Up @@ -323,7 +321,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayPopBack => Volatility::Immutable,
BuiltinScalarFunction::ArrayPosition => Volatility::Immutable,
BuiltinScalarFunction::ArrayPositions => Volatility::Immutable,
BuiltinScalarFunction::ArrayRepeat => Volatility::Immutable,
BuiltinScalarFunction::ArrayRemove => Volatility::Immutable,
BuiltinScalarFunction::ArrayRemoveN => Volatility::Immutable,
BuiltinScalarFunction::ArrayRemoveAll => Volatility::Immutable,
Expand Down Expand Up @@ -421,11 +418,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayPositions => {
Ok(List(Arc::new(Field::new("item", UInt64, true))))
}
BuiltinScalarFunction::ArrayRepeat => Ok(List(Arc::new(Field::new(
"item",
input_expr_types[0].clone(),
true,
)))),
BuiltinScalarFunction::ArrayRemove => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayRemoveN => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayRemoveAll => Ok(input_expr_types[0].clone()),
Expand Down Expand Up @@ -652,7 +644,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayPositions => {
Signature::array_and_element(self.volatility())
}
BuiltinScalarFunction::ArrayRepeat => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayRemove => {
Signature::array_and_element(self.volatility())
}
Expand Down Expand Up @@ -1075,7 +1066,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayPositions => {
&["array_positions", "list_positions"]
}
BuiltinScalarFunction::ArrayRepeat => &["array_repeat", "list_repeat"],
BuiltinScalarFunction::ArrayRemove => &["array_remove", "list_remove"],
BuiltinScalarFunction::ArrayRemoveN => &["array_remove_n", "list_remove_n"],
BuiltinScalarFunction::ArrayRemoveAll => {
Expand Down
7 changes: 0 additions & 7 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -622,12 +622,6 @@ scalar_expr!(
array element,
"searches for an element in the array, returns all occurrences."
);
scalar_expr!(
ArrayRepeat,
array_repeat,
element count,
"returns an array containing element `count` times."
);
scalar_expr!(
ArrayRemove,
array_remove,
Expand Down Expand Up @@ -1270,7 +1264,6 @@ mod test {
test_scalar_expr!(ArrayPopBack, array_pop_back, array);
test_scalar_expr!(ArrayPosition, array_position, array, element, index);
test_scalar_expr!(ArrayPositions, array_positions, array, element);
test_scalar_expr!(ArrayRepeat, array_repeat, element, count);
test_scalar_expr!(ArrayRemove, array_remove, array, element);
test_scalar_expr!(ArrayRemoveN, array_remove_n, array, element, max);
test_scalar_expr!(ArrayRemoveAll, array_remove_all, array, element);
Expand Down
153 changes: 150 additions & 3 deletions datafusion/functions-array/src/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@
//! implementation kernels for array functions
use arrow::array::{
Array, ArrayRef, BooleanArray, Date32Array, Float32Array, Float64Array,
Array, ArrayRef, BooleanArray, Capacities, Date32Array, Float32Array, Float64Array,
GenericListArray, Int16Array, Int32Array, Int64Array, Int8Array, LargeListArray,
LargeStringArray, ListArray, ListBuilder, OffsetSizeTrait, StringArray,
StringBuilder, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
LargeStringArray, ListArray, ListBuilder, MutableArrayData, OffsetSizeTrait,
StringArray, StringBuilder, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
};
use arrow::compute;
use arrow::datatypes::{
DataType, Date32Type, Field, IntervalMonthDayNanoType, UInt64Type,
};
use arrow::row::{RowConverter, SortField};
use arrow_array::new_null_array;
use arrow_buffer::{BooleanBufferBuilder, NullBuffer, OffsetBuffer};
use arrow_schema::FieldRef;
use arrow_schema::SortOptions;
Expand Down Expand Up @@ -733,6 +734,152 @@ fn general_array_length<O: OffsetSizeTrait>(array: &[ArrayRef]) -> Result<ArrayR
Ok(Arc::new(result) as ArrayRef)
}

/// Array_repeat SQL function
pub fn array_repeat(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_repeat expects two arguments");
}

let element = &args[0];
let count_array = as_int64_array(&args[1])?;

match element.data_type() {
DataType::List(_) => {
let list_array = as_list_array(element)?;
general_list_repeat::<i32>(list_array, count_array)
}
DataType::LargeList(_) => {
let list_array = as_large_list_array(element)?;
general_list_repeat::<i64>(list_array, count_array)
}
_ => general_repeat::<i32>(element, count_array),
}
}

/// For each element of `array[i]` repeat `count_array[i]` times.
///
/// Assumption for the input:
/// 1. `count[i] >= 0`
/// 2. `array.len() == count_array.len()`
///
/// For example,
/// ```text
/// array_repeat(
/// [1, 2, 3], [2, 0, 1] => [[1, 1], [], [3]]
/// )
/// ```
fn general_repeat<O: OffsetSizeTrait>(
array: &ArrayRef,
count_array: &Int64Array,
) -> Result<ArrayRef> {
let data_type = array.data_type();
let mut new_values = vec![];

let count_vec = count_array
.values()
.to_vec()
.iter()
.map(|x| *x as usize)
.collect::<Vec<_>>();

for (row_index, &count) in count_vec.iter().enumerate() {
let repeated_array = if array.is_null(row_index) {
new_null_array(data_type, count)
} else {
let original_data = array.to_data();
let capacity = Capacities::Array(count);
let mut mutable =
MutableArrayData::with_capacities(vec![&original_data], false, capacity);

for _ in 0..count {
mutable.extend(0, row_index, row_index + 1);
}

let data = mutable.freeze();
arrow_array::make_array(data)
};
new_values.push(repeated_array);
}

let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect();
let values = compute::concat(&new_values)?;

Ok(Arc::new(GenericListArray::<O>::try_new(
Arc::new(Field::new("item", data_type.to_owned(), true)),
OffsetBuffer::from_lengths(count_vec),
values,
None,
)?))
}

/// Handle List version of `general_repeat`
///
/// For each element of `list_array[i]` repeat `count_array[i]` times.
///
/// For example,
/// ```text
/// array_repeat(
/// [[1, 2, 3], [4, 5], [6]], [2, 0, 1] => [[[1, 2, 3], [1, 2, 3]], [], [[6]]]
/// )
/// ```
fn general_list_repeat<O: OffsetSizeTrait>(
list_array: &GenericListArray<O>,
count_array: &Int64Array,
) -> Result<ArrayRef> {
let data_type = list_array.data_type();
let value_type = list_array.value_type();
let mut new_values = vec![];

let count_vec = count_array
.values()
.to_vec()
.iter()
.map(|x| *x as usize)
.collect::<Vec<_>>();

for (list_array_row, &count) in list_array.iter().zip(count_vec.iter()) {
let list_arr = match list_array_row {
Some(list_array_row) => {
let original_data = list_array_row.to_data();
let capacity = Capacities::Array(original_data.len() * count);
let mut mutable = MutableArrayData::with_capacities(
vec![&original_data],
false,
capacity,
);

for _ in 0..count {
mutable.extend(0, 0, original_data.len());
}

let data = mutable.freeze();
let repeated_array = arrow_array::make_array(data);

let list_arr = GenericListArray::<O>::try_new(
Arc::new(Field::new("item", value_type.clone(), true)),
OffsetBuffer::<O>::from_lengths(vec![original_data.len(); count]),
repeated_array,
None,
)?;
Arc::new(list_arr) as ArrayRef
}
None => new_null_array(data_type, count),
};
new_values.push(list_arr);
}

let lengths = new_values.iter().map(|a| a.len()).collect::<Vec<_>>();
let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect();
let values = compute::concat(&new_values)?;

Ok(Arc::new(ListArray::try_new(
Arc::new(Field::new("item", data_type.to_owned(), true)),
OffsetBuffer::<i32>::from_lengths(lengths),
values,
None,
)?))
}

/// Array_length SQL function
pub fn array_length(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 1 && args.len() != 2 {
Expand Down
2 changes: 2 additions & 0 deletions datafusion/functions-array/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ pub mod expr_fn {
pub use super::udf::array_empty;
pub use super::udf::array_length;
pub use super::udf::array_ndims;
pub use super::udf::array_repeat;
pub use super::udf::array_sort;
pub use super::udf::array_to_string;
pub use super::udf::cardinality;
Expand Down Expand Up @@ -86,6 +87,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> {
udf::flatten_udf(),
udf::array_sort_udf(),
udf::array_distinct_udf(),
udf::array_repeat_udf(),
];
functions.into_iter().try_for_each(|udf| {
let existing_udf = registry.register_udf(udf)?;
Expand Down
61 changes: 57 additions & 4 deletions datafusion/functions-array/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use arrow::array::{NullArray, StringArray};
use arrow::datatypes::DataType;
use arrow::datatypes::Field;
use arrow::datatypes::IntervalUnit::MonthDayNano;
use arrow_schema::DataType::{LargeUtf8, List, Utf8};
use datafusion_common::exec_err;
use datafusion_common::plan_err;
use datafusion_common::Result;
Expand Down Expand Up @@ -126,7 +127,7 @@ impl ScalarUDFImpl for StringToArray {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
use DataType::*;
Ok(match arg_types[0] {
Utf8 | LargeUtf8 => {
Expand All @@ -140,18 +141,18 @@ impl ScalarUDFImpl for StringToArray {
})
}

fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result<ColumnarValue> {
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let mut args = ColumnarValue::values_to_arrays(args)?;
// Case: delimiter is NULL, needs to be handled as well.
if args[1].as_any().is::<NullArray>() {
args[1] = Arc::new(StringArray::new_null(args[1].len()));
};

match args[0].data_type() {
arrow::datatypes::DataType::Utf8 => {
Utf8 => {
crate::kernels::string_to_array::<i32>(&args).map(ColumnarValue::Array)
}
arrow::datatypes::DataType::LargeUtf8 => {
LargeUtf8 => {
crate::kernels::string_to_array::<i64>(&args).map(ColumnarValue::Array)
}
other => {
Expand Down Expand Up @@ -588,6 +589,58 @@ impl ScalarUDFImpl for ArrayEmpty {
}
}

make_udf_function!(
ArrayRepeat,
array_repeat,
element count, // arg name
"returns an array containing element `count` times.", // doc
array_repeat_udf // internal function name
);
#[derive(Debug)]
pub(super) struct ArrayRepeat {
signature: Signature,
aliases: Vec<String>,
}

impl ArrayRepeat {
pub fn new() -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
aliases: vec![String::from("array_repeat"), String::from("list_repeat")],
}
}
}

impl ScalarUDFImpl for ArrayRepeat {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"array_repeat"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
Ok(List(Arc::new(Field::new(
"item",
arg_types[0].clone(),
true,
))))
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(args)?;
crate::kernels::array_repeat(&args).map(ColumnarValue::Array)
}

fn aliases(&self) -> &[String] {
&self.aliases
}
}

make_udf_function!(
ArrayLength,
array_length,
Expand Down
Loading

0 comments on commit 78bb64e

Please sign in to comment.