Skip to content

Commit

Permalink
Support non-tuple expression for exists-subquery to join (#5264)
Browse files Browse the repository at this point in the history
* Support non-tuple expression for exists-subquery to join

* fix tests

* add tests

* add comments

* fix tests

* fix test comment
  • Loading branch information
ygf11 committed Feb 18, 2023
1 parent 5d5b1a0 commit 27b15fd
Show file tree
Hide file tree
Showing 8 changed files with 389 additions and 156 deletions.
12 changes: 7 additions & 5 deletions benchmarks/expected-plans/q21.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ Sort: numwait DESC NULLS FIRST, supplier.s_name ASC NULLS LAST
TableScan: orders projection=[o_orderkey, o_orderstatus]
Filter: nation.n_name = Utf8("SAUDI ARABIA")
TableScan: nation projection=[n_nationkey, n_name]
SubqueryAlias: l2
TableScan: lineitem projection=[l_orderkey, l_suppkey]
SubqueryAlias: l3
Filter: lineitem.l_receiptdate > lineitem.l_commitdate
TableScan: lineitem projection=[l_orderkey, l_suppkey, l_commitdate, l_receiptdate]
Projection: l2.l_orderkey, l2.l_suppkey
SubqueryAlias: l2
TableScan: lineitem projection=[l_orderkey, l_suppkey]
Projection: l3.l_orderkey, l3.l_suppkey
SubqueryAlias: l3
Filter: lineitem.l_receiptdate > lineitem.l_commitdate
TableScan: lineitem projection=[l_orderkey, l_suppkey, l_commitdate, l_receiptdate]
3 changes: 2 additions & 1 deletion benchmarks/expected-plans/q22.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ Sort: custsale.cntrycode ASC NULLS LAST
LeftAnti Join: customer.c_custkey = orders.o_custkey
Filter: substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])
TableScan: customer projection=[c_custkey, c_phone, c_acctbal]
TableScan: orders projection=[o_custkey]
Projection: orders.o_custkey
TableScan: orders projection=[o_custkey]
SubqueryAlias: __scalar_sq_1
Projection: AVG(customer.c_acctbal) AS __value
Aggregate: groupBy=[[]], aggr=[[AVG(customer.c_acctbal)]]
Expand Down
5 changes: 3 additions & 2 deletions benchmarks/expected-plans/q4.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@ Sort: orders.o_orderpriority ASC NULLS LAST
LeftSemi Join: orders.o_orderkey = lineitem.l_orderkey
Filter: orders.o_orderdate >= Date32("8582") AND orders.o_orderdate < Date32("8674")
TableScan: orders projection=[o_orderkey, o_orderdate, o_orderpriority]
Filter: lineitem.l_commitdate < lineitem.l_receiptdate
TableScan: lineitem projection=[l_orderkey, l_commitdate, l_receiptdate]
Projection: lineitem.l_orderkey
Filter: lineitem.l_commitdate < lineitem.l_receiptdate
TableScan: lineitem projection=[l_orderkey, l_commitdate, l_receiptdate]
203 changes: 187 additions & 16 deletions datafusion/core/tests/sql/joins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2187,9 +2187,8 @@ async fn left_anti_join() -> Result<()> {
}

#[tokio::test]
#[ignore = "Test ignored, will be enabled after fixing the anti join plan bug"]
// https://github.com/apache/arrow-datafusion/issues/4366
async fn error_left_anti_join() -> Result<()> {
// https://github.com/apache/arrow-datafusion/issues/4366
let test_repartition_joins = vec![true, false];
for repartition_joins in test_repartition_joins {
let ctx = create_left_semi_anti_join_context_with_null_ids(
Expand Down Expand Up @@ -2255,27 +2254,29 @@ async fn right_semi_join() -> Result<()> {
let dataframe = ctx.sql(sql).await.expect(&msg);
let physical_plan = dataframe.create_physical_plan().await?;
let expected = if repartition_joins {
vec![ "SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]",
" SortExec: expr=[t1_id@0 ASC NULLS LAST]",
" ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int]",
" CoalesceBatchesExec: target_batch_size=4096",
" HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(Column { name: \"t2_id\", index: 0 }, Column { name: \"t1_id\", index: 0 })], filter=BinaryExpr { left: Column { name: \"t2_name\", index: 1 }, op: NotEq, right: Column { name: \"t1_name\", index: 0 } }",
" CoalesceBatchesExec: target_batch_size=4096",
" RepartitionExec: partitioning=Hash([Column { name: \"t2_id\", index: 0 }], 2), input_partitions=2",
" RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1",
" MemoryExec: partitions=1, partition_sizes=[1]",
" CoalesceBatchesExec: target_batch_size=4096",
" RepartitionExec: partitioning=Hash([Column { name: \"t1_id\", index: 0 }], 2), input_partitions=2",
" RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1",
" MemoryExec: partitions=1, partition_sizes=[1]",
vec!["SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]",
" SortExec: expr=[t1_id@0 ASC NULLS LAST]",
" ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int]",
" CoalesceBatchesExec: target_batch_size=4096",
" HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(Column { name: \"t2_id\", index: 0 }, Column { name: \"t1_id\", index: 0 })], filter=BinaryExpr { left: Column { name: \"t2_name\", index: 1 }, op: NotEq, right: Column { name: \"t1_name\", index: 0 } }",
" CoalesceBatchesExec: target_batch_size=4096",
" RepartitionExec: partitioning=Hash([Column { name: \"t2_id\", index: 0 }], 2), input_partitions=2",
" RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1",
" ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as t2_name]",
" MemoryExec: partitions=1, partition_sizes=[1]",
" CoalesceBatchesExec: target_batch_size=4096",
" RepartitionExec: partitioning=Hash([Column { name: \"t1_id\", index: 0 }], 2), input_partitions=2",
" RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1",
" MemoryExec: partitions=1, partition_sizes=[1]",
]
} else {
vec![
"SortExec: expr=[t1_id@0 ASC NULLS LAST]",
" ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int]",
" CoalesceBatchesExec: target_batch_size=4096",
" HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(Column { name: \"t2_id\", index: 0 }, Column { name: \"t1_id\", index: 0 })], filter=BinaryExpr { left: Column { name: \"t2_name\", index: 1 }, op: NotEq, right: Column { name: \"t1_name\", index: 0 } }",
" MemoryExec: partitions=1, partition_sizes=[1]",
" ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as t2_name]",
" MemoryExec: partitions=1, partition_sizes=[1]",
" MemoryExec: partitions=1, partition_sizes=[1]",
]
};
Expand Down Expand Up @@ -3393,3 +3394,173 @@ async fn left_as_inner_table_nested_loop_join() -> Result<()> {

Ok(())
}

#[tokio::test]
async fn exists_subquery_to_join_expr_filter() -> Result<()> {
let test_repartition_joins = vec![true, false];
for repartition_joins in test_repartition_joins {
let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?;

// exists subquery to LeftSemi join
let sql = "SELECT * FROM t1 WHERE EXISTS(SELECT t2_id FROM t2 WHERE t1.t1_id + 1 > t2.t2_id * 2)";
let msg = format!("Creating logical plan for '{sql}'");
let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
let plan = dataframe.into_optimized_plan()?;

let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" LeftSemi Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" Projection: t2.t2_id [t2_id:UInt32;N]",
" TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
expected, actual,
"\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
);
let expected = vec![
"+-------+---------+--------+",
"| t1_id | t1_name | t1_int |",
"+-------+---------+--------+",
"| 22 | b | 2 |",
"| 33 | c | 3 |",
"| 44 | d | 4 |",
"+-------+---------+--------+",
];

let results = execute_to_batches(&ctx, sql).await;
assert_batches_sorted_eq!(expected, &results);
}

Ok(())
}

#[tokio::test]
async fn exists_subquery_to_join_inner_filter() -> Result<()> {
let test_repartition_joins = vec![true, false];
for repartition_joins in test_repartition_joins {
let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?;

// exists subquery to LeftSemi join
let sql = "SELECT * FROM t1 WHERE EXISTS(SELECT t2_id FROM t2 WHERE t1.t1_id + 1 > t2.t2_id * 2 AND t2.t2_int < 3)";
let msg = format!("Creating logical plan for '{sql}'");
let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
let plan = dataframe.into_optimized_plan()?;

// `t2.t2_int < 3` will be kept in the subquery filter.
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" LeftSemi Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" Projection: t2.t2_id [t2_id:UInt32;N]",
" Filter: t2.t2_int < UInt32(3) [t2_id:UInt32;N, t2_int:UInt32;N]",
" TableScan: t2 projection=[t2_id, t2_int] [t2_id:UInt32;N, t2_int:UInt32;N]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
expected, actual,
"\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
);
let expected = vec![
"+-------+---------+--------+",
"| t1_id | t1_name | t1_int |",
"+-------+---------+--------+",
"| 44 | d | 4 |",
"+-------+---------+--------+",
];

let results = execute_to_batches(&ctx, sql).await;
assert_batches_sorted_eq!(expected, &results);
}

Ok(())
}

#[tokio::test]
async fn exists_subquery_to_join_outer_filter() -> Result<()> {
let test_repartition_joins = vec![true, false];
for repartition_joins in test_repartition_joins {
let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?;

// exists subquery to LeftSemi join
let sql = "SELECT * FROM t1 WHERE EXISTS(SELECT t2_id FROM t2 WHERE t1.t1_id + 1 > t2.t2_id * 2 AND t1.t1_int < 3)";
let msg = format!("Creating logical plan for '{sql}'");
let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
let plan = dataframe.into_optimized_plan()?;

// `t1.t1_int < 3` will be moved to the filter of t1.
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" LeftSemi Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" Filter: t1.t1_int < UInt32(3) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" Projection: t2.t2_id [t2_id:UInt32;N]",
" TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
expected, actual,
"\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
);
let expected = vec![
"+-------+---------+--------+",
"| t1_id | t1_name | t1_int |",
"+-------+---------+--------+",
"| 22 | b | 2 |",
"+-------+---------+--------+",
];

let results = execute_to_batches(&ctx, sql).await;
assert_batches_sorted_eq!(expected, &results);
}

Ok(())
}

#[tokio::test]
async fn not_exists_subquery_to_join_expr_filter() -> Result<()> {
let test_repartition_joins = vec![true, false];
for repartition_joins in test_repartition_joins {
let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?;

// not exists subquery to LeftAnti join
let sql = "SELECT * FROM t1 WHERE NOT EXISTS(SELECT t2_id FROM t2 WHERE t1.t1_id + 1 > t2.t2_id * 2)";
let msg = format!("Creating logical plan for '{sql}'");
let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
let plan = dataframe.into_optimized_plan()?;

let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" LeftAnti Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" Projection: t2.t2_id [t2_id:UInt32;N]",
" TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
expected, actual,
"\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
);
let expected = vec![
"+-------+---------+--------+",
"| t1_id | t1_name | t1_int |",
"+-------+---------+--------+",
"| 11 | a | 1 |",
"+-------+---------+--------+",
];

let results = execute_to_batches(&ctx, sql).await;
assert_batches_sorted_eq!(expected, &results);
}

Ok(())
}
Loading

0 comments on commit 27b15fd

Please sign in to comment.