Skip to content

Commit ce4eb18

Browse files
compheadalamb
andauthored
[branch-50]: chore: cherry pick concat to 50.3.0 (apache#18128)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #. ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. --> --------- Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
1 parent 28ad4ef commit ce4eb18

File tree

3 files changed

+325
-1
lines changed

3 files changed

+325
-1
lines changed
Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::Array;
19+
use arrow::buffer::NullBuffer;
20+
use arrow::datatypes::DataType;
21+
use datafusion_common::{Result, ScalarValue};
22+
use datafusion_expr::{
23+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
24+
Volatility,
25+
};
26+
use datafusion_functions::string::concat::ConcatFunc;
27+
use std::any::Any;
28+
use std::sync::Arc;
29+
30+
/// Spark-compatible `concat` expression
31+
/// <https://spark.apache.org/docs/latest/api/sql/index.html#concat>
32+
///
33+
/// Concatenates multiple input strings into a single string.
34+
/// Returns NULL if any input is NULL.
35+
///
36+
/// Differences with DataFusion concat:
37+
/// - Support 0 arguments
38+
/// - Return NULL if any input is NULL
39+
#[derive(Debug, PartialEq, Eq, Hash)]
40+
pub struct SparkConcat {
41+
signature: Signature,
42+
}
43+
44+
impl Default for SparkConcat {
45+
fn default() -> Self {
46+
Self::new()
47+
}
48+
}
49+
50+
impl SparkConcat {
51+
pub fn new() -> Self {
52+
Self {
53+
signature: Signature::one_of(
54+
vec![TypeSignature::UserDefined, TypeSignature::Nullary],
55+
Volatility::Immutable,
56+
),
57+
}
58+
}
59+
}
60+
61+
impl ScalarUDFImpl for SparkConcat {
62+
fn as_any(&self) -> &dyn Any {
63+
self
64+
}
65+
66+
fn name(&self) -> &str {
67+
"concat"
68+
}
69+
70+
fn signature(&self) -> &Signature {
71+
&self.signature
72+
}
73+
74+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
75+
Ok(DataType::Utf8)
76+
}
77+
78+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
79+
spark_concat(args)
80+
}
81+
82+
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
83+
// Accept any string types, including zero arguments
84+
Ok(arg_types.to_vec())
85+
}
86+
}
87+
88+
/// Represents the null state for Spark concat
89+
enum NullMaskResolution {
90+
/// Return NULL as the result (e.g., scalar inputs with at least one NULL)
91+
ReturnNull,
92+
/// No null mask needed (e.g., all scalar inputs are non-NULL)
93+
NoMask,
94+
/// Null mask to apply for arrays
95+
Apply(NullBuffer),
96+
}
97+
98+
/// Concatenates strings, returning NULL if any input is NULL
99+
/// This is a Spark-specific wrapper around DataFusion's concat that returns NULL
100+
/// if any argument is NULL (Spark behavior), whereas DataFusion's concat ignores NULLs.
101+
fn spark_concat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
102+
let ScalarFunctionArgs {
103+
args: arg_values,
104+
arg_fields,
105+
number_rows,
106+
return_field,
107+
config_options,
108+
} = args;
109+
110+
// Handle zero-argument case: return empty string
111+
if arg_values.is_empty() {
112+
return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(
113+
Some(String::new()),
114+
)));
115+
}
116+
117+
// Step 1: Check for NULL mask in incoming args
118+
let null_mask = compute_null_mask(&arg_values, number_rows)?;
119+
120+
// If all scalars and any is NULL, return NULL immediately
121+
if matches!(null_mask, NullMaskResolution::ReturnNull) {
122+
return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
123+
}
124+
125+
// Step 2: Delegate to DataFusion's concat
126+
let concat_func = ConcatFunc::new();
127+
let func_args = ScalarFunctionArgs {
128+
args: arg_values,
129+
arg_fields,
130+
number_rows,
131+
return_field,
132+
config_options,
133+
};
134+
let result = concat_func.invoke_with_args(func_args)?;
135+
136+
// Step 3: Apply NULL mask to result
137+
apply_null_mask(result, null_mask)
138+
}
139+
140+
/// Compute NULL mask for the arguments using NullBuffer::union
141+
fn compute_null_mask(
142+
args: &[ColumnarValue],
143+
number_rows: usize,
144+
) -> Result<NullMaskResolution> {
145+
// Check if all arguments are scalars
146+
let all_scalars = args
147+
.iter()
148+
.all(|arg| matches!(arg, ColumnarValue::Scalar(_)));
149+
150+
if all_scalars {
151+
// For scalars, check if any is NULL
152+
for arg in args {
153+
if let ColumnarValue::Scalar(scalar) = arg {
154+
if scalar.is_null() {
155+
return Ok(NullMaskResolution::ReturnNull);
156+
}
157+
}
158+
}
159+
// No NULLs in scalars
160+
Ok(NullMaskResolution::NoMask)
161+
} else {
162+
// For arrays, compute NULL mask for each row using NullBuffer::union
163+
let array_len = args
164+
.iter()
165+
.find_map(|arg| match arg {
166+
ColumnarValue::Array(array) => Some(array.len()),
167+
_ => None,
168+
})
169+
.unwrap_or(number_rows);
170+
171+
// Convert all scalars to arrays for uniform processing
172+
let arrays: Result<Vec<_>> = args
173+
.iter()
174+
.map(|arg| match arg {
175+
ColumnarValue::Array(array) => Ok(Arc::clone(array)),
176+
ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(array_len),
177+
})
178+
.collect();
179+
let arrays = arrays?;
180+
181+
// Use NullBuffer::union to combine all null buffers
182+
let combined_nulls = arrays
183+
.iter()
184+
.map(|arr| arr.nulls())
185+
.fold(None, |acc, nulls| NullBuffer::union(acc.as_ref(), nulls));
186+
187+
match combined_nulls {
188+
Some(nulls) => Ok(NullMaskResolution::Apply(nulls)),
189+
None => Ok(NullMaskResolution::NoMask),
190+
}
191+
}
192+
}
193+
194+
/// Apply NULL mask to the result using NullBuffer::union
195+
fn apply_null_mask(
196+
result: ColumnarValue,
197+
null_mask: NullMaskResolution,
198+
) -> Result<ColumnarValue> {
199+
match (result, null_mask) {
200+
// Scalar with ReturnNull mask means return NULL
201+
(ColumnarValue::Scalar(_), NullMaskResolution::ReturnNull) => {
202+
Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)))
203+
}
204+
// Scalar without mask, return as-is
205+
(scalar @ ColumnarValue::Scalar(_), NullMaskResolution::NoMask) => Ok(scalar),
206+
// Array with NULL mask - use NullBuffer::union to combine nulls
207+
(ColumnarValue::Array(array), NullMaskResolution::Apply(null_mask)) => {
208+
// Combine the result's existing nulls with our computed null mask
209+
let combined_nulls = NullBuffer::union(array.nulls(), Some(&null_mask));
210+
211+
// Create new array with combined nulls
212+
let new_array = array
213+
.into_data()
214+
.into_builder()
215+
.nulls(combined_nulls)
216+
.build()?;
217+
218+
Ok(ColumnarValue::Array(Arc::new(arrow::array::make_array(
219+
new_array,
220+
))))
221+
}
222+
// Array without NULL mask, return as-is
223+
(array @ ColumnarValue::Array(_), NullMaskResolution::NoMask) => Ok(array),
224+
// Edge cases that shouldn't happen in practice
225+
(scalar, _) => Ok(scalar),
226+
}
227+
}
228+
229+
#[cfg(test)]
230+
mod tests {
231+
use super::*;
232+
use crate::function::utils::test::test_scalar_function;
233+
use arrow::array::StringArray;
234+
use arrow::datatypes::DataType;
235+
use datafusion_common::Result;
236+
237+
#[test]
238+
fn test_concat_basic() -> Result<()> {
239+
test_scalar_function!(
240+
SparkConcat::new(),
241+
vec![
242+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("Spark".to_string()))),
243+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("SQL".to_string()))),
244+
],
245+
Ok(Some("SparkSQL")),
246+
&str,
247+
DataType::Utf8,
248+
StringArray
249+
);
250+
Ok(())
251+
}
252+
253+
#[test]
254+
fn test_concat_with_null() -> Result<()> {
255+
test_scalar_function!(
256+
SparkConcat::new(),
257+
vec![
258+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("Spark".to_string()))),
259+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("SQL".to_string()))),
260+
ColumnarValue::Scalar(ScalarValue::Utf8(None)),
261+
],
262+
Ok(None),
263+
&str,
264+
DataType::Utf8,
265+
StringArray
266+
);
267+
Ok(())
268+
}
269+
}

datafusion/spark/src/function/string/mod.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
pub mod ascii;
1919
pub mod char;
20+
pub mod concat;
2021
pub mod ilike;
2122
pub mod like;
2223
pub mod luhn_check;
@@ -27,6 +28,7 @@ use std::sync::Arc;
2728

2829
make_udf_function!(ascii::SparkAscii, ascii);
2930
make_udf_function!(char::CharFunc, char);
31+
make_udf_function!(concat::SparkConcat, concat);
3032
make_udf_function!(ilike::SparkILike, ilike);
3133
make_udf_function!(like::SparkLike, like);
3234
make_udf_function!(luhn_check::SparkLuhnCheck, luhn_check);
@@ -44,6 +46,11 @@ pub mod expr_fn {
4446
"Returns the ASCII character having the binary equivalent to col. If col is larger than 256 the result is equivalent to char(col % 256).",
4547
arg1
4648
));
49+
export_functions!((
50+
concat,
51+
"Concatenates multiple input strings into a single string. Returns NULL if any input is NULL.",
52+
args
53+
));
4754
export_functions!((
4855
ilike,
4956
"Returns true if str matches pattern (case insensitive).",
@@ -62,5 +69,5 @@ pub mod expr_fn {
6269
}
6370

6471
pub fn functions() -> Vec<Arc<ScalarUDF>> {
65-
vec![ascii(), char(), ilike(), like(), luhn_check()]
72+
vec![ascii(), char(), concat(), ilike(), like(), luhn_check()]
6673
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
query T
19+
SELECT concat('Spark', 'SQL');
20+
----
21+
SparkSQL
22+
23+
query T
24+
SELECT concat('Spark', 'SQL', NULL);
25+
----
26+
NULL
27+
28+
query T
29+
SELECT concat('', '1', '', '2');
30+
----
31+
12
32+
33+
query T
34+
SELECT concat();
35+
----
36+
(empty)
37+
38+
query T
39+
SELECT concat('');
40+
----
41+
(empty)
42+
43+
44+
query T
45+
SELECT concat(a, b, c) from (select 'a' a, 'b' b, 'c' c union all select null a, 'b', 'c') order by 1 nulls last;
46+
----
47+
abc
48+
NULL

0 commit comments

Comments
 (0)