Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Move subquery check from analyzer to PullUpCorrelatedExpr (Fix TPC-DS q41) #13091

Merged
merged 1 commit into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions datafusion/core/benches/sql_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,7 @@ fn criterion_benchmark(c: &mut Criterion) {

let tpcds_ctx = register_defs(SessionContext::new(), tpcds_schemas());

// 41: check_analyzed_plan: Correlated column is not allowed in predicate
let ignored = [41];

let raw_tpcds_sql_queries = (1..100)
.filter(|q| !ignored.contains(q))
.map(|q| std::fs::read_to_string(format!("./tests/tpc-ds/{q}.sql")).unwrap())
.collect::<Vec<_>>();

Expand Down
5 changes: 0 additions & 5 deletions datafusion/core/tests/tpcds_planning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,6 @@ async fn tpcds_logical_q40() -> Result<()> {
}

#[tokio::test]
#[ignore]
// check_analyzed_plan: Correlated column is not allowed in predicate
// issue: https://github.com/apache/datafusion/issues/13074
async fn tpcds_logical_q41() -> Result<()> {
create_logical_plan(41).await
}
Expand Down Expand Up @@ -726,8 +723,6 @@ async fn tpcds_physical_q40() -> Result<()> {
create_physical_plan(40).await
}

#[ignore]
// Context("check_analyzed_plan", Plan("Correlated column is not allowed in predicate: (..)
#[tokio::test]
async fn tpcds_physical_q41() -> Result<()> {
create_physical_plan(41).await
Expand Down
96 changes: 18 additions & 78 deletions datafusion/optimizer/src/analyzer/subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,14 @@
// specific language governing permissions and limitations
// under the License.

use std::ops::Deref;

use crate::analyzer::check_plan;
use crate::utils::collect_subquery_cols;

use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
use datafusion_common::{plan_err, Result};
use datafusion_expr::expr_rewriter::strip_outer_reference;
use datafusion_expr::utils::split_conjunction;
use datafusion_expr::{
Aggregate, BinaryExpr, Cast, Expr, Filter, Join, JoinType, LogicalPlan, Operator,
Window,
};
use datafusion_expr::{Aggregate, Expr, Filter, Join, JoinType, LogicalPlan, Window};

/// Do necessary check on subquery expressions and fail the invalid plan
/// 1) Check whether the outer plan is in the allowed outer plans list to use subquery expressions,
Expand Down Expand Up @@ -98,7 +93,7 @@ pub fn check_subquery_expr(
)
}?;
}
check_correlations_in_subquery(inner_plan, true)
check_correlations_in_subquery(inner_plan)
} else {
if let Expr::InSubquery(subquery) = expr {
// InSubquery should only return one column
Expand All @@ -121,58 +116,36 @@ pub fn check_subquery_expr(
Projection, Filter, Window functions, Aggregate and Join plan nodes"
),
}?;
check_correlations_in_subquery(inner_plan, false)
check_correlations_in_subquery(inner_plan)
}
}

// Recursively check the unsupported outer references in the sub query plan.
fn check_correlations_in_subquery(
inner_plan: &LogicalPlan,
is_scalar: bool,
) -> Result<()> {
check_inner_plan(inner_plan, is_scalar, false, true)
fn check_correlations_in_subquery(inner_plan: &LogicalPlan) -> Result<()> {
check_inner_plan(inner_plan, true)
}

// Recursively check the unsupported outer references in the sub query plan.
fn check_inner_plan(
inner_plan: &LogicalPlan,
is_scalar: bool,
is_aggregate: bool,
can_contain_outer_ref: bool,
) -> Result<()> {
fn check_inner_plan(inner_plan: &LogicalPlan, can_contain_outer_ref: bool) -> Result<()> {
if !can_contain_outer_ref && inner_plan.contains_outer_reference() {
return plan_err!("Accessing outer reference columns is not allowed in the plan");
}
// We want to support as many operators as possible inside the correlated subquery
match inner_plan {
LogicalPlan::Aggregate(_) => {
inner_plan.apply_children(|plan| {
check_inner_plan(plan, is_scalar, true, can_contain_outer_ref)?;
check_inner_plan(plan, can_contain_outer_ref)?;
Ok(TreeNodeRecursion::Continue)
})?;
Ok(())
}
LogicalPlan::Filter(Filter {
predicate, input, ..
}) => {
let (correlated, _): (Vec<_>, Vec<_>) = split_conjunction(predicate)
.into_iter()
.partition(|e| e.contains_outer());
let maybe_unsupported = correlated
.into_iter()
.filter(|expr| !can_pullup_over_aggregation(expr))
.collect::<Vec<_>>();
if is_aggregate && is_scalar && !maybe_unsupported.is_empty() {
return plan_err!(
"Correlated column is not allowed in predicate: {predicate}"
);
}
check_inner_plan(input, is_scalar, is_aggregate, can_contain_outer_ref)
LogicalPlan::Filter(Filter { input, .. }) => {
check_inner_plan(input, can_contain_outer_ref)
}
LogicalPlan::Window(window) => {
check_mixed_out_refer_in_window(window)?;
inner_plan.apply_children(|plan| {
check_inner_plan(plan, is_scalar, is_aggregate, can_contain_outer_ref)?;
check_inner_plan(plan, can_contain_outer_ref)?;
Ok(TreeNodeRecursion::Continue)
})?;
Ok(())
Expand All @@ -188,7 +161,7 @@ fn check_inner_plan(
| LogicalPlan::Subquery(_)
| LogicalPlan::SubqueryAlias(_) => {
inner_plan.apply_children(|plan| {
check_inner_plan(plan, is_scalar, is_aggregate, can_contain_outer_ref)?;
check_inner_plan(plan, can_contain_outer_ref)?;
Ok(TreeNodeRecursion::Continue)
})?;
Ok(())
Expand All @@ -201,27 +174,22 @@ fn check_inner_plan(
}) => match join_type {
JoinType::Inner => {
inner_plan.apply_children(|plan| {
check_inner_plan(
plan,
is_scalar,
is_aggregate,
can_contain_outer_ref,
)?;
check_inner_plan(plan, can_contain_outer_ref)?;
Ok(TreeNodeRecursion::Continue)
})?;
Ok(())
}
JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => {
check_inner_plan(left, is_scalar, is_aggregate, can_contain_outer_ref)?;
check_inner_plan(right, is_scalar, is_aggregate, false)
check_inner_plan(left, can_contain_outer_ref)?;
check_inner_plan(right, false)
}
JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => {
check_inner_plan(left, is_scalar, is_aggregate, false)?;
check_inner_plan(right, is_scalar, is_aggregate, can_contain_outer_ref)
check_inner_plan(left, false)?;
check_inner_plan(right, can_contain_outer_ref)
}
JoinType::Full => {
inner_plan.apply_children(|plan| {
check_inner_plan(plan, is_scalar, is_aggregate, false)?;
check_inner_plan(plan, false)?;
Ok(TreeNodeRecursion::Continue)
})?;
Ok(())
Expand Down Expand Up @@ -290,34 +258,6 @@ fn get_correlated_expressions(inner_plan: &LogicalPlan) -> Result<Vec<Expr>> {
Ok(exprs)
}

/// Check whether the expression can pull up over the aggregation without change the result of the query
fn can_pullup_over_aggregation(expr: &Expr) -> bool {
if let Expr::BinaryExpr(BinaryExpr {
left,
op: Operator::Eq,
right,
}) = expr
{
match (left.deref(), right.deref()) {
(Expr::Column(_), right) => !right.any_column_refs(),
(left, Expr::Column(_)) => !left.any_column_refs(),
(Expr::Cast(Cast { expr, .. }), right)
if matches!(expr.deref(), Expr::Column(_)) =>
{
!right.any_column_refs()
}
(left, Expr::Cast(Cast { expr, .. }))
if matches!(expr.deref(), Expr::Column(_)) =>
{
!left.any_column_refs()
}
(_, _) => false,
}
} else {
false
}
}

/// Check whether the window expressions contain a mixture of out reference columns and inner columns
fn check_mixed_out_refer_in_window(window: &Window) -> Result<()> {
let mixed = window
Expand Down Expand Up @@ -398,6 +338,6 @@ mod test {
}),
});

check_inner_plan(&plan, false, false, true).unwrap();
check_inner_plan(&plan, true).unwrap();
}
}
45 changes: 44 additions & 1 deletion datafusion/optimizer/src/decorrelate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ use datafusion_expr::expr::Alias;
use datafusion_expr::simplify::SimplifyContext;
use datafusion_expr::utils::{conjunction, find_join_exprs, split_conjunction};
use datafusion_expr::{
expr, lit, EmptyRelation, Expr, FetchType, LogicalPlan, LogicalPlanBuilder,
expr, lit, BinaryExpr, Cast, EmptyRelation, Expr, FetchType, LogicalPlan,
LogicalPlanBuilder, Operator,
};
use datafusion_physical_expr::execution_props::ExecutionProps;

Expand All @@ -51,6 +52,9 @@ pub struct PullUpCorrelatedExpr {
pub exists_sub_query: bool,
/// Can the correlated expressions be pulled up. Defaults to **TRUE**
pub can_pull_up: bool,
/// Indicates if we encounter any correlated expression that can not be pulled up
/// above a aggregation without changing the meaning of the query.
can_pull_over_aggregation: bool,
/// Do we need to handle [the Count bug] during the pull up process
///
/// [the Count bug]: https://github.com/apache/datafusion/pull/10500
Expand All @@ -75,6 +79,7 @@ impl PullUpCorrelatedExpr {
in_predicate_opt: None,
exists_sub_query: false,
can_pull_up: true,
can_pull_over_aggregation: true,
need_handle_count_bug: false,
collected_count_expr_map: HashMap::new(),
pull_up_having_expr: None,
Expand Down Expand Up @@ -154,6 +159,11 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr {
match &plan {
LogicalPlan::Filter(plan_filter) => {
let subquery_filter_exprs = split_conjunction(&plan_filter.predicate);
self.can_pull_over_aggregation = self.can_pull_over_aggregation
&& subquery_filter_exprs
.iter()
.filter(|e| e.contains_outer())
.all(|&e| can_pullup_over_aggregation(e));
let (mut join_filters, subquery_filters) =
find_join_exprs(subquery_filter_exprs)?;
if let Some(in_predicate) = &self.in_predicate_opt {
Expand Down Expand Up @@ -259,6 +269,12 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr {
LogicalPlan::Aggregate(aggregate)
if self.in_predicate_opt.is_some() || !self.join_filters.is_empty() =>
{
// If the aggregation is from a distinct it will not change the result for
// exists/in subqueries so we can still pull up all predicates.
let is_distinct = aggregate.aggr_expr.is_empty();
if !is_distinct {
self.can_pull_up = self.can_pull_up && self.can_pull_over_aggregation;
}
let mut local_correlated_cols = BTreeSet::new();
collect_local_correlated_cols(
&plan,
Expand Down Expand Up @@ -385,6 +401,33 @@ impl PullUpCorrelatedExpr {
}
}

fn can_pullup_over_aggregation(expr: &Expr) -> bool {
if let Expr::BinaryExpr(BinaryExpr {
left,
op: Operator::Eq,
right,
}) = expr
{
match (left.deref(), right.deref()) {
(Expr::Column(_), right) => !right.any_column_refs(),
(left, Expr::Column(_)) => !left.any_column_refs(),
(Expr::Cast(Cast { expr, .. }), right)
if matches!(expr.deref(), Expr::Column(_)) =>
{
!right.any_column_refs()
}
(left, Expr::Cast(Cast { expr, .. }))
if matches!(expr.deref(), Expr::Column(_)) =>
{
!left.any_column_refs()
}
(_, _) => false,
}
} else {
false
}
}

fn collect_local_correlated_cols(
plan: &LogicalPlan,
all_cols_map: &HashMap<LogicalPlan, BTreeSet<Column>>,
Expand Down
54 changes: 42 additions & 12 deletions datafusion/optimizer/src/scalar_subquery_to_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -625,11 +625,21 @@ mod tests {
.project(vec![col("customer.c_custkey")])?
.build()?;

let expected = "check_analyzed_plan\
\ncaused by\
\nError during planning: Correlated column is not allowed in predicate: outer_ref(customer.c_custkey) != orders.o_custkey";
// Unsupported predicate, subquery should not be decorrelated
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
\n Filter: customer.c_custkey = (<subquery>) [c_custkey:Int64, c_name:Utf8]\
\n Subquery: [max(orders.o_custkey):Int64;N]\
\n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\
\n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\
\n Filter: outer_ref(customer.c_custkey) != orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]";

assert_analyzer_check_err(vec![], plan, expected);
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
plan,
expected,
);
Ok(())
}

Expand All @@ -652,11 +662,21 @@ mod tests {
.project(vec![col("customer.c_custkey")])?
.build()?;

let expected = "check_analyzed_plan\
\ncaused by\
\nError during planning: Correlated column is not allowed in predicate: outer_ref(customer.c_custkey) < orders.o_custkey";
// Unsupported predicate, subquery should not be decorrelated
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
\n Filter: customer.c_custkey = (<subquery>) [c_custkey:Int64, c_name:Utf8]\
\n Subquery: [max(orders.o_custkey):Int64;N]\
\n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\
\n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\
\n Filter: outer_ref(customer.c_custkey) < orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]";

assert_analyzer_check_err(vec![], plan, expected);
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
plan,
expected,
);
Ok(())
}

Expand All @@ -680,11 +700,21 @@ mod tests {
.project(vec![col("customer.c_custkey")])?
.build()?;

let expected = "check_analyzed_plan\
\ncaused by\
\nError during planning: Correlated column is not allowed in predicate: outer_ref(customer.c_custkey) = orders.o_custkey OR orders.o_orderkey = Int32(1)";
// Unsupported predicate, subquery should not be decorrelated
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
\n Filter: customer.c_custkey = (<subquery>) [c_custkey:Int64, c_name:Utf8]\
\n Subquery: [max(orders.o_custkey):Int64;N]\
\n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\
\n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\
\n Filter: outer_ref(customer.c_custkey) = orders.o_custkey OR orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]";

assert_analyzer_check_err(vec![], plan, expected);
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
plan,
expected,
);
Ok(())
}

Expand Down
14 changes: 12 additions & 2 deletions datafusion/sqllogictest/test_files/subquery.slt
Original file line number Diff line number Diff line change
Expand Up @@ -509,8 +509,18 @@ SELECT t1_id, (SELECT a FROM (select 1 as a) WHERE a = t1.t1_int) as t2_int from
44 NULL

#non_equal_correlated_scalar_subquery
statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: Correlated column is not allowed in predicate: t2\.t2_id < outer_ref\(t1\.t1_id\)
SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id < t1.t1_id) as t2_sum from t1
# Currently not supported and should not be decorrelated
query TT
explain SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id < t1.t1_id) as t2_sum from t1
----
logical_plan
01)Projection: t1.t1_id, (<subquery>) AS t2_sum
02)--Subquery:
03)----Projection: sum(t2.t2_int)
04)------Aggregate: groupBy=[[]], aggr=[[sum(CAST(t2.t2_int AS Int64))]]
05)--------Filter: t2.t2_id < outer_ref(t1.t1_id)
06)----------TableScan: t2
07)--TableScan: t1 projection=[t1_id]

#aggregated_correlated_scalar_subquery_with_extra_group_by_columns
statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: A GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns
Expand Down