diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index edc1fe357649..2804a1de0606 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -3388,26 +3388,15 @@ fn ident_normalization_parser_options_ident_normalization() -> ParserOptions { } } -fn prepare_stmt_quick_test( - sql: &str, - expected_plan: &str, - expected_data_types: &str, -) -> LogicalPlan { +fn generate_prepare_stmt_and_data_types(sql: &str) -> (LogicalPlan, String) { let plan = logical_plan(sql).unwrap(); - - let assert_plan = plan.clone(); - // verify plan - assert_eq!(format!("{assert_plan}"), expected_plan); - - // verify data types - if let LogicalPlan::Statement(Statement::Prepare(Prepare { data_types, .. })) = - assert_plan - { - let dt = format!("{data_types:?}"); - assert_eq!(dt, expected_data_types); - } - - plan + let data_types = match &plan { + LogicalPlan::Statement(Statement::Prepare(Prepare { data_types, .. })) => { + format!("{data_types:?}") + } + _ => panic!("Expected a Prepare statement"), + }; + (plan, data_types) } #[test] @@ -4383,8 +4372,12 @@ fn test_prepare_statement_to_plan_panic_prepare_wrong_syntax() { #[test] fn test_prepare_statement_to_plan_panic_no_relation_and_constant_param() { let sql = "PREPARE my_plan(INT) AS SELECT id + $1"; - let expected = "Schema error: No field named id."; - assert_eq!(logical_plan(sql).unwrap_err().strip_backtrace(), expected); + + let plan = logical_plan(sql).unwrap_err().strip_backtrace(); + assert_snapshot!( + plan, + @r"Schema error: No field named id." + ); } #[test] @@ -4426,15 +4419,17 @@ fn test_prepare_statement_to_plan_panic_is_param() { fn test_prepare_statement_to_plan_no_param() { // no embedded parameter but still declare it let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = 10"; - - let expected_plan = "Prepare: \"my_plan\" [Int32] \ - \n Projection: person.id, person.age\ - \n Filter: person.age = Int64(10)\ - \n TableScan: person"; - - let expected_dt = "[Int32]"; - - let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [Int32] + Projection: person.id, person.age + Filter: person.age = Int64(10) + TableScan: person + "# + ); + assert_snapshot!(dt, @r#"[Int32]"#); /////////////////// // replace params with values @@ -4452,15 +4447,17 @@ fn test_prepare_statement_to_plan_no_param() { ////////////////////////////////////////// // no embedded parameter and no declare it let sql = "PREPARE my_plan AS SELECT id, age FROM person WHERE age = 10"; - - let expected_plan = "Prepare: \"my_plan\" [] \ - \n Projection: person.id, person.age\ - \n Filter: person.age = Int64(10)\ - \n TableScan: person"; - - let expected_dt = "[]"; - - let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [] + Projection: person.id, person.age + Filter: person.age = Int64(10) + TableScan: person + "# + ); + assert_snapshot!(dt, @r#"[]"#); /////////////////// // replace params with values @@ -4532,12 +4529,16 @@ fn test_prepare_statement_to_plan_no_param_on_value_panic() { #[test] fn test_prepare_statement_to_plan_params_as_constants() { let sql = "PREPARE my_plan(INT) AS SELECT $1"; - - let expected_plan = "Prepare: \"my_plan\" [Int32] \ - \n Projection: $1\n EmptyRelation"; - let expected_dt = "[Int32]"; - - let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [Int32] + Projection: $1 + EmptyRelation + "# + ); + assert_snapshot!(dt, @r#"[Int32]"#); /////////////////// // replace params with values @@ -4553,12 +4554,16 @@ fn test_prepare_statement_to_plan_params_as_constants() { /////////////////////////////////////// let sql = "PREPARE my_plan(INT) AS SELECT 1 + $1"; - - let expected_plan = "Prepare: \"my_plan\" [Int32] \ - \n Projection: Int64(1) + $1\n EmptyRelation"; - let expected_dt = "[Int32]"; - - let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [Int32] + Projection: Int64(1) + $1 + EmptyRelation + "# + ); + assert_snapshot!(dt, @r#"[Int32]"#); /////////////////// // replace params with values @@ -4574,12 +4579,16 @@ fn test_prepare_statement_to_plan_params_as_constants() { /////////////////////////////////////// let sql = "PREPARE my_plan(INT, DOUBLE) AS SELECT 1 + $1 + $2"; - - let expected_plan = "Prepare: \"my_plan\" [Int32, Float64] \ - \n Projection: Int64(1) + $1 + $2\n EmptyRelation"; - let expected_dt = "[Int32, Float64]"; - - let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [Int32, Float64] + Projection: Int64(1) + $1 + $2 + EmptyRelation + "# + ); + assert_snapshot!(dt, @r#"[Int32, Float64]"#); /////////////////// // replace params with values @@ -4598,20 +4607,20 @@ fn test_prepare_statement_to_plan_params_as_constants() { } #[test] -fn test_prepare_statement_infer_types_from_join() { +fn test_infer_types_from_join() { let sql = "SELECT id, order_id FROM person JOIN orders ON id = customer_id and age = $1"; - let expected_plan = r#" -Projection: person.id, orders.order_id - Inner Join: Filter: person.id = orders.customer_id AND person.age = $1 - TableScan: person - TableScan: orders + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.id, orders.order_id + Inner Join: Filter: person.id = orders.customer_id AND person.age = $1 + TableScan: person + TableScan: orders "# - .trim(); - - let expected_dt = "[Int32]"; - let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + ); let actual_types = plan.get_parameter_types().unwrap(); let expected_types = HashMap::from([("$1".to_string(), Some(DataType::Int32))]); @@ -4633,18 +4642,17 @@ Projection: person.id, orders.order_id } #[test] -fn test_prepare_statement_infer_types_from_predicate() { +fn test_infer_types_from_predicate() { let sql = "SELECT id, age FROM person WHERE age = $1"; - - let expected_plan = r#" -Projection: person.id, person.age - Filter: person.age = $1 - TableScan: person - "# - .trim(); - - let expected_dt = "[Int32]"; - let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.id, person.age + Filter: person.age = $1 + TableScan: person + "# + ); let actual_types = plan.get_parameter_types().unwrap(); let expected_types = HashMap::from([("$1".to_string(), Some(DataType::Int32))]); @@ -4665,18 +4673,18 @@ Projection: person.id, person.age } #[test] -fn test_prepare_statement_infer_types_from_between_predicate() { +fn test_infer_types_from_between_predicate() { let sql = "SELECT id, age FROM person WHERE age BETWEEN $1 AND $2"; - let expected_plan = r#" -Projection: person.id, person.age - Filter: person.age BETWEEN $1 AND $2 - TableScan: person - "# - .trim(); - - let expected_dt = "[Int32]"; - let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.id, person.age + Filter: person.age BETWEEN $1 AND $2 + TableScan: person + "# + ); let actual_types = plan.get_parameter_types().unwrap(); let expected_types = HashMap::from([ @@ -4700,23 +4708,23 @@ Projection: person.id, person.age } #[test] -fn test_prepare_statement_infer_types_subquery() { +fn test_infer_types_subquery() { let sql = "SELECT id, age FROM person WHERE age = (select max(age) from person where id = $1)"; - let expected_plan = r#" -Projection: person.id, person.age - Filter: person.age = () - Subquery: - Projection: max(person.age) - Aggregate: groupBy=[[]], aggr=[[max(person.age)]] - Filter: person.id = $1 - TableScan: person - TableScan: person - "# - .trim(); - - let expected_dt = "[Int32]"; - let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.id, person.age + Filter: person.age = () + Subquery: + Projection: max(person.age) + Aggregate: groupBy=[[]], aggr=[[max(person.age)]] + Filter: person.id = $1 + TableScan: person + TableScan: person + "# + ); let actual_types = plan.get_parameter_types().unwrap(); let expected_types = HashMap::from([("$1".to_string(), Some(DataType::UInt32))]); @@ -4742,19 +4750,19 @@ Projection: person.id, person.age } #[test] -fn test_prepare_statement_update_infer() { +fn test_update_infer() { let sql = "update person set age=$1 where id=$2"; - let expected_plan = r#" -Dml: op=[Update] table=[person] - Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, $1 AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 - Filter: person.id = $2 - TableScan: person - "# - .trim(); - - let expected_dt = "[Int32]"; - let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Dml: op=[Update] table=[person] + Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, $1 AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 + Filter: person.id = $2 + TableScan: person + "# + ); let actual_types = plan.get_parameter_types().unwrap(); let expected_types = HashMap::from([ @@ -4779,17 +4787,17 @@ Dml: op=[Update] table=[person] } #[test] -fn test_prepare_statement_insert_infer() { +fn test_insert_infer() { let sql = "insert into person (id, first_name, last_name) values ($1, $2, $3)"; - - let expected_plan = "Dml: op=[Insert Into] table=[person]\ - \n Projection: column1 AS id, column2 AS first_name, column3 AS last_name, \ - CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, \ - CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀\ - \n Values: ($1, $2, $3)"; - - let expected_dt = "[Int32]"; - let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Dml: op=[Insert Into] table=[person] + Projection: column1 AS id, column2 AS first_name, column3 AS last_name, CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀 + Values: ($1, $2, $3) + "# + ); let actual_types = plan.get_parameter_types().unwrap(); let expected_types = HashMap::from([ @@ -4819,15 +4827,17 @@ fn test_prepare_statement_insert_infer() { #[test] fn test_prepare_statement_to_plan_one_param() { let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $1"; - - let expected_plan = "Prepare: \"my_plan\" [Int32] \ - \n Projection: person.id, person.age\ - \n Filter: person.age = $1\ - \n TableScan: person"; - - let expected_dt = "[Int32]"; - - let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [Int32] + Projection: person.id, person.age + Filter: person.age = $1 + TableScan: person + "# + ); + assert_snapshot!(dt, @r#"[Int32]"#); /////////////////// // replace params with values @@ -4848,16 +4858,19 @@ fn test_prepare_statement_to_plan_one_param() { fn test_prepare_statement_to_plan_data_type() { let sql = "PREPARE my_plan(DOUBLE) AS SELECT id, age FROM person WHERE age = $1"; - // age is defined as Int32 but prepare statement declares it as DOUBLE/Float64 - // Prepare statement and its logical plan should be created successfully - let expected_plan = "Prepare: \"my_plan\" [Float64] \ - \n Projection: person.id, person.age\ - \n Filter: person.age = $1\ - \n TableScan: person"; - - let expected_dt = "[Float64]"; - - let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + // age is defined as Int32 but prepare statement declares it as DOUBLE/Float64 + // Prepare statement and its logical plan should be created successfully + @r#" + Prepare: "my_plan" [Float64] + Projection: person.id, person.age + Filter: person.age = $1 + TableScan: person + "# + ); + assert_snapshot!(dt, @r#"[Float64]"#); /////////////////// // replace params with values still succeed and use Float64 @@ -4880,15 +4893,17 @@ fn test_prepare_statement_to_plan_multi_params() { SELECT id, age, $6 FROM person WHERE age IN ($1, $4) AND salary > $3 and salary < $5 OR first_name < $2"; - - let expected_plan = "Prepare: \"my_plan\" [Int32, Utf8, Float64, Int32, Float64, Utf8] \ - \n Projection: person.id, person.age, $6\ - \n Filter: person.age IN ([$1, $4]) AND person.salary > $3 AND person.salary < $5 OR person.first_name < $2\ - \n TableScan: person"; - - let expected_dt = "[Int32, Utf8, Float64, Int32, Float64, Utf8]"; - - let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [Int32, Utf8, Float64, Int32, Float64, Utf8] + Projection: person.id, person.age, $6 + Filter: person.age IN ([$1, $4]) AND person.salary > $3 AND person.salary < $5 OR person.first_name < $2 + TableScan: person + "# + ); + assert_snapshot!(dt, @r#"[Int32, Utf8, Float64, Int32, Float64, Utf8]"#); /////////////////// // replace params with values @@ -4921,17 +4936,19 @@ fn test_prepare_statement_to_plan_having() { GROUP BY id HAVING sum(age) < $1 AND sum(age) > 10 OR sum(age) in ($3, $4)\ "; - - let expected_plan = "Prepare: \"my_plan\" [Int32, Float64, Float64, Float64] \ - \n Projection: person.id, sum(person.age)\ - \n Filter: sum(person.age) < $1 AND sum(person.age) > Int64(10) OR sum(person.age) IN ([$3, $4])\ - \n Aggregate: groupBy=[[person.id]], aggr=[[sum(person.age)]]\ - \n Filter: person.salary > $2\ - \n TableScan: person"; - - let expected_dt = "[Int32, Float64, Float64, Float64]"; - - let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [Int32, Float64, Float64, Float64] + Projection: person.id, sum(person.age) + Filter: sum(person.age) < $1 AND sum(person.age) > Int64(10) OR sum(person.age) IN ([$3, $4]) + Aggregate: groupBy=[[person.id]], aggr=[[sum(person.age)]] + Filter: person.salary > $2 + TableScan: person + "# + ); + assert_snapshot!(dt, @r#"[Int32, Float64, Float64, Float64]"#); /////////////////// // replace params with values @@ -4960,15 +4977,17 @@ fn test_prepare_statement_to_plan_limit() { let sql = "PREPARE my_plan(BIGINT, BIGINT) AS SELECT id FROM person \ OFFSET $1 LIMIT $2"; - - let expected_plan = "Prepare: \"my_plan\" [Int64, Int64] \ - \n Limit: skip=$1, fetch=$2\ - \n Projection: person.id\ - \n TableScan: person"; - - let expected_dt = "[Int64, Int64]"; - - let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [Int64, Int64] + Limit: skip=$1, fetch=$2 + Projection: person.id + TableScan: person + "# + ); + assert_snapshot!(dt, @r#"[Int64, Int64]"#); // replace params with values let param_values = vec![ScalarValue::Int64(Some(10)), ScalarValue::Int64(Some(200))];