Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pushdown single column predicates from ON join clauses #3578

Merged
merged 6 commits into from
Oct 15, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion datafusion/core/tests/sql/joins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1484,7 +1484,8 @@ async fn reduce_left_join_2() -> Result<()> {
" Filter: CAST(#t2.t2_int AS Int64) < Int64(10) OR CAST(#t1.t1_int AS Int64) > Int64(2) AND #t2.t2_name != Utf8(\"w\") [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" Inner Join: #t1.t1_id = #t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" Filter: CAST(#t2.t2_int AS Int64) < Int64(10) OR #t2.t2_name != Utf8(\"w\") [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 nice

" TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
Expand Down
7 changes: 4 additions & 3 deletions datafusion/core/tests/sql/predicates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -430,9 +430,10 @@ async fn multiple_or_predicates() -> Result<()> {
" Projection: #part.p_size >= Int32(1) AS #part.p_size >= Int32(1)Int32(1)#part.p_size, #lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size [#part.p_size >= Int32(1)Int32(1)#part.p_size:Boolean;N, l_partkey:Int64, l_quantity:Decimal128(15, 2), p_brand:Utf8, p_size:Int32]",
" Filter: #part.p_brand = Utf8(\"Brand#12\") AND #lineitem.l_quantity >= Decimal128(Some(100),15,2) AND #lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND #part.p_size <= Int32(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND #lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND #part.p_size <= Int32(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND #lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND #part.p_size <= Int32(15) [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" Inner Join: #lineitem.l_partkey = #part.p_partkey [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" TableScan: lineitem projection=[l_partkey, l_quantity] [l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
" Filter: #part.p_size >= Int32(1) [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[#part.p_size >= Int32(1)] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" Filter: #lineitem.l_quantity >= Decimal128(Some(100),15,2) AND #lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR #lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND #lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR #lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND #lineitem.l_quantity <= Decimal128(Some(3000),15,2) [l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went through this plan and I agree it seems correct (as in the pushed down filters don't filter out anything that would have passed the original filter)

" TableScan: lineitem projection=[l_partkey, l_quantity], partial_filters=[#lineitem.l_quantity >= Decimal128(Some(100),15,2) AND #lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR #lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND #lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR #lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND #lineitem.l_quantity <= Decimal128(Some(3000),15,2)] [l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
" Filter: #part.p_size >= Int32(1) AND #part.p_brand = Utf8(\"Brand#12\") AND #part.p_size <= Int32(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #part.p_size <= Int32(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #part.p_size <= Int32(15) [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[#part.p_size >= Int32(1), #part.p_brand = Utf8(\"Brand#12\") AND #part.p_size <= Int32(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #part.p_size <= Int32(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #part.p_size <= Int32(15)] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
Expand Down
223 changes: 221 additions & 2 deletions datafusion/optimizer/src/filter_push_down.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
use crate::{utils, OptimizerConfig, OptimizerRule};
use datafusion_common::{Column, DFSchema, DataFusionError, Result};
use datafusion_expr::{
col,
and, col,
expr_rewriter::{replace_col, ExprRewritable, ExprRewriter},
logical_plan::{
Aggregate, CrossJoin, Filter, Join, JoinType, Limit, LogicalPlan, Projection,
TableScan, Union,
},
or,
utils::{expr_to_columns, exprlist_to_columns, from_plan},
Expr, TableProviderFilterPushDown,
Expr, Operator, TableProviderFilterPushDown,
};
use std::collections::{HashMap, HashSet};
use std::iter::once;
Expand Down Expand Up @@ -248,6 +249,156 @@ fn get_pushable_join_predicates<'a>(
.unzip()
}

// examine OR clause to see if any useful clauses can be extracted and push down.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this transformation is correct. In particular, I don't think the results will always be the same

Schematically, we have this type of predicate (that is being evaluated during the join)

(A AND B) OR (C AND D)

This transformation proposes adding another (A OR B) clause (evaluated before the join), so effectively

 ((A AND B) OR (C AND D)) AND (A OR B)

In order to do this transformation, the boolean statements must be equivalent for all inputs.

However, a counter example is

A: false, B: false, C: true, D: true

In this case, the original predicate would be true, but the rewrite would be false

Here is the program I wrote to generate the entire truth table: https://play.rust-lang.org/?version=stable&mode=debug&edition=2018&gist=334938478775ba3cd55e7c400ea89b06

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This transformation should extract at least one quals from each sub-clauses of OR, else do nothing.

(A AND B) OR (C AND D)

will be transformed to

((A AND B) OR (C AND D)) AND (A OR C)

OR

((A AND B) OR (C AND D)) AND ((A AND B) OR C)

OR

do nothing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see -- thanks -- I checked those rewrites and https://play.rust-lang.org/?version=stable&mode=debug&edition=2018&gist=3b41b0409c8ecf4df0027f323668e0db they do look good to me

// extract at least one qual from each sub clauses of OR clause, then form the quals
// to new OR clause as predicate.
//
Comment on lines +253 to +254
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to explain the conditions under which a qual can be extracted as it may not be obvious to someone when they initially look at this.

Suggested change
// to new OR clause as predicate.
//
// to new OR clause as predicate.
//
// A qual is extracted if it it contains (only) common set of column references with the other quals.

I am not sure that is correct

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also like to see the return type documented here (as in what does the (Vec<Expr>, Vec<HashSet<Column>>) represent? I think it is the extracted quals and their column references but I am not sure

// Filter: (a = c and a < 20) or (b = d and b > 10)
// join/crossjoin:
// TableScan: projection=[a, b]
// TableScan: projection=[c, d]
//
// is optimized to
//
// Filter: (a = c and a < 20) or (b = d and b > 10)
// join/crossjoin:
// Filter: (a < 20) or (b > 10)
// TableScan: projection=[a, b]
// TableScan: projection=[c, d]
//
// In general, predicates of this form:
//
// (A AND B) OR (C AND D)
//
// will be transformed to
//
// ((A AND B) OR (C AND D)) AND (A OR C)
//
// OR
//
// ((A AND B) OR (C AND D)) AND ((A AND B) OR C)
//
// OR
//
// do nothing.
//
fn extract_or_clauses_for_join(
filters: &[&Expr],
schema: &DFSchema,
preserved: bool,
) -> (Vec<Expr>, Vec<HashSet<Column>>) {
if !preserved {
return (vec![], vec![]);
}

let schema_columns = schema
.fields()
.iter()
.flat_map(|f| {
[
f.qualified_column(),
// we need to push down filter using unqualified column as well
f.unqualified_column(),
]
})
.collect::<HashSet<_>>();

let mut exprs = vec![];
let mut expr_columns = vec![];
for expr in filters.iter() {
if let Expr::BinaryExpr {
left,
op: Operator::Or,
right,
} = expr
{
let left_expr = extract_or_clause(left.as_ref(), &schema_columns);
let right_expr = extract_or_clause(right.as_ref(), &schema_columns);

// If nothing can be extracted from any sub clauses, do nothing for this OR clause.
if let (Some(left_expr), Some(right_expr)) = (left_expr, right_expr) {
let predicate = or(left_expr, right_expr);
let mut columns: HashSet<Column> = HashSet::new();
expr_to_columns(&predicate, &mut columns).ok().unwrap();

exprs.push(predicate);
expr_columns.push(columns);
}
}
}

// new formed OR clauses and their column references
(exprs, expr_columns)
}

// extract qual from OR sub-clause.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add some additional comments under what conditions the OR clause is extracted? I tried to explain above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, I can add more comments for this.

//
// A qual is extracted if it only contains set of column references in schema_columns.
//
// For AND clause, we extract from both sub-clauses, then make new AND clause by extracted
// clauses if both extracted; Otherwise, use the extracted clause from any sub-clauses or None.
//
// For OR clause, we extract from both sub-clauses, then make new OR clause by extracted clauses if both extracted;
// Otherwise, return None.
//
// For other clause, apply the rule above to extract clause.
fn extract_or_clause(expr: &Expr, schema_columns: &HashSet<Column>) -> Option<Expr> {
let mut predicate = None;

match expr {
Expr::BinaryExpr {
left: l_expr,
op: Operator::Or,
right: r_expr,
} => {
let l_expr = extract_or_clause(l_expr, schema_columns);
let r_expr = extract_or_clause(r_expr, schema_columns);

if let (Some(l_expr), Some(r_expr)) = (l_expr, r_expr) {
predicate = Some(or(l_expr, r_expr));
}
}
Expr::BinaryExpr {
left: l_expr,
op: Operator::And,
right: r_expr,
} => {
let l_expr = extract_or_clause(l_expr, schema_columns);
let r_expr = extract_or_clause(r_expr, schema_columns);

match (l_expr, r_expr) {
(Some(l_expr), Some(r_expr)) => {
predicate = Some(and(l_expr, r_expr));
}
(Some(l_expr), None) => {
predicate = Some(l_expr);
}
(None, Some(r_expr)) => {
predicate = Some(r_expr);
}
(None, None) => {
predicate = None;
}
}
}
_ => {
let mut columns: HashSet<Column> = HashSet::new();
expr_to_columns(expr, &mut columns).ok().unwrap();

if schema_columns
.intersection(&columns)
.collect::<HashSet<_>>()
.len()
== columns.len()
{
predicate = Some(expr.clone());
}
}
}

predicate
}

fn optimize_join(
mut state: State,
plan: &LogicalPlan,
Expand Down Expand Up @@ -286,17 +437,54 @@ fn optimize_join(
(on_to_left, on_to_right, on_to_keep)
};

// Extract from OR clause, generate new predicates for both side of join if possible.
// We only track the unpushable predicates above.
let or_to_left =
extract_or_clauses_for_join(&to_keep.0, left.schema(), left_preserved);
let or_to_right =
extract_or_clauses_for_join(&to_keep.0, right.schema(), right_preserved);
let on_or_to_left = extract_or_clauses_for_join(
&on_to_keep.iter().collect::<Vec<_>>(),
left.schema(),
left_preserved,
);
let on_or_to_right = extract_or_clauses_for_join(
&on_to_keep.iter().collect::<Vec<_>>(),
right.schema(),
right_preserved,
);

// Build new filter states using pushable predicates
// from current optimizer states and from ON clause.
// Then recursively call optimization for both join inputs
let mut left_state = State { filters: vec![] };
left_state.append_predicates(to_left);
left_state.append_predicates(on_to_left);
or_to_left
.0
.into_iter()
.zip(or_to_left.1)
.for_each(|(expr, cols)| left_state.filters.push((expr, cols)));
on_or_to_left
.0
.into_iter()
.zip(on_or_to_left.1)
.for_each(|(expr, cols)| left_state.filters.push((expr, cols)));
let left = optimize(left, left_state)?;

let mut right_state = State { filters: vec![] };
right_state.append_predicates(to_right);
right_state.append_predicates(on_to_right);
or_to_right
.0
.into_iter()
.zip(or_to_right.1)
.for_each(|(expr, cols)| right_state.filters.push((expr, cols)));
on_or_to_right
.0
.into_iter()
.zip(on_or_to_right.1)
.for_each(|(expr, cols)| right_state.filters.push((expr, cols)));
let right = optimize(right, right_state)?;

// Create a new Join with the new `left` and `right`
Expand Down Expand Up @@ -2137,4 +2325,35 @@ mod tests {

Ok(())
}

#[test]
fn test_crossjoin_with_or_clause() -> Result<()> {
// select * from test,test1 where (test.a = test1.a and test.b > 1) or (test.b = test1.b and test.c < 10);
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test1")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a").alias("d"), col("a").alias("e")])?
.build()?;
let filter = or(
and(col("a").eq(col("d")), col("b").gt(lit(1u32))),
and(col("b").eq(col("e")), col("c").lt(lit(10u32))),
);
let plan = LogicalPlanBuilder::from(left)
.cross_join(&right)?
.filter(filter)?
.build()?;

let expected = "Filter: #test.a = #d AND #test.b > UInt32(1) OR #test.b = #e AND #test.c < UInt32(10)\
\n CrossJoin:\
\n Projection: #test.a, #test.b, #test.c\
\n Filter: #test.b > UInt32(1) OR #test.c < UInt32(10)\
\n TableScan: test\
\n Projection: #test1.a AS d, #test1.a AS e\
\n TableScan: test1";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
}