Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve Filter pushdown to Join #5770

Merged
merged 1 commit into from
Apr 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 49 additions & 55 deletions benchmarks/expected-plans/q17.txt

Large diffs are not rendered by default.

73 changes: 34 additions & 39 deletions benchmarks/expected-plans/q19.txt

Large diffs are not rendered by default.

165 changes: 80 additions & 85 deletions benchmarks/expected-plans/q20.txt

Large diffs are not rendered by default.

181 changes: 88 additions & 93 deletions benchmarks/expected-plans/q7.txt

Large diffs are not rendered by default.

9 changes: 4 additions & 5 deletions datafusion/core/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -952,21 +952,21 @@ impl DefaultPhysicalPlanner {

let join_filter = match filter {
Some(expr) => {
// Extract columns from filter expression
// Extract columns from filter expression and saved in a HashSet
let cols = expr.to_columns()?;

// Collect left & right field indices
// Collect left & right field indices, the field indices are sorted in ascending order
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the sort order is important for later stages, can you make a note about the rationale (so the comment explains why the sorting is important, in addition to noting the output is sorted)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will do.

let left_field_indices = cols.iter()
.filter_map(|c| match left_df_schema.index_of_column(c) {
Ok(idx) => Some(idx),
_ => None,
})
}).sorted()
.collect::<Vec<_>>();
let right_field_indices = cols.iter()
.filter_map(|c| match right_df_schema.index_of_column(c) {
Ok(idx) => Some(idx),
_ => None,
})
}).sorted()
.collect::<Vec<_>>();

// Collect DFFields and Fields required for intermediate schemas
Expand All @@ -986,7 +986,6 @@ impl DefaultPhysicalPlanner {
)
.unzip();


// Construct intermediate schemas used for filtering data and
// convert logical expression to physical according to filter schema
let filter_df_schema = DFSchema::new_with_metadata(filter_df_fields, HashMap::new())?;
Expand Down
20 changes: 9 additions & 11 deletions datafusion/core/tests/sql/joins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1103,12 +1103,11 @@ async fn reduce_left_join_2() -> Result<()> {
// the right part `(t2.t2_name != 'w' or t2.t2_int < 10)` could be push down left join side and remove in filter.

let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Filter: t2.t2_int < UInt32(10) OR t1.t1_int > UInt32(2) AND t2.t2_name != Utf8(\"w\") [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" Filter: t2.t2_int < UInt32(10) OR t2.t2_name != Utf8(\"w\") [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
"Explain [plan_type:Utf8, plan:Utf8]",
" Inner Join: t1.t1_id = t2.t2_id Filter: t2.t2_int < UInt32(10) OR t1.t1_int > UInt32(2) AND t2.t2_name != Utf8(\"w\") [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" Filter: t2.t2_int < UInt32(10) OR t2.t2_name != Utf8(\"w\") [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
Expand Down Expand Up @@ -1188,11 +1187,10 @@ async fn reduce_right_join_2() -> Result<()> {
let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
let plan = dataframe.into_optimized_plan()?;
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Filter: t1.t1_int != t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
"Explain [plan_type:Utf8, plan:Utf8]",
" Inner Join: t1.t1_id = t2.t2_id Filter: t1.t1_int != t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
Expand Down
12 changes: 5 additions & 7 deletions datafusion/core/tests/sql/predicates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,11 @@ async fn multiple_or_predicates() -> Result<()> {
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: lineitem.l_partkey [l_partkey:Int64]",
" Filter: part.p_brand = Utf8(\"Brand#12\") AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15) [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_brand:Utf8, p_size:Int32]",
" Projection: lineitem.l_partkey, lineitem.l_quantity, part.p_brand, part.p_size [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_brand:Utf8, p_size:Int32]",
" Inner Join: lineitem.l_partkey = part.p_partkey [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" Filter: lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) [l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
" TableScan: lineitem projection=[l_partkey, l_quantity], partial_filters=[lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)] [l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
" Filter: (part.p_brand = Utf8(\"Brand#12\") AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1) [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8(\"Brand#12\") AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND part.p_size <= Int32(15)] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" Inner Join: lineitem.l_partkey = part.p_partkey Filter: part.p_brand = Utf8(\"Brand#12\") AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15) [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" Filter: lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) [l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
" TableScan: lineitem projection=[l_partkey, l_quantity], partial_filters=[lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)] [l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
" Filter: (part.p_brand = Utf8(\"Brand#12\") AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1) [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8(\"Brand#12\") AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND part.p_size <= Int32(15)] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
Expand Down
28 changes: 12 additions & 16 deletions datafusion/core/tests/sql/subqueries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,18 @@ where c_acctbal < (
let actual = format!("{}", plan.display_indent());
let expected = "Sort: customer.c_custkey ASC NULLS LAST\
\n Projection: customer.c_custkey\
\n Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_1.__value\
\n Projection: customer.c_custkey, customer.c_acctbal, __scalar_sq_1.__value\
\n Inner Join: customer.c_custkey = __scalar_sq_1.o_custkey\
\n TableScan: customer projection=[c_custkey, c_acctbal]\
\n SubqueryAlias: __scalar_sq_1\
\n Projection: orders.o_custkey, SUM(orders.o_totalprice) AS __value\
\n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[SUM(orders.o_totalprice)]]\
\n Projection: orders.o_custkey, orders.o_totalprice\
\n Filter: CAST(orders.o_totalprice AS Decimal128(25, 2)) < __scalar_sq_2.__value\
\n Projection: orders.o_custkey, orders.o_totalprice, __scalar_sq_2.__value\
\n Inner Join: orders.o_orderkey = __scalar_sq_2.l_orderkey\
\n TableScan: orders projection=[o_orderkey, o_custkey, o_totalprice]\
\n SubqueryAlias: __scalar_sq_2\
\n Projection: lineitem.l_orderkey, SUM(lineitem.l_extendedprice) AS price AS __value\
\n Aggregate: groupBy=[[lineitem.l_orderkey]], aggr=[[SUM(lineitem.l_extendedprice)]]\
\n TableScan: lineitem projection=[l_orderkey, l_extendedprice]";
\n Inner Join: customer.c_custkey = __scalar_sq_1.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_1.__value\
\n TableScan: customer projection=[c_custkey, c_acctbal]\
\n SubqueryAlias: __scalar_sq_1\
\n Projection: orders.o_custkey, SUM(orders.o_totalprice) AS __value\
\n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[SUM(orders.o_totalprice)]]\
\n Projection: orders.o_custkey, orders.o_totalprice\
\n Inner Join: orders.o_orderkey = __scalar_sq_2.l_orderkey Filter: CAST(orders.o_totalprice AS Decimal128(25, 2)) < __scalar_sq_2.__value\
\n TableScan: orders projection=[o_orderkey, o_custkey, o_totalprice]\
\n SubqueryAlias: __scalar_sq_2\
\n Projection: lineitem.l_orderkey, SUM(lineitem.l_extendedprice) AS price AS __value\
\n Aggregate: groupBy=[[lineitem.l_orderkey]], aggr=[[SUM(lineitem.l_extendedprice)]]\
\n TableScan: lineitem projection=[l_orderkey, l_extendedprice]";
assert_eq!(actual, expected);

Ok(())
Expand Down
Loading