Skip to content

Commit c956104

Browse files
chenkovskyJefffrey
andauthored
fix: window unparsing (#17367)
## Which issue does this PR close? - Closes #17360. ## Rationale for this change in LogicalPlan::Filter unparsing, if there's a window expr, it should be converted to quailify. postgres must has an alias for derived table. otherwise it will complain: ``` ERROR: subquery in FROM must have an alias. ``` fixed this issue at the same time. ## What changes are included in this PR? If window expr is found, convert filter to quailify. ## Are these changes tested? UT ## Are there any user-facing changes? No --------- Co-authored-by: Jeffrey Vo <jeffrey.vo.australia@gmail.com>
1 parent f0ab136 commit c956104

File tree

4 files changed

+192
-2
lines changed

4 files changed

+192
-2
lines changed

datafusion/sql/src/unparser/dialect.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,13 @@ pub trait Dialect: Send + Sync {
207207
Ok(None)
208208
}
209209

210+
/// Allows the dialect to support the QUALIFY clause
211+
///
212+
/// Some dialects, like Postgres, do not support the QUALIFY clause
213+
fn supports_qualify(&self) -> bool {
214+
true
215+
}
216+
210217
/// Allows the dialect to override logic of formatting datetime with tz into string.
211218
fn timestamp_with_tz_to_string(&self, dt: DateTime<Tz>, _unit: TimeUnit) -> String {
212219
dt.to_string()
@@ -274,6 +281,14 @@ impl Dialect for DefaultDialect {
274281
pub struct PostgreSqlDialect {}
275282

276283
impl Dialect for PostgreSqlDialect {
284+
fn supports_qualify(&self) -> bool {
285+
false
286+
}
287+
288+
fn requires_derived_table_alias(&self) -> bool {
289+
true
290+
}
291+
277292
fn identifier_quote_style(&self, _: &str) -> Option<char> {
278293
Some('"')
279294
}
@@ -424,6 +439,10 @@ impl Dialect for DuckDBDialect {
424439
pub struct MySqlDialect {}
425440

426441
impl Dialect for MySqlDialect {
442+
fn supports_qualify(&self) -> bool {
443+
false
444+
}
445+
427446
fn identifier_quote_style(&self, _: &str) -> Option<char> {
428447
Some('`')
429448
}
@@ -485,6 +504,10 @@ impl Dialect for MySqlDialect {
485504
pub struct SqliteDialect {}
486505

487506
impl Dialect for SqliteDialect {
507+
fn supports_qualify(&self) -> bool {
508+
false
509+
}
510+
488511
fn identifier_quote_style(&self, _: &str) -> Option<char> {
489512
Some('`')
490513
}

datafusion/sql/src/unparser/plan.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@ use super::{
3232
},
3333
Unparser,
3434
};
35-
use crate::unparser::ast::UnnestRelationBuilder;
3635
use crate::unparser::extension_unparser::{
3736
UnparseToStatementResult, UnparseWithinStatementResult,
3837
};
3938
use crate::unparser::utils::{find_unnest_node_until_relation, unproject_agg_exprs};
39+
use crate::unparser::{ast::UnnestRelationBuilder, rewrite::rewrite_qualify};
4040
use crate::utils::UNNEST_PLACEHOLDER;
4141
use datafusion_common::{
4242
internal_err, not_impl_err,
@@ -95,7 +95,10 @@ pub fn plan_to_sql(plan: &LogicalPlan) -> Result<ast::Statement> {
9595

9696
impl Unparser<'_> {
9797
pub fn plan_to_sql(&self, plan: &LogicalPlan) -> Result<ast::Statement> {
98-
let plan = normalize_union_schema(plan)?;
98+
let mut plan = normalize_union_schema(plan)?;
99+
if !self.dialect.supports_qualify() {
100+
plan = rewrite_qualify(plan)?;
101+
}
99102

100103
match plan {
101104
LogicalPlan::Projection(_)
@@ -428,6 +431,18 @@ impl Unparser<'_> {
428431
unproject_agg_exprs(filter.predicate.clone(), agg, None)?;
429432
let filter_expr = self.expr_to_sql(&unprojected)?;
430433
select.having(Some(filter_expr));
434+
} else if let (Some(window), true) = (
435+
find_window_nodes_within_select(
436+
plan,
437+
None,
438+
select.already_projected(),
439+
),
440+
self.dialect.supports_qualify(),
441+
) {
442+
let unprojected =
443+
unproject_window_exprs(filter.predicate.clone(), &window)?;
444+
let filter_expr = self.expr_to_sql(&unprojected)?;
445+
select.qualify(Some(filter_expr));
431446
} else {
432447
let filter_expr = self.expr_to_sql(&filter.predicate)?;
433448
select.selection(Some(filter_expr));

datafusion/sql/src/unparser/rewrite.rs

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,72 @@ fn rewrite_sort_expr_for_union(exprs: Vec<SortExpr>) -> Result<Vec<SortExpr>> {
100100
Ok(sort_exprs)
101101
}
102102

103+
/// Rewrite Filter plans that have a Window as their input by inserting a SubqueryAlias.
104+
///
105+
/// When a Filter directly operates on a Window plan, it can cause issues during SQL unparsing
106+
/// because window functions in a WHERE clause are not valid SQL. The solution is to wrap
107+
/// the Window plan in a SubqueryAlias, effectively creating a derived table.
108+
///
109+
/// Example transformation:
110+
///
111+
/// Filter: condition
112+
/// Window: window_function
113+
/// TableScan: table
114+
///
115+
/// becomes:
116+
///
117+
/// Filter: condition
118+
/// SubqueryAlias: __qualify_subquery
119+
/// Projection: table.column1, table.column2
120+
/// Window: window_function
121+
/// TableScan: table
122+
///
123+
pub(super) fn rewrite_qualify(plan: LogicalPlan) -> Result<LogicalPlan> {
124+
let transformed_plan = plan.transform_up(|plan| match plan {
125+
// Check if the filter's input is a Window plan
126+
LogicalPlan::Filter(mut filter) => {
127+
if matches!(&*filter.input, LogicalPlan::Window(_)) {
128+
// Create a SubqueryAlias around the Window plan
129+
let qualifier = filter
130+
.input
131+
.schema()
132+
.iter()
133+
.find_map(|(q, _)| q)
134+
.map(|q| q.to_string())
135+
.unwrap_or_else(|| "__qualify_subquery".to_string());
136+
137+
// for Postgres, name of column for 'rank() over (...)' is 'rank'
138+
// but in Datafusion, it is 'rank() over (...)'
139+
// without projection, it's still an invalid sql in Postgres
140+
141+
let project_exprs = filter
142+
.input
143+
.schema()
144+
.iter()
145+
.map(|(_, f)| datafusion_expr::col(f.name()).alias(f.name()))
146+
.collect::<Vec<_>>();
147+
148+
let input =
149+
datafusion_expr::LogicalPlanBuilder::from(Arc::clone(&filter.input))
150+
.project(project_exprs)?
151+
.build()?;
152+
153+
let subquery_alias =
154+
datafusion_expr::SubqueryAlias::try_new(Arc::new(input), qualifier)?;
155+
156+
filter.input = Arc::new(LogicalPlan::SubqueryAlias(subquery_alias));
157+
Ok(Transformed::yes(LogicalPlan::Filter(filter)))
158+
} else {
159+
Ok(Transformed::no(LogicalPlan::Filter(filter)))
160+
}
161+
}
162+
163+
_ => Ok(Transformed::no(plan)),
164+
});
165+
166+
transformed_plan.data()
167+
}
168+
103169
/// Rewrite logic plan for query that order by columns are not in projections
104170
/// Plan before rewrite:
105171
///

datafusion/sql/tests/cases/plan_to_sql.rs

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,14 @@ use datafusion_common::{
2121
assert_contains, Column, DFSchema, DFSchemaRef, DataFusionError, Result,
2222
TableReference,
2323
};
24+
use datafusion_expr::expr::{WindowFunction, WindowFunctionParams};
2425
use datafusion_expr::test::function_stub::{
2526
count_udaf, max_udaf, min_udaf, sum, sum_udaf,
2627
};
2728
use datafusion_expr::{
2829
cast, col, lit, table_scan, wildcard, EmptyRelation, Expr, Extension, LogicalPlan,
2930
LogicalPlanBuilder, Union, UserDefinedLogicalNode, UserDefinedLogicalNodeCore,
31+
WindowFrame, WindowFunctionDefinition,
3032
};
3133
use datafusion_functions::unicode;
3234
use datafusion_functions_aggregate::grouping::grouping_udaf;
@@ -2521,6 +2523,90 @@ fn test_unparse_left_semi_join_with_table_scan_projection() -> Result<()> {
25212523
Ok(())
25222524
}
25232525

2526+
#[test]
2527+
fn test_unparse_window() -> Result<()> {
2528+
// SubqueryAlias: t
2529+
// Projection: t.k, t.v, rank() PARTITION BY [t.k] ORDER BY [t.v ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS r
2530+
// Filter: rank() PARTITION BY [t.k] ORDER BY [t.v ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW = UInt64(1)
2531+
// WindowAggr: windowExpr=[[rank() PARTITION BY [t.k] ORDER BY [t.v ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
2532+
// TableScan: t projection=[k, v]
2533+
2534+
let schema = Schema::new(vec![
2535+
Field::new("k", DataType::Int32, false),
2536+
Field::new("v", DataType::Int32, false),
2537+
]);
2538+
let window_expr = Expr::WindowFunction(Box::new(WindowFunction {
2539+
fun: WindowFunctionDefinition::WindowUDF(rank_udwf()),
2540+
params: WindowFunctionParams {
2541+
args: vec![],
2542+
partition_by: vec![col("k")],
2543+
order_by: vec![col("v").sort(true, true)],
2544+
window_frame: WindowFrame::new(None),
2545+
null_treatment: None,
2546+
distinct: false,
2547+
filter: None,
2548+
},
2549+
}));
2550+
let table = table_scan(Some("test"), &schema, Some(vec![0, 1]))?.build()?;
2551+
let plan = LogicalPlanBuilder::window_plan(table, vec![window_expr.clone()])?;
2552+
2553+
let name = plan.schema().fields().last().unwrap().name().clone();
2554+
let plan = LogicalPlanBuilder::from(plan)
2555+
.filter(col(name.clone()).eq(lit(1i64)))?
2556+
.project(vec![col("k"), col("v"), col(name)])?
2557+
.build()?;
2558+
2559+
let unparser = Unparser::new(&UnparserPostgreSqlDialect {});
2560+
let sql = unparser.plan_to_sql(&plan)?;
2561+
assert_snapshot!(
2562+
sql,
2563+
@r#"SELECT "test"."k", "test"."v", "rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" FROM (SELECT "test"."k" AS "k", "test"."v" AS "v", rank() OVER (PARTITION BY "test"."k" ORDER BY "test"."v" ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" FROM "test") AS "test" WHERE ("rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" = 1)"#
2564+
);
2565+
2566+
let unparser = Unparser::new(&UnparserMySqlDialect {});
2567+
let sql = unparser.plan_to_sql(&plan)?;
2568+
assert_snapshot!(
2569+
sql,
2570+
@r#"SELECT `test`.`k`, `test`.`v`, `rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING` FROM (SELECT `test`.`k` AS `k`, `test`.`v` AS `v`, rank() OVER (PARTITION BY `test`.`k` ORDER BY `test`.`v` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS `rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING` FROM `test`) AS `test` WHERE (`rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING` = 1)"#
2571+
);
2572+
2573+
let unparser = Unparser::new(&SqliteDialect {});
2574+
let sql = unparser.plan_to_sql(&plan)?;
2575+
assert_snapshot!(
2576+
sql,
2577+
@r#"SELECT `test`.`k`, `test`.`v`, `rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING` FROM (SELECT `test`.`k` AS `k`, `test`.`v` AS `v`, rank() OVER (PARTITION BY `test`.`k` ORDER BY `test`.`v` ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS `rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING` FROM `test`) AS `test` WHERE (`rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING` = 1)"#
2578+
);
2579+
2580+
let unparser = Unparser::new(&DefaultDialect {});
2581+
let sql = unparser.plan_to_sql(&plan)?;
2582+
assert_snapshot!(
2583+
sql,
2584+
@r#"SELECT test.k, test.v, rank() OVER (PARTITION BY test.k ORDER BY test.v ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) FROM test QUALIFY (rank() OVER (PARTITION BY test.k ORDER BY test.v ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) = 1)"#
2585+
);
2586+
2587+
// without table qualifier
2588+
let table = table_scan(Some("test"), &schema, Some(vec![0, 1]))?.build()?;
2589+
let table = LogicalPlanBuilder::from(table)
2590+
.project(vec![col("k").alias("k"), col("v").alias("v")])?
2591+
.build()?;
2592+
let plan = LogicalPlanBuilder::window_plan(table, vec![window_expr])?;
2593+
2594+
let name = plan.schema().fields().last().unwrap().name().clone();
2595+
let plan = LogicalPlanBuilder::from(plan)
2596+
.filter(col(name.clone()).eq(lit(1i64)))?
2597+
.project(vec![col("k"), col("v"), col(name)])?
2598+
.build()?;
2599+
2600+
let unparser = Unparser::new(&UnparserPostgreSqlDialect {});
2601+
let sql = unparser.plan_to_sql(&plan)?;
2602+
assert_snapshot!(
2603+
sql,
2604+
@r#"SELECT "k", "v", "rank() PARTITION BY [k] ORDER BY [v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" FROM (SELECT "k" AS "k", "v" AS "v", rank() OVER (PARTITION BY "k" ORDER BY "v" ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "rank() PARTITION BY [k] ORDER BY [v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" FROM (SELECT "test"."k" AS "k", "test"."v" AS "v" FROM "test") AS "derived_projection") AS "__qualify_subquery" WHERE ("rank() PARTITION BY [k] ORDER BY [v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" = 1)"#
2605+
);
2606+
2607+
Ok(())
2608+
}
2609+
25242610
#[test]
25252611
fn test_like_filter() {
25262612
let statement = generate_round_trip_statement(

0 commit comments

Comments
 (0)