diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index d0dd24621d3e..276a1cc4c59c 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -615,6 +615,11 @@ impl Transformed { } } + /// Create a `Transformed` with `transformed and [`TreeNodeRecursion::Continue`]. + pub fn new_transformed(data: T, transformed: bool) -> Self { + Self::new(data, transformed, TreeNodeRecursion::Continue) + } + /// Wrapper for transformed data with [`TreeNodeRecursion::Continue`] statement. pub fn yes(data: T) -> Self { Self::new(data, true, TreeNodeRecursion::Continue) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 02378ab3fc1b..85958223ac97 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -870,37 +870,7 @@ impl LogicalPlan { LogicalPlan::Filter { .. } => { assert_eq!(1, expr.len()); let predicate = expr.pop().unwrap(); - - // filter predicates should not contain aliased expressions so we remove any aliases - // before this logic was added we would have aliases within filters such as for - // benchmark q6: - // - // lineitem.l_shipdate >= Date32(\"8766\") - // AND lineitem.l_shipdate < Date32(\"9131\") - // AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS lineitem.l_discount >= - // Decimal128(Some(49999999999999),30,15) - // AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS lineitem.l_discount <= - // Decimal128(Some(69999999999999),30,15) - // AND lineitem.l_quantity < Decimal128(Some(2400),15,2) - - let predicate = predicate - .transform_down(|expr| { - match expr { - Expr::Exists { .. } - | Expr::ScalarSubquery(_) - | Expr::InSubquery(_) => { - // subqueries could contain aliases so we don't recurse into those - Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)) - } - Expr::Alias(_) => Ok(Transformed::new( - expr.unalias(), - true, - TreeNodeRecursion::Jump, - )), - _ => Ok(Transformed::no(expr)), - } - }) - .data()?; + let predicate = Filter::remove_aliases(predicate)?.data; Filter::try_new(predicate, Arc::new(inputs.swap_remove(0))) .map(LogicalPlan::Filter) @@ -2230,6 +2200,38 @@ impl Filter { } false } + + /// Remove aliases from a predicate for use in a `Filter` + /// + /// filter predicates should not contain aliased expressions so we remove + /// any aliases. + /// + /// before this logic was added we would have aliases within filters such as + /// for benchmark q6: + /// + /// ```sql + /// lineitem.l_shipdate >= Date32(\"8766\") + /// AND lineitem.l_shipdate < Date32(\"9131\") + /// AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS lineitem.l_discount >= + /// Decimal128(Some(49999999999999),30,15) + /// AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS lineitem.l_discount <= + /// Decimal128(Some(69999999999999),30,15) + /// AND lineitem.l_quantity < Decimal128(Some(2400),15,2) + /// ``` + pub fn remove_aliases(predicate: Expr) -> Result> { + predicate.transform_down(|expr| { + match expr { + Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::InSubquery(_) => { + // subqueries could contain aliases so we don't recurse into those + Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)) + } + Expr::Alias(Alias { expr, .. }) => { + Ok(Transformed::new(*expr, true, TreeNodeRecursion::Jump)) + } + _ => Ok(Transformed::no(expr)), + } + }) + } } /// Window its input based on a set of window spec and window function (e.g. SUM or RANK) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index e150a957bfcf..7f4093ba110e 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -20,16 +20,22 @@ use std::collections::{BTreeSet, HashMap}; use std::sync::Arc; -use crate::{utils, OptimizerConfig, OptimizerRule}; +use crate::{OptimizerConfig, OptimizerRule}; +use crate::optimizer::ApplyOrder; +use crate::utils::NamePreserver; use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, - TreeNodeVisitor, + Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, +}; +use datafusion_common::{ + internal_datafusion_err, internal_err, qualified_name, Column, DFSchema, Result, }; -use datafusion_common::{qualified_name, Column, DFSchema, DataFusionError, Result}; use datafusion_expr::expr::Alias; -use datafusion_expr::logical_plan::{Aggregate, LogicalPlan, Projection, Window}; +use datafusion_expr::logical_plan::tree_node::unwrap_arc; +use datafusion_expr::logical_plan::{ + Aggregate, Filter, LogicalPlan, Projection, Sort, Window, +}; use datafusion_expr::{col, Expr, ExprSchemable}; use indexmap::IndexMap; @@ -123,32 +129,39 @@ impl CommonSubexprEliminate { /// Returns the rewritten expressions fn rewrite_exprs_list( &self, - exprs_list: &[&[Expr]], + exprs_list: Vec>, arrays_list: &[&[IdArray]], expr_stats: &ExprStats, common_exprs: &mut CommonExprs, alias_generator: &AliasGenerator, - ) -> Result>> { + ) -> Result>>> { + let mut transformed = false; exprs_list - .iter() + .into_iter() .zip(arrays_list.iter()) .map(|(exprs, arrays)| { exprs - .iter() - .cloned() + .into_iter() .zip(arrays.iter()) .map(|(expr, id_array)| { - replace_common_expr( + let replaced = replace_common_expr( expr, id_array, expr_stats, common_exprs, alias_generator, - ) + )?; + // remember if this expression was actually replaced + transformed |= replaced.transformed; + Ok(replaced.data) }) .collect::>>() }) .collect::>>() + .map(|rewritten_exprs_list| { + // propagate back transformed information + Transformed::new_transformed(rewritten_exprs_list, transformed) + }) } /// Rewrites the expression in `exprs_list` with common sub-expressions @@ -161,13 +174,15 @@ impl CommonSubexprEliminate { /// common sub-expressions that were used fn rewrite_expr( &self, - exprs_list: &[&[Expr]], + exprs_list: Vec>, arrays_list: &[&[IdArray]], - input: &LogicalPlan, + input: LogicalPlan, expr_stats: &ExprStats, config: &dyn OptimizerConfig, - ) -> Result<(Vec>, LogicalPlan)> { + ) -> Result>, LogicalPlan)>> { + let mut transformed = false; let mut common_exprs = CommonExprs::new(); + let rewrite_exprs = self.rewrite_exprs_list( exprs_list, arrays_list, @@ -175,115 +190,193 @@ impl CommonSubexprEliminate { &mut common_exprs, &config.alias_generator(), )?; + transformed |= rewrite_exprs.transformed; - let mut new_input = self - .try_optimize(input, config)? - .unwrap_or_else(|| input.clone()); + let new_input = self.rewrite(input, config)?; + transformed |= new_input.transformed; + let mut new_input = new_input.data; if !common_exprs.is_empty() { + assert!(transformed); new_input = build_common_expr_project_plan(new_input, common_exprs)?; } - Ok((rewrite_exprs, new_input)) + // return the transformed information + + Ok(Transformed::new_transformed( + (rewrite_exprs.data, new_input), + transformed, + )) } - fn try_optimize_window( + fn try_optimize_proj( &self, - window: &Window, + projection: Projection, config: &dyn OptimizerConfig, - ) -> Result { - let mut window_exprs = vec![]; - let mut arrays_per_window = vec![]; - let mut expr_stats = ExprStats::new(); - - // Get all window expressions inside the consecutive window operators. - // Consecutive window expressions may refer to same complex expression. - // If same complex expression is referred more than once by subsequent `WindowAggr`s, - // we can cache complex expression by evaluating it with a projection before the - // first WindowAggr. - // This enables us to cache complex expression "c3+c4" for following plan: - // WindowAggr: windowExpr=[[sum(c9) ORDER BY [c3 + c4] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] - // --WindowAggr: windowExpr=[[sum(c9) ORDER BY [c3 + c4] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] - // where, it is referred once by each `WindowAggr` (total of 2) in the plan. - let mut plan = LogicalPlan::Window(window.clone()); - while let LogicalPlan::Window(window) = plan { - let Window { - input, window_expr, .. - } = window; - plan = input.as_ref().clone(); + ) -> Result> { + let Projection { + expr, + input, + schema, + .. + } = projection; + let input = unwrap_arc(input); + self.try_unary_plan(expr, input, config)? + .map_data(|(new_expr, new_input)| { + Projection::try_new_with_schema(new_expr, Arc::new(new_input), schema) + .map(LogicalPlan::Projection) + }) + } + fn try_optimize_sort( + &self, + sort: Sort, + config: &dyn OptimizerConfig, + ) -> Result> { + let Sort { expr, input, fetch } = sort; + let input = unwrap_arc(input); + let new_sort = self.try_unary_plan(expr, input, config)?.update_data( + |(new_expr, new_input)| { + LogicalPlan::Sort(Sort { + expr: new_expr, + input: Arc::new(new_input), + fetch, + }) + }, + ); + Ok(new_sort) + } - let arrays = to_arrays(&window_expr, &mut expr_stats, ExprMask::Normal)?; + fn try_optimize_filter( + &self, + filter: Filter, + config: &dyn OptimizerConfig, + ) -> Result> { + let Filter { + predicate, input, .. + } = filter; + let input = unwrap_arc(input); + let expr = vec![predicate]; + self.try_unary_plan(expr, input, config)? + .transform_data(|(mut new_expr, new_input)| { + assert_eq!(new_expr.len(), 1); // passed in vec![predicate] + let new_predicate = new_expr.pop().unwrap(); + Ok(Filter::remove_aliases(new_predicate)? + .update_data(|new_predicate| (new_predicate, new_input))) + })? + .map_data(|(new_predicate, new_input)| { + Filter::try_new(new_predicate, Arc::new(new_input)) + .map(LogicalPlan::Filter) + }) + } - window_exprs.push(window_expr); - arrays_per_window.push(arrays); - } + fn try_optimize_window( + &self, + window: Window, + config: &dyn OptimizerConfig, + ) -> Result> { + // collect all window expressions from any number of LogicalPlanWindow + let ConsecutiveWindowExprs { + window_exprs, + arrays_per_window, + expr_stats, + plan, + } = ConsecutiveWindowExprs::try_new(window)?; - let mut window_exprs = window_exprs - .iter() - .map(|expr| expr.as_slice()) - .collect::>(); let arrays_per_window = arrays_per_window .iter() .map(|arrays| arrays.as_slice()) .collect::>(); + // save the original names + let name_preserver = NamePreserver::new(&plan); + let mut saved_names = window_exprs + .iter() + .map(|exprs| { + exprs + .iter() + .map(|expr| name_preserver.save(expr)) + .collect::>>() + }) + .collect::>>()?; + assert_eq!(window_exprs.len(), arrays_per_window.len()); - let (mut new_expr, new_input) = self.rewrite_expr( - &window_exprs, + let num_window_exprs = window_exprs.len(); + let rewritten_window_exprs = self.rewrite_expr( + window_exprs, &arrays_per_window, - &plan, + plan, &expr_stats, config, )?; - assert_eq!(window_exprs.len(), new_expr.len()); + let transformed = rewritten_window_exprs.transformed; + + let (mut new_expr, new_input) = rewritten_window_exprs.data; - // Construct consecutive window operator, with their corresponding new window expressions. - plan = new_input; - while let Some(new_window_expr) = new_expr.pop() { - // Since `new_expr` and `window_exprs` length are same. We can safely `.unwrap` here. - let orig_window_expr = window_exprs.pop().unwrap(); - assert_eq!(new_window_expr.len(), orig_window_expr.len()); + let mut plan = new_input; - // Rename new re-written window expressions with original name (by giving alias) - // Otherwise we may receive schema error, in subsequent operators. + // Construct consecutive window operator, with their corresponding new + // window expressions. + // + // Note this iterates over, `new_expr` and `saved_names` which are the + // same length, in reverse order + assert_eq!(num_window_exprs, new_expr.len()); + assert_eq!(num_window_exprs, saved_names.len()); + while let (Some(new_window_expr), Some(saved_names)) = + (new_expr.pop(), saved_names.pop()) + { + assert_eq!(new_window_expr.len(), saved_names.len()); + + // Rename re-written window expressions with original name, to + // preserve the output schema let new_window_expr = new_window_expr .into_iter() - .zip(orig_window_expr.iter()) - .map(|(new_window_expr, window_expr)| { - let original_name = window_expr.name_for_alias()?; - new_window_expr.alias_if_changed(original_name) - }) + .zip(saved_names.into_iter()) + .map(|(new_window_expr, saved_name)| saved_name.restore(new_window_expr)) .collect::>>()?; plan = LogicalPlan::Window(Window::try_new(new_window_expr, Arc::new(plan))?); } - Ok(plan) + Ok(Transformed::new_transformed(plan, transformed)) } fn try_optimize_aggregate( &self, - aggregate: &Aggregate, + aggregate: Aggregate, config: &dyn OptimizerConfig, - ) -> Result { + ) -> Result> { let Aggregate { group_expr, aggr_expr, input, + schema: orig_schema, .. } = aggregate; let mut expr_stats = ExprStats::new(); + // track transformed information + let mut transformed = false; + // rewrite inputs - let group_arrays = to_arrays(group_expr, &mut expr_stats, ExprMask::Normal)?; - let aggr_arrays = to_arrays(aggr_expr, &mut expr_stats, ExprMask::Normal)?; + let group_arrays = to_arrays(&group_expr, &mut expr_stats, ExprMask::Normal)?; + let aggr_arrays = to_arrays(&aggr_expr, &mut expr_stats, ExprMask::Normal)?; + + let name_perserver = NamePreserver::new_for_projection(); + let saved_names = aggr_expr + .iter() + .map(|expr| name_perserver.save(expr)) + .collect::>>()?; - let (mut new_expr, new_input) = self.rewrite_expr( - &[group_expr, aggr_expr], + // rewrite both group exprs and aggr_expr + let rewritten = self.rewrite_expr( + vec![group_expr, aggr_expr], &[&group_arrays, &aggr_arrays], - input, + unwrap_arc(input), &expr_stats, config, )?; + transformed |= rewritten.transformed; + let (mut new_expr, new_input) = rewritten.data; + // note the reversed pop order. let new_aggr_expr = pop_expr(&mut new_expr)?; let new_group_expr = pop_expr(&mut new_expr)?; @@ -296,108 +389,208 @@ impl CommonSubexprEliminate { &mut expr_stats, ExprMask::NormalAndAggregates, )?; - let mut common_exprs = CommonExprs::new(); - let mut rewritten = self.rewrite_exprs_list( - &[&new_aggr_expr], + let mut common_exprs = IndexMap::new(); + let mut rewritten_exprs = self.rewrite_exprs_list( + vec![new_aggr_expr.clone()], &[&aggr_arrays], &expr_stats, &mut common_exprs, &config.alias_generator(), )?; - let rewritten = pop_expr(&mut rewritten)?; + transformed |= rewritten_exprs.transformed; + let rewritten = pop_expr(&mut rewritten_exprs.data)?; if common_exprs.is_empty() { // Alias aggregation expressions if they have changed let new_aggr_expr = new_aggr_expr - .iter() - .zip(aggr_expr.iter()) - .map(|(new_expr, old_expr)| { - new_expr.clone().alias_if_changed(old_expr.display_name()?) - }) + .into_iter() + .zip(saved_names.into_iter()) + .map(|(new_expr, saved_name)| saved_name.restore(new_expr)) .collect::>>()?; - // Since group_epxr changes, schema changes also. Use try_new method. - Aggregate::try_new(Arc::new(new_input), new_group_expr, new_aggr_expr) - .map(LogicalPlan::Aggregate) - } else { - let mut agg_exprs = common_exprs - .into_values() - .map(|(expr, expr_alias)| expr.alias(expr_alias)) - .collect::>(); - - let mut proj_exprs = vec![]; - for expr in &new_group_expr { - extract_expressions(expr, &new_input_schema, &mut proj_exprs)? - } - for (expr_rewritten, expr_orig) in rewritten.into_iter().zip(new_aggr_expr) { - if expr_rewritten == expr_orig { - if let Expr::Alias(Alias { expr, name, .. }) = expr_rewritten { - agg_exprs.push(expr.alias(&name)); - proj_exprs.push(Expr::Column(Column::from_name(name))); - } else { - let expr_alias = config.alias_generator().next(CSE_PREFIX); - let (qualifier, field) = - expr_rewritten.to_field(&new_input_schema)?; - let out_name = qualified_name(qualifier.as_ref(), field.name()); - - agg_exprs.push(expr_rewritten.alias(&expr_alias)); - proj_exprs.push( - Expr::Column(Column::from_name(expr_alias)).alias(out_name), - ); - } + // Since group_expr may have changed, schema may also. Use try_new method. + let new_agg = if transformed { + Aggregate::try_new(Arc::new(new_input), new_group_expr, new_aggr_expr)? + } else { + Aggregate::try_new_with_schema( + Arc::new(new_input), + new_group_expr, + new_aggr_expr, + orig_schema, + )? + }; + let new_agg = LogicalPlan::Aggregate(new_agg); + return Ok(Transformed::new_transformed(new_agg, transformed)); + } + let mut agg_exprs = common_exprs + .into_values() + .map(|(expr, expr_alias)| expr.alias(expr_alias)) + .collect::>(); + + let mut proj_exprs = vec![]; + for expr in &new_group_expr { + extract_expressions(expr, &new_input_schema, &mut proj_exprs)? + } + for (expr_rewritten, expr_orig) in rewritten.into_iter().zip(new_aggr_expr) { + if expr_rewritten == expr_orig { + if let Expr::Alias(Alias { expr, name, .. }) = expr_rewritten { + agg_exprs.push(expr.alias(&name)); + proj_exprs.push(Expr::Column(Column::from_name(name))); } else { - proj_exprs.push(expr_rewritten); + let expr_alias = config.alias_generator().next(CSE_PREFIX); + let (qualifier, field) = + expr_rewritten.to_field(&new_input_schema)?; + let out_name = qualified_name(qualifier.as_ref(), field.name()); + + agg_exprs.push(expr_rewritten.alias(&expr_alias)); + proj_exprs.push( + Expr::Column(Column::from_name(expr_alias)).alias(out_name), + ); } + } else { + proj_exprs.push(expr_rewritten); } + } - let agg = LogicalPlan::Aggregate(Aggregate::try_new( - Arc::new(new_input), - new_group_expr, - agg_exprs, - )?); + let agg = LogicalPlan::Aggregate(Aggregate::try_new( + Arc::new(new_input), + new_group_expr, + agg_exprs, + )?); - Ok(LogicalPlan::Projection(Projection::try_new( - proj_exprs, - Arc::new(agg), - )?)) - } + Projection::try_new(proj_exprs, Arc::new(agg)) + .map(LogicalPlan::Projection) + .map(Transformed::yes) } + /// Rewrites the expr list and input to remove common subexpressions + /// + /// # Parameters + /// + /// * `exprs`: List of expressions in the node + /// * `input`: input plan (that produces the columns referred to in `exprs`) + /// + /// # Return value + /// + /// Returns `(rewritten_exprs, new_input)`. `new_input` is either: + /// + /// 1. The original `input` of no common subexpressions were extracted + /// 2. A newly added projection on top of the original input + /// that computes the common subexpressions fn try_unary_plan( &self, - plan: &LogicalPlan, + expr: Vec, + input: LogicalPlan, config: &dyn OptimizerConfig, - ) -> Result { - let expr = plan.expressions(); - let inputs = plan.inputs(); - let input = inputs[0]; + ) -> Result, LogicalPlan)>> { let mut expr_stats = ExprStats::new(); - - // Visit expr list and build expr identifier to occuring count map (`expr_stats`). let arrays = to_arrays(&expr, &mut expr_stats, ExprMask::Normal)?; - let (mut new_expr, new_input) = - self.rewrite_expr(&[&expr], &[&arrays], input, &expr_stats, config)?; + self.rewrite_expr(vec![expr], &[&arrays], input, &expr_stats, config)? + .map_data(|(mut new_expr, new_input)| { + assert_eq!(new_expr.len(), 1); + Ok((new_expr.pop().unwrap(), new_input)) + }) + } +} - plan.with_new_exprs(pop_expr(&mut new_expr)?, vec![new_input]) +/// Get all window expressions inside the consecutive window operators. +/// +/// Returns the window expressions, and the input to the deepest child +/// LogicalPlan. +/// +/// For example, if the input widnow looks like +/// +/// ```text +/// LogicalPlan::Window(exprs=[a, b, c]) +/// LogicalPlan::Window(exprs=[d]) +/// InputPlan +/// ``` +/// +/// Returns: +/// * `window_exprs`: `[a, b, c, d]` +/// * InputPlan +/// +/// Consecutive window expressions may refer to same complex expression. +/// +/// If same complex expression is referred more than once by subsequent +/// `WindowAggr`s, we can cache complex expression by evaluating it with a +/// projection before the first WindowAggr. +/// +/// This enables us to cache complex expression "c3+c4" for following plan: +/// +/// ```text +/// WindowAggr: windowExpr=[[sum(c9) ORDER BY [c3 + c4] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +/// --WindowAggr: windowExpr=[[sum(c9) ORDER BY [c3 + c4] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +/// ``` +/// +/// where, it is referred once by each `WindowAggr` (total of 2) in the plan. +struct ConsecutiveWindowExprs { + window_exprs: Vec>, + /// result of calling `to_arrays` on each set of window exprs + arrays_per_window: Vec>>, + expr_stats: ExprStats, + /// input plan to the window + plan: LogicalPlan, +} + +impl ConsecutiveWindowExprs { + fn try_new(window: Window) -> Result { + let mut window_exprs = vec![]; + let mut arrays_per_window = vec![]; + let mut expr_stats = ExprStats::new(); + + let mut plan = LogicalPlan::Window(window); + while let LogicalPlan::Window(Window { + input, window_expr, .. + }) = plan + { + plan = unwrap_arc(input); + + let arrays = to_arrays(&window_expr, &mut expr_stats, ExprMask::Normal)?; + + window_exprs.push(window_expr); + arrays_per_window.push(arrays); + } + + Ok(Self { + window_exprs, + arrays_per_window, + expr_stats, + plan, + }) } } impl OptimizerRule for CommonSubexprEliminate { fn try_optimize( &self, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, + _plan: &LogicalPlan, + _config: &dyn OptimizerConfig, ) -> Result> { + internal_err!("Should have called CommonSubexprEliminate::rewrite") + } + + fn supports_rewrite(&self) -> bool { + true + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } + + fn rewrite( + &self, + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { + let original_schema = Arc::clone(plan.schema()); + let optimized_plan = match plan { - LogicalPlan::Projection(_) - | LogicalPlan::Sort(_) - | LogicalPlan::Filter(_) => Some(self.try_unary_plan(plan, config)?), - LogicalPlan::Window(window) => { - Some(self.try_optimize_window(window, config)?) - } - LogicalPlan::Aggregate(aggregate) => { - Some(self.try_optimize_aggregate(aggregate, config)?) - } + LogicalPlan::Projection(proj) => self.try_optimize_proj(proj, config)?, + LogicalPlan::Sort(sort) => self.try_optimize_sort(sort, config)?, + LogicalPlan::Filter(filter) => self.try_optimize_filter(filter, config)?, + LogicalPlan::Window(window) => self.try_optimize_window(window, config)?, + LogicalPlan::Aggregate(agg) => self.try_optimize_aggregate(agg, config)?, LogicalPlan::Join(_) | LogicalPlan::CrossJoin(_) | LogicalPlan::Repartition(_) @@ -420,21 +613,19 @@ impl OptimizerRule for CommonSubexprEliminate { | LogicalPlan::Unnest(_) | LogicalPlan::RecursiveQuery(_) | LogicalPlan::Prepare(_) => { - // apply the optimization to all inputs of the plan - utils::optimize_children(self, plan, config)? + // ApplyOrder::TopDown handles recursion + Transformed::no(plan) } }; - let original_schema = plan.schema(); - match optimized_plan { - Some(optimized_plan) if optimized_plan.schema() != original_schema => { - // add an additional projection if the output schema changed. - Ok(Some(build_recover_project_plan( - original_schema, - optimized_plan, - )?)) - } - plan => Ok(plan), + // If we rewrote the plan, ensure the schema stays the same + if optimized_plan.transformed && optimized_plan.data.schema() != &original_schema + { + optimized_plan.map_data(|optimized_plan| { + build_recover_project_plan(&original_schema, optimized_plan) + }) + } else { + Ok(optimized_plan) } } @@ -459,22 +650,29 @@ impl CommonSubexprEliminate { fn pop_expr(new_expr: &mut Vec>) -> Result> { new_expr .pop() - .ok_or_else(|| DataFusionError::Internal("Failed to pop expression".to_string())) + .ok_or_else(|| internal_datafusion_err!("Failed to pop expression")) } +/// Returns the identifier list for each element in `exprs` +/// +/// Returns and array with 1 element for each input expr in `exprs` +/// +/// Each element is itself the result of [`expr_to_identifier`] for that expr +/// (e.g. the identifiers for each node in the tree) fn to_arrays( - expr: &[Expr], + exprs: &[Expr], expr_stats: &mut ExprStats, expr_mask: ExprMask, ) -> Result> { - expr.iter() + exprs + .iter() .map(|e| { let mut id_array = vec![]; expr_to_identifier(e, expr_stats, &mut id_array, expr_mask)?; Ok(id_array) }) - .collect::>>() + .collect() } /// Build the "intermediate" projection plan that evaluates the extracted common @@ -506,10 +704,7 @@ fn build_common_expr_project_plan( } } - Ok(LogicalPlan::Projection(Projection::try_new( - project_exprs, - Arc::new(input), - )?)) + Projection::try_new(project_exprs, Arc::new(input)).map(LogicalPlan::Projection) } /// Build the projection plan to eliminate unnecessary columns produced by @@ -522,10 +717,7 @@ fn build_recover_project_plan( input: LogicalPlan, ) -> Result { let col_exprs = schema.iter().map(Expr::from).collect(); - Ok(LogicalPlan::Projection(Projection::try_new( - col_exprs, - Arc::new(input), - )?)) + Projection::try_new(col_exprs, Arc::new(input)).map(LogicalPlan::Projection) } fn extract_expressions( @@ -807,7 +999,7 @@ fn replace_common_expr( expr_stats: &ExprStats, common_exprs: &mut CommonExprs, alias_generator: &AliasGenerator, -) -> Result { +) -> Result> { expr.rewrite(&mut CommonSubexprRewriter { expr_stats, id_array, @@ -816,7 +1008,6 @@ fn replace_common_expr( alias_counter: 0, alias_generator, }) - .data() } #[cfg(test)] @@ -839,18 +1030,36 @@ mod test { use super::*; + fn assert_non_optimized_plan_eq( + expected: &str, + plan: LogicalPlan, + config: Option<&dyn OptimizerConfig>, + ) { + assert_eq!(expected, format!("{plan:?}"), "Unexpected starting plan"); + let optimizer = CommonSubexprEliminate {}; + let default_config = OptimizerContext::new(); + let config = config.unwrap_or(&default_config); + let optimized_plan = optimizer.rewrite(plan, config).unwrap(); + assert!(!optimized_plan.transformed, "unexpectedly optimize plan"); + let optimized_plan = optimized_plan.data; + assert_eq!( + expected, + format!("{optimized_plan:?}"), + "Unexpected optimized plan" + ); + } + fn assert_optimized_plan_eq( expected: &str, - plan: &LogicalPlan, + plan: LogicalPlan, config: Option<&dyn OptimizerConfig>, ) { let optimizer = CommonSubexprEliminate {}; let default_config = OptimizerContext::new(); let config = config.unwrap_or(&default_config); - let optimized_plan = optimizer - .try_optimize(plan, config) - .unwrap() - .expect("failed to optimize plan"); + let optimized_plan = optimizer.rewrite(plan, config).unwrap(); + assert!(optimized_plan.transformed, "failed to optimize plan"); + let optimized_plan = optimized_plan.data; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(expected, formatted_plan); } @@ -933,7 +1142,7 @@ mod test { \n Projection: test.a * (Int32(1) - test.b) AS __common_expr_1, test.a, test.b, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan, None); + assert_optimized_plan_eq(expected, plan, None); Ok(()) } @@ -953,7 +1162,7 @@ mod test { \n Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan, None); + assert_optimized_plan_eq(expected, plan, None); Ok(()) } @@ -1006,7 +1215,7 @@ mod test { \n Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2, AVG(test.b) AS col3, AVG(test.c) AS __common_expr_3, my_agg(test.b) AS col6, my_agg(test.c) AS __common_expr_4]]\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan, None); + assert_optimized_plan_eq(expected, plan, None); // test: trafo after aggregate let plan = LogicalPlanBuilder::from(table_scan.clone()) @@ -1025,7 +1234,7 @@ mod test { \n Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2]]\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan, None); + assert_optimized_plan_eq(expected, plan, None); // test: transformation before aggregate let plan = LogicalPlanBuilder::from(table_scan.clone()) @@ -1042,7 +1251,7 @@ mod test { \n Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan, None); + assert_optimized_plan_eq(expected, plan, None); // test: common between agg and group let plan = LogicalPlanBuilder::from(table_scan.clone()) @@ -1059,7 +1268,7 @@ mod test { \n Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan, None); + assert_optimized_plan_eq(expected, plan, None); // test: all mixed let plan = LogicalPlanBuilder::from(table_scan) @@ -1081,7 +1290,7 @@ mod test { \n Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan, None); + assert_optimized_plan_eq(expected, plan, None); Ok(()) } @@ -1108,7 +1317,7 @@ mod test { \n Projection: UInt32(1) + table.test.col.a AS __common_expr_1, table.test.col.a\ \n TableScan: table.test"; - assert_optimized_plan_eq(expected, &plan, None); + assert_optimized_plan_eq(expected, plan, None); Ok(()) } @@ -1128,7 +1337,7 @@ mod test { \n Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan, None); + assert_optimized_plan_eq(expected, plan, None); Ok(()) } @@ -1144,7 +1353,7 @@ mod test { let expected = "Projection: Int32(1) + test.a, test.a + Int32(1)\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan, None); + assert_non_optimized_plan_eq(expected, plan, None); Ok(()) } @@ -1162,7 +1371,7 @@ mod test { \n Projection: Int32(1) + test.a, test.a\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan, None); + assert_non_optimized_plan_eq(expected, plan, None); Ok(()) } @@ -1257,10 +1466,9 @@ mod test { .build() .unwrap(); let rule = CommonSubexprEliminate {}; - let optimized_plan = rule - .try_optimize(&plan, &OptimizerContext::new()) - .unwrap() - .unwrap(); + let optimized_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap(); + assert!(!optimized_plan.transformed); + let optimized_plan = optimized_plan.data; let schema = optimized_plan.schema(); let fields_with_datatypes: Vec<_> = schema @@ -1299,7 +1507,7 @@ mod test { \n Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan, None); + assert_optimized_plan_eq(expected, plan, None); Ok(()) } @@ -1365,7 +1573,7 @@ mod test { \n Projection: test.a + test.b AS __common_expr_1, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan, Some(config)); + assert_optimized_plan_eq(expected, plan, Some(config)); let config = &OptimizerContext::new(); let _common_expr_1 = config.alias_generator().next(CSE_PREFIX); @@ -1388,7 +1596,7 @@ mod test { \n Projection: test.a + test.b AS __common_expr_2, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan, Some(config)); + assert_optimized_plan_eq(expected, plan, Some(config)); Ok(()) }