diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index c63aec6d336a..3a38d9f8eb03 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -368,8 +368,8 @@ impl CommonSubexprEliminate { None => Ok((new_window_expr_list, input, None)), })? - // Recurse into the new input. this is similar to top-down optimizer rule's - // logic. + // Recurse into the new input. + // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.) .transform_data(|(new_window_expr_list, new_input, window_expr_list)| { self.rewrite(new_input, config)?.map_data(|new_input| { Ok((new_window_expr_list, new_input, window_expr_list)) @@ -467,8 +467,8 @@ impl CommonSubexprEliminate { None => Ok((new_aggr_expr, new_group_expr, input, None)), } })? - // Recurse into the new input. this is similar to top-down optimizer rule's - // logic. + // Recurse into the new input. + // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.) .transform_data(|(new_aggr_expr, new_group_expr, new_input, aggr_expr)| { self.rewrite(new_input, config)?.map_data(|new_input| { Ok(( @@ -636,8 +636,8 @@ impl CommonSubexprEliminate { None => Ok((new_exprs, input)), } })? - // Recurse into the new input. This is similar to top-down optimizer rule's - // logic. + // Recurse into the new input. + // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.) .transform_data(|(new_exprs, new_input)| { self.rewrite(new_input, config)? .map_data(|new_input| Ok((new_exprs, new_input))) @@ -702,7 +702,10 @@ impl OptimizerRule for CommonSubexprEliminate { } fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) + // This rule handles recursion itself in a `ApplyOrder::TopDown` like manner. + // This is because in some cases adjacent nodes are collected (e.g. `Window`) and + // CSEd as a group, which can't be done in a simple `ApplyOrder::TopDown` rule. + None } fn rewrite( @@ -740,8 +743,9 @@ impl OptimizerRule for CommonSubexprEliminate { | LogicalPlan::Unnest(_) | LogicalPlan::RecursiveQuery(_) | LogicalPlan::Prepare(_) => { - // ApplyOrder::TopDown handles recursion - Transformed::no(plan) + // This rule handles recursion itself in a `ApplyOrder::TopDown` like + // manner. + plan.map_children(|c| self.rewrite(c, config))? } }; @@ -1187,42 +1191,22 @@ mod test { }; use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; + use super::*; use crate::optimizer::OptimizerContext; use crate::test::*; + use crate::Optimizer; use datafusion_expr::test::function_stub::{avg, sum}; - use super::*; - - fn assert_non_optimized_plan_eq( - expected: &str, - plan: LogicalPlan, - config: Option<&dyn OptimizerConfig>, - ) { - assert_eq!(expected, format!("{plan}"), "Unexpected starting plan"); - let optimizer = CommonSubexprEliminate::new(); - let default_config = OptimizerContext::new(); - let config = config.unwrap_or(&default_config); - let optimized_plan = optimizer.rewrite(plan, config).unwrap(); - assert!(!optimized_plan.transformed, "unexpectedly optimize plan"); - let optimized_plan = optimized_plan.data; - assert_eq!( - expected, - format!("{optimized_plan}"), - "Unexpected optimized plan" - ); - } - fn assert_optimized_plan_eq( expected: &str, plan: LogicalPlan, config: Option<&dyn OptimizerConfig>, ) { - let optimizer = CommonSubexprEliminate::new(); + let optimizer = + Optimizer::with_rules(vec![Arc::new(CommonSubexprEliminate::new())]); let default_config = OptimizerContext::new(); let config = config.unwrap_or(&default_config); - let optimized_plan = optimizer.rewrite(plan, config).unwrap(); - assert!(optimized_plan.transformed, "failed to optimize plan"); - let optimized_plan = optimized_plan.data; + let optimized_plan = optimizer.optimize(plan, config, |_, _| ()).unwrap(); let formatted_plan = format!("{optimized_plan}"); assert_eq!(expected, formatted_plan); } @@ -1612,7 +1596,7 @@ mod test { let expected = "Projection: Int32(1) + test.a, test.a + Int32(1)\ \n TableScan: test"; - assert_non_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_eq(expected, plan, None); Ok(()) } @@ -1630,7 +1614,7 @@ mod test { \n Projection: Int32(1) + test.a, test.a\ \n TableScan: test"; - assert_non_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_eq(expected, plan, None); Ok(()) }