Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support join-filter pushdown for semi/anti join #4923

Merged
merged 6 commits into from
Jan 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions datafusion/core/tests/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,95 @@ async fn join_with_alias_filter() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn right_semi_with_alias_filter() -> Result<()> {
let join_ctx = create_join_context()?;
let t1 = join_ctx.table("t1").await?;
let t2 = join_ctx.table("t2").await?;

// t1.a = t2.a and t1.c > 1 and t2.c > 1
let filter = col("t1.a")
.eq(col("t2.a"))
.and(col("t1.c").gt(lit(1u32)))
.and(col("t2.c").gt(lit(1u32)));

let df = t1
.join(t2, JoinType::RightSemi, &[], &[], Some(filter))?
.select(vec![col("t2.a"), col("t2.b"), col("t2.c")])?;
let optimized_plan = df.clone().into_optimized_plan()?;
let expected = vec![
"Projection: t2.a, t2.b, t2.c [a:UInt32, b:Utf8, c:Int32]",
" RightSemi Join: t1.a = t2.a [a:UInt32, b:Utf8, c:Int32]",
" Filter: t1.c > Int32(1) [a:UInt32, c:Int32]",
" TableScan: t1 projection=[a, c] [a:UInt32, c:Int32]",
" Filter: t2.c > Int32(1) [a:UInt32, b:Utf8, c:Int32]",
" TableScan: t2 projection=[a, b, c] [a:UInt32, b:Utf8, c:Int32]",
];

let formatted = optimized_plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
expected, actual,
"\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
);

let results = df.collect().await?;
let expected: Vec<&str> = vec![
"+-----+---+---+",
"| a | b | c |",
"+-----+---+---+",
"| 10 | b | 2 |",
"| 100 | d | 4 |",
"+-----+---+---+",
];
assert_batches_sorted_eq!(expected, &results);
Ok(())
}

#[tokio::test]
async fn right_anti_filter_push_down() -> Result<()> {
let join_ctx = create_join_context()?;
let t1 = join_ctx.table("t1").await?;
let t2 = join_ctx.table("t2").await?;

// t1.a = t2.a and t1.c > 1 and t2.c > 1
let filter = col("t1.a")
.eq(col("t2.a"))
.and(col("t1.c").gt(lit(1u32)))
.and(col("t2.c").gt(lit(1u32)));

let df = t1
.join(t2, JoinType::RightAnti, &[], &[], Some(filter))?
.select(vec![col("t2.a"), col("t2.b"), col("t2.c")])?;
let optimized_plan = df.clone().into_optimized_plan()?;
let expected = vec![
"Projection: t2.a, t2.b, t2.c [a:UInt32, b:Utf8, c:Int32]",
" RightAnti Join: t1.a = t2.a Filter: t2.c > Int32(1) [a:UInt32, b:Utf8, c:Int32]",
" Filter: t1.c > Int32(1) [a:UInt32, c:Int32]",
" TableScan: t1 projection=[a, c] [a:UInt32, c:Int32]",
" TableScan: t2 projection=[a, b, c] [a:UInt32, b:Utf8, c:Int32]",
];

let formatted = optimized_plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
expected, actual,
"\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
);

let results = df.collect().await?;
let expected: Vec<&str> = vec![
"+----+---+---+",
"| a | b | c |",
"+----+---+---+",
"| 13 | c | 3 |",
"| 3 | a | 1 |",
"+----+---+---+",
];
assert_batches_sorted_eq!(expected, &results);
Ok(())
}

async fn create_test_table() -> Result<DataFrame> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Utf8, false),
Expand Down
71 changes: 57 additions & 14 deletions datafusion/core/tests/sql/joins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2112,6 +2112,34 @@ async fn left_semi_join() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn left_semi_join_pushdown() -> Result<()> {
let ctx = create_left_semi_anti_join_context_with_null_ids("t1_id", "t2_id", false)
.unwrap();

// assert logical plan
let sql = "SELECT t1.t1_id, t1.t1_name FROM t1 LEFT SEMI JOIN t2 ON (t1.t1_id = t2.t2_id and t2.t2_int > 1)";
let msg = format!("Creating logical plan for '{sql}'");
let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
let plan = dataframe.into_optimized_plan()?;
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.t1_id, t1.t1_name [t1_id:UInt32;N, t1_name:Utf8;N]",
" LeftSemi Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N]",
" TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]",
" Filter: t2.t2_int > UInt32(1) [t2_id:UInt32;N, t2_int:UInt32;N]",
" TableScan: t2 projection=[t2_id, t2_int] [t2_id:UInt32;N, t2_int:UInt32;N]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
expected, actual,
"\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
);

Ok(())
}

#[tokio::test]
async fn left_anti_join() -> Result<()> {
let test_repartition_joins = vec![true, false];
Expand Down Expand Up @@ -3126,13 +3154,12 @@ async fn in_subquery_to_join_with_correlated_outer_filter() -> Result<()> {
let msg = format!("Creating logical plan for '{sql}'");
let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
let plan = dataframe.into_optimized_plan().unwrap();

// The `t1.t1_int > UInt32(0)` should be pushdown by `filter push down rule`.
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) Filter: t1.t1_int > UInt32(0) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" Filter: t1.t1_int > UInt32(0) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N]",
" Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1) [CAST(t2_id AS Int64) + Int64(1):Int64;N]",
" TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
Expand All @@ -3144,20 +3171,36 @@ async fn in_subquery_to_join_with_correlated_outer_filter() -> Result<()> {
expected, actual,
"\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
);
Ok(())
}

#[tokio::test]
async fn not_in_subquery_to_join_with_correlated_outer_filter() -> Result<()> {
let ctx = create_join_context("t1_id", "t2_id", false)?;

let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 not in
(select t2.t2_id + 1 from t2 where t1.t1_int > 0)";

// assert logical plan
let msg = format!("Creating logical plan for '{sql}'");
let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
let plan = dataframe.into_optimized_plan().unwrap();
let expected = vec![
"+-------+---------+--------+",
"| t1_id | t1_name | t1_int |",
"+-------+---------+--------+",
"| 11 | a | 1 |",
"| 33 | c | 3 |",
"| 44 | d | 4 |",
"+-------+---------+--------+",
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" LeftAnti Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) Filter: t1.t1_int > UInt32(0) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N]",
" Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1) [CAST(t2_id AS Int64) + Int64(1):Int64;N]",
" TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
];

let results = execute_to_batches(&ctx, sql).await;
assert_batches_sorted_eq!(expected, &results);

let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
expected, actual,
"\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
);
Ok(())
}

Expand Down
43 changes: 41 additions & 2 deletions datafusion/core/tests/sqllogictests/test_files/join.slt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.

##########
## JOIN Tests
## Join Tests
##########

statement ok
Expand All @@ -36,9 +36,48 @@ CREATE TABLE grades(grade INT, min INT, max INT) AS VALUES
(5, 80, 100);

# Regression test: https://github.com/apache/arrow-datafusion/issues/4844
query I
query TII
SELECT s.*, g.grade FROM students s join grades g on s.mark between g.min and g.max WHERE grade > 2 ORDER BY s.mark DESC
----
Amina 89 5
Salma 77 4
Christen 50 3

# two tables for join
statement ok
CREATE TABLE t1(t1_id INT, t1_name TEXT, t1_int INT) AS VALUES
(11, 'a', 1),
(22, 'b', 2),
(33, 'c', 3),
(44, 'd', 4);

statement ok
CREATE TABLE t2(t2_id INT, t2_name TEXT, t2_int INT) AS VALUES
(11, 'z', 3),
(22, 'y', 1),
(44, 'x', 3),
(55, 'w', 3);

# left semi with wrong where clause
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

query error DataFusion error: Schema error: No field named 't2'.'t2_id'. Valid fields are 't1'.'t1_id', 't1'.'t1_name', 't1'.'t1_int'.
SELECT t1.t1_id,
t1.t1_name,
t1.t1_int
FROM t1 LEFT SEMI
JOIN t2
ON (
t1.t1_id = t2.t2_id)
WHERE t2.t2_id > 1

# left semi join with on-filter
query ITI rowsort
SELECT t1.t1_id,
t1.t1_name,
t1.t1_int
FROM t1 LEFT SEMI
JOIN t2
ON (
t1.t1_id = t2.t2_id and t2.t2_int > 1)
----
11 a 1
44 d 4
62 changes: 62 additions & 0 deletions datafusion/core/tests/sqllogictests/test_files/subquery.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Licensed to the Apache Software Foundation (ASF) under one
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

#############
## Subquery Tests
#############

# two tables for subquery
statement ok
CREATE TABLE t1(t1_id INT, t1_name TEXT, t1_int INT) AS VALUES
(11, 'a', 1),
(22, 'b', 2),
(33, 'c', 3),
(44, 'd', 4);

statement ok
CREATE TABLE t2(t2_id INT, t2_name TEXT, t2_int INT) AS VALUES
(11, 'z', 3),
(22, 'y', 1),
(44, 'x', 3),
(55, 'w', 3);


# in_subquery_to_join_with_correlated_outer_filter
query ITI rowsort
select t1.t1_id,
t1.t1_name,
t1.t1_int
from t1
where t1.t1_id + 12 in (
select t2.t2_id + 1 from t2 where t1.t1_int > 0
)
----
11 a 1
33 c 3
44 d 4

# not_in_subquery_to_join_with_correlated_outer_filter
query ITI rowsort
select t1.t1_id,
t1.t1_name,
t1.t1_int
from t1
where t1.t1_id + 12 not in (
select t2.t2_id + 1 from t2 where t1.t1_int > 0
)
----
22 b 2
Loading