Skip to content

Commit 4e5ee3a

Browse files
committed
chore: use NullBuffer::union for Spark concat (apache#18087)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #. Followup on apache#18063 (review) ## Rationale for this change Use cheaper `NullBuffer::union` to apply null mask instead of iterator approach <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. --> (cherry picked from commit 337378a)
1 parent d799512 commit 4e5ee3a

File tree

1 file changed

+52
-89
lines changed
  • datafusion/spark/src/function/string

1 file changed

+52
-89
lines changed

datafusion/spark/src/function/string/concat.rs

Lines changed: 52 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use arrow::array::{Array, ArrayBuilder};
18+
use arrow::array::Array;
19+
use arrow::buffer::NullBuffer;
1920
use arrow::datatypes::DataType;
2021
use datafusion_common::{Result, ScalarValue};
2122
use datafusion_expr::{
@@ -31,6 +32,10 @@ use std::sync::Arc;
3132
///
3233
/// Concatenates multiple input strings into a single string.
3334
/// Returns NULL if any input is NULL.
35+
///
36+
/// Differences with DataFusion concat:
37+
/// - Support 0 arguments
38+
/// - Return NULL if any input is NULL
3439
#[derive(Debug, PartialEq, Eq, Hash)]
3540
pub struct SparkConcat {
3641
signature: Signature,
@@ -80,6 +85,16 @@ impl ScalarUDFImpl for SparkConcat {
8085
}
8186
}
8287

88+
/// Represents the null state for Spark concat
89+
enum NullMaskResolution {
90+
/// Return NULL as the result (e.g., scalar inputs with at least one NULL)
91+
ReturnNull,
92+
/// No null mask needed (e.g., all scalar inputs are non-NULL)
93+
NoMask,
94+
/// Null mask to apply for arrays
95+
Apply(NullBuffer),
96+
}
97+
8398
/// Concatenates strings, returning NULL if any input is NULL
8499
/// This is a Spark-specific wrapper around DataFusion's concat that returns NULL
85100
/// if any argument is NULL (Spark behavior), whereas DataFusion's concat ignores NULLs.
@@ -103,7 +118,7 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
103118
let null_mask = compute_null_mask(&arg_values, number_rows)?;
104119

105120
// If all scalars and any is NULL, return NULL immediately
106-
if null_mask.is_none() {
121+
if matches!(null_mask, NullMaskResolution::ReturnNull) {
107122
return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
108123
}
109124

@@ -122,13 +137,11 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
122137
apply_null_mask(result, null_mask)
123138
}
124139

125-
/// Compute NULL mask for the arguments
126-
/// Returns None if all scalars and any is NULL, or a Vector of
127-
/// boolean representing the null mask for incoming arrays
140+
/// Compute NULL mask for the arguments using NullBuffer::union
128141
fn compute_null_mask(
129142
args: &[ColumnarValue],
130143
number_rows: usize,
131-
) -> Result<Option<Vec<bool>>> {
144+
) -> Result<NullMaskResolution> {
132145
// Check if all arguments are scalars
133146
let all_scalars = args
134147
.iter()
@@ -139,15 +152,14 @@ fn compute_null_mask(
139152
for arg in args {
140153
if let ColumnarValue::Scalar(scalar) = arg {
141154
if scalar.is_null() {
142-
// Return None to indicate all values should be NULL
143-
return Ok(None);
155+
return Ok(NullMaskResolution::ReturnNull);
144156
}
145157
}
146158
}
147159
// No NULLs in scalars
148-
Ok(Some(vec![]))
160+
Ok(NullMaskResolution::NoMask)
149161
} else {
150-
// For arrays, compute NULL mask for each row
162+
// For arrays, compute NULL mask for each row using NullBuffer::union
151163
let array_len = args
152164
.iter()
153165
.find_map(|arg| match arg {
@@ -166,99 +178,50 @@ fn compute_null_mask(
166178
.collect();
167179
let arrays = arrays?;
168180

169-
// Compute NULL mask
170-
let mut null_mask = vec![false; array_len];
171-
for array in &arrays {
172-
for (i, null_flag) in null_mask.iter_mut().enumerate().take(array_len) {
173-
if array.is_null(i) {
174-
*null_flag = true;
175-
}
176-
}
177-
}
181+
// Use NullBuffer::union to combine all null buffers
182+
let combined_nulls = arrays
183+
.iter()
184+
.map(|arr| arr.nulls())
185+
.fold(None, |acc, nulls| NullBuffer::union(acc.as_ref(), nulls));
178186

179-
Ok(Some(null_mask))
187+
match combined_nulls {
188+
Some(nulls) => Ok(NullMaskResolution::Apply(nulls)),
189+
None => Ok(NullMaskResolution::NoMask),
190+
}
180191
}
181192
}
182193

183-
/// Apply NULL mask to the result
194+
/// Apply NULL mask to the result using NullBuffer::union
184195
fn apply_null_mask(
185196
result: ColumnarValue,
186-
null_mask: Option<Vec<bool>>,
197+
null_mask: NullMaskResolution,
187198
) -> Result<ColumnarValue> {
188199
match (result, null_mask) {
189-
// Scalar with NULL mask means return NULL
190-
(ColumnarValue::Scalar(_), None) => {
200+
// Scalar with ReturnNull mask means return NULL
201+
(ColumnarValue::Scalar(_), NullMaskResolution::ReturnNull) => {
191202
Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)))
192203
}
193-
// Scalar without NULL mask, return as-is
194-
(scalar @ ColumnarValue::Scalar(_), Some(mask)) if mask.is_empty() => Ok(scalar),
195-
// Array with NULL mask
196-
(ColumnarValue::Array(array), Some(null_mask)) if !null_mask.is_empty() => {
197-
let array_len = array.len();
198-
let return_type = array.data_type();
204+
// Scalar without mask, return as-is
205+
(scalar @ ColumnarValue::Scalar(_), NullMaskResolution::NoMask) => Ok(scalar),
206+
// Array with NULL mask - use NullBuffer::union to combine nulls
207+
(ColumnarValue::Array(array), NullMaskResolution::Apply(null_mask)) => {
208+
// Combine the result's existing nulls with our computed null mask
209+
let combined_nulls = NullBuffer::union(array.nulls(), Some(&null_mask));
199210

200-
let mut builder: Box<dyn ArrayBuilder> = match return_type {
201-
DataType::Utf8 => {
202-
let string_array = array
203-
.as_any()
204-
.downcast_ref::<arrow::array::StringArray>()
205-
.unwrap();
206-
let mut builder =
207-
arrow::array::StringBuilder::with_capacity(array_len, 0);
208-
for (i, &is_null) in null_mask.iter().enumerate().take(array_len) {
209-
if is_null || string_array.is_null(i) {
210-
builder.append_null();
211-
} else {
212-
builder.append_value(string_array.value(i));
213-
}
214-
}
215-
Box::new(builder)
216-
}
217-
DataType::LargeUtf8 => {
218-
let string_array = array
219-
.as_any()
220-
.downcast_ref::<arrow::array::LargeStringArray>()
221-
.unwrap();
222-
let mut builder =
223-
arrow::array::LargeStringBuilder::with_capacity(array_len, 0);
224-
for (i, &is_null) in null_mask.iter().enumerate().take(array_len) {
225-
if is_null || string_array.is_null(i) {
226-
builder.append_null();
227-
} else {
228-
builder.append_value(string_array.value(i));
229-
}
230-
}
231-
Box::new(builder)
232-
}
233-
DataType::Utf8View => {
234-
let string_array = array
235-
.as_any()
236-
.downcast_ref::<arrow::array::StringViewArray>()
237-
.unwrap();
238-
let mut builder =
239-
arrow::array::StringViewBuilder::with_capacity(array_len);
240-
for (i, &is_null) in null_mask.iter().enumerate().take(array_len) {
241-
if is_null || string_array.is_null(i) {
242-
builder.append_null();
243-
} else {
244-
builder.append_value(string_array.value(i));
245-
}
246-
}
247-
Box::new(builder)
248-
}
249-
_ => {
250-
return datafusion_common::exec_err!(
251-
"Unsupported return type for concat: {:?}",
252-
return_type
253-
);
254-
}
255-
};
211+
// Create new array with combined nulls
212+
let new_array = array
213+
.into_data()
214+
.into_builder()
215+
.nulls(combined_nulls)
216+
.build()?;
256217

257-
Ok(ColumnarValue::Array(builder.finish()))
218+
Ok(ColumnarValue::Array(Arc::new(arrow::array::make_array(
219+
new_array,
220+
))))
258221
}
259222
// Array without NULL mask, return as-is
260-
(array @ ColumnarValue::Array(_), _) => Ok(array),
261-
// Shouldn't happen
223+
(array @ ColumnarValue::Array(_), NullMaskResolution::NoMask) => Ok(array),
224+
// Edge cases that shouldn't happen in practice
262225
(scalar, _) => Ok(scalar),
263226
}
264227
}

0 commit comments

Comments
 (0)