diff --git a/datafusion/functions/src/regex/regexpreplace.rs b/datafusion/functions/src/regex/regexpreplace.rs index 378b6ced076c..9f4d0374e125 100644 --- a/datafusion/functions/src/regex/regexpreplace.rs +++ b/datafusion/functions/src/regex/regexpreplace.rs @@ -17,11 +17,14 @@ //! Regx expressions use arrow::array::new_null_array; +use arrow::array::ArrayAccessor; use arrow::array::ArrayDataBuilder; use arrow::array::BufferBuilder; use arrow::array::GenericStringArray; +use arrow::array::StringViewBuilder; use arrow::array::{Array, ArrayRef, OffsetSizeTrait}; use arrow::datatypes::DataType; +use datafusion_common::cast::as_string_view_array; use datafusion_common::exec_err; use datafusion_common::plan_err; use datafusion_common::ScalarValue; @@ -54,6 +57,7 @@ impl RegexpReplaceFunc { signature: Signature::one_of( vec![ Exact(vec![Utf8, Utf8, Utf8]), + Exact(vec![Utf8View, Utf8, Utf8]), Exact(vec![Utf8, Utf8, Utf8, Utf8]), ], Volatility::Immutable, @@ -80,6 +84,7 @@ impl ScalarUDFImpl for RegexpReplaceFunc { Ok(match &arg_types[0] { LargeUtf8 | LargeBinary => LargeUtf8, Utf8 | Binary => Utf8, + Utf8View | BinaryView => Utf8View, Null => Null, Dictionary(_, t) => match **t { LargeUtf8 | LargeBinary => LargeUtf8, @@ -118,15 +123,18 @@ impl ScalarUDFImpl for RegexpReplaceFunc { } } } + fn regexp_replace_func(args: &[ColumnarValue]) -> Result { match args[0].data_type() { DataType::Utf8 => specialize_regexp_replace::(args), DataType::LargeUtf8 => specialize_regexp_replace::(args), + DataType::Utf8View => specialize_regexp_replace::(args), other => { internal_err!("Unsupported data type {other:?} for function regexp_replace") } } } + /// replace POSIX capture groups (like \1) with Rust Regex group (like ${1}) /// used by regexp_replace fn regex_replace_posix_groups(replacement: &str) -> String { @@ -280,8 +288,8 @@ pub fn regexp_replace(args: &[ArrayRef]) -> Result } } -fn _regexp_replace_early_abort( - input_array: &GenericStringArray, +fn _regexp_replace_early_abort( + input_array: T, sz: usize, ) -> Result { // Mimicking the existing behavior of regexp_replace, if any of the scalar arguments @@ -290,13 +298,14 @@ fn _regexp_replace_early_abort( // Also acts like an early abort mechanism when the input array is empty. Ok(new_null_array(input_array.data_type(), sz)) } + /// Get the first argument from the given string array. /// /// Note: If the array is empty or the first argument is null, /// then calls the given early abort function. macro_rules! fetch_string_arg { ($ARG:expr, $NAME:expr, $T:ident, $EARLY_ABORT:ident, $ARRAY_SIZE:expr) => {{ - let array = as_generic_string_array::($ARG)?; + let array = as_generic_string_array::<$T>($ARG)?; if array.len() == 0 || array.is_null(0) { return $EARLY_ABORT(array, $ARRAY_SIZE); } else { @@ -313,25 +322,24 @@ macro_rules! fetch_string_arg { fn _regexp_replace_static_pattern_replace( args: &[ArrayRef], ) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - let array_size = string_array.len(); + let array_size = args[0].len(); let pattern = fetch_string_arg!( &args[1], "pattern", - T, + i32, _regexp_replace_early_abort, array_size ); let replacement = fetch_string_arg!( &args[2], "replacement", - T, + i32, _regexp_replace_early_abort, array_size ); let flags = match args.len() { 3 => None, - 4 => Some(fetch_string_arg!(&args[3], "flags", T, _regexp_replace_early_abort, array_size)), + 4 => Some(fetch_string_arg!(&args[3], "flags", i32, _regexp_replace_early_abort, array_size)), other => { return exec_err!( "regexp_replace was called with {other} arguments. It requires at least 3 and at most 4." @@ -358,32 +366,61 @@ fn _regexp_replace_static_pattern_replace( // with rust ones. let replacement = regex_replace_posix_groups(replacement); - // We are going to create the underlying string buffer from its parts - // to be able to re-use the existing null buffer for sparse arrays. - let mut vals = BufferBuilder::::new({ - let offsets = string_array.value_offsets(); - (offsets[string_array.len()] - offsets[0]) - .to_usize() - .expect("Failed to convert usize") - }); - let mut new_offsets = BufferBuilder::::new(string_array.len() + 1); - new_offsets.append(T::zero()); - - string_array.iter().for_each(|val| { - if let Some(val) = val { - let result = re.replacen(val, limit, replacement.as_str()); - vals.append_slice(result.as_bytes()); + let string_array_type = args[0].data_type(); + match string_array_type { + DataType::Utf8 | DataType::LargeUtf8 => { + let string_array = as_generic_string_array::(&args[0])?; + + // We are going to create the underlying string buffer from its parts + // to be able to re-use the existing null buffer for sparse arrays. + let mut vals = BufferBuilder::::new({ + let offsets = string_array.value_offsets(); + (offsets[string_array.len()] - offsets[0]) + .to_usize() + .unwrap() + }); + let mut new_offsets = BufferBuilder::::new(string_array.len() + 1); + new_offsets.append(T::zero()); + + string_array.iter().for_each(|val| { + if let Some(val) = val { + let result = re.replacen(val, limit, replacement.as_str()); + vals.append_slice(result.as_bytes()); + } + new_offsets.append(T::from_usize(vals.len()).unwrap()); + }); + + let data = ArrayDataBuilder::new(GenericStringArray::::DATA_TYPE) + .len(string_array.len()) + .nulls(string_array.nulls().cloned()) + .buffers(vec![new_offsets.finish(), vals.finish()]) + .build()?; + let result_array = GenericStringArray::::from(data); + Ok(Arc::new(result_array) as ArrayRef) } - new_offsets.append(T::from_usize(vals.len()).unwrap()); - }); - - let data = ArrayDataBuilder::new(GenericStringArray::::DATA_TYPE) - .len(string_array.len()) - .nulls(string_array.nulls().cloned()) - .buffers(vec![new_offsets.finish(), vals.finish()]) - .build()?; - let result_array = GenericStringArray::::from(data); - Ok(Arc::new(result_array) as ArrayRef) + DataType::Utf8View => { + let string_view_array = as_string_view_array(&args[0])?; + + let mut builder = StringViewBuilder::with_capacity(string_view_array.len()) + .with_block_size(1024 * 1024 * 2); + + for val in string_view_array.iter() { + if let Some(val) = val { + let result = re.replacen(val, limit, replacement.as_str()); + builder.append_value(result); + } else { + builder.append_null(); + } + } + + let result = builder.finish(); + Ok(Arc::new(result) as ArrayRef) + } + _ => unreachable!( + "Invalid data type for regexp_replace: {}", + string_array_type + ), + } } /// Determine which implementation of the regexp_replace to use based @@ -469,43 +506,91 @@ mod tests { use super::*; - #[test] - fn test_static_pattern_regexp_replace() { - let values = StringArray::from(vec!["abc"; 5]); - let patterns = StringArray::from(vec!["b"; 5]); - let replacements = StringArray::from(vec!["foo"; 5]); - let expected = StringArray::from(vec!["afooc"; 5]); - - let re = _regexp_replace_static_pattern_replace::(&[ - Arc::new(values), - Arc::new(patterns), - Arc::new(replacements), - ]) - .unwrap(); - - assert_eq!(re.as_ref(), &expected); + macro_rules! static_pattern_regexp_replace { + ($name:ident, $T:ty, $O:ty) => { + #[test] + fn $name() { + let values = vec!["abc", "acd", "abcd1234567890123", "123456789012abc"]; + let patterns = vec!["b"; 4]; + let replacement = vec!["foo"; 4]; + let expected = + vec!["afooc", "acd", "afoocd1234567890123", "123456789012afooc"]; + + let values = <$T>::from(values); + let patterns = StringArray::from(patterns); + let replacements = StringArray::from(replacement); + let expected = <$T>::from(expected); + + let re = _regexp_replace_static_pattern_replace::<$O>(&[ + Arc::new(values), + Arc::new(patterns), + Arc::new(replacements), + ]) + .unwrap(); + + assert_eq!(re.as_ref(), &expected); + } + }; } - #[test] - fn test_static_pattern_regexp_replace_with_flags() { - let values = StringArray::from(vec!["abc", "ABC", "aBc", "AbC", "aBC"]); - let patterns = StringArray::from(vec!["b"; 5]); - let replacements = StringArray::from(vec!["foo"; 5]); - let flags = StringArray::from(vec!["i"; 5]); - let expected = - StringArray::from(vec!["afooc", "AfooC", "afooc", "AfooC", "afooC"]); - - let re = _regexp_replace_static_pattern_replace::(&[ - Arc::new(values), - Arc::new(patterns), - Arc::new(replacements), - Arc::new(flags), - ]) - .unwrap(); - - assert_eq!(re.as_ref(), &expected); + static_pattern_regexp_replace!(string_array, StringArray, i32); + static_pattern_regexp_replace!(string_view_array, StringViewArray, i32); + static_pattern_regexp_replace!(large_string_array, LargeStringArray, i64); + + macro_rules! static_pattern_regexp_replace_with_flags { + ($name:ident, $T:ty, $O: ty) => { + #[test] + fn $name() { + let values = vec![ + "abc", + "aBc", + "acd", + "abcd1234567890123", + "aBcd1234567890123", + "123456789012abc", + "123456789012aBc", + ]; + let expected = vec![ + "afooc", + "afooc", + "acd", + "afoocd1234567890123", + "afoocd1234567890123", + "123456789012afooc", + "123456789012afooc", + ]; + + let values = <$T>::from(values); + let patterns = StringArray::from(vec!["b"; 7]); + let replacements = StringArray::from(vec!["foo"; 7]); + let flags = StringArray::from(vec!["i"; 5]); + let expected = <$T>::from(expected); + + let re = _regexp_replace_static_pattern_replace::<$O>(&[ + Arc::new(values), + Arc::new(patterns), + Arc::new(replacements), + Arc::new(flags), + ]) + .unwrap(); + + assert_eq!(re.as_ref(), &expected); + } + }; } + static_pattern_regexp_replace_with_flags!(string_array_with_flags, StringArray, i32); + static_pattern_regexp_replace_with_flags!( + string_view_array_with_flags, + StringViewArray, + i32 + ); + static_pattern_regexp_replace_with_flags!( + large_string_array_with_flags, + LargeStringArray, + i64 + ); + #[test] fn test_static_pattern_regexp_replace_early_abort() { let values = StringArray::from(vec!["abc"; 5]);