diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index 6928e98b789b..969443f6235e 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -116,3 +116,65 @@ where o_orderstatus in ( Ok(()) } + +#[tokio::test] +async fn exists_subquery_with_same_table() -> Result<()> { + let ctx = create_join_context("t1_id", "t2_id", true)?; + + // Subquery and outer query refer to the same table. + // It will not be rewritten to join because it is not a correlated subquery. + let sql = "SELECT t1_id, t1_name, t1_int FROM t1 WHERE EXISTS(SELECT t1_int FROM t1 WHERE t1.t1_id > t1.t1_int)"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " Filter: EXISTS () [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " Subquery: [t1_int:UInt32;N]", + " Projection: t1.t1_int [t1_int:UInt32;N]", + " Filter: t1.t1_id > t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t1 [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + Ok(()) +} + +#[tokio::test] +async fn in_subquery_with_same_table() -> Result<()> { + let ctx = create_join_context("t1_id", "t2_id", true)?; + + // Subquery and outer query refer to the same table. + // It will be rewritten to join because in-subquery has extra predicate(`t1.t1_id = __correlated_sq_1.t1_int`). + let sql = "SELECT t1_id, t1_name, t1_int FROM t1 WHERE t1_id IN(SELECT t1_int FROM t1 WHERE t1.t1_id > t1.t1_int)"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " LeftSemi Join: t1.t1_id = __correlated_sq_1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " SubqueryAlias: __correlated_sq_1 [t1_int:UInt32;N]", + " Projection: t1.t1_int AS t1_int [t1_int:UInt32;N]", + " Filter: t1.t1_id > t1.t1_int [t1_id:UInt32;N, t1_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_int] [t1_id:UInt32;N, t1_int:UInt32;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + Ok(()) +} diff --git a/datafusion/optimizer/src/decorrelate_where_exists.rs b/datafusion/optimizer/src/decorrelate_where_exists.rs index 3f6b160fa9a0..72a68b3123b0 100644 --- a/datafusion/optimizer/src/decorrelate_where_exists.rs +++ b/datafusion/optimizer/src/decorrelate_where_exists.rs @@ -644,4 +644,30 @@ mod tests { assert_plan_eq(&plan, expected) } + + #[test] + fn exists_subquery_with_same_table() -> Result<()> { + let outer_scan = test_table_scan()?; + let subquery_scan = test_table_scan()?; + let subquery = LogicalPlanBuilder::from(subquery_scan) + .filter(col("test.a").gt(col("test.b")))? + .project(vec![col("c")])? + .build()?; + + let plan = LogicalPlanBuilder::from(outer_scan) + .filter(exists(Arc::new(subquery)))? + .project(vec![col("test.b")])? + .build()?; + + // Subquery and outer query refer to the same table. + let expected = "Projection: test.b [b:UInt32]\ + \n Filter: EXISTS () [a:UInt32, b:UInt32, c:UInt32]\ + \n Subquery: [c:UInt32]\ + \n Projection: test.c [c:UInt32]\ + \n Filter: test.a > test.b [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_plan_eq(&plan, expected) + } } diff --git a/datafusion/optimizer/src/decorrelate_where_in.rs b/datafusion/optimizer/src/decorrelate_where_in.rs index c8ff65f12523..bc70098610f9 100644 --- a/datafusion/optimizer/src/decorrelate_where_in.rs +++ b/datafusion/optimizer/src/decorrelate_where_in.rs @@ -1149,4 +1149,35 @@ mod tests { ); Ok(()) } + + #[test] + fn in_subquery_with_same_table() -> Result<()> { + let outer_scan = test_table_scan()?; + let subquery_scan = test_table_scan()?; + let subquery = LogicalPlanBuilder::from(subquery_scan) + .filter(col("test.a").gt(col("test.b")))? + .project(vec![col("c")])? + .build()?; + + let plan = LogicalPlanBuilder::from(outer_scan) + .filter(in_subquery(col("test.a"), Arc::new(subquery)))? + .project(vec![col("test.b")])? + .build()?; + + // Subquery and outer query refer to the same table. + let expected = "Projection: test.b [b:UInt32]\ + \n LeftSemi Join: Filter: test.a = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ + \n Projection: test.c AS c [c:UInt32]\ + \n Filter: test.a > test.b [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_eq_display_indent( + Arc::new(DecorrelateWhereIn::new()), + &plan, + expected, + ); + Ok(()) + } }