Skip to content

Commit b95f91a

Browse files
committed
fix tests, default argument validation
1 parent 8d50405 commit b95f91a

File tree

2 files changed

+21
-5
lines changed

2 files changed

+21
-5
lines changed

datafusion/core/tests/user_defined/user_defined_scalar_functions.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,9 +1174,15 @@ async fn create_scalar_function_from_sql_statement_named_arguments() -> Result<(
11741174
let bad_expression_sql = r#"
11751175
CREATE FUNCTION bad_expression_fun(DOUBLE, b DOUBLE)
11761176
RETURNS DOUBLE
1177-
RETURN $1 $b
1177+
RETURN $1 + $b
11781178
"#;
1179-
assert!(ctx.sql(bad_expression_sql).await.is_err());
1179+
let err = ctx
1180+
.sql(bad_expression_sql)
1181+
.await
1182+
.expect_err("cannot mix named and positional style");
1183+
let expected = "Error during planning: All function arguments must use either named or positional style.";
1184+
assert!(expected.starts_with(&err.strip_backtrace()));
1185+
11801186
Ok(())
11811187
}
11821188

@@ -1243,9 +1249,15 @@ async fn create_scalar_function_from_sql_statement_default_arguments() -> Result
12431249
let bad_expression_sql = r#"
12441250
CREATE FUNCTION bad_expression_fun(a DOUBLE DEFAULT 2.0, b DOUBLE)
12451251
RETURNS DOUBLE
1246-
RETURN $a $b
1252+
RETURN $a + $b
12471253
"#;
1248-
assert!(ctx.sql(bad_expression_sql).await.is_err());
1254+
let err = ctx
1255+
.sql(bad_expression_sql)
1256+
.await
1257+
.expect_err("non-default argument cannot follow default argument");
1258+
let expected =
1259+
"Error during planning: Non-default arguments cannot follow default arguments.";
1260+
assert!(expected.starts_with(&err.strip_backtrace()));
12491261
Ok(())
12501262
}
12511263

datafusion/sql/src/statement.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1228,7 +1228,11 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
12281228
None => None,
12291229
};
12301230
let last_non_default = match args.as_ref() {
1231-
Some(arg) => arg.iter().rev().position(|t| t.default_expr.is_none()),
1231+
Some(arg) => arg
1232+
.iter()
1233+
.rev()
1234+
.position(|t| t.default_expr.is_none())
1235+
.map(|reverse_pos| arg.len() - reverse_pos - 1),
12321236
None => None,
12331237
};
12341238
if let (Some(pos_default), Some(pos_non_default)) =

0 commit comments

Comments
 (0)