Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions datafusion/expr/src/predicate_bounds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,9 @@ mod tests {
(binary_expr(t.clone(), And, f.clone()), NullableInterval::FALSE),
(binary_expr(f.clone(), And, t.clone()), NullableInterval::FALSE),
(binary_expr(f.clone(), And, f.clone()), NullableInterval::FALSE),
(binary_expr(t.clone(), And, func.clone()), NullableInterval::ANY_TRUTH_VALUE),
(binary_expr(func.clone(), And, t.clone()), NullableInterval::ANY_TRUTH_VALUE),
// UDF with no arguments has no nullable inputs, so output is not nullable (TRUE_OR_FALSE)
(binary_expr(t.clone(), And, func.clone()), NullableInterval::TRUE_OR_FALSE),
(binary_expr(func.clone(), And, t.clone()), NullableInterval::TRUE_OR_FALSE),
(binary_expr(f.clone(), And, func.clone()), NullableInterval::FALSE),
(binary_expr(func.clone(), And, f.clone()), NullableInterval::FALSE),
(binary_expr(null.clone(), And, func.clone()), NullableInterval::FALSE_OR_UNKNOWN),
Expand Down Expand Up @@ -344,8 +345,9 @@ mod tests {
(binary_expr(f.clone(), Or, f.clone()), NullableInterval::FALSE),
(binary_expr(t.clone(), Or, func.clone()), NullableInterval::TRUE),
(binary_expr(func.clone(), Or, t.clone()), NullableInterval::TRUE),
(binary_expr(f.clone(), Or, func.clone()), NullableInterval::ANY_TRUTH_VALUE),
(binary_expr(func.clone(), Or, f.clone()), NullableInterval::ANY_TRUTH_VALUE),
// UDF with no arguments has no nullable inputs, so output is not nullable (TRUE_OR_FALSE)
(binary_expr(f.clone(), Or, func.clone()), NullableInterval::TRUE_OR_FALSE),
(binary_expr(func.clone(), Or, f.clone()), NullableInterval::TRUE_OR_FALSE),
(binary_expr(null.clone(), Or, func.clone()), NullableInterval::TRUE_OR_UNKNOWN),
(binary_expr(func.clone(), Or, null.clone()), NullableInterval::TRUE_OR_UNKNOWN),
];
Expand Down Expand Up @@ -376,7 +378,8 @@ mod tests {
(not(zero.clone()), NullableInterval::TRUE),
(not(t.clone()), NullableInterval::FALSE),
(not(f.clone()), NullableInterval::TRUE),
(not(func.clone()), NullableInterval::ANY_TRUTH_VALUE),
// UDF with no arguments has no nullable inputs, so output is not nullable (TRUE_OR_FALSE)
(not(func.clone()), NullableInterval::TRUE_OR_FALSE),
];

for case in cases {
Expand Down Expand Up @@ -654,11 +657,13 @@ mod tests {
fn evaluate_bounds_udf() {
let func = make_scalar_func_expr();

// UDF with no arguments has no nullable inputs, so output is not nullable
// This means the predicate can be true or false, but never null (UNKNOWN)
#[rustfmt::skip]
let cases = vec![
(func.clone(), NullableInterval::ANY_TRUTH_VALUE),
(not(func.clone()), NullableInterval::ANY_TRUTH_VALUE),
(binary_expr(func.clone(), And, func.clone()), NullableInterval::ANY_TRUTH_VALUE),
(func.clone(), NullableInterval::TRUE_OR_FALSE),
(not(func.clone()), NullableInterval::TRUE_OR_FALSE),
(binary_expr(func.clone(), And, func.clone()), NullableInterval::TRUE_OR_FALSE),
];

for case in cases {
Expand Down
82 changes: 79 additions & 3 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync {
/// fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
/// // report output is only nullable if any one of the arguments are nullable
/// let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
/// let field = Arc::new(Field::new("ignored_name", DataType::Int32, true));
/// let field = Arc::new(Field::new("ignored_name", DataType::Int32, nullable));
/// Ok(field)
/// }
/// # }
Expand Down Expand Up @@ -638,7 +638,11 @@ pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync {
.cloned()
.collect::<Vec<_>>();
let return_type = self.return_type(&data_types)?;
Ok(Arc::new(Field::new(self.name(), return_type, true)))
// The output is nullable if any of the input arguments are nullable.
// For functions with different null semantics (e.g., concat ignores nulls,
// coalesce only returns null if all inputs are null), override this method.
let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
Ok(Arc::new(Field::new(self.name(), return_type, nullable)))
}

#[deprecated(
Expand Down Expand Up @@ -969,6 +973,7 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl {
#[cfg(test)]
mod tests {
use super::*;
use arrow::datatypes::Field;
use datafusion_expr_common::signature::Volatility;
use std::hash::DefaultHasher;

Expand All @@ -992,14 +997,85 @@ mod tests {
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
unimplemented!()
Ok(DataType::Int32)
}

fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
unimplemented!()
}
}

/// Test that the default `return_field_from_args` implementation correctly
/// computes nullability based on input argument nullability.
///
/// This is a regression test for the bug where UDFs always returned
/// `is_nullable = true` regardless of input nullability.
#[test]
fn test_return_field_nullability_from_args() {
let udf = ScalarUDF::from(TestScalarUDFImpl {
name: "test_func",
field: "test",
signature: Signature::any(2, Volatility::Immutable),
});

// All non-nullable inputs -> output should be non-nullable
let non_nullable_fields: Vec<FieldRef> = vec![
Arc::new(Field::new("a", DataType::Int32, false)),
Arc::new(Field::new("b", DataType::Int32, false)),
];
let args = ReturnFieldArgs {
arg_fields: &non_nullable_fields,
scalar_arguments: &[],
};
let result = udf.return_field_from_args(args).unwrap();
assert!(
!result.is_nullable(),
"Output should be non-nullable when all inputs are non-nullable"
);

// One nullable input -> output should be nullable
let one_nullable_fields: Vec<FieldRef> = vec![
Arc::new(Field::new("a", DataType::Int32, false)),
Arc::new(Field::new("b", DataType::Int32, true)), // nullable
];
let args = ReturnFieldArgs {
arg_fields: &one_nullable_fields,
scalar_arguments: &[],
};
let result = udf.return_field_from_args(args).unwrap();
assert!(
result.is_nullable(),
"Output should be nullable when any input is nullable"
);

// All nullable inputs -> output should be nullable
let all_nullable_fields: Vec<FieldRef> = vec![
Arc::new(Field::new("a", DataType::Int32, true)),
Arc::new(Field::new("b", DataType::Int32, true)),
];
let args = ReturnFieldArgs {
arg_fields: &all_nullable_fields,
scalar_arguments: &[],
};
let result = udf.return_field_from_args(args).unwrap();
assert!(
result.is_nullable(),
"Output should be nullable when all inputs are nullable"
);

// No inputs -> output should be non-nullable
let no_fields: Vec<FieldRef> = vec![];
let args = ReturnFieldArgs {
arg_fields: &no_fields,
scalar_arguments: &[],
};
let result = udf.return_field_from_args(args).unwrap();
assert!(
!result.is_nullable(),
"Output should be non-nullable when there are no inputs"
);
}

// PartialEq and Hash must be consistent, and also PartialEq and PartialOrd
// must be consistent, so they are tested together.
#[test]
Expand Down