From 17743a56a77f02ae05f713cd4e28b79cde499e19 Mon Sep 17 00:00:00 2001 From: JasonLi-cn Date: Sat, 6 Apr 2024 12:17:00 +0800 Subject: [PATCH 1/9] feat: optimize lower and upper functions --- datafusion/functions/Cargo.toml | 10 ++++ datafusion/functions/benches/lower.rs | 40 ++++++++++++++ datafusion/functions/benches/upper.rs | 40 ++++++++++++++ datafusion/functions/src/string/common.rs | 64 +++++++++++++++++++++++ datafusion/functions/src/string/lower.rs | 4 +- datafusion/functions/src/string/upper.rs | 47 ++++++++++++++++- 6 files changed, 201 insertions(+), 4 deletions(-) create mode 100644 datafusion/functions/benches/lower.rs create mode 100644 datafusion/functions/benches/upper.rs diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index ef7d2c9b1892..68b32ff0a086 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -83,6 +83,7 @@ unicode-segmentation = { version = "^1.7.1", optional = true } uuid = { version = "1.7", features = ["v4"], optional = true } [dev-dependencies] +arrow = { workspace = true, features = ["test_utils"] } criterion = "0.5" rand = { workspace = true } rstest = { workspace = true } @@ -112,3 +113,12 @@ required-features = ["datetime_expressions"] harness = false name = "substr_index" required-features = ["unicode_expressions"] + +[[bench]] +harness = false +name = "lower" + +[[bench]] +harness = false +name = "upper" + diff --git a/datafusion/functions/benches/lower.rs b/datafusion/functions/benches/lower.rs new file mode 100644 index 000000000000..59da4113cc95 --- /dev/null +++ b/datafusion/functions/benches/lower.rs @@ -0,0 +1,40 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::util::bench_util::create_string_array_with_len; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::ColumnarValue; +use datafusion_functions::string; +use std::sync::Arc; + +fn create_args(size: usize, str_len: usize) -> Vec { + let array = Arc::new(create_string_array_with_len::(size, 0.2, str_len)); + vec![ColumnarValue::Array(array)] +} + +fn criterion_benchmark(c: &mut Criterion) { + let lower = string::lower(); + for size in [1024, 4096, 8192] { + let args = create_args(size, 32); + c.bench_function("lower", |b| b.iter(|| black_box(lower.invoke(&args)))); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/upper.rs b/datafusion/functions/benches/upper.rs new file mode 100644 index 000000000000..5c9a1096ff81 --- /dev/null +++ b/datafusion/functions/benches/upper.rs @@ -0,0 +1,40 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::util::bench_util::create_string_array_with_len; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::ColumnarValue; +use datafusion_functions::string; +use std::sync::Arc; + +fn create_args(size: usize, str_len: usize) -> Vec { + let array = Arc::new(create_string_array_with_len::(size, 0.2, str_len)); + vec![ColumnarValue::Array(array)] +} + +fn criterion_benchmark(c: &mut Criterion) { + let upper = string::upper(); + for size in [1024, 4096, 8192] { + let args = create_args(size, 32); + c.bench_function("upper", |b| b.iter(|| black_box(upper.invoke(&args)))); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs index 276aad121df2..e8989b64a4c7 100644 --- a/datafusion/functions/src/string/common.rs +++ b/datafusion/functions/src/string/common.rs @@ -19,6 +19,7 @@ use std::fmt::{Display, Formatter}; use std::sync::Arc; use arrow::array::{Array, ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::buffer::Buffer; use arrow::datatypes::DataType; use datafusion_common::cast::as_generic_string_array; @@ -103,6 +104,7 @@ pub(crate) fn general_trim( /// This function errors when: /// * the number of arguments is not 1 /// * the first argument is not castable to a `GenericStringArray` +#[allow(dead_code)] pub(crate) fn unary_string_function<'a, T, O, F, R>( args: &[&'a dyn Array], op: F, @@ -128,6 +130,7 @@ where Ok(string_array.iter().map(|string| string.map(&op)).collect()) } +#[allow(dead_code)] pub(crate) fn handle<'a, F, R>( args: &'a [ColumnarValue], op: F, @@ -174,3 +177,64 @@ where }, } } + +pub(crate) fn case_conversion<'a, F>( + args: &'a [ColumnarValue], + op: F, + name: &str, +) -> Result +where + F: Fn(&'a str) -> String, +{ + match &args[0] { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Utf8 => Ok(ColumnarValue::Array(convert_array::( + array, + |string| op(string), + )?)), + DataType::LargeUtf8 => Ok(ColumnarValue::Array(convert_array::( + array, + |string| op(string), + )?)), + other => exec_err!("Unsupported data type {other:?} for function {name}"), + }, + ColumnarValue::Scalar(scalar) => match scalar { + ScalarValue::Utf8(a) => { + let result = a.as_ref().map(|x| op(x)); + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) + } + ScalarValue::LargeUtf8(a) => { + let result = a.as_ref().map(|x| op(x)); + Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result))) + } + other => exec_err!("Unsupported data type {other:?} for function {name}"), + }, + } +} + +fn convert_array<'a, O, F>(array: &'a ArrayRef, op: F) -> Result +where + O: OffsetSizeTrait, + F: Fn(&'a str) -> String, +{ + let string_array = as_generic_string_array::(array)?; + let value_data = string_array.value_data(); + + // SAFETY: all items stored in value_data satisfy UTF8. + // ref: impl ByteArrayNativeType for str {...} + let str_values = unsafe { std::str::from_utf8_unchecked(value_data) }; + + // conversion + let converted_values = op(str_values); + let bytes = converted_values.into_bytes(); + + // build result + let values = Buffer::from_vec(bytes); + let offsets = string_array.offsets().clone(); + let nulls = string_array.nulls().cloned(); + + // SAFETY: offsets and nulls are consistent with the input array. + Ok(Arc::new(unsafe { + GenericStringArray::::new_unchecked(offsets, values, nulls) + })) +} diff --git a/datafusion/functions/src/string/lower.rs b/datafusion/functions/src/string/lower.rs index a1eff7042211..9d84017d15ec 100644 --- a/datafusion/functions/src/string/lower.rs +++ b/datafusion/functions/src/string/lower.rs @@ -23,7 +23,7 @@ use datafusion_common::Result; use datafusion_expr::ColumnarValue; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use crate::string::common::handle; +use crate::string::common::case_conversion; use crate::utils::utf8_to_str_type; #[derive(Debug)] @@ -62,6 +62,6 @@ impl ScalarUDFImpl for LowerFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - handle(args, |string| string.to_lowercase(), "lower") + case_conversion(args, |string| string.to_lowercase(), "lower") } } diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs index c21824d30d53..b7ad6244a849 100644 --- a/datafusion/functions/src/string/upper.rs +++ b/datafusion/functions/src/string/upper.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::string::common::handle; +use crate::string::common::case_conversion; use crate::utils::utf8_to_str_type; use arrow::datatypes::DataType; use datafusion_common::Result; @@ -59,6 +59,49 @@ impl ScalarUDFImpl for UpperFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - handle(args, |string| string.to_uppercase(), "upper") + case_conversion(args, |string| string.to_uppercase(), "upper") + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{ArrayRef, StringArray}; + use std::sync::Arc; + + #[test] + fn upper() -> Result<()> { + let string_array = Arc::new(StringArray::from(vec![ + Some("arrow"), + None, + Some("datafusion"), + Some("@_"), + Some("0123456789"), + None, + Some(""), + Some("\t\n"), + ])) as ArrayRef; + + let args = vec![ColumnarValue::Array(string_array)]; + let result = + match case_conversion(&args, |string| string.to_uppercase(), "upper")? { + ColumnarValue::Array(result) => result, + _ => unreachable!(), + }; + + let expected = Arc::new(StringArray::from(vec![ + Some("ARROW"), + None, + Some("DATAFUSION"), + Some("@_"), + Some("0123456789"), + None, + Some(""), + Some("\t\n"), + ])) as ArrayRef; + + assert_eq!(&expected, &result); + + Ok(()) } } From e80e3dbcf8e5fc7838ee6abd312ed29e4a6e10d9 Mon Sep 17 00:00:00 2001 From: JasonLi-cn Date: Sat, 6 Apr 2024 13:05:23 +0800 Subject: [PATCH 2/9] chore: pass cargo check --- datafusion/functions/Cargo.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 68b32ff0a086..2502df535a53 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -117,8 +117,9 @@ required-features = ["unicode_expressions"] [[bench]] harness = false name = "lower" +required-features = ["string_expressions"] [[bench]] harness = false name = "upper" - +required-features = ["string_expressions"] From bee214058f35adff54b80f5032c7c7cf330b5ab7 Mon Sep 17 00:00:00 2001 From: JasonLi-cn Date: Sat, 6 Apr 2024 14:55:49 +0800 Subject: [PATCH 3/9] chore: pass cargo clippy --- datafusion/functions/src/string/common.rs | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs index e8989b64a4c7..f617c3f01584 100644 --- a/datafusion/functions/src/string/common.rs +++ b/datafusion/functions/src/string/common.rs @@ -188,14 +188,12 @@ where { match &args[0] { ColumnarValue::Array(array) => match array.data_type() { - DataType::Utf8 => Ok(ColumnarValue::Array(convert_array::( - array, - |string| op(string), - )?)), - DataType::LargeUtf8 => Ok(ColumnarValue::Array(convert_array::( - array, - |string| op(string), - )?)), + DataType::Utf8 => { + Ok(ColumnarValue::Array(convert_array::(array, op)?)) + } + DataType::LargeUtf8 => { + Ok(ColumnarValue::Array(convert_array::(array, op)?)) + } other => exec_err!("Unsupported data type {other:?} for function {name}"), }, ColumnarValue::Scalar(scalar) => match scalar { From 2c27a8ec56f85066b41cf2508876dc3cdedb7c91 Mon Sep 17 00:00:00 2001 From: JasonLi-cn Date: Sun, 7 Apr 2024 19:36:54 +0800 Subject: [PATCH 4/9] fix: lower and upper bug --- datafusion/functions/benches/lower.rs | 43 +++++++++- datafusion/functions/src/string/common.rs | 86 +++++++++++++++++++- datafusion/functions/src/string/lower.rs | 97 ++++++++++++++++++++++- datafusion/functions/src/string/upper.rs | 82 +++++++++++++++---- 4 files changed, 283 insertions(+), 25 deletions(-) diff --git a/datafusion/functions/benches/lower.rs b/datafusion/functions/benches/lower.rs index 59da4113cc95..d99cdeb745a3 100644 --- a/datafusion/functions/benches/lower.rs +++ b/datafusion/functions/benches/lower.rs @@ -17,22 +17,59 @@ extern crate criterion; +use arrow::array::{ArrayRef, StringArray}; use arrow::util::bench_util::create_string_array_with_len; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::ColumnarValue; use datafusion_functions::string; use std::sync::Arc; -fn create_args(size: usize, str_len: usize) -> Vec { +fn create_args1(size: usize, str_len: usize) -> Vec { let array = Arc::new(create_string_array_with_len::(size, 0.2, str_len)); vec![ColumnarValue::Array(array)] } +fn create_args2(size: usize) -> Vec { + let mut items = Vec::with_capacity(size); + items.push("农历新年".to_string()); + for i in 1..size { + items.push(format!("DATAFUSION {}", i)); + } + let array = Arc::new(StringArray::from(items)) as ArrayRef; + vec![ColumnarValue::Array(array)] +} + +fn create_args3(size: usize) -> Vec { + let mut items = Vec::with_capacity(size); + let half = size / 2; + for i in 0..half { + items.push(format!("DATAFUSION {}", i)); + } + items.push("Ⱦ".to_string()); + for i in half + 1..size { + items.push(format!("DATAFUSION {}", i)); + } + let array = Arc::new(StringArray::from(items)) as ArrayRef; + vec![ColumnarValue::Array(array)] +} + fn criterion_benchmark(c: &mut Criterion) { let lower = string::lower(); for size in [1024, 4096, 8192] { - let args = create_args(size, 32); - c.bench_function("lower", |b| b.iter(|| black_box(lower.invoke(&args)))); + let args = create_args1(size, 32); + c.bench_function(&format!("lower full optimization: {}", size), |b| { + b.iter(|| black_box(lower.invoke(&args))) + }); + + let args = create_args2(size); + c.bench_function(&format!("lower no optimization: {}", size), |b| { + b.iter(|| black_box(lower.invoke(&args))) + }); + + let args = create_args3(size); + c.bench_function(&format!("lower partial optimization: {}", size), |b| { + b.iter(|| black_box(lower.invoke(&args))) + }); } } diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs index f617c3f01584..0e6a872be2e3 100644 --- a/datafusion/functions/src/string/common.rs +++ b/datafusion/functions/src/string/common.rs @@ -18,8 +18,8 @@ use std::fmt::{Display, Formatter}; use std::sync::Arc; -use arrow::array::{Array, ArrayRef, GenericStringArray, OffsetSizeTrait}; -use arrow::buffer::Buffer; +use arrow::array::{Array, ArrayRef, BufferBuilder, GenericStringArray, OffsetSizeTrait}; +use arrow::buffer::{Buffer, OffsetBuffer, ScalarBuffer}; use arrow::datatypes::DataType; use datafusion_common::cast::as_generic_string_array; @@ -178,7 +178,15 @@ where } } -pub(crate) fn case_conversion<'a, F>( +pub(crate) fn to_lower(args: &[ColumnarValue], name: &str) -> Result { + case_conversion(args, |string| string.to_lowercase(), name) +} + +pub(crate) fn to_upper(args: &[ColumnarValue], name: &str) -> Result { + case_conversion(args, |string| string.to_uppercase(), name) +} + +fn case_conversion<'a, F>( args: &'a [ColumnarValue], op: F, name: &str, @@ -211,6 +219,76 @@ where } fn convert_array<'a, O, F>(array: &'a ArrayRef, op: F) -> Result +where + O: OffsetSizeTrait, + F: Fn(&'a str) -> String, +{ + let string_array = as_generic_string_array::(array)?; + let item_len = string_array.len(); + + // Find the ASCII string at the beginning. + let mut i = 0; + while i < item_len { + let item = unsafe { string_array.value_unchecked(i) }; + if !item.as_bytes().is_ascii() { + break; + } + i += 1; + } + + let the_first_nonascii_index = i; + + // Case1: no optimization + if the_first_nonascii_index == 0 { + let result: GenericStringArray = + string_array.iter().map(|string| string.map(&op)).collect(); + return Ok(Arc::new(result)); + } + + // Case2: full optimization + if the_first_nonascii_index == item_len { + return convert_ascii_array::(array, op); + } + + // Case3: partial optimization + let value_data = string_array.value_data(); + let offsets = string_array.offsets(); + let nulls = string_array.nulls().cloned(); + + // Init new offsets buffer builder + let mut offsets_builder = BufferBuilder::::new(item_len + 1); + offsets_builder.append_slice(&offsets.as_ref()[..the_first_nonascii_index + 1]); + + // convert ascii + let end: O = unsafe { *offsets.get_unchecked(the_first_nonascii_index) }; + let end = end.as_usize(); + let ascii = unsafe { std::str::from_utf8_unchecked(&value_data[..end]) }; + let mut converted_values = op(ascii); + // To avoid repeatedly allocating memory, perform a reserve in advance. + converted_values.reserve(value_data.len() - end + 8); + + // Convert remaining items + for j in the_first_nonascii_index..item_len { + let item = unsafe { string_array.value_unchecked(j) }; + // Memory will be continuously allocated here, but it is unavoidable. + let converted = op(item); + converted_values.push_str(&converted); + offsets_builder + .append(O::from_usize(converted_values.len()).expect("offset overflow")); + } + let offsets_buffer = offsets_builder.finish(); + + // Build result + let bytes = converted_values.into_bytes(); + let values = Buffer::from_vec(bytes); + let offsets = OffsetBuffer::new(ScalarBuffer::new(offsets_buffer, 0, item_len + 1)); + // SAFETY: offsets and nulls are consistent with the input array. + Ok(Arc::new(unsafe { + GenericStringArray::::new_unchecked(offsets, values, nulls) + })) +} + +fn convert_ascii_array<'a, O, F>(array: &'a ArrayRef, op: F) -> Result where O: OffsetSizeTrait, F: Fn(&'a str) -> String, @@ -224,13 +302,13 @@ where // conversion let converted_values = op(str_values); + assert_eq!(converted_values.len(), str_values.len()); let bytes = converted_values.into_bytes(); // build result let values = Buffer::from_vec(bytes); let offsets = string_array.offsets().clone(); let nulls = string_array.nulls().cloned(); - // SAFETY: offsets and nulls are consistent with the input array. Ok(Arc::new(unsafe { GenericStringArray::::new_unchecked(offsets, values, nulls) diff --git a/datafusion/functions/src/string/lower.rs b/datafusion/functions/src/string/lower.rs index 9d84017d15ec..0e01c5aa60a8 100644 --- a/datafusion/functions/src/string/lower.rs +++ b/datafusion/functions/src/string/lower.rs @@ -23,7 +23,7 @@ use datafusion_common::Result; use datafusion_expr::ColumnarValue; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use crate::string::common::case_conversion; +use crate::string::common::to_lower; use crate::utils::utf8_to_str_type; #[derive(Debug)] @@ -62,6 +62,99 @@ impl ScalarUDFImpl for LowerFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - case_conversion(args, |string| string.to_lowercase(), "lower") + to_lower(args, "lower") + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{ArrayRef, StringArray}; + use std::sync::Arc; + + fn to_lower(input: ArrayRef, expected: ArrayRef) -> Result<()> { + let func = LowerFunc::new(); + let args = vec![ColumnarValue::Array(input)]; + let result = match func.invoke(&args)? { + ColumnarValue::Array(result) => result, + _ => unreachable!(), + }; + assert_eq!(&expected, &result); + Ok(()) + } + + #[test] + fn lower_no_optimization() -> Result<()> { + let input = Arc::new(StringArray::from(vec![ + Some("农历新年"), + None, + Some("DATAFUSION"), + Some("0123456789"), + Some(""), + ])) as ArrayRef; + + let expected = Arc::new(StringArray::from(vec![ + Some("农历新年"), + None, + Some("datafusion"), + Some("0123456789"), + Some(""), + ])) as ArrayRef; + + to_lower(input, expected) + } + + #[test] + fn lower_full_optimization() -> Result<()> { + let input = Arc::new(StringArray::from(vec![ + Some("ARROW"), + None, + Some("DATAFUSION"), + Some("0123456789"), + Some(""), + ])) as ArrayRef; + + let expected = Arc::new(StringArray::from(vec![ + Some("arrow"), + None, + Some("datafusion"), + Some("0123456789"), + Some(""), + ])) as ArrayRef; + + to_lower(input, expected) + } + + #[test] + fn lower_partial_optimization() -> Result<()> { + let input = Arc::new(StringArray::from(vec![ + Some("ARROW"), + None, + Some("DATAFUSION"), + Some("@_"), + Some("0123456789"), + Some(""), + Some("\t\n"), + Some("ὈΔΥΣΣΕΎΣ"), + Some("TSCHÜSS"), + Some("Ⱦ"), // ⱦ: length change + Some("农历新年"), + ])) as ArrayRef; + + let expected = Arc::new(StringArray::from(vec![ + Some("arrow"), + None, + Some("datafusion"), + Some("@_"), + Some("0123456789"), + Some(""), + Some("\t\n"), + Some("ὀδυσσεύς"), + Some("tschüss"), + Some("ⱦ"), + Some("农历新年"), + ])) as ArrayRef; + + to_lower(input, expected) } } diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs index b7ad6244a849..0776cbad5291 100644 --- a/datafusion/functions/src/string/upper.rs +++ b/datafusion/functions/src/string/upper.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::string::common::case_conversion; +use crate::string::common::to_upper; use crate::utils::utf8_to_str_type; use arrow::datatypes::DataType; use datafusion_common::Result; @@ -59,7 +59,7 @@ impl ScalarUDFImpl for UpperFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - case_conversion(args, |string| string.to_uppercase(), "upper") + to_upper(args, "upper") } } @@ -69,39 +69,89 @@ mod tests { use arrow::array::{ArrayRef, StringArray}; use std::sync::Arc; + fn to_upper(input: ArrayRef, expected: ArrayRef) -> Result<()> { + let func = UpperFunc::new(); + let args = vec![ColumnarValue::Array(input)]; + let result = match func.invoke(&args)? { + ColumnarValue::Array(result) => result, + _ => unreachable!(), + }; + assert_eq!(&expected, &result); + Ok(()) + } + #[test] - fn upper() -> Result<()> { - let string_array = Arc::new(StringArray::from(vec![ - Some("arrow"), + fn upper_no_optimization() -> Result<()> { + let input = Arc::new(StringArray::from(vec![ + Some("农历新年"), None, Some("datafusion"), - Some("@_"), Some("0123456789"), + Some(""), + ])) as ArrayRef; + + let expected = Arc::new(StringArray::from(vec![ + Some("农历新年"), None, + Some("DATAFUSION"), + Some("0123456789"), Some(""), - Some("\t\n"), ])) as ArrayRef; - let args = vec![ColumnarValue::Array(string_array)]; - let result = - match case_conversion(&args, |string| string.to_uppercase(), "upper")? { - ColumnarValue::Array(result) => result, - _ => unreachable!(), - }; + to_upper(input, expected) + } + + #[test] + fn upper_full_optimization() -> Result<()> { + let input = Arc::new(StringArray::from(vec![ + Some("arrow"), + None, + Some("datafusion"), + Some("0123456789"), + Some(""), + ])) as ArrayRef; let expected = Arc::new(StringArray::from(vec![ Some("ARROW"), None, Some("DATAFUSION"), - Some("@_"), Some("0123456789"), + Some(""), + ])) as ArrayRef; + + to_upper(input, expected) + } + + #[test] + fn upper_partial_optimization() -> Result<()> { + let input = Arc::new(StringArray::from(vec![ + Some("arrow"), None, + Some("datafusion"), + Some("@_"), + Some("0123456789"), Some(""), Some("\t\n"), + Some("ὀδυσσεύς"), + Some("tschüß"), + Some("ⱦ"), // Ⱦ: length change + Some("农历新年"), ])) as ArrayRef; - assert_eq!(&expected, &result); + let expected = Arc::new(StringArray::from(vec![ + Some("ARROW"), + None, + Some("DATAFUSION"), + Some("@_"), + Some("0123456789"), + Some(""), + Some("\t\n"), + Some("ὈΔΥΣΣΕΎΣ"), + Some("TSCHÜSS"), + Some("Ⱦ"), + Some("农历新年"), + ])) as ArrayRef; - Ok(()) + to_upper(input, expected) } } From 9f1185b58b457173122dd89b5fa4cfc77b381184 Mon Sep 17 00:00:00 2001 From: JasonLi-cn Date: Sun, 7 Apr 2024 21:57:49 +0800 Subject: [PATCH 5/9] optimize --- datafusion/functions/benches/lower.rs | 2 +- datafusion/functions/src/string/common.rs | 95 +++-------------------- datafusion/functions/src/string/lower.rs | 2 +- datafusion/functions/src/string/upper.rs | 2 +- 4 files changed, 13 insertions(+), 88 deletions(-) diff --git a/datafusion/functions/benches/lower.rs b/datafusion/functions/benches/lower.rs index d99cdeb745a3..1e9d3505b557 100644 --- a/datafusion/functions/benches/lower.rs +++ b/datafusion/functions/benches/lower.rs @@ -62,7 +62,7 @@ fn criterion_benchmark(c: &mut Criterion) { }); let args = create_args2(size); - c.bench_function(&format!("lower no optimization: {}", size), |b| { + c.bench_function(&format!("lower maybe optimization: {}", size), |b| { b.iter(|| black_box(lower.invoke(&args))) }); diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs index 0e6a872be2e3..eaeacb5aa27a 100644 --- a/datafusion/functions/src/string/common.rs +++ b/datafusion/functions/src/string/common.rs @@ -18,7 +18,10 @@ use std::fmt::{Display, Formatter}; use std::sync::Arc; -use arrow::array::{Array, ArrayRef, BufferBuilder, GenericStringArray, OffsetSizeTrait}; +use arrow::array::{ + Array, ArrayRef, BufferBuilder, GenericStringArray, GenericStringBuilder, + OffsetSizeTrait, +}; use arrow::buffer::{Buffer, OffsetBuffer, ScalarBuffer}; use arrow::datatypes::DataType; @@ -98,86 +101,6 @@ pub(crate) fn general_trim( } } -/// applies a unary expression to `args[0]` that is expected to be downcastable to -/// a `GenericStringArray` and returns a `GenericStringArray` (which may have a different offset) -/// # Errors -/// This function errors when: -/// * the number of arguments is not 1 -/// * the first argument is not castable to a `GenericStringArray` -#[allow(dead_code)] -pub(crate) fn unary_string_function<'a, T, O, F, R>( - args: &[&'a dyn Array], - op: F, - name: &str, -) -> Result> -where - R: AsRef, - O: OffsetSizeTrait, - T: OffsetSizeTrait, - F: Fn(&'a str) -> R, -{ - if args.len() != 1 { - return exec_err!( - "{:?} args were supplied but {} takes exactly one argument", - args.len(), - name - ); - } - - let string_array = as_generic_string_array::(args[0])?; - - // first map is the iterator, second is for the `Option<_>` - Ok(string_array.iter().map(|string| string.map(&op)).collect()) -} - -#[allow(dead_code)] -pub(crate) fn handle<'a, F, R>( - args: &'a [ColumnarValue], - op: F, - name: &str, -) -> Result -where - R: AsRef, - F: Fn(&'a str) -> R, -{ - match &args[0] { - ColumnarValue::Array(a) => match a.data_type() { - DataType::Utf8 => { - Ok(ColumnarValue::Array(Arc::new(unary_string_function::< - i32, - i32, - _, - _, - >( - &[a.as_ref()], op, name - )?))) - } - DataType::LargeUtf8 => { - Ok(ColumnarValue::Array(Arc::new(unary_string_function::< - i64, - i64, - _, - _, - >( - &[a.as_ref()], op, name - )?))) - } - other => exec_err!("Unsupported data type {other:?} for function {name}"), - }, - ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8(a) => { - let result = a.as_ref().map(|x| (op)(x).as_ref().to_string()); - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) - } - ScalarValue::LargeUtf8(a) => { - let result = a.as_ref().map(|x| (op)(x).as_ref().to_string()); - Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result))) - } - other => exec_err!("Unsupported data type {other:?} for function {name}"), - }, - } -} - pub(crate) fn to_lower(args: &[ColumnarValue], name: &str) -> Result { case_conversion(args, |string| string.to_lowercase(), name) } @@ -238,11 +161,13 @@ where let the_first_nonascii_index = i; - // Case1: no optimization + // Case1: maybe optimization if the_first_nonascii_index == 0 { - let result: GenericStringArray = - string_array.iter().map(|string| string.map(&op)).collect(); - return Ok(Arc::new(result)); + let iter = string_array.iter().map(|string| string.map(&op)); + let capacity = string_array.value_data().len() + 8; + let mut builder = GenericStringBuilder::::with_capacity(item_len, capacity); + builder.extend(iter); + return Ok(Arc::new(builder.finish())); } // Case2: full optimization diff --git a/datafusion/functions/src/string/lower.rs b/datafusion/functions/src/string/lower.rs index 0e01c5aa60a8..b9b3840252c5 100644 --- a/datafusion/functions/src/string/lower.rs +++ b/datafusion/functions/src/string/lower.rs @@ -84,7 +84,7 @@ mod tests { } #[test] - fn lower_no_optimization() -> Result<()> { + fn lower_maybe_optimization() -> Result<()> { let input = Arc::new(StringArray::from(vec![ Some("农历新年"), None, diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs index 0776cbad5291..8f03d7dc6bbc 100644 --- a/datafusion/functions/src/string/upper.rs +++ b/datafusion/functions/src/string/upper.rs @@ -81,7 +81,7 @@ mod tests { } #[test] - fn upper_no_optimization() -> Result<()> { + fn upper_maybe_optimization() -> Result<()> { let input = Arc::new(StringArray::from(vec![ Some("农历新年"), None, From d0a1b4f4871c4a3368dfe38b00f5e5c800b4057c Mon Sep 17 00:00:00 2001 From: JasonLi-cn Date: Mon, 8 Apr 2024 23:25:07 +0800 Subject: [PATCH 6/9] using iter to find the first nonascii --- datafusion/functions/src/string/common.rs | 31 +++++++++++++---------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs index eaeacb5aa27a..47b0fba88230 100644 --- a/datafusion/functions/src/string/common.rs +++ b/datafusion/functions/src/string/common.rs @@ -149,17 +149,18 @@ where let string_array = as_generic_string_array::(array)?; let item_len = string_array.len(); - // Find the ASCII string at the beginning. - let mut i = 0; - while i < item_len { - let item = unsafe { string_array.value_unchecked(i) }; - if !item.as_bytes().is_ascii() { - break; + // Find the first nonascii string at the beginning. + let find_the_first_nonascii = || { + for (i, item) in string_array.iter().enumerate() { + if let Some(str) = item { + if !str.as_bytes().is_ascii() { + return i; + } + } } - i += 1; - } - - let the_first_nonascii_index = i; + item_len + }; + let the_first_nonascii_index = find_the_first_nonascii(); // Case1: maybe optimization if the_first_nonascii_index == 0 { @@ -194,10 +195,12 @@ where // Convert remaining items for j in the_first_nonascii_index..item_len { - let item = unsafe { string_array.value_unchecked(j) }; - // Memory will be continuously allocated here, but it is unavoidable. - let converted = op(item); - converted_values.push_str(&converted); + if string_array.is_valid(j) { + let item = unsafe { string_array.value_unchecked(j) }; + // Memory will be continuously allocated here, but it is unavoidable. + let converted = op(item); + converted_values.push_str(&converted); + } offsets_builder .append(O::from_usize(converted_values.len()).expect("offset overflow")); } From fd04d4a59558ff49ee68dc0216c74162cf76d05a Mon Sep 17 00:00:00 2001 From: JasonLi-cn Date: Mon, 8 Apr 2024 23:28:11 +0800 Subject: [PATCH 7/9] chore: rename function --- datafusion/functions/src/string/common.rs | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs index 47b0fba88230..630ce28afc56 100644 --- a/datafusion/functions/src/string/common.rs +++ b/datafusion/functions/src/string/common.rs @@ -119,12 +119,13 @@ where { match &args[0] { ColumnarValue::Array(array) => match array.data_type() { - DataType::Utf8 => { - Ok(ColumnarValue::Array(convert_array::(array, op)?)) - } - DataType::LargeUtf8 => { - Ok(ColumnarValue::Array(convert_array::(array, op)?)) - } + DataType::Utf8 => Ok(ColumnarValue::Array(case_conversion_array::( + array, op, + )?)), + DataType::LargeUtf8 => Ok(ColumnarValue::Array(case_conversion_array::< + i64, + _, + >(array, op)?)), other => exec_err!("Unsupported data type {other:?} for function {name}"), }, ColumnarValue::Scalar(scalar) => match scalar { @@ -141,7 +142,7 @@ where } } -fn convert_array<'a, O, F>(array: &'a ArrayRef, op: F) -> Result +fn case_conversion_array<'a, O, F>(array: &'a ArrayRef, op: F) -> Result where O: OffsetSizeTrait, F: Fn(&'a str) -> String, @@ -173,7 +174,7 @@ where // Case2: full optimization if the_first_nonascii_index == item_len { - return convert_ascii_array::(array, op); + return case_conversion_ascii_array::(array, op); } // Case3: partial optimization @@ -216,7 +217,7 @@ where })) } -fn convert_ascii_array<'a, O, F>(array: &'a ArrayRef, op: F) -> Result +fn case_conversion_ascii_array<'a, O, F>(array: &'a ArrayRef, op: F) -> Result where O: OffsetSizeTrait, F: Fn(&'a str) -> String, From 087f6ed95143352a419b8a60768e80dc3fc5c2e9 Mon Sep 17 00:00:00 2001 From: JasonLi-cn Date: Fri, 12 Apr 2024 19:08:37 +0800 Subject: [PATCH 8/9] refactor: case_conversion_array function --- datafusion/functions/benches/lower.rs | 28 +++++-- datafusion/functions/benches/upper.rs | 8 +- datafusion/functions/src/string/common.rs | 96 +++++++---------------- 3 files changed, 56 insertions(+), 76 deletions(-) diff --git a/datafusion/functions/benches/lower.rs b/datafusion/functions/benches/lower.rs index 1e9d3505b557..fa963f174e46 100644 --- a/datafusion/functions/benches/lower.rs +++ b/datafusion/functions/benches/lower.rs @@ -24,11 +24,19 @@ use datafusion_expr::ColumnarValue; use datafusion_functions::string; use std::sync::Arc; +/// Create an array of args containing a StringArray, where all the values in the +/// StringArray are ASCII. +/// * `size` - the length of the StringArray, and +/// * `str_len` - the length of the strings within the StringArray. fn create_args1(size: usize, str_len: usize) -> Vec { let array = Arc::new(create_string_array_with_len::(size, 0.2, str_len)); vec![ColumnarValue::Array(array)] } +/// Create an array of args containing a StringArray, where the first value in the +/// StringArray is non-ASCII. +/// * `size` - the length of the StringArray, and +/// * `str_len` - the length of the strings within the StringArray. fn create_args2(size: usize) -> Vec { let mut items = Vec::with_capacity(size); items.push("农历新年".to_string()); @@ -39,6 +47,10 @@ fn create_args2(size: usize) -> Vec { vec![ColumnarValue::Array(array)] } +/// Create an array of args containing a StringArray, where the middle value of the +/// StringArray is non-ASCII. +/// * `size` - the length of the StringArray, and +/// * `str_len` - the length of the strings within the StringArray. fn create_args3(size: usize) -> Vec { let mut items = Vec::with_capacity(size); let half = size / 2; @@ -57,19 +69,21 @@ fn criterion_benchmark(c: &mut Criterion) { let lower = string::lower(); for size in [1024, 4096, 8192] { let args = create_args1(size, 32); - c.bench_function(&format!("lower full optimization: {}", size), |b| { + c.bench_function(&format!("lower_all_values_are_ascii: {}", size), |b| { b.iter(|| black_box(lower.invoke(&args))) }); let args = create_args2(size); - c.bench_function(&format!("lower maybe optimization: {}", size), |b| { - b.iter(|| black_box(lower.invoke(&args))) - }); + c.bench_function( + &format!("lower_the_first_value_is_nonascii: {}", size), + |b| b.iter(|| black_box(lower.invoke(&args))), + ); let args = create_args3(size); - c.bench_function(&format!("lower partial optimization: {}", size), |b| { - b.iter(|| black_box(lower.invoke(&args))) - }); + c.bench_function( + &format!("lower_the_middle_value_is_nonascii: {}", size), + |b| b.iter(|| black_box(lower.invoke(&args))), + ); } } diff --git a/datafusion/functions/benches/upper.rs b/datafusion/functions/benches/upper.rs index 5c9a1096ff81..a3e5fbd7a433 100644 --- a/datafusion/functions/benches/upper.rs +++ b/datafusion/functions/benches/upper.rs @@ -23,6 +23,10 @@ use datafusion_expr::ColumnarValue; use datafusion_functions::string; use std::sync::Arc; +/// Create an array of args containing a StringArray, where all the values in the +/// StringArray are ASCII. +/// * `size` - the length of the StringArray, and +/// * `str_len` - the length of the strings within the StringArray. fn create_args(size: usize, str_len: usize) -> Vec { let array = Arc::new(create_string_array_with_len::(size, 0.2, str_len)); vec![ColumnarValue::Array(array)] @@ -32,7 +36,9 @@ fn criterion_benchmark(c: &mut Criterion) { let upper = string::upper(); for size in [1024, 4096, 8192] { let args = create_args(size, 32); - c.bench_function("upper", |b| b.iter(|| black_box(upper.invoke(&args)))); + c.bench_function("upper_all_values_are_ascii", |b| { + b.iter(|| black_box(upper.invoke(&args))) + }); } } diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs index c3f7f14820b7..d9e41ba4756e 100644 --- a/datafusion/functions/src/string/common.rs +++ b/datafusion/functions/src/string/common.rs @@ -19,10 +19,10 @@ use std::fmt::{Display, Formatter}; use std::sync::Arc; use arrow::array::{ - new_null_array, Array, ArrayRef, BufferBuilder, GenericStringArray, - GenericStringBuilder, OffsetSizeTrait, + new_null_array, Array, ArrayRef, GenericStringArray, GenericStringBuilder, + OffsetSizeTrait, }; -use arrow::buffer::{Buffer, OffsetBuffer, ScalarBuffer}; +use arrow::buffer::Buffer; use arrow::datatypes::DataType; use datafusion_common::cast::as_generic_string_array; @@ -160,84 +160,44 @@ where O: OffsetSizeTrait, F: Fn(&'a str) -> String, { - let string_array = as_generic_string_array::(array)?; - let item_len = string_array.len(); - - // Find the first nonascii string at the beginning. - let find_the_first_nonascii = || { - for (i, item) in string_array.iter().enumerate() { - if let Some(str) = item { - if !str.as_bytes().is_ascii() { - return i; - } - } - } - item_len - }; - let the_first_nonascii_index = find_the_first_nonascii(); - - // Case1: maybe optimization - if the_first_nonascii_index == 0 { - let iter = string_array.iter().map(|string| string.map(&op)); - let capacity = string_array.value_data().len() + 8; - let mut builder = GenericStringBuilder::::with_capacity(item_len, capacity); - builder.extend(iter); - return Ok(Arc::new(builder.finish())); - } - - // Case2: full optimization - if the_first_nonascii_index == item_len { - return case_conversion_ascii_array::(array, op); - } + const PRE_ALLOC_BYTES: usize = 8; - // Case3: partial optimization + let string_array = as_generic_string_array::(array)?; let value_data = string_array.value_data(); - let offsets = string_array.offsets(); - let nulls = string_array.nulls().cloned(); - // Init new offsets buffer builder - let mut offsets_builder = BufferBuilder::::new(item_len + 1); - offsets_builder.append_slice(&offsets.as_ref()[..the_first_nonascii_index + 1]); + // All values are ASCII. + if value_data.is_ascii() { + return case_conversion_ascii_array::(string_array, op); + } - // convert ascii - let end: O = unsafe { *offsets.get_unchecked(the_first_nonascii_index) }; - let end = end.as_usize(); - let ascii = unsafe { std::str::from_utf8_unchecked(&value_data[..end]) }; - let mut converted_values = op(ascii); - // To avoid repeatedly allocating memory, perform a reserve in advance. - converted_values.reserve(value_data.len() - end + 8); + // Values contain non-ASCII. + let item_len = string_array.len(); + let capacity = string_array.value_data().len() + PRE_ALLOC_BYTES; + let mut builder = GenericStringBuilder::::with_capacity(item_len, capacity); - // Convert remaining items - for j in the_first_nonascii_index..item_len { - if string_array.is_valid(j) { - let item = unsafe { string_array.value_unchecked(j) }; - // Memory will be continuously allocated here, but it is unavoidable. - let converted = op(item); - converted_values.push_str(&converted); - } - offsets_builder - .append(O::from_usize(converted_values.len()).expect("offset overflow")); + if !string_array.is_nullable() || string_array.null_count() == 0 { + let iter = + (0..item_len).map(|i| Some(op(unsafe { string_array.value_unchecked(i) }))); + builder.extend(iter); + } else { + let iter = string_array.iter().map(|string| string.map(&op)); + builder.extend(iter); } - let offsets_buffer = offsets_builder.finish(); - - // Build result - let bytes = converted_values.into_bytes(); - let values = Buffer::from_vec(bytes); - let offsets = OffsetBuffer::new(ScalarBuffer::new(offsets_buffer, 0, item_len + 1)); - // SAFETY: offsets and nulls are consistent with the input array. - Ok(Arc::new(unsafe { - GenericStringArray::::new_unchecked(offsets, values, nulls) - })) + Ok(Arc::new(builder.finish())) } -fn case_conversion_ascii_array<'a, O, F>(array: &'a ArrayRef, op: F) -> Result +/// All values of string_array are ASCII, and when converting case, there is no changes in the byte +/// array length. Therefore, the StringArray can be treated as a complete ASCII string for +/// case conversion, and we can reuse the offsets buffer and the nulls buffer. +fn case_conversion_ascii_array<'a, O, F>( + string_array: &'a GenericStringArray, + op: F, +) -> Result where O: OffsetSizeTrait, F: Fn(&'a str) -> String, { - let string_array = as_generic_string_array::(array)?; let value_data = string_array.value_data(); - // SAFETY: all items stored in value_data satisfy UTF8. // ref: impl ByteArrayNativeType for str {...} let str_values = unsafe { std::str::from_utf8_unchecked(value_data) }; From eeb928207b8cf85cddcb68899ef9ac15901cea60 Mon Sep 17 00:00:00 2001 From: JasonLi-cn Date: Sun, 14 Apr 2024 22:28:21 +0800 Subject: [PATCH 9/9] refactor: remove !string_array.is_nullable() from case_conversion_array --- datafusion/functions/src/string/common.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs index d9e41ba4756e..97f9e1d93be5 100644 --- a/datafusion/functions/src/string/common.rs +++ b/datafusion/functions/src/string/common.rs @@ -175,7 +175,7 @@ where let capacity = string_array.value_data().len() + PRE_ALLOC_BYTES; let mut builder = GenericStringBuilder::::with_capacity(item_len, capacity); - if !string_array.is_nullable() || string_array.null_count() == 0 { + if string_array.null_count() == 0 { let iter = (0..item_len).map(|i| Some(op(unsafe { string_array.value_unchecked(i) }))); builder.extend(iter);