Skip to content

Commit 264030c

Browse files
authored
feat: support Spark concat string function (#18063)
* chore: Extend backtrace coverage * fmt * part2 * feedback * clippy * feat: support Spark `concat` * clippy * comments * test * doc
1 parent 41fdab9 commit 264030c

File tree

3 files changed

+362
-0
lines changed

3 files changed

+362
-0
lines changed
Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::{Array, ArrayBuilder};
19+
use arrow::datatypes::DataType;
20+
use datafusion_common::{Result, ScalarValue};
21+
use datafusion_expr::{
22+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
23+
Volatility,
24+
};
25+
use datafusion_functions::string::concat::ConcatFunc;
26+
use std::any::Any;
27+
use std::sync::Arc;
28+
29+
/// Spark-compatible `concat` expression
30+
/// <https://spark.apache.org/docs/latest/api/sql/index.html#concat>
31+
///
32+
/// Concatenates multiple input strings into a single string.
33+
/// Returns NULL if any input is NULL.
34+
#[derive(Debug, PartialEq, Eq, Hash)]
35+
pub struct SparkConcat {
36+
signature: Signature,
37+
}
38+
39+
impl Default for SparkConcat {
40+
fn default() -> Self {
41+
Self::new()
42+
}
43+
}
44+
45+
impl SparkConcat {
46+
pub fn new() -> Self {
47+
Self {
48+
signature: Signature::one_of(
49+
vec![TypeSignature::UserDefined, TypeSignature::Nullary],
50+
Volatility::Immutable,
51+
),
52+
}
53+
}
54+
}
55+
56+
impl ScalarUDFImpl for SparkConcat {
57+
fn as_any(&self) -> &dyn Any {
58+
self
59+
}
60+
61+
fn name(&self) -> &str {
62+
"concat"
63+
}
64+
65+
fn signature(&self) -> &Signature {
66+
&self.signature
67+
}
68+
69+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
70+
Ok(DataType::Utf8)
71+
}
72+
73+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
74+
spark_concat(args)
75+
}
76+
77+
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
78+
// Accept any string types, including zero arguments
79+
Ok(arg_types.to_vec())
80+
}
81+
}
82+
83+
/// Concatenates strings, returning NULL if any input is NULL
84+
/// This is a Spark-specific wrapper around DataFusion's concat that returns NULL
85+
/// if any argument is NULL (Spark behavior), whereas DataFusion's concat ignores NULLs.
86+
fn spark_concat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
87+
let ScalarFunctionArgs {
88+
args: arg_values,
89+
arg_fields,
90+
number_rows,
91+
return_field,
92+
config_options,
93+
} = args;
94+
95+
// Handle zero-argument case: return empty string
96+
if arg_values.is_empty() {
97+
return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(
98+
Some(String::new()),
99+
)));
100+
}
101+
102+
// Step 1: Check for NULL mask in incoming args
103+
let null_mask = compute_null_mask(&arg_values, number_rows)?;
104+
105+
// If all scalars and any is NULL, return NULL immediately
106+
if null_mask.is_none() {
107+
return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
108+
}
109+
110+
// Step 2: Delegate to DataFusion's concat
111+
let concat_func = ConcatFunc::new();
112+
let func_args = ScalarFunctionArgs {
113+
args: arg_values,
114+
arg_fields,
115+
number_rows,
116+
return_field,
117+
config_options,
118+
};
119+
let result = concat_func.invoke_with_args(func_args)?;
120+
121+
// Step 3: Apply NULL mask to result
122+
apply_null_mask(result, null_mask)
123+
}
124+
125+
/// Compute NULL mask for the arguments
126+
/// Returns None if all scalars and any is NULL, or a Vector of
127+
/// boolean representing the null mask for incoming arrays
128+
fn compute_null_mask(
129+
args: &[ColumnarValue],
130+
number_rows: usize,
131+
) -> Result<Option<Vec<bool>>> {
132+
// Check if all arguments are scalars
133+
let all_scalars = args
134+
.iter()
135+
.all(|arg| matches!(arg, ColumnarValue::Scalar(_)));
136+
137+
if all_scalars {
138+
// For scalars, check if any is NULL
139+
for arg in args {
140+
if let ColumnarValue::Scalar(scalar) = arg {
141+
if scalar.is_null() {
142+
// Return None to indicate all values should be NULL
143+
return Ok(None);
144+
}
145+
}
146+
}
147+
// No NULLs in scalars
148+
Ok(Some(vec![]))
149+
} else {
150+
// For arrays, compute NULL mask for each row
151+
let array_len = args
152+
.iter()
153+
.find_map(|arg| match arg {
154+
ColumnarValue::Array(array) => Some(array.len()),
155+
_ => None,
156+
})
157+
.unwrap_or(number_rows);
158+
159+
// Convert all scalars to arrays for uniform processing
160+
let arrays: Result<Vec<_>> = args
161+
.iter()
162+
.map(|arg| match arg {
163+
ColumnarValue::Array(array) => Ok(Arc::clone(array)),
164+
ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(array_len),
165+
})
166+
.collect();
167+
let arrays = arrays?;
168+
169+
// Compute NULL mask
170+
let mut null_mask = vec![false; array_len];
171+
for array in &arrays {
172+
for (i, null_flag) in null_mask.iter_mut().enumerate().take(array_len) {
173+
if array.is_null(i) {
174+
*null_flag = true;
175+
}
176+
}
177+
}
178+
179+
Ok(Some(null_mask))
180+
}
181+
}
182+
183+
/// Apply NULL mask to the result
184+
fn apply_null_mask(
185+
result: ColumnarValue,
186+
null_mask: Option<Vec<bool>>,
187+
) -> Result<ColumnarValue> {
188+
match (result, null_mask) {
189+
// Scalar with NULL mask means return NULL
190+
(ColumnarValue::Scalar(_), None) => {
191+
Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)))
192+
}
193+
// Scalar without NULL mask, return as-is
194+
(scalar @ ColumnarValue::Scalar(_), Some(mask)) if mask.is_empty() => Ok(scalar),
195+
// Array with NULL mask
196+
(ColumnarValue::Array(array), Some(null_mask)) if !null_mask.is_empty() => {
197+
let array_len = array.len();
198+
let return_type = array.data_type();
199+
200+
let mut builder: Box<dyn ArrayBuilder> = match return_type {
201+
DataType::Utf8 => {
202+
let string_array = array
203+
.as_any()
204+
.downcast_ref::<arrow::array::StringArray>()
205+
.unwrap();
206+
let mut builder =
207+
arrow::array::StringBuilder::with_capacity(array_len, 0);
208+
for (i, &is_null) in null_mask.iter().enumerate().take(array_len) {
209+
if is_null || string_array.is_null(i) {
210+
builder.append_null();
211+
} else {
212+
builder.append_value(string_array.value(i));
213+
}
214+
}
215+
Box::new(builder)
216+
}
217+
DataType::LargeUtf8 => {
218+
let string_array = array
219+
.as_any()
220+
.downcast_ref::<arrow::array::LargeStringArray>()
221+
.unwrap();
222+
let mut builder =
223+
arrow::array::LargeStringBuilder::with_capacity(array_len, 0);
224+
for (i, &is_null) in null_mask.iter().enumerate().take(array_len) {
225+
if is_null || string_array.is_null(i) {
226+
builder.append_null();
227+
} else {
228+
builder.append_value(string_array.value(i));
229+
}
230+
}
231+
Box::new(builder)
232+
}
233+
DataType::Utf8View => {
234+
let string_array = array
235+
.as_any()
236+
.downcast_ref::<arrow::array::StringViewArray>()
237+
.unwrap();
238+
let mut builder =
239+
arrow::array::StringViewBuilder::with_capacity(array_len);
240+
for (i, &is_null) in null_mask.iter().enumerate().take(array_len) {
241+
if is_null || string_array.is_null(i) {
242+
builder.append_null();
243+
} else {
244+
builder.append_value(string_array.value(i));
245+
}
246+
}
247+
Box::new(builder)
248+
}
249+
_ => {
250+
return datafusion_common::exec_err!(
251+
"Unsupported return type for concat: {:?}",
252+
return_type
253+
);
254+
}
255+
};
256+
257+
Ok(ColumnarValue::Array(builder.finish()))
258+
}
259+
// Array without NULL mask, return as-is
260+
(array @ ColumnarValue::Array(_), _) => Ok(array),
261+
// Shouldn't happen
262+
(scalar, _) => Ok(scalar),
263+
}
264+
}
265+
266+
#[cfg(test)]
267+
mod tests {
268+
use super::*;
269+
use crate::function::utils::test::test_scalar_function;
270+
use arrow::array::StringArray;
271+
use arrow::datatypes::DataType;
272+
use datafusion_common::Result;
273+
274+
#[test]
275+
fn test_concat_basic() -> Result<()> {
276+
test_scalar_function!(
277+
SparkConcat::new(),
278+
vec![
279+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("Spark".to_string()))),
280+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("SQL".to_string()))),
281+
],
282+
Ok(Some("SparkSQL")),
283+
&str,
284+
DataType::Utf8,
285+
StringArray
286+
);
287+
Ok(())
288+
}
289+
290+
#[test]
291+
fn test_concat_with_null() -> Result<()> {
292+
test_scalar_function!(
293+
SparkConcat::new(),
294+
vec![
295+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("Spark".to_string()))),
296+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("SQL".to_string()))),
297+
ColumnarValue::Scalar(ScalarValue::Utf8(None)),
298+
],
299+
Ok(None),
300+
&str,
301+
DataType::Utf8,
302+
StringArray
303+
);
304+
Ok(())
305+
}
306+
}

datafusion/spark/src/function/string/mod.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
pub mod ascii;
1919
pub mod char;
20+
pub mod concat;
2021
pub mod elt;
2122
pub mod format_string;
2223
pub mod ilike;
@@ -30,6 +31,7 @@ use std::sync::Arc;
3031

3132
make_udf_function!(ascii::SparkAscii, ascii);
3233
make_udf_function!(char::CharFunc, char);
34+
make_udf_function!(concat::SparkConcat, concat);
3335
make_udf_function!(ilike::SparkILike, ilike);
3436
make_udf_function!(length::SparkLengthFunc, length);
3537
make_udf_function!(elt::SparkElt, elt);
@@ -50,6 +52,11 @@ pub mod expr_fn {
5052
"Returns the ASCII character having the binary equivalent to col. If col is larger than 256 the result is equivalent to char(col % 256).",
5153
arg1
5254
));
55+
export_functions!((
56+
concat,
57+
"Concatenates multiple input strings into a single string. Returns NULL if any input is NULL.",
58+
args
59+
));
5360
export_functions!((
5461
elt,
5562
"Returns the n-th input (1-indexed), e.g. returns 2nd input when n is 2. The function returns NULL if the index is 0 or exceeds the length of the array.",
@@ -86,6 +93,7 @@ pub fn functions() -> Vec<Arc<ScalarUDF>> {
8693
vec![
8794
ascii(),
8895
char(),
96+
concat(),
8997
elt(),
9098
ilike(),
9199
length(),
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
query T
19+
SELECT concat('Spark', 'SQL');
20+
----
21+
SparkSQL
22+
23+
query T
24+
SELECT concat('Spark', 'SQL', NULL);
25+
----
26+
NULL
27+
28+
query T
29+
SELECT concat('', '1', '', '2');
30+
----
31+
12
32+
33+
query T
34+
SELECT concat();
35+
----
36+
(empty)
37+
38+
query T
39+
SELECT concat('');
40+
----
41+
(empty)
42+
43+
44+
query T
45+
SELECT concat(a, b, c) from (select 'a' a, 'b' b, 'c' c union all select null a, 'b', 'c') order by 1 nulls last;
46+
----
47+
abc
48+
NULL

0 commit comments

Comments
 (0)