Skip to content

Commit 988ee87

Browse files
committed
clippy
1 parent 3f59d81 commit 988ee87

File tree

2 files changed

+125
-93
lines changed

2 files changed

+125
-93
lines changed

datafusion/spark/src/function/string/concat.rs

Lines changed: 117 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,16 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use arrow::array::Array;
18+
use arrow::array::{Array, ArrayBuilder};
1919
use arrow::datatypes::DataType;
2020
use datafusion_common::{Result, ScalarValue};
2121
use datafusion_expr::{
2222
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
2323
Volatility,
2424
};
2525
use datafusion_functions::string::concat::ConcatFunc;
26-
use std::{any::Any, sync::Arc};
26+
use std::any::Any;
27+
use std::sync::Arc;
2728

2829
/// Spark-compatible `concat` expression
2930
/// <https://spark.apache.org/docs/latest/api/sql/index.html#concat>
@@ -97,44 +98,64 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
9798
)));
9899
}
99100

101+
// Step 1: Check for NULL mask in incoming args
102+
let null_mask = compute_null_mask(&arg_values, number_rows)?;
103+
104+
// If all scalars and any is NULL, return NULL immediately
105+
if null_mask.is_none() {
106+
return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
107+
}
108+
109+
// Step 2: Delegate to DataFusion's concat
110+
let concat_func = ConcatFunc::new();
111+
let func_args = ScalarFunctionArgs {
112+
args: arg_values,
113+
arg_fields,
114+
number_rows,
115+
return_field,
116+
config_options,
117+
};
118+
let result = concat_func.invoke_with_args(func_args)?;
119+
120+
// Step 3: Apply NULL mask to result
121+
apply_null_mask(result, null_mask)
122+
}
123+
124+
/// Compute NULL mask for the arguments
125+
/// Returns None if all scalars and any is NULL, or a Vec<bool> for arrays
126+
fn compute_null_mask(
127+
args: &[ColumnarValue],
128+
number_rows: usize,
129+
) -> Result<Option<Vec<bool>>> {
100130
// Check if all arguments are scalars
101-
let all_scalars = arg_values
131+
let all_scalars = args
102132
.iter()
103133
.all(|arg| matches!(arg, ColumnarValue::Scalar(_)));
104134

105135
if all_scalars {
106136
// For scalars, check if any is NULL
107-
for arg in &arg_values {
137+
for arg in args {
108138
if let ColumnarValue::Scalar(scalar) = arg {
109139
if scalar.is_null() {
110-
// Return NULL if any argument is NULL (Spark behavior)
111-
return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
140+
// Return None to indicate all values should be NULL
141+
return Ok(None);
112142
}
113143
}
114144
}
115-
// No NULLs found, delegate to DataFusion's concat
116-
let concat_func = ConcatFunc::new();
117-
let func_args = ScalarFunctionArgs {
118-
args: arg_values,
119-
arg_fields,
120-
number_rows,
121-
return_field,
122-
config_options,
123-
};
124-
concat_func.invoke_with_args(func_args)
145+
// No NULLs in scalars
146+
Ok(Some(vec![]))
125147
} else {
126-
// For arrays, we need to check each row for NULLs and return NULL for that row
127-
// Get array length
128-
let array_len = arg_values
148+
// For arrays, compute NULL mask for each row
149+
let array_len = args
129150
.iter()
130151
.find_map(|arg| match arg {
131152
ColumnarValue::Array(array) => Some(array.len()),
132153
_ => None,
133154
})
134155
.unwrap_or(number_rows);
135156

136-
// Convert all scalars to arrays
137-
let arrays: Result<Vec<_>> = arg_values
157+
// Convert all scalars to arrays for uniform processing
158+
let arrays: Result<Vec<_>> = args
138159
.iter()
139160
.map(|arg| match arg {
140161
ColumnarValue::Array(array) => Ok(Arc::clone(array)),
@@ -143,7 +164,7 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
143164
.collect();
144165
let arrays = arrays?;
145166

146-
// Check for NULL values in each row
167+
// Compute NULL mask
147168
let mut null_mask = vec![false; array_len];
148169
for array in &arrays {
149170
for (i, null_flag) in null_mask.iter_mut().enumerate().take(array_len) {
@@ -153,86 +174,90 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
153174
}
154175
}
155176

156-
// Delegate to DataFusion's concat
157-
let concat_func = ConcatFunc::new();
158-
let func_args = ScalarFunctionArgs {
159-
args: arg_values,
160-
arg_fields,
161-
number_rows,
162-
return_field,
163-
config_options,
164-
};
177+
Ok(Some(null_mask))
178+
}
179+
}
165180

166-
let result = concat_func.invoke_with_args(func_args)?;
181+
/// Apply NULL mask to the result
182+
fn apply_null_mask(
183+
result: ColumnarValue,
184+
null_mask: Option<Vec<bool>>,
185+
) -> Result<ColumnarValue> {
186+
match (result, null_mask) {
187+
// Scalar with NULL mask means return NULL
188+
(ColumnarValue::Scalar(_), None) => {
189+
Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)))
190+
}
191+
// Scalar without NULL mask, return as-is
192+
(scalar @ ColumnarValue::Scalar(_), Some(mask)) if mask.is_empty() => Ok(scalar),
193+
// Array with NULL mask
194+
(ColumnarValue::Array(array), Some(null_mask)) if !null_mask.is_empty() => {
195+
let array_len = array.len();
196+
let return_type = array.data_type();
167197

168-
// Apply NULL mask to the result
169-
match result {
170-
ColumnarValue::Array(array) => {
171-
let return_type = array.data_type();
172-
let mut builder: Box<dyn arrow::array::ArrayBuilder> = match return_type {
173-
DataType::Utf8 => {
174-
let string_array = array
175-
.as_any()
176-
.downcast_ref::<arrow::array::StringArray>()
177-
.unwrap();
178-
let mut builder =
179-
arrow::array::StringBuilder::with_capacity(array_len, 0);
180-
for (i, &is_null) in null_mask.iter().enumerate().take(array_len)
181-
{
182-
if is_null || string_array.is_null(i) {
183-
builder.append_null();
184-
} else {
185-
builder.append_value(string_array.value(i));
186-
}
198+
let mut builder: Box<dyn ArrayBuilder> = match return_type {
199+
DataType::Utf8 => {
200+
let string_array = array
201+
.as_any()
202+
.downcast_ref::<arrow::array::StringArray>()
203+
.unwrap();
204+
let mut builder =
205+
arrow::array::StringBuilder::with_capacity(array_len, 0);
206+
for (i, &is_null) in null_mask.iter().enumerate().take(array_len) {
207+
if is_null || string_array.is_null(i) {
208+
builder.append_null();
209+
} else {
210+
builder.append_value(string_array.value(i));
187211
}
188-
Box::new(builder)
189212
}
190-
DataType::LargeUtf8 => {
191-
let string_array = array
192-
.as_any()
193-
.downcast_ref::<arrow::array::LargeStringArray>()
194-
.unwrap();
195-
let mut builder =
196-
arrow::array::LargeStringBuilder::with_capacity(array_len, 0);
197-
for (i, &is_null) in null_mask.iter().enumerate().take(array_len)
198-
{
199-
if is_null || string_array.is_null(i) {
200-
builder.append_null();
201-
} else {
202-
builder.append_value(string_array.value(i));
203-
}
213+
Box::new(builder)
214+
}
215+
DataType::LargeUtf8 => {
216+
let string_array = array
217+
.as_any()
218+
.downcast_ref::<arrow::array::LargeStringArray>()
219+
.unwrap();
220+
let mut builder =
221+
arrow::array::LargeStringBuilder::with_capacity(array_len, 0);
222+
for (i, &is_null) in null_mask.iter().enumerate().take(array_len) {
223+
if is_null || string_array.is_null(i) {
224+
builder.append_null();
225+
} else {
226+
builder.append_value(string_array.value(i));
204227
}
205-
Box::new(builder)
206228
}
207-
DataType::Utf8View => {
208-
let string_array = array
209-
.as_any()
210-
.downcast_ref::<arrow::array::StringViewArray>()
211-
.unwrap();
212-
let mut builder =
213-
arrow::array::StringViewBuilder::with_capacity(array_len);
214-
for (i, &is_null) in null_mask.iter().enumerate().take(array_len)
215-
{
216-
if is_null || string_array.is_null(i) {
217-
builder.append_null();
218-
} else {
219-
builder.append_value(string_array.value(i));
220-
}
229+
Box::new(builder)
230+
}
231+
DataType::Utf8View => {
232+
let string_array = array
233+
.as_any()
234+
.downcast_ref::<arrow::array::StringViewArray>()
235+
.unwrap();
236+
let mut builder =
237+
arrow::array::StringViewBuilder::with_capacity(array_len);
238+
for (i, &is_null) in null_mask.iter().enumerate().take(array_len) {
239+
if is_null || string_array.is_null(i) {
240+
builder.append_null();
241+
} else {
242+
builder.append_value(string_array.value(i));
221243
}
222-
Box::new(builder)
223-
}
224-
_ => {
225-
return datafusion_common::exec_err!(
226-
"Unsupported return type for concat: {:?}",
227-
return_type
228-
);
229244
}
230-
};
245+
Box::new(builder)
246+
}
247+
_ => {
248+
return datafusion_common::exec_err!(
249+
"Unsupported return type for concat: {:?}",
250+
return_type
251+
);
252+
}
253+
};
231254

232-
Ok(ColumnarValue::Array(builder.finish()))
233-
}
234-
other => Ok(other),
255+
Ok(ColumnarValue::Array(builder.finish()))
235256
}
257+
// Array without NULL mask, return as-is
258+
(array @ ColumnarValue::Array(_), _) => Ok(array),
259+
// Shouldn't happen
260+
(scalar, _) => Ok(scalar),
236261
}
237262
}
238263

@@ -243,7 +268,6 @@ mod tests {
243268
use arrow::array::StringArray;
244269
use arrow::datatypes::DataType;
245270
use datafusion_common::Result;
246-
use datafusion_expr::ColumnarValue;
247271

248272
#[test]
249273
fn test_concat_basic() -> Result<()> {

datafusion/spark/src/function/string/mod.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
pub mod ascii;
1919
pub mod char;
20+
pub mod concat;
2021
pub mod elt;
2122
pub mod format_string;
2223
pub mod ilike;
@@ -30,6 +31,7 @@ use std::sync::Arc;
3031

3132
make_udf_function!(ascii::SparkAscii, ascii);
3233
make_udf_function!(char::CharFunc, char);
34+
make_udf_function!(concat::SparkConcat, concat);
3335
make_udf_function!(ilike::SparkILike, ilike);
3436
make_udf_function!(length::SparkLengthFunc, length);
3537
make_udf_function!(elt::SparkElt, elt);
@@ -50,6 +52,11 @@ pub mod expr_fn {
5052
"Returns the ASCII character having the binary equivalent to col. If col is larger than 256 the result is equivalent to char(col % 256).",
5153
arg1
5254
));
55+
export_functions!((
56+
concat,
57+
"Concatenates multiple input strings into a single string. Returns NULL if any input is NULL.",
58+
args
59+
));
5360
export_functions!((
5461
elt,
5562
"Returns the n-th input (1-indexed), e.g. returns 2nd input when n is 2. The function returns NULL if the index is 0 or exceeds the length of the array.",
@@ -86,6 +93,7 @@ pub fn functions() -> Vec<Arc<ScalarUDF>> {
8693
vec![
8794
ascii(),
8895
char(),
96+
concat(),
8997
elt(),
9098
ilike(),
9199
length(),

0 commit comments

Comments
 (0)