Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: optimize lower and upper functions #9971

Merged
merged 10 commits into from
Apr 15, 2024
11 changes: 11 additions & 0 deletions datafusion/functions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ unicode-segmentation = { version = "^1.7.1", optional = true }
uuid = { version = "1.7", features = ["v4"], optional = true }

[dev-dependencies]
arrow = { workspace = true, features = ["test_utils"] }
criterion = "0.5"
rand = { workspace = true }
rstest = { workspace = true }
Expand Down Expand Up @@ -118,3 +119,13 @@ required-features = ["unicode_expressions"]
harness = false
name = "ltrim"
required-features = ["string_expressions"]

[[bench]]
harness = false
name = "lower"
required-features = ["string_expressions"]

[[bench]]
harness = false
name = "upper"
required-features = ["string_expressions"]
91 changes: 91 additions & 0 deletions datafusion/functions/benches/lower.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

extern crate criterion;

use arrow::array::{ArrayRef, StringArray};
use arrow::util::bench_util::create_string_array_with_len;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use datafusion_expr::ColumnarValue;
use datafusion_functions::string;
use std::sync::Arc;

/// Create an array of args containing a StringArray, where all the values in the
/// StringArray are ASCII.
/// * `size` - the length of the StringArray, and
/// * `str_len` - the length of the strings within the StringArray.
fn create_args1(size: usize, str_len: usize) -> Vec<ColumnarValue> {
let array = Arc::new(create_string_array_with_len::<i32>(size, 0.2, str_len));
vec![ColumnarValue::Array(array)]
}

/// Create an array of args containing a StringArray, where the first value in the
/// StringArray is non-ASCII.
/// * `size` - the length of the StringArray, and
/// * `str_len` - the length of the strings within the StringArray.
fn create_args2(size: usize) -> Vec<ColumnarValue> {
JasonLi-cn marked this conversation as resolved.
Show resolved Hide resolved
let mut items = Vec::with_capacity(size);
items.push("农历新年".to_string());
for i in 1..size {
items.push(format!("DATAFUSION {}", i));
}
let array = Arc::new(StringArray::from(items)) as ArrayRef;
vec![ColumnarValue::Array(array)]
}

/// Create an array of args containing a StringArray, where the middle value of the
/// StringArray is non-ASCII.
/// * `size` - the length of the StringArray, and
/// * `str_len` - the length of the strings within the StringArray.
fn create_args3(size: usize) -> Vec<ColumnarValue> {
let mut items = Vec::with_capacity(size);
let half = size / 2;
for i in 0..half {
items.push(format!("DATAFUSION {}", i));
}
items.push("Ⱦ".to_string());
for i in half + 1..size {
items.push(format!("DATAFUSION {}", i));
}
let array = Arc::new(StringArray::from(items)) as ArrayRef;
vec![ColumnarValue::Array(array)]
}

fn criterion_benchmark(c: &mut Criterion) {
let lower = string::lower();
for size in [1024, 4096, 8192] {
let args = create_args1(size, 32);
c.bench_function(&format!("lower_all_values_are_ascii: {}", size), |b| {
b.iter(|| black_box(lower.invoke(&args)))
});

let args = create_args2(size);
c.bench_function(
&format!("lower_the_first_value_is_nonascii: {}", size),
|b| b.iter(|| black_box(lower.invoke(&args))),
);

let args = create_args3(size);
c.bench_function(
&format!("lower_the_middle_value_is_nonascii: {}", size),
|b| b.iter(|| black_box(lower.invoke(&args))),
);
}
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
46 changes: 46 additions & 0 deletions datafusion/functions/benches/upper.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

extern crate criterion;

use arrow::util::bench_util::create_string_array_with_len;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use datafusion_expr::ColumnarValue;
use datafusion_functions::string;
use std::sync::Arc;

/// Create an array of args containing a StringArray, where all the values in the
/// StringArray are ASCII.
/// * `size` - the length of the StringArray, and
/// * `str_len` - the length of the strings within the StringArray.
fn create_args(size: usize, str_len: usize) -> Vec<ColumnarValue> {
let array = Arc::new(create_string_array_with_len::<i32>(size, 0.2, str_len));
vec![ColumnarValue::Array(array)]
}

fn criterion_benchmark(c: &mut Criterion) {
let upper = string::upper();
for size in [1024, 4096, 8192] {
let args = create_args(size, 32);
c.bench_function("upper_all_values_are_ascii", |b| {
b.iter(|| black_box(upper.invoke(&args)))
});
}
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
137 changes: 82 additions & 55 deletions datafusion/functions/src/string/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ use std::fmt::{Display, Formatter};
use std::sync::Arc;

use arrow::array::{
new_null_array, Array, ArrayRef, GenericStringArray, OffsetSizeTrait,
new_null_array, Array, ArrayRef, GenericStringArray, GenericStringBuilder,
OffsetSizeTrait,
};
use arrow::buffer::Buffer;
use arrow::datatypes::DataType;

use datafusion_common::cast::as_generic_string_array;
Expand Down Expand Up @@ -112,80 +114,105 @@ pub(crate) fn general_trim<T: OffsetSizeTrait>(
}
}

/// applies a unary expression to `args[0]` that is expected to be downcastable to
/// a `GenericStringArray` and returns a `GenericStringArray` (which may have a different offset)
/// # Errors
/// This function errors when:
/// * the number of arguments is not 1
/// * the first argument is not castable to a `GenericStringArray`
pub(crate) fn unary_string_function<'a, T, O, F, R>(
args: &[&'a dyn Array],
op: F,
name: &str,
) -> Result<GenericStringArray<O>>
where
R: AsRef<str>,
O: OffsetSizeTrait,
T: OffsetSizeTrait,
F: Fn(&'a str) -> R,
{
if args.len() != 1 {
return exec_err!(
"{:?} args were supplied but {} takes exactly one argument",
args.len(),
name
);
}

let string_array = as_generic_string_array::<T>(args[0])?;
pub(crate) fn to_lower(args: &[ColumnarValue], name: &str) -> Result<ColumnarValue> {
case_conversion(args, |string| string.to_lowercase(), name)
}

// first map is the iterator, second is for the `Option<_>`
Ok(string_array.iter().map(|string| string.map(&op)).collect())
pub(crate) fn to_upper(args: &[ColumnarValue], name: &str) -> Result<ColumnarValue> {
case_conversion(args, |string| string.to_uppercase(), name)
}

pub(crate) fn handle<'a, F, R>(
fn case_conversion<'a, F>(
args: &'a [ColumnarValue],
op: F,
name: &str,
) -> Result<ColumnarValue>
where
R: AsRef<str>,
F: Fn(&'a str) -> R,
F: Fn(&'a str) -> String,
{
match &args[0] {
ColumnarValue::Array(a) => match a.data_type() {
DataType::Utf8 => {
Ok(ColumnarValue::Array(Arc::new(unary_string_function::<
i32,
i32,
_,
_,
>(
&[a.as_ref()], op, name
)?)))
}
DataType::LargeUtf8 => {
Ok(ColumnarValue::Array(Arc::new(unary_string_function::<
i64,
i64,
_,
_,
>(
&[a.as_ref()], op, name
)?)))
}
ColumnarValue::Array(array) => match array.data_type() {
DataType::Utf8 => Ok(ColumnarValue::Array(case_conversion_array::<i32, _>(
array, op,
)?)),
DataType::LargeUtf8 => Ok(ColumnarValue::Array(case_conversion_array::<
i64,
_,
>(array, op)?)),
other => exec_err!("Unsupported data type {other:?} for function {name}"),
},
ColumnarValue::Scalar(scalar) => match scalar {
ScalarValue::Utf8(a) => {
let result = a.as_ref().map(|x| (op)(x).as_ref().to_string());
let result = a.as_ref().map(|x| op(x));
Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result)))
}
ScalarValue::LargeUtf8(a) => {
let result = a.as_ref().map(|x| (op)(x).as_ref().to_string());
let result = a.as_ref().map(|x| op(x));
Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result)))
}
other => exec_err!("Unsupported data type {other:?} for function {name}"),
},
}
}

fn case_conversion_array<'a, O, F>(array: &'a ArrayRef, op: F) -> Result<ArrayRef>
where
O: OffsetSizeTrait,
F: Fn(&'a str) -> String,
{
const PRE_ALLOC_BYTES: usize = 8;

let string_array = as_generic_string_array::<O>(array)?;
let value_data = string_array.value_data();

// All values are ASCII.
if value_data.is_ascii() {
return case_conversion_ascii_array::<O, _>(string_array, op);
}

// Values contain non-ASCII.
let item_len = string_array.len();
let capacity = string_array.value_data().len() + PRE_ALLOC_BYTES;
let mut builder = GenericStringBuilder::<O>::with_capacity(item_len, capacity);

if string_array.null_count() == 0 {
let iter =
(0..item_len).map(|i| Some(op(unsafe { string_array.value_unchecked(i) })));
builder.extend(iter);
} else {
let iter = string_array.iter().map(|string| string.map(&op));
builder.extend(iter);
}
Ok(Arc::new(builder.finish()))
}

/// All values of string_array are ASCII, and when converting case, there is no changes in the byte
/// array length. Therefore, the StringArray can be treated as a complete ASCII string for
/// case conversion, and we can reuse the offsets buffer and the nulls buffer.
fn case_conversion_ascii_array<'a, O, F>(
string_array: &'a GenericStringArray<O>,
op: F,
) -> Result<ArrayRef>
where
O: OffsetSizeTrait,
F: Fn(&'a str) -> String,
{
let value_data = string_array.value_data();
// SAFETY: all items stored in value_data satisfy UTF8.
// ref: impl ByteArrayNativeType for str {...}
let str_values = unsafe { std::str::from_utf8_unchecked(value_data) };

// conversion
let converted_values = op(str_values);
assert_eq!(converted_values.len(), str_values.len());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you for this check

let bytes = converted_values.into_bytes();

// build result
let values = Buffer::from_vec(bytes);
let offsets = string_array.offsets().clone();
let nulls = string_array.nulls().cloned();
// SAFETY: offsets and nulls are consistent with the input array.
JasonLi-cn marked this conversation as resolved.
Show resolved Hide resolved
Ok(Arc::new(unsafe {
GenericStringArray::<O>::new_unchecked(offsets, values, nulls)
}))
}
Loading