Skip to content

Commit 70a87c6

Browse files
joroKr21avantgardnerio
authored andcommitted
Ensure that math functions fulfil the ColumnarValue contract (#275)
If all UDF arguments are scalars, so should be the result. In most cases, such function calls will be contant-folded, however if for whatever reason the are not optimized, we want to avoid an error due to array length mismatch.
1 parent 213879e commit 70a87c6

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

datafusion/expr-common/src/columnar_value.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,17 @@ impl ColumnarValue {
217217
}
218218
}
219219
}
220+
221+
/// Converts an [`ArrayRef`] to a [`ColumnarValue`] based on the supplied arguments.
222+
/// This is useful for scalar UDF implementations to fulfil their contract:
223+
/// if all arguments are scalar values, the result should also be a scalar value.
224+
pub fn from_args_and_result(args: &[Self], result: ArrayRef) -> Result<Self> {
225+
if result.len() == 1 && args.iter().all(|arg| matches!(arg, Self::Scalar(_))) {
226+
Ok(Self::Scalar(ScalarValue::try_from_array(&result, 0)?))
227+
} else {
228+
Ok(Self::Array(result))
229+
}
230+
}
220231
}
221232

222233
#[cfg(test)]

datafusion/functions/src/macros.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,8 @@ macro_rules! make_math_unary_udf {
208208
$EVALUATE_BOUNDS(inputs)
209209
}
210210

211-
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
212-
let args = ColumnarValue::values_to_arrays(args)?;
211+
fn invoke(&self, col_args: &[ColumnarValue]) -> Result<ColumnarValue> {
212+
let args = ColumnarValue::values_to_arrays(col_args)?;
213213
let arr: ArrayRef = match args[0].data_type() {
214214
DataType::Float64 => Arc::new(
215215
args[0]
@@ -229,7 +229,7 @@ macro_rules! make_math_unary_udf {
229229
}
230230
};
231231

232-
Ok(ColumnarValue::Array(arr))
232+
ColumnarValue::from_args_and_result(col_args, arr)
233233
}
234234

235235
fn documentation(&self) -> Option<&Documentation> {
@@ -316,8 +316,8 @@ macro_rules! make_math_binary_udf {
316316
$OUTPUT_ORDERING(input)
317317
}
318318

319-
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
320-
let args = ColumnarValue::values_to_arrays(args)?;
319+
fn invoke(&self, col_args: &[ColumnarValue]) -> Result<ColumnarValue> {
320+
let args = ColumnarValue::values_to_arrays(col_args)?;
321321
let arr: ArrayRef = match args[0].data_type() {
322322
DataType::Float64 => {
323323
let y = args[0].as_primitive::<Float64Type>();
@@ -347,7 +347,7 @@ macro_rules! make_math_binary_udf {
347347
}
348348
};
349349

350-
Ok(ColumnarValue::Array(arr))
350+
ColumnarValue::from_args_and_result(col_args, arr)
351351
}
352352

353353
fn documentation(&self) -> Option<&Documentation> {

0 commit comments

Comments
 (0)