Skip to content

Commit

Permalink
feat: optimize lower and upper functions (#9971)
Browse files Browse the repository at this point in the history
* feat: optimize lower and upper functions

* chore: pass cargo check

* chore: pass cargo clippy

* fix: lower and upper bug

* optimize

* using iter to find the first nonascii

* chore: rename function

* refactor: case_conversion_array function

* refactor: remove !string_array.is_nullable() from case_conversion_array
  • Loading branch information
JasonLi-cn authored Apr 15, 2024
1 parent c3f48d4 commit 483663b
Show file tree
Hide file tree
Showing 6 changed files with 420 additions and 59 deletions.
11 changes: 11 additions & 0 deletions datafusion/functions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ unicode-segmentation = { version = "^1.7.1", optional = true }
uuid = { version = "1.7", features = ["v4"], optional = true }

[dev-dependencies]
arrow = { workspace = true, features = ["test_utils"] }
criterion = "0.5"
rand = { workspace = true }
rstest = { workspace = true }
Expand Down Expand Up @@ -117,3 +118,13 @@ required-features = ["unicode_expressions"]
harness = false
name = "ltrim"
required-features = ["string_expressions"]

[[bench]]
harness = false
name = "lower"
required-features = ["string_expressions"]

[[bench]]
harness = false
name = "upper"
required-features = ["string_expressions"]
91 changes: 91 additions & 0 deletions datafusion/functions/benches/lower.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

extern crate criterion;

use arrow::array::{ArrayRef, StringArray};
use arrow::util::bench_util::create_string_array_with_len;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use datafusion_expr::ColumnarValue;
use datafusion_functions::string;
use std::sync::Arc;

/// Create an array of args containing a StringArray, where all the values in the
/// StringArray are ASCII.
/// * `size` - the length of the StringArray, and
/// * `str_len` - the length of the strings within the StringArray.
fn create_args1(size: usize, str_len: usize) -> Vec<ColumnarValue> {
let array = Arc::new(create_string_array_with_len::<i32>(size, 0.2, str_len));
vec![ColumnarValue::Array(array)]
}

/// Create an array of args containing a StringArray, where the first value in the
/// StringArray is non-ASCII.
/// * `size` - the length of the StringArray, and
/// * `str_len` - the length of the strings within the StringArray.
fn create_args2(size: usize) -> Vec<ColumnarValue> {
let mut items = Vec::with_capacity(size);
items.push("农历新年".to_string());
for i in 1..size {
items.push(format!("DATAFUSION {}", i));
}
let array = Arc::new(StringArray::from(items)) as ArrayRef;
vec![ColumnarValue::Array(array)]
}

/// Create an array of args containing a StringArray, where the middle value of the
/// StringArray is non-ASCII.
/// * `size` - the length of the StringArray, and
/// * `str_len` - the length of the strings within the StringArray.
fn create_args3(size: usize) -> Vec<ColumnarValue> {
let mut items = Vec::with_capacity(size);
let half = size / 2;
for i in 0..half {
items.push(format!("DATAFUSION {}", i));
}
items.push("Ⱦ".to_string());
for i in half + 1..size {
items.push(format!("DATAFUSION {}", i));
}
let array = Arc::new(StringArray::from(items)) as ArrayRef;
vec![ColumnarValue::Array(array)]
}

fn criterion_benchmark(c: &mut Criterion) {
let lower = string::lower();
for size in [1024, 4096, 8192] {
let args = create_args1(size, 32);
c.bench_function(&format!("lower_all_values_are_ascii: {}", size), |b| {
b.iter(|| black_box(lower.invoke(&args)))
});

let args = create_args2(size);
c.bench_function(
&format!("lower_the_first_value_is_nonascii: {}", size),
|b| b.iter(|| black_box(lower.invoke(&args))),
);

let args = create_args3(size);
c.bench_function(
&format!("lower_the_middle_value_is_nonascii: {}", size),
|b| b.iter(|| black_box(lower.invoke(&args))),
);
}
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
46 changes: 46 additions & 0 deletions datafusion/functions/benches/upper.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

extern crate criterion;

use arrow::util::bench_util::create_string_array_with_len;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use datafusion_expr::ColumnarValue;
use datafusion_functions::string;
use std::sync::Arc;

/// Create an array of args containing a StringArray, where all the values in the
/// StringArray are ASCII.
/// * `size` - the length of the StringArray, and
/// * `str_len` - the length of the strings within the StringArray.
fn create_args(size: usize, str_len: usize) -> Vec<ColumnarValue> {
let array = Arc::new(create_string_array_with_len::<i32>(size, 0.2, str_len));
vec![ColumnarValue::Array(array)]
}

fn criterion_benchmark(c: &mut Criterion) {
let upper = string::upper();
for size in [1024, 4096, 8192] {
let args = create_args(size, 32);
c.bench_function("upper_all_values_are_ascii", |b| {
b.iter(|| black_box(upper.invoke(&args)))
});
}
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
137 changes: 82 additions & 55 deletions datafusion/functions/src/string/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ use std::fmt::{Display, Formatter};
use std::sync::Arc;

use arrow::array::{
new_null_array, Array, ArrayRef, GenericStringArray, OffsetSizeTrait,
new_null_array, Array, ArrayRef, GenericStringArray, GenericStringBuilder,
OffsetSizeTrait,
};
use arrow::buffer::Buffer;
use arrow::datatypes::DataType;

use datafusion_common::cast::as_generic_string_array;
Expand Down Expand Up @@ -112,80 +114,105 @@ pub(crate) fn general_trim<T: OffsetSizeTrait>(
}
}

/// applies a unary expression to `args[0]` that is expected to be downcastable to
/// a `GenericStringArray` and returns a `GenericStringArray` (which may have a different offset)
/// # Errors
/// This function errors when:
/// * the number of arguments is not 1
/// * the first argument is not castable to a `GenericStringArray`
pub(crate) fn unary_string_function<'a, T, O, F, R>(
args: &[&'a dyn Array],
op: F,
name: &str,
) -> Result<GenericStringArray<O>>
where
R: AsRef<str>,
O: OffsetSizeTrait,
T: OffsetSizeTrait,
F: Fn(&'a str) -> R,
{
if args.len() != 1 {
return exec_err!(
"{:?} args were supplied but {} takes exactly one argument",
args.len(),
name
);
}

let string_array = as_generic_string_array::<T>(args[0])?;
pub(crate) fn to_lower(args: &[ColumnarValue], name: &str) -> Result<ColumnarValue> {
case_conversion(args, |string| string.to_lowercase(), name)
}

// first map is the iterator, second is for the `Option<_>`
Ok(string_array.iter().map(|string| string.map(&op)).collect())
pub(crate) fn to_upper(args: &[ColumnarValue], name: &str) -> Result<ColumnarValue> {
case_conversion(args, |string| string.to_uppercase(), name)
}

pub(crate) fn handle<'a, F, R>(
fn case_conversion<'a, F>(
args: &'a [ColumnarValue],
op: F,
name: &str,
) -> Result<ColumnarValue>
where
R: AsRef<str>,
F: Fn(&'a str) -> R,
F: Fn(&'a str) -> String,
{
match &args[0] {
ColumnarValue::Array(a) => match a.data_type() {
DataType::Utf8 => {
Ok(ColumnarValue::Array(Arc::new(unary_string_function::<
i32,
i32,
_,
_,
>(
&[a.as_ref()], op, name
)?)))
}
DataType::LargeUtf8 => {
Ok(ColumnarValue::Array(Arc::new(unary_string_function::<
i64,
i64,
_,
_,
>(
&[a.as_ref()], op, name
)?)))
}
ColumnarValue::Array(array) => match array.data_type() {
DataType::Utf8 => Ok(ColumnarValue::Array(case_conversion_array::<i32, _>(
array, op,
)?)),
DataType::LargeUtf8 => Ok(ColumnarValue::Array(case_conversion_array::<
i64,
_,
>(array, op)?)),
other => exec_err!("Unsupported data type {other:?} for function {name}"),
},
ColumnarValue::Scalar(scalar) => match scalar {
ScalarValue::Utf8(a) => {
let result = a.as_ref().map(|x| (op)(x).as_ref().to_string());
let result = a.as_ref().map(|x| op(x));
Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result)))
}
ScalarValue::LargeUtf8(a) => {
let result = a.as_ref().map(|x| (op)(x).as_ref().to_string());
let result = a.as_ref().map(|x| op(x));
Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result)))
}
other => exec_err!("Unsupported data type {other:?} for function {name}"),
},
}
}

fn case_conversion_array<'a, O, F>(array: &'a ArrayRef, op: F) -> Result<ArrayRef>
where
O: OffsetSizeTrait,
F: Fn(&'a str) -> String,
{
const PRE_ALLOC_BYTES: usize = 8;

let string_array = as_generic_string_array::<O>(array)?;
let value_data = string_array.value_data();

// All values are ASCII.
if value_data.is_ascii() {
return case_conversion_ascii_array::<O, _>(string_array, op);
}

// Values contain non-ASCII.
let item_len = string_array.len();
let capacity = string_array.value_data().len() + PRE_ALLOC_BYTES;
let mut builder = GenericStringBuilder::<O>::with_capacity(item_len, capacity);

if string_array.null_count() == 0 {
let iter =
(0..item_len).map(|i| Some(op(unsafe { string_array.value_unchecked(i) })));
builder.extend(iter);
} else {
let iter = string_array.iter().map(|string| string.map(&op));
builder.extend(iter);
}
Ok(Arc::new(builder.finish()))
}

/// All values of string_array are ASCII, and when converting case, there is no changes in the byte
/// array length. Therefore, the StringArray can be treated as a complete ASCII string for
/// case conversion, and we can reuse the offsets buffer and the nulls buffer.
fn case_conversion_ascii_array<'a, O, F>(
string_array: &'a GenericStringArray<O>,
op: F,
) -> Result<ArrayRef>
where
O: OffsetSizeTrait,
F: Fn(&'a str) -> String,
{
let value_data = string_array.value_data();
// SAFETY: all items stored in value_data satisfy UTF8.
// ref: impl ByteArrayNativeType for str {...}
let str_values = unsafe { std::str::from_utf8_unchecked(value_data) };

// conversion
let converted_values = op(str_values);
assert_eq!(converted_values.len(), str_values.len());
let bytes = converted_values.into_bytes();

// build result
let values = Buffer::from_vec(bytes);
let offsets = string_array.offsets().clone();
let nulls = string_array.nulls().cloned();
// SAFETY: offsets and nulls are consistent with the input array.
Ok(Arc::new(unsafe {
GenericStringArray::<O>::new_unchecked(offsets, values, nulls)
}))
}
Loading

0 comments on commit 483663b

Please sign in to comment.