diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index 200f1f159d81..6a56c1753328 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -19,12 +19,13 @@ use std::sync::Arc; -use datafusion_common::tree_node::Transformed; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result}; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::logical_plan::LogicalPlan; use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::utils::merge_schema; +use datafusion_expr::Expr; use crate::optimizer::ApplyOrder; use crate::utils::NamePreserver; @@ -122,14 +123,21 @@ impl SimplifyExpressions { // Preserve expression names to avoid changing the schema of the plan. let name_preserver = NamePreserver::new(&plan); - plan.map_expressions(|e| { - let original_name = name_preserver.save(&e); - let new_e = simplifier - .simplify(e) - .map(|expr| original_name.restore(expr))?; + let mut rewrite_expr = |expr: Expr| { + let name = name_preserver.save(&expr); + let expr = simplifier.simplify(expr)?; // TODO it would be nice to have a way to know if the expression was simplified // or not. For now conservatively return Transformed::yes - Ok(Transformed::yes(new_e)) + Ok(Transformed::yes(name.restore(expr))) + }; + + plan.map_expressions(|expr| { + // Preserve the aliasing of grouping sets. + if let Expr::GroupingSet(_) = &expr { + expr.map_children(&mut rewrite_expr) + } else { + rewrite_expr(expr) + } }) } } @@ -151,11 +159,7 @@ mod tests { use crate::optimizer::Optimizer; use datafusion_expr::logical_plan::builder::table_scan_with_filters; use datafusion_expr::logical_plan::table_scan; - use datafusion_expr::{ - and, binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, Expr, - ExprSchemable, JoinType, - }; - use datafusion_expr::{or, BinaryExpr, Cast, Operator}; + use datafusion_expr::*; use datafusion_functions_aggregate::expr_fn::{max, min}; use crate::test::{assert_fields_eq, test_table_scan_with_name}; @@ -743,4 +747,24 @@ mod tests { assert_optimized_plan_eq(plan, expected) } + + #[test] + fn simplify_grouping_sets() -> Result<()> { + let table_scan = test_table_scan(); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + [grouping_set(vec![ + vec![(lit(42).alias("prev") + lit(1)).alias("age"), col("a")], + vec![col("a").or(col("b")).and(lit(1).lt(lit(0))).alias("cond")], + vec![col("d").alias("e"), (lit(1) + lit(2))], + ])], + [] as [Expr; 0], + )? + .build()?; + + let expected = "Aggregate: groupBy=[[GROUPING SETS ((Int32(43) AS age, test.a), (Boolean(false) AS cond), (test.d AS e, Int32(3) AS Int32(1) + Int32(2)))]], aggr=[[]]\ + \n TableScan: test"; + + assert_optimized_plan_eq(plan, expected) + } }