diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 6bf7b776c8db..7df0068c5f54 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -56,6 +56,7 @@ use super::{ can_columns_satisfy_exprs, expand_wildcard, expr_as_column_expr, extract_aliases, find_aggregate_exprs, find_column_exprs, find_window_exprs, group_window_expr_by_sort_keys, rebase_expr, resolve_aliases_to_exprs, + resolve_positions_to_exprs, }, }; @@ -582,15 +583,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // All of the aggregate expressions (deduplicated). let aggr_exprs = find_aggregate_exprs(&aggr_expr_haystack); + let alias_map = extract_aliases(&select_exprs); let group_by_exprs = select .group_by .iter() .map(|e| { let group_by_expr = self.sql_expr_to_logical_expr(e)?; - let group_by_expr = resolve_aliases_to_exprs( - &group_by_expr, - &extract_aliases(&select_exprs), - )?; + let group_by_expr = resolve_aliases_to_exprs(&group_by_expr, &alias_map)?; + let group_by_expr = + resolve_positions_to_exprs(&group_by_expr, &select_exprs)?; self.validate_schema_satisfies_exprs( plan.schema(), &[group_by_expr.clone()], @@ -2326,6 +2327,39 @@ mod tests { ); } + #[test] + fn select_simple_aggregate_with_groupby_can_use_positions() { + quick_test( + "SELECT state, age AS b, COUNT(1) FROM person GROUP BY 1, 2", + "Projection: #state, #age AS b, #COUNT(UInt8(1))\ + \n Aggregate: groupBy=[[#state, #age]], aggr=[[COUNT(UInt8(1))]]\ + \n TableScan: person projection=None", + ); + quick_test( + "SELECT state, age AS b, COUNT(1) FROM person GROUP BY 2, 1", + "Projection: #state, #age AS b, #COUNT(UInt8(1))\ + \n Aggregate: groupBy=[[#age, #state]], aggr=[[COUNT(UInt8(1))]]\ + \n TableScan: person projection=None", + ); + } + + #[test] + fn select_simple_aggregate_with_groupby_position_out_of_range() { + let sql = "SELECT state, MIN(age) FROM person GROUP BY 0"; + let err = logical_plan(sql).expect_err("query should have failed"); + assert_eq!( + "Plan(\"Projection references non-aggregate values\")", + format!("{:?}", err) + ); + + let sql2 = "SELECT state, MIN(age) FROM person GROUP BY 5"; + let err2 = logical_plan(sql2).expect_err("query should have failed"); + assert_eq!( + "Plan(\"Projection references non-aggregate values\")", + format!("{:?}", err2) + ); + } + #[test] fn select_simple_aggregate_with_groupby_can_use_alias() { quick_test( diff --git a/datafusion/src/sql/utils.rs b/datafusion/src/sql/utils.rs index 7a5dc0da1b53..848fb3ee31fc 100644 --- a/datafusion/src/sql/utils.rs +++ b/datafusion/src/sql/utils.rs @@ -18,6 +18,7 @@ //! SQL Utility Functions use crate::logical_plan::{DFSchema, Expr, LogicalPlan}; +use crate::scalar::ScalarValue; use crate::{ error::{DataFusionError, Result}, logical_plan::{ExpressionVisitor, Recursion}, @@ -392,6 +393,27 @@ pub(crate) fn extract_aliases(exprs: &[Expr]) -> HashMap { .collect::>() } +pub(crate) fn resolve_positions_to_exprs( + expr: &Expr, + select_exprs: &[Expr], +) -> Result { + match expr { + // sql_expr_to_logical_expr maps number to i64 + // https://github.com/apache/arrow-datafusion/blob/8d175c759e17190980f270b5894348dc4cff9bbf/datafusion/src/sql/planner.rs#L882-L887 + Expr::Literal(ScalarValue::Int64(Some(position))) + if position > &0_i64 && position <= &(select_exprs.len() as i64) => + { + let index = (position - 1) as usize; + let select_expr = &select_exprs[index]; + match select_expr { + Expr::Alias(nested_expr, _alias_name) => Ok(*nested_expr.clone()), + _ => Ok(select_expr.clone()), + } + } + _ => Ok(expr.clone()), + } +} + /// Rebuilds an `Expr` with columns that refer to aliases replaced by the /// alias' underlying `Expr`. pub(crate) fn resolve_aliases_to_exprs(