Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 36 additions & 12 deletions datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@

use std::sync::Arc;

use datafusion_common::tree_node::Transformed;
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result};
use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::logical_plan::LogicalPlan;
use datafusion_expr::simplify::SimplifyContext;
use datafusion_expr::utils::merge_schema;
use datafusion_expr::Expr;

use crate::optimizer::ApplyOrder;
use crate::utils::NamePreserver;
Expand Down Expand Up @@ -122,14 +123,21 @@ impl SimplifyExpressions {

// Preserve expression names to avoid changing the schema of the plan.
let name_preserver = NamePreserver::new(&plan);
plan.map_expressions(|e| {
let original_name = name_preserver.save(&e);
let new_e = simplifier
.simplify(e)
.map(|expr| original_name.restore(expr))?;
let mut rewrite_expr = |expr: Expr| {
let name = name_preserver.save(&expr);
let expr = simplifier.simplify(expr)?;
// TODO it would be nice to have a way to know if the expression was simplified
// or not. For now conservatively return Transformed::yes
Ok(Transformed::yes(new_e))
Ok(Transformed::yes(name.restore(expr)))
};

plan.map_expressions(|expr| {
// Preserve the aliasing of grouping sets.
Copy link
Contributor

@alamb alamb Feb 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line seems like it does the opposite (doesn't preserve the original names 🤔 )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What makes you think so?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Grouping sets needs to maintain the alias of the children expressions as the field names needs to be based on that rather than the outer expression.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What makes you think so?

I was thinking that the code skips calling rewrite_expr (which calls name preserver) for the GroupingSets (so thus does not preserve the aliases of the Expr::Grouping itself

To be clear I think the code in PR looks good to me, I am just discussing if we can make the comment better

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I meant it in the sense that Expr::GroupinSets is a container of the actual grouping sets (which are a Vec<Expr> or Vec<Vec<Expr>>) and we want to preserve their names. But I'm happy to reword it however you like.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to reword if I am the only one confused :)

if let Expr::GroupingSet(_) = &expr {
expr.map_children(&mut rewrite_expr)
} else {
rewrite_expr(expr)
}
})
}
}
Expand All @@ -151,11 +159,7 @@ mod tests {
use crate::optimizer::Optimizer;
use datafusion_expr::logical_plan::builder::table_scan_with_filters;
use datafusion_expr::logical_plan::table_scan;
use datafusion_expr::{
and, binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, Expr,
ExprSchemable, JoinType,
};
use datafusion_expr::{or, BinaryExpr, Cast, Operator};
use datafusion_expr::*;
use datafusion_functions_aggregate::expr_fn::{max, min};

use crate::test::{assert_fields_eq, test_table_scan_with_name};
Expand Down Expand Up @@ -743,4 +747,24 @@ mod tests {

assert_optimized_plan_eq(plan, expected)
}

#[test]
fn simplify_grouping_sets() -> Result<()> {
let table_scan = test_table_scan();
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(
[grouping_set(vec![
vec![(lit(42).alias("prev") + lit(1)).alias("age"), col("a")],
vec![col("a").or(col("b")).and(lit(1).lt(lit(0))).alias("cond")],
vec![col("d").alias("e"), (lit(1) + lit(2))],
])],
[] as [Expr; 0],
)?
.build()?;

let expected = "Aggregate: groupBy=[[GROUPING SETS ((Int32(43) AS age, test.a), (Boolean(false) AS cond), (test.d AS e, Int32(3) AS Int32(1) + Int32(2)))]], aggr=[[]]\
\n TableScan: test";

assert_optimized_plan_eq(plan, expected)
}
}