Skip to content

Commit d3f1c9a

Browse files
authored
Introduce return_type_from_args for ScalarFunction. (#14094)
* switch func Signed-off-by: Jay Zhan <jayzhan211@gmail.com> * fix test Signed-off-by: Jay Zhan <jayzhan211@gmail.com> * fix test Signed-off-by: Jay Zhan <jayzhan211@gmail.com> * deprecate old Signed-off-by: Jay Zhan <jayzhan211@gmail.com> * add try new Signed-off-by: Jay Zhan <jayzhan211@gmail.com> * deprecate Signed-off-by: Jay Zhan <jayzhan211@gmail.com> * rm deprecate Signed-off-by: Jay Zhan <jayzhan211@gmail.com> * reaplce deprecated func Signed-off-by: Jay Zhan <jayzhan211@gmail.com> * cleanup Signed-off-by: Jay Zhan <jayzhan211@gmail.com> * combine type and nullable Signed-off-by: Jay Zhan <jayzhan211@gmail.com> * fix slowdown Signed-off-by: Jay Zhan <jayzhan211@gmail.com> * clippy Signed-off-by: Jay Zhan <jayzhan211@gmail.com> * fix take Signed-off-by: Jay Zhan <jayzhan211@gmail.com> * fmt Signed-off-by: Jay Zhan <jayzhan211@gmail.com> * rm duplicated test Signed-off-by: Jay Zhan <jayzhan211@gmail.com> * refactor: remove unused documentation sections from scalar functions * upd doc Signed-off-by: Jay Zhan <jayzhan211@gmail.com> * use scalar value Signed-off-by: Jay Zhan <jayzhan211@gmail.com> * fix test Signed-off-by: Jay Zhan <jayzhan211@gmail.com> * fix test Signed-off-by: Jay Zhan <jayzhan211@gmail.com> * use try_as_str Signed-off-by: Jay Zhan <jayzhan211@gmail.com> * refactor: improve error handling for constant string arguments in UDFs * refactor: enhance error messages for constant string requirements in UDFs * refactor: streamline argument validation in return_type_from_args for UDFs * rename and doc Signed-off-by: Jay Zhan <jayzhan211@gmail.com> * refactor: add documentation for nullability of scalar arguments in ReturnTypeArgs * rm test Signed-off-by: Jay Zhan <jayzhan211@gmail.com> * refactor: remove unused import of Int32Array in utils tests --------- Signed-off-by: Jay Zhan <jayzhan211@gmail.com>
1 parent acf66d6 commit d3f1c9a

File tree

21 files changed

+475
-326
lines changed

21 files changed

+475
-326
lines changed

datafusion/core/tests/fuzz_cases/equivalence/ordering.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ use crate::fuzz_cases::equivalence::utils::{
2121
is_table_same_after_sort, TestScalarUDF,
2222
};
2323
use arrow_schema::SortOptions;
24-
use datafusion_common::{DFSchema, Result};
24+
use datafusion_common::Result;
2525
use datafusion_expr::{Operator, ScalarUDF};
2626
use datafusion_physical_expr::expressions::{col, BinaryExpr};
27+
use datafusion_physical_expr::ScalarFunctionExpr;
2728
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
2829
use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
2930
use itertools::Itertools;
@@ -103,14 +104,13 @@ fn test_ordering_satisfy_with_equivalence_complex_random() -> Result<()> {
103104
let table_data_with_properties =
104105
generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?;
105106

106-
let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new());
107-
let floor_a = datafusion_physical_expr::udf::create_physical_expr(
108-
&test_fun,
109-
&[col("a", &test_schema)?],
107+
let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new()));
108+
let col_a = col("a", &test_schema)?;
109+
let floor_a = Arc::new(ScalarFunctionExpr::try_new(
110+
Arc::clone(&test_fun),
111+
vec![col_a],
110112
&test_schema,
111-
&[],
112-
&DFSchema::empty(),
113-
)?;
113+
)?);
114114
let a_plus_b = Arc::new(BinaryExpr::new(
115115
col("a", &test_schema)?,
116116
Operator::Plus,

datafusion/core/tests/fuzz_cases/equivalence/projection.rs

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@ use crate::fuzz_cases::equivalence::utils::{
2020
is_table_same_after_sort, TestScalarUDF,
2121
};
2222
use arrow_schema::SortOptions;
23-
use datafusion_common::{DFSchema, Result};
23+
use datafusion_common::Result;
2424
use datafusion_expr::{Operator, ScalarUDF};
2525
use datafusion_physical_expr::equivalence::ProjectionMapping;
2626
use datafusion_physical_expr::expressions::{col, BinaryExpr};
27+
use datafusion_physical_expr::{PhysicalExprRef, ScalarFunctionExpr};
2728
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
2829
use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
2930
use itertools::Itertools;
@@ -42,14 +43,13 @@ fn project_orderings_random() -> Result<()> {
4243
let table_data_with_properties =
4344
generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?;
4445
// Floor(a)
45-
let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new());
46-
let floor_a = datafusion_physical_expr::udf::create_physical_expr(
47-
&test_fun,
48-
&[col("a", &test_schema)?],
46+
let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new()));
47+
let col_a = col("a", &test_schema)?;
48+
let floor_a = Arc::new(ScalarFunctionExpr::try_new(
49+
Arc::clone(&test_fun),
50+
vec![col_a],
4951
&test_schema,
50-
&[],
51-
&DFSchema::empty(),
52-
)?;
52+
)?);
5353
// a + b
5454
let a_plus_b = Arc::new(BinaryExpr::new(
5555
col("a", &test_schema)?,
@@ -120,14 +120,13 @@ fn ordering_satisfy_after_projection_random() -> Result<()> {
120120
let table_data_with_properties =
121121
generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?;
122122
// Floor(a)
123-
let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new());
124-
let floor_a = datafusion_physical_expr::udf::create_physical_expr(
125-
&test_fun,
126-
&[col("a", &test_schema)?],
123+
let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new()));
124+
let col_a = col("a", &test_schema)?;
125+
let floor_a = Arc::new(ScalarFunctionExpr::try_new(
126+
Arc::clone(&test_fun),
127+
vec![col_a],
127128
&test_schema,
128-
&[],
129-
&DFSchema::empty(),
130-
)?;
129+
)?) as PhysicalExprRef;
131130
// a + b
132131
let a_plus_b = Arc::new(BinaryExpr::new(
133132
col("a", &test_schema)?,

datafusion/core/tests/fuzz_cases/equivalence/properties.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@ use crate::fuzz_cases::equivalence::utils::{
1919
create_random_schema, generate_table_for_eq_properties, is_table_same_after_sort,
2020
TestScalarUDF,
2121
};
22-
use datafusion_common::{DFSchema, Result};
22+
use datafusion_common::Result;
2323
use datafusion_expr::{Operator, ScalarUDF};
2424
use datafusion_physical_expr::expressions::{col, BinaryExpr};
25+
use datafusion_physical_expr::{PhysicalExprRef, ScalarFunctionExpr};
2526
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
2627
use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
2728
use itertools::Itertools;
@@ -40,14 +41,14 @@ fn test_find_longest_permutation_random() -> Result<()> {
4041
let table_data_with_properties =
4142
generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?;
4243

43-
let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new());
44-
let floor_a = datafusion_physical_expr::udf::create_physical_expr(
45-
&test_fun,
46-
&[col("a", &test_schema)?],
44+
let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new()));
45+
let col_a = col("a", &test_schema)?;
46+
let floor_a = Arc::new(ScalarFunctionExpr::try_new(
47+
Arc::clone(&test_fun),
48+
vec![col_a],
4749
&test_schema,
48-
&[],
49-
&DFSchema::empty(),
50-
)?;
50+
)?) as PhysicalExprRef;
51+
5152
let a_plus_b = Arc::new(BinaryExpr::new(
5253
col("a", &test_schema)?,
5354
Operator::Plus,

datafusion/core/tests/user_defined/user_defined_scalar_functions.rs

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,12 @@ use datafusion_common::cast::{as_float64_array, as_int32_array};
3434
use datafusion_common::tree_node::{Transformed, TreeNode};
3535
use datafusion_common::{
3636
assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, internal_err,
37-
not_impl_err, plan_err, DFSchema, DataFusionError, ExprSchema, HashMap, Result,
38-
ScalarValue,
37+
not_impl_err, plan_err, DFSchema, DataFusionError, HashMap, Result, ScalarValue,
3938
};
4039
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
4140
use datafusion_expr::{
42-
Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, ExprSchemable,
43-
LogicalPlanBuilder, OperateFunctionArg, ScalarUDF, ScalarUDFImpl, Signature,
41+
Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, LogicalPlanBuilder,
42+
OperateFunctionArg, ReturnInfo, ReturnTypeArgs, ScalarUDF, ScalarUDFImpl, Signature,
4443
Volatility,
4544
};
4645
use datafusion_functions_nested::range::range_udf;
@@ -819,32 +818,36 @@ impl ScalarUDFImpl for TakeUDF {
819818
///
820819
/// 1. If the third argument is '0', return the type of the first argument
821820
/// 2. If the third argument is '1', return the type of the second argument
822-
fn return_type_from_exprs(
823-
&self,
824-
arg_exprs: &[Expr],
825-
schema: &dyn ExprSchema,
826-
_arg_data_types: &[DataType],
827-
) -> Result<DataType> {
828-
if arg_exprs.len() != 3 {
829-
return plan_err!("Expected 3 arguments, got {}.", arg_exprs.len());
821+
fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result<ReturnInfo> {
822+
if args.arg_types.len() != 3 {
823+
return plan_err!("Expected 3 arguments, got {}.", args.arg_types.len());
830824
}
831825

832-
let take_idx = if let Some(Expr::Literal(ScalarValue::Int64(Some(idx)))) =
833-
arg_exprs.get(2)
834-
{
835-
if *idx == 0 || *idx == 1 {
836-
*idx as usize
826+
let take_idx = if let Some(take_idx) = args.scalar_arguments.get(2) {
827+
// This is for test only, safe to unwrap
828+
let take_idx = take_idx
829+
.unwrap()
830+
.try_as_str()
831+
.unwrap()
832+
.unwrap()
833+
.parse::<usize>()
834+
.unwrap();
835+
836+
if take_idx == 0 || take_idx == 1 {
837+
take_idx
837838
} else {
838-
return plan_err!("The third argument must be 0 or 1, got: {idx}");
839+
return plan_err!("The third argument must be 0 or 1, got: {take_idx}");
839840
}
840841
} else {
841842
return plan_err!(
842843
"The third argument must be a literal of type int64, but got {:?}",
843-
arg_exprs.get(2)
844+
args.scalar_arguments.get(2)
844845
);
845846
};
846847

847-
arg_exprs.get(take_idx).unwrap().get_type(schema)
848+
Ok(ReturnInfo::new_nullable(
849+
args.arg_types[take_idx].to_owned(),
850+
))
848851
}
849852

850853
// The actual implementation
@@ -854,7 +857,8 @@ impl ScalarUDFImpl for TakeUDF {
854857
_number_rows: usize,
855858
) -> Result<ColumnarValue> {
856859
let take_idx = match &args[2] {
857-
ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) if v < &2 => *v as usize,
860+
ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) if v == "0" => 0,
861+
ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) if v == "1" => 1,
858862
_ => unreachable!(),
859863
};
860864
match &args[take_idx] {
@@ -874,9 +878,9 @@ async fn verify_udf_return_type() -> Result<()> {
874878
// take(smallint_col, double_col, 1) as take1
875879
// FROM alltypes_plain;
876880
let exprs = vec![
877-
take.call(vec![col("smallint_col"), col("double_col"), lit(0_i64)])
881+
take.call(vec![col("smallint_col"), col("double_col"), lit("0")])
878882
.alias("take0"),
879-
take.call(vec![col("smallint_col"), col("double_col"), lit(1_i64)])
883+
take.call(vec![col("smallint_col"), col("double_col"), lit("1")])
880884
.alias("take1"),
881885
];
882886

datafusion/expr/src/expr_schema.rs

Lines changed: 48 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use crate::type_coercion::binary::get_result_type;
2424
use crate::type_coercion::functions::{
2525
data_types_with_aggregate_udf, data_types_with_scalar_udf, data_types_with_window_udf,
2626
};
27+
use crate::udf::ReturnTypeArgs;
2728
use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition};
2829
use arrow::compute::can_cast_types;
2930
use arrow::datatypes::{DataType, Field};
@@ -145,32 +146,9 @@ impl ExprSchemable for Expr {
145146
}
146147
}
147148
}
148-
Expr::ScalarFunction(ScalarFunction { func, args }) => {
149-
let arg_data_types = args
150-
.iter()
151-
.map(|e| e.get_type(schema))
152-
.collect::<Result<Vec<_>>>()?;
153-
154-
// Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature`
155-
let new_data_types = data_types_with_scalar_udf(&arg_data_types, func)
156-
.map_err(|err| {
157-
plan_datafusion_err!(
158-
"{} {}",
159-
match err {
160-
DataFusionError::Plan(msg) => msg,
161-
err => err.to_string(),
162-
},
163-
utils::generate_signature_error_msg(
164-
func.name(),
165-
func.signature().clone(),
166-
&arg_data_types,
167-
)
168-
)
169-
})?;
170-
171-
// Perform additional function arguments validation (due to limited
172-
// expressiveness of `TypeSignature`), then infer return type
173-
Ok(func.return_type_from_exprs(args, schema, &new_data_types)?)
149+
Expr::ScalarFunction(_func) => {
150+
let (return_type, _) = self.data_type_and_nullable(schema)?;
151+
Ok(return_type)
174152
}
175153
Expr::WindowFunction(window_function) => self
176154
.data_type_and_nullable_with_window_function(schema, window_function)
@@ -303,8 +281,9 @@ impl ExprSchemable for Expr {
303281
}
304282
}
305283
Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema),
306-
Expr::ScalarFunction(ScalarFunction { func, args }) => {
307-
Ok(func.is_nullable(args, input_schema))
284+
Expr::ScalarFunction(_func) => {
285+
let (_, nullable) = self.data_type_and_nullable(input_schema)?;
286+
Ok(nullable)
308287
}
309288
Expr::AggregateFunction(AggregateFunction { func, .. }) => {
310289
Ok(func.is_nullable())
@@ -415,6 +394,47 @@ impl ExprSchemable for Expr {
415394
Expr::WindowFunction(window_function) => {
416395
self.data_type_and_nullable_with_window_function(schema, window_function)
417396
}
397+
Expr::ScalarFunction(ScalarFunction { func, args }) => {
398+
let (arg_types, nullables): (Vec<DataType>, Vec<bool>) = args
399+
.iter()
400+
.map(|e| e.data_type_and_nullable(schema))
401+
.collect::<Result<Vec<_>>>()?
402+
.into_iter()
403+
.unzip();
404+
// Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature`
405+
let new_data_types = data_types_with_scalar_udf(&arg_types, func)
406+
.map_err(|err| {
407+
plan_datafusion_err!(
408+
"{} {}",
409+
match err {
410+
DataFusionError::Plan(msg) => msg,
411+
err => err.to_string(),
412+
},
413+
utils::generate_signature_error_msg(
414+
func.name(),
415+
func.signature().clone(),
416+
&arg_types,
417+
)
418+
)
419+
})?;
420+
421+
let arguments = args
422+
.iter()
423+
.map(|e| match e {
424+
Expr::Literal(sv) => Some(sv),
425+
_ => None,
426+
})
427+
.collect::<Vec<_>>();
428+
let args = ReturnTypeArgs {
429+
arg_types: &new_data_types,
430+
scalar_arguments: &arguments,
431+
nullables: &nullables,
432+
};
433+
434+
let (return_type, nullable) =
435+
func.return_type_from_args(args)?.into_parts();
436+
Ok((return_type, nullable))
437+
}
418438
_ => Ok((self.get_type(schema)?, self.nullable(schema)?)),
419439
}
420440
}

datafusion/expr/src/lib.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,10 @@ pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
9393
pub use udaf::{
9494
aggregate_doc_sections, AggregateUDF, AggregateUDFImpl, ReversedUDAF, StatisticsArgs,
9595
};
96-
pub use udf::{scalar_doc_sections, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl};
96+
pub use udf::{
97+
scalar_doc_sections, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF,
98+
ScalarUDFImpl,
99+
};
97100
pub use udwf::{window_doc_sections, ReversedUDWF, WindowUDF, WindowUDFImpl};
98101
pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};
99102

0 commit comments

Comments
 (0)