diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index f989d05c80dd..36ee78fe5d9a 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -37,6 +37,7 @@ use datafusion::logical_expr::{ }; use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST; use datafusion::prelude::*; +use insta::assert_snapshot; use std::hash::Hash; use std::sync::Arc; use substrait::proto::extensions::simple_extension_declaration::MappingType; @@ -188,13 +189,16 @@ async fn simple_select() -> Result<()> { #[tokio::test] async fn wildcard_select() -> Result<()> { - assert_expected_plan_unoptimized( - "SELECT * FROM data", - "Projection: data.a, data.b, data.c, data.d, data.e, data.f\ - \n TableScan: data", - true, - ) - .await + let plan = generate_plan_from_sql("SELECT * FROM data", true, false).await?; + + assert_snapshot!( + plan, + @r#" + Projection: data.a, data.b, data.c, data.d, data.e, data.f + TableScan: data + "# + ); + Ok(()) } #[tokio::test] @@ -299,24 +303,42 @@ async fn aggregate_grouping_sets() -> Result<()> { #[tokio::test] async fn aggregate_grouping_rollup() -> Result<()> { - assert_expected_plan( + let plan = generate_plan_from_sql( "SELECT a, c, e, avg(b) FROM data GROUP BY ROLLUP (a, c, e)", - "Projection: data.a, data.c, data.e, avg(data.b)\ - \n Aggregate: groupBy=[[GROUPING SETS ((data.a, data.c, data.e), (data.a, data.c), (data.a), ())]], aggr=[[avg(data.b)]]\ - \n TableScan: data projection=[a, b, c, e]", - true - ).await + true, + true, + ) + .await?; + + assert_snapshot!( + plan, + @r#" + Projection: data.a, data.c, data.e, avg(data.b) + Aggregate: groupBy=[[GROUPING SETS ((data.a, data.c, data.e), (data.a, data.c), (data.a), ())]], aggr=[[avg(data.b)]] + TableScan: data projection=[a, b, c, e] + "# + ); + Ok(()) } #[tokio::test] async fn multilayer_aggregate() -> Result<()> { - assert_expected_plan( + let plan = generate_plan_from_sql( "SELECT a, sum(partial_count_b) FROM (SELECT a, count(b) as partial_count_b FROM data GROUP BY a) GROUP BY a", - "Aggregate: groupBy=[[data.a]], aggr=[[sum(count(data.b)) AS sum(partial_count_b)]]\ - \n Aggregate: groupBy=[[data.a]], aggr=[[count(data.b)]]\ - \n TableScan: data projection=[a, b]", - true - ).await + true, + true, + ) + .await?; + + assert_snapshot!( + plan, + @r#" + Aggregate: groupBy=[[data.a]], aggr=[[sum(count(data.b)) AS sum(partial_count_b)]] + Aggregate: groupBy=[[data.a]], aggr=[[count(data.b)]] + TableScan: data projection=[a, b] + "# + ); + Ok(()) } #[tokio::test] @@ -454,13 +476,21 @@ async fn try_cast_decimal_to_string() -> Result<()> { #[tokio::test] async fn aggregate_case() -> Result<()> { - assert_expected_plan( + let plan = generate_plan_from_sql( "SELECT sum(CASE WHEN a > 0 THEN 1 ELSE NULL END) FROM data", - "Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN data.a > Int64(0) THEN Int64(1) ELSE Int64(NULL) END) AS sum(CASE WHEN data.a > Int64(0) THEN Int64(1) ELSE NULL END)]]\ - \n TableScan: data projection=[a]", - true + true, + true, ) - .await + .await?; + + assert_snapshot!( + plan, + @r#" + Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN data.a > Int64(0) THEN Int64(1) ELSE Int64(NULL) END) AS sum(CASE WHEN data.a > Int64(0) THEN Int64(1) ELSE NULL END)]] + TableScan: data projection=[a] + "# + ); + Ok(()) } #[tokio::test] @@ -493,18 +523,27 @@ async fn roundtrip_inlist_4() -> Result<()> { #[tokio::test] async fn roundtrip_inlist_5() -> Result<()> { // on roundtrip there is an additional projection during TableScan which includes all column of the table, - // using assert_expected_plan here as a workaround - assert_expected_plan( + // using assert_and_generate_plan and assert_snapshot! here as a workaround + let plan = generate_plan_from_sql( "SELECT a, f FROM data WHERE (f IN ('a', 'b', 'c') OR a in (SELECT data2.a FROM data2 WHERE f IN ('b', 'c', 'd')))", + true, + true, + ) + .await?; - "Projection: data.a, data.f\ - \n Filter: data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = Utf8(\"c\") OR data2.mark\ - \n LeftMark Join: data.a = data2.a\ - \n TableScan: data projection=[a, f]\ - \n Projection: data2.a\ - \n Filter: data2.f = Utf8(\"b\") OR data2.f = Utf8(\"c\") OR data2.f = Utf8(\"d\")\ - \n TableScan: data2 projection=[a, f], partial_filters=[data2.f = Utf8(\"b\") OR data2.f = Utf8(\"c\") OR data2.f = Utf8(\"d\")]", - true).await + assert_snapshot!( + plan, + @r#" + Projection: data.a, data.f + Filter: data.f = Utf8("a") OR data.f = Utf8("b") OR data.f = Utf8("c") OR data2.mark + LeftMark Join: data.a = data2.a + TableScan: data projection=[a, f] + Projection: data2.a + Filter: data2.f = Utf8("b") OR data2.f = Utf8("c") OR data2.f = Utf8("d") + TableScan: data2 projection=[a, f], partial_filters=[data2.f = Utf8("b") OR data2.f = Utf8("c") OR data2.f = Utf8("d")] + "# + ); + Ok(()) } #[tokio::test] @@ -535,27 +574,44 @@ async fn roundtrip_non_equi_join() -> Result<()> { #[tokio::test] async fn roundtrip_exists_filter() -> Result<()> { - assert_expected_plan( + let plan = generate_plan_from_sql( "SELECT b FROM data d1 WHERE EXISTS (SELECT * FROM data2 d2 WHERE d2.a = d1.a AND d2.e != d1.e)", - "Projection: data.b\ - \n LeftSemi Join: data.a = data2.a Filter: data2.e != CAST(data.e AS Int64)\ - \n TableScan: data projection=[a, b, e]\ - \n TableScan: data2 projection=[a, e]", - false // "d1" vs "data" field qualifier - ).await + false, + true, + ) + .await?; + + assert_snapshot!( + plan, + @r#" + Projection: data.b + LeftSemi Join: data.a = data2.a Filter: data2.e != CAST(data.e AS Int64) + TableScan: data projection=[a, b, e] + TableScan: data2 projection=[a, e] + "# + ); + Ok(()) } #[tokio::test] async fn inner_join() -> Result<()> { - assert_expected_plan( + let plan = generate_plan_from_sql( "SELECT data.a FROM data JOIN data2 ON data.a = data2.a", - "Projection: data.a\ - \n Inner Join: data.a = data2.a\ - \n TableScan: data projection=[a]\ - \n TableScan: data2 projection=[a]", + true, true, ) - .await + .await?; + + assert_snapshot!( + plan, + @r#" + Projection: data.a + Inner Join: data.a = data2.a + TableScan: data projection=[a] + TableScan: data2 projection=[a] + "# + ); + Ok(()) } #[tokio::test] @@ -592,17 +648,25 @@ async fn roundtrip_self_implicit_cross_join() -> Result<()> { #[tokio::test] async fn self_join_introduces_aliases() -> Result<()> { - assert_expected_plan( + let plan = generate_plan_from_sql( "SELECT d1.b, d2.c FROM data d1 JOIN data d2 ON d1.b = d2.b", - "Projection: left.b, right.c\ - \n Inner Join: left.b = right.b\ - \n SubqueryAlias: left\ - \n TableScan: data projection=[b]\ - \n SubqueryAlias: right\ - \n TableScan: data projection=[b, c]", false, + true, ) - .await + .await?; + + assert_snapshot!( + plan, + @r#" + Projection: left.b, right.c + Inner Join: left.b = right.b + SubqueryAlias: left + TableScan: data projection=[b] + SubqueryAlias: right + TableScan: data projection=[b, c] + "# + ); + Ok(()) } #[tokio::test] @@ -747,12 +811,15 @@ async fn aggregate_wo_projection_consume() -> Result<()> { let proto_plan = read_json("tests/testdata/test_plans/aggregate_no_project.substrait.json"); - assert_expected_plan_substrait( - proto_plan, - "Aggregate: groupBy=[[data.a]], aggr=[[count(data.a) AS countA]]\ - \n TableScan: data projection=[a]", - ) - .await + let plan = generate_plan_from_substrait(proto_plan).await?; + assert_snapshot!( + plan, + @r#" + Aggregate: groupBy=[[data.a]], aggr=[[count(data.a) AS countA]] + TableScan: data projection=[a] + "# + ); + Ok(()) } #[tokio::test] @@ -760,12 +827,15 @@ async fn aggregate_wo_projection_group_expression_ref_consume() -> Result<()> { let proto_plan = read_json("tests/testdata/test_plans/aggregate_no_project_group_expression_ref.substrait.json"); - assert_expected_plan_substrait( - proto_plan, - "Aggregate: groupBy=[[data.a]], aggr=[[count(data.a) AS countA]]\ - \n TableScan: data projection=[a]", - ) - .await + let plan = generate_plan_from_substrait(proto_plan).await?; + assert_snapshot!( + plan, + @r#" + Aggregate: groupBy=[[data.a]], aggr=[[count(data.a) AS countA]] + TableScan: data projection=[a] + "# + ); + Ok(()) } #[tokio::test] @@ -773,12 +843,15 @@ async fn aggregate_wo_projection_sorted_consume() -> Result<()> { let proto_plan = read_json("tests/testdata/test_plans/aggregate_sorted_no_project.substrait.json"); - assert_expected_plan_substrait( - proto_plan, - "Aggregate: groupBy=[[data.a]], aggr=[[count(data.a) ORDER BY [data.a DESC NULLS FIRST] AS countA]]\ - \n TableScan: data projection=[a]", - ) - .await + let plan = generate_plan_from_substrait(proto_plan).await?; + assert_snapshot!( + plan, + @r#" + Aggregate: groupBy=[[data.a]], aggr=[[count(data.a) ORDER BY [data.a DESC NULLS FIRST] AS countA]] + TableScan: data projection=[a] + "# + ); + Ok(()) } #[tokio::test] @@ -986,19 +1059,27 @@ async fn roundtrip_literal_list() -> Result<()> { #[tokio::test] async fn roundtrip_literal_struct() -> Result<()> { - assert_expected_plan( + let plan = generate_plan_from_sql( "SELECT STRUCT(1, true, CAST(NULL AS STRING)) FROM data", - "Projection: Struct({c0:1,c1:true,c2:}) AS struct(Int64(1),Boolean(true),NULL)\ - \n TableScan: data projection=[]", - false, // "Struct(..)" vs "struct(..)" + false, + true, ) - .await + .await?; + + assert_snapshot!( + plan, + @r#" + Projection: Struct({c0:1,c1:true,c2:}) AS struct(Int64(1),Boolean(true),NULL) + TableScan: data projection=[] + "# + ); + Ok(()) } #[tokio::test] async fn roundtrip_values() -> Result<()> { // TODO: would be nice to have a struct inside the LargeList, but arrow_cast doesn't support that currently - assert_expected_plan( + let plan = generate_plan_from_sql( "VALUES \ (\ 1, \ @@ -1009,17 +1090,18 @@ async fn roundtrip_values() -> Result<()> { [STRUCT(STRUCT('a' AS string_field) AS struct_field), STRUCT(STRUCT('b' AS string_field) AS struct_field)]\ ), \ (NULL, NULL, NULL, NULL, NULL, NULL)", - "Values: \ - (\ - Int64(1), \ - Utf8(\"a\"), \ - List([[-213.1, , 5.5, 2.0, 1.0], []]), \ - LargeList([1, 2, 3]), \ - Struct({c0:true,int_field:1,c2:}), \ - List([{struct_field: {string_field: a}}, {struct_field: {string_field: b}}])\ - ), \ - (Int64(NULL), Utf8(NULL), List(), LargeList(), Struct({c0:,int_field:,c2:}), List())", - true).await + true, + true, + ) + .await?; + + assert_snapshot!( + plan, + @r#" + Values: (Int64(1), Utf8("a"), List([[-213.1, , 5.5, 2.0, 1.0], []]), LargeList([1, 2, 3]), Struct({c0:true,int_field:1,c2:}), List([{struct_field: {string_field: a}}, {struct_field: {string_field: b}}])), (Int64(NULL), Utf8(NULL), List(), LargeList(), Struct({c0:,int_field:,c2:}), List()) + "# + ); + Ok(()) } #[tokio::test] @@ -1061,14 +1143,22 @@ async fn duplicate_column() -> Result<()> { // only. DataFusion however, is strict about not having duplicate column names appear in the plan. // This test confirms that we generate aliases for columns in the plan which would otherwise have // colliding names. - assert_expected_plan( + let plan = generate_plan_from_sql( "SELECT a + 1 as sum_a, a + 1 as sum_a_2 FROM data", - "Projection: data.a + Int64(1) AS sum_a, data.a + Int64(1) AS data.a + Int64(1)__temp__0 AS sum_a_2\ - \n Projection: data.a + Int64(1)\ - \n TableScan: data projection=[a]", + true, true, ) - .await + .await?; + + assert_snapshot!( + plan, + @r#" + Projection: data.a + Int64(1) AS sum_a, data.a + Int64(1) AS data.a + Int64(1)__temp__0 AS sum_a_2 + Projection: data.a + Int64(1) + TableScan: data projection=[a] + "# + ); + Ok(()) } /// Construct a plan that cast columns. Only those SQL types are supported for now. @@ -1374,30 +1464,32 @@ async fn assert_read_filter_count( Ok(()) } -async fn assert_expected_plan_unoptimized( +async fn generate_plan_from_sql( sql: &str, - expected_plan_str: &str, assert_schema: bool, -) -> Result<()> { + optimized: bool, +) -> Result { let ctx = create_context().await?; - let df = ctx.sql(sql).await?; - let plan = df.into_unoptimized_plan(); - let proto = to_substrait_plan(&plan, &ctx.state())?; - let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; - - println!("{plan}"); - println!("{plan2}"); + let df: DataFrame = ctx.sql(sql).await?; - println!("{proto:?}"); + let plan = if optimized { + df.into_optimized_plan()? + } else { + df.into_unoptimized_plan() + }; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let plan2 = if optimized { + let temp = from_substrait_plan(&ctx.state(), &proto).await?; + ctx.state().optimize(&temp)? + } else { + from_substrait_plan(&ctx.state(), &proto).await? + }; if assert_schema { assert_eq!(plan.schema(), plan2.schema()); } - let plan2str = format!("{plan2}"); - assert_eq!(expected_plan_str, &plan2str); - - Ok(()) + Ok(plan2) } async fn assert_expected_plan( @@ -1412,11 +1504,6 @@ async fn assert_expected_plan( let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; let plan2 = ctx.state().optimize(&plan2)?; - println!("{plan}"); - println!("{plan2}"); - - println!("{proto:?}"); - if assert_schema { assert_eq!(plan.schema(), plan2.schema()); } @@ -1427,20 +1514,14 @@ async fn assert_expected_plan( Ok(()) } -async fn assert_expected_plan_substrait( - substrait_plan: Plan, - expected_plan_str: &str, -) -> Result<()> { +async fn generate_plan_from_substrait(substrait_plan: Plan) -> Result { let ctx = create_context().await?; let plan = from_substrait_plan(&ctx.state(), &substrait_plan).await?; let plan = ctx.state().optimize(&plan)?; - let planstr = format!("{plan}"); - assert_eq!(planstr, expected_plan_str); - - Ok(()) + Ok(plan) } async fn assert_substrait_sql(substrait_plan: Plan, sql: &str) -> Result<()> { @@ -1491,9 +1572,6 @@ async fn test_alias(sql_with_alias: &str, sql_no_alias: &str) -> Result<()> { let proto = to_substrait_plan(&df.into_optimized_plan()?, &ctx.state())?; let plan = from_substrait_plan(&ctx.state(), &proto).await?; - println!("{plan_with_alias}"); - println!("{plan}"); - let plan1str = format!("{plan_with_alias}"); let plan2str = format!("{plan}"); assert_eq!(plan1str, plan2str); @@ -1510,11 +1588,6 @@ async fn roundtrip_logical_plan_with_ctx( let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; let plan2 = ctx.state().optimize(&plan2)?; - println!("{plan}"); - println!("{plan2}"); - - println!("{proto:?}"); - let plan1str = format!("{plan}"); let plan2str = format!("{plan2}"); assert_eq!(plan1str, plan2str);