Skip to content

Commit ad273ca

Browse files
authored
Improve unparsing for ORDER BY, UNION, Windows functions with Aggregation (#12946)
* Improve unparsing for ORDER BY with Aggregation functions (#38) * Improve UNION unparsing (#39) * Scalar functions in ORDER BY unparsing support (#41) * Improve unparsing for complex Window functions with Aggregation (#42) * WindowFunction order_by should respect `supports_nulls_first_in_sort` dialect setting (#43) * Fix plan_to_sql * Improve
1 parent ccfe020 commit ad273ca

File tree

4 files changed

+148
-36
lines changed

4 files changed

+148
-36
lines changed

datafusion/sql/src/unparser/expr.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,6 @@ pub fn expr_to_sql(expr: &Expr) -> Result<ast::Expr> {
7676
unparser.expr_to_sql(expr)
7777
}
7878

79-
pub fn sort_to_sql(sort: &Sort) -> Result<ast::OrderByExpr> {
80-
let unparser = Unparser::default();
81-
unparser.sort_to_sql(sort)
82-
}
83-
8479
const LOWEST: &BinaryOperator = &BinaryOperator::Or;
8580
// Closest precedence we have to IS operator is BitwiseAnd (any other) in PG docs
8681
// (https://www.postgresql.org/docs/7.2/sql-precedence.html)
@@ -229,9 +224,10 @@ impl Unparser<'_> {
229224
ast::WindowFrameUnits::Groups
230225
}
231226
};
232-
let order_by: Vec<ast::OrderByExpr> = order_by
227+
228+
let order_by = order_by
233229
.iter()
234-
.map(sort_to_sql)
230+
.map(|sort_expr| self.sort_to_sql(sort_expr))
235231
.collect::<Result<Vec<_>>>()?;
236232

237233
let start_bound = self.convert_bound(&window_frame.start_bound)?;

datafusion/sql/src/unparser/plan.rs

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ use super::{
2727
},
2828
utils::{
2929
find_agg_node_within_select, find_window_nodes_within_select,
30-
unproject_window_exprs,
30+
unproject_sort_expr, unproject_window_exprs,
3131
},
3232
Unparser,
3333
};
@@ -352,19 +352,30 @@ impl Unparser<'_> {
352352
if select.already_projected() {
353353
return self.derive(plan, relation);
354354
}
355-
if let Some(query_ref) = query {
356-
if let Some(fetch) = sort.fetch {
357-
query_ref.limit(Some(ast::Expr::Value(ast::Value::Number(
358-
fetch.to_string(),
359-
false,
360-
))));
361-
}
362-
query_ref.order_by(self.sorts_to_sql(sort.expr.clone())?);
363-
} else {
355+
let Some(query_ref) = query else {
364356
return internal_err!(
365357
"Sort operator only valid in a statement context."
366358
);
367-
}
359+
};
360+
361+
if let Some(fetch) = sort.fetch {
362+
query_ref.limit(Some(ast::Expr::Value(ast::Value::Number(
363+
fetch.to_string(),
364+
false,
365+
))));
366+
};
367+
368+
let agg = find_agg_node_within_select(plan, select.already_projected());
369+
// unproject sort expressions
370+
let sort_exprs: Vec<SortExpr> = sort
371+
.expr
372+
.iter()
373+
.map(|sort_expr| {
374+
unproject_sort_expr(sort_expr, agg, sort.input.as_ref())
375+
})
376+
.collect::<Result<Vec<_>>>()?;
377+
378+
query_ref.order_by(self.sorts_to_sql(&sort_exprs)?);
368379

369380
self.select_to_sql_recursively(
370381
sort.input.as_ref(),
@@ -402,7 +413,7 @@ impl Unparser<'_> {
402413
.collect::<Result<Vec<_>>>()?;
403414
if let Some(sort_expr) = &on.sort_expr {
404415
if let Some(query_ref) = query {
405-
query_ref.order_by(self.sorts_to_sql(sort_expr.clone())?);
416+
query_ref.order_by(self.sorts_to_sql(sort_expr)?);
406417
} else {
407418
return internal_err!(
408419
"Sort operator only valid in a statement context."
@@ -546,6 +557,11 @@ impl Unparser<'_> {
546557
);
547558
}
548559

560+
// Covers cases where the UNION is a subquery and the projection is at the top level
561+
if select.already_projected() {
562+
return self.derive(plan, relation);
563+
}
564+
549565
let input_exprs: Vec<SetExpr> = union
550566
.inputs
551567
.iter()
@@ -691,7 +707,7 @@ impl Unparser<'_> {
691707
}
692708
}
693709

694-
fn sorts_to_sql(&self, sort_exprs: Vec<SortExpr>) -> Result<Vec<ast::OrderByExpr>> {
710+
fn sorts_to_sql(&self, sort_exprs: &[SortExpr]) -> Result<Vec<ast::OrderByExpr>> {
695711
sort_exprs
696712
.iter()
697713
.map(|sort_expr| self.sort_to_sql(sort_expr))

datafusion/sql/src/unparser/utils.rs

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@ use std::cmp::Ordering;
2020
use datafusion_common::{
2121
internal_err,
2222
tree_node::{Transformed, TreeNode},
23-
Column, DataFusionError, Result, ScalarValue,
23+
Column, Result, ScalarValue,
2424
};
2525
use datafusion_expr::{
26-
utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Window,
26+
utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Projection, SortExpr,
27+
Window,
2728
};
2829
use sqlparser::ast;
2930

@@ -118,21 +119,11 @@ pub(crate) fn unproject_agg_exprs(
118119
if let Expr::Column(c) = sub_expr {
119120
if let Some(unprojected_expr) = find_agg_expr(agg, &c)? {
120121
Ok(Transformed::yes(unprojected_expr.clone()))
121-
} else if let Some(mut unprojected_expr) =
122+
} else if let Some(unprojected_expr) =
122123
windows.and_then(|w| find_window_expr(w, &c.name).cloned())
123124
{
124-
if let Expr::WindowFunction(func) = &mut unprojected_expr {
125-
// Window function can contain an aggregation column, e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected
126-
func.args.iter_mut().try_for_each(|arg| {
127-
if let Expr::Column(c) = arg {
128-
if let Some(expr) = find_agg_expr(agg, c)? {
129-
*arg = expr.clone();
130-
}
131-
}
132-
Ok::<(), DataFusionError>(())
133-
})?;
134-
}
135-
Ok(Transformed::yes(unprojected_expr))
125+
// Window function can contain an aggregation columns, e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected
126+
return Ok(Transformed::yes(unproject_agg_exprs(&unprojected_expr, agg, None)?));
136127
} else {
137128
internal_err!(
138129
"Tried to unproject agg expr for column '{}' that was not found in the provided Aggregate!", &c.name
@@ -200,6 +191,54 @@ fn find_window_expr<'a>(
200191
.find(|expr| expr.schema_name().to_string() == column_name)
201192
}
202193

194+
/// Transforms a Column expression into the actual expression from aggregation or projection if found.
195+
/// This is required because if an ORDER BY expression is present in an Aggregate or Select, it is replaced
196+
/// with a Column expression (e.g., "sum(catalog_returns.cr_net_loss)"). We need to transform it back to
197+
/// the actual expression, such as sum("catalog_returns"."cr_net_loss").
198+
pub(crate) fn unproject_sort_expr(
199+
sort_expr: &SortExpr,
200+
agg: Option<&Aggregate>,
201+
input: &LogicalPlan,
202+
) -> Result<SortExpr> {
203+
let mut sort_expr = sort_expr.clone();
204+
205+
// Remove alias if present, because ORDER BY cannot use aliases
206+
if let Expr::Alias(alias) = &sort_expr.expr {
207+
sort_expr.expr = *alias.expr.clone();
208+
}
209+
210+
let Expr::Column(ref col_ref) = sort_expr.expr else {
211+
return Ok(sort_expr);
212+
};
213+
214+
if col_ref.relation.is_some() {
215+
return Ok(sort_expr);
216+
};
217+
218+
// In case of aggregation there could be columns containing aggregation functions we need to unproject
219+
if let Some(agg) = agg {
220+
if agg.schema.is_column_from_schema(col_ref) {
221+
let new_expr = unproject_agg_exprs(&sort_expr.expr, agg, None)?;
222+
sort_expr.expr = new_expr;
223+
return Ok(sort_expr);
224+
}
225+
}
226+
227+
// If SELECT and ORDER BY contain the same expression with a scalar function, the ORDER BY expression will
228+
// be replaced by a Column expression (e.g., "substr(customer.c_last_name, Int64(0), Int64(5))"), and we need
229+
// to transform it back to the actual expression.
230+
if let LogicalPlan::Projection(Projection { expr, schema, .. }) = input {
231+
if let Ok(idx) = schema.index_of_column(col_ref) {
232+
if let Some(Expr::ScalarFunction(scalar_fn)) = expr.get(idx) {
233+
sort_expr.expr = Expr::ScalarFunction(scalar_fn.clone());
234+
}
235+
}
236+
return Ok(sort_expr);
237+
}
238+
239+
Ok(sort_expr)
240+
}
241+
203242
/// Converts a date_part function to SQL, tailoring it to the supported date field extraction style.
204243
pub(crate) fn date_part_to_sql(
205244
unparser: &Unparser,

datafusion/sql/tests/cases/plan_to_sql.rs

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ use arrow_schema::*;
2222
use datafusion_common::{DFSchema, Result, TableReference};
2323
use datafusion_expr::test::function_stub::{count_udaf, max_udaf, min_udaf, sum_udaf};
2424
use datafusion_expr::{col, lit, table_scan, wildcard, LogicalPlanBuilder};
25+
use datafusion_functions::unicode;
26+
use datafusion_functions_aggregate::grouping::grouping_udaf;
27+
use datafusion_functions_window::rank::rank_udwf;
2528
use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel};
2629
use datafusion_sql::unparser::dialect::{
2730
DefaultDialect as UnparserDefaultDialect, Dialect as UnparserDialect,
@@ -139,6 +142,13 @@ fn roundtrip_statement() -> Result<()> {
139142
SELECT j2_string as string FROM j2
140143
ORDER BY string DESC
141144
LIMIT 10"#,
145+
r#"SELECT col1, id FROM (
146+
SELECT j1_string AS col1, j1_id AS id FROM j1
147+
UNION ALL
148+
SELECT j2_string AS col1, j2_id AS id FROM j2
149+
UNION ALL
150+
SELECT j3_string AS col1, j3_id AS id FROM j3
151+
) AS subquery GROUP BY col1, id ORDER BY col1 ASC, id ASC"#,
142152
"SELECT id, count(*) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
143153
last_name, sum(id) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
144154
first_name from person",
@@ -657,7 +667,12 @@ where
657667
.unwrap();
658668

659669
let context = MockContextProvider {
660-
state: MockSessionState::default(),
670+
state: MockSessionState::default()
671+
.with_aggregate_function(sum_udaf())
672+
.with_aggregate_function(max_udaf())
673+
.with_aggregate_function(grouping_udaf())
674+
.with_window_function(rank_udwf())
675+
.with_scalar_function(Arc::new(unicode::substr().as_ref().clone())),
661676
};
662677
let sql_to_rel = SqlToRel::new(&context);
663678
let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();
@@ -969,3 +984,49 @@ fn test_with_offset0() {
969984
fn test_with_offset95() {
970985
sql_round_trip(MySqlDialect {}, "select 1 offset 95", "SELECT 1 OFFSET 95");
971986
}
987+
988+
#[test]
989+
fn test_order_by_to_sql() {
990+
// order by aggregation function
991+
sql_round_trip(
992+
GenericDialect {},
993+
r#"SELECT id, first_name, SUM(id) FROM person GROUP BY id, first_name ORDER BY SUM(id) ASC, first_name DESC, id, first_name LIMIT 10"#,
994+
r#"SELECT person.id, person.first_name, sum(person.id) FROM person GROUP BY person.id, person.first_name ORDER BY sum(person.id) ASC NULLS LAST, person.first_name DESC NULLS FIRST, person.id ASC NULLS LAST, person.first_name ASC NULLS LAST LIMIT 10"#,
995+
);
996+
997+
// order by aggregation function alias
998+
sql_round_trip(
999+
GenericDialect {},
1000+
r#"SELECT id, first_name, SUM(id) as total_sum FROM person GROUP BY id, first_name ORDER BY total_sum ASC, first_name DESC, id, first_name LIMIT 10"#,
1001+
r#"SELECT person.id, person.first_name, sum(person.id) AS total_sum FROM person GROUP BY person.id, person.first_name ORDER BY total_sum ASC NULLS LAST, person.first_name DESC NULLS FIRST, person.id ASC NULLS LAST, person.first_name ASC NULLS LAST LIMIT 10"#,
1002+
);
1003+
1004+
// order by scalar function from projection
1005+
sql_round_trip(
1006+
GenericDialect {},
1007+
r#"SELECT id, first_name, substr(first_name,0,5) FROM person ORDER BY id, substr(first_name,0,5)"#,
1008+
r#"SELECT person.id, person.first_name, substr(person.first_name, 0, 5) FROM person ORDER BY person.id ASC NULLS LAST, substr(person.first_name, 0, 5) ASC NULLS LAST"#,
1009+
);
1010+
}
1011+
1012+
#[test]
1013+
fn test_aggregation_to_sql() {
1014+
sql_round_trip(
1015+
GenericDialect {},
1016+
r#"SELECT id, first_name,
1017+
SUM(id) AS total_sum,
1018+
SUM(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum,
1019+
MAX(SUM(id)) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total,
1020+
rank() OVER (PARTITION BY grouping(id) + grouping(age), CASE WHEN grouping(age) = 0 THEN id END ORDER BY sum(id) DESC) AS rank_within_parent_1,
1021+
rank() OVER (PARTITION BY grouping(age) + grouping(id), CASE WHEN (CAST(grouping(age) AS BIGINT) = 0) THEN id END ORDER BY sum(id) DESC) AS rank_within_parent_2
1022+
FROM person
1023+
GROUP BY id, first_name;"#,
1024+
r#"SELECT person.id, person.first_name,
1025+
sum(person.id) AS total_sum, sum(person.id) OVER (PARTITION BY person.first_name ROWS BETWEEN '5' PRECEDING AND '2' FOLLOWING) AS moving_sum,
1026+
max(sum(person.id)) OVER (PARTITION BY person.first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total,
1027+
rank() OVER (PARTITION BY (grouping(person.id) + grouping(person.age)), CASE WHEN (grouping(person.age) = 0) THEN person.id END ORDER BY sum(person.id) DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS rank_within_parent_1,
1028+
rank() OVER (PARTITION BY (grouping(person.age) + grouping(person.id)), CASE WHEN (CAST(grouping(person.age) AS BIGINT) = 0) THEN person.id END ORDER BY sum(person.id) DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS rank_within_parent_2
1029+
FROM person
1030+
GROUP BY person.id, person.first_name"#.replace("\n", " ").as_str(),
1031+
);
1032+
}

0 commit comments

Comments
 (0)