diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index 399bb876b3d0..2200b86fc5b9 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -20,7 +20,9 @@ use datafusion_common::{ tree_node::{Transformed, TreeNode}, Column, Result, }; -use datafusion_expr::{Aggregate, Expr, LogicalPlan, Window}; +use datafusion_expr::{ + utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Window, +}; /// Recursively searches children of [LogicalPlan] to find an Aggregate node if exists /// prior to encountering a Join, TableScan, or a nested subquery (derived table factor). @@ -109,16 +111,16 @@ pub(crate) fn unproject_agg_exprs( expr.clone() .transform(|sub_expr| { if let Expr::Column(c) = sub_expr { - if let Some(unprojected_expr) = find_agg_expr(agg, &c) { + if let Some(unprojected_expr) = find_agg_expr(agg, &c)? { Ok(Transformed::yes(unprojected_expr.clone())) } else if let Some(mut unprojected_expr) = windows.and_then(|w| find_window_expr(w, &c.name).cloned()) { if let Expr::WindowFunction(func) = &mut unprojected_expr { - // Window function can contain aggregation column, for ex 'avg(sum(ss_sales_price)) over ..' that needs to be unprojected + // Window function can contain an aggregation column, e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected for arg in &mut func.args { if let Expr::Column(c) = arg { - if let Some(expr) = find_agg_expr(agg, c) { + if let Some(expr) = find_agg_expr(agg, c)? { *arg = expr.clone(); } } @@ -158,11 +160,20 @@ pub(crate) fn unproject_window_exprs(expr: &Expr, windows: &[&Window]) -> Result .map(|e| e.data) } -fn find_agg_expr<'a>(agg: &'a Aggregate, column: &Column) -> Option<&'a Expr> { +fn find_agg_expr<'a>(agg: &'a Aggregate, column: &Column) -> Result> { if let Ok(index) = agg.schema.index_of_column(column) { - agg.group_expr.iter().chain(agg.aggr_expr.iter()).nth(index) + if matches!(agg.group_expr.as_slice(), [Expr::GroupingSet(_)]) { + // For grouping set expr, we must operate by expression list from the grouping set + let grouping_expr = grouping_set_to_exprlist(agg.group_expr.as_slice())?; + return Ok(grouping_expr + .into_iter() + .chain(agg.aggr_expr.iter()) + .nth(index)); + } else { + return Ok(agg.group_expr.iter().chain(agg.aggr_expr.iter()).nth(index)); + }; } else { - None + Ok(None) } }