From 9aa01f35cf512e03196799c44cfecf7fe8af90e2 Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Fri, 1 Mar 2024 15:22:47 +0800 Subject: [PATCH] Minor: Move function signature check to planning stage (#9401) * check signature for udf-based function impl * clippy --- datafusion/expr/src/expr_schema.rs | 18 +++++++++++++++++- datafusion/expr/src/logical_plan/plan.rs | 2 +- datafusion/functions/src/math/abs.rs | 12 ------------ datafusion/functions/src/math/acos.rs | 17 ++--------------- datafusion/functions/src/math/nans.rs | 19 +++---------------- .../optimizer/src/analyzer/type_coercion.rs | 14 ++++++-------- datafusion/sqllogictest/test_files/errors.slt | 2 +- datafusion/sqllogictest/test_files/scalar.slt | 2 +- 8 files changed, 31 insertions(+), 55 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index a453730a0e71..026627a05e62 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -122,7 +122,7 @@ impl ExprSchemable for Expr { .collect::>>()?; match func_def { ScalarFunctionDefinition::BuiltIn(fun) => { - // verify that input data types is consistent with function's `TypeSignature` + // verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` data_types(&arg_data_types, &fun.signature()).map_err(|_| { plan_datafusion_err!( "{}", @@ -134,9 +134,25 @@ impl ExprSchemable for Expr { ) })?; + // perform additional function arguments validation (due to limited + // expressiveness of `TypeSignature`), then infer return type fun.return_type(&arg_data_types) } ScalarFunctionDefinition::UDF(fun) => { + // verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` + data_types(&arg_data_types, fun.signature()).map_err(|_| { + plan_datafusion_err!( + "{}", + utils::generate_signature_error_msg( + fun.name(), + fun.signature().clone(), + &arg_data_types, + ) + ) + })?; + + // perform additional function arguments validation (due to limited + // expressiveness of `TypeSignature`), then infer return type Ok(fun.return_type_from_exprs(args, schema)?) } ScalarFunctionDefinition::Name(_) => { diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index aa5dff25efd8..5cce8f9cd45c 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -1800,7 +1800,7 @@ pub struct Values { /// Evaluates an arbitrary list of expressions (essentially a /// SELECT with an expression list) on its input. -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash, Debug)] // mark non_exhaustive to encourage use of try_new/new() #[non_exhaustive] pub struct Projection { diff --git a/datafusion/functions/src/math/abs.rs b/datafusion/functions/src/math/abs.rs index 34d7a29e222e..8aa48460ff69 100644 --- a/datafusion/functions/src/math/abs.rs +++ b/datafusion/functions/src/math/abs.rs @@ -24,10 +24,8 @@ use arrow::array::Int32Array; use arrow::array::Int64Array; use arrow::array::Int8Array; use arrow::datatypes::DataType; -use datafusion_common::plan_datafusion_err; use datafusion_common::{exec_err, not_impl_err}; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::utils; use datafusion_expr::ColumnarValue; use arrow::array::{ArrayRef, Float32Array, Float64Array}; @@ -131,16 +129,6 @@ impl ScalarUDFImpl for AbsFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - if arg_types.len() != 1 { - return Err(plan_datafusion_err!( - "{}", - utils::generate_signature_error_msg( - self.name(), - self.signature().clone(), - arg_types, - ) - )); - } match arg_types[0] { DataType::Float32 => Ok(DataType::Float32), DataType::Float64 => Ok(DataType::Float64), diff --git a/datafusion/functions/src/math/acos.rs b/datafusion/functions/src/math/acos.rs index 22dfd37a0159..f6a440beec0a 100644 --- a/datafusion/functions/src/math/acos.rs +++ b/datafusion/functions/src/math/acos.rs @@ -19,11 +19,9 @@ use arrow::array::{ArrayRef, Float32Array, Float64Array}; use arrow::datatypes::DataType; -use datafusion_common::{exec_err, plan_datafusion_err, DataFusionError, Result}; +use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_expr::ColumnarValue; -use datafusion_expr::{ - utils::generate_signature_error_msg, ScalarUDFImpl, Signature, Volatility, -}; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::Arc; @@ -58,17 +56,6 @@ impl ScalarUDFImpl for AcosFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - if arg_types.len() != 1 { - return Err(plan_datafusion_err!( - "{}", - generate_signature_error_msg( - self.name(), - self.signature().clone(), - arg_types, - ) - )); - } - let arg_type = &arg_types[0]; match arg_type { diff --git a/datafusion/functions/src/math/nans.rs b/datafusion/functions/src/math/nans.rs index 8abbe7e7ab83..3f3d7d197c33 100644 --- a/datafusion/functions/src/math/nans.rs +++ b/datafusion/functions/src/math/nans.rs @@ -18,14 +18,12 @@ //! Math function: `isnan()`. use arrow::datatypes::DataType; -use datafusion_common::{exec_err, plan_datafusion_err, DataFusionError, Result}; +use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_expr::ColumnarValue; use arrow::array::{ArrayRef, BooleanArray, Float32Array, Float64Array}; use datafusion_expr::TypeSignature::*; -use datafusion_expr::{ - utils::generate_signature_error_msg, ScalarUDFImpl, Signature, Volatility, -}; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::Arc; @@ -58,18 +56,7 @@ impl ScalarUDFImpl for IsNanFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - if arg_types.len() != 1 { - return Err(plan_datafusion_err!( - "{}", - generate_signature_error_msg( - self.name(), - self.signature().clone(), - arg_types, - ) - )); - } - + fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(DataType::Boolean) } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 3a43e3cd7c20..d469e0f8ce0d 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -873,14 +873,12 @@ mod test { fn scalar_udf_invalid_input() -> Result<()> { let empty = empty(); let udf = ScalarUDF::from(TestScalarUDF {}).call(vec![lit("Apple")]); - let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?); - let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, "") - .err() - .unwrap(); - assert_eq!( - "type_coercion\ncaused by\nError during planning: Coercion from [Utf8] to the signature Uniform(1, [Float32]) failed.", - err.strip_backtrace() - ); + let plan_err = Projection::try_new(vec![udf], empty) + .expect_err("Expected an error due to incorrect function input"); + + let expected_error = "Error during planning: No function matches the given name and argument types 'TestScalarUDF(Utf8)'. You might need to add explicit type casts."; + + assert!(plan_err.to_string().starts_with(expected_error)); Ok(()) } diff --git a/datafusion/sqllogictest/test_files/errors.slt b/datafusion/sqllogictest/test_files/errors.slt index 3a23f3615d08..ab281eac31f5 100644 --- a/datafusion/sqllogictest/test_files/errors.slt +++ b/datafusion/sqllogictest/test_files/errors.slt @@ -84,7 +84,7 @@ statement error Error during planning: No function matches the given name and ar SELECT concat(); # error message for wrong function signature (Uniform: t args all from some common types) -statement error DataFusion error: Failed to coerce arguments for NULLIF +statement error DataFusion error: Error during planning: No function matches the given name and argument types 'nullif\(Int64\)'. You might need to add explicit type casts. SELECT nullif(1); # error message for wrong function signature (Exact: exact number of args of an exact type) diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 45334f5b402f..a3e97d6a7d82 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -1858,7 +1858,7 @@ statement error Error during planning: No function matches the given name and ar SELECT concat(); # error message for wrong function signature (Uniform: t args all from some common types) -statement error DataFusion error: Failed to coerce arguments for NULLIF +statement error DataFusion error: Error during planning: No function matches the given name and argument types 'nullif\(Int64\)'. You might need to add explicit type casts. SELECT nullif(1);