Skip to content

Commit

Permalink
string view
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <jayzhan211@gmail.com>
  • Loading branch information
jayzhan211 committed Aug 21, 2024
1 parent 9c5aadb commit 1f38a7c
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 123 deletions.
2 changes: 1 addition & 1 deletion datafusion/functions-nested/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,4 @@ name = "map"

[[bench]]
harness = false
name = "array_has"
name = "array_has"
8 changes: 3 additions & 5 deletions datafusion/functions-nested/benches/array_has.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use arrow_array::{ArrayRef, BooleanArray, StringArray};
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use datafusion_common::utils::array_into_list_array;
use datafusion_functions_nested::{
array_has::ComparisonType, array_has_internal, general_array_has_dispatch,
array_has::ComparisonType, array_has_dispatch, general_array_has_dispatch,
};
use rand::Rng;

Expand Down Expand Up @@ -67,10 +67,8 @@ fn criterion_benchmark(c: &mut Criterion) {

c.bench_function("array_has specialized approach", |b| {
b.iter(|| {
let is_contained = black_box(
array_has_internal::<i32>(&array, &sub_array, ComparisonType::Single)
.unwrap(),
);
let is_contained =
black_box(array_has_dispatch::<i32>(&array, &sub_array).unwrap());
assert_eq!(&is_contained, &expected);
});
});
Expand Down
177 changes: 61 additions & 116 deletions datafusion/functions-nested/src/array_has.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@
use arrow::array::{Array, ArrayRef, AsArray, BooleanArray, OffsetSizeTrait};
use arrow::datatypes::DataType;
use arrow::row::{RowConverter, SortField};
use arrow_array::GenericListArray;
use datafusion_common::cast::as_generic_list_array;
use datafusion_common::{exec_err, Result};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};

use itertools::Itertools;

use crate::utils::check_datatypes;
use crate::utils::{check_datatypes, make_scalar_function};

use std::any::Any;
use std::sync::Arc;
Expand Down Expand Up @@ -93,34 +94,25 @@ impl ScalarUDFImpl for ArrayHas {
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(args)?;

if args.len() != 2 {
return exec_err!("array_has needs two arguments");
}

let array_type = args[0].data_type();

match array_type {
DataType::List(_) => {
array_has_internal::<i32>(&args[0], &args[1], ComparisonType::Single)
.map(ColumnarValue::Array)
}
DataType::LargeList(_) => general_array_has_dispatch::<i64>(
&args[0],
&args[1],
ComparisonType::Single,
)
.map(ColumnarValue::Array),
_ => exec_err!("array_has does not support type '{array_type:?}'."),
}
make_scalar_function(array_has_inner)(args)
}

fn aliases(&self) -> &[String] {
&self.aliases
}
}

fn array_has_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
match args[0].data_type() {
DataType::List(_) => array_has_dispatch::<i32>(&args[0], &args[1]),
DataType::LargeList(_) => array_has_dispatch::<i64>(&args[0], &args[1]),
_ => exec_err!(
"array_has does not support type '{:?}'.",
args[0].data_type()
),
}
}

#[derive(Debug)]
pub struct ArrayHasAll {
signature: Signature,
Expand Down Expand Up @@ -265,6 +257,7 @@ pub fn general_array_has_dispatch<O: OffsetSizeTrait>(
needle: &ArrayRef,
comparison_type: ComparisonType,
) -> Result<ArrayRef> {
// TODO: Handle data types verification via signature
let array = if comparison_type == ComparisonType::Single {
let arr = as_generic_list_array::<O>(haystack)?;
check_datatypes("array_has", &[arr.values(), needle])?;
Expand Down Expand Up @@ -325,33 +318,62 @@ pub fn general_array_has_dispatch<O: OffsetSizeTrait>(
}

/// Public function for internal benchmark, avoid to use it in production
pub fn array_has_internal<O: OffsetSizeTrait>(
pub fn array_has_dispatch<O: OffsetSizeTrait>(
haystack: &ArrayRef,
needle: &ArrayRef,
_comparison_type: ComparisonType,
) -> Result<ArrayRef> {
let data_type = needle.data_type();
if *data_type == DataType::Utf8 {
return array_has_string_internal(haystack, needle);
let haystack = as_generic_list_array::<O>(haystack)?;
match needle.data_type() {
DataType::Utf8 => array_has_string_internal::<O, i32>(haystack, needle),
DataType::LargeUtf8 => array_has_string_internal::<O, i64>(haystack, needle),
DataType::Utf8View => array_has_string_view_internal::<O>(haystack, needle),
_ => general_array_has::<O>(haystack, needle),
}
general_array_has::<O>(haystack, needle)
}

fn array_has_string_internal(haystack: &ArrayRef, needle: &ArrayRef) -> Result<ArrayRef> {
let array = as_generic_list_array::<i32>(haystack)?;
fn array_has_string_internal<O: OffsetSizeTrait, S: OffsetSizeTrait>(
array: &GenericListArray<O>,
needle: &ArrayRef,
) -> Result<ArrayRef> {
let mut boolean_builder = BooleanArray::builder(array.len());
let needle_array = needle.as_string::<i32>();
let needle_array = needle.as_string::<S>();
for (arr, element) in array.iter().zip(needle_array.iter()) {
match (arr, element) {
(Some(arr), Some(element)) => {
let string_arr = arr.as_string::<i32>();
let string_arr = arr.as_string::<S>();
let mut res = false;
for sub_arr in string_arr.iter() {
if let Some(sub_arr) = sub_arr {
if sub_arr == element {
res = true;
break;
}
for sub_arr in string_arr.iter().flatten() {
if sub_arr == element {
res = true;
break;
}
}
boolean_builder.append_value(res);
}
(_, _) => {
boolean_builder.append_null();
}
}
}

Ok(Arc::new(boolean_builder.finish()))
}

fn array_has_string_view_internal<O: OffsetSizeTrait>(
array: &GenericListArray<O>,
needle: &ArrayRef,
) -> Result<ArrayRef> {
let mut boolean_builder = BooleanArray::builder(array.len());
let needle_array = needle.as_string_view();
for (arr, element) in array.iter().zip(needle_array.iter()) {
match (arr, element) {
(Some(arr), Some(element)) => {
let string_arr = arr.as_string_view();
let mut res = false;
for sub_arr in string_arr.iter().flatten() {
if sub_arr == element {
res = true;
break;
}
}
boolean_builder.append_value(res);
Expand All @@ -366,10 +388,9 @@ fn array_has_string_internal(haystack: &ArrayRef, needle: &ArrayRef) -> Result<A
}

fn general_array_has<O: OffsetSizeTrait>(
haystack: &ArrayRef,
array: &GenericListArray<O>,
needle: &ArrayRef,
) -> Result<ArrayRef> {
let array = as_generic_list_array::<O>(haystack)?;
let mut boolean_builder = BooleanArray::builder(array.len());
let converter = RowConverter::new(vec![SortField::new(array.value_type())])?;
let sub_arr_values = converter.convert_columns(&[Arc::clone(needle)])?;
Expand All @@ -389,79 +410,3 @@ fn general_array_has<O: OffsetSizeTrait>(

Ok(Arc::new(boolean_builder.finish()))
}

#[cfg(test)]
mod tests {
use std::vec;

use arrow_array::StringArray;
use datafusion_common::utils::array_into_list_array;
use datafusion_expr::lit;
use rand::Rng;

use crate::make_array::make_array;

use super::*;

fn generate_random_strings(n: usize, size: usize) -> Vec<String> {
let mut rng = rand::thread_rng();
let mut strings = Vec::with_capacity(n);

// Define the characters to use in the random strings
let charset: &[u8] =
b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";

for _ in 0..n {
// Generate a random string of the specified size or length 4
let random_string: String = if rng.gen_bool(0.5) {
(0..4)
.map(|_| {
let idx = rng.gen_range(0..charset.len());
charset[idx] as char
})
.collect()
} else {
(0..size)
.map(|_| {
let idx = rng.gen_range(0..charset.len());
charset[idx] as char
})
.collect()
};

strings.push(random_string);
}

strings
}

#[test]
fn test_array_has_internal() {
let data = generate_random_strings(4, 4);
let haystack_array = make_array(
(0..data.len())
.map(|i| lit(data[i].clone()))
.collect::<Vec<_>>(),
);
let element = lit(data[0].clone());
println!("haystack_array: {:?}", haystack_array);
println!("element: {:?}", element);
let result = array_has_udf().call(vec![haystack_array, element]);
assert_eq!(result, lit(true));
}

#[test]
fn test_array_has_string_internal() {
let array =
Arc::new(StringArray::from(vec!["abcd", "efgh", "ijkl", "mnop"])) as ArrayRef;
let array = Arc::new(array_into_list_array(array, true)) as ArrayRef;

let sub_array = Arc::new(StringArray::from(vec!["abcd"])) as ArrayRef;

let result =
array_has_internal::<i32>(&array, &sub_array, ComparisonType::Single)
.unwrap();
let expected = Arc::new(BooleanArray::from(vec![true])) as ArrayRef;
assert_eq!(&result, &expected);
}
}
2 changes: 1 addition & 1 deletion datafusion/functions-nested/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ pub mod sort;
pub mod string;
pub mod utils;

pub use array_has::array_has_internal;
pub use array_has::array_has_dispatch;
pub use array_has::general_array_has_dispatch;
use datafusion_common::Result;
use datafusion_execution::FunctionRegistry;
Expand Down

0 comments on commit 1f38a7c

Please sign in to comment.