diff --git a/datafusion/spark/src/function/string/ascii.rs b/datafusion/spark/src/function/string/ascii.rs index 50f870ee29b96..44e3501b86adb 100644 --- a/datafusion/spark/src/function/string/ascii.rs +++ b/datafusion/spark/src/function/string/ascii.rs @@ -15,13 +15,14 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::DataType; -use datafusion_common::Result; +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::types::{NativeType, logical_string}; -use datafusion_expr::ColumnarValue; +use datafusion_common::{Result, internal_err}; use datafusion_expr::{ - Coercion, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignatureClass, - Volatility, + Coercion, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, + Signature, TypeSignatureClass, Volatility, }; use datafusion_functions::string::ascii::ascii; use datafusion_functions::utils::make_scalar_function; @@ -75,10 +76,61 @@ impl ScalarUDFImpl for SparkAscii { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Int32) + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + // ascii returns an Int32 value + // The result is nullable only if any of the input arguments is nullable + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + Ok(Arc::new(Field::new("ascii", DataType::Int32, nullable))) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(ascii, vec![])(&args.args) } } + +#[cfg(test)] +mod tests { + use super::*; + use datafusion_expr::ReturnFieldArgs; + + #[test] + fn test_return_field_nullable_input() { + let ascii_func = SparkAscii::new(); + let nullable_field = Arc::new(Field::new("input", DataType::Utf8, true)); + + let result = ascii_func + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[nullable_field], + scalar_arguments: &[], + }) + .unwrap(); + + assert_eq!(result.data_type(), &DataType::Int32); + assert!( + result.is_nullable(), + "Output should be nullable when input is nullable" + ); + } + + #[test] + fn test_return_field_non_nullable_input() { + let ascii_func = SparkAscii::new(); + let non_nullable_field = Arc::new(Field::new("input", DataType::Utf8, false)); + + let result = ascii_func + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[non_nullable_field], + scalar_arguments: &[], + }) + .unwrap(); + + assert_eq!(result.data_type(), &DataType::Int32); + assert!( + !result.is_nullable(), + "Output should not be nullable when input is not nullable" + ); + } +}