Skip to content

Commit

Permalink
extract OR clause for join
Browse files Browse the repository at this point in the history
  • Loading branch information
HuSen8891 committed Sep 21, 2022
1 parent ff718d0 commit f65dd47
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 6 deletions.
3 changes: 2 additions & 1 deletion datafusion/core/tests/sql/joins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1484,7 +1484,8 @@ async fn reduce_left_join_2() -> Result<()> {
" Filter: CAST(#t2.t2_int AS Int64) < Int64(10) OR CAST(#t1.t1_int AS Int64) > Int64(2) AND #t2.t2_name != Utf8(\"w\") [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" Inner Join: #t1.t1_id = #t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" Filter: CAST(#t2.t2_int AS Int64) < Int64(10) OR #t2.t2_name != Utf8(\"w\") [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
Expand Down
7 changes: 4 additions & 3 deletions datafusion/core/tests/sql/predicates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -430,9 +430,10 @@ async fn multiple_or_predicates() -> Result<()> {
" Projection: #part.p_size >= Int32(1) AS #part.p_size >= Int32(1)Int32(1)#part.p_size, #lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size [#part.p_size >= Int32(1)Int32(1)#part.p_size:Boolean;N, l_partkey:Int64, l_quantity:Decimal128(15, 2), p_brand:Utf8, p_size:Int32]",
" Filter: #part.p_brand = Utf8(\"Brand#12\") AND #lineitem.l_quantity >= Decimal128(Some(100),15,2) AND #lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND #part.p_size <= Int32(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND #lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND #part.p_size <= Int32(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND #lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND #part.p_size <= Int32(15) [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" Inner Join: #lineitem.l_partkey = #part.p_partkey [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" TableScan: lineitem projection=[l_partkey, l_quantity] [l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
" Filter: #part.p_size >= Int32(1) [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[#part.p_size >= Int32(1)] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" Filter: #lineitem.l_quantity >= Decimal128(Some(100),15,2) AND #lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR #lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND #lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR #lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND #lineitem.l_quantity <= Decimal128(Some(3000),15,2) [l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
" TableScan: lineitem projection=[l_partkey, l_quantity], partial_filters=[#lineitem.l_quantity >= Decimal128(Some(100),15,2) AND #lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR #lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND #lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR #lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND #lineitem.l_quantity <= Decimal128(Some(3000),15,2)] [l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
" Filter: #part.p_size >= Int32(1) AND #part.p_brand = Utf8(\"Brand#12\") AND #part.p_size <= Int32(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #part.p_size <= Int32(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #part.p_size <= Int32(15) [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[#part.p_size >= Int32(1), #part.p_brand = Utf8(\"Brand#12\") AND #part.p_size <= Int32(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #part.p_size <= Int32(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #part.p_size <= Int32(15)] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
Expand Down
195 changes: 193 additions & 2 deletions datafusion/optimizer/src/filter_push_down.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
use crate::{utils, OptimizerConfig, OptimizerRule};
use datafusion_common::{Column, DFSchema, DataFusionError, Result};
use datafusion_expr::{
col,
and, col,
expr_rewriter::{replace_col, ExprRewritable, ExprRewriter},
logical_plan::{
Aggregate, CrossJoin, Filter, Join, JoinType, Limit, LogicalPlan, Projection,
TableScan, Union,
},
or,
utils::{expr_to_columns, exprlist_to_columns, from_plan},
Expr, TableProviderFilterPushDown,
Expr, Operator, TableProviderFilterPushDown,
};
use std::collections::{HashMap, HashSet};
use std::iter::once;
Expand Down Expand Up @@ -248,6 +249,128 @@ fn get_pushable_join_predicates<'a>(
.unzip()
}

// examine OR clause to see if any useful clauses can be extracted and push down.
// extract at least one qual of each sub clauses of OR clause, then form the quals
// to new OR clause as predicate.
//
// Filter: (a = c and a < 20) or (b = d and b > 10)
// join/crossjoin:
// TableScan: projection=[a, b]
// TableScan: projection=[c, d]
//
// is optimized to
//
// Filter: (a = c and a < 20) or (b = d and b > 10)
// join/crossjoin:
// Filter: (a < 20) or (b > 10)
// TableScan: projection=[a, b]
// TableScan: projection=[c, d]
fn extract_or_clauses_for_join(
filters: &[&Expr],
schema: &DFSchema,
preserved: bool,
) -> (Vec<Expr>, Vec<HashSet<Column>>) {
if !preserved {
return (vec![], vec![]);
}

let schema_columns = schema
.fields()
.iter()
.flat_map(|f| {
[
f.qualified_column(),
// we need to push down filter using unqualified column as well
f.unqualified_column(),
]
})
.collect::<HashSet<_>>();

let mut exprs = vec![];
let mut expr_columns = vec![];
for expr in filters.iter() {
if let Expr::BinaryExpr {
left,
op: Operator::Or,
right,
} = expr
{
let left_expr = extract_or_clause(left.as_ref(), &schema_columns);
let right_expr = extract_or_clause(right.as_ref(), &schema_columns);

// If nothing can be extracted from any sub clauses, do nothing for this OR clause.
if let (Some(left_expr), Some(right_expr)) = (left_expr, right_expr) {
let predicate = or(left_expr, right_expr);
let mut columns: HashSet<Column> = HashSet::new();
expr_to_columns(&predicate, &mut columns).ok().unwrap();

exprs.push(predicate);
expr_columns.push(columns);
}
}
}

(exprs, expr_columns)
}

// extract qual from OR sub-clause.
fn extract_or_clause(expr: &Expr, schema_columns: &HashSet<Column>) -> Option<Expr> {
let mut predicate = None;

match expr {
Expr::BinaryExpr {
left: l_expr,
op: Operator::Or,
right: r_expr,
} => {
let l_expr = extract_or_clause(l_expr, schema_columns);
let r_expr = extract_or_clause(r_expr, schema_columns);

if let (Some(l_expr), Some(r_expr)) = (l_expr, r_expr) {
predicate = Some(or(l_expr, r_expr));
}
}
Expr::BinaryExpr {
left: l_expr,
op: Operator::And,
right: r_expr,
} => {
let l_expr = extract_or_clause(l_expr, schema_columns);
let r_expr = extract_or_clause(r_expr, schema_columns);

match (l_expr, r_expr) {
(Some(l_expr), Some(r_expr)) => {
predicate = Some(and(l_expr, r_expr));
}
(Some(l_expr), None) => {
predicate = Some(l_expr);
}
(None, Some(r_expr)) => {
predicate = Some(r_expr);
}
(None, None) => {
predicate = None;
}
}
}
_ => {
let mut columns: HashSet<Column> = HashSet::new();
expr_to_columns(expr, &mut columns).ok().unwrap();

if schema_columns
.intersection(&columns)
.collect::<HashSet<_>>()
.len()
== columns.len()
{
predicate = Some(expr.clone());
}
}
}

predicate
}

fn optimize_join(
mut state: State,
plan: &LogicalPlan,
Expand Down Expand Up @@ -286,17 +409,54 @@ fn optimize_join(
(on_to_left, on_to_right, on_to_keep)
};

// Extract from OR clause, generate new predicates for both side of join if possible.
// We only track the unpushable predicates above.
let or_to_left =
extract_or_clauses_for_join(&to_keep.0, left.schema(), left_preserved);
let or_to_right =
extract_or_clauses_for_join(&to_keep.0, right.schema(), right_preserved);
let on_or_to_left = extract_or_clauses_for_join(
&on_to_keep.iter().collect::<Vec<_>>(),
left.schema(),
left_preserved,
);
let on_or_to_right = extract_or_clauses_for_join(
&on_to_keep.iter().collect::<Vec<_>>(),
right.schema(),
right_preserved,
);

// Build new filter states using pushable predicates
// from current optimizer states and from ON clause.
// Then recursively call optimization for both join inputs
let mut left_state = State { filters: vec![] };
left_state.append_predicates(to_left);
left_state.append_predicates(on_to_left);
or_to_left
.0
.into_iter()
.zip(or_to_left.1)
.for_each(|(expr, cols)| left_state.filters.push((expr, cols)));
on_or_to_left
.0
.into_iter()
.zip(on_or_to_left.1)
.for_each(|(expr, cols)| left_state.filters.push((expr, cols)));
let left = optimize(left, left_state)?;

let mut right_state = State { filters: vec![] };
right_state.append_predicates(to_right);
right_state.append_predicates(on_to_right);
or_to_right
.0
.into_iter()
.zip(or_to_right.1)
.for_each(|(expr, cols)| right_state.filters.push((expr, cols)));
on_or_to_right
.0
.into_iter()
.zip(on_or_to_right.1)
.for_each(|(expr, cols)| right_state.filters.push((expr, cols)));
let right = optimize(right, right_state)?;

// Create a new Join with the new `left` and `right`
Expand Down Expand Up @@ -2137,4 +2297,35 @@ mod tests {

Ok(())
}

#[test]
fn test_crossjoin_with_or_clause() -> Result<()> {
// select * from test,test1 where (test.a = test1.a and test.b > 1) or (test.b = test1.b and test.c < 10);
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test1")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a").alias("d"), col("a").alias("e")])?
.build()?;
let filter = or(
and(col("a").eq(col("d")), col("b").gt(lit(1u32))),
and(col("b").eq(col("e")), col("c").lt(lit(10u32))),
);
let plan = LogicalPlanBuilder::from(left)
.cross_join(&right)?
.filter(filter)?
.build()?;

let expected = "Filter: #test.a = #d AND #test.b > UInt32(1) OR #test.b = #e AND #test.c < UInt32(10)\
\n CrossJoin:\
\n Projection: #test.a, #test.b, #test.c\
\n Filter: #test.b > UInt32(1) OR #test.c < UInt32(10)\
\n TableScan: test\
\n Projection: #test1.a AS d, #test1.a AS e\
\n TableScan: test1";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
}

0 comments on commit f65dd47

Please sign in to comment.