Skip to content

Commit 399e840

Browse files
authored
Ensure that math functions fulfil the ColumnarValue contract (#12922)
If all UDF arguments are scalars, so should be the result. In most cases, such function calls will be contant-folded, however if for whatever reason the are not optimized, we want to avoid an error due to array length mismatch.
1 parent 5a0ea0b commit 399e840

File tree

2 files changed

+20
-10
lines changed

2 files changed

+20
-10
lines changed

datafusion/expr-common/src/columnar_value.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717

1818
//! [`ColumnarValue`] represents the result of evaluating an expression.
1919
20-
use arrow::array::ArrayRef;
21-
use arrow::array::NullArray;
20+
use arrow::array::{Array, ArrayRef, NullArray};
2221
use arrow::compute::{kernels, CastOptions};
2322
use arrow::datatypes::{DataType, TimeUnit};
2423
use datafusion_common::format::DEFAULT_CAST_OPTIONS;
@@ -218,6 +217,17 @@ impl ColumnarValue {
218217
}
219218
}
220219
}
220+
221+
/// Converts an [`ArrayRef`] to a [`ColumnarValue`] based on the supplied arguments.
222+
/// This is useful for scalar UDF implementations to fulfil their contract:
223+
/// if all arguments are scalar values, the result should also be a scalar value.
224+
pub fn from_args_and_result(args: &[Self], result: ArrayRef) -> Result<Self> {
225+
if result.len() == 1 && args.iter().all(|arg| matches!(arg, Self::Scalar(_))) {
226+
Ok(Self::Scalar(ScalarValue::try_from_array(&result, 0)?))
227+
} else {
228+
Ok(Self::Array(result))
229+
}
230+
}
221231
}
222232

223233
#[cfg(test)]

datafusion/functions/src/macros.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -228,9 +228,8 @@ macro_rules! make_math_unary_udf {
228228
$EVALUATE_BOUNDS(inputs)
229229
}
230230

231-
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
232-
let args = ColumnarValue::values_to_arrays(args)?;
233-
231+
fn invoke(&self, col_args: &[ColumnarValue]) -> Result<ColumnarValue> {
232+
let args = ColumnarValue::values_to_arrays(col_args)?;
234233
let arr: ArrayRef = match args[0].data_type() {
235234
DataType::Float64 => {
236235
Arc::new(make_function_scalar_inputs_return_type!(
@@ -257,7 +256,8 @@ macro_rules! make_math_unary_udf {
257256
)
258257
}
259258
};
260-
Ok(ColumnarValue::Array(arr))
259+
260+
ColumnarValue::from_args_and_result(col_args, arr)
261261
}
262262

263263
fn documentation(&self) -> Option<&Documentation> {
@@ -344,9 +344,8 @@ macro_rules! make_math_binary_udf {
344344
$OUTPUT_ORDERING(input)
345345
}
346346

347-
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
348-
let args = ColumnarValue::values_to_arrays(args)?;
349-
347+
fn invoke(&self, col_args: &[ColumnarValue]) -> Result<ColumnarValue> {
348+
let args = ColumnarValue::values_to_arrays(col_args)?;
350349
let arr: ArrayRef = match args[0].data_type() {
351350
DataType::Float64 => Arc::new(make_function_inputs2!(
352351
&args[0],
@@ -372,7 +371,8 @@ macro_rules! make_math_binary_udf {
372371
)
373372
}
374373
};
375-
Ok(ColumnarValue::Array(arr))
374+
375+
ColumnarValue::from_args_and_result(col_args, arr)
376376
}
377377

378378
fn documentation(&self) -> Option<&Documentation> {

0 commit comments

Comments
 (0)