diff --git a/datafusion/functions/src/unicode/strpos.rs b/datafusion/functions/src/unicode/strpos.rs index abc2780e002f..a0ad2e3f0b75 100644 --- a/datafusion/functions/src/unicode/strpos.rs +++ b/datafusion/functions/src/unicode/strpos.rs @@ -23,7 +23,7 @@ use arrow::array::{ ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray, StringArrayType, }; use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{exec_err, internal_err, Result}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; @@ -79,8 +79,17 @@ impl ScalarUDFImpl for StrposFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - utf8_to_int_type(&arg_types[0], "strpos/instr/position") + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_type_from_args should be used instead") + } + + fn return_type_from_args( + &self, + args: datafusion_expr::ReturnTypeArgs, + ) -> Result { + utf8_to_int_type(&args.arg_types[0], "strpos/instr/position").map(|data_type| { + datafusion_expr::ReturnInfo::new(data_type, args.nullables.iter().any(|x| *x)) + }) } fn invoke_with_args( @@ -201,6 +210,7 @@ mod tests { use arrow::array::{Array, Int32Array, Int64Array}; use arrow::datatypes::DataType::{Int32, Int64}; + use arrow::datatypes::DataType; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -288,4 +298,27 @@ mod tests { test_strpos!("", "" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array); test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View LargeUtf8 i32 Int32 Int32Array); } + + #[test] + fn nullable_return_type() { + fn get_nullable(string_array_nullable: bool, substring_nullable: bool) -> bool { + let strpos = StrposFunc::new(); + let args = datafusion_expr::ReturnTypeArgs { + arg_types: &[DataType::Utf8, DataType::Utf8], + nullables: &[string_array_nullable, substring_nullable], + scalar_arguments: &[None::<&ScalarValue>, None::<&ScalarValue>], + }; + + let (_, nullable) = strpos.return_type_from_args(args).unwrap().into_parts(); + + nullable + } + + assert!(!get_nullable(false, false)); + + // If any of the arguments is nullable, the result is nullable + assert!(get_nullable(true, false)); + assert!(get_nullable(false, true)); + assert!(get_nullable(true, true)); + } }