Skip to content

Commit

Permalink
Directly support utf8view in array_to_string & string_to_array. apach…
Browse files Browse the repository at this point in the history
  • Loading branch information
Omega359 committed Nov 13, 2024
1 parent 4e1f839 commit db18086
Show file tree
Hide file tree
Showing 4 changed files with 284 additions and 58 deletions.
2 changes: 1 addition & 1 deletion datafusion/expr-common/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ pub enum TypeSignature {
Numeric(usize),
/// Fixed number of arguments of all the same string types.
/// The precedence of type from high to low is Utf8View, LargeUtf8 and Utf8.
/// Null is considerd as `Utf8` by default
/// Null is considered as `Utf8` by default
/// Dictionary with string value type is also handled.
String(usize),
/// Zero argument
Expand Down
279 changes: 222 additions & 57 deletions datafusion/functions-nested/src/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,18 @@ use std::any::{type_name, Any};

use crate::utils::{downcast_arg, make_scalar_function};
use arrow::compute::cast;
use arrow_array::cast::AsArray;
use arrow_array::{GenericStringArray, StringViewArray};
use arrow_schema::DataType::{
Dictionary, FixedSizeList, LargeList, LargeUtf8, List, Null, Utf8,
};
use datafusion_common::cast::{
as_generic_string_array, as_large_list_array, as_list_array, as_string_array,
Dictionary, FixedSizeList, LargeList, LargeUtf8, List, Null, Utf8, Utf8View,
};
use datafusion_common::cast::{as_large_list_array, as_list_array, as_string_array};
use datafusion_common::exec_err;
use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY;
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_functions::strings::StringArrayType;
use std::sync::{Arc, OnceLock};

macro_rules! to_string {
Expand All @@ -69,6 +70,7 @@ macro_rules! to_string {
macro_rules! call_array_function {
($DATATYPE:expr, false) => {
match $DATATYPE {
DataType::Utf8View => array_function!(StringViewArray),
DataType::Utf8 => array_function!(StringArray),
DataType::LargeUtf8 => array_function!(LargeStringArray),
DataType::Boolean => array_function!(BooleanArray),
Expand All @@ -88,6 +90,7 @@ macro_rules! call_array_function {
($DATATYPE:expr, $INCLUDE_LIST:expr) => {{
match $DATATYPE {
DataType::List(_) => array_function!(ListArray),
DataType::Utf8View => array_function!(StringViewArray),
DataType::Utf8 => array_function!(StringArray),
DataType::LargeUtf8 => array_function!(LargeStringArray),
DataType::Boolean => array_function!(BooleanArray),
Expand Down Expand Up @@ -219,8 +222,8 @@ impl StringToArray {
Self {
signature: Signature::one_of(
vec![
TypeSignature::Uniform(2, vec![Utf8, LargeUtf8]),
TypeSignature::Uniform(3, vec![Utf8, LargeUtf8]),
TypeSignature::Uniform(2, vec![Utf8View, Utf8, LargeUtf8]),
TypeSignature::Uniform(3, vec![Utf8View, Utf8, LargeUtf8]),
],
Volatility::Immutable,
),
Expand All @@ -247,20 +250,21 @@ impl ScalarUDFImpl for StringToArray {
Utf8 | LargeUtf8 => {
List(Arc::new(Field::new("item", arg_types[0].clone(), true)))
}
Utf8View => List(Arc::new(Field::new("item", Utf8, true))),
_ => {
return plan_err!(
"The string_to_array function can only accept Utf8 or LargeUtf8."
"The string_to_array function can only accept Utf8, LargeUtf8 or Utf8View."
);
}
})
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
match args[0].data_type() {
Utf8 => make_scalar_function(string_to_array_inner::<i32>)(args),
Utf8View | Utf8 => make_scalar_function(string_to_array_inner::<i32>)(args),
LargeUtf8 => make_scalar_function(string_to_array_inner::<i64>)(args),
other => {
exec_err!("unsupported type for string_to_array function as {other}")
exec_err!("unsupported type for string_to_array function as {other:?}")
}
}
}
Expand Down Expand Up @@ -493,16 +497,186 @@ pub fn string_to_array_inner<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<Ar
if args.len() < 2 || args.len() > 3 {
return exec_err!("string_to_array expects two or three arguments");
}
let string_array = as_generic_string_array::<T>(&args[0])?;
let delimiter_array = as_generic_string_array::<T>(&args[1])?;
match (args[0].data_type(), args[1].data_type()) {
(Utf8View, Utf8View) => {
let string_array = args[0].as_string_view();
let delimiter_array = args[1].as_string_view();
if args.len() == 3 {
match args[2].data_type() {
Utf8View => {
let null_type_array = Some(args[2].as_string_view());
string_to_array_impl::<
&StringViewArray,
&StringViewArray,
&StringViewArray,
>(
string_array, delimiter_array, null_type_array
)
}
Utf8 | LargeUtf8 => {
let null_type_array = Some(args[2].as_string::<T>());
string_to_array_impl::<
&StringViewArray,
&StringViewArray,
&GenericStringArray<T>,
>(
string_array, delimiter_array, null_type_array
)
}
other => {
exec_err!(
"unsupported type for string_to_array function as {other:?}"
)
}
}
} else {
string_to_array_impl::<
&StringViewArray,
&StringViewArray,
&GenericStringArray<T>,
>(string_array, delimiter_array, None)
}
}
(Utf8View, Utf8 | LargeUtf8) => {
let string_array = args[0].as_string_view();
let delimiter_array = args[1].as_string::<T>();
if args.len() == 3 {
match args[2].data_type() {
Utf8View => {
let null_type_array = Some(args[2].as_string_view());
string_to_array_impl::<
&StringViewArray,
&GenericStringArray<T>,
&StringViewArray,
>(
string_array, delimiter_array, null_type_array
)
}
Utf8 | LargeUtf8 => {
let null_type_array = Some(args[2].as_string::<T>());
string_to_array_impl::<
&StringViewArray,
&GenericStringArray<T>,
&GenericStringArray<T>,
>(
string_array, delimiter_array, null_type_array
)
}
other => {
exec_err!(
"unsupported type for string_to_array function as {other:?}"
)
}
}
} else {
string_to_array_impl::<
&StringViewArray,
&GenericStringArray<T>,
&GenericStringArray<T>,
>(string_array, delimiter_array, None)
}
}
(Utf8 | LargeUtf8, Utf8 | LargeUtf8) => {
let string_array = args[0].as_string::<T>();
let delimiter_array = args[1].as_string::<T>();
if args.len() == 3 {
match args[2].data_type() {
Utf8View => {
let null_type_array = Some(args[2].as_string_view());
string_to_array_impl::<
&GenericStringArray<T>,
&GenericStringArray<T>,
&StringViewArray,
>(
string_array, delimiter_array, null_type_array
)
}
Utf8 | LargeUtf8 => {
let null_type_array = Some(args[2].as_string::<T>());
string_to_array_impl::<
&GenericStringArray<T>,
&GenericStringArray<T>,
&GenericStringArray<T>,
>(
string_array, delimiter_array, null_type_array
)
}
other => {
exec_err!(
"unsupported type for string_to_array function as {other:?}"
)
}
}
} else {
string_to_array_impl::<
&GenericStringArray<T>,
&GenericStringArray<T>,
&GenericStringArray<T>,
>(string_array, delimiter_array, None)
}
}
(Utf8 | LargeUtf8, Utf8View) => {
let string_array = args[0].as_string::<T>();
let delimiter_array = args[1].as_string_view();
if args.len() == 3 {
match args[2].data_type() {
Utf8View => {
let null_type_array = Some(args[2].as_string_view());
string_to_array_impl::<
&GenericStringArray<T>,
&StringViewArray,
&StringViewArray,
>(
string_array, delimiter_array, null_type_array
)
}
Utf8 | LargeUtf8 => {
let null_type_array = Some(args[2].as_string::<T>());
string_to_array_impl::<
&GenericStringArray<T>,
&StringViewArray,
&GenericStringArray<T>,
>(
string_array, delimiter_array, null_type_array
)
}
other => {
exec_err!(
"unsupported type for string_to_array function as {other:?}"
)
}
}
} else {
string_to_array_impl::<
&GenericStringArray<T>,
&StringViewArray,
&GenericStringArray<T>,
>(string_array, delimiter_array, None)
}
}
other => {
exec_err!("unsupported type for string_to_array function as {other:?}")
}
}
}

fn string_to_array_impl<'a, StringArrType, DelimiterArrType, NullValueArrType>(
string_array: StringArrType,
delimiter_array: DelimiterArrType,
null_value_array: Option<NullValueArrType>,
) -> Result<ArrayRef>
where
StringArrType: StringArrayType<'a>,
DelimiterArrType: StringArrayType<'a>,
NullValueArrType: StringArrayType<'a>,
{
let mut list_builder = ListBuilder::new(StringBuilder::with_capacity(
string_array.len(),
string_array.get_buffer_memory_size(),
));

match args.len() {
2 => {
match null_value_array {
None => {
string_array.iter().zip(delimiter_array.iter()).for_each(
|(string, delimiter)| {
match (string, delimiter) {
Expand All @@ -525,55 +699,46 @@ pub fn string_to_array_inner<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<Ar
_ => list_builder.append(false), // null value
}
},
);
)
}

3 => {
let null_value_array = as_generic_string_array::<T>(&args[2])?;
string_array
.iter()
.zip(delimiter_array.iter())
.zip(null_value_array.iter())
.for_each(|((string, delimiter), null_value)| {
match (string, delimiter) {
(Some(string), Some("")) => {
if Some(string) == null_value {
Some(null_value_array) => string_array
.iter()
.zip(delimiter_array.iter())
.zip(null_value_array.iter())
.for_each(|((string, delimiter), null_value)| {
match (string, delimiter) {
(Some(string), Some("")) => {
if Some(string) == null_value {
list_builder.values().append_null();
} else {
list_builder.values().append_value(string);
}
list_builder.append(true);
}
(Some(string), Some(delimiter)) => {
string.split(delimiter).for_each(|s| {
if Some(s) == null_value {
list_builder.values().append_null();
} else {
list_builder.values().append_value(string);
list_builder.values().append_value(s);
}
list_builder.append(true);
}
(Some(string), Some(delimiter)) => {
string.split(delimiter).for_each(|s| {
if Some(s) == null_value {
list_builder.values().append_null();
} else {
list_builder.values().append_value(s);
}
});
list_builder.append(true);
}
(Some(string), None) => {
string.chars().map(|c| c.to_string()).for_each(|c| {
if Some(c.as_str()) == null_value {
list_builder.values().append_null();
} else {
list_builder.values().append_value(c);
}
});
list_builder.append(true);
}
_ => list_builder.append(false), // null value
});
list_builder.append(true);
}
});
}
_ => {
return exec_err!(
"Expect string_to_array function to take two or three parameters"
)
}
}
(Some(string), None) => {
string.chars().map(|c| c.to_string()).for_each(|c| {
if Some(c.as_str()) == null_value {
list_builder.values().append_null();
} else {
list_builder.values().append_value(c);
}
});
list_builder.append(true);
}
_ => list_builder.append(false), // null value
}
}),
};

let list_array = list_builder.finish();
Ok(Arc::new(list_array) as ArrayRef)
Expand Down
Loading

0 comments on commit db18086

Please sign in to comment.