Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 52 additions & 89 deletions datafusion/spark/src/function/string/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
// specific language governing permissions and limitations
// under the License.

use arrow::array::{Array, ArrayBuilder};
use arrow::array::Array;
use arrow::buffer::NullBuffer;
use arrow::datatypes::DataType;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{
Expand All @@ -31,6 +32,10 @@ use std::sync::Arc;
///
/// Concatenates multiple input strings into a single string.
/// Returns NULL if any input is NULL.
///
/// Differences with DataFusion concat:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

/// - Support 0 arguments
/// - Return NULL if any input is NULL
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SparkConcat {
signature: Signature,
Expand Down Expand Up @@ -80,6 +85,16 @@ impl ScalarUDFImpl for SparkConcat {
}
}

/// Represents the null state for Spark concat
enum NullMaskResolution {
/// Return NULL as the result (e.g., scalar inputs with at least one NULL)
ReturnNull,
/// No null mask needed (e.g., all scalar inputs are non-NULL)
NoMask,
/// Null mask to apply for arrays
Apply(NullBuffer),
}

/// Concatenates strings, returning NULL if any input is NULL
/// This is a Spark-specific wrapper around DataFusion's concat that returns NULL
/// if any argument is NULL (Spark behavior), whereas DataFusion's concat ignores NULLs.
Expand All @@ -103,7 +118,7 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let null_mask = compute_null_mask(&arg_values, number_rows)?;

// If all scalars and any is NULL, return NULL immediately
if null_mask.is_none() {
if matches!(null_mask, NullMaskResolution::ReturnNull) {
return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
}

Expand All @@ -122,13 +137,11 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
apply_null_mask(result, null_mask)
}

/// Compute NULL mask for the arguments
/// Returns None if all scalars and any is NULL, or a Vector of
/// boolean representing the null mask for incoming arrays
/// Compute NULL mask for the arguments using NullBuffer::union
fn compute_null_mask(
args: &[ColumnarValue],
number_rows: usize,
) -> Result<Option<Vec<bool>>> {
) -> Result<NullMaskResolution> {
// Check if all arguments are scalars
let all_scalars = args
.iter()
Expand All @@ -139,15 +152,14 @@ fn compute_null_mask(
for arg in args {
if let ColumnarValue::Scalar(scalar) = arg {
if scalar.is_null() {
// Return None to indicate all values should be NULL
return Ok(None);
return Ok(NullMaskResolution::ReturnNull);
}
}
}
// No NULLs in scalars
Ok(Some(vec![]))
Ok(NullMaskResolution::NoMask)
} else {
// For arrays, compute NULL mask for each row
// For arrays, compute NULL mask for each row using NullBuffer::union
let array_len = args
.iter()
.find_map(|arg| match arg {
Expand All @@ -166,99 +178,50 @@ fn compute_null_mask(
.collect();
let arrays = arrays?;

// Compute NULL mask
let mut null_mask = vec![false; array_len];
for array in &arrays {
for (i, null_flag) in null_mask.iter_mut().enumerate().take(array_len) {
if array.is_null(i) {
*null_flag = true;
}
}
}
// Use NullBuffer::union to combine all null buffers
let combined_nulls = arrays
.iter()
.map(|arr| arr.nulls())
.fold(None, |acc, nulls| NullBuffer::union(acc.as_ref(), nulls));

Ok(Some(null_mask))
match combined_nulls {
Some(nulls) => Ok(NullMaskResolution::Apply(nulls)),
None => Ok(NullMaskResolution::NoMask),
}
}
}

/// Apply NULL mask to the result
/// Apply NULL mask to the result using NullBuffer::union
fn apply_null_mask(
result: ColumnarValue,
null_mask: Option<Vec<bool>>,
null_mask: NullMaskResolution,
) -> Result<ColumnarValue> {
match (result, null_mask) {
// Scalar with NULL mask means return NULL
(ColumnarValue::Scalar(_), None) => {
// Scalar with ReturnNull mask means return NULL
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👨‍🍳 👌

(ColumnarValue::Scalar(_), NullMaskResolution::ReturnNull) => {
Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)))
}
// Scalar without NULL mask, return as-is
(scalar @ ColumnarValue::Scalar(_), Some(mask)) if mask.is_empty() => Ok(scalar),
// Array with NULL mask
(ColumnarValue::Array(array), Some(null_mask)) if !null_mask.is_empty() => {
let array_len = array.len();
let return_type = array.data_type();
// Scalar without mask, return as-is
(scalar @ ColumnarValue::Scalar(_), NullMaskResolution::NoMask) => Ok(scalar),
// Array with NULL mask - use NullBuffer::union to combine nulls
(ColumnarValue::Array(array), NullMaskResolution::Apply(null_mask)) => {
// Combine the result's existing nulls with our computed null mask
let combined_nulls = NullBuffer::union(array.nulls(), Some(&null_mask));

let mut builder: Box<dyn ArrayBuilder> = match return_type {
DataType::Utf8 => {
let string_array = array
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.unwrap();
let mut builder =
arrow::array::StringBuilder::with_capacity(array_len, 0);
for (i, &is_null) in null_mask.iter().enumerate().take(array_len) {
if is_null || string_array.is_null(i) {
builder.append_null();
} else {
builder.append_value(string_array.value(i));
}
}
Box::new(builder)
}
DataType::LargeUtf8 => {
let string_array = array
.as_any()
.downcast_ref::<arrow::array::LargeStringArray>()
.unwrap();
let mut builder =
arrow::array::LargeStringBuilder::with_capacity(array_len, 0);
for (i, &is_null) in null_mask.iter().enumerate().take(array_len) {
if is_null || string_array.is_null(i) {
builder.append_null();
} else {
builder.append_value(string_array.value(i));
}
}
Box::new(builder)
}
DataType::Utf8View => {
let string_array = array
.as_any()
.downcast_ref::<arrow::array::StringViewArray>()
.unwrap();
let mut builder =
arrow::array::StringViewBuilder::with_capacity(array_len);
for (i, &is_null) in null_mask.iter().enumerate().take(array_len) {
if is_null || string_array.is_null(i) {
builder.append_null();
} else {
builder.append_value(string_array.value(i));
}
}
Box::new(builder)
}
_ => {
return datafusion_common::exec_err!(
"Unsupported return type for concat: {:?}",
return_type
);
}
};
// Create new array with combined nulls
let new_array = array
.into_data()
.into_builder()
.nulls(combined_nulls)
.build()?;

Ok(ColumnarValue::Array(builder.finish()))
Ok(ColumnarValue::Array(Arc::new(arrow::array::make_array(
new_array,
))))
}
// Array without NULL mask, return as-is
(array @ ColumnarValue::Array(_), _) => Ok(array),
// Shouldn't happen
(array @ ColumnarValue::Array(_), NullMaskResolution::NoMask) => Ok(array),
// Edge cases that shouldn't happen in practice
(scalar, _) => Ok(scalar),
}
}
Expand Down