Skip to content

Commit 62f5cd6

Browse files
authored
feat: support named variables & defaults for CREATE FUNCTION (#18450)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #17887. ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> See linked issue above ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> - Add validation when planning `CreateFunction` - Enforce consistent parameter style (positional or named) - Non-default params cannot follow default params - `CreateFunction` parameter names are now preserved from the parse tree - If we encounter a named parameter when constructing a `Placeholder` - We try to rewrite this to a positional parameter from the available param types - If no matching param type is found, report an error - Update `ScalarFunctionWrapper` to handle defaults - Preserve the parsed defaults - Generate all valid signatures for all possible combinations of arguments - Fall back to default expr when no matching argument is provided ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> Yes, see added / adjusted unit tests ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> Yes, if the approach here is acceptable we should update the SQL UDF examples in `datafusion-examples/examples/function_factory.rs` (TODO). Also, note that due to ambiguity between `PREPARE` and `CREATE FUNCTION` param context one error message now has reduced fidelity. > Invalid placeholder, not a number: $foo This can now be triggered either by using named params in a prepared statement or when referencing an undefined named param in a SQL UDF. New message: > Unknown placeholder: $foo Finally, there are two additional user-facing errors when planning `CreateFunction`: - When named / positional parameter styles are mixed - When non-default arguments follow default arguments <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
1 parent 92a3a33 commit 62f5cd6

File tree

5 files changed

+276
-21
lines changed

5 files changed

+276
-21
lines changed

datafusion/core/tests/user_defined/user_defined_scalar_functions.rs

Lines changed: 220 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ use datafusion_expr::{
4747
LogicalPlanBuilder, OperateFunctionArg, ReturnFieldArgs, ScalarFunctionArgs,
4848
ScalarUDF, ScalarUDFImpl, Signature, Volatility,
4949
};
50+
use datafusion_expr_common::signature::TypeSignature;
5051
use datafusion_functions_nested::range::range_udf;
5152
use parking_lot::Mutex;
5253
use regex::Regex;
@@ -945,6 +946,7 @@ struct ScalarFunctionWrapper {
945946
expr: Expr,
946947
signature: Signature,
947948
return_type: DataType,
949+
defaults: Vec<Option<Expr>>,
948950
}
949951

950952
impl ScalarUDFImpl for ScalarFunctionWrapper {
@@ -973,27 +975,39 @@ impl ScalarUDFImpl for ScalarFunctionWrapper {
973975
args: Vec<Expr>,
974976
_info: &dyn SimplifyInfo,
975977
) -> Result<ExprSimplifyResult> {
976-
let replacement = Self::replacement(&self.expr, &args)?;
978+
let replacement = Self::replacement(&self.expr, &args, &self.defaults)?;
977979

978980
Ok(ExprSimplifyResult::Simplified(replacement))
979981
}
980982
}
981983

982984
impl ScalarFunctionWrapper {
983985
// replaces placeholders with actual arguments
984-
fn replacement(expr: &Expr, args: &[Expr]) -> Result<Expr> {
986+
fn replacement(
987+
expr: &Expr,
988+
args: &[Expr],
989+
defaults: &[Option<Expr>],
990+
) -> Result<Expr> {
985991
let result = expr.clone().transform(|e| {
986992
let r = match e {
987993
Expr::Placeholder(placeholder) => {
988994
let placeholder_position =
989995
Self::parse_placeholder_identifier(&placeholder.id)?;
990996
if placeholder_position < args.len() {
991997
Transformed::yes(args[placeholder_position].clone())
992-
} else {
998+
} else if placeholder_position >= defaults.len() {
993999
exec_err!(
994-
"Function argument {} not provided, argument missing!",
1000+
"Invalid placeholder, out of range: {}",
9951001
placeholder.id
9961002
)?
1003+
} else {
1004+
match defaults[placeholder_position] {
1005+
Some(ref default) => Transformed::yes(default.clone()),
1006+
None => exec_err!(
1007+
"Function argument {} not provided, argument missing!",
1008+
placeholder.id
1009+
)?,
1010+
}
9971011
}
9981012
}
9991013
_ => Transformed::no(e),
@@ -1021,6 +1035,32 @@ impl TryFrom<CreateFunction> for ScalarFunctionWrapper {
10211035
type Error = DataFusionError;
10221036

10231037
fn try_from(definition: CreateFunction) -> std::result::Result<Self, Self::Error> {
1038+
let args = definition.args.unwrap_or_default();
1039+
let defaults: Vec<Option<Expr>> =
1040+
args.iter().map(|a| a.default_expr.clone()).collect();
1041+
let signature: Signature = match defaults.iter().position(|v| v.is_some()) {
1042+
Some(pos) => {
1043+
let mut type_signatures: Vec<TypeSignature> = vec![];
1044+
// Generate all valid signatures
1045+
for n in pos..defaults.len() + 1 {
1046+
if n == 0 {
1047+
type_signatures.push(TypeSignature::Nullary)
1048+
} else {
1049+
type_signatures.push(TypeSignature::Exact(
1050+
args.iter().take(n).map(|a| a.data_type.clone()).collect(),
1051+
))
1052+
}
1053+
}
1054+
Signature::one_of(
1055+
type_signatures,
1056+
definition.params.behavior.unwrap_or(Volatility::Volatile),
1057+
)
1058+
}
1059+
None => Signature::exact(
1060+
args.iter().map(|a| a.data_type.clone()).collect(),
1061+
definition.params.behavior.unwrap_or(Volatility::Volatile),
1062+
),
1063+
};
10241064
Ok(Self {
10251065
name: definition.name,
10261066
expr: definition
@@ -1030,15 +1070,8 @@ impl TryFrom<CreateFunction> for ScalarFunctionWrapper {
10301070
return_type: definition
10311071
.return_type
10321072
.expect("Return type has to be defined!"),
1033-
signature: Signature::exact(
1034-
definition
1035-
.args
1036-
.unwrap_or_default()
1037-
.into_iter()
1038-
.map(|a| a.data_type)
1039-
.collect(),
1040-
definition.params.behavior.unwrap_or(Volatility::Volatile),
1041-
),
1073+
signature,
1074+
defaults,
10421075
})
10431076
}
10441077
}
@@ -1109,6 +1142,180 @@ async fn create_scalar_function_from_sql_statement() -> Result<()> {
11091142
"#;
11101143
assert!(ctx.sql(bad_definition_sql).await.is_err());
11111144

1145+
// FIXME: Definitions with invalid placeholders are allowed, fail at runtime
1146+
let bad_expression_sql = r#"
1147+
CREATE FUNCTION better_add(DOUBLE, DOUBLE)
1148+
RETURNS DOUBLE
1149+
RETURN $1 + $3
1150+
"#;
1151+
assert!(ctx.sql(bad_expression_sql).await.is_ok());
1152+
1153+
let err = ctx
1154+
.sql("select better_add(2.0, 2.0)")
1155+
.await?
1156+
.collect()
1157+
.await
1158+
.expect_err("unknown placeholder");
1159+
let expected = "Optimizer rule 'simplify_expressions' failed\ncaused by\nExecution error: Invalid placeholder, out of range: $3";
1160+
assert!(expected.starts_with(&err.strip_backtrace()));
1161+
1162+
Ok(())
1163+
}
1164+
1165+
#[tokio::test]
1166+
async fn create_scalar_function_from_sql_statement_named_arguments() -> Result<()> {
1167+
let function_factory = Arc::new(CustomFunctionFactory::default());
1168+
let ctx = SessionContext::new().with_function_factory(function_factory.clone());
1169+
1170+
let sql = r#"
1171+
CREATE FUNCTION better_add(a DOUBLE, b DOUBLE)
1172+
RETURNS DOUBLE
1173+
RETURN $a + $b
1174+
"#;
1175+
1176+
assert!(ctx.sql(sql).await.is_ok());
1177+
1178+
let result = ctx
1179+
.sql("select better_add(2.0, 2.0)")
1180+
.await?
1181+
.collect()
1182+
.await?;
1183+
1184+
assert_batches_eq!(
1185+
&[
1186+
"+-----------------------------------+",
1187+
"| better_add(Float64(2),Float64(2)) |",
1188+
"+-----------------------------------+",
1189+
"| 4.0 |",
1190+
"+-----------------------------------+",
1191+
],
1192+
&result
1193+
);
1194+
1195+
// cannot mix named and positional style
1196+
let bad_expression_sql = r#"
1197+
CREATE FUNCTION bad_expression_fun(DOUBLE, b DOUBLE)
1198+
RETURNS DOUBLE
1199+
RETURN $1 + $b
1200+
"#;
1201+
let err = ctx
1202+
.sql(bad_expression_sql)
1203+
.await
1204+
.expect_err("cannot mix named and positional style");
1205+
let expected = "Error during planning: All function arguments must use either named or positional style.";
1206+
assert!(expected.starts_with(&err.strip_backtrace()));
1207+
1208+
Ok(())
1209+
}
1210+
1211+
#[tokio::test]
1212+
async fn create_scalar_function_from_sql_statement_default_arguments() -> Result<()> {
1213+
let function_factory = Arc::new(CustomFunctionFactory::default());
1214+
let ctx = SessionContext::new().with_function_factory(function_factory.clone());
1215+
1216+
let sql = r#"
1217+
CREATE FUNCTION better_add(a DOUBLE = 2.0, b DOUBLE = 2.0)
1218+
RETURNS DOUBLE
1219+
RETURN $a + $b
1220+
"#;
1221+
1222+
assert!(ctx.sql(sql).await.is_ok());
1223+
1224+
// Check all function arity supported
1225+
let result = ctx.sql("select better_add()").await?.collect().await?;
1226+
1227+
assert_batches_eq!(
1228+
&[
1229+
"+--------------+",
1230+
"| better_add() |",
1231+
"+--------------+",
1232+
"| 4.0 |",
1233+
"+--------------+",
1234+
],
1235+
&result
1236+
);
1237+
1238+
let result = ctx.sql("select better_add(2.0)").await?.collect().await?;
1239+
1240+
assert_batches_eq!(
1241+
&[
1242+
"+------------------------+",
1243+
"| better_add(Float64(2)) |",
1244+
"+------------------------+",
1245+
"| 4.0 |",
1246+
"+------------------------+",
1247+
],
1248+
&result
1249+
);
1250+
1251+
let result = ctx
1252+
.sql("select better_add(2.0, 2.0)")
1253+
.await?
1254+
.collect()
1255+
.await?;
1256+
1257+
assert_batches_eq!(
1258+
&[
1259+
"+-----------------------------------+",
1260+
"| better_add(Float64(2),Float64(2)) |",
1261+
"+-----------------------------------+",
1262+
"| 4.0 |",
1263+
"+-----------------------------------+",
1264+
],
1265+
&result
1266+
);
1267+
1268+
assert!(ctx.sql("select better_add(2.0, 2.0, 2.0)").await.is_err());
1269+
assert!(ctx.sql("drop function better_add").await.is_ok());
1270+
1271+
// works with positional style
1272+
let sql = r#"
1273+
CREATE FUNCTION better_add(DOUBLE, DOUBLE = 2.0)
1274+
RETURNS DOUBLE
1275+
RETURN $1 + $2
1276+
"#;
1277+
assert!(ctx.sql(sql).await.is_ok());
1278+
1279+
assert!(ctx.sql("select better_add()").await.is_err());
1280+
let result = ctx.sql("select better_add(2.0)").await?.collect().await?;
1281+
assert_batches_eq!(
1282+
&[
1283+
"+------------------------+",
1284+
"| better_add(Float64(2)) |",
1285+
"+------------------------+",
1286+
"| 4.0 |",
1287+
"+------------------------+",
1288+
],
1289+
&result
1290+
);
1291+
1292+
// non-default argument cannot follow default argument
1293+
let bad_expression_sql = r#"
1294+
CREATE FUNCTION bad_expression_fun(a DOUBLE = 2.0, b DOUBLE)
1295+
RETURNS DOUBLE
1296+
RETURN $a + $b
1297+
"#;
1298+
let err = ctx
1299+
.sql(bad_expression_sql)
1300+
.await
1301+
.expect_err("non-default argument cannot follow default argument");
1302+
let expected =
1303+
"Error during planning: Non-default arguments cannot follow default arguments.";
1304+
assert!(expected.starts_with(&err.strip_backtrace()));
1305+
1306+
// FIXME: The `DEFAULT` syntax does not work with positional params
1307+
let bad_expression_sql = r#"
1308+
CREATE FUNCTION bad_expression_fun(DOUBLE, DOUBLE DEFAULT 2.0)
1309+
RETURNS DOUBLE
1310+
RETURN $1 + $2
1311+
"#;
1312+
let err = ctx
1313+
.sql(bad_expression_sql)
1314+
.await
1315+
.expect_err("sqlparser error");
1316+
let expected =
1317+
"SQL error: ParserError(\"Expected: ), found: 2.0 at Line: 2, Column: 63\")";
1318+
assert!(expected.starts_with(&err.strip_backtrace()));
11121319
Ok(())
11131320
}
11141321

datafusion/sql/src/expr/value.rs

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,13 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
104104
}
105105

106106
/// Create a placeholder expression
107-
/// This is the same as Postgres's prepare statement syntax in which a placeholder starts with `$` sign and then
108-
/// number 1, 2, ... etc. For example, `$1` is the first placeholder; $2 is the second one and so on.
107+
/// Both named (`$foo`) and positional (`$1`, `$2`, ...) placeholder styles are supported.
109108
fn create_placeholder_expr(
110109
param: String,
111110
param_data_types: &[FieldRef],
112111
) -> Result<Expr> {
113-
// Parse the placeholder as a number because it is the only support from sqlparser and postgres
112+
// Try to parse the placeholder as a number. If the placeholder does not have a valid
113+
// positional value, assume we have a named placeholder.
114114
let index = param[1..].parse::<usize>();
115115
let idx = match index {
116116
Ok(0) => {
@@ -123,12 +123,24 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
123123
return if param_data_types.is_empty() {
124124
Ok(Expr::Placeholder(Placeholder::new_with_field(param, None)))
125125
} else {
126-
// when PREPARE Statement, param_data_types length is always 0
127-
plan_err!("Invalid placeholder, not a number: {param}")
126+
// FIXME: This branch is shared by params from PREPARE and CREATE FUNCTION, but
127+
// only CREATE FUNCTION currently supports named params. For now, we rewrite
128+
// these to positional params.
129+
let named_param_pos = param_data_types
130+
.iter()
131+
.position(|v| v.name() == &param[1..]);
132+
match named_param_pos {
133+
Some(pos) => Ok(Expr::Placeholder(Placeholder::new_with_field(
134+
format!("${}", pos + 1),
135+
param_data_types.get(pos).cloned(),
136+
))),
137+
None => plan_err!("Unknown placeholder: {param}"),
138+
}
128139
};
129140
}
130141
};
131142
// Check if the placeholder is in the parameter list
143+
// FIXME: In the CREATE FUNCTION branch, param_type = None should raise an error
132144
let param_type = param_data_types.get(idx);
133145
// Data type of the parameter
134146
debug!("type of param {param} param_data_types[idx]: {param_type:?}");

datafusion/sql/src/statement.rs

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1222,6 +1222,28 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
12221222
}
12231223
None => None,
12241224
};
1225+
// Validate default arguments
1226+
let first_default = match args.as_ref() {
1227+
Some(arg) => arg.iter().position(|t| t.default_expr.is_some()),
1228+
None => None,
1229+
};
1230+
let last_non_default = match args.as_ref() {
1231+
Some(arg) => arg
1232+
.iter()
1233+
.rev()
1234+
.position(|t| t.default_expr.is_none())
1235+
.map(|reverse_pos| arg.len() - reverse_pos - 1),
1236+
None => None,
1237+
};
1238+
if let (Some(pos_default), Some(pos_non_default)) =
1239+
(first_default, last_non_default)
1240+
{
1241+
if pos_non_default > pos_default {
1242+
return plan_err!(
1243+
"Non-default arguments cannot follow default arguments."
1244+
);
1245+
}
1246+
}
12251247
// At the moment functions can't be qualified `schema.name`
12261248
let name = match &name.0[..] {
12271249
[] => exec_err!("Function should have name")?,
@@ -1233,9 +1255,23 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
12331255
//
12341256
let arg_types = args.as_ref().map(|arg| {
12351257
arg.iter()
1236-
.map(|t| Arc::new(Field::new("", t.data_type.clone(), true)))
1258+
.map(|t| {
1259+
let name = match t.name.clone() {
1260+
Some(name) => name.value,
1261+
None => "".to_string(),
1262+
};
1263+
Arc::new(Field::new(name, t.data_type.clone(), true))
1264+
})
12371265
.collect::<Vec<_>>()
12381266
});
1267+
// Validate parameter style
1268+
if let Some(ref fields) = arg_types {
1269+
let count_positional =
1270+
fields.iter().filter(|f| f.name() == "").count();
1271+
if !(count_positional == 0 || count_positional == fields.len()) {
1272+
return plan_err!("All function arguments must use either named or positional style.");
1273+
}
1274+
}
12391275
let mut planner_context = PlannerContext::new()
12401276
.with_prepare_param_data_types(arg_types.unwrap_or_default());
12411277

datafusion/sql/tests/cases/params.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ fn test_prepare_statement_to_plan_panic_param_format() {
105105
assert_snapshot!(
106106
logical_plan(sql).unwrap_err().strip_backtrace(),
107107
@r###"
108-
Error during planning: Invalid placeholder, not a number: $foo
108+
Error during planning: Unknown placeholder: $foo
109109
"###
110110
);
111111
}

0 commit comments

Comments
 (0)