Skip to content

Commit

Permalink
put subquery's equal clause into join on clauses instead of filter cl… (
Browse files Browse the repository at this point in the history
#3862)

* put subquery's equal clause into join on clauses instead of filter clauses

* only do this optimization for correlated subqueries
  • Loading branch information
HuSen8891 authored Oct 19, 2022
1 parent b004ec7 commit 13addce
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 75 deletions.
43 changes: 21 additions & 22 deletions benchmarks/expected-plans/q2.txt
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@
Sort: supplier.s_acctbal DESC NULLS FIRST, nation.n_name ASC NULLS LAST, supplier.s_name ASC NULLS LAST, part.p_partkey ASC NULLS LAST
Projection: supplier.s_acctbal, supplier.s_name, nation.n_name, part.p_partkey, part.p_mfgr, supplier.s_address, supplier.s_phone, supplier.s_comment
Filter: partsupp.ps_supplycost = __sq_1.__value
Inner Join: part.p_partkey = __sq_1.ps_partkey
Inner Join: nation.n_regionkey = region.r_regionkey
Inner Join: supplier.s_nationkey = nation.n_nationkey
Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
Inner Join: part.p_partkey = partsupp.ps_partkey
Filter: part.p_size = Int32(15) AND part.p_type LIKE Utf8("%BRASS")
TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size]
Inner Join: part.p_partkey = __sq_1.ps_partkey, partsupp.ps_supplycost = __sq_1.__value
Inner Join: nation.n_regionkey = region.r_regionkey
Inner Join: supplier.s_nationkey = nation.n_nationkey
Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
Inner Join: part.p_partkey = partsupp.ps_partkey
Filter: part.p_size = Int32(15) AND part.p_type LIKE Utf8("%BRASS")
TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size]
TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost]
TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]
TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
Filter: region.r_name = Utf8("EUROPE")
TableScan: region projection=[r_regionkey, r_name]
Projection: partsupp.ps_partkey, MIN(partsupp.ps_supplycost) AS __value, alias=__sq_1
Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[MIN(partsupp.ps_supplycost)]]
Inner Join: nation.n_regionkey = region.r_regionkey
Inner Join: supplier.s_nationkey = nation.n_nationkey
Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost]
TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]
TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
Filter: region.r_name = Utf8("EUROPE")
TableScan: region projection=[r_regionkey, r_name]
Projection: partsupp.ps_partkey, MIN(partsupp.ps_supplycost) AS __value, alias=__sq_1
Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[MIN(partsupp.ps_supplycost)]]
Inner Join: nation.n_regionkey = region.r_regionkey
Inner Join: supplier.s_nationkey = nation.n_nationkey
Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost]
TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]
TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
Filter: region.r_name = Utf8("EUROPE")
TableScan: region projection=[r_regionkey, r_name]
TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]
TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
Filter: region.r_name = Utf8("EUROPE")
TableScan: region projection=[r_regionkey, r_name]
43 changes: 21 additions & 22 deletions datafusion/core/tests/sql/subqueries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,29 +141,28 @@ order by s_acctbal desc, n_name, s_name, p_partkey;"#;
let actual = format!("{}", plan.display_indent());
let expected = r#"Sort: supplier.s_acctbal DESC NULLS FIRST, nation.n_name ASC NULLS LAST, supplier.s_name ASC NULLS LAST, part.p_partkey ASC NULLS LAST
Projection: supplier.s_acctbal, supplier.s_name, nation.n_name, part.p_partkey, part.p_mfgr, supplier.s_address, supplier.s_phone, supplier.s_comment
Filter: partsupp.ps_supplycost = __sq_1.__value
Inner Join: part.p_partkey = __sq_1.ps_partkey
Inner Join: nation.n_regionkey = region.r_regionkey
Inner Join: supplier.s_nationkey = nation.n_nationkey
Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
Inner Join: part.p_partkey = partsupp.ps_partkey
Filter: part.p_size = Int32(15) AND part.p_type LIKE Utf8("%BRASS")
TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[part.p_size = Int32(15), part.p_type LIKE Utf8("%BRASS")]
Inner Join: part.p_partkey = __sq_1.ps_partkey, partsupp.ps_supplycost = __sq_1.__value
Inner Join: nation.n_regionkey = region.r_regionkey
Inner Join: supplier.s_nationkey = nation.n_nationkey
Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
Inner Join: part.p_partkey = partsupp.ps_partkey
Filter: part.p_size = Int32(15) AND part.p_type LIKE Utf8("%BRASS")
TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[part.p_size = Int32(15), part.p_type LIKE Utf8("%BRASS")]
TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost]
TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]
TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
Filter: region.r_name = Utf8("EUROPE")
TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8("EUROPE")]
Projection: partsupp.ps_partkey, MIN(partsupp.ps_supplycost) AS __value, alias=__sq_1
Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[MIN(partsupp.ps_supplycost)]]
Inner Join: nation.n_regionkey = region.r_regionkey
Inner Join: supplier.s_nationkey = nation.n_nationkey
Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost]
TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]
TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
Filter: region.r_name = Utf8("EUROPE")
TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8("EUROPE")]
Projection: partsupp.ps_partkey, MIN(partsupp.ps_supplycost) AS __value, alias=__sq_1
Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[MIN(partsupp.ps_supplycost)]]
Inner Join: nation.n_regionkey = region.r_regionkey
Inner Join: supplier.s_nationkey = nation.n_nationkey
Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost]
TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]
TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
Filter: region.r_name = Utf8("EUROPE")
TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8("EUROPE")]"#
TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]
TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
Filter: region.r_name = Utf8("EUROPE")
TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8("EUROPE")]"#
.to_string();
assert_eq!(actual, expected);

Expand Down
134 changes: 103 additions & 31 deletions datafusion/optimizer/src/scalar_subquery_to_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ fn optimize_scalar(
// Grab column names to join on
let (col_exprs, other_subqry_exprs) =
find_join_exprs(subqry_filter_exprs, input.schema())?;
let (outer_cols, subqry_cols, join_filters) =
let (mut outer_cols, subqry_cols, join_filters) =
exprs_to_join_cols(&col_exprs, input.schema(), false)?;
if join_filters.is_some() {
plan_err!("only joins on column equality are presently supported")?;
Expand Down Expand Up @@ -275,13 +275,31 @@ fn optimize_scalar(
.build()?;

// qualify the join columns for outside the subquery
let subqry_cols: Vec<_> = subqry_cols
let mut subqry_cols: Vec<_> = subqry_cols
.iter()
.map(|it| Column {
relation: Some(subqry_alias.clone()),
name: it.name.clone(),
})
.collect();

let qry_expr = Expr::Column(Column {
relation: Some(subqry_alias),
name: "__value".to_string(),
});

// if correlated subquery's operation is column equality, put the clause into join on clause.
let mut restore_where_clause = true;

if let (Operator::Eq, Expr::Column(column)) = (query_info.op, &query_info.expr) {
// only do this optimization for correlated subquery
if !outer_cols.is_empty() {
outer_cols.push(column.clone());
subqry_cols.push(qry_expr.try_into_col().unwrap());
restore_where_clause = false;
}
}

let join_keys = (outer_cols, subqry_cols);

// join our sub query into the main plan
Expand All @@ -295,24 +313,22 @@ fn optimize_scalar(
};

// restore where in condition
let qry_expr = Box::new(Expr::Column(Column {
relation: Some(subqry_alias),
name: "__value".to_string(),
}));
let filter_expr = if query_info.expr_on_left {
Expr::BinaryExpr(BinaryExpr::new(
Box::new(query_info.expr.clone()),
query_info.op,
qry_expr,
))
} else {
Expr::BinaryExpr(BinaryExpr::new(
qry_expr,
query_info.op,
Box::new(query_info.expr.clone()),
))
};
new_plan = new_plan.filter(filter_expr)?;
if restore_where_clause {
let filter_expr = if query_info.expr_on_left {
Expr::BinaryExpr(BinaryExpr::new(
Box::new(query_info.expr.clone()),
query_info.op,
Box::new(qry_expr),
))
} else {
Expr::BinaryExpr(BinaryExpr::new(
Box::new(qry_expr),
query_info.op,
Box::new(query_info.expr.clone()),
))
};
new_plan = new_plan.filter(filter_expr)?;
}

// if the main query had additional expressions, restore them
if let Some(expr) = conjunction(outer_others.to_vec()) {
Expand Down Expand Up @@ -461,13 +477,12 @@ mod tests {
.build()?;

let expected = r#"Projection: customer.c_custkey [c_custkey:Int64]
Filter: customer.c_custkey = __sq_1.__value [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N]
Inner Join: customer.c_custkey = __sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N]
TableScan: customer [c_custkey:Int64, c_name:Utf8]
Projection: orders.o_custkey, MAX(orders.o_custkey) AS __value, alias=__sq_1 [o_custkey:Int64, __value:Int64;N]
Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]
Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#;
Inner Join: customer.c_custkey = __sq_1.o_custkey, customer.c_custkey = __sq_1.__value [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N]
TableScan: customer [c_custkey:Int64, c_name:Utf8]
Projection: orders.o_custkey, MAX(orders.o_custkey) AS __value, alias=__sq_1 [o_custkey:Int64, __value:Int64;N]
Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]
Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#;

assert_optimized_plan_eq(&ScalarSubqueryToJoin::new(), &plan, expected);
Ok(())
Expand Down Expand Up @@ -677,7 +692,7 @@ mod tests {

/// Test for correlated scalar subquery filter with additional filters
#[test]
fn scalar_subquery_additional_filters() -> Result<()> {
fn scalar_subquery_additional_filters_with_non_equal_clause() -> Result<()> {
let sq = Arc::new(
LogicalPlanBuilder::from(scan_tpch_table("orders"))
.filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
Expand All @@ -689,15 +704,15 @@ mod tests {
let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
.filter(
col("customer.c_custkey")
.eq(scalar_subquery(sq))
.gt_eq(scalar_subquery(sq))
.and(col("c_custkey").eq(lit(1))),
)?
.project(vec![col("customer.c_custkey")])?
.build()?;

let expected = r#"Projection: customer.c_custkey [c_custkey:Int64]
Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N]
Filter: customer.c_custkey = __sq_1.__value [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N]
Filter: customer.c_custkey >= __sq_1.__value [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N]
Inner Join: customer.c_custkey = __sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N]
TableScan: customer [c_custkey:Int64, c_name:Utf8]
Projection: orders.o_custkey, MAX(orders.o_custkey) AS __value, alias=__sq_1 [o_custkey:Int64, __value:Int64;N]
Expand All @@ -708,6 +723,37 @@ mod tests {
Ok(())
}

#[test]
fn scalar_subquery_additional_filters_with_equal_clause() -> Result<()> {
let sq = Arc::new(
LogicalPlanBuilder::from(scan_tpch_table("orders"))
.filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
.aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
.project(vec![max(col("orders.o_custkey"))])?
.build()?,
);

let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
.filter(
col("customer.c_custkey")
.eq(scalar_subquery(sq))
.and(col("c_custkey").eq(lit(1))),
)?
.project(vec![col("customer.c_custkey")])?
.build()?;

let expected = r#"Projection: customer.c_custkey [c_custkey:Int64]
Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N]
Inner Join: customer.c_custkey = __sq_1.o_custkey, customer.c_custkey = __sq_1.__value [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N]
TableScan: customer [c_custkey:Int64, c_name:Utf8]
Projection: orders.o_custkey, MAX(orders.o_custkey) AS __value, alias=__sq_1 [o_custkey:Int64, __value:Int64;N]
Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]
TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#;

assert_optimized_plan_eq(&ScalarSubqueryToJoin::new(), &plan, expected);
Ok(())
}

/// Test for correlated scalar subquery filter with disjustions
#[test]
fn scalar_subquery_disjunction() -> Result<()> {
Expand Down Expand Up @@ -771,7 +817,33 @@ mod tests {

/// Test for non-correlated scalar subquery with no filters
#[test]
fn scalar_subquery_non_correlated_no_filters() -> Result<()> {
fn scalar_subquery_non_correlated_no_filters_with_non_equal_clause() -> Result<()> {
let sq = Arc::new(
LogicalPlanBuilder::from(scan_tpch_table("orders"))
.aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
.project(vec![max(col("orders.o_custkey"))])?
.build()?,
);

let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
.filter(col("customer.c_custkey").lt(scalar_subquery(sq)))?
.project(vec![col("customer.c_custkey")])?
.build()?;

let expected = r#"Projection: customer.c_custkey [c_custkey:Int64]
Filter: customer.c_custkey < __sq_1.__value [c_custkey:Int64, c_name:Utf8, __value:Int64;N]
CrossJoin: [c_custkey:Int64, c_name:Utf8, __value:Int64;N]
TableScan: customer [c_custkey:Int64, c_name:Utf8]
Projection: MAX(orders.o_custkey) AS __value, alias=__sq_1 [__value:Int64;N]
Aggregate: groupBy=[[]], aggr=[[MAX(orders.o_custkey)]] [MAX(orders.o_custkey):Int64;N]
TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#;

assert_optimized_plan_eq(&ScalarSubqueryToJoin::new(), &plan, expected);
Ok(())
}

#[test]
fn scalar_subquery_non_correlated_no_filters_with_equal_clause() -> Result<()> {
let sq = Arc::new(
LogicalPlanBuilder::from(scan_tpch_table("orders"))
.aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
Expand Down

0 comments on commit 13addce

Please sign in to comment.