diff --git a/docs/source/user-guide/latest/configs.md b/docs/source/user-guide/latest/configs.md index 1767137b09..3ed8be8515 100644 --- a/docs/source/user-guide/latest/configs.md +++ b/docs/source/user-guide/latest/configs.md @@ -307,6 +307,7 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.expression.StringRepeat.enabled` | Enable Comet acceleration for `StringRepeat` | true | | `spark.comet.expression.StringReplace.enabled` | Enable Comet acceleration for `StringReplace` | true | | `spark.comet.expression.StringSpace.enabled` | Enable Comet acceleration for `StringSpace` | true | +| `spark.comet.expression.StringSplit.enabled` | Enable Comet acceleration for `StringSplit` | true | | `spark.comet.expression.StringTranslate.enabled` | Enable Comet acceleration for `StringTranslate` | true | | `spark.comet.expression.StringTrim.enabled` | Enable Comet acceleration for `StringTrim` | true | | `spark.comet.expression.StringTrimBoth.enabled` | Enable Comet acceleration for `StringTrimBoth` | true | diff --git a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala index 2e29cb930b..a510584b9b 100644 --- a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala +++ b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala @@ -184,6 +184,11 @@ object Meta { FunctionSignature(Seq(SparkStringType, SparkIntegralType)), FunctionSignature(Seq(SparkStringType, SparkIntegralType, SparkStringType)))), createUnaryStringFunction("rtrim"), + createFunctions( + "split", + Seq( + FunctionSignature(Seq(SparkStringType, SparkStringType)), + FunctionSignature(Seq(SparkStringType, SparkStringType, SparkIntType)))), createFunctionWithInputTypes("starts_with", Seq(SparkStringType, SparkStringType)), createFunctionWithInputTypes("string_space", Seq(SparkIntType)), createFunctionWithInputTypes("substring", Seq(SparkStringType, SparkIntType, SparkIntType)), diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 021bb1c78f..1844942630 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -185,6 +185,10 @@ pub fn create_comet_physical_fun_with_eval_mode( let func = Arc::new(abs); make_comet_scalar_udf!("abs", func, without data_type) } + "split" => { + let func = Arc::new(crate::string_funcs::spark_split); + make_comet_scalar_udf!("split", func, without data_type) + } _ => registry.udf(fun_name).map_err(|e| { DataFusionError::Execution(format!( "Function {fun_name} not found in the registry: {e}", diff --git a/native/spark-expr/src/string_funcs/mod.rs b/native/spark-expr/src/string_funcs/mod.rs index aac8204e29..ae00349ba1 100644 --- a/native/spark-expr/src/string_funcs/mod.rs +++ b/native/spark-expr/src/string_funcs/mod.rs @@ -15,8 +15,10 @@ // specific language governing permissions and limitations // under the License. +mod split; mod string_space; mod substring; +pub use split::spark_split; pub use string_space::SparkStringSpace; pub use substring::SubstringExpr; diff --git a/native/spark-expr/src/string_funcs/split.rs b/native/spark-expr/src/string_funcs/split.rs new file mode 100644 index 0000000000..f3c2c33782 --- /dev/null +++ b/native/spark-expr/src/string_funcs/split.rs @@ -0,0 +1,314 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, GenericStringArray, ListArray}; +use arrow::datatypes::{DataType, Field}; +use datafusion::common::{ + cast::as_generic_string_array, exec_err, DataFusionError, Result as DataFusionResult, + ScalarValue, +}; +use datafusion::logical_expr::ColumnarValue; +use regex::Regex; +use std::sync::Arc; + +/// Spark-compatible split function +/// Splits a string around matches of a regex pattern with optional limit +/// +/// Arguments: +/// - string: The string to split +/// - pattern: The regex pattern to split on +/// - limit (optional): Controls the number of splits +/// - limit > 0: At most limit-1 splits, array length <= limit +/// - limit = 0: As many splits as possible, trailing empty strings removed +/// - limit < 0: As many splits as possible, trailing empty strings kept +pub fn spark_split(args: &[ColumnarValue]) -> DataFusionResult { + if args.len() < 2 || args.len() > 3 { + return exec_err!( + "split expects 2 or 3 arguments (string, pattern, [limit]), got {}", + args.len() + ); + } + + // Get limit parameter (default to -1 if not provided) + let limit = if args.len() == 3 { + match &args[2] { + ColumnarValue::Scalar(ScalarValue::Int32(Some(l))) => *l, + ColumnarValue::Scalar(ScalarValue::Int32(None)) => { + // NULL limit, return NULL + return Ok(ColumnarValue::Scalar(ScalarValue::Null)); + } + _ => { + return exec_err!("split limit argument must be an Int32 scalar"); + } + } + } else { + -1 + }; + + match (&args[0], &args[1]) { + (ColumnarValue::Array(string_array), ColumnarValue::Scalar(ScalarValue::Utf8(pattern))) + | ( + ColumnarValue::Array(string_array), + ColumnarValue::Scalar(ScalarValue::LargeUtf8(pattern)), + ) => { + if pattern.is_none() { + // NULL pattern returns NULL + let null_array = new_null_list_array(string_array.len()); + return Ok(ColumnarValue::Array(null_array)); + } + + let pattern_str = pattern.as_ref().unwrap(); + split_array(string_array.as_ref(), pattern_str, limit) + } + (ColumnarValue::Scalar(ScalarValue::Utf8(string)), ColumnarValue::Scalar(pattern_val)) + | ( + ColumnarValue::Scalar(ScalarValue::LargeUtf8(string)), + ColumnarValue::Scalar(pattern_val), + ) => { + if string.is_none() { + return Ok(ColumnarValue::Scalar(ScalarValue::Null)); + } + + let pattern_str = match pattern_val { + ScalarValue::Utf8(Some(p)) | ScalarValue::LargeUtf8(Some(p)) => p, + ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None) => { + return Ok(ColumnarValue::Scalar(ScalarValue::Null)); + } + _ => { + return exec_err!("split pattern must be a string"); + } + }; + + let result = split_string(string.as_ref().unwrap(), pattern_str, limit)?; + let string_array = GenericStringArray::::from(result); + let list_array = create_list_array(Arc::new(string_array)); + + Ok(ColumnarValue::Scalar(ScalarValue::List(Arc::new( + list_array, + )))) + } + _ => exec_err!("split expects (array, scalar) or (scalar, scalar) arguments"), + } +} + +fn split_array( + string_array: &dyn arrow::array::Array, + pattern: &str, + limit: i32, +) -> DataFusionResult { + // Compile regex once for the entire array + let regex = Regex::new(pattern).map_err(|e| { + DataFusionError::Execution(format!("Invalid regex pattern '{}': {}", pattern, e)) + })?; + + let string_array = match string_array.data_type() { + DataType::Utf8 => as_generic_string_array::(string_array)?, + DataType::LargeUtf8 => { + // Convert LargeUtf8 to Utf8 for processing + let large_array = as_generic_string_array::(string_array)?; + return split_large_string_array(&large_array, ®ex, limit); + } + _ => { + return exec_err!( + "split expects Utf8 or LargeUtf8 string array, got {:?}", + string_array.data_type() + ); + } + }; + + // Build the result ListArray + let mut offsets: Vec = Vec::with_capacity(string_array.len() + 1); + let mut values: Vec = Vec::new(); + offsets.push(0); + + for i in 0..string_array.len() { + if string_array.is_null(i) { + // NULL input produces empty array element (maintain position) + offsets.push(offsets[i]); + } else { + let string_val = string_array.value(i); + let parts = split_with_regex(string_val, ®ex, limit); + values.extend(parts); + offsets.push(values.len() as i32); + } + } + + let values_array = Arc::new(GenericStringArray::::from(values)) as ArrayRef; + let field = Arc::new(Field::new("item", DataType::Utf8, false)); + let list_array = ListArray::new( + field, + arrow::buffer::OffsetBuffer::new(offsets.into()), + values_array, + None, // No nulls at list level + ); + + Ok(ColumnarValue::Array(Arc::new(list_array))) +} + +fn split_large_string_array( + string_array: &GenericStringArray, + regex: &Regex, + limit: i32, +) -> DataFusionResult { + let mut offsets: Vec = Vec::with_capacity(string_array.len() + 1); + let mut values: Vec = Vec::new(); + offsets.push(0); + + for i in 0..string_array.len() { + if string_array.is_null(i) { + offsets.push(offsets[i]); + } else { + let string_val = string_array.value(i); + let parts = split_with_regex(string_val, regex, limit); + values.extend(parts); + offsets.push(values.len() as i32); + } + } + + let values_array = Arc::new(GenericStringArray::::from(values)) as ArrayRef; + let field = Arc::new(Field::new("item", DataType::Utf8, false)); + let list_array = ListArray::new( + field, + arrow::buffer::OffsetBuffer::new(offsets.into()), + values_array, + None, + ); + + Ok(ColumnarValue::Array(Arc::new(list_array))) +} + +fn split_string(string: &str, pattern: &str, limit: i32) -> DataFusionResult> { + let regex = Regex::new(pattern).map_err(|e| { + DataFusionError::Execution(format!("Invalid regex pattern '{}': {}", pattern, e)) + })?; + + Ok(split_with_regex(string, ®ex, limit)) +} + +fn split_with_regex(string: &str, regex: &Regex, limit: i32) -> Vec { + if limit == 0 { + // limit = 0: split as many times as possible, discard trailing empty strings + let mut parts: Vec = regex.split(string).map(|s| s.to_string()).collect(); + // Remove trailing empty strings + while parts.last().map_or(false, |s| s.is_empty()) { + parts.pop(); + } + if parts.is_empty() { + vec!["".to_string()] + } else { + parts + } + } else if limit > 0 { + // limit > 0: at most limit-1 splits (array length <= limit) + let mut parts: Vec = Vec::new(); + let mut last_end = 0; + let mut count = 0; + + for mat in regex.find_iter(string) { + if count >= limit - 1 { + break; + } + parts.push(string[last_end..mat.start()].to_string()); + last_end = mat.end(); + count += 1; + } + // Add the remaining string + parts.push(string[last_end..].to_string()); + parts + } else { + // limit < 0: split as many times as possible, keep trailing empty strings + regex.split(string).map(|s| s.to_string()).collect() + } +} + +fn create_list_array(values: ArrayRef) -> ListArray { + let field = Arc::new(Field::new("item", DataType::Utf8, false)); + let offsets = vec![0i32, values.len() as i32]; + ListArray::new( + field, + arrow::buffer::OffsetBuffer::new(offsets.into()), + values, + None, + ) +} + +fn new_null_list_array(len: usize) -> ArrayRef { + let field = Arc::new(Field::new("item", DataType::Utf8, false)); + let values = Arc::new(GenericStringArray::::from(Vec::::new())) as ArrayRef; + let offsets = vec![0i32; len + 1]; + let nulls = arrow::buffer::NullBuffer::new_null(len); + + Arc::new(ListArray::new( + field, + arrow::buffer::OffsetBuffer::new(offsets.into()), + values, + Some(nulls), + )) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::StringArray; + + #[test] + fn test_split_basic() { + let string_array = Arc::new(StringArray::from(vec!["a,b,c", "x,y,z"])) as ArrayRef; + let pattern = ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string()))); + let args = vec![ColumnarValue::Array(string_array), pattern]; + + let result = spark_split(&args).unwrap(); + // Should produce [["a", "b", "c"], ["x", "y", "z"]] + assert!(matches!(result, ColumnarValue::Array(_))); + } + + #[test] + fn test_split_with_limit() { + let string_array = Arc::new(StringArray::from(vec!["a,b,c,d"])) as ArrayRef; + let pattern = ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string()))); + let limit = ColumnarValue::Scalar(ScalarValue::Int32(Some(2))); + let args = vec![ColumnarValue::Array(string_array), pattern, limit]; + + let result = spark_split(&args).unwrap(); + // Should produce [["a", "b,c,d"]] + assert!(matches!(result, ColumnarValue::Array(_))); + } + + #[test] + fn test_split_regex() { + let parts = split_string("foo123bar456baz", r"\d+", -1).unwrap(); + assert_eq!(parts, vec!["foo", "bar", "baz"]); + } + + #[test] + fn test_split_limit_positive() { + let parts = split_string("a,b,c,d,e", ",", 3).unwrap(); + assert_eq!(parts, vec!["a", "b", "c,d,e"]); + } + + #[test] + fn test_split_limit_zero() { + let parts = split_string("a,b,c,,", ",", 0).unwrap(); + assert_eq!(parts, vec!["a", "b", "c"]); + } + + #[test] + fn test_split_limit_negative() { + let parts = split_string("a,b,c,,", ",", -1).unwrap(); + assert_eq!(parts, vec!["a", "b", "c", "", ""]); + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 6bf3776a23..c1c8311ba3 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -162,6 +162,7 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[StringRPad] -> CometStringRPad, classOf[StringLPad] -> CometStringLPad, classOf[StringSpace] -> CometScalarFunction("string_space"), + classOf[StringSplit] -> CometScalarFunction("split"), classOf[StringTranslate] -> CometScalarFunction("translate"), classOf[StringTrim] -> CometScalarFunction("trim"), classOf[StringTrimBoth] -> CometScalarFunction("btrim"), diff --git a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala index f9882780c8..ae5e5a73d6 100644 --- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala @@ -148,6 +148,126 @@ class CometStringExpressionSuite extends CometTestBase { } } + test("split string basic") { + // Basic split tests with 2 arguments (no limit) + withParquetTable((0 until 5).map(i => (s"value$i,test$i", i)), "tbl") { + checkSparkAnswerAndOperator("SELECT split(_1, ',') FROM tbl") + checkSparkAnswerAndOperator("SELECT split('one,two,three', ',') FROM tbl") + checkSparkAnswerAndOperator("SELECT split(_1, '-') FROM tbl") + } + } + + test("split string with limit") { + // Split tests with 3 arguments (with limit) + withParquetTable((0 until 5).map(i => ("a,b,c,d,e", i)), "tbl") { + checkSparkAnswerAndOperator("SELECT split(_1, ',', 2) FROM tbl") + checkSparkAnswerAndOperator("SELECT split(_1, ',', 3) FROM tbl") + checkSparkAnswerAndOperator("SELECT split(_1, ',', -1) FROM tbl") + checkSparkAnswerAndOperator("SELECT split(_1, ',', 0) FROM tbl") + } + } + + test("split string with regex patterns") { + // Test with various regex patterns + withParquetTable((0 until 5).map(i => ("word1 word2 word3", i)), "tbl") { + checkSparkAnswerAndOperator("SELECT split(_1, ' ') FROM tbl") + checkSparkAnswerAndOperator("SELECT split(_1, '\\\\s+') FROM tbl") + } + + withParquetTable((0 until 5).map(i => ("foo123bar456baz", i)), "tbl2") { + checkSparkAnswerAndOperator("SELECT split(_1, '\\\\d+') FROM tbl2") + } + } + + test("split string edge cases") { + // Test edge cases: empty strings, nulls, single character + withParquetTable(Seq(("", 0), ("single", 1), (null, 2), ("a", 3)), "tbl") { + checkSparkAnswerAndOperator("SELECT split(_1, ',') FROM tbl") + } + } + + test("split string with UTF-8 characters") { + // Test with multi-byte UTF-8 characters to verify regex engine compatibility + // between Java (Spark) and Rust (Comet) + + // CJK characters + withParquetTable(Seq(("你好,世界", 0), ("こんにちは,世界", 1)), "tbl_cjk") { + checkSparkAnswerAndOperator("SELECT split(_1, ',') FROM tbl_cjk") + } + + // Emoji and symbols + withParquetTable(Seq(("😀,😃,😄", 0), ("🔥,💧,🌍", 1), ("α,β,γ", 2)), "tbl_emoji") { + checkSparkAnswerAndOperator("SELECT split(_1, ',') FROM tbl_emoji") + } + + // Combining characters / grapheme clusters + // "é" as combining character (e + combining acute accent) + // vs "é" as single character (precomposed) + withParquetTable( + Seq( + ("café,naïve", 0), // precomposed + ("café,naïve", 1), // combining (if your editor supports it) + ("मानक,हिन्दी", 2) + ), // Devanagari script + "tbl_graphemes") { + checkSparkAnswerAndOperator("SELECT split(_1, ',') FROM tbl_graphemes") + } + + // Mixed ASCII and multi-byte with regex patterns + withParquetTable( + Seq(("hello世界test你好", 0), ("foo😀bar😃baz", 1), ("abc한글def", 2)), // Korean Hangul + "tbl_mixed") { + // Split on ASCII word boundaries + checkSparkAnswerAndOperator("SELECT split(_1, '[a-z]+') FROM tbl_mixed") + } + + // RTL (Right-to-Left) characters + withParquetTable(Seq(("مرحبا,عالم", 0), ("שלום,עולם", 1)), "tbl_rtl") { // Arabic, Hebrew + checkSparkAnswerAndOperator("SELECT split(_1, ',') FROM tbl_rtl") + } + + // Zero-width characters and special Unicode + withParquetTable( + Seq( + ("test\u200Bword", 0), // Zero-width space + ("foo\u00ADbar", 1) + ), // Soft hyphen + "tbl_special") { + checkSparkAnswerAndOperator("SELECT split(_1, '\u200B') FROM tbl_special") + } + + // Surrogate pairs (4-byte UTF-8) + withParquetTable( + Seq( + ("𝐇𝐞𝐥𝐥𝐨,𝐖𝐨𝐫𝐥𝐝", 0), // Mathematical bold letters (U+1D400 range) + ("𠜎,𠜱,𠝹", 1) + ), // CJK Extension B + "tbl_surrogate") { + checkSparkAnswerAndOperator("SELECT split(_1, ',') FROM tbl_surrogate") + } + } + + test("split string with UTF-8 regex patterns") { + // Test regex patterns that involve UTF-8 characters + + // Split on Unicode character classes + withParquetTable( + Seq( + ("word1 word2 word3", 0), // Regular space and ideographic space (U+3000) + ("test1\u00A0test2", 1) + ), // Non-breaking space + "tbl_space") { + // Split on any whitespace (should match all Unicode whitespace) + checkSparkAnswerAndOperator("SELECT split(_1, '\\\\s+') FROM tbl_space") + } + + // Split with limit on UTF-8 strings + withParquetTable(Seq(("你,好,世,界", 0), ("😀,😃,😄,😁", 1)), "tbl_utf8_limit") { + checkSparkAnswerAndOperator("SELECT split(_1, ',', 2) FROM tbl_utf8_limit") + checkSparkAnswerAndOperator("SELECT split(_1, ',', -1) FROM tbl_utf8_limit") + } + } + test("Various String scalar functions") { val table = "names" withTable(table) {