Skip to content

Commit

Permalink
Fix bug in subquery join filters referencing outer query (#2416)
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove authored May 3, 2022
1 parent c33ffe0 commit c42dc82
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 9 deletions.
2 changes: 1 addition & 1 deletion datafusion/core/src/logical_plan/expr_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {

/// Recursively call [`Column::normalize_with_schemas`] on all Column expressions
/// in the `expr` expression tree.
fn normalize_col_with_schemas(
pub fn normalize_col_with_schemas(
expr: Expr,
schemas: &[&Arc<DFSchema>],
using_columns: &[HashSet<Column>],
Expand Down
5 changes: 3 additions & 2 deletions datafusion/core/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ pub use expr::{
when, Column, Expr, ExprSchema, Literal,
};
pub use expr_rewriter::{
normalize_col, normalize_cols, replace_col, rewrite_sort_cols_by_aggs,
unnormalize_col, unnormalize_cols, ExprRewritable, ExprRewriter, RewriteRecursion,
normalize_col, normalize_col_with_schemas, normalize_cols, replace_col,
rewrite_sort_cols_by_aggs, unnormalize_col, unnormalize_cols, ExprRewritable,
ExprRewriter, RewriteRecursion,
};
pub use expr_simplier::{ExprSimplifiable, SimplifyInfo};
pub use expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion};
Expand Down
80 changes: 74 additions & 6 deletions datafusion/core/src/sql/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ use crate::logical_plan::window_frames::{WindowFrame, WindowFrameUnits};
use crate::logical_plan::Expr::Alias;
use crate::logical_plan::{
and, builder::expand_qualified_wildcard, builder::expand_wildcard, col, lit,
normalize_col, union_with_alias, Column, CreateCatalog, CreateCatalogSchema,
CreateExternalTable as PlanCreateExternalTable, CreateMemoryTable, DFSchema,
DFSchemaRef, DropTable, Expr, FileType, LogicalPlan, LogicalPlanBuilder, Operator,
PlanType, ToDFSchema, ToStringifiedPlan,
normalize_col, normalize_col_with_schemas, union_with_alias, Column, CreateCatalog,
CreateCatalogSchema, CreateExternalTable as PlanCreateExternalTable,
CreateMemoryTable, DFSchema, DFSchemaRef, DropTable, Expr, FileType, LogicalPlan,
LogicalPlanBuilder, Operator, PlanType, ToDFSchema, ToStringifiedPlan,
};
use crate::optimizer::utils::exprlist_to_columns;
use crate::prelude::JoinType;
Expand All @@ -50,7 +50,7 @@ use datafusion_expr::{window_function::WindowFunction, BuiltinScalarFunction};
use hashbrown::HashMap;

use datafusion_common::field_not_found;
use datafusion_expr::logical_plan::Subquery;
use datafusion_expr::logical_plan::{Filter, Subquery};
use sqlparser::ast::{
BinaryOperator, DataType as SQLDataType, DateTimeField, Expr as SQLExpr, FunctionArg,
FunctionArgExpr, Ident, Join, JoinConstraint, JoinOperator, ObjectName, Query,
Expand Down Expand Up @@ -803,6 +803,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {

let mut all_join_keys = HashSet::new();

let orig_plans = plans.clone();
let mut plans = plans.into_iter();
let mut left = plans.next().unwrap(); // have at least one plan

Expand Down Expand Up @@ -885,7 +886,33 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
// remove join expressions from filter
match remove_join_expressions(&filter_expr, &all_join_keys)? {
Some(filter_expr) => {
LogicalPlanBuilder::from(left).filter(filter_expr)?.build()
// this logic is adapted from [`LogicalPlanBuilder::filter`] to take
// the query outer schema into account so that joins in subqueries
// can reference outer query fields.
let mut all_schemas: Vec<DFSchemaRef> = vec![];
for plan in orig_plans {
for schema in plan.all_schemas() {
all_schemas.push(schema.clone());
}
}
if let Some(outer_query_schema) = outer_query_schema {
all_schemas.push(Arc::new(outer_query_schema.clone()));
}
let mut join_columns = HashSet::new();
for (l, r) in &all_join_keys {
join_columns.insert(l.clone());
join_columns.insert(r.clone());
}
let x: Vec<&DFSchemaRef> = all_schemas.iter().collect();
let filter_expr = normalize_col_with_schemas(
filter_expr,
x.as_slice(),
&[join_columns],
)?;
Ok(LogicalPlan::Filter(Filter {
predicate: filter_expr,
input: Arc::new(left),
}))
}
_ => Ok(left),
}
Expand Down Expand Up @@ -4244,6 +4271,18 @@ mod tests {
Field::new("t_date32", DataType::Date32, false),
Field::new("t_date64", DataType::Date64, false),
])),
"j1" => Some(Schema::new(vec![
Field::new("j1_id", DataType::Int32, false),
Field::new("j1_string", DataType::Utf8, false),
])),
"j2" => Some(Schema::new(vec![
Field::new("j2_id", DataType::Int32, false),
Field::new("j2_string", DataType::Utf8, false),
])),
"j3" => Some(Schema::new(vec![
Field::new("j3_id", DataType::Int32, false),
Field::new("j3_string", DataType::Utf8, false),
])),
"person" => Some(Schema::new(vec![
Field::new("id", DataType::UInt32, false),
Field::new("first_name", DataType::Utf8, false),
Expand Down Expand Up @@ -4518,6 +4557,35 @@ mod tests {
quick_test(sql, &expected);
}

#[test]
fn scalar_subquery_reference_outer_field() {
let sql = "SELECT j1_string, j2_string \
FROM j1, j2 \
WHERE j1_id = j2_id - 1 \
AND j2_id < (SELECT count(*) \
FROM j1, j3 \
WHERE j2_id = j1_id \
AND j1_id = j3_id)";

let subquery = "Subquery: Projection: #COUNT(UInt8(1))\
\n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\
\n Filter: #j2.j2_id = #j1.j1_id\
\n Inner Join: #j1.j1_id = #j3.j3_id\
\n TableScan: j1 projection=None\
\n TableScan: j3 projection=None";

let expected = format!(
"Projection: #j1.j1_string, #j2.j2_string\
\n Filter: #j1.j1_id = #j2.j2_id - Int64(1) AND #j2.j2_id < ({})\
\n CrossJoin:\
\n TableScan: j1 projection=None\
\n TableScan: j2 projection=None",
subquery
);

quick_test(sql, &expected);
}

#[tokio::test]
async fn subquery_references_cte() {
let sql = "WITH \
Expand Down

0 comments on commit c42dc82

Please sign in to comment.