Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: optimize lower and upper functions #9971

Merged
merged 10 commits into from
Apr 15, 2024
11 changes: 11 additions & 0 deletions datafusion/functions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ unicode-segmentation = { version = "^1.7.1", optional = true }
uuid = { version = "1.7", features = ["v4"], optional = true }

[dev-dependencies]
arrow = { workspace = true, features = ["test_utils"] }
criterion = "0.5"
rand = { workspace = true }
rstest = { workspace = true }
Expand Down Expand Up @@ -112,3 +113,13 @@ required-features = ["datetime_expressions"]
harness = false
name = "substr_index"
required-features = ["unicode_expressions"]

[[bench]]
harness = false
name = "lower"
required-features = ["string_expressions"]

[[bench]]
harness = false
name = "upper"
required-features = ["string_expressions"]
77 changes: 77 additions & 0 deletions datafusion/functions/benches/lower.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

extern crate criterion;

use arrow::array::{ArrayRef, StringArray};
use arrow::util::bench_util::create_string_array_with_len;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use datafusion_expr::ColumnarValue;
use datafusion_functions::string;
use std::sync::Arc;

fn create_args1(size: usize, str_len: usize) -> Vec<ColumnarValue> {
let array = Arc::new(create_string_array_with_len::<i32>(size, 0.2, str_len));
vec![ColumnarValue::Array(array)]
}

fn create_args2(size: usize) -> Vec<ColumnarValue> {
let mut items = Vec::with_capacity(size);
items.push("农历新年".to_string());
for i in 1..size {
items.push(format!("DATAFUSION {}", i));
}
let array = Arc::new(StringArray::from(items)) as ArrayRef;
vec![ColumnarValue::Array(array)]
}

fn create_args3(size: usize) -> Vec<ColumnarValue> {
let mut items = Vec::with_capacity(size);
let half = size / 2;
for i in 0..half {
items.push(format!("DATAFUSION {}", i));
}
items.push("Ⱦ".to_string());
for i in half + 1..size {
items.push(format!("DATAFUSION {}", i));
}
let array = Arc::new(StringArray::from(items)) as ArrayRef;
vec![ColumnarValue::Array(array)]
}

fn criterion_benchmark(c: &mut Criterion) {
let lower = string::lower();
for size in [1024, 4096, 8192] {
let args = create_args1(size, 32);
c.bench_function(&format!("lower full optimization: {}", size), |b| {
b.iter(|| black_box(lower.invoke(&args)))
});

let args = create_args2(size);
c.bench_function(&format!("lower maybe optimization: {}", size), |b| {
b.iter(|| black_box(lower.invoke(&args)))
});

let args = create_args3(size);
c.bench_function(&format!("lower partial optimization: {}", size), |b| {
b.iter(|| black_box(lower.invoke(&args)))
});
}
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
40 changes: 40 additions & 0 deletions datafusion/functions/benches/upper.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

extern crate criterion;

use arrow::util::bench_util::create_string_array_with_len;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use datafusion_expr::ColumnarValue;
use datafusion_functions::string;
use std::sync::Arc;

fn create_args(size: usize, str_len: usize) -> Vec<ColumnarValue> {
let array = Arc::new(create_string_array_with_len::<i32>(size, 0.2, str_len));
vec![ColumnarValue::Array(array)]
}

fn criterion_benchmark(c: &mut Criterion) {
let upper = string::upper();
for size in [1024, 4096, 8192] {
let args = create_args(size, 32);
c.bench_function("upper", |b| b.iter(|| black_box(upper.invoke(&args))));
}
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
179 changes: 124 additions & 55 deletions datafusion/functions/src/string/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
use std::fmt::{Display, Formatter};
use std::sync::Arc;

use arrow::array::{Array, ArrayRef, GenericStringArray, OffsetSizeTrait};
use arrow::array::{
Array, ArrayRef, BufferBuilder, GenericStringArray, GenericStringBuilder,
OffsetSizeTrait,
};
use arrow::buffer::{Buffer, OffsetBuffer, ScalarBuffer};
use arrow::datatypes::DataType;

use datafusion_common::cast::as_generic_string_array;
Expand Down Expand Up @@ -97,80 +101,145 @@ pub(crate) fn general_trim<T: OffsetSizeTrait>(
}
}

/// applies a unary expression to `args[0]` that is expected to be downcastable to
/// a `GenericStringArray` and returns a `GenericStringArray` (which may have a different offset)
/// # Errors
/// This function errors when:
/// * the number of arguments is not 1
/// * the first argument is not castable to a `GenericStringArray`
pub(crate) fn unary_string_function<'a, T, O, F, R>(
args: &[&'a dyn Array],
op: F,
name: &str,
) -> Result<GenericStringArray<O>>
where
R: AsRef<str>,
O: OffsetSizeTrait,
T: OffsetSizeTrait,
F: Fn(&'a str) -> R,
{
if args.len() != 1 {
return exec_err!(
"{:?} args were supplied but {} takes exactly one argument",
args.len(),
name
);
}

let string_array = as_generic_string_array::<T>(args[0])?;
pub(crate) fn to_lower(args: &[ColumnarValue], name: &str) -> Result<ColumnarValue> {
case_conversion(args, |string| string.to_lowercase(), name)
}

// first map is the iterator, second is for the `Option<_>`
Ok(string_array.iter().map(|string| string.map(&op)).collect())
pub(crate) fn to_upper(args: &[ColumnarValue], name: &str) -> Result<ColumnarValue> {
case_conversion(args, |string| string.to_uppercase(), name)
}

pub(crate) fn handle<'a, F, R>(
fn case_conversion<'a, F>(
args: &'a [ColumnarValue],
op: F,
name: &str,
) -> Result<ColumnarValue>
where
R: AsRef<str>,
F: Fn(&'a str) -> R,
F: Fn(&'a str) -> String,
{
match &args[0] {
ColumnarValue::Array(a) => match a.data_type() {
DataType::Utf8 => {
Ok(ColumnarValue::Array(Arc::new(unary_string_function::<
i32,
i32,
_,
_,
>(
&[a.as_ref()], op, name
)?)))
}
DataType::LargeUtf8 => {
Ok(ColumnarValue::Array(Arc::new(unary_string_function::<
i64,
i64,
_,
_,
>(
&[a.as_ref()], op, name
)?)))
}
ColumnarValue::Array(array) => match array.data_type() {
DataType::Utf8 => Ok(ColumnarValue::Array(case_conversion_array::<i32, _>(
array, op,
)?)),
DataType::LargeUtf8 => Ok(ColumnarValue::Array(case_conversion_array::<
i64,
_,
>(array, op)?)),
other => exec_err!("Unsupported data type {other:?} for function {name}"),
},
ColumnarValue::Scalar(scalar) => match scalar {
ScalarValue::Utf8(a) => {
let result = a.as_ref().map(|x| (op)(x).as_ref().to_string());
let result = a.as_ref().map(|x| op(x));
Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result)))
}
ScalarValue::LargeUtf8(a) => {
let result = a.as_ref().map(|x| (op)(x).as_ref().to_string());
let result = a.as_ref().map(|x| op(x));
Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result)))
}
other => exec_err!("Unsupported data type {other:?} for function {name}"),
},
}
}

fn case_conversion_array<'a, O, F>(array: &'a ArrayRef, op: F) -> Result<ArrayRef>
where
O: OffsetSizeTrait,
F: Fn(&'a str) -> String,
{
let string_array = as_generic_string_array::<O>(array)?;
let item_len = string_array.len();

// Find the first nonascii string at the beginning.
let find_the_first_nonascii = || {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK it is quite a bit faster to do the check once on the entire string/byte array (including nulls), than to check it individually.
This should simplify the logic as well, e.g. not searching for the index but only do it when the entire array is ascii.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @Dandandan for your suggestion. Based on your suggestion:
The benefits:

  • Simpler logic
  • Helps to further improve the performance of Case1 and Case2

The downside:

  • Giving up on Case3

Is my understanding correct? 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your understanding seems correct :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. If the majority are in favor of this plan, I will implement it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it would be good to try the simpler approach. However, as long at this current implementation is well tested and shows performance improvements I think we could merge it as is and simplify the implementation in a follow on PR as well.

If this is your preference @JasonLi-cn I will try and find some more time to review the implementation carefully. A simpler implementation has the benefit it is easier (and thus faster) to review.

Some other random optimization thoughts:

  • We could and upper/lower values as a single string in one call, for example, detecting when the relevant value was a different length and doing a special path then
  • We could also special case when the string had nulls and when it didn't (which can make the inner loop simpler and allow a better chance for auto vectorization)

for (i, item) in string_array.iter().enumerate() {
if let Some(str) = item {
if !str.as_bytes().is_ascii() {
return i;
}
}
}
item_len
};
let the_first_nonascii_index = find_the_first_nonascii();

// Case1: maybe optimization
if the_first_nonascii_index == 0 {
let iter = string_array.iter().map(|string| string.map(&op));
let capacity = string_array.value_data().len() + 8;
let mut builder = GenericStringBuilder::<O>::with_capacity(item_len, capacity);
builder.extend(iter);
return Ok(Arc::new(builder.finish()));
}

// Case2: full optimization
if the_first_nonascii_index == item_len {
return case_conversion_ascii_array::<O, _>(array, op);
}

// Case3: partial optimization
let value_data = string_array.value_data();
let offsets = string_array.offsets();
let nulls = string_array.nulls().cloned();

// Init new offsets buffer builder
let mut offsets_builder = BufferBuilder::<O>::new(item_len + 1);
offsets_builder.append_slice(&offsets.as_ref()[..the_first_nonascii_index + 1]);

// convert ascii
let end: O = unsafe { *offsets.get_unchecked(the_first_nonascii_index) };
let end = end.as_usize();
let ascii = unsafe { std::str::from_utf8_unchecked(&value_data[..end]) };
let mut converted_values = op(ascii);
// To avoid repeatedly allocating memory, perform a reserve in advance.
converted_values.reserve(value_data.len() - end + 8);

// Convert remaining items
for j in the_first_nonascii_index..item_len {
if string_array.is_valid(j) {
let item = unsafe { string_array.value_unchecked(j) };
// Memory will be continuously allocated here, but it is unavoidable.
let converted = op(item);
converted_values.push_str(&converted);
}
offsets_builder
.append(O::from_usize(converted_values.len()).expect("offset overflow"));
}
let offsets_buffer = offsets_builder.finish();

// Build result
let bytes = converted_values.into_bytes();
let values = Buffer::from_vec(bytes);
let offsets = OffsetBuffer::new(ScalarBuffer::new(offsets_buffer, 0, item_len + 1));
// SAFETY: offsets and nulls are consistent with the input array.
Ok(Arc::new(unsafe {
GenericStringArray::<O>::new_unchecked(offsets, values, nulls)
}))
}

fn case_conversion_ascii_array<'a, O, F>(array: &'a ArrayRef, op: F) -> Result<ArrayRef>
where
O: OffsetSizeTrait,
F: Fn(&'a str) -> String,
{
let string_array = as_generic_string_array::<O>(array)?;
let value_data = string_array.value_data();

// SAFETY: all items stored in value_data satisfy UTF8.
// ref: impl ByteArrayNativeType for str {...}
let str_values = unsafe { std::str::from_utf8_unchecked(value_data) };

// conversion
let converted_values = op(str_values);
assert_eq!(converted_values.len(), str_values.len());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you for this check

let bytes = converted_values.into_bytes();

// build result
let values = Buffer::from_vec(bytes);
let offsets = string_array.offsets().clone();
let nulls = string_array.nulls().cloned();
// SAFETY: offsets and nulls are consistent with the input array.
Ok(Arc::new(unsafe {
GenericStringArray::<O>::new_unchecked(offsets, values, nulls)
}))
}
Loading
Loading