Skip to content

Commit c104bf4

Browse files
committed
feat: support Spark concat
1 parent 72499e6 commit c104bf4

File tree

2 files changed

+321
-0
lines changed

2 files changed

+321
-0
lines changed
Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::Array;
19+
use arrow::datatypes::DataType;
20+
use datafusion_common::{Result, ScalarValue};
21+
use datafusion_expr::{
22+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
23+
Volatility,
24+
};
25+
use datafusion_functions::string::concat::ConcatFunc;
26+
use std::{any::Any, sync::Arc};
27+
28+
/// Spark-compatible `concat` expression
29+
/// <https://spark.apache.org/docs/latest/api/sql/index.html#concat>
30+
///
31+
/// Concatenates multiple input strings into a single string.
32+
/// Returns NULL if any input is NULL.
33+
#[derive(Debug, PartialEq, Eq, Hash)]
34+
pub struct SparkConcat {
35+
signature: Signature,
36+
}
37+
38+
impl Default for SparkConcat {
39+
fn default() -> Self {
40+
Self::new()
41+
}
42+
}
43+
44+
impl SparkConcat {
45+
pub fn new() -> Self {
46+
Self {
47+
signature: Signature::one_of(
48+
vec![TypeSignature::UserDefined, TypeSignature::Nullary],
49+
Volatility::Immutable,
50+
),
51+
}
52+
}
53+
}
54+
55+
impl ScalarUDFImpl for SparkConcat {
56+
fn as_any(&self) -> &dyn Any {
57+
self
58+
}
59+
60+
fn name(&self) -> &str {
61+
"concat"
62+
}
63+
64+
fn signature(&self) -> &Signature {
65+
&self.signature
66+
}
67+
68+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
69+
Ok(DataType::Utf8)
70+
}
71+
72+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
73+
spark_concat(args)
74+
}
75+
76+
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
77+
// Accept any string types, including zero arguments
78+
Ok(arg_types.to_vec())
79+
}
80+
}
81+
82+
/// Concatenates strings, returning NULL if any input is NULL
83+
/// This is a Spark-specific wrapper around DataFusion's concat that returns NULL
84+
/// if any argument is NULL (Spark behavior), whereas DataFusion's concat ignores NULLs.
85+
fn spark_concat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
86+
let ScalarFunctionArgs {
87+
args: arg_values,
88+
arg_fields,
89+
number_rows,
90+
return_field,
91+
config_options,
92+
} = args;
93+
94+
if arg_values.is_empty() {
95+
return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(
96+
Some(String::new()),
97+
)));
98+
}
99+
100+
// Check if all arguments are scalars
101+
let all_scalars = arg_values
102+
.iter()
103+
.all(|arg| matches!(arg, ColumnarValue::Scalar(_)));
104+
105+
if all_scalars {
106+
// For scalars, check if any is NULL
107+
for arg in &arg_values {
108+
if let ColumnarValue::Scalar(scalar) = arg {
109+
if scalar.is_null() {
110+
// Return NULL if any argument is NULL (Spark behavior)
111+
return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
112+
}
113+
}
114+
}
115+
// No NULLs found, delegate to DataFusion's concat
116+
let concat_func = ConcatFunc::new();
117+
let func_args = ScalarFunctionArgs {
118+
args: arg_values,
119+
arg_fields,
120+
number_rows,
121+
return_field,
122+
config_options,
123+
};
124+
concat_func.invoke_with_args(func_args)
125+
} else {
126+
// For arrays, we need to check each row for NULLs and return NULL for that row
127+
// Get array length
128+
let array_len = arg_values
129+
.iter()
130+
.find_map(|arg| match arg {
131+
ColumnarValue::Array(array) => Some(array.len()),
132+
_ => None,
133+
})
134+
.unwrap_or(number_rows);
135+
136+
// Convert all scalars to arrays
137+
let arrays: Result<Vec<_>> = arg_values
138+
.iter()
139+
.map(|arg| match arg {
140+
ColumnarValue::Array(array) => Ok(Arc::clone(array)),
141+
ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(array_len),
142+
})
143+
.collect();
144+
let arrays = arrays?;
145+
146+
// Check for NULL values in each row
147+
let mut null_mask = vec![false; array_len];
148+
for array in &arrays {
149+
for (i, null_flag) in null_mask.iter_mut().enumerate().take(array_len) {
150+
if array.is_null(i) {
151+
*null_flag = true;
152+
}
153+
}
154+
}
155+
156+
// Delegate to DataFusion's concat
157+
let concat_func = ConcatFunc::new();
158+
let func_args = ScalarFunctionArgs {
159+
args: arg_values,
160+
arg_fields,
161+
number_rows,
162+
return_field,
163+
config_options,
164+
};
165+
166+
let result = concat_func.invoke_with_args(func_args)?;
167+
168+
// Apply NULL mask to the result
169+
match result {
170+
ColumnarValue::Array(array) => {
171+
let return_type = array.data_type();
172+
let mut builder: Box<dyn arrow::array::ArrayBuilder> = match return_type {
173+
DataType::Utf8 => {
174+
let string_array = array
175+
.as_any()
176+
.downcast_ref::<arrow::array::StringArray>()
177+
.unwrap();
178+
let mut builder =
179+
arrow::array::StringBuilder::with_capacity(array_len, 0);
180+
for (i, &is_null) in null_mask.iter().enumerate().take(array_len)
181+
{
182+
if is_null || string_array.is_null(i) {
183+
builder.append_null();
184+
} else {
185+
builder.append_value(string_array.value(i));
186+
}
187+
}
188+
Box::new(builder)
189+
}
190+
DataType::LargeUtf8 => {
191+
let string_array = array
192+
.as_any()
193+
.downcast_ref::<arrow::array::LargeStringArray>()
194+
.unwrap();
195+
let mut builder =
196+
arrow::array::LargeStringBuilder::with_capacity(array_len, 0);
197+
for (i, &is_null) in null_mask.iter().enumerate().take(array_len)
198+
{
199+
if is_null || string_array.is_null(i) {
200+
builder.append_null();
201+
} else {
202+
builder.append_value(string_array.value(i));
203+
}
204+
}
205+
Box::new(builder)
206+
}
207+
DataType::Utf8View => {
208+
let string_array = array
209+
.as_any()
210+
.downcast_ref::<arrow::array::StringViewArray>()
211+
.unwrap();
212+
let mut builder =
213+
arrow::array::StringViewBuilder::with_capacity(array_len);
214+
for (i, &is_null) in null_mask.iter().enumerate().take(array_len)
215+
{
216+
if is_null || string_array.is_null(i) {
217+
builder.append_null();
218+
} else {
219+
builder.append_value(string_array.value(i));
220+
}
221+
}
222+
Box::new(builder)
223+
}
224+
_ => {
225+
return datafusion_common::exec_err!(
226+
"Unsupported return type for concat: {:?}",
227+
return_type
228+
);
229+
}
230+
};
231+
232+
Ok(ColumnarValue::Array(builder.finish()))
233+
}
234+
other => Ok(other),
235+
}
236+
}
237+
}
238+
239+
#[cfg(test)]
240+
mod tests {
241+
use super::*;
242+
use crate::function::utils::test::test_scalar_function;
243+
use arrow::array::StringArray;
244+
use arrow::datatypes::DataType;
245+
use datafusion_common::Result;
246+
use datafusion_expr::ColumnarValue;
247+
248+
#[test]
249+
fn test_concat_basic() -> Result<()> {
250+
test_scalar_function!(
251+
SparkConcat::new(),
252+
vec![
253+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("Spark".to_string()))),
254+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("SQL".to_string()))),
255+
],
256+
Ok(Some("SparkSQL")),
257+
&str,
258+
DataType::Utf8,
259+
StringArray
260+
);
261+
Ok(())
262+
}
263+
264+
#[test]
265+
fn test_concat_with_null() -> Result<()> {
266+
test_scalar_function!(
267+
SparkConcat::new(),
268+
vec![
269+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("Spark".to_string()))),
270+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("SQL".to_string()))),
271+
ColumnarValue::Scalar(ScalarValue::Utf8(None)),
272+
],
273+
Ok(None),
274+
&str,
275+
DataType::Utf8,
276+
StringArray
277+
);
278+
Ok(())
279+
}
280+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
query T
19+
SELECT concat('Spark', 'SQL');
20+
----
21+
SparkSQL
22+
23+
query T
24+
SELECT concat('Spark', 'SQL', NULL);
25+
----
26+
NULL
27+
28+
query T
29+
SELECT concat('', '1', '', '2');
30+
----
31+
12
32+
33+
query T
34+
SELECT concat();
35+
----
36+
(empty)
37+
38+
query T
39+
SELECT concat('');
40+
----
41+
(empty)

0 commit comments

Comments
 (0)