Skip to content

Commit

Permalink
Minor: Move function signature check to planning stage (#9401)
Browse files Browse the repository at this point in the history
* check signature for udf-based function impl

* clippy
  • Loading branch information
2010YOUY01 authored Mar 1, 2024
1 parent ec67380 commit 9aa01f3
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 55 deletions.
18 changes: 17 additions & 1 deletion datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl ExprSchemable for Expr {
.collect::<Result<Vec<_>>>()?;
match func_def {
ScalarFunctionDefinition::BuiltIn(fun) => {
// verify that input data types is consistent with function's `TypeSignature`
// verify that function is invoked with correct number and type of arguments as defined in `TypeSignature`
data_types(&arg_data_types, &fun.signature()).map_err(|_| {
plan_datafusion_err!(
"{}",
Expand All @@ -134,9 +134,25 @@ impl ExprSchemable for Expr {
)
})?;

// perform additional function arguments validation (due to limited
// expressiveness of `TypeSignature`), then infer return type
fun.return_type(&arg_data_types)
}
ScalarFunctionDefinition::UDF(fun) => {
// verify that function is invoked with correct number and type of arguments as defined in `TypeSignature`
data_types(&arg_data_types, fun.signature()).map_err(|_| {
plan_datafusion_err!(
"{}",
utils::generate_signature_error_msg(
fun.name(),
fun.signature().clone(),
&arg_data_types,
)
)
})?;

// perform additional function arguments validation (due to limited
// expressiveness of `TypeSignature`), then infer return type
Ok(fun.return_type_from_exprs(args, schema)?)
}
ScalarFunctionDefinition::Name(_) => {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1800,7 +1800,7 @@ pub struct Values {

/// Evaluates an arbitrary list of expressions (essentially a
/// SELECT with an expression list) on its input.
#[derive(Clone, PartialEq, Eq, Hash)]
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
// mark non_exhaustive to encourage use of try_new/new()
#[non_exhaustive]
pub struct Projection {
Expand Down
12 changes: 0 additions & 12 deletions datafusion/functions/src/math/abs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,8 @@ use arrow::array::Int32Array;
use arrow::array::Int64Array;
use arrow::array::Int8Array;
use arrow::datatypes::DataType;
use datafusion_common::plan_datafusion_err;
use datafusion_common::{exec_err, not_impl_err};
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::utils;
use datafusion_expr::ColumnarValue;

use arrow::array::{ArrayRef, Float32Array, Float64Array};
Expand Down Expand Up @@ -131,16 +129,6 @@ impl ScalarUDFImpl for AbsFunc {
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if arg_types.len() != 1 {
return Err(plan_datafusion_err!(
"{}",
utils::generate_signature_error_msg(
self.name(),
self.signature().clone(),
arg_types,
)
));
}
match arg_types[0] {
DataType::Float32 => Ok(DataType::Float32),
DataType::Float64 => Ok(DataType::Float64),
Expand Down
17 changes: 2 additions & 15 deletions datafusion/functions/src/math/acos.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@
use arrow::array::{ArrayRef, Float32Array, Float64Array};
use arrow::datatypes::DataType;
use datafusion_common::{exec_err, plan_datafusion_err, DataFusionError, Result};
use datafusion_common::{exec_err, DataFusionError, Result};
use datafusion_expr::ColumnarValue;
use datafusion_expr::{
utils::generate_signature_error_msg, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use std::any::Any;
use std::sync::Arc;

Expand Down Expand Up @@ -58,17 +56,6 @@ impl ScalarUDFImpl for AcosFunc {
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if arg_types.len() != 1 {
return Err(plan_datafusion_err!(
"{}",
generate_signature_error_msg(
self.name(),
self.signature().clone(),
arg_types,
)
));
}

let arg_type = &arg_types[0];

match arg_type {
Expand Down
19 changes: 3 additions & 16 deletions datafusion/functions/src/math/nans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@
//! Math function: `isnan()`.
use arrow::datatypes::DataType;
use datafusion_common::{exec_err, plan_datafusion_err, DataFusionError, Result};
use datafusion_common::{exec_err, DataFusionError, Result};
use datafusion_expr::ColumnarValue;

use arrow::array::{ArrayRef, BooleanArray, Float32Array, Float64Array};
use datafusion_expr::TypeSignature::*;
use datafusion_expr::{
utils::generate_signature_error_msg, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use std::any::Any;
use std::sync::Arc;

Expand Down Expand Up @@ -58,18 +56,7 @@ impl ScalarUDFImpl for IsNanFunc {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if arg_types.len() != 1 {
return Err(plan_datafusion_err!(
"{}",
generate_signature_error_msg(
self.name(),
self.signature().clone(),
arg_types,
)
));
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Boolean)
}

Expand Down
14 changes: 6 additions & 8 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -873,14 +873,12 @@ mod test {
fn scalar_udf_invalid_input() -> Result<()> {
let empty = empty();
let udf = ScalarUDF::from(TestScalarUDF {}).call(vec![lit("Apple")]);
let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?);
let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, "")
.err()
.unwrap();
assert_eq!(
"type_coercion\ncaused by\nError during planning: Coercion from [Utf8] to the signature Uniform(1, [Float32]) failed.",
err.strip_backtrace()
);
let plan_err = Projection::try_new(vec![udf], empty)
.expect_err("Expected an error due to incorrect function input");

let expected_error = "Error during planning: No function matches the given name and argument types 'TestScalarUDF(Utf8)'. You might need to add explicit type casts.";

assert!(plan_err.to_string().starts_with(expected_error));
Ok(())
}

Expand Down
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/test_files/errors.slt
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ statement error Error during planning: No function matches the given name and ar
SELECT concat();

# error message for wrong function signature (Uniform: t args all from some common types)
statement error DataFusion error: Failed to coerce arguments for NULLIF
statement error DataFusion error: Error during planning: No function matches the given name and argument types 'nullif\(Int64\)'. You might need to add explicit type casts.
SELECT nullif(1);

# error message for wrong function signature (Exact: exact number of args of an exact type)
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/test_files/scalar.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1858,7 +1858,7 @@ statement error Error during planning: No function matches the given name and ar
SELECT concat();

# error message for wrong function signature (Uniform: t args all from some common types)
statement error DataFusion error: Failed to coerce arguments for NULLIF
statement error DataFusion error: Error during planning: No function matches the given name and argument types 'nullif\(Int64\)'. You might need to add explicit type casts.
SELECT nullif(1);


Expand Down

0 comments on commit 9aa01f3

Please sign in to comment.