Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 90 additions & 18 deletions datafusion/functions/benches/initcap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use arrow::array::OffsetSizeTrait;
use arrow::array::{ArrayRef, OffsetSizeTrait, StringArray, StringViewBuilder};
use arrow::datatypes::{DataType, Field};
use arrow::util::bench_util::{
create_string_array_with_len, create_string_view_array_with_len,
Expand Down Expand Up @@ -47,52 +47,124 @@ fn create_args<O: OffsetSizeTrait>(
}
}

/// Create a Utf8 array where every value contains non-ASCII Unicode text.
fn create_unicode_utf8_args(size: usize) -> Vec<ColumnarValue> {
let items: Vec<String> = (0..size)
.map(|_| "ñAnDÚ ÁrBOL ОлЕГ ÍslENsku".to_string())
.collect();
let array = Arc::new(StringArray::from(items)) as ArrayRef;
vec![ColumnarValue::Array(array)]
}

/// Create a Utf8View array where every value contains non-ASCII Unicode text.
fn create_unicode_utf8view_args(size: usize) -> Vec<ColumnarValue> {
let mut builder = StringViewBuilder::with_capacity(size);
for _ in 0..size {
builder.append_value("ñAnDÚ ÁrBOL ОлЕГ ÍslENsku");
}
let array = Arc::new(builder.finish()) as ArrayRef;
vec![ColumnarValue::Array(array)]
}

fn criterion_benchmark(c: &mut Criterion) {
let initcap = unicode::initcap();
let config_options = Arc::new(ConfigOptions::default());

// Grouped benchmarks for array sizes - to compare with scalar performance
// Array benchmarks: vary both row count and string length
for size in [1024, 4096, 8192] {
for str_len in [16, 128] {
let mut group =
c.benchmark_group(format!("initcap size={size} str_len={str_len}"));
group.sampling_mode(SamplingMode::Flat);
group.sample_size(10);
group.measurement_time(Duration::from_secs(10));

// Utf8
let array_args = create_args::<i32>(size, str_len, false);
let array_arg_fields = vec![Field::new("arg_0", DataType::Utf8, true).into()];

group.bench_function("array_utf8", |b| {
b.iter(|| {
black_box(initcap.invoke_with_args(ScalarFunctionArgs {
args: array_args.clone(),
arg_fields: array_arg_fields.clone(),
number_rows: size,
return_field: Field::new("f", DataType::Utf8, true).into(),
config_options: Arc::clone(&config_options),
}))
})
});

// Utf8View
let array_view_args = create_args::<i32>(size, str_len, true);
let array_view_arg_fields =
vec![Field::new("arg_0", DataType::Utf8View, true).into()];

group.bench_function("array_utf8view", |b| {
b.iter(|| {
black_box(initcap.invoke_with_args(ScalarFunctionArgs {
args: array_view_args.clone(),
arg_fields: array_view_arg_fields.clone(),
number_rows: size,
return_field: Field::new("f", DataType::Utf8View, true).into(),
config_options: Arc::clone(&config_options),
}))
})
});

group.finish();
}
}

// Unicode array benchmarks
for size in [1024, 4096, 8192] {
let mut group = c.benchmark_group(format!("initcap size={size}"));
let mut group = c.benchmark_group(format!("initcap unicode size={size}"));
group.sampling_mode(SamplingMode::Flat);
group.sample_size(10);
group.measurement_time(Duration::from_secs(10));

// Array benchmark - Utf8
let array_args = create_args::<i32>(size, 16, false);
let array_arg_fields = vec![Field::new("arg_0", DataType::Utf8, true).into()];
let batch_len = size;
let unicode_args = create_unicode_utf8_args(size);
let unicode_arg_fields = vec![Field::new("arg_0", DataType::Utf8, true).into()];

group.bench_function("array_utf8", |b| {
b.iter(|| {
black_box(initcap.invoke_with_args(ScalarFunctionArgs {
args: array_args.clone(),
arg_fields: array_arg_fields.clone(),
number_rows: batch_len,
args: unicode_args.clone(),
arg_fields: unicode_arg_fields.clone(),
number_rows: size,
return_field: Field::new("f", DataType::Utf8, true).into(),
config_options: Arc::clone(&config_options),
}))
})
});

// Array benchmark - Utf8View
let array_view_args = create_args::<i32>(size, 16, true);
let array_view_arg_fields =
let unicode_view_args = create_unicode_utf8view_args(size);
let unicode_view_arg_fields =
vec![Field::new("arg_0", DataType::Utf8View, true).into()];

group.bench_function("array_utf8view", |b| {
b.iter(|| {
black_box(initcap.invoke_with_args(ScalarFunctionArgs {
args: array_view_args.clone(),
arg_fields: array_view_arg_fields.clone(),
number_rows: batch_len,
args: unicode_view_args.clone(),
arg_fields: unicode_view_arg_fields.clone(),
number_rows: size,
return_field: Field::new("f", DataType::Utf8View, true).into(),
config_options: Arc::clone(&config_options),
}))
})
});

// Scalar benchmark - Utf8 (the optimization we added)
group.finish();
}

// Scalar benchmarks: independent of array size, run once
{
let mut group = c.benchmark_group("initcap scalar");
group.sampling_mode(SamplingMode::Flat);
group.sample_size(10);
group.measurement_time(Duration::from_secs(10));

// Utf8
let scalar_args = vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some(
"hello world test string".to_string(),
)))];
Expand All @@ -110,7 +182,7 @@ fn criterion_benchmark(c: &mut Criterion) {
})
});

// Scalar benchmark - Utf8View
// Utf8View
let scalar_view_args = vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
"hello world test string".to_string(),
)))];
Expand Down
174 changes: 163 additions & 11 deletions datafusion/functions/src/unicode/initcap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ use std::any::Any;
use std::sync::Arc;

use arrow::array::{
Array, ArrayRef, GenericStringBuilder, OffsetSizeTrait, StringViewBuilder,
Array, ArrayRef, GenericStringArray, GenericStringBuilder, OffsetSizeTrait,
StringViewBuilder,
};
use arrow::buffer::Buffer;
use arrow::datatypes::DataType;

use crate::utils::{make_scalar_function, utf8_to_str_type};
Expand Down Expand Up @@ -148,8 +150,8 @@ impl ScalarUDFImpl for InitcapFunc {
}
}

/// Converts the first letter of each word to upper case and the rest to lower
/// case. Words are sequences of alphanumeric characters separated by
/// Converts the first letter of each word to uppercase and the rest to
/// lowercase. Words are sequences of alphanumeric characters separated by
/// non-alphanumeric characters.
///
/// Example:
Expand All @@ -159,6 +161,10 @@ impl ScalarUDFImpl for InitcapFunc {
fn initcap<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
let string_array = as_generic_string_array::<T>(&args[0])?;

if string_array.value_data().is_ascii() {
return Ok(initcap_ascii_array(string_array));
}

let mut builder = GenericStringBuilder::<T>::with_capacity(
string_array.len(),
string_array.value_data().len(),
Expand All @@ -176,12 +182,58 @@ fn initcap<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
Ok(Arc::new(builder.finish()) as ArrayRef)
}

/// Fast path for initcap of `Utf8` or `LargeUtf8` arrays that are
/// ASCII-only. We can operate on the entire buffer in a single pass, and
/// operate on bytes directly. Since ASCII case conversion preserves byte
/// length, the original offsets and nulls also don't need to be recomputed.
fn initcap_ascii_array<T: OffsetSizeTrait>(
string_array: &GenericStringArray<T>,
) -> ArrayRef {
let offsets = string_array.offsets();
let src = string_array.value_data();
let first_offset = offsets.first().unwrap().as_usize();
let last_offset = offsets.last().unwrap().as_usize();
let mut out = Vec::with_capacity(src.len());

// Preserve bytes before the first offset unchanged.
out.extend_from_slice(&src[..first_offset]);

for window in offsets.windows(2) {
let start = window[0].as_usize();
let end = window[1].as_usize();

let mut prev_is_alnum = false;
for &b in &src[start..end] {
let converted = if prev_is_alnum {
b.to_ascii_lowercase()
} else {
b.to_ascii_uppercase()
};
out.push(converted);
prev_is_alnum = b.is_ascii_alphanumeric();
}
}

// Preserve bytes after the last offset unchanged.
out.extend_from_slice(&src[last_offset..]);

let values = Buffer::from_vec(out);
// SAFETY: ASCII case conversion preserves byte length, so the original
// offsets and nulls remain valid.
Arc::new(unsafe {
GenericStringArray::<T>::new_unchecked(
offsets.clone(),
values,
string_array.nulls().cloned(),
)
})
}

fn initcap_utf8view(args: &[ArrayRef]) -> Result<ArrayRef> {
let string_view_array = as_string_view_array(&args[0])?;

let mut builder = StringViewBuilder::with_capacity(string_view_array.len());

let mut container = String::new();

string_view_array.iter().for_each(|str| match str {
Some(s) => {
initcap_string(s, &mut container);
Expand All @@ -198,13 +250,16 @@ fn initcap_string(input: &str, container: &mut String) {
let mut prev_is_alphanumeric = false;

if input.is_ascii() {
for c in input.chars() {
container.reserve(input.len());
// SAFETY: each byte is ASCII, so the result is valid UTF-8.
let out = unsafe { container.as_mut_vec() };
for &b in input.as_bytes() {
if prev_is_alphanumeric {
container.push(c.to_ascii_lowercase());
out.push(b.to_ascii_lowercase());
} else {
container.push(c.to_ascii_uppercase());
};
prev_is_alphanumeric = c.is_ascii_alphanumeric();
out.push(b.to_ascii_uppercase());
}
prev_is_alphanumeric = b.is_ascii_alphanumeric();
}
} else {
for c in input.chars() {
Expand All @@ -222,10 +277,11 @@ fn initcap_string(input: &str, container: &mut String) {
mod tests {
use crate::unicode::initcap::InitcapFunc;
use crate::utils::test::test_function;
use arrow::array::{Array, StringArray, StringViewArray};
use arrow::array::{Array, ArrayRef, LargeStringArray, StringArray, StringViewArray};
use arrow::datatypes::DataType::{Utf8, Utf8View};
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
use std::sync::Arc;

#[test]
fn test_functions() -> Result<()> {
Expand Down Expand Up @@ -329,4 +385,100 @@ mod tests {

Ok(())
}

#[test]
fn test_initcap_ascii_array() -> Result<()> {
let array = StringArray::from(vec![
Some("hello world"),
None,
Some("foo-bar_baz/baX"),
Some(""),
Some("123 abc 456DEF"),
Some("ALL CAPS"),
Some("already correct"),
]);
let args: Vec<ArrayRef> = vec![Arc::new(array)];
let result = super::initcap::<i32>(&args)?;
let result = result.as_any().downcast_ref::<StringArray>().unwrap();

assert_eq!(result.len(), 7);
assert_eq!(result.value(0), "Hello World");
assert!(result.is_null(1));
assert_eq!(result.value(2), "Foo-Bar_Baz/Bax");
assert_eq!(result.value(3), "");
assert_eq!(result.value(4), "123 Abc 456def");
assert_eq!(result.value(5), "All Caps");
assert_eq!(result.value(6), "Already Correct");
Ok(())
}

#[test]
fn test_initcap_ascii_large_array() -> Result<()> {
let array = LargeStringArray::from(vec![
Some("hello world"),
None,
Some("foo-bar_baz/baX"),
Some(""),
Some("123 abc 456DEF"),
Some("ALL CAPS"),
Some("already correct"),
]);
let args: Vec<ArrayRef> = vec![Arc::new(array)];
let result = super::initcap::<i64>(&args)?;
let result = result.as_any().downcast_ref::<LargeStringArray>().unwrap();

assert_eq!(result.len(), 7);
assert_eq!(result.value(0), "Hello World");
assert!(result.is_null(1));
assert_eq!(result.value(2), "Foo-Bar_Baz/Bax");
assert_eq!(result.value(3), "");
assert_eq!(result.value(4), "123 Abc 456def");
assert_eq!(result.value(5), "All Caps");
assert_eq!(result.value(6), "Already Correct");
Ok(())
}

/// Test that initcap works correctly on a sliced ASCII StringArray.
#[test]
fn test_initcap_sliced_ascii_array() -> Result<()> {
let array = StringArray::from(vec![
Some("hello world"),
Some("foo bar"),
Some("baz qux"),
]);
// Slice to get only the last two elements. The resulting array's
// offsets are [11, 18, 25] (non-zero start), but value_data still
// contains the full original buffer.
let sliced = array.slice(1, 2);
let args: Vec<ArrayRef> = vec![Arc::new(sliced)];
let result = super::initcap::<i32>(&args)?;
let result = result.as_any().downcast_ref::<StringArray>().unwrap();

assert_eq!(result.len(), 2);
assert_eq!(result.value(0), "Foo Bar");
assert_eq!(result.value(1), "Baz Qux");
Ok(())
}

/// Test that initcap works correctly on a sliced ASCII LargeStringArray.
#[test]
fn test_initcap_sliced_ascii_large_array() -> Result<()> {
let array = LargeStringArray::from(vec![
Some("hello world"),
Some("foo bar"),
Some("baz qux"),
]);
// Slice to get only the last two elements. The resulting array's
// offsets are [11, 18, 25] (non-zero start), but value_data still
// contains the full original buffer.
let sliced = array.slice(1, 2);
let args: Vec<ArrayRef> = vec![Arc::new(sliced)];
let result = super::initcap::<i64>(&args)?;
let result = result.as_any().downcast_ref::<LargeStringArray>().unwrap();

assert_eq!(result.len(), 2);
assert_eq!(result.value(0), "Foo Bar");
assert_eq!(result.value(1), "Baz Qux");
Ok(())
}
}