From afdc58cff847618fca56a3df3e081ab711f402f4 Mon Sep 17 00:00:00 2001 From: Yuvraj Date: Mon, 29 Dec 2025 01:31:12 +0530 Subject: [PATCH] Fix SparkAscii nullability to depend on input nullability --- datafusion/spark/src/function/string/ascii.rs | 63 +++++++++++++++++-- 1 file changed, 59 insertions(+), 4 deletions(-) diff --git a/datafusion/spark/src/function/string/ascii.rs b/datafusion/spark/src/function/string/ascii.rs index f14a66d4e484d..117881f2f122b 100644 --- a/datafusion/spark/src/function/string/ascii.rs +++ b/datafusion/spark/src/function/string/ascii.rs @@ -15,10 +15,14 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::DataType; -use datafusion_common::Result; +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::{Result, internal_err}; use datafusion_expr::ColumnarValue; -use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; use datafusion_functions::string::ascii::ascii; use datafusion_functions::utils::make_scalar_function; use std::any::Any; @@ -62,7 +66,14 @@ impl ScalarUDFImpl for SparkAscii { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Int32) + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + // ascii returns an Int32 value + // The result is nullable only if any of the input arguments is nullable + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + Ok(Arc::new(Field::new("ascii", DataType::Int32, nullable))) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -73,3 +84,47 @@ impl ScalarUDFImpl for SparkAscii { Ok(vec![DataType::Utf8]) } } + +#[cfg(test)] +mod tests { + use super::*; + use datafusion_expr::ReturnFieldArgs; + + #[test] + fn test_return_field_nullable_input() { + let ascii_func = SparkAscii::new(); + let nullable_field = Arc::new(Field::new("input", DataType::Utf8, true)); + + let result = ascii_func + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[nullable_field], + scalar_arguments: &[], + }) + .unwrap(); + + assert_eq!(result.data_type(), &DataType::Int32); + assert!( + result.is_nullable(), + "Output should be nullable when input is nullable" + ); + } + + #[test] + fn test_return_field_non_nullable_input() { + let ascii_func = SparkAscii::new(); + let non_nullable_field = Arc::new(Field::new("input", DataType::Utf8, false)); + + let result = ascii_func + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[non_nullable_field], + scalar_arguments: &[], + }) + .unwrap(); + + assert_eq!(result.data_type(), &DataType::Int32); + assert!( + !result.is_nullable(), + "Output should not be nullable when input is not nullable" + ); + } +}