From c0990de1d4b760ef72caff026fcc6450f5580e3f Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 12 Dec 2023 10:06:33 +0100 Subject: [PATCH 1/9] POC --- datafusion-examples/examples/rewrite_expr.rs | 6 +- datafusion/common/src/tree_node.rs | 403 +++++++++++++----- .../core/src/datasource/listing/helpers.rs | 23 +- .../physical_plan/parquet/row_groups.rs | 6 +- datafusion/core/src/execution/context/mod.rs | 12 +- .../combine_partial_final_agg.rs | 2 +- .../enforce_distribution.rs | 65 ++- .../src/physical_optimizer/enforce_sorting.rs | 68 +-- .../physical_optimizer/pipeline_checker.rs | 33 +- .../physical_optimizer/projection_pushdown.rs | 6 +- .../core/src/physical_optimizer/pruning.rs | 2 +- .../replace_with_order_preserving_variants.rs | 34 +- .../src/physical_optimizer/sort_pushdown.rs | 35 +- datafusion/core/src/physical_planner.rs | 3 + datafusion/expr/src/expr.rs | 19 +- datafusion/expr/src/expr_rewriter/mod.rs | 75 ++-- datafusion/expr/src/expr_rewriter/order_by.rs | 2 +- datafusion/expr/src/expr_schema.rs | 4 +- datafusion/expr/src/logical_plan/builder.rs | 5 +- datafusion/expr/src/logical_plan/display.rs | 22 +- datafusion/expr/src/logical_plan/plan.rs | 216 +++++----- datafusion/expr/src/tree_node/expr.rs | 204 +++++---- datafusion/expr/src/tree_node/plan.rs | 106 ++--- datafusion/expr/src/utils.rs | 77 +--- .../src/analyzer/count_wildcard_rule.rs | 111 ++--- .../src/analyzer/inline_table_scan.rs | 2 +- datafusion/optimizer/src/analyzer/mod.rs | 33 +- datafusion/optimizer/src/analyzer/subquery.rs | 18 +- .../optimizer/src/analyzer/type_coercion.rs | 284 +++++------- .../optimizer/src/common_subexpr_eliminate.rs | 14 +- datafusion/optimizer/src/decorrelate.rs | 4 +- datafusion/optimizer/src/plan_signature.rs | 6 +- datafusion/optimizer/src/push_down_filter.rs | 29 +- .../simplify_expressions/expr_simplifier.rs | 8 +- .../src/unwrap_cast_in_comparison.rs | 83 ++-- datafusion/physical-expr/src/equivalence.rs | 2 +- .../physical-expr/src/expressions/case.rs | 2 +- .../physical-expr/src/sort_properties.rs | 23 +- datafusion/physical-expr/src/utils/mod.rs | 51 ++- datafusion/proto/src/logical_plan/to_proto.rs | 4 + 40 files changed, 1109 insertions(+), 993 deletions(-) diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/rewrite_expr.rs index 5e95562033e6..9dfc238ab9e8 100644 --- a/datafusion-examples/examples/rewrite_expr.rs +++ b/datafusion-examples/examples/rewrite_expr.rs @@ -91,7 +91,7 @@ impl AnalyzerRule for MyAnalyzerRule { impl MyAnalyzerRule { fn analyze_plan(plan: LogicalPlan) -> Result { - plan.transform(&|plan| { + plan.transform_up(&|plan| { Ok(match plan { LogicalPlan::Filter(filter) => { let predicate = Self::analyze_expr(filter.predicate.clone())?; @@ -106,7 +106,7 @@ impl MyAnalyzerRule { } fn analyze_expr(expr: Expr) -> Result { - expr.transform(&|expr| { + expr.transform_up(&|expr| { // closure is invoked for all sub expressions Ok(match expr { Expr::Literal(ScalarValue::Int64(i)) => { @@ -161,7 +161,7 @@ impl OptimizerRule for MyOptimizerRule { /// use rewrite_expr to modify the expression tree. fn my_rewrite(expr: Expr) -> Result { - expr.transform(&|expr| { + expr.transform_up(&|expr| { // closure is invoked for all sub expressions Ok(match expr { Expr::Between(Between { diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 5da9636ffe18..39d691a9dcea 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -22,10 +22,24 @@ use std::sync::Arc; use crate::Result; -/// Defines a visitable and rewriteable a tree node. This trait is -/// implemented for plans ([`ExecutionPlan`] and [`LogicalPlan`]) as -/// well as expression trees ([`PhysicalExpr`], [`Expr`]) in -/// DataFusion +/// Defines a tree node that can have children of the same type as the parent node. The +/// implementations must provide [`TreeNode::apply_children()`] and +/// [`TreeNode::map_children()`] for visiting and changing the structure of the tree. +/// +/// [`TreeNode`] is implemented for plans ([`ExecutionPlan`] and [`LogicalPlan`]) as well +/// as expression trees ([`PhysicalExpr`], [`Expr`]) in DataFusion. +/// +/// Besides the children, each tree node can define links to embedded trees of the same +/// type. The root node of these trees are called inner children of a node. +/// +/// A logical plan of a query is a tree of [`LogicalPlan`] nodes, where each node can +/// contain multiple expression ([`Expr`]) trees. But expression tree nodes can contain +/// logical plans of subqueries, which are again trees of [`LogicalPlan`] nodes. The root +/// nodes of these subquery plans are the inner children of the containing query plan +/// node. +/// +/// Tree node implementations can provide [`TreeNode::apply_inner_children()`] for +/// visiting the structure of the inner tree. /// /// /// [`ExecutionPlan`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.ExecutionPlan.html @@ -33,28 +47,40 @@ use crate::Result; /// [`LogicalPlan`]: https://docs.rs/datafusion-expr/latest/datafusion_expr/logical_plan/enum.LogicalPlan.html /// [`Expr`]: https://docs.rs/datafusion-expr/latest/datafusion_expr/expr/enum.Expr.html pub trait TreeNode: Sized { - /// Use preorder to iterate the node on the tree so that we can - /// stop fast for some cases. - /// - /// The `op` closure can be used to collect some info from the - /// tree node or do some checking for the tree node. - fn apply(&self, op: &mut F) -> Result + /// Applies `f` to the tree node, then to its inner children and then to its children + /// depending on the result of `f` in a preorder traversal. + /// See [`TreeNodeRecursion`] for more details on how the preorder traversal can be + /// controlled. + /// If an [`Err`] result is returned, recursion is stopped immediately. + fn visit_down(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { - match op(self)? { - VisitRecursion::Continue => {} - // If the recursion should skip, do not apply to its children. And let the recursion continue - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - // If the recursion should stop, do not apply to its children - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - }; - - self.apply_children(&mut |node| node.apply(op)) + // Apply `f` on self. + f(self) + // If it returns continue (not prune or stop or stop all) then continue + // traversal on inner children and children. + .and_then_on_continue(|| { + // Run the recursive `apply` on each inner children, but as they are + // unrelated root nodes of inner trees if any returns stop then continue + // with the next one. + self.apply_inner_children(&mut |c| c.visit_down(f).continue_on_stop()) + // Run the recursive `apply` on each children. + .and_then_on_continue(|| { + self.apply_children(&mut |c| c.visit_down(f)) + }) + }) + // Applying `f` on self might have returned prune, but we need to propagate + // continue. + .continue_on_prune() } - /// Visit the tree node using the given [TreeNodeVisitor] - /// It performs a depth first walk of an node and its children. + /// Uses a [`TreeNodeVisitor`] to visit the tree node, then its inner children and + /// then its children depending on the result of [`TreeNodeVisitor::pre_visit()`] and + /// [`TreeNodeVisitor::post_visit()`] in a traversal. + /// See [`TreeNodeRecursion`] for more details on how the traversal can be controlled. + /// + /// If an [`Err`] result is returned, recursion is stopped immediately. /// /// For an node tree such as /// ```text @@ -73,45 +99,54 @@ pub trait TreeNode: Sized { /// post_visit(ParentNode) /// ``` /// - /// If an Err result is returned, recursion is stopped immediately - /// - /// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no - /// children of that node will be visited, nor is post_visit - /// called on that node. Details see [`TreeNodeVisitor`] - /// - /// If using the default [`TreeNodeVisitor::post_visit`] that does - /// nothing, [`Self::apply`] should be preferred. - fn visit>( + /// If using the default [`TreeNodeVisitor::post_visit()`] that does nothing, + /// [`Self::visit_down()`] should be preferred. + fn visit>( &self, visitor: &mut V, - ) -> Result { - match visitor.pre_visit(self)? { - VisitRecursion::Continue => {} - // If the recursion should skip, do not apply to its children. And let the recursion continue - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - // If the recursion should stop, do not apply to its children - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - }; - - match self.apply_children(&mut |node| node.visit(visitor))? { - VisitRecursion::Continue => {} - // If the recursion should skip, do not apply to its children. And let the recursion continue - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - // If the recursion should stop, do not apply to its children - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - - visitor.post_visit(self) + ) -> Result { + // Apply `pre_visit` on self. + visitor + .pre_visit(self) + // If it returns continue (not prune or stop or stop all) then continue + // traversal on inner children and children. + .and_then_on_continue(|| { + // Run the recursive `visit` on each inner children, but as they are + // unrelated subquery plans if any returns stop then continue with the + // next one. + self.apply_inner_children(&mut |c| c.visit(visitor).continue_on_stop()) + // Run the recursive `visit` on each children. + .and_then_on_continue(|| { + self.apply_children(&mut |c| c.visit(visitor)) + }) + // Apply `post_visit` on self. + .and_then_on_continue(|| visitor.post_visit(self)) + }) + // Applying `pre_visit` or `post_visit` on self might have returned prune, + // but we need to propagate continue. + .continue_on_prune() } - /// Convenience utils for writing optimizers rule: recursively apply the given `op` to the node tree. - /// When `op` does not apply to a given node, it is left unchanged. - /// The default tree traversal direction is transform_up(Postorder Traversal). - fn transform(self, op: &F) -> Result - where - F: Fn(Self) -> Result>, - { - self.transform_up(op) + fn transform>( + &mut self, + transformer: &mut T, + ) -> Result { + // Apply `pre_transform` on self. + transformer + .pre_transform(self) + // If it returns continue (not prune or stop or stop all) then continue + // traversal on inner children and children. + .and_then_on_continue(|| + // Run the recursive `transform` on each children. + self + .transform_children(&mut |c| c.transform(transformer)) + // Apply `post_transform` on new self. + .and_then_on_continue(|| { + transformer.post_transform(self) + })) + // Applying `pre_transform` or `post_transform` on self might have returned + // prune, but we need to propagate continue. + .continue_on_prune() } /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its @@ -208,54 +243,109 @@ pub trait TreeNode: Sized { } } - /// Apply the closure `F` to the node's children - fn apply_children(&self, op: &mut F) -> Result + /// Apply `f` to the node's children. + fn apply_children(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> Result; + F: FnMut(&Self) -> Result; + + /// Apply `f` to the node's inner children. + fn apply_inner_children(&self, _f: &mut F) -> Result + where + F: FnMut(&Self) -> Result, + { + Ok(TreeNodeRecursion::Continue) + } /// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder) fn map_children(self, transform: F) -> Result where F: FnMut(Self) -> Result; + + /// Apply `f` to the node's children. + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result; + + /// Convenience function to do a preorder traversal of the tree nodes with `f` that + /// can't fail. + fn for_each(&self, f: &mut F) + where + F: FnMut(&Self), + { + self.visit_down(&mut |n| { + f(n); + Ok(TreeNodeRecursion::Continue) + }) + .unwrap(); + } + + /// Convenience function to collect the first non-empty value that `f` returns in a + /// preorder traversal. + fn collect_first(&self, f: &mut F) -> Option + where + F: FnMut(&Self) -> Option, + { + let mut res = None; + self.visit_down(&mut |n| { + res = f(n); + if res.is_some() { + Ok(TreeNodeRecursion::StopAll) + } else { + Ok(TreeNodeRecursion::Continue) + } + }) + .unwrap(); + res + } + + /// Convenience function to collect all values that `f` returns in a preorder + /// traversal. + fn collect(&self, f: &mut F) -> Vec + where + F: FnMut(&Self) -> Vec, + { + let mut res = vec![]; + self.visit_down(&mut |n| { + res.extend(f(n)); + Ok(TreeNodeRecursion::Continue) + }) + .unwrap(); + res + } } -/// Implements the [visitor -/// pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for recursively walking [`TreeNode`]s. -/// -/// [`TreeNodeVisitor`] allows keeping the algorithms -/// separate from the code to traverse the structure of the `TreeNode` -/// tree and makes it easier to add new types of tree node and -/// algorithms. +/// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for +/// recursively walking [`TreeNode`]s. /// -/// When passed to[`TreeNode::visit`], [`TreeNodeVisitor::pre_visit`] -/// and [`TreeNodeVisitor::post_visit`] are invoked recursively -/// on an node tree. +/// [`TreeNodeVisitor`] allows keeping the algorithms separate from the code to traverse +/// the structure of the [`TreeNode`] tree and makes it easier to add new types of tree +/// node and algorithms. /// -/// If an [`Err`] result is returned, recursion is stopped -/// immediately. +/// When passed to [`TreeNode::visit()`], [`TreeNodeVisitor::pre_visit()`] and +/// [`TreeNodeVisitor::post_visit()`] are invoked recursively on an node tree. +/// See [`TreeNodeRecursion`] for more details on how the traversal can be controlled. /// -/// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no -/// children of that tree node are visited, nor is post_visit -/// called on that tree node -/// -/// If [`VisitRecursion::Stop`] is returned on a call to post_visit, no -/// siblings of that tree node are visited, nor is post_visit -/// called on its parent tree node -/// -/// If [`VisitRecursion::Skip`] is returned on a call to pre_visit, no -/// children of that tree node are visited. +/// If an [`Err`] result is returned, recursion is stopped immediately. pub trait TreeNodeVisitor: Sized { /// The node type which is visitable. - type N: TreeNode; + type Node: TreeNode; - /// Invoked before any children of `node` are visited. - fn pre_visit(&mut self, node: &Self::N) -> Result; + /// Invoked before any inner children or children of a node are visited. + fn pre_visit(&mut self, node: &Self::Node) -> Result; - /// Invoked after all children of `node` are visited. Default - /// implementation does nothing. - fn post_visit(&mut self, _node: &Self::N) -> Result { - Ok(VisitRecursion::Continue) - } + /// Invoked after all inner children and children of a node are visited. + fn post_visit(&mut self, _node: &Self::Node) -> Result; +} + +pub trait TreeNodeTransformer: Sized { + /// The node type which is visitable. + type Node: TreeNode; + + /// Invoked before any inner children or children of a node are modified. + fn pre_transform(&mut self, node: &mut Self::Node) -> Result; + + /// Invoked after all inner children and children of a node are modified. + fn post_transform(&mut self, node: &mut Self::Node) -> Result; } /// Trait for potentially recursively transform an [`TreeNode`] node @@ -289,15 +379,108 @@ pub enum RewriteRecursion { Skip, } -/// Controls how the [`TreeNode`] recursion should proceed for [`TreeNode::visit`]. +/// Controls how the [`TreeNode`] recursion should proceed for [`TreeNode::visit_down()`] and +/// [`TreeNode::visit()`]. #[derive(Debug)] -pub enum VisitRecursion { - /// Continue the visit to this node tree. +pub enum TreeNodeRecursion { + /// Continue the visit to the next node. Continue, - /// Keep recursive but skip applying op on the children - Skip, - /// Stop the visit to this node tree. + + /// Prune the current subtree. + /// If a preorder visit of a tree node returns [`TreeNodeRecursion::Prune`] then inner + /// children and children will not be visited and postorder visit of the node will not + /// be invoked. + Prune, + + /// Stop recursion on current tree. + /// If recursion runs on an inner tree then returning [`TreeNodeRecursion::Stop`] doesn't + /// stop recursion on the outer tree. Stop, + + /// Stop recursion on all (including outer) trees. + StopAll, +} + +impl TreeNodeRecursion { + fn continue_on_prune(self) -> TreeNodeRecursion { + match self { + TreeNodeRecursion::Prune => TreeNodeRecursion::Continue, + o => o, + } + } + + fn fail_on_prune(self) -> TreeNodeRecursion { + match self { + TreeNodeRecursion::Prune => panic!("Recursion can't prune."), + o => o, + } + } + + fn continue_on_stop(self) -> TreeNodeRecursion { + match self { + TreeNodeRecursion::Stop => TreeNodeRecursion::Continue, + o => o, + } + } +} + +/// This helper trait provide functions to control recursion on +/// [`Result`]. +pub trait TreeNodeRecursionResult: Sized { + fn and_then_on_continue(self, f: F) -> Result + where + F: FnOnce() -> Result; + + fn continue_on_prune(self) -> Result; + + fn fail_on_prune(self) -> Result; + + fn continue_on_stop(self) -> Result; +} + +impl TreeNodeRecursionResult for Result { + fn and_then_on_continue(self, f: F) -> Result + where + F: FnOnce() -> Result, + { + match self? { + TreeNodeRecursion::Continue => f(), + o => Ok(o), + } + } + + fn continue_on_prune(self) -> Result { + self.map(|tnr| tnr.continue_on_prune()) + } + + fn fail_on_prune(self) -> Result { + self.map(|tnr| tnr.fail_on_prune()) + } + + fn continue_on_stop(self) -> Result { + self.map(|tnr| tnr.continue_on_stop()) + } +} + +pub trait VisitRecursionIterator: Iterator { + fn for_each_till_continue(self, f: &mut F) -> Result + where + F: FnMut(Self::Item) -> Result; +} + +impl VisitRecursionIterator for I { + fn for_each_till_continue(self, f: &mut F) -> Result + where + F: FnMut(Self::Item) -> Result, + { + for i in self { + match f(i)? { + TreeNodeRecursion::Continue => {} + o => return Ok(o), + } + } + Ok(TreeNodeRecursion::Continue) + } } pub enum Transformed { @@ -342,19 +525,11 @@ pub trait DynTreeNode { /// Blanket implementation for Arc for any tye that implements /// [`DynTreeNode`] (such as [`Arc`]) impl TreeNode for Arc { - fn apply_children(&self, op: &mut F) -> Result + fn apply_children(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { - for child in self.arc_children() { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + self.arc_children().iter().for_each_till_continue(f) } fn map_children(self, transform: F) -> Result @@ -371,4 +546,18 @@ impl TreeNode for Arc { Ok(self) } } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + let mut new_children = self.arc_children(); + if !new_children.is_empty() { + let tnr = new_children.iter_mut().for_each_till_continue(f)?; + *self = self.with_new_arc_children(self.clone(), new_children)?; + Ok(tnr) + } else { + Ok(TreeNodeRecursion::Continue) + } + } } diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index be74afa1f4d6..870bddbaaaa5 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -37,7 +37,7 @@ use crate::{error::Result, scalar::ScalarValue}; use super::PartitionedFile; use crate::datasource::listing::ListingTableUrl; use crate::execution::context::SessionState; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{internal_err, Column, DFField, DFSchema, DataFusionError}; use datafusion_expr::{Expr, ScalarFunctionDefinition, Volatility}; use datafusion_physical_expr::create_physical_expr; @@ -52,17 +52,18 @@ use object_store::{ObjectMeta, ObjectStore}; /// was performed pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { let mut is_applicable = true; - expr.apply(&mut |expr| { + expr.visit_down(&mut |expr| { match expr { Expr::Column(Column { ref name, .. }) => { is_applicable &= col_names.contains(name); if is_applicable { - Ok(VisitRecursion::Skip) + Ok(TreeNodeRecursion::Prune) } else { - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } } - Expr::Literal(_) + Expr::Nop + | Expr::Literal(_) | Expr::Alias(_) | Expr::OuterReferenceColumn(_, _) | Expr::ScalarVariable(_, _) @@ -88,27 +89,27 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { | Expr::ScalarSubquery(_) | Expr::GetIndexedField { .. } | Expr::GroupingSet(_) - | Expr::Case { .. } => Ok(VisitRecursion::Continue), + | Expr::Case { .. } => Ok(TreeNodeRecursion::Continue), Expr::ScalarFunction(scalar_function) => { match &scalar_function.func_def { ScalarFunctionDefinition::BuiltIn(fun) => { match fun.volatility() { - Volatility::Immutable => Ok(VisitRecursion::Continue), + Volatility::Immutable => Ok(TreeNodeRecursion::Continue), // TODO: Stable functions could be `applicable`, but that would require access to the context Volatility::Stable | Volatility::Volatile => { is_applicable = false; - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } } } ScalarFunctionDefinition::UDF(fun) => { match fun.signature().volatility { - Volatility::Immutable => Ok(VisitRecursion::Continue), + Volatility::Immutable => Ok(TreeNodeRecursion::Continue), // TODO: Stable functions could be `applicable`, but that would require access to the context Volatility::Stable | Volatility::Volatile => { is_applicable = false; - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } } } @@ -128,7 +129,7 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { | Expr::Wildcard { .. } | Expr::Placeholder(_) => { is_applicable = false; - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } } }) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index 7c3f7d9384ab..f573fd11a8c2 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -17,7 +17,7 @@ use arrow::{array::ArrayRef, datatypes::Schema}; use arrow_schema::FieldRef; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{Column, DataFusionError, Result, ScalarValue}; use parquet::file::metadata::ColumnChunkMetaData; use parquet::schema::types::SchemaDescriptor; @@ -259,7 +259,7 @@ impl BloomFilterPruningPredicate { fn get_predicate_columns(expr: &Arc) -> HashSet { let mut columns = HashSet::new(); - expr.apply(&mut |expr| { + expr.visit_down(&mut |expr| { if let Some(binary_expr) = expr.as_any().downcast_ref::() { @@ -269,7 +269,7 @@ impl BloomFilterPruningPredicate { columns.insert(column.name().to_string()); } } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) }) // no way to fail as only Ok(VisitRecursion::Continue) is returned .unwrap(); diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 8916fa814a4a..02e1aad80b2d 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -38,7 +38,7 @@ use crate::{ use datafusion_common::{ alias::AliasGenerator, exec_err, not_impl_err, plan_datafusion_err, plan_err, - tree_node::{TreeNode, TreeNodeVisitor, VisitRecursion}, + tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}, }; use datafusion_execution::registry::SerializerRegistry; use datafusion_expr::{ @@ -2093,9 +2093,9 @@ impl<'a> BadPlanVisitor<'a> { } impl<'a> TreeNodeVisitor for BadPlanVisitor<'a> { - type N = LogicalPlan; + type Node = LogicalPlan; - fn pre_visit(&mut self, node: &Self::N) -> Result { + fn pre_visit(&mut self, node: &Self::Node) -> Result { match node { LogicalPlan::Ddl(ddl) if !self.options.allow_ddl => { plan_err!("DDL not supported: {}", ddl.name()) @@ -2109,9 +2109,13 @@ impl<'a> TreeNodeVisitor for BadPlanVisitor<'a> { LogicalPlan::Statement(stmt) if !self.options.allow_statements => { plan_err!("Statement not supported: {}", stmt.name()) } - _ => Ok(VisitRecursion::Continue), + _ => Ok(TreeNodeRecursion::Continue), } } + + fn post_visit(&mut self, _node: &Self::Node) -> Result { + Ok(TreeNodeRecursion::Continue) + } } #[cfg(test)] diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index 7359a6463059..5878650a49e3 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -178,7 +178,7 @@ fn normalize_group_exprs(group_exprs: GroupExprsRef) -> GroupExprs { fn discard_column_index(group_expr: Arc) -> Arc { group_expr .clone() - .transform(&|expr| { + .transform_up(&|expr| { let normalized_form: Option> = match expr.as_any().downcast_ref::() { Some(column) => Some(Arc::new(Column::new(column.name(), 0))), diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index 0aef126578f3..9392d443e150 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -47,7 +47,9 @@ use crate::physical_plan::{ }; use arrow::compute::SortOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeRecursion, VisitRecursionIterator, +}; use datafusion_expr::logical_plan::JoinType; use datafusion_physical_expr::expressions::{Column, NoOp}; use datafusion_physical_expr::utils::map_columns_before_projection; @@ -1476,18 +1478,11 @@ impl DistributionContext { } impl TreeNode for DistributionContext { - fn apply_children(&self, op: &mut F) -> Result + fn apply_children(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { - for child in self.children() { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - Ok(VisitRecursion::Continue) + self.children().iter().for_each_till_continue(f) } fn map_children(self, transform: F) -> Result @@ -1505,6 +1500,23 @@ impl TreeNode for DistributionContext { DistributionContext::new_from_children_nodes(children_nodes, self.plan) } } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + let mut children = self.children(); + if children.is_empty() { + Ok(TreeNodeRecursion::Continue) + } else { + let tnr = children.iter_mut().for_each_till_continue(f)?; + *self = DistributionContext::new_from_children_nodes( + children, + self.plan.clone(), + )?; + Ok(tnr) + } + } } /// implement Display method for `DistributionContext` struct. @@ -1566,20 +1578,11 @@ impl PlanWithKeyRequirements { } impl TreeNode for PlanWithKeyRequirements { - fn apply_children(&self, op: &mut F) -> Result + fn apply_children(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { - let children = self.children(); - for child in children { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + self.children().iter().for_each_till_continue(f) } fn map_children(self, transform: F) -> Result @@ -1605,6 +1608,22 @@ impl TreeNode for PlanWithKeyRequirements { Ok(self) } } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + let mut children = self.children(); + if !children.is_empty() { + let tnr = children.iter_mut().for_each_till_continue(f)?; + let children_plans = children.into_iter().map(|c| c.plan).collect(); + self.plan = + with_new_children_if_necessary(self.plan.clone(), children_plans)?.into(); + Ok(tnr) + } else { + Ok(TreeNodeRecursion::Continue) + } + } } /// Since almost all of these tests explicitly use `ParquetExec` they only run with the parquet feature flag on diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index 2b650a42696b..9a57a030fcc6 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -57,7 +57,9 @@ use crate::physical_plan::{ with_new_children_if_necessary, Distribution, ExecutionPlan, InputOrderMode, }; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeRecursion, VisitRecursionIterator, +}; use datafusion_common::{plan_err, DataFusionError}; use datafusion_physical_expr::{PhysicalSortExpr, PhysicalSortRequirement}; @@ -157,20 +159,11 @@ impl PlanWithCorrespondingSort { } impl TreeNode for PlanWithCorrespondingSort { - fn apply_children(&self, op: &mut F) -> Result + fn apply_children(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { - let children = self.children(); - for child in children { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + self.children().iter().for_each_till_continue(f) } fn map_children(self, transform: F) -> Result @@ -188,6 +181,23 @@ impl TreeNode for PlanWithCorrespondingSort { PlanWithCorrespondingSort::new_from_children_nodes(children_nodes, self.plan) } } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + let mut children = self.children(); + if children.is_empty() { + Ok(TreeNodeRecursion::Continue) + } else { + let tnr = children.iter_mut().for_each_till_continue(f)?; + *self = PlanWithCorrespondingSort::new_from_children_nodes( + children, + self.plan.clone(), + )?; + Ok(tnr) + } + } } /// This object is used within the [EnforceSorting] rule to track the closest @@ -273,20 +283,11 @@ impl PlanWithCorrespondingCoalescePartitions { } impl TreeNode for PlanWithCorrespondingCoalescePartitions { - fn apply_children(&self, op: &mut F) -> Result + fn apply_children(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { - let children = self.children(); - for child in children { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + self.children().iter().for_each_till_continue(f) } fn map_children(self, transform: F) -> Result @@ -307,6 +308,23 @@ impl TreeNode for PlanWithCorrespondingCoalescePartitions { ) } } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + let mut children = self.children(); + if children.is_empty() { + Ok(TreeNodeRecursion::Continue) + } else { + let tnr = children.iter_mut().for_each_till_continue(f)?; + *self = PlanWithCorrespondingCoalescePartitions::new_from_children_nodes( + children, + self.plan.clone(), + )?; + Ok(tnr) + } + } } /// The boolean flag `repartition_sorts` defined in the config indicates diff --git a/datafusion/core/src/physical_optimizer/pipeline_checker.rs b/datafusion/core/src/physical_optimizer/pipeline_checker.rs index d59248aadf05..122ce7171bd3 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_checker.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_checker.rs @@ -28,7 +28,9 @@ use crate::physical_plan::joins::SymmetricHashJoinExec; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; use datafusion_common::config::OptimizerOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeRecursion, VisitRecursionIterator, +}; use datafusion_common::{plan_err, DataFusionError}; use datafusion_physical_expr::intervals::utils::{check_support, is_datatype_supported}; @@ -94,19 +96,11 @@ impl PipelineStatePropagator { } impl TreeNode for PipelineStatePropagator { - fn apply_children(&self, op: &mut F) -> Result + fn apply_children(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { - for child in &self.children { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + self.children.iter().for_each_till_continue(f) } fn map_children(self, transform: F) -> Result @@ -130,6 +124,21 @@ impl TreeNode for PipelineStatePropagator { Ok(self) } } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + if !self.children.is_empty() { + let tnr = self.children.iter_mut().for_each_till_continue(f)?; + let children_plans = self.children.iter().map(|c| c.plan.clone()).collect(); + self.plan = + with_new_children_if_necessary(self.plan.clone(), children_plans)?.into(); + Ok(tnr) + } else { + Ok(TreeNodeRecursion::Continue) + } + } } /// This function propagates finiteness information and rejects any plan with diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index 7e1312dad23e..e2b290f3f5ce 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -43,7 +43,7 @@ use crate::physical_plan::{Distribution, ExecutionPlan}; use arrow_schema::SchemaRef; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_common::JoinSide; use datafusion_physical_expr::expressions::{Column, Literal}; use datafusion_physical_expr::{ @@ -255,12 +255,12 @@ fn try_unifying_projections( // Collect the column references usage in the outer projection. projection.expr().iter().for_each(|(expr, _)| { - expr.apply(&mut |expr| { + expr.visit_down(&mut |expr| { Ok({ if let Some(column) = expr.as_any().downcast_ref::() { *column_ref_map.entry(column.clone()).or_default() += 1; } - VisitRecursion::Continue + TreeNodeRecursion::Continue }) }) .unwrap(); diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index b2ba7596db8d..2423ccc4c32e 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -678,7 +678,7 @@ fn rewrite_column_expr( column_old: &phys_expr::Column, column_new: &phys_expr::Column, ) -> Result> { - e.transform(&|expr| { + e.transform_up(&|expr| { if let Some(column) = expr.as_any().downcast_ref::() { if column == column_old { return Ok(Transformed::Yes(Arc::new(column_new.clone()))); diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index 671891be433c..21602487640f 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -30,7 +30,9 @@ use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; use super::utils::is_repartition; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeRecursion, VisitRecursionIterator, +}; use datafusion_physical_plan::unbounded_output; /// For a given `plan`, this object carries the information one needs from its @@ -118,18 +120,11 @@ impl OrderPreservationContext { } impl TreeNode for OrderPreservationContext { - fn apply_children(&self, op: &mut F) -> Result + fn apply_children(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { - for child in self.children() { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - Ok(VisitRecursion::Continue) + self.children().iter().for_each_till_continue(f) } fn map_children(self, transform: F) -> Result @@ -147,6 +142,23 @@ impl TreeNode for OrderPreservationContext { OrderPreservationContext::new_from_children_nodes(children_nodes, self.plan) } } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + let mut children = self.children(); + if children.is_empty() { + Ok(TreeNodeRecursion::Continue) + } else { + let tnr = children.iter_mut().for_each_till_continue(f)?; + *self = OrderPreservationContext::new_from_children_nodes( + children, + self.plan.clone(), + )?; + Ok(tnr) + } + } } /// Calculates the updated plan by replacing executors that lose ordering diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs b/datafusion/core/src/physical_optimizer/sort_pushdown.rs index b9502d92ac12..4b06218df9e9 100644 --- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs @@ -28,7 +28,9 @@ use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeRecursion, VisitRecursionIterator, +}; use datafusion_common::{plan_err, DataFusionError, JoinSide, Result}; use datafusion_expr::JoinType; use datafusion_physical_expr::expressions::Column; @@ -82,20 +84,11 @@ impl SortPushDown { } impl TreeNode for SortPushDown { - fn apply_children(&self, op: &mut F) -> Result + fn apply_children(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { - let children = self.children(); - for child in children { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + self.children().iter().for_each_till_continue(f) } fn map_children(mut self, transform: F) -> Result @@ -118,6 +111,22 @@ impl TreeNode for SortPushDown { }; Ok(self) } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + let mut children = self.children(); + if !children.is_empty() { + let tnr = children.iter_mut().for_each_till_continue(f)?; + let children_plans = children.into_iter().map(|c| c.plan).collect(); + self.plan = + with_new_children_if_necessary(self.plan.clone(), children_plans)?.into(); + Ok(tnr) + } else { + Ok(TreeNodeRecursion::Continue) + } + } } pub(crate) fn pushdown_sorts( diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index e5816eb49ebb..fd9e81c1b752 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -381,6 +381,9 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { Expr::OuterReferenceColumn(_, _) => { internal_err!("Create physical name does not support OuterReferenceColumn") } + Expr::Nop => { + internal_err!("Create physical name does not support Nop expression") + } } } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index f0aab95b8f0d..5369d502113b 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -31,10 +31,10 @@ use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{internal_err, DFSchema, OwnedTableReference}; use datafusion_common::{plan_err, Column, DataFusionError, Result, ScalarValue}; use std::collections::HashSet; -use std::fmt; use std::fmt::{Display, Formatter, Write}; use std::hash::{BuildHasher, Hash, Hasher}; use std::sync::Arc; +use std::{fmt, mem}; /// `Expr` is a central struct of DataFusion's query API, and /// represent logical expressions such as `A + 1`, or `CAST(c1 AS @@ -81,8 +81,10 @@ use std::sync::Arc; /// assert_eq!(binary_expr.op, Operator::Eq); /// } /// ``` -#[derive(Clone, PartialEq, Eq, Hash, Debug)] +#[derive(Clone, PartialEq, Eq, Hash, Debug, Default)] pub enum Expr { + #[default] + Nop, /// An expression with a specific name. Alias(Alias), /// A named reference to a qualified filed in a schema. @@ -784,6 +786,7 @@ impl Expr { /// Useful for non-rust based bindings pub fn variant_name(&self) -> &str { match self { + Expr::Nop { .. } => "Nop", Expr::AggregateFunction { .. } => "AggregateFunction", Expr::Alias(..) => "Alias", Expr::Between { .. } => "Between", @@ -954,11 +957,11 @@ impl Expr { } /// Remove an alias from an expression if one exists. - pub fn unalias(self) -> Expr { - match self { - Expr::Alias(alias) => alias.expr.as_ref().clone(), - _ => self, + pub fn unalias(&mut self) -> &mut Self { + if let Expr::Alias(alias) = self { + *self = mem::take(alias.expr.as_mut()); } + self } /// Return `self IN ` if `negated` is false, otherwise @@ -1147,7 +1150,7 @@ impl Expr { /// For example, gicen an expression like ` = $0` will infer `$0` to /// have type `int32`. pub fn infer_placeholder_types(self, schema: &DFSchema) -> Result { - self.transform(&|mut expr| { + self.transform_up(&|mut expr| { // Default to assuming the arguments are the same type if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut expr { rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?; @@ -1204,6 +1207,7 @@ macro_rules! expr_vec_fmt { impl fmt::Display for Expr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { + Expr::Nop => write!(f, "NOP"), Expr::Alias(Alias { expr, name, .. }) => write!(f, "{expr} AS {name}"), Expr::Column(c) => write!(f, "{c}"), Expr::OuterReferenceColumn(_, c) => write!(f, "outer_ref({c})"), @@ -1446,6 +1450,7 @@ fn create_function_name(fun: &str, distinct: bool, args: &[Expr]) -> Result 2)". fn create_name(e: &Expr) -> Result { match e { + Expr::Nop => Ok("NOP".to_string()), Expr::Alias(Alias { name, .. }) => Ok(name.clone()), Expr::Column(c) => Ok(c.flat_name()), Expr::OuterReferenceColumn(_, c) => Ok(format!("outer_ref({})", c.flat_name())), diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 1f04c80833f0..cbdeb16f99b2 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -20,7 +20,7 @@ use crate::expr::Alias; use crate::logical_plan::Projection; use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder}; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeTransformer}; use datafusion_common::Result; use datafusion_common::{Column, DFSchema}; use std::collections::HashMap; @@ -33,7 +33,7 @@ pub use order_by::rewrite_sort_cols_by_aggs; /// Recursively call [`Column::normalize_with_schemas`] on all [`Column`] expressions /// in the `expr` expression tree. pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result { - expr.transform(&|expr| { + expr.transform_up(&|expr| { Ok({ if let Expr::Column(c) = expr { let col = LogicalPlanBuilder::normalize(plan, c)?; @@ -57,7 +57,7 @@ pub fn normalize_col_with_schemas( schemas: &[&Arc], using_columns: &[HashSet], ) -> Result { - expr.transform(&|expr| { + expr.transform_up(&|expr| { Ok({ if let Expr::Column(c) = expr { let col = c.normalize_with_schemas(schemas, using_columns)?; @@ -75,7 +75,7 @@ pub fn normalize_col_with_schemas_and_ambiguity_check( schemas: &[&[&DFSchema]], using_columns: &[HashSet], ) -> Result { - expr.transform(&|expr| { + expr.transform_up(&|expr| { Ok({ if let Expr::Column(c) = expr { let col = @@ -102,7 +102,7 @@ pub fn normalize_cols( /// Recursively replace all [`Column`] expressions in a given expression tree with /// `Column` expressions provided by the hash map argument. pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Result { - expr.transform(&|expr| { + expr.transform_up(&|expr| { Ok({ if let Expr::Column(c) = &expr { match replace_map.get(c) { @@ -122,7 +122,7 @@ pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Resul /// For example, if there were expressions like `foo.bar` this would /// rewrite it to just `bar`. pub fn unnormalize_col(expr: Expr) -> Expr { - expr.transform(&|expr| { + expr.transform_up(&|expr| { Ok({ if let Expr::Column(c) = expr { let col = Column { @@ -164,7 +164,7 @@ pub fn unnormalize_cols(exprs: impl IntoIterator) -> Vec { /// Recursively remove all the ['OuterReferenceColumn'] and return the inside Column /// in the expression tree. pub fn strip_outer_reference(expr: Expr) -> Expr { - expr.transform(&|expr| { + expr.transform_up(&|expr| { Ok({ if let Expr::OuterReferenceColumn(_, col) = expr { Transformed::Yes(Expr::Column(col)) @@ -248,12 +248,12 @@ pub fn unalias(expr: Expr) -> Expr { /// /// This is important when optimizing plans to ensure the output /// schema of plan nodes don't change after optimization -pub fn rewrite_preserving_name(expr: Expr, rewriter: &mut R) -> Result +pub fn rewrite_preserving_name(mut expr: Expr, transformer: &mut R) -> Result where - R: TreeNodeRewriter, + R: TreeNodeTransformer, { let original_name = expr.name_for_alias()?; - let expr = expr.rewrite(rewriter)?; + expr.transform(transformer)?; expr.alias_if_changed(original_name) } @@ -263,7 +263,7 @@ mod test { use crate::expr::Sort; use crate::{col, lit, Cast}; use arrow::datatypes::DataType; - use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}; + use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{DFField, DFSchema, ScalarValue}; use std::ops::Add; @@ -272,17 +272,17 @@ mod test { v: Vec, } - impl TreeNodeRewriter for RecordingRewriter { - type N = Expr; + impl TreeNodeTransformer for RecordingRewriter { + type Node = Expr; - fn pre_visit(&mut self, expr: &Expr) -> Result { + fn pre_transform(&mut self, expr: &mut Expr) -> Result { self.v.push(format!("Previsited {expr}")); - Ok(RewriteRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } - fn mutate(&mut self, expr: Expr) -> Result { + fn post_transform(&mut self, expr: &mut Expr) -> Result { self.v.push(format!("Mutated {expr}")); - Ok(expr) + Ok(TreeNodeRecursion::Continue) } } @@ -305,11 +305,17 @@ mod test { }; // rewrites "foo" --> "bar" - let rewritten = col("state").eq(lit("foo")).transform(&transformer).unwrap(); + let rewritten = col("state") + .eq(lit("foo")) + .transform_up(&transformer) + .unwrap(); assert_eq!(rewritten, col("state").eq(lit("bar"))); // doesn't rewrite - let rewritten = col("state").eq(lit("baz")).transform(&transformer).unwrap(); + let rewritten = col("state") + .eq(lit("baz")) + .transform_up(&transformer) + .unwrap(); assert_eq!(rewritten, col("state").eq(lit("baz"))); } @@ -399,7 +405,8 @@ mod test { #[test] fn rewriter_visit() { let mut rewriter = RecordingRewriter::default(); - col("state").eq(lit("CO")).rewrite(&mut rewriter).unwrap(); + let mut expr = col("state").eq(lit("CO")); + expr.transform(&mut rewriter).unwrap(); assert_eq!( rewriter.v, @@ -439,22 +446,28 @@ mod test { /// rewrites `expr_from` to `rewrite_to` using /// `rewrite_preserving_name` verifying the result is `expected_expr` fn test_rewrite(expr_from: Expr, rewrite_to: Expr) { - struct TestRewriter { - rewrite_to: Expr, - } + struct TestTransformer {} + + impl TreeNodeTransformer for TestTransformer { + type Node = Expr; - impl TreeNodeRewriter for TestRewriter { - type N = Expr; + fn pre_transform( + &mut self, + _node: &mut Self::Node, + ) -> Result { + Ok(TreeNodeRecursion::Continue) + } - fn mutate(&mut self, _: Expr) -> Result { - Ok(self.rewrite_to.clone()) + fn post_transform( + &mut self, + _node: &mut Self::Node, + ) -> Result { + Ok(TreeNodeRecursion::Continue) } } - let mut rewriter = TestRewriter { - rewrite_to: rewrite_to.clone(), - }; - let expr = rewrite_preserving_name(expr_from.clone(), &mut rewriter).unwrap(); + let mut transformer = TestTransformer {}; + let expr = rewrite_preserving_name(expr_from.clone(), &mut transformer).unwrap(); let original_name = match &expr_from { Expr::Sort(Sort { expr, .. }) => expr.display_name(), diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index c87a724d5646..1e7efcafd04d 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -83,7 +83,7 @@ fn rewrite_in_terms_of_projection( ) -> Result { // assumption is that each item in exprs, such as "b + c" is // available as an output column named "b + c" - expr.transform(&|expr| { + expr.transform_up(&|expr| { // search for unnormalized names first such as "c1" (such as aliases) if let Some(found) = proj_exprs.iter().find(|a| (**a) == expr) { let col = Expr::Column( diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index e5b0185d90e0..71987667be4a 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -67,6 +67,7 @@ impl ExprSchemable for Expr { /// (e.g. `[utf8] + [bool]`). fn get_type(&self, schema: &S) -> Result { match self { + Expr::Nop => Ok(DataType::Null), Expr::Alias(Alias { expr, name, .. }) => match &**expr { Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type { None => schema.data_type(&Column::from_name(name)).cloned(), @@ -251,7 +252,8 @@ impl ExprSchemable for Expr { | Expr::WindowFunction { .. } | Expr::AggregateFunction { .. } | Expr::Placeholder(_) => Ok(true), - Expr::IsNull(_) + Expr::Nop + | Expr::IsNull(_) | Expr::IsNotNull(_) | Expr::IsTrue(_) | Expr::IsFalse(_) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 88310dab82a2..eb2085ec9e15 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1220,9 +1220,10 @@ pub fn project_with_column_index( let alias_expr = expr .into_iter() .enumerate() - .map(|(i, e)| match e { + .map(|(i, mut e)| match e { Expr::Alias(Alias { ref name, .. }) if name != schema.field(i).name() => { - e.unalias().alias(schema.field(i).name()) + e.unalias(); + e.alias(schema.field(i).name()) } Expr::Column(Column { relation: _, diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index 112dbf74dba1..2a8c4ce5912d 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -19,7 +19,7 @@ use crate::LogicalPlan; use arrow::datatypes::Schema; use datafusion_common::display::GraphvizBuilder; -use datafusion_common::tree_node::{TreeNodeVisitor, VisitRecursion}; +use datafusion_common::tree_node::{TreeNodeRecursion, TreeNodeVisitor}; use datafusion_common::DataFusionError; use std::fmt; @@ -49,12 +49,12 @@ impl<'a, 'b> IndentVisitor<'a, 'b> { } impl<'a, 'b> TreeNodeVisitor for IndentVisitor<'a, 'b> { - type N = LogicalPlan; + type Node = LogicalPlan; fn pre_visit( &mut self, plan: &LogicalPlan, - ) -> datafusion_common::Result { + ) -> datafusion_common::Result { if self.indent > 0 { writeln!(self.f)?; } @@ -69,15 +69,15 @@ impl<'a, 'b> TreeNodeVisitor for IndentVisitor<'a, 'b> { } self.indent += 1; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } fn post_visit( &mut self, _plan: &LogicalPlan, - ) -> datafusion_common::Result { + ) -> datafusion_common::Result { self.indent -= 1; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } } @@ -171,12 +171,12 @@ impl<'a, 'b> GraphvizVisitor<'a, 'b> { } impl<'a, 'b> TreeNodeVisitor for GraphvizVisitor<'a, 'b> { - type N = LogicalPlan; + type Node = LogicalPlan; fn pre_visit( &mut self, plan: &LogicalPlan, - ) -> datafusion_common::Result { + ) -> datafusion_common::Result { let id = self.graphviz_builder.next_id(); // Create a new graph node for `plan` such as @@ -204,18 +204,18 @@ impl<'a, 'b> TreeNodeVisitor for GraphvizVisitor<'a, 'b> { } self.parent_ids.push(id); - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } fn post_visit( &mut self, _plan: &LogicalPlan, - ) -> datafusion_common::Result { + ) -> datafusion_common::Result { // always be non-empty as pre_visit always pushes // So it should always be Ok(true) let res = self.parent_ids.pop(); res.ok_or(DataFusionError::Internal("Fail to format".to_string())) - .map(|_| VisitRecursion::Continue) + .map(|_| TreeNodeRecursion::Continue) } } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 1f3711407a14..3d8a8356f397 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -32,8 +32,7 @@ use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::logical_plan::{DmlStatement, Statement}; use crate::utils::{ enumerate_grouping_sets, exprlist_to_fields, find_out_reference_exprs, - grouping_set_expr_count, grouping_set_to_exprlist, inspect_expr_pre, - split_conjunction, + grouping_set_expr_count, grouping_set_to_exprlist, split_conjunction, }; use crate::{ build_join_schema, expr_vec_fmt, BinaryExpr, CreateMemoryTable, CreateView, Expr, @@ -43,8 +42,8 @@ use crate::{ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::tree_node::{ - RewriteRecursion, Transformed, TreeNode, TreeNodeRewriter, TreeNodeVisitor, - VisitRecursion, + Transformed, TreeNode, TreeNodeRecursion, TreeNodeRecursionResult, + TreeNodeTransformer, VisitRecursionIterator, }; use datafusion_common::{ aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, @@ -277,9 +276,9 @@ impl LogicalPlan { /// children pub fn expressions(self: &LogicalPlan) -> Vec { let mut exprs = vec![]; - self.inspect_expressions(|e| { + self.apply_expressions(&mut |e| { exprs.push(e.clone()); - Ok(()) as Result<()> + Ok(TreeNodeRecursion::Continue) }) // closure always returns OK .unwrap(); @@ -290,13 +289,13 @@ impl LogicalPlan { /// logical plan nodes and all its descendant nodes. pub fn all_out_ref_exprs(self: &LogicalPlan) -> Vec { let mut exprs = vec![]; - self.inspect_expressions(|e| { + self.apply_expressions(&mut |e| { find_out_reference_exprs(e).into_iter().for_each(|e| { if !exprs.contains(&e) { exprs.push(e) } }); - Ok(()) as Result<(), DataFusionError> + Ok(TreeNodeRecursion::Continue) }) // closure always returns OK .unwrap(); @@ -311,37 +310,41 @@ impl LogicalPlan { exprs } - /// Calls `f` on all expressions (non-recursively) in the current - /// logical plan node. This does not include expressions in any - /// children. - pub fn inspect_expressions(self: &LogicalPlan, mut f: F) -> Result<(), E> + /// Apply `f` on expressions of the plan node. + /// `f` is not allowed to return [`TreeNodeRecursion::Prune`]. + pub fn apply_expressions(&self, f: &mut F) -> Result where - F: FnMut(&Expr) -> Result<(), E>, + F: FnMut(&Expr) -> Result, { + let f = &mut |e: &Expr| f(e).fail_on_prune(); + match self { LogicalPlan::Projection(Projection { expr, .. }) => { - expr.iter().try_for_each(f) + expr.iter().for_each_till_continue(f) } LogicalPlan::Values(Values { values, .. }) => { - values.iter().flatten().try_for_each(f) + values.iter().flatten().for_each_till_continue(f) } LogicalPlan::Filter(Filter { predicate, .. }) => f(predicate), LogicalPlan::Repartition(Repartition { partitioning_scheme, .. }) => match partitioning_scheme { - Partitioning::Hash(expr, _) => expr.iter().try_for_each(f), - Partitioning::DistributeBy(expr) => expr.iter().try_for_each(f), - Partitioning::RoundRobinBatch(_) => Ok(()), + Partitioning::Hash(expr, _) => expr.iter().for_each_till_continue(f), + Partitioning::DistributeBy(expr) => expr.iter().for_each_till_continue(f), + Partitioning::RoundRobinBatch(_) => Ok(TreeNodeRecursion::Continue), }, LogicalPlan::Window(Window { window_expr, .. }) => { - window_expr.iter().try_for_each(f) + window_expr.iter().for_each_till_continue(f) } LogicalPlan::Aggregate(Aggregate { group_expr, aggr_expr, .. - }) => group_expr.iter().chain(aggr_expr.iter()).try_for_each(f), + }) => group_expr + .iter() + .chain(aggr_expr.iter()) + .for_each_till_continue(f), // There are two part of expression for join, equijoin(on) and non-equijoin(filter). // 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`. // 2. the second part is non-equijoin(filter). @@ -349,22 +352,21 @@ impl LogicalPlan { on.iter() // it not ideal to create an expr here to analyze them, but could cache it on the Join itself .map(|(l, r)| Expr::eq(l.clone(), r.clone())) - .try_for_each(|e| f(&e))?; - - if let Some(filter) = filter.as_ref() { - f(filter) - } else { - Ok(()) - } + .for_each_till_continue(&mut |e| f(&e)) + .and_then_on_continue(|| filter.iter().for_each_till_continue(f)) } - LogicalPlan::Sort(Sort { expr, .. }) => expr.iter().try_for_each(f), + LogicalPlan::Sort(Sort { expr, .. }) => expr.iter().for_each_till_continue(f), LogicalPlan::Extension(extension) => { // would be nice to avoid this copy -- maybe can // update extension to just observer Exprs - extension.node.expressions().iter().try_for_each(f) + extension + .node + .expressions() + .iter() + .for_each_till_continue(f) } LogicalPlan::TableScan(TableScan { filters, .. }) => { - filters.iter().try_for_each(f) + filters.iter().for_each_till_continue(f) } LogicalPlan::Unnest(Unnest { column, .. }) => { f(&Expr::Column(column.clone())) @@ -378,7 +380,7 @@ impl LogicalPlan { .iter() .chain(select_expr.iter()) .chain(sort_expr.clone().unwrap_or(vec![]).iter()) - .try_for_each(f), + .for_each_till_continue(f), // plans without expressions LogicalPlan::EmptyRelation(_) | LogicalPlan::Subquery(_) @@ -394,7 +396,7 @@ impl LogicalPlan { | LogicalPlan::Ddl(_) | LogicalPlan::Copy(_) | LogicalPlan::DescribeTable(_) - | LogicalPlan::Prepare(_) => Ok(()), + | LogicalPlan::Prepare(_) => Ok(TreeNodeRecursion::Continue), } } @@ -440,7 +442,7 @@ impl LogicalPlan { pub fn using_columns(&self) -> Result>, DataFusionError> { let mut using_columns: Vec> = vec![]; - self.apply(&mut |plan| { + self.visit_down(&mut |plan| { if let LogicalPlan::Join(Join { join_constraint: JoinConstraint::Using, on, @@ -456,7 +458,7 @@ impl LogicalPlan { })?; using_columns.push(columns); } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(using_columns) @@ -642,7 +644,7 @@ impl LogicalPlan { } LogicalPlan::Filter { .. } => { assert_eq!(1, expr.len()); - let predicate = expr.pop().unwrap(); + let mut predicate = expr.pop().unwrap(); // filter predicates should not contain aliased expressions so we remove any aliases // before this logic was added we would have aliases within filters such as for @@ -658,29 +660,39 @@ impl LogicalPlan { struct RemoveAliases {} - impl TreeNodeRewriter for RemoveAliases { - type N = Expr; + impl TreeNodeTransformer for RemoveAliases { + type Node = Expr; - fn pre_visit(&mut self, expr: &Expr) -> Result { + fn pre_transform( + &mut self, + expr: &mut Expr, + ) -> Result { match expr { Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::InSubquery(_) => { // subqueries could contain aliases so we don't recurse into those - Ok(RewriteRecursion::Stop) + Ok(TreeNodeRecursion::Prune) } - Expr::Alias(_) => Ok(RewriteRecursion::Mutate), - _ => Ok(RewriteRecursion::Continue), + Expr::Alias(_) => { + expr.unalias(); + Ok(TreeNodeRecursion::Prune) + } + _ => Ok(TreeNodeRecursion::Continue), } } - fn mutate(&mut self, expr: Expr) -> Result { - Ok(expr.unalias()) + fn post_transform( + &mut self, + expr: &mut Expr, + ) -> Result { + expr.unalias(); + Ok(TreeNodeRecursion::Continue) } } let mut remove_aliases = RemoveAliases {}; - let predicate = predicate.rewrite(&mut remove_aliases)?; + predicate.transform(&mut remove_aliases)?; Filter::try_new(predicate, Arc::new(inputs[0].clone())) .map(LogicalPlan::Filter) @@ -754,10 +766,10 @@ impl LogicalPlan { // The first part of expr is equi-exprs, // and the struct of each equi-expr is like `left-expr = right-expr`. assert_eq!(expr.len(), equi_expr_count); - let new_on:Vec<(Expr,Expr)> = expr.into_iter().map(|equi_expr| { + let new_on:Vec<(Expr,Expr)> = expr.into_iter().map(|mut equi_expr| { // SimplifyExpression rule may add alias to the equi_expr. - let unalias_expr = equi_expr.clone().unalias(); - if let Expr::BinaryExpr(BinaryExpr { left, op: Operator::Eq, right }) = unalias_expr { + equi_expr.unalias(); + if let Expr::BinaryExpr(BinaryExpr { left, op: Operator::Eq, right }) = equi_expr { Ok((*left, *right)) } else { internal_err!( @@ -1126,59 +1138,27 @@ impl LogicalPlan { | LogicalPlan::Extension(_) => None, } } -} -impl LogicalPlan { - /// applies `op` to any subqueries in the plan - pub(crate) fn apply_subqueries(&self, op: &mut F) -> datafusion_common::Result<()> + /// Apply `f` on the root nodes of subquery plans of the plan node. + /// `f` is not allowed to return [`TreeNodeRecursion::Prune`]. + pub fn apply_subqueries(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> datafusion_common::Result, + F: FnMut(&Self) -> Result, { - self.inspect_expressions(|expr| { - // recursively look for subqueries - inspect_expr_pre(expr, |expr| { - match expr { - Expr::Exists(Exists { subquery, .. }) - | Expr::InSubquery(InSubquery { subquery, .. }) - | Expr::ScalarSubquery(subquery) => { - // use a synthetic plan so the collector sees a - // LogicalPlan::Subquery (even though it is - // actually a Subquery alias) - let synthetic_plan = LogicalPlan::Subquery(subquery.clone()); - synthetic_plan.apply(op)?; - } - _ => {} + self.apply_expressions(&mut |e| { + e.visit_down(&mut |e| match e { + Expr::Exists(Exists { subquery, .. }) + | Expr::InSubquery(InSubquery { subquery, .. }) + | Expr::ScalarSubquery(subquery) => { + // use a synthetic plan so the collector sees a + // LogicalPlan::Subquery (even though it is + // actually a Subquery alias) + let synthetic_plan = LogicalPlan::Subquery(subquery.clone()); + f(&synthetic_plan).fail_on_prune() } - Ok::<(), DataFusionError>(()) + _ => Ok(TreeNodeRecursion::Continue), }) - })?; - Ok(()) - } - - /// applies visitor to any subqueries in the plan - pub(crate) fn visit_subqueries(&self, v: &mut V) -> datafusion_common::Result<()> - where - V: TreeNodeVisitor, - { - self.inspect_expressions(|expr| { - // recursively look for subqueries - inspect_expr_pre(expr, |expr| { - match expr { - Expr::Exists(Exists { subquery, .. }) - | Expr::InSubquery(InSubquery { subquery, .. }) - | Expr::ScalarSubquery(subquery) => { - // use a synthetic plan so the visitor sees a - // LogicalPlan::Subquery (even though it is - // actually a Subquery alias) - let synthetic_plan = LogicalPlan::Subquery(subquery.clone()); - synthetic_plan.visit(v)?; - } - _ => {} - } - Ok::<(), DataFusionError>(()) - }) - })?; - Ok(()) + }) } /// Return a `LogicalPlan` with all placeholders (e.g $1 $2, @@ -1214,9 +1194,9 @@ impl LogicalPlan { ) -> Result>, DataFusionError> { let mut param_types: HashMap> = HashMap::new(); - self.apply(&mut |plan| { - plan.inspect_expressions(|expr| { - expr.apply(&mut |expr| { + self.visit_down(&mut |plan| { + plan.apply_expressions(&mut |expr| { + expr.visit_down(&mut |expr| { if let Expr::Placeholder(Placeholder { id, data_type }) = expr { let prev = param_types.get(id); match (prev, data_type) { @@ -1231,11 +1211,9 @@ impl LogicalPlan { _ => {} } } - Ok(VisitRecursion::Continue) - })?; - Ok::<(), DataFusionError>(()) - })?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) + }) + }) })?; Ok(param_types) @@ -1247,7 +1225,7 @@ impl LogicalPlan { expr: Expr, param_values: &ParamValues, ) -> Result { - expr.transform(&|expr| { + expr.transform_up(&|expr| { match &expr { Expr::Placeholder(Placeholder { id, data_type }) => { let value = @@ -2762,9 +2740,9 @@ digraph { } impl TreeNodeVisitor for OkVisitor { - type N = LogicalPlan; + type Node = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { let s = match plan { LogicalPlan::Projection { .. } => "pre_visit Projection", LogicalPlan::Filter { .. } => "pre_visit Filter", @@ -2775,10 +2753,10 @@ digraph { }; self.strings.push(s.into()); - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } - fn post_visit(&mut self, plan: &LogicalPlan) -> Result { + fn post_visit(&mut self, plan: &LogicalPlan) -> Result { let s = match plan { LogicalPlan::Projection { .. } => "post_visit Projection", LogicalPlan::Filter { .. } => "post_visit Filter", @@ -2789,7 +2767,7 @@ digraph { }; self.strings.push(s.into()); - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } } @@ -2845,20 +2823,20 @@ digraph { } impl TreeNodeVisitor for StoppingVisitor { - type N = LogicalPlan; + type Node = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { if self.return_false_from_pre_in.dec() { - return Ok(VisitRecursion::Stop); + return Ok(TreeNodeRecursion::Stop); } self.inner.pre_visit(plan)?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } - fn post_visit(&mut self, plan: &LogicalPlan) -> Result { + fn post_visit(&mut self, plan: &LogicalPlan) -> Result { if self.return_false_from_post_in.dec() { - return Ok(VisitRecursion::Stop); + return Ok(TreeNodeRecursion::Stop); } self.inner.post_visit(plan) @@ -2914,9 +2892,9 @@ digraph { } impl TreeNodeVisitor for ErrorVisitor { - type N = LogicalPlan; + type Node = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { if self.return_error_from_pre_in.dec() { return not_impl_err!("Error in pre_visit"); } @@ -2924,7 +2902,7 @@ digraph { self.inner.pre_visit(plan) } - fn post_visit(&mut self, plan: &LogicalPlan) -> Result { + fn post_visit(&mut self, plan: &LogicalPlan) -> Result { if self.return_error_from_post_in.dec() { return not_impl_err!("Error in post_visit"); } @@ -3217,7 +3195,7 @@ digraph { // after transformation, because plan is not the same anymore, // the parent plan is built again with call to LogicalPlan::with_new_inputs -> with_new_exprs let plan = plan - .transform(&|plan| match plan { + .transform_up(&|plan| match plan { LogicalPlan::TableScan(table) => { let filter = Filter::try_new( external_filter.clone(), diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 1098842716b9..8ec4a94204b0 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -24,15 +24,17 @@ use crate::expr::{ }; use crate::{Expr, GetFieldAccess}; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{ + TreeNode, TreeNodeRecursion, TreeNodeRecursionResult, VisitRecursionIterator, +}; use datafusion_common::{internal_err, DataFusionError, Result}; impl TreeNode for Expr { - fn apply_children(&self, op: &mut F) -> Result + fn apply_children(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { - let children = match self { + match self { Expr::Alias(Alias{expr,..}) | Expr::Not(expr) | Expr::IsNotNull(expr) @@ -47,30 +49,26 @@ impl TreeNode for Expr { | Expr::Cast(Cast { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) | Expr::Sort(Sort { expr, .. }) - | Expr::InSubquery(InSubquery{ expr, .. }) => vec![expr.as_ref().clone()], + | Expr::InSubquery(InSubquery{ expr, .. }) => f(expr), Expr::GetIndexedField(GetIndexedField { expr, field }) => { - let expr = expr.as_ref().clone(); - match field { + f(expr).and_then_on_continue(|| match field { GetFieldAccess::ListIndex {key} => { - vec![key.as_ref().clone(), expr] + f(key) }, - GetFieldAccess::ListRange {start, stop} => { - vec![start.as_ref().clone(), stop.as_ref().clone(), expr] - } - GetFieldAccess::NamedStructField {name: _name} => { - vec![expr] + GetFieldAccess::ListRange { start, stop} => { + f(start).and_then_on_continue(|| f(stop)) } - } + GetFieldAccess::NamedStructField { name: _name } => Ok(TreeNodeRecursion::Continue) + }) } Expr::GroupingSet(GroupingSet::Rollup(exprs)) - | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.clone(), - Expr::ScalarFunction (ScalarFunction{ args, .. } ) => { - args.clone() - } + | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.iter().for_each_till_continue(f), + Expr::ScalarFunction (ScalarFunction{ args, .. } ) => args.iter().for_each_till_continue(f), Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { - lists_of_exprs.clone().into_iter().flatten().collect() + lists_of_exprs.iter().flatten().for_each_till_continue(f) } - Expr::Column(_) + Expr::Nop + | Expr::Column(_) // Treat OuterReferenceColumn as a leaf expression | Expr::OuterReferenceColumn(_, _) | Expr::ScalarVariable(_, _) @@ -78,76 +76,43 @@ impl TreeNode for Expr { | Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::Wildcard {..} - | Expr::Placeholder (_) => vec![], + | Expr::Placeholder (_) => Ok(TreeNodeRecursion::Continue), Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { - vec![left.as_ref().clone(), right.as_ref().clone()] + f(left) + .and_then_on_continue(|| f(right)) } Expr::Like(Like { expr, pattern, .. }) | Expr::SimilarTo(Like { expr, pattern, .. }) => { - vec![expr.as_ref().clone(), pattern.as_ref().clone()] + f(expr) + .and_then_on_continue(|| f(pattern)) } - Expr::Between(Between { - expr, low, high, .. - }) => vec![ - expr.as_ref().clone(), - low.as_ref().clone(), - high.as_ref().clone(), - ], - Expr::Case(case) => { - let mut expr_vec = vec![]; - if let Some(expr) = case.expr.as_ref() { - expr_vec.push(expr.as_ref().clone()); - }; - for (when, then) in case.when_then_expr.iter() { - expr_vec.push(when.as_ref().clone()); - expr_vec.push(then.as_ref().clone()); - } - if let Some(else_expr) = case.else_expr.as_ref() { - expr_vec.push(else_expr.as_ref().clone()); - } - expr_vec + Expr::Between(Between { expr, low, high, .. }) => { + f(expr) + .and_then_on_continue(|| f(low)) + .and_then_on_continue(|| f(high)) + }, + Expr::Case( Case { expr, when_then_expr, else_expr }) => { + expr.as_deref().into_iter().for_each_till_continue(f) + .and_then_on_continue(|| + when_then_expr.iter().for_each_till_continue(&mut |(w, t)| f(w).and_then_on_continue(|| f(t)))) + .and_then_on_continue(|| else_expr.as_deref().into_iter().for_each_till_continue(f)) } Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) => { - let mut expr_vec = args.clone(); - - if let Some(f) = filter { - expr_vec.push(f.as_ref().clone()); - } - if let Some(o) = order_by { - expr_vec.extend(o.clone()); - } - - expr_vec + args.iter().for_each_till_continue(f) + .and_then_on_continue(|| filter.as_deref().into_iter().for_each_till_continue(f)) + .and_then_on_continue(|| order_by.iter().flatten().for_each_till_continue(f)) } - Expr::WindowFunction(WindowFunction { - args, - partition_by, - order_by, - .. - }) => { - let mut expr_vec = args.clone(); - expr_vec.extend(partition_by.clone()); - expr_vec.extend(order_by.clone()); - expr_vec + Expr::WindowFunction(WindowFunction { args, partition_by, order_by, .. }) => { + args.iter().for_each_till_continue(f) + .and_then_on_continue(|| partition_by.iter().for_each_till_continue(f)) + .and_then_on_continue(|| order_by.iter().for_each_till_continue(f)) } Expr::InList(InList { expr, list, .. }) => { - let mut expr_vec = vec![]; - expr_vec.push(expr.as_ref().clone()); - expr_vec.extend(list.clone()); - expr_vec - } - }; - - for child in children.iter() { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + f(expr) + .and_then_on_continue(|| list.iter().for_each_till_continue(f)) } } - - Ok(VisitRecursion::Continue) } fn map_children(self, transform: F) -> Result @@ -157,6 +122,7 @@ impl TreeNode for Expr { let mut transform = transform; Ok(match self { + Expr::Nop => self, Expr::Alias(Alias { expr, relation, @@ -376,6 +342,90 @@ impl TreeNode for Expr { } }) } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + match self { + Expr::Alias(Alias { expr,.. }) + | Expr::Not(expr) + | Expr::IsNotNull(expr) + | Expr::IsTrue(expr) + | Expr::IsFalse(expr) + | Expr::IsUnknown(expr) + | Expr::IsNotTrue(expr) + | Expr::IsNotFalse(expr) + | Expr::IsNotUnknown(expr) + | Expr::IsNull(expr) + | Expr::Negative(expr) + | Expr::Cast(Cast { expr, .. }) + | Expr::TryCast(TryCast { expr, .. }) + | Expr::Sort(Sort { expr, .. }) + | Expr::InSubquery(InSubquery{ expr, .. }) => f(expr), + Expr::GetIndexedField(GetIndexedField { expr, field }) => { + f(expr).and_then_on_continue(|| match field { + GetFieldAccess::ListIndex {key} => { + f(key) + }, + GetFieldAccess::ListRange { start, stop} => { + f(start).and_then_on_continue(|| f(stop)) + } + GetFieldAccess::NamedStructField { name: _name } => Ok(TreeNodeRecursion::Continue) + }) + } + Expr::GroupingSet(GroupingSet::Rollup(exprs)) + | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.iter_mut().for_each_till_continue(f), + | Expr::ScalarFunction(ScalarFunction{ args, .. }) => args.iter_mut().for_each_till_continue(f), + Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { + lists_of_exprs.iter_mut().flatten().for_each_till_continue(f) + } + Expr::Nop + | Expr::Column(_) + // Treat OuterReferenceColumn as a leaf expression + | Expr::OuterReferenceColumn(_, _) + | Expr::ScalarVariable(_, _) + | Expr::Literal(_) + | Expr::Exists { .. } + | Expr::ScalarSubquery(_) + | Expr::Wildcard {..} + | Expr::Placeholder (_) => Ok(TreeNodeRecursion::Continue), + Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { + f(left) + .and_then_on_continue(|| f(right)) + } + Expr::Like(Like { expr, pattern, .. }) + | Expr::SimilarTo(Like { expr, pattern, .. }) => { + f(expr) + .and_then_on_continue(|| f(pattern)) + } + Expr::Between(Between { expr, low, high, .. }) => { + f(expr) + .and_then_on_continue(|| f(low)) + .and_then_on_continue(|| f(high)) + }, + Expr::Case( Case { expr, when_then_expr, else_expr }) => { + expr.as_deref_mut().into_iter().for_each_till_continue(f) + .and_then_on_continue(|| + when_then_expr.iter_mut().for_each_till_continue(&mut |(w, t)| f(w).and_then_on_continue(|| f(t)))) + .and_then_on_continue(|| else_expr.as_deref_mut().into_iter().for_each_till_continue(f)) + } + Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) => { + args.iter_mut().for_each_till_continue(f) + .and_then_on_continue(|| filter.as_deref_mut().into_iter().for_each_till_continue(f)) + .and_then_on_continue(|| order_by.iter_mut().flatten().for_each_till_continue(f)) + } + Expr::WindowFunction(WindowFunction { args, partition_by, order_by, .. }) => { + args.iter_mut().for_each_till_continue(f) + .and_then_on_continue(|| partition_by.iter_mut().for_each_till_continue(f)) + .and_then_on_continue(|| order_by.iter_mut().for_each_till_continue(f)) + } + Expr::InList(InList { expr, list, .. }) => { + f(expr) + .and_then_on_continue(|| list.iter_mut().for_each_till_continue(f)) + } + } + } } fn transform_boxed(boxed_expr: Box, transform: &mut F) -> Result> diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/tree_node/plan.rs index c7621bc17833..e85294ea5f73 100644 --- a/datafusion/expr/src/tree_node/plan.rs +++ b/datafusion/expr/src/tree_node/plan.rs @@ -18,92 +18,22 @@ //! Tree node implementation for logical plan use crate::LogicalPlan; -use datafusion_common::tree_node::{TreeNodeVisitor, VisitRecursion}; -use datafusion_common::{tree_node::TreeNode, Result}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, VisitRecursionIterator}; +use datafusion_common::Result; impl TreeNode for LogicalPlan { - fn apply(&self, op: &mut F) -> Result + fn apply_children(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { - // Note, - // - // Compared to the default implementation, we need to invoke - // [`Self::apply_subqueries`] before visiting its children - match op(self)? { - VisitRecursion::Continue => {} - // If the recursion should skip, do not apply to its children. And let the recursion continue - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - // If the recursion should stop, do not apply to its children - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - }; - - self.apply_subqueries(op)?; - - self.apply_children(&mut |node| node.apply(op)) - } - - /// To use, define a struct that implements the trait [`TreeNodeVisitor`] and then invoke - /// [`LogicalPlan::visit`]. - /// - /// For example, for a logical plan like: - /// - /// ```text - /// Projection: id - /// Filter: state Eq Utf8(\"CO\")\ - /// CsvScan: employee.csv projection=Some([0, 3])"; - /// ``` - /// - /// The sequence of visit operations would be: - /// ```text - /// visitor.pre_visit(Projection) - /// visitor.pre_visit(Filter) - /// visitor.pre_visit(CsvScan) - /// visitor.post_visit(CsvScan) - /// visitor.post_visit(Filter) - /// visitor.post_visit(Projection) - /// ``` - fn visit>( - &self, - visitor: &mut V, - ) -> Result { - // Compared to the default implementation, we need to invoke - // [`Self::visit_subqueries`] before visiting its children - - match visitor.pre_visit(self)? { - VisitRecursion::Continue => {} - // If the recursion should skip, do not apply to its children. And let the recursion continue - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - // If the recursion should stop, do not apply to its children - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - }; - - self.visit_subqueries(visitor)?; - - match self.apply_children(&mut |node| node.visit(visitor))? { - VisitRecursion::Continue => {} - // If the recursion should skip, do not apply to its children. And let the recursion continue - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - // If the recursion should stop, do not apply to its children - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - - visitor.post_visit(self) + self.inputs().into_iter().for_each_till_continue(f) } - fn apply_children(&self, op: &mut F) -> Result + fn apply_inner_children(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { - for child in self.inputs() { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + self.apply_subqueries(f) } fn map_children(self, transform: F) -> Result @@ -128,4 +58,24 @@ impl TreeNode for LogicalPlan { Ok(self) } } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + let old_children = self.inputs(); + let mut new_children = + old_children.iter().map(|&c| c.clone()).collect::>(); + let tnr = new_children.iter_mut().for_each_till_continue(f)?; + + // if any changes made, make a new child + if old_children + .iter() + .zip(new_children.iter()) + .any(|(c1, c2)| c1 != &c2) + { + *self = self.with_new_inputs(new_children.as_slice())?; + } + Ok(tnr) + } } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index abdd7f5f57f6..9d0daa5f4ca2 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -31,7 +31,7 @@ use crate::{ }; use arrow::datatypes::{DataType, TimeUnit}; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{ internal_err, plan_datafusion_err, plan_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference, @@ -261,15 +261,16 @@ pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result> { /// Recursively walk an expression tree, collecting the unique set of columns /// referenced in the expression pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { - inspect_expr_pre(expr, |expr| { - match expr { + expr.visit_down(&mut |e| { + match e { Expr::Column(qc) => { accum.insert(qc.clone()); } // Use explicit pattern match instead of a default // implementation, so that in the future if someone adds // new Expr types, they will check here as well - Expr::ScalarVariable(_, _) + Expr::Nop + | Expr::ScalarVariable(_, _) | Expr::Alias(_) | Expr::Literal(_) | Expr::BinaryExpr { .. } @@ -303,8 +304,9 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::Placeholder(_) | Expr::OuterReferenceColumn { .. } => {} } - Ok(()) + Ok(TreeNodeRecursion::Continue) }) + .map(|_| ()) } /// Find excluded columns in the schema, if any @@ -655,44 +657,22 @@ where F: Fn(&Expr) -> bool, { let mut exprs = vec![]; - expr.apply(&mut |expr| { + expr.visit_down(&mut |expr| { if test_fn(expr) { if !(exprs.contains(expr)) { exprs.push(expr.clone()) } // stop recursing down this expr once we find a match - return Ok(VisitRecursion::Skip); + return Ok(TreeNodeRecursion::Prune); } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) }) // pre_visit always returns OK, so this will always too .expect("no way to return error during recursion"); exprs } -/// Recursively inspect an [`Expr`] and all its children. -pub fn inspect_expr_pre(expr: &Expr, mut f: F) -> Result<(), E> -where - F: FnMut(&Expr) -> Result<(), E>, -{ - let mut err = Ok(()); - expr.apply(&mut |expr| { - if let Err(e) = f(expr) { - // save the error for later (it may not be a DataFusionError - err = Err(e); - Ok(VisitRecursion::Stop) - } else { - // keep going - Ok(VisitRecursion::Continue) - } - }) - // The closure always returns OK, so this will always too - .expect("no way to return error during recursion"); - - err -} - /// Returns a new logical plan based on the original one with inputs /// and expressions replaced. /// @@ -825,17 +805,14 @@ pub fn find_column_exprs(exprs: &[Expr]) -> Vec { .collect() } -pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec { - let mut exprs = vec![]; - inspect_expr_pre(e, |expr| { - if let Expr::Column(c) = expr { - exprs.push(c.clone()) +pub(crate) fn find_columns_referenced_by_expr(expr: &Expr) -> Vec { + expr.collect(&mut |e| { + if let Expr::Column(c) = e { + vec![c.clone()] + } else { + vec![] } - Ok(()) as Result<()> }) - // As the closure always returns Ok, this "can't" error - .expect("Unexpected error"); - exprs } /// Convert any `Expr` to an `Expr::Column`. @@ -852,26 +829,16 @@ pub fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result { /// Recursively walk an expression tree, collecting the column indexes /// referenced in the expression pub(crate) fn find_column_indexes_referenced_by_expr( - e: &Expr, + expr: &Expr, schema: &DFSchemaRef, ) -> Vec { - let mut indexes = vec![]; - inspect_expr_pre(e, |expr| { - match expr { - Expr::Column(qc) => { - if let Ok(idx) = schema.index_of_column(qc) { - indexes.push(idx); - } - } - Expr::Literal(_) => { - indexes.push(std::usize::MAX); - } - _ => {} + expr.collect(&mut |e| match e { + Expr::Column(qc) => schema.index_of_column(qc).into_iter().collect(), + Expr::Literal(_) => { + vec![std::usize::MAX] } - Ok(()) as Result<()> + _ => vec![], }) - .unwrap(); - indexes } /// can this data type be used in hash join equal conditions?? diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index fd84bb80160b..17b1ad8cc73f 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -17,14 +17,18 @@ use crate::analyzer::AnalyzerRule; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeRecursion, TreeNodeTransformer, +}; use datafusion_common::Result; -use datafusion_expr::expr::{AggregateFunction, AggregateFunctionDefinition, InSubquery}; +use datafusion_expr::expr::{ + AggregateFunction, AggregateFunctionDefinition, Exists, InSubquery, WindowFunction, +}; use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_expr::Expr::ScalarSubquery; use datafusion_expr::{ - aggregate_function, expr, lit, window_function, Aggregate, Expr, Filter, LogicalPlan, + aggregate_function, lit, window_function, Aggregate, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Projection, Sort, Subquery, }; use std::sync::Arc; @@ -114,108 +118,69 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { struct CountWildcardRewriter {} -impl TreeNodeRewriter for CountWildcardRewriter { - type N = Expr; +impl TreeNodeTransformer for CountWildcardRewriter { + type Node = Expr; + + fn pre_transform(&mut self, _node: &mut Self::Node) -> Result { + Ok(TreeNodeRecursion::Continue) + } - fn mutate(&mut self, old_expr: Expr) -> Result { - let new_expr = match old_expr.clone() { - Expr::WindowFunction(expr::WindowFunction { + fn post_transform(&mut self, expr: &mut Expr) -> Result { + match expr { + Expr::WindowFunction(WindowFunction { fun: window_function::WindowFunction::AggregateFunction( aggregate_function::AggregateFunction::Count, ), args, - partition_by, - order_by, - window_frame, - }) if args.len() == 1 => match args[0] { - Expr::Wildcard { qualifier: None } => { - Expr::WindowFunction(expr::WindowFunction { - fun: window_function::WindowFunction::AggregateFunction( - aggregate_function::AggregateFunction::Count, - ), - args: vec![lit(COUNT_STAR_EXPANSION)], - partition_by, - order_by, - window_frame, - }) + .. + }) if args.len() == 1 => { + if let Expr::Wildcard { qualifier: None } = args[0] { + args[0] = lit(COUNT_STAR_EXPANSION) } - - _ => old_expr, - }, + } Expr::AggregateFunction(AggregateFunction { func_def: AggregateFunctionDefinition::BuiltIn( aggregate_function::AggregateFunction::Count, ), args, - distinct, - filter, - order_by, - }) if args.len() == 1 => match args[0] { - Expr::Wildcard { qualifier: None } => { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::Count, - vec![lit(COUNT_STAR_EXPANSION)], - distinct, - filter, - order_by, - )) + .. + }) if args.len() == 1 => { + if let Expr::Wildcard { qualifier: None } = args[0] { + args[0] = lit(COUNT_STAR_EXPANSION) } - _ => old_expr, - }, - - ScalarSubquery(Subquery { - subquery, - outer_ref_columns, - }) => { + } + ScalarSubquery(Subquery { subquery, .. }) => { let new_plan = subquery .as_ref() .clone() .transform_down(&analyze_internal)?; - ScalarSubquery(Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns, - }) + *subquery = Arc::new(new_plan); } Expr::InSubquery(InSubquery { - expr, - subquery, - negated, + subquery: Subquery { subquery, .. }, + .. }) => { let new_plan = subquery - .subquery .as_ref() .clone() .transform_down(&analyze_internal)?; - - Expr::InSubquery(InSubquery::new( - expr, - Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns: subquery.outer_ref_columns, - }, - negated, - )) + *subquery = Arc::new(new_plan); } - Expr::Exists(expr::Exists { subquery, negated }) => { + Expr::Exists(Exists { + subquery: Subquery { subquery, .. }, + .. + }) => { let new_plan = subquery - .subquery .as_ref() .clone() .transform_down(&analyze_internal)?; - - Expr::Exists(expr::Exists { - subquery: Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns: subquery.outer_ref_columns, - }, - negated, - }) + *subquery = Arc::new(new_plan); } - _ => old_expr, + _ => {} }; - Ok(new_expr) + Ok(TreeNodeRecursion::Continue) } } diff --git a/datafusion/optimizer/src/analyzer/inline_table_scan.rs b/datafusion/optimizer/src/analyzer/inline_table_scan.rs index 90af7aec8293..a418fbf5537b 100644 --- a/datafusion/optimizer/src/analyzer/inline_table_scan.rs +++ b/datafusion/optimizer/src/analyzer/inline_table_scan.rs @@ -74,7 +74,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { Transformed::Yes(plan) } LogicalPlan::Filter(filter) => { - let new_expr = filter.predicate.transform(&rewrite_subquery)?; + let new_expr = filter.predicate.transform_up(&rewrite_subquery)?; Transformed::Yes(LogicalPlan::Filter(Filter::try_new( new_expr, filter.input, diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index 14d5ddf47378..0b2c20db3957 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -27,11 +27,10 @@ use crate::analyzer::subquery::check_subquery_expr; use crate::analyzer::type_coercion::TypeCoercion; use crate::utils::log_plan; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::expr::Exists; use datafusion_expr::expr::InSubquery; -use datafusion_expr::utils::inspect_expr_pre; use datafusion_expr::{Expr, LogicalPlan}; use log::debug; use std::sync::Arc; @@ -117,21 +116,21 @@ impl Analyzer { /// Do necessary check and fail the invalid plan fn check_plan(plan: &LogicalPlan) -> Result<()> { - plan.apply(&mut |plan: &LogicalPlan| { - for expr in plan.expressions().iter() { + plan.visit_down(&mut |plan: &LogicalPlan| { + plan.apply_expressions(&mut |e| { // recursively look for subqueries - inspect_expr_pre(expr, |expr| match expr { - Expr::Exists(Exists { subquery, .. }) - | Expr::InSubquery(InSubquery { subquery, .. }) - | Expr::ScalarSubquery(subquery) => { - check_subquery_expr(plan, &subquery.subquery, expr) + e.visit_down(&mut |e| { + match e { + Expr::Exists(Exists { subquery, .. }) + | Expr::InSubquery(InSubquery { subquery, .. }) + | Expr::ScalarSubquery(subquery) => { + check_subquery_expr(plan, &subquery.subquery, e)? + } + _ => {} } - _ => Ok(()), - })?; - } - - Ok(VisitRecursion::Continue) - })?; - - Ok(()) + Ok(TreeNodeRecursion::Continue) + }) + }) + }) + .map(|_| ()) } diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index 7c5b70b19af0..78c630982d9b 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -17,7 +17,7 @@ use crate::analyzer::check_plan; use crate::utils::collect_subquery_cols; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_expr::expr_rewriter::strip_outer_reference; use datafusion_expr::utils::split_conjunction; @@ -146,7 +146,7 @@ fn check_inner_plan( LogicalPlan::Aggregate(_) => { inner_plan.apply_children(&mut |plan| { check_inner_plan(plan, is_scalar, true, can_contain_outer_ref)?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) } @@ -171,7 +171,7 @@ fn check_inner_plan( check_mixed_out_refer_in_window(window)?; inner_plan.apply_children(&mut |plan| { check_inner_plan(plan, is_scalar, is_aggregate, can_contain_outer_ref)?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) } @@ -188,7 +188,7 @@ fn check_inner_plan( | LogicalPlan::SubqueryAlias(_) => { inner_plan.apply_children(&mut |plan| { check_inner_plan(plan, is_scalar, is_aggregate, can_contain_outer_ref)?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) } @@ -206,7 +206,7 @@ fn check_inner_plan( is_aggregate, can_contain_outer_ref, )?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) } @@ -221,7 +221,7 @@ fn check_inner_plan( JoinType::Full => { inner_plan.apply_children(&mut |plan| { check_inner_plan(plan, is_scalar, is_aggregate, false)?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) } @@ -281,7 +281,7 @@ fn strip_inner_query(inner_plan: &LogicalPlan) -> &LogicalPlan { fn get_correlated_expressions(inner_plan: &LogicalPlan) -> Result> { let mut exprs = vec![]; - inner_plan.apply(&mut |plan| { + inner_plan.visit_down(&mut |plan| { if let LogicalPlan::Filter(Filter { predicate, .. }) = plan { let (correlated, _): (Vec<_>, Vec<_>) = split_conjunction(predicate) .into_iter() @@ -290,9 +290,9 @@ fn get_correlated_expressions(inner_plan: &LogicalPlan) -> Result> { correlated .into_iter() .for_each(|expr| exprs.push(strip_outer_reference(expr.clone()))); - return Ok(VisitRecursion::Continue); + return Ok(TreeNodeRecursion::Continue); } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(exprs) } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index c5e1180b9f97..3d8526bb32b3 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -17,12 +17,13 @@ //! Optimizer rule for type validation and coercion +use std::mem; use std::sync::Arc; use arrow::datatypes::{DataType, IntervalUnit}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter}; +use datafusion_common::tree_node::{TreeNodeRecursion, TreeNodeTransformer}; use datafusion_common::{ exec_err, internal_err, plan_datafusion_err, plan_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, @@ -44,7 +45,6 @@ use datafusion_expr::type_coercion::other::{ use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_large_utf8}; use datafusion_expr::utils::merge_schema; use datafusion_expr::{ - is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, type_coercion, window_function, AggregateFunction, BuiltinScalarFunction, Expr, ExprSchemable, LogicalPlan, Operator, Projection, ScalarFunctionDefinition, Signature, WindowFrame, WindowFrameBound, WindowFrameUnits, @@ -125,40 +125,24 @@ pub(crate) struct TypeCoercionRewriter { pub(crate) schema: DFSchemaRef, } -impl TreeNodeRewriter for TypeCoercionRewriter { - type N = Expr; +impl TreeNodeTransformer for TypeCoercionRewriter { + type Node = Expr; - fn pre_visit(&mut self, _expr: &Expr) -> Result { - Ok(RewriteRecursion::Continue) + fn pre_transform(&mut self, _expr: &mut Expr) -> Result { + Ok(TreeNodeRecursion::Continue) } - fn mutate(&mut self, expr: Expr) -> Result { + fn post_transform(&mut self, expr: &mut Expr) -> Result { match expr { - Expr::ScalarSubquery(Subquery { - subquery, - outer_ref_columns, + Expr::ScalarSubquery(Subquery { subquery, .. }) + | Expr::Exists(Exists { + subquery: Subquery { subquery, .. }, + .. }) => { - let new_plan = analyze_internal(&self.schema, &subquery)?; - Ok(Expr::ScalarSubquery(Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns, - })) - } - Expr::Exists(Exists { subquery, negated }) => { - let new_plan = analyze_internal(&self.schema, &subquery.subquery)?; - Ok(Expr::Exists(Exists { - subquery: Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns: subquery.outer_ref_columns, - }, - negated, - })) + let new_plan = analyze_internal(&self.schema, subquery)?; + *subquery = Arc::new(new_plan); } - Expr::InSubquery(InSubquery { - expr, - subquery, - negated, - }) => { + Expr::InSubquery(InSubquery { expr, subquery, .. }) => { let new_plan = analyze_internal(&self.schema, &subquery.subquery)?; let expr_type = expr.get_type(&self.schema)?; let subquery_type = new_plan.schema().field(0).data_type(); @@ -166,53 +150,31 @@ impl TreeNodeRewriter for TypeCoercionRewriter { "expr type {expr_type:?} can't cast to {subquery_type:?} in InSubquery" ), )?; + **expr = mem::take(expr.as_mut()).cast_to(&common_type, &self.schema)?; let new_subquery = Subquery { subquery: Arc::new(new_plan), - outer_ref_columns: subquery.outer_ref_columns, + outer_ref_columns: mem::take(&mut subquery.outer_ref_columns), }; - Ok(Expr::InSubquery(InSubquery::new( - Box::new(expr.cast_to(&common_type, &self.schema)?), - cast_subquery(new_subquery, &common_type)?, - negated, - ))) - } - Expr::IsTrue(expr) => { - let expr = is_true(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsNotTrue(expr) => { - let expr = is_not_true(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsFalse(expr) => { - let expr = is_false(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) + *subquery = cast_subquery(new_subquery, &common_type)?; } - Expr::IsNotFalse(expr) => { - let expr = - is_not_false(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsUnknown(expr) => { - let expr = is_unknown(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsNotUnknown(expr) => { - let expr = - is_not_unknown(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) + Expr::IsTrue(expr) + | Expr::IsNotTrue(expr) + | Expr::IsFalse(expr) + | Expr::IsNotFalse(expr) + | Expr::IsUnknown(expr) + | Expr::IsNotUnknown(expr) => { + **expr = get_casted_expr_for_bool_op(expr, &self.schema)? } Expr::Like(Like { - negated, expr, pattern, - escape_char, case_insensitive, + .. }) => { let left_type = expr.get_type(&self.schema)?; let right_type = pattern.get_type(&self.schema)?; let coerced_type = like_coercion(&left_type, &right_type).ok_or_else(|| { - let op_name = if case_insensitive { + let op_name = if *case_insensitive { "ILIKE" } else { "LIKE" @@ -221,35 +183,21 @@ impl TreeNodeRewriter for TypeCoercionRewriter { "There isn't a common type to coerce {left_type} and {right_type} in {op_name} expression" ) })?; - let expr = Box::new(expr.cast_to(&coerced_type, &self.schema)?); - let pattern = Box::new(pattern.cast_to(&coerced_type, &self.schema)?); - let expr = Expr::Like(Like::new( - negated, - expr, - pattern, - escape_char, - case_insensitive, - )); - Ok(expr) + **expr = mem::take(expr.as_mut()).cast_to(&coerced_type, &self.schema)?; + **pattern = + mem::take(pattern.as_mut()).cast_to(&coerced_type, &self.schema)?; } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { let (left_type, right_type) = get_input_types( &left.get_type(&self.schema)?, - &op, + op, &right.get_type(&self.schema)?, )?; - - Ok(Expr::BinaryExpr(BinaryExpr::new( - Box::new(left.cast_to(&left_type, &self.schema)?), - op, - Box::new(right.cast_to(&right_type, &self.schema)?), - ))) + **left = mem::take(left.as_mut()).cast_to(&left_type, &self.schema)?; + **right = mem::take(right.as_mut()).cast_to(&right_type, &self.schema)?; } Expr::Between(Between { - expr, - negated, - low, - high, + expr, low, high, .. }) => { let expr_type = expr.get_type(&self.schema)?; let low_type = low.get_type(&self.schema)?; @@ -273,19 +221,13 @@ impl TreeNodeRewriter for TypeCoercionRewriter { "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression" )) })?; - let expr = Expr::Between(Between::new( - Box::new(expr.cast_to(&coercion_type, &self.schema)?), - negated, - Box::new(low.cast_to(&coercion_type, &self.schema)?), - Box::new(high.cast_to(&coercion_type, &self.schema)?), - )); - Ok(expr) + **expr = + mem::take(expr.as_mut()).cast_to(&coercion_type, &self.schema)?; + **low = mem::take(low.as_mut()).cast_to(&coercion_type, &self.schema)?; + **high = + mem::take(high.as_mut()).cast_to(&coercion_type, &self.schema)?; } - Expr::InList(InList { - expr, - list, - negated, - }) => { + Expr::InList(InList { expr, list, .. }) => { let expr_data_type = expr.get_type(&self.schema)?; let list_data_types = list .iter() @@ -296,28 +238,21 @@ impl TreeNodeRewriter for TypeCoercionRewriter { match result_type { None => plan_err!( "Can not find compatible types to compare {expr_data_type:?} with {list_data_types:?}" - ), + )?, Some(coerced_type) => { // find the coerced type - let cast_expr = expr.cast_to(&coerced_type, &self.schema)?; - let cast_list_expr = list - .into_iter() - .map(|list_expr| { - list_expr.cast_to(&coerced_type, &self.schema) - }) - .collect::>>()?; - let expr = Expr::InList(InList ::new( - Box::new(cast_expr), - cast_list_expr, - negated, - )); - Ok(expr) + **expr = mem::take(expr.as_mut()).cast_to(&coerced_type, &self.schema)?; + list.iter_mut() + .try_for_each(|list_expr| { + mem::take(list_expr).cast_to(&coerced_type, &self.schema).map(|r| *list_expr = r) + })?; } } } - Expr::Case(case) => { - let case = coerce_case_expression(case, &self.schema)?; - Ok(Expr::Case(case)) + Expr::Case(_) => { + if let Expr::Case(case) = mem::take(expr) { + *expr = Expr::Case(coerce_case_expression(case, &self.schema)?); + } } Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { ScalarFunctionDefinition::BuiltIn(fun) => { @@ -326,12 +261,9 @@ impl TreeNodeRewriter for TypeCoercionRewriter { &self.schema, &fun.signature(), )?; - let new_args = coerce_arguments_for_fun( - new_args.as_slice(), - &self.schema, - &fun, - )?; - Ok(Expr::ScalarFunction(ScalarFunction::new(fun, new_args))) + let new_args = + coerce_arguments_for_fun(new_args.as_slice(), &self.schema, fun)?; + *args = new_args } ScalarFunctionDefinition::UDF(fun) => { let new_expr = coerce_arguments_for_signature( @@ -339,30 +271,23 @@ impl TreeNodeRewriter for TypeCoercionRewriter { &self.schema, fun.signature(), )?; - Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, new_expr))) + *args = new_expr } ScalarFunctionDefinition::Name(_) => { - internal_err!("Function `Expr` with name should be resolved.") + internal_err!("Function `Expr` with name should be resolved.")? } }, Expr::AggregateFunction(expr::AggregateFunction { - func_def, - args, - distinct, - filter, - order_by, + func_def, args, .. }) => match func_def { AggregateFunctionDefinition::BuiltIn(fun) => { let new_expr = coerce_agg_exprs_for_signature( - &fun, - &args, + fun, + args, &self.schema, &fun.signature(), )?; - let expr = Expr::AggregateFunction(expr::AggregateFunction::new( - fun, new_expr, distinct, filter, order_by, - )); - Ok(expr) + *args = new_expr } AggregateFunctionDefinition::UDF(fun) => { let new_expr = coerce_arguments_for_signature( @@ -370,48 +295,47 @@ impl TreeNodeRewriter for TypeCoercionRewriter { &self.schema, fun.signature(), )?; - let expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( - fun, new_expr, false, filter, order_by, - )); - Ok(expr) + *args = new_expr } AggregateFunctionDefinition::Name(_) => { - internal_err!("Function `Expr` with name should be resolved.") + internal_err!("Function `Expr` with name should be resolved.")? } }, - Expr::WindowFunction(WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - }) => { - let window_frame = - coerce_window_frame(window_frame, &self.schema, &order_by)?; - - let args = match &fun { - window_function::WindowFunction::AggregateFunction(fun) => { - coerce_agg_exprs_for_signature( - fun, - &args, - &self.schema, - &fun.signature(), - )? - } - _ => args, - }; - - let expr = Expr::WindowFunction(WindowFunction::new( + Expr::WindowFunction(_) => { + if let Expr::WindowFunction(WindowFunction { fun, args, partition_by, order_by, window_frame, - )); - Ok(expr) + .. + }) = mem::take(expr) + { + let window_frame = + coerce_window_frame(window_frame, &self.schema, &order_by)?; + let args = match &fun { + window_function::WindowFunction::AggregateFunction(fun) => { + coerce_agg_exprs_for_signature( + fun, + &args, + &self.schema, + &fun.signature(), + )? + } + _ => args, + }; + *expr = Expr::WindowFunction(WindowFunction::new( + fun, + args, + partition_by, + order_by, + window_frame, + )); + } } - expr => Ok(expr), + _ => {} } + Ok(TreeNodeRecursion::Continue) } } @@ -1225,7 +1149,7 @@ mod test { None, ), ))); - let expr = Expr::ScalarFunction(ScalarFunction::new( + let mut expr = Expr::ScalarFunction(ScalarFunction::new( BuiltinScalarFunction::MakeArray, vec![val.clone()], )); @@ -1240,8 +1164,8 @@ mod test { )], std::collections::HashMap::new(), )?); - let mut rewriter = TypeCoercionRewriter { schema }; - let result = expr.rewrite(&mut rewriter)?; + let mut transformer = TypeCoercionRewriter { schema }; + expr.transform(&mut transformer)?; let schema = Arc::new(DFSchema::new_with_metadata( vec![DFField::new_unqualified( @@ -1262,7 +1186,7 @@ mod test { vec![expected_casted_expr], )); - assert_eq!(result, expected); + assert_eq!(expr, expected); Ok(()) } @@ -1273,33 +1197,33 @@ mod test { vec![DFField::new_unqualified("a", DataType::Int64, true)], std::collections::HashMap::new(), )?); - let mut rewriter = TypeCoercionRewriter { schema }; - let expr = is_true(lit(12i32).gt(lit(13i64))); + let mut transformer = TypeCoercionRewriter { schema }; + let mut expr = is_true(lit(12i32).gt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64))); - let result = expr.rewrite(&mut rewriter)?; - assert_eq!(expected, result); + expr.transform(&mut transformer)?; + assert_eq!(expected, expr); // eq let schema = Arc::new(DFSchema::new_with_metadata( vec![DFField::new_unqualified("a", DataType::Int64, true)], std::collections::HashMap::new(), )?); - let mut rewriter = TypeCoercionRewriter { schema }; - let expr = is_true(lit(12i32).eq(lit(13i64))); + let mut transformer = TypeCoercionRewriter { schema }; + let mut expr = is_true(lit(12i32).eq(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64))); - let result = expr.rewrite(&mut rewriter)?; - assert_eq!(expected, result); + expr.transform(&mut transformer)?; + assert_eq!(expected, expr); // lt let schema = Arc::new(DFSchema::new_with_metadata( vec![DFField::new_unqualified("a", DataType::Int64, true)], std::collections::HashMap::new(), )?); - let mut rewriter = TypeCoercionRewriter { schema }; - let expr = is_true(lit(12i32).lt(lit(13i64))); + let mut transfomer = TypeCoercionRewriter { schema }; + let mut expr = is_true(lit(12i32).lt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64))); - let result = expr.rewrite(&mut rewriter)?; - assert_eq!(expected, result); + expr.transform(&mut transfomer)?; + assert_eq!(expected, expr); Ok(()) } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 1e089257c61a..d6cad22eb7e2 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -24,7 +24,7 @@ use crate::{utils, OptimizerConfig, OptimizerRule}; use arrow::datatypes::DataType; use datafusion_common::tree_node::{ - RewriteRecursion, TreeNode, TreeNodeRewriter, TreeNodeVisitor, VisitRecursion, + RewriteRecursion, TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, }; use datafusion_common::{ internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, @@ -612,18 +612,18 @@ impl ExprIdentifierVisitor<'_> { } impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { - type N = Expr; + type Node = Expr; - fn pre_visit(&mut self, _expr: &Expr) -> Result { + fn pre_visit(&mut self, _expr: &Expr) -> Result { self.visit_stack .push(VisitRecord::EnterMark(self.node_count)); self.node_count += 1; // put placeholder self.id_array.push((0, "".to_string())); - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } - fn post_visit(&mut self, expr: &Expr) -> Result { + fn post_visit(&mut self, expr: &Expr) -> Result { self.series_number += 1; let (idx, sub_expr_desc) = self.pop_enter_mark(); @@ -632,7 +632,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { self.id_array[idx].0 = self.series_number; let desc = Self::desc_expr(expr); self.visit_stack.push(VisitRecord::ExprItem(desc)); - return Ok(VisitRecursion::Continue); + return Ok(TreeNodeRecursion::Continue); } let mut desc = Self::desc_expr(expr); desc.push_str(&sub_expr_desc); @@ -646,7 +646,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { .entry(desc) .or_insert_with(|| (expr.clone(), 0, data_type)) .1 += 1; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } } diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index b1000f042c98..a68f374d9fe6 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -370,7 +370,7 @@ fn agg_exprs_evaluation_result_on_empty_batch( expr_result_map_for_count_bug: &mut ExprResultMap, ) -> Result<()> { for e in agg_expr.iter() { - let result_expr = e.clone().transform_up(&|expr| { + let mut result_expr = e.clone().transform_up(&|expr| { let new_expr = match expr { Expr::AggregateFunction(expr::AggregateFunction { func_def, .. }) => { match func_def { @@ -396,7 +396,7 @@ fn agg_exprs_evaluation_result_on_empty_batch( Ok(new_expr) })?; - let result_expr = result_expr.unalias(); + result_expr.unalias(); let props = ExecutionProps::new(); let info = SimplifyContext::new(&props).with_schema(schema.clone()); let simplifier = ExprSimplifier::new(info); diff --git a/datafusion/optimizer/src/plan_signature.rs b/datafusion/optimizer/src/plan_signature.rs index 07f495a7262d..97a56f85ef96 100644 --- a/datafusion/optimizer/src/plan_signature.rs +++ b/datafusion/optimizer/src/plan_signature.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use std::{ collections::hash_map::DefaultHasher, hash::{Hash, Hasher}, @@ -73,9 +73,9 @@ impl LogicalPlanSignature { /// Get total number of [`LogicalPlan`]s in the plan. fn get_node_number(plan: &LogicalPlan) -> NonZeroUsize { let mut node_number = 0; - plan.apply(&mut |_plan| { + plan.visit_down(&mut |_plan| { node_number += 1; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) }) // Closure always return Ok .unwrap(); diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 4bea17500acc..cb9c06154bad 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -17,7 +17,7 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_common::{ internal_err, plan_datafusion_err, Column, DFSchema, DataFusionError, Result, }; @@ -213,11 +213,11 @@ fn can_pushdown_join_predicate(predicate: &Expr, schema: &DFSchema) -> Result Result { let mut is_evaluate = true; - predicate.apply(&mut |expr| match expr { + predicate.visit_down(&mut |expr| match expr { Expr::Column(_) | Expr::Literal(_) | Expr::Placeholder(_) - | Expr::ScalarVariable(_, _) => Ok(VisitRecursion::Skip), + | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Prune), Expr::Exists { .. } | Expr::InSubquery(_) | Expr::ScalarSubquery(_) @@ -227,7 +227,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { .. }) => { is_evaluate = false; - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } Expr::Alias(_) | Expr::BinaryExpr(_) @@ -249,8 +249,9 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::Cast(_) | Expr::TryCast(_) | Expr::ScalarFunction(..) - | Expr::InList { .. } => Ok(VisitRecursion::Continue), + | Expr::InList { .. } => Ok(TreeNodeRecursion::Continue), Expr::Sort(_) + | Expr::Nop | Expr::AggregateFunction(_) | Expr::WindowFunction(_) | Expr::Wildcard { .. } @@ -975,29 +976,29 @@ pub fn replace_cols_by_name( /// check whether the expression is volatile predicates fn is_volatile_expression(e: &Expr) -> bool { let mut is_volatile = false; - e.apply(&mut |expr| { + e.visit_down(&mut |expr| { Ok(match expr { Expr::ScalarFunction(f) => match &f.func_def { ScalarFunctionDefinition::BuiltIn(fun) if fun.volatility() == Volatility::Volatile => { is_volatile = true; - VisitRecursion::Stop + TreeNodeRecursion::Stop } ScalarFunctionDefinition::UDF(fun) if fun.signature().volatility == Volatility::Volatile => { is_volatile = true; - VisitRecursion::Stop + TreeNodeRecursion::Stop } ScalarFunctionDefinition::Name(_) => { return internal_err!( "Function `Expr` with name should be resolved." ); } - _ => VisitRecursion::Continue, + _ => TreeNodeRecursion::Continue, }, - _ => VisitRecursion::Continue, + _ => TreeNodeRecursion::Continue, }) }) .unwrap(); @@ -1007,17 +1008,17 @@ fn is_volatile_expression(e: &Expr) -> bool { /// check whether the expression uses the columns in `check_map`. fn contain(e: &Expr, check_map: &HashMap) -> bool { let mut is_contain = false; - e.apply(&mut |expr| { + e.visit_down(&mut |expr| { Ok(if let Expr::Column(c) = &expr { match check_map.get(&c.flat_name()) { Some(_) => { is_contain = true; - VisitRecursion::Stop + TreeNodeRecursion::Stop } - None => VisitRecursion::Continue, + None => TreeNodeRecursion::Continue, } } else { - VisitRecursion::Continue + TreeNodeRecursion::Continue }) }) .unwrap(); diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index e2fbd5e927a1..fe2e1345290b 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -158,10 +158,11 @@ impl ExprSimplifier { // rather than creating an DFSchemaRef coerces rather than doing // it manually. // https://github.com/apache/arrow-datafusion/issues/3793 - pub fn coerce(&self, expr: Expr, schema: DFSchemaRef) -> Result { + pub fn coerce(&self, mut expr: Expr, schema: DFSchemaRef) -> Result { let mut expr_rewrite = TypeCoercionRewriter { schema }; - expr.rewrite(&mut expr_rewrite) + expr.transform(&mut expr_rewrite)?; + Ok(expr) } /// Input guarantees about the values of columns. @@ -330,7 +331,8 @@ impl<'a> ConstEvaluator<'a> { // at plan time match expr { // Has no runtime cost, but needed during planning - Expr::Alias(..) + Expr::Nop + | Expr::Alias(..) | Expr::AggregateFunction { .. } | Expr::ScalarVariable(_, _) | Expr::Column(_) diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 91603e82a54f..dfe4d1fa9ab8 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -24,17 +24,16 @@ use arrow::datatypes::{ DataType, TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, }; use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; -use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter}; +use datafusion_common::tree_node::{TreeNodeRecursion, TreeNodeTransformer}; use datafusion_common::{ internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, }; use datafusion_expr::expr::{BinaryExpr, Cast, InList, TryCast}; use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::utils::merge_schema; -use datafusion_expr::{ - binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator, -}; +use datafusion_expr::{lit, Expr, ExprSchemable, LogicalPlan, Operator}; use std::cmp::Ordering; +use std::mem; use std::sync::Arc; /// [`UnwrapCastInComparison`] attempts to remove casts from @@ -126,21 +125,19 @@ struct UnwrapCastExprRewriter { schema: DFSchemaRef, } -impl TreeNodeRewriter for UnwrapCastExprRewriter { - type N = Expr; +impl TreeNodeTransformer for UnwrapCastExprRewriter { + type Node = Expr; - fn pre_visit(&mut self, _expr: &Expr) -> Result { - Ok(RewriteRecursion::Continue) + fn pre_transform(&mut self, _expr: &mut Expr) -> Result { + Ok(TreeNodeRecursion::Continue) } - fn mutate(&mut self, expr: Expr) -> Result { - match &expr { + fn post_transform(&mut self, expr: &mut Expr) -> Result { + match expr { // For case: // try_cast/cast(expr as data_type) op literal // literal op try_cast/cast(expr as data_type) Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let left = left.as_ref().clone(); - let right = right.as_ref().clone(); let left_type = left.get_type(&self.schema)?; let right_type = right.get_type(&self.schema)?; // Because the plan has been done the type coercion, the left and right must be equal @@ -148,7 +145,7 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { && is_support_data_type(&right_type) && is_comparison_op(op) { - match (&left, &right) { + match (left.as_mut(), right.as_mut()) { ( Expr::Literal(left_lit_value), Expr::TryCast(TryCast { expr, .. }) @@ -161,11 +158,8 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { try_cast_literal_to_type(left_lit_value, &expr_type)?; if let Some(value) = casted_scalar_value { // unwrap the cast/try_cast for the right expr - return Ok(binary_expr( - lit(value), - *op, - expr.as_ref().clone(), - )); + **left = lit(value); + **right = mem::take(expr.as_mut()); } } ( @@ -180,49 +174,42 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { try_cast_literal_to_type(right_lit_value, &expr_type)?; if let Some(value) = casted_scalar_value { // unwrap the cast/try_cast for the left expr - return Ok(binary_expr( - expr.as_ref().clone(), - *op, - lit(value), - )); + **left = mem::take(expr.as_mut()); + **right = lit(value); } } (_, _) => { // do nothing } - }; + } } - // return the new binary op - Ok(binary_expr(left, *op, right)) } // For case: // try_cast/cast(expr as left_type) in (expr1,expr2,expr3) Expr::InList(InList { expr: left_expr, list, - negated, + .. }) => { - if let Some( - Expr::TryCast(TryCast { - expr: internal_left_expr, - .. - }) - | Expr::Cast(Cast { - expr: internal_left_expr, - .. - }), - ) = Some(left_expr.as_ref()) + if let Expr::TryCast(TryCast { + expr: internal_left_expr, + .. + }) + | Expr::Cast(Cast { + expr: internal_left_expr, + .. + }) = left_expr.as_ref() { let internal_left = internal_left_expr.as_ref().clone(); let internal_left_type = internal_left.get_type(&self.schema); if internal_left_type.is_err() { // error data type - return Ok(expr); + return Ok(TreeNodeRecursion::Continue); } let internal_left_type = internal_left_type?; if !is_support_data_type(&internal_left_type) { // not supported data type - return Ok(expr); + return Ok(TreeNodeRecursion::Continue); } let right_exprs = list .iter() @@ -256,19 +243,16 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { } }) .collect::>>(); - match right_exprs { - Ok(right_exprs) => { - Ok(in_list(internal_left, right_exprs, *negated)) - } - Err(_) => Ok(expr), + if let Ok(right_exprs) = right_exprs { + **left_expr = internal_left; + *list = right_exprs; } - } else { - Ok(expr) } } // TODO: handle other expr type and dfs visit them - _ => Ok(expr), + _ => {} } + Ok(TreeNodeRecursion::Continue) } } @@ -730,11 +714,12 @@ mod tests { assert_eq!(optimize_test(expr_lt, &schema), expected); } - fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr { + fn optimize_test(mut expr: Expr, schema: &DFSchemaRef) -> Expr { let mut expr_rewriter = UnwrapCastExprRewriter { schema: schema.clone(), }; - expr.rewrite(&mut expr_rewriter).unwrap() + expr.transform(&mut expr_rewriter).unwrap(); + expr } fn expr_test_schema() -> DFSchemaRef { diff --git a/datafusion/physical-expr/src/equivalence.rs b/datafusion/physical-expr/src/equivalence.rs index defd7b5786a3..4899d69bad58 100644 --- a/datafusion/physical-expr/src/equivalence.rs +++ b/datafusion/physical-expr/src/equivalence.rs @@ -352,7 +352,7 @@ impl EquivalenceGroup { /// class it matches with (if any). pub fn normalize_expr(&self, expr: Arc) -> Arc { expr.clone() - .transform(&|expr| { + .transform_up(&|expr| { for cls in self.iter() { if cls.contains(&expr) { return Ok(Transformed::Yes(cls.canonical_expr().unwrap())); diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 52fb85657f4e..d637cf1e54e6 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -944,7 +944,7 @@ mod tests { let expr2 = expr .clone() - .transform(&|e| { + .transform_up(&|e| { let transformed = match e.as_any().downcast_ref::() { Some(lit_value) => match lit_value.value() { diff --git a/datafusion/physical-expr/src/sort_properties.rs b/datafusion/physical-expr/src/sort_properties.rs index f51374461776..7c61e14e345a 100644 --- a/datafusion/physical-expr/src/sort_properties.rs +++ b/datafusion/physical-expr/src/sort_properties.rs @@ -20,7 +20,8 @@ use std::{ops::Neg, sync::Arc}; use arrow_schema::SortOptions; use crate::PhysicalExpr; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; + +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, VisitRecursionIterator}; use datafusion_common::Result; /// To propagate [`SortOptions`] across the [`PhysicalExpr`], it is insufficient @@ -173,18 +174,11 @@ impl ExprOrdering { } impl TreeNode for ExprOrdering { - fn apply_children(&self, op: &mut F) -> Result + fn apply_children(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { - for child in &self.children { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - Ok(VisitRecursion::Continue) + self.children.iter().for_each_till_continue(f) } fn map_children(mut self, transform: F) -> Result @@ -202,4 +196,11 @@ impl TreeNode for ExprOrdering { Ok(self) } } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + self.children.iter_mut().for_each_till_continue(f) + } } diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index 87ef36558b96..cf4d0a077e6f 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -29,7 +29,7 @@ use arrow::array::{make_array, Array, ArrayRef, BooleanArray, MutableArrayData}; use arrow::compute::{and_kleene, is_not_null, SlicesIterator}; use arrow::datatypes::SchemaRef; use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeRewriter, VisitRecursion, + Transformed, TreeNode, TreeNodeRecursion, TreeNodeTransformer, VisitRecursionIterator, }; use datafusion_common::Result; use datafusion_expr::Operator; @@ -154,19 +154,11 @@ impl ExprTreeNode { } impl TreeNode for ExprTreeNode { - fn apply_children(&self, op: &mut F) -> Result + fn apply_children(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { - for child in self.children() { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + self.children().iter().for_each_till_continue(f) } fn map_children(mut self, transform: F) -> Result @@ -180,9 +172,16 @@ impl TreeNode for ExprTreeNode { .collect::>>()?; Ok(self) } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + self.child_nodes.iter_mut().for_each_till_continue(f) + } } -/// This struct facilitates the [TreeNodeRewriter] mechanism to convert a +/// This struct facilitates the [TreeNodeTransformer] mechanism to convert a /// [PhysicalExpr] tree into a DAEG (i.e. an expression DAG) by collecting /// identical expressions in one node. Caller specifies the node type in the /// DAEG via the `constructor` argument, which constructs nodes in the DAEG @@ -196,16 +195,21 @@ struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode) -> Result< constructor: &'a F, } -impl<'a, T, F: Fn(&ExprTreeNode) -> Result> TreeNodeRewriter +impl<'a, T, F: Fn(&ExprTreeNode) -> Result> TreeNodeTransformer for PhysicalExprDAEGBuilder<'a, T, F> { - type N = ExprTreeNode; + type Node = ExprTreeNode; + + fn pre_transform(&mut self, _node: &mut Self::Node) -> Result { + Ok(TreeNodeRecursion::Continue) + } + // This method mutates an expression node by transforming it to a physical expression // and adding it to the graph. The method returns the mutated expression node. - fn mutate( + fn post_transform( &mut self, - mut node: ExprTreeNode, - ) -> Result> { + node: &mut ExprTreeNode, + ) -> Result { // Get the expression associated with the input expression node. let expr = &node.expr; @@ -217,7 +221,7 @@ impl<'a, T, F: Fn(&ExprTreeNode) -> Result> TreeNodeRewriter // add edges to its child nodes. Add the visited expression to the vector // of visited expressions and return the newly created node index. None => { - let node_idx = self.graph.add_node((self.constructor)(&node)?); + let node_idx = self.graph.add_node((self.constructor)(node)?); for expr_node in node.child_nodes.iter() { self.graph.add_edge(node_idx, expr_node.data.unwrap(), 0); } @@ -228,7 +232,7 @@ impl<'a, T, F: Fn(&ExprTreeNode) -> Result> TreeNodeRewriter // Set the data field of the input expression node to the corresponding node index. node.data = Some(node_idx); // Return the mutated expression node. - Ok(node) + Ok(TreeNodeRecursion::Continue) } } @@ -249,7 +253,8 @@ where constructor, }; // Use the builder to transform the expression tree node into a DAG. - let root = init.rewrite(&mut builder)?; + let mut root = init; + root.transform(&mut builder)?; // Return a tuple containing the root node index and the DAG. Ok((root.data.unwrap(), builder.graph)) } @@ -257,13 +262,13 @@ where /// Recursively extract referenced [`Column`]s within a [`PhysicalExpr`]. pub fn collect_columns(expr: &Arc) -> HashSet { let mut columns = HashSet::::new(); - expr.apply(&mut |expr| { + expr.visit_down(&mut |expr| { if let Some(column) = expr.as_any().downcast_ref::() { if !columns.iter().any(|c| c.eq(column)) { columns.insert(column.clone()); } } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) }) // pre_visit always returns OK, so this will always too .expect("no way to return error during recursion"); diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 2997d147424d..f2d2779f99a8 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -484,6 +484,10 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { use protobuf::logical_expr_node::ExprType; let expr_node = match expr { + Expr::Nop => Err(Error::NotImplemented( + "Proto serialization error: Trying to serialize a Nop expression" + .to_string(), + ))?, Expr::Column(c) => Self { expr_type: Some(ExprType::Column(c.into())), }, From 5c6147049fd48976063724dd69ccff59dc78fba9 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 19 Dec 2023 13:51:08 +0100 Subject: [PATCH 2/9] - refactor `transform_down()` and `transform_up()` to work on mutable `TreeNode`s and use them in a few examples - add `transform_down_with_payload()`, `transform_up_with_payload()`, `transform_with_payload()` and use it in `EnforceSorting` as an example --- datafusion-examples/examples/rewrite_expr.rs | 6 +- datafusion/common/src/tree_node.rs | 207 ++++++++++++------ .../physical_optimizer/coalesce_batches.rs | 2 +- .../combine_partial_final_agg.rs | 103 ++++----- .../enforce_distribution.rs | 6 +- .../src/physical_optimizer/enforce_sorting.rs | 71 +++++- .../src/physical_optimizer/join_selection.rs | 5 +- .../limited_distinct_aggregation.rs | 2 +- .../physical_optimizer/output_requirements.rs | 2 +- .../physical_optimizer/pipeline_checker.rs | 2 +- .../physical_optimizer/projection_pushdown.rs | 2 +- .../core/src/physical_optimizer/pruning.rs | 2 +- .../replace_with_order_preserving_variants.rs | 2 +- .../src/physical_optimizer/sort_pushdown.rs | 165 +------------- .../physical_optimizer/topk_aggregation.rs | 2 +- datafusion/expr/src/expr.rs | 2 +- datafusion/expr/src/expr_rewriter/mod.rs | 16 +- datafusion/expr/src/expr_rewriter/order_by.rs | 2 +- datafusion/expr/src/logical_plan/plan.rs | 13 +- datafusion/expr/src/tree_node/expr.rs | 61 +++--- .../src/analyzer/count_wildcard_rule.rs | 8 +- .../src/analyzer/inline_table_scan.rs | 10 +- datafusion/optimizer/src/decorrelate.rs | 6 +- datafusion/optimizer/src/push_down_filter.rs | 2 +- .../optimizer/src/scalar_subquery_to_join.rs | 7 +- datafusion/physical-expr/src/equivalence.rs | 22 +- .../physical-expr/src/expressions/case.rs | 4 +- datafusion/physical-expr/src/utils/mod.rs | 2 +- .../src/joins/stream_join_utils.rs | 2 +- datafusion/sql/src/utils.rs | 23 +- 30 files changed, 368 insertions(+), 391 deletions(-) diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/rewrite_expr.rs index 9dfc238ab9e8..226f548dd446 100644 --- a/datafusion-examples/examples/rewrite_expr.rs +++ b/datafusion-examples/examples/rewrite_expr.rs @@ -91,7 +91,7 @@ impl AnalyzerRule for MyAnalyzerRule { impl MyAnalyzerRule { fn analyze_plan(plan: LogicalPlan) -> Result { - plan.transform_up(&|plan| { + plan.transform_up_old(&|plan| { Ok(match plan { LogicalPlan::Filter(filter) => { let predicate = Self::analyze_expr(filter.predicate.clone())?; @@ -106,7 +106,7 @@ impl MyAnalyzerRule { } fn analyze_expr(expr: Expr) -> Result { - expr.transform_up(&|expr| { + expr.transform_up_old(&|expr| { // closure is invoked for all sub expressions Ok(match expr { Expr::Literal(ScalarValue::Int64(i)) => { @@ -161,7 +161,7 @@ impl OptimizerRule for MyOptimizerRule { /// use rewrite_expr to modify the expression tree. fn my_rewrite(expr: Expr) -> Result { - expr.transform_up(&|expr| { + expr.transform_up_old(&|expr| { // closure is invoked for all sub expressions Ok(match expr { Expr::Between(Between { diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 39d691a9dcea..e15c43543a22 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -57,19 +57,19 @@ pub trait TreeNode: Sized { F: FnMut(&Self) -> Result, { // Apply `f` on self. - f(self) + f(self)? // If it returns continue (not prune or stop or stop all) then continue // traversal on inner children and children. .and_then_on_continue(|| { // Run the recursive `apply` on each inner children, but as they are // unrelated root nodes of inner trees if any returns stop then continue // with the next one. - self.apply_inner_children(&mut |c| c.visit_down(f).continue_on_stop()) + self.apply_inner_children(&mut |c| c.visit_down(f)?.continue_on_stop())? // Run the recursive `apply` on each children. .and_then_on_continue(|| { self.apply_children(&mut |c| c.visit_down(f)) }) - }) + })? // Applying `f` on self might have returned prune, but we need to propagate // continue. .continue_on_prune() @@ -107,21 +107,21 @@ pub trait TreeNode: Sized { ) -> Result { // Apply `pre_visit` on self. visitor - .pre_visit(self) + .pre_visit(self)? // If it returns continue (not prune or stop or stop all) then continue // traversal on inner children and children. .and_then_on_continue(|| { // Run the recursive `visit` on each inner children, but as they are // unrelated subquery plans if any returns stop then continue with the // next one. - self.apply_inner_children(&mut |c| c.visit(visitor).continue_on_stop()) + self.apply_inner_children(&mut |c| c.visit(visitor)?.continue_on_stop())? // Run the recursive `visit` on each children. .and_then_on_continue(|| { self.apply_children(&mut |c| c.visit(visitor)) - }) + })? // Apply `post_visit` on self. .and_then_on_continue(|| visitor.post_visit(self)) - }) + })? // Applying `pre_visit` or `post_visit` on self might have returned prune, // but we need to propagate continue. .continue_on_prune() @@ -133,31 +133,144 @@ pub trait TreeNode: Sized { ) -> Result { // Apply `pre_transform` on self. transformer - .pre_transform(self) + .pre_transform(self)? // If it returns continue (not prune or stop or stop all) then continue // traversal on inner children and children. .and_then_on_continue(|| // Run the recursive `transform` on each children. self - .transform_children(&mut |c| c.transform(transformer)) + .transform_children(&mut |c| c.transform(transformer))? // Apply `post_transform` on new self. - .and_then_on_continue(|| { - transformer.post_transform(self) - })) + .and_then_on_continue(|| transformer.post_transform(self)))? // Applying `pre_transform` or `post_transform` on self might have returned // prune, but we need to propagate continue. .continue_on_prune() } + fn transform_with_payload( + &mut self, + f_down: &mut FD, + payload_down: Option, + f_up: &mut FU, + ) -> Result<(TreeNodeRecursion, Option)> + where + FD: FnMut(&mut Self, Option) -> Result<(TreeNodeRecursion, Vec)>, + FU: FnMut(&mut Self, Vec) -> Result<(TreeNodeRecursion, PU)>, + { + // Apply `f_down` on self. + let (tnr, new_payload_down) = f_down(self, payload_down)?; + let mut new_payload_down_iter = new_payload_down.into_iter(); + // If it returns continue (not prune or stop or stop all) then continue traversal + // on inner children and children. + let mut new_payload_up = None; + tnr.and_then_on_continue(|| { + // Run the recursive `transform` on each children. + let mut payload_up = vec![]; + let tnr = self.transform_children(&mut |c| { + let (tnr, p) = + c.transform_with_payload(f_down, new_payload_down_iter.next(), f_up)?; + p.into_iter().for_each(|p| payload_up.push(p)); + Ok(tnr) + })?; + // Apply `f_up` on self. + tnr.and_then_on_continue(|| { + let (tnr, np) = f_up(self, payload_up)?; + new_payload_up = Some(np); + Ok(tnr) + }) + })? + // Applying `f_down` or `f_up` on self might have returned prune, but we need to propagate + // continue. + .continue_on_prune() + .map(|tnr| (tnr, new_payload_up)) + } + + fn transform_down(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + // Apply `f` on self. + f(self)? + // If it returns continue (not prune or stop or stop all) then continue + // traversal on inner children and children. + .and_then_on_continue(|| + // Run the recursive `transform` on each children. + self.transform_children(&mut |c| c.transform_down(f)))? + // Applying `f` on self might have returned prune, but we need to propagate + // continue. + .continue_on_prune() + } + + fn transform_down_with_payload( + &mut self, + f: &mut F, + payload: P, + ) -> Result + where + F: FnMut(&mut Self, P) -> Result<(TreeNodeRecursion, Vec

)>, + { + // Apply `f` on self. + let (tnr, new_payload) = f(self, payload)?; + let mut new_payload_iter = new_payload.into_iter(); + // If it returns continue (not prune or stop or stop all) then continue + // traversal on inner children and children. + tnr.and_then_on_continue(|| + // Run the recursive `transform` on each children. + self.transform_children(&mut |c| c.transform_down_with_payload(f, new_payload_iter.next().unwrap())))? + // Applying `f` on self might have returned prune, but we need to propagate + // continue. + .continue_on_prune() + } + + fn transform_up(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + // Run the recursive `transform` on each children. + self.transform_children(&mut |c| c.transform_up(f))? + // Apply `f` on self. + .and_then_on_continue(|| f(self))? + // Applying `f` on self might have returned prune, but we need to propagate + // continue. + .continue_on_prune() + } + + fn transform_up_with_payload( + &mut self, + f: &mut F, + ) -> Result<(TreeNodeRecursion, Option

)> + where + F: FnMut(&mut Self, Vec

) -> Result<(TreeNodeRecursion, P)>, + { + // Run the recursive `transform` on each children. + let mut payload = vec![]; + let tnr = self.transform_children(&mut |c| { + let (tnr, p) = c.transform_up_with_payload(f)?; + p.into_iter().for_each(|p| payload.push(p)); + Ok(tnr) + })?; + let mut new_payload = None; + // Apply `f` on self. + tnr.and_then_on_continue(|| { + let (tnr, np) = f(self, payload)?; + new_payload = Some(np); + Ok(tnr) + })? + // Applying `f` on self might have returned prune, but we need to propagate + // continue. + .continue_on_prune() + .map(|tnr| (tnr, new_payload)) + } + /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its /// children(Preorder Traversal). /// When the `op` does not apply to a given node, it is left unchanged. - fn transform_down(self, op: &F) -> Result + fn transform_down_old(self, op: &F) -> Result where F: Fn(Self) -> Result>, { let after_op = op(self)?.into(); - after_op.map_children(|node| node.transform_down(op)) + after_op.map_children(|node| node.transform_down_old(op)) } /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its @@ -174,11 +287,11 @@ pub trait TreeNode: Sized { /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its /// children and then itself(Postorder Traversal). /// When the `op` does not apply to a given node, it is left unchanged. - fn transform_up(self, op: &F) -> Result + fn transform_up_old(self, op: &F) -> Result where F: Fn(Self) -> Result>, { - let after_op_children = self.map_children(|node| node.transform_up(op))?; + let after_op_children = self.map_children(|node| node.transform_up_old(op))?; let new_node = op(after_op_children)?.into(); Ok(new_node) @@ -402,63 +515,35 @@ pub enum TreeNodeRecursion { } impl TreeNodeRecursion { - fn continue_on_prune(self) -> TreeNodeRecursion { - match self { - TreeNodeRecursion::Prune => TreeNodeRecursion::Continue, - o => o, - } - } - - fn fail_on_prune(self) -> TreeNodeRecursion { - match self { - TreeNodeRecursion::Prune => panic!("Recursion can't prune."), - o => o, - } - } - - fn continue_on_stop(self) -> TreeNodeRecursion { - match self { - TreeNodeRecursion::Stop => TreeNodeRecursion::Continue, - o => o, - } - } -} - -/// This helper trait provide functions to control recursion on -/// [`Result`]. -pub trait TreeNodeRecursionResult: Sized { - fn and_then_on_continue(self, f: F) -> Result - where - F: FnOnce() -> Result; - - fn continue_on_prune(self) -> Result; - - fn fail_on_prune(self) -> Result; - - fn continue_on_stop(self) -> Result; -} - -impl TreeNodeRecursionResult for Result { - fn and_then_on_continue(self, f: F) -> Result + pub fn and_then_on_continue(self, f: F) -> Result where F: FnOnce() -> Result, { - match self? { + match self { TreeNodeRecursion::Continue => f(), o => Ok(o), } } - fn continue_on_prune(self) -> Result { - self.map(|tnr| tnr.continue_on_prune()) + pub fn continue_on_prune(self) -> Result { + Ok(match self { + TreeNodeRecursion::Prune => TreeNodeRecursion::Continue, + o => o, + }) } - fn fail_on_prune(self) -> Result { - self.map(|tnr| tnr.fail_on_prune()) + pub fn fail_on_prune(self) -> Result { + Ok(match self { + TreeNodeRecursion::Prune => panic!("Recursion can't prune."), + o => o, + }) } - fn continue_on_stop(self) -> Result { - self.map(|tnr| tnr.continue_on_stop()) + pub fn continue_on_stop(self) -> Result { + Ok(match self { + TreeNodeRecursion::Stop => TreeNodeRecursion::Continue, + o => o, + }) } } diff --git a/datafusion/core/src/physical_optimizer/coalesce_batches.rs b/datafusion/core/src/physical_optimizer/coalesce_batches.rs index 7b66ca529094..19a8701a1003 100644 --- a/datafusion/core/src/physical_optimizer/coalesce_batches.rs +++ b/datafusion/core/src/physical_optimizer/coalesce_batches.rs @@ -52,7 +52,7 @@ impl PhysicalOptimizerRule for CoalesceBatches { } let target_batch_size = config.execution.batch_size; - plan.transform_up(&|plan| { + plan.transform_up_old(&|plan| { let plan_any = plan.as_any(); // The goal here is to detect operators that could produce small batches and only // wrap those ones with a CoalesceBatchesExec operator. An alternate approach here diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index 5878650a49e3..09963edd5979 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -26,7 +26,7 @@ use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGro use crate::physical_plan::ExecutionPlan; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{AggregateExpr, PhysicalExpr}; @@ -48,27 +48,27 @@ impl CombinePartialFinalAggregate { impl PhysicalOptimizerRule for CombinePartialFinalAggregate { fn optimize( &self, - plan: Arc, + mut plan: Arc, _config: &ConfigOptions, ) -> Result> { - plan.transform_down(&|plan| { - let transformed = - plan.as_any() - .downcast_ref::() - .and_then(|agg_exec| { - if matches!( - agg_exec.mode(), - AggregateMode::Final | AggregateMode::FinalPartitioned - ) { - agg_exec - .input() - .as_any() - .downcast_ref::() - .and_then(|input_agg_exec| { - if matches!( - input_agg_exec.mode(), - AggregateMode::Partial - ) && can_combine( + plan.transform_down(&mut |plan| { + plan.clone() + .as_any() + .downcast_ref::() + .into_iter() + .for_each(|agg_exec| { + if matches!( + agg_exec.mode(), + AggregateMode::Final | AggregateMode::FinalPartitioned + ) { + agg_exec + .input() + .as_any() + .downcast_ref::() + .into_iter() + .for_each(|input_agg_exec| { + if matches!(input_agg_exec.mode(), AggregateMode::Partial) + && can_combine( ( agg_exec.group_by(), agg_exec.aggr_expr(), @@ -79,41 +79,34 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate { input_agg_exec.aggr_expr(), input_agg_exec.filter_expr(), ), - ) { - let mode = - if agg_exec.mode() == &AggregateMode::Final { - AggregateMode::Single - } else { - AggregateMode::SinglePartitioned - }; - AggregateExec::try_new( - mode, - input_agg_exec.group_by().clone(), - input_agg_exec.aggr_expr().to_vec(), - input_agg_exec.filter_expr().to_vec(), - input_agg_exec.input().clone(), - input_agg_exec.input_schema(), - ) - .map(|combined_agg| { - combined_agg.with_limit(agg_exec.limit()) - }) - .ok() - .map(Arc::new) + ) + { + let mode = if agg_exec.mode() == &AggregateMode::Final + { + AggregateMode::Single } else { - None - } - }) - } else { - None - } - }); - - Ok(if let Some(transformed) = transformed { - Transformed::Yes(transformed) - } else { - Transformed::No(plan) - }) - }) + AggregateMode::SinglePartitioned + }; + AggregateExec::try_new( + mode, + input_agg_exec.group_by().clone(), + input_agg_exec.aggr_expr().to_vec(), + input_agg_exec.filter_expr().to_vec(), + input_agg_exec.input().clone(), + input_agg_exec.input_schema(), + ) + .map(|combined_agg| { + combined_agg.with_limit(agg_exec.limit()) + }) + .into_iter() + .for_each(|p| *plan = Arc::new(p)) + } + }) + } + }); + Ok(TreeNodeRecursion::Continue) + })?; + Ok(plan) } fn name(&self) -> &str { @@ -178,7 +171,7 @@ fn normalize_group_exprs(group_exprs: GroupExprsRef) -> GroupExprs { fn discard_column_index(group_expr: Arc) -> Arc { group_expr .clone() - .transform_up(&|expr| { + .transform_up_old(&|expr| { let normalized_form: Option> = match expr.as_any().downcast_ref::() { Some(column) => Some(Arc::new(Column::new(column.name(), 0))), diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index 9392d443e150..b54ec2d6a7f0 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -200,11 +200,11 @@ impl PhysicalOptimizerRule for EnforceDistribution { // Run a top-down process to adjust input key ordering recursively let plan_requirements = PlanWithKeyRequirements::new(plan); let adjusted = - plan_requirements.transform_down(&adjust_input_keys_ordering)?; + plan_requirements.transform_down_old(&adjust_input_keys_ordering)?; adjusted.plan } else { // Run a bottom-up process - plan.transform_up(&|plan| { + plan.transform_up_old(&|plan| { Ok(Transformed::Yes(reorder_join_keys_to_inputs(plan)?)) })? }; @@ -212,7 +212,7 @@ impl PhysicalOptimizerRule for EnforceDistribution { let distribution_context = DistributionContext::new(adjusted); // Distribution enforcement needs to be applied bottom-up. let distribution_context = - distribution_context.transform_up(&|distribution_context| { + distribution_context.transform_up_old(&|distribution_context| { ensure_distribution(distribution_context, config) })?; Ok(distribution_context.plan) diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index 9a57a030fcc6..7512f6e8aa2c 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -41,7 +41,7 @@ use crate::error::Result; use crate::physical_optimizer::replace_with_order_preserving_variants::{ replace_with_order_preserving_variants, OrderPreservationContext, }; -use crate::physical_optimizer::sort_pushdown::{pushdown_sorts, SortPushDown}; +use crate::physical_optimizer::sort_pushdown::pushdown_requirement_to_children; use crate::physical_optimizer::utils::{ add_sort_above, is_coalesce_partitions, is_limit, is_repartition, is_sort, is_sort_preserving_merge, is_union, is_window, ExecTree, @@ -340,19 +340,19 @@ impl PhysicalOptimizerRule for EnforceSorting { let plan_requirements = PlanWithCorrespondingSort::new(plan); // Execute a bottom-up traversal to enforce sorting requirements, // remove unnecessary sorts, and optimize sort-sensitive operators: - let adjusted = plan_requirements.transform_up(&ensure_sorting)?; + let adjusted = plan_requirements.transform_up_old(&ensure_sorting)?; let new_plan = if config.optimizer.repartition_sorts { let plan_with_coalesce_partitions = PlanWithCorrespondingCoalescePartitions::new(adjusted.plan); let parallel = - plan_with_coalesce_partitions.transform_up(¶llelize_sorts)?; + plan_with_coalesce_partitions.transform_up_old(¶llelize_sorts)?; parallel.plan } else { adjusted.plan }; let plan_with_pipeline_fixer = OrderPreservationContext::new(new_plan); - let updated_plan = - plan_with_pipeline_fixer.transform_up(&|plan_with_pipeline_fixer| { + let mut updated_plan = + plan_with_pipeline_fixer.transform_up_old(&|plan_with_pipeline_fixer| { replace_with_order_preserving_variants( plan_with_pipeline_fixer, false, @@ -363,9 +363,64 @@ impl PhysicalOptimizerRule for EnforceSorting { // Execute a top-down traversal to exploit sort push-down opportunities // missed by the bottom-up traversal: - let sort_pushdown = SortPushDown::init(updated_plan.plan); - let adjusted = sort_pushdown.transform_down(&pushdown_sorts)?; - Ok(adjusted.plan) + updated_plan.plan.transform_down_with_payload( + &mut |plan, required_ordering: Option>| { + let parent_required = required_ordering.as_deref().unwrap_or(&[]); + if let Some(sort_exec) = plan.as_any().downcast_ref::() { + let new_plan = if !plan + .equivalence_properties() + .ordering_satisfy_requirement(parent_required) + { + // If the current plan is a SortExec, modify it to satisfy parent requirements: + let mut new_plan = sort_exec.input().clone(); + add_sort_above(&mut new_plan, parent_required, sort_exec.fetch()); + new_plan + } else { + plan.clone() + }; + let required_ordering = new_plan + .output_ordering() + .map(PhysicalSortRequirement::from_sort_exprs) + .unwrap_or_default(); + // Since new_plan is a SortExec, we can safely get the 0th index. + let child = new_plan.children().swap_remove(0); + if let Some(adjusted) = + pushdown_requirement_to_children(&child, &required_ordering)? + { + *plan = child; + Ok((TreeNodeRecursion::Continue, adjusted)) + } else { + *plan = new_plan; + // Can not push down requirements + Ok((TreeNodeRecursion::Continue, plan.required_input_ordering())) + } + } else { + // Executors other than SortExec + if plan + .equivalence_properties() + .ordering_satisfy_requirement(parent_required) + { + // Satisfies parent requirements, immediately return. + return Ok(( + TreeNodeRecursion::Continue, + plan.required_input_ordering(), + )); + } + // Can not satisfy the parent requirements, check whether the requirements can be pushed down: + if let Some(adjusted) = + pushdown_requirement_to_children(plan, parent_required)? + { + Ok((TreeNodeRecursion::Continue, adjusted)) + } else { + // Can not push down requirements, add new SortExec: + add_sort_above(plan, parent_required, None); + Ok((TreeNodeRecursion::Continue, plan.required_input_ordering())) + } + } + }, + None, + )?; + Ok(updated_plan.plan) } fn name(&self) -> &str { diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index 6b2fe24acf00..66a27aa7cb6b 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -237,7 +237,8 @@ impl PhysicalOptimizerRule for JoinSelection { Box::new(hash_join_convert_symmetric_subrule), Box::new(hash_join_swap_subrule), ]; - let state = pipeline.transform_up(&|p| apply_subrules(p, &subrules, config))?; + let state = + pipeline.transform_up_old(&|p| apply_subrules(p, &subrules, config))?; // Next, we apply another subrule that tries to optimize joins using any // statistics their inputs might have. // - For a hash join with partition mode [`PartitionMode::Auto`], we will @@ -251,7 +252,7 @@ impl PhysicalOptimizerRule for JoinSelection { // side is the small side. let config = &config.optimizer; let collect_left_threshold = config.hash_join_single_partition_threshold; - state.plan.transform_up(&|plan| { + state.plan.transform_up_old(&|plan| { statistical_join_selection_subrule(plan, collect_left_threshold) }) } diff --git a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs index 540f9a6a132b..249537534ada 100644 --- a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs +++ b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs @@ -160,7 +160,7 @@ impl PhysicalOptimizerRule for LimitedDistinctAggregation { config: &ConfigOptions, ) -> Result> { let plan = if config.optimizer.enable_distinct_aggregation_soft_limit { - plan.transform_down(&|plan| { + plan.transform_down_old(&|plan| { Ok( if let Some(plan) = LimitedDistinctAggregation::transform_limit(plan.clone()) diff --git a/datafusion/core/src/physical_optimizer/output_requirements.rs b/datafusion/core/src/physical_optimizer/output_requirements.rs index f8bf3bb965e8..c817f6b4ad35 100644 --- a/datafusion/core/src/physical_optimizer/output_requirements.rs +++ b/datafusion/core/src/physical_optimizer/output_requirements.rs @@ -192,7 +192,7 @@ impl PhysicalOptimizerRule for OutputRequirements { ) -> Result> { match self.mode { RuleMode::Add => require_top_ordering(plan), - RuleMode::Remove => plan.transform_up(&|plan| { + RuleMode::Remove => plan.transform_up_old(&|plan| { if let Some(sort_req) = plan.as_any().downcast_ref::() { diff --git a/datafusion/core/src/physical_optimizer/pipeline_checker.rs b/datafusion/core/src/physical_optimizer/pipeline_checker.rs index 122ce7171bd3..9176bab57656 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_checker.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_checker.rs @@ -54,7 +54,7 @@ impl PhysicalOptimizerRule for PipelineChecker { ) -> Result> { let pipeline = PipelineStatePropagator::new(plan); let state = pipeline - .transform_up(&|p| check_finiteness_requirements(p, &config.optimizer))?; + .transform_up_old(&|p| check_finiteness_requirements(p, &config.optimizer))?; Ok(state.plan) } diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index e2b290f3f5ce..6b7e139e711a 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -72,7 +72,7 @@ impl PhysicalOptimizerRule for ProjectionPushdown { plan: Arc, _config: &ConfigOptions, ) -> Result> { - plan.transform_down(&remove_unnecessary_projections) + plan.transform_down_old(&remove_unnecessary_projections) } fn name(&self) -> &str { diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 2423ccc4c32e..41a3a9397bbc 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -678,7 +678,7 @@ fn rewrite_column_expr( column_old: &phys_expr::Column, column_new: &phys_expr::Column, ) -> Result> { - e.transform_up(&|expr| { + e.transform_up_old(&|expr| { if let Some(column) = expr.as_any().downcast_ref::() { if column == column_old { return Ok(Transformed::Yes(Arc::new(column_new.clone()))); diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index 21602487640f..4b001d67aca9 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -344,7 +344,7 @@ mod tests { // let optimized_physical_plan = physical_plan.transform_down(&replace_repartition_execs)?; let config = SessionConfig::new().with_prefer_existing_sort($ALLOW_BOUNDED); let plan_with_pipeline_fixer = OrderPreservationContext::new(physical_plan); - let parallel = plan_with_pipeline_fixer.transform_up(&|plan_with_pipeline_fixer| replace_with_order_preserving_variants(plan_with_pipeline_fixer, false, false, config.options()))?; + let parallel = plan_with_pipeline_fixer.transform_up_old(&|plan_with_pipeline_fixer| replace_with_order_preserving_variants(plan_with_pipeline_fixer, false, false, config.options()))?; let optimized_physical_plan = parallel.plan; // Get string representation of the plan diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs b/datafusion/core/src/physical_optimizer/sort_pushdown.rs index 4b06218df9e9..d06adb82c83e 100644 --- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs @@ -18,19 +18,15 @@ use std::sync::Arc; use crate::physical_optimizer::utils::{ - add_sort_above, is_limit, is_sort_preserving_merge, is_union, is_window, + is_limit, is_sort_preserving_merge, is_union, is_window, }; use crate::physical_plan::filter::FilterExec; use crate::physical_plan::joins::utils::calculate_join_output_ordering; use crate::physical_plan::joins::{HashJoinExec, SortMergeJoinExec}; use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::repartition::RepartitionExec; -use crate::physical_plan::sorts::sort::SortExec; -use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; +use crate::physical_plan::ExecutionPlan; -use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeRecursion, VisitRecursionIterator, -}; use datafusion_common::{plan_err, DataFusionError, JoinSide, Result}; use datafusion_expr::JoinType; use datafusion_physical_expr::expressions::Column; @@ -38,162 +34,7 @@ use datafusion_physical_expr::{ LexRequirementRef, PhysicalSortExpr, PhysicalSortRequirement, }; -use itertools::izip; - -/// This is a "data class" we use within the [`EnforceSorting`] rule to push -/// down [`SortExec`] in the plan. In some cases, we can reduce the total -/// computational cost by pushing down `SortExec`s through some executors. -/// -/// [`EnforceSorting`]: crate::physical_optimizer::enforce_sorting::EnforceSorting -#[derive(Debug, Clone)] -pub(crate) struct SortPushDown { - /// Current plan - pub plan: Arc, - /// Parent required sort ordering - required_ordering: Option>, - /// The adjusted request sort ordering to children. - /// By default they are the same as the plan's required input ordering, but can be adjusted based on parent required sort ordering properties. - adjusted_request_ordering: Vec>>, -} - -impl SortPushDown { - pub fn init(plan: Arc) -> Self { - let request_ordering = plan.required_input_ordering(); - SortPushDown { - plan, - required_ordering: None, - adjusted_request_ordering: request_ordering, - } - } - - pub fn children(&self) -> Vec { - izip!( - self.plan.children().into_iter(), - self.adjusted_request_ordering.clone().into_iter(), - ) - .map(|(child, from_parent)| { - let child_request_ordering = child.required_input_ordering(); - SortPushDown { - plan: child, - required_ordering: from_parent, - adjusted_request_ordering: child_request_ordering, - } - }) - .collect() - } -} - -impl TreeNode for SortPushDown { - fn apply_children(&self, f: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - self.children().iter().for_each_till_continue(f) - } - - fn map_children(mut self, transform: F) -> Result - where - F: FnMut(Self) -> Result, - { - let children = self.children(); - if !children.is_empty() { - let children_plans = children - .into_iter() - .map(transform) - .map(|r| r.map(|s| s.plan)) - .collect::>>()?; - - match with_new_children_if_necessary(self.plan, children_plans)? { - Transformed::Yes(plan) | Transformed::No(plan) => { - self.plan = plan; - } - } - }; - Ok(self) - } - - fn transform_children(&mut self, f: &mut F) -> Result - where - F: FnMut(&mut Self) -> Result, - { - let mut children = self.children(); - if !children.is_empty() { - let tnr = children.iter_mut().for_each_till_continue(f)?; - let children_plans = children.into_iter().map(|c| c.plan).collect(); - self.plan = - with_new_children_if_necessary(self.plan.clone(), children_plans)?.into(); - Ok(tnr) - } else { - Ok(TreeNodeRecursion::Continue) - } - } -} - -pub(crate) fn pushdown_sorts( - requirements: SortPushDown, -) -> Result> { - let plan = &requirements.plan; - let parent_required = requirements.required_ordering.as_deref().unwrap_or(&[]); - if let Some(sort_exec) = plan.as_any().downcast_ref::() { - let new_plan = if !plan - .equivalence_properties() - .ordering_satisfy_requirement(parent_required) - { - // If the current plan is a SortExec, modify it to satisfy parent requirements: - let mut new_plan = sort_exec.input().clone(); - add_sort_above(&mut new_plan, parent_required, sort_exec.fetch()); - new_plan - } else { - requirements.plan - }; - let required_ordering = new_plan - .output_ordering() - .map(PhysicalSortRequirement::from_sort_exprs) - .unwrap_or_default(); - // Since new_plan is a SortExec, we can safely get the 0th index. - let child = new_plan.children().swap_remove(0); - if let Some(adjusted) = - pushdown_requirement_to_children(&child, &required_ordering)? - { - // Can push down requirements - Ok(Transformed::Yes(SortPushDown { - plan: child, - required_ordering: None, - adjusted_request_ordering: adjusted, - })) - } else { - // Can not push down requirements - Ok(Transformed::Yes(SortPushDown::init(new_plan))) - } - } else { - // Executors other than SortExec - if plan - .equivalence_properties() - .ordering_satisfy_requirement(parent_required) - { - // Satisfies parent requirements, immediately return. - return Ok(Transformed::Yes(SortPushDown { - required_ordering: None, - ..requirements - })); - } - // Can not satisfy the parent requirements, check whether the requirements can be pushed down: - if let Some(adjusted) = pushdown_requirement_to_children(plan, parent_required)? { - Ok(Transformed::Yes(SortPushDown { - plan: requirements.plan, - required_ordering: None, - adjusted_request_ordering: adjusted, - })) - } else { - // Can not push down requirements, add new SortExec: - let mut new_plan = requirements.plan; - add_sort_above(&mut new_plan, parent_required, None); - Ok(Transformed::Yes(SortPushDown::init(new_plan))) - } - } -} - -fn pushdown_requirement_to_children( +pub fn pushdown_requirement_to_children( plan: &Arc, parent_required: LexRequirementRef, ) -> Result>>>> { diff --git a/datafusion/core/src/physical_optimizer/topk_aggregation.rs b/datafusion/core/src/physical_optimizer/topk_aggregation.rs index dd0261420304..f00c44b3234f 100644 --- a/datafusion/core/src/physical_optimizer/topk_aggregation.rs +++ b/datafusion/core/src/physical_optimizer/topk_aggregation.rs @@ -138,7 +138,7 @@ impl PhysicalOptimizerRule for TopKAggregation { config: &ConfigOptions, ) -> Result> { let plan = if config.optimizer.enable_topk_aggregation { - plan.transform_down(&|plan| { + plan.transform_down_old(&|plan| { Ok( if let Some(plan) = TopKAggregation::transform_sort(plan.clone()) { Transformed::Yes(plan) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 5369d502113b..93ec2f369b41 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1150,7 +1150,7 @@ impl Expr { /// For example, gicen an expression like ` = $0` will infer `$0` to /// have type `int32`. pub fn infer_placeholder_types(self, schema: &DFSchema) -> Result { - self.transform_up(&|mut expr| { + self.transform_up_old(&|mut expr| { // Default to assuming the arguments are the same type if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut expr { rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?; diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index cbdeb16f99b2..a91cd408aed1 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -33,7 +33,7 @@ pub use order_by::rewrite_sort_cols_by_aggs; /// Recursively call [`Column::normalize_with_schemas`] on all [`Column`] expressions /// in the `expr` expression tree. pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result { - expr.transform_up(&|expr| { + expr.transform_up_old(&|expr| { Ok({ if let Expr::Column(c) = expr { let col = LogicalPlanBuilder::normalize(plan, c)?; @@ -57,7 +57,7 @@ pub fn normalize_col_with_schemas( schemas: &[&Arc], using_columns: &[HashSet], ) -> Result { - expr.transform_up(&|expr| { + expr.transform_up_old(&|expr| { Ok({ if let Expr::Column(c) = expr { let col = c.normalize_with_schemas(schemas, using_columns)?; @@ -75,7 +75,7 @@ pub fn normalize_col_with_schemas_and_ambiguity_check( schemas: &[&[&DFSchema]], using_columns: &[HashSet], ) -> Result { - expr.transform_up(&|expr| { + expr.transform_up_old(&|expr| { Ok({ if let Expr::Column(c) = expr { let col = @@ -102,7 +102,7 @@ pub fn normalize_cols( /// Recursively replace all [`Column`] expressions in a given expression tree with /// `Column` expressions provided by the hash map argument. pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Result { - expr.transform_up(&|expr| { + expr.transform_up_old(&|expr| { Ok({ if let Expr::Column(c) = &expr { match replace_map.get(c) { @@ -122,7 +122,7 @@ pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Resul /// For example, if there were expressions like `foo.bar` this would /// rewrite it to just `bar`. pub fn unnormalize_col(expr: Expr) -> Expr { - expr.transform_up(&|expr| { + expr.transform_up_old(&|expr| { Ok({ if let Expr::Column(c) = expr { let col = Column { @@ -164,7 +164,7 @@ pub fn unnormalize_cols(exprs: impl IntoIterator) -> Vec { /// Recursively remove all the ['OuterReferenceColumn'] and return the inside Column /// in the expression tree. pub fn strip_outer_reference(expr: Expr) -> Expr { - expr.transform_up(&|expr| { + expr.transform_up_old(&|expr| { Ok({ if let Expr::OuterReferenceColumn(_, col) = expr { Transformed::Yes(Expr::Column(col)) @@ -307,14 +307,14 @@ mod test { // rewrites "foo" --> "bar" let rewritten = col("state") .eq(lit("foo")) - .transform_up(&transformer) + .transform_up_old(&transformer) .unwrap(); assert_eq!(rewritten, col("state").eq(lit("bar"))); // doesn't rewrite let rewritten = col("state") .eq(lit("baz")) - .transform_up(&transformer) + .transform_up_old(&transformer) .unwrap(); assert_eq!(rewritten, col("state").eq(lit("baz"))); } diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index 1e7efcafd04d..e275487c4574 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -83,7 +83,7 @@ fn rewrite_in_terms_of_projection( ) -> Result { // assumption is that each item in exprs, such as "b + c" is // available as an output column named "b + c" - expr.transform_up(&|expr| { + expr.transform_up_old(&|expr| { // search for unnormalized names first such as "c1" (such as aliases) if let Some(found) = proj_exprs.iter().find(|a| (**a) == expr) { let col = Expr::Column( diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 3d8a8356f397..013a7b673265 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -42,8 +42,7 @@ use crate::{ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeRecursion, TreeNodeRecursionResult, - TreeNodeTransformer, VisitRecursionIterator, + Transformed, TreeNode, TreeNodeRecursion, TreeNodeTransformer, VisitRecursionIterator, }; use datafusion_common::{ aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, @@ -316,7 +315,7 @@ impl LogicalPlan { where F: FnMut(&Expr) -> Result, { - let f = &mut |e: &Expr| f(e).fail_on_prune(); + let f = &mut |e: &Expr| f(e)?.fail_on_prune(); match self { LogicalPlan::Projection(Projection { expr, .. }) => { @@ -352,7 +351,7 @@ impl LogicalPlan { on.iter() // it not ideal to create an expr here to analyze them, but could cache it on the Join itself .map(|(l, r)| Expr::eq(l.clone(), r.clone())) - .for_each_till_continue(&mut |e| f(&e)) + .for_each_till_continue(&mut |e| f(&e))? .and_then_on_continue(|| filter.iter().for_each_till_continue(f)) } LogicalPlan::Sort(Sort { expr, .. }) => expr.iter().for_each_till_continue(f), @@ -1154,7 +1153,7 @@ impl LogicalPlan { // LogicalPlan::Subquery (even though it is // actually a Subquery alias) let synthetic_plan = LogicalPlan::Subquery(subquery.clone()); - f(&synthetic_plan).fail_on_prune() + f(&synthetic_plan)?.fail_on_prune() } _ => Ok(TreeNodeRecursion::Continue), }) @@ -1225,7 +1224,7 @@ impl LogicalPlan { expr: Expr, param_values: &ParamValues, ) -> Result { - expr.transform_up(&|expr| { + expr.transform_up_old(&|expr| { match &expr { Expr::Placeholder(Placeholder { id, data_type }) => { let value = @@ -3195,7 +3194,7 @@ digraph { // after transformation, because plan is not the same anymore, // the parent plan is built again with call to LogicalPlan::with_new_inputs -> with_new_exprs let plan = plan - .transform_up(&|plan| match plan { + .transform_up_old(&|plan| match plan { LogicalPlan::TableScan(table) => { let filter = Filter::try_new( external_filter.clone(), diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 8ec4a94204b0..de407063d78f 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -24,9 +24,7 @@ use crate::expr::{ }; use crate::{Expr, GetFieldAccess}; -use datafusion_common::tree_node::{ - TreeNode, TreeNodeRecursion, TreeNodeRecursionResult, VisitRecursionIterator, -}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, VisitRecursionIterator}; use datafusion_common::{internal_err, DataFusionError, Result}; impl TreeNode for Expr { @@ -51,12 +49,12 @@ impl TreeNode for Expr { | Expr::Sort(Sort { expr, .. }) | Expr::InSubquery(InSubquery{ expr, .. }) => f(expr), Expr::GetIndexedField(GetIndexedField { expr, field }) => { - f(expr).and_then_on_continue(|| match field { + f(expr)?.and_then_on_continue(|| match field { GetFieldAccess::ListIndex {key} => { f(key) }, GetFieldAccess::ListRange { start, stop} => { - f(start).and_then_on_continue(|| f(stop)) + f(start)?.and_then_on_continue(|| f(stop)) } GetFieldAccess::NamedStructField { name: _name } => Ok(TreeNodeRecursion::Continue) }) @@ -78,38 +76,38 @@ impl TreeNode for Expr { | Expr::Wildcard {..} | Expr::Placeholder (_) => Ok(TreeNodeRecursion::Continue), Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { - f(left) + f(left)? .and_then_on_continue(|| f(right)) } Expr::Like(Like { expr, pattern, .. }) | Expr::SimilarTo(Like { expr, pattern, .. }) => { - f(expr) + f(expr)? .and_then_on_continue(|| f(pattern)) } Expr::Between(Between { expr, low, high, .. }) => { - f(expr) - .and_then_on_continue(|| f(low)) + f(expr)? + .and_then_on_continue(|| f(low))? .and_then_on_continue(|| f(high)) }, Expr::Case( Case { expr, when_then_expr, else_expr }) => { - expr.as_deref().into_iter().for_each_till_continue(f) + expr.as_deref().into_iter().for_each_till_continue(f)? .and_then_on_continue(|| - when_then_expr.iter().for_each_till_continue(&mut |(w, t)| f(w).and_then_on_continue(|| f(t)))) + when_then_expr.iter().for_each_till_continue(&mut |(w, t)| f(w)?.and_then_on_continue(|| f(t))))? .and_then_on_continue(|| else_expr.as_deref().into_iter().for_each_till_continue(f)) } Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) => { - args.iter().for_each_till_continue(f) - .and_then_on_continue(|| filter.as_deref().into_iter().for_each_till_continue(f)) + args.iter().for_each_till_continue(f)? + .and_then_on_continue(|| filter.as_deref().into_iter().for_each_till_continue(f))? .and_then_on_continue(|| order_by.iter().flatten().for_each_till_continue(f)) } Expr::WindowFunction(WindowFunction { args, partition_by, order_by, .. }) => { - args.iter().for_each_till_continue(f) - .and_then_on_continue(|| partition_by.iter().for_each_till_continue(f)) + args.iter().for_each_till_continue(f)? + .and_then_on_continue(|| partition_by.iter().for_each_till_continue(f))? .and_then_on_continue(|| order_by.iter().for_each_till_continue(f)) } Expr::InList(InList { expr, list, .. }) => { - f(expr) + f(expr)? .and_then_on_continue(|| list.iter().for_each_till_continue(f)) } } @@ -362,14 +360,17 @@ impl TreeNode for Expr { | Expr::Cast(Cast { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) | Expr::Sort(Sort { expr, .. }) - | Expr::InSubquery(InSubquery{ expr, .. }) => f(expr), + | Expr::InSubquery(InSubquery{ expr, .. }) => { + let x = expr; + f(x) + } Expr::GetIndexedField(GetIndexedField { expr, field }) => { - f(expr).and_then_on_continue(|| match field { + f(expr)?.and_then_on_continue(|| match field { GetFieldAccess::ListIndex {key} => { f(key) }, GetFieldAccess::ListRange { start, stop} => { - f(start).and_then_on_continue(|| f(stop)) + f(start)?.and_then_on_continue(|| f(stop)) } GetFieldAccess::NamedStructField { name: _name } => Ok(TreeNodeRecursion::Continue) }) @@ -391,37 +392,37 @@ impl TreeNode for Expr { | Expr::Wildcard {..} | Expr::Placeholder (_) => Ok(TreeNodeRecursion::Continue), Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { - f(left) + f(left)? .and_then_on_continue(|| f(right)) } Expr::Like(Like { expr, pattern, .. }) | Expr::SimilarTo(Like { expr, pattern, .. }) => { - f(expr) + f(expr)? .and_then_on_continue(|| f(pattern)) } Expr::Between(Between { expr, low, high, .. }) => { - f(expr) - .and_then_on_continue(|| f(low)) + f(expr)? + .and_then_on_continue(|| f(low))? .and_then_on_continue(|| f(high)) }, Expr::Case( Case { expr, when_then_expr, else_expr }) => { - expr.as_deref_mut().into_iter().for_each_till_continue(f) + expr.as_deref_mut().into_iter().for_each_till_continue(f)? .and_then_on_continue(|| - when_then_expr.iter_mut().for_each_till_continue(&mut |(w, t)| f(w).and_then_on_continue(|| f(t)))) + when_then_expr.iter_mut().for_each_till_continue(&mut |(w, t)| f(w)?.and_then_on_continue(|| f(t))))? .and_then_on_continue(|| else_expr.as_deref_mut().into_iter().for_each_till_continue(f)) } Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) => { - args.iter_mut().for_each_till_continue(f) - .and_then_on_continue(|| filter.as_deref_mut().into_iter().for_each_till_continue(f)) + args.iter_mut().for_each_till_continue(f)? + .and_then_on_continue(|| filter.as_deref_mut().into_iter().for_each_till_continue(f))? .and_then_on_continue(|| order_by.iter_mut().flatten().for_each_till_continue(f)) } Expr::WindowFunction(WindowFunction { args, partition_by, order_by, .. }) => { - args.iter_mut().for_each_till_continue(f) - .and_then_on_continue(|| partition_by.iter_mut().for_each_till_continue(f)) + args.iter_mut().for_each_till_continue(f)? + .and_then_on_continue(|| partition_by.iter_mut().for_each_till_continue(f))? .and_then_on_continue(|| order_by.iter_mut().for_each_till_continue(f)) } Expr::InList(InList { expr, list, .. }) => { - f(expr) + f(expr)? .and_then_on_continue(|| list.iter_mut().for_each_till_continue(f)) } } diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 17b1ad8cc73f..73309c1882dc 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -47,7 +47,7 @@ impl CountWildcardRule { impl AnalyzerRule for CountWildcardRule { fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { - plan.transform_down(&analyze_internal) + plan.transform_down_old(&analyze_internal) } fn name(&self) -> &str { @@ -155,7 +155,7 @@ impl TreeNodeTransformer for CountWildcardRewriter { let new_plan = subquery .as_ref() .clone() - .transform_down(&analyze_internal)?; + .transform_down_old(&analyze_internal)?; *subquery = Arc::new(new_plan); } Expr::InSubquery(InSubquery { @@ -165,7 +165,7 @@ impl TreeNodeTransformer for CountWildcardRewriter { let new_plan = subquery .as_ref() .clone() - .transform_down(&analyze_internal)?; + .transform_down_old(&analyze_internal)?; *subquery = Arc::new(new_plan); } Expr::Exists(Exists { @@ -175,7 +175,7 @@ impl TreeNodeTransformer for CountWildcardRewriter { let new_plan = subquery .as_ref() .clone() - .transform_down(&analyze_internal)?; + .transform_down_old(&analyze_internal)?; *subquery = Arc::new(new_plan); } _ => {} diff --git a/datafusion/optimizer/src/analyzer/inline_table_scan.rs b/datafusion/optimizer/src/analyzer/inline_table_scan.rs index a418fbf5537b..f2e00ed0763d 100644 --- a/datafusion/optimizer/src/analyzer/inline_table_scan.rs +++ b/datafusion/optimizer/src/analyzer/inline_table_scan.rs @@ -42,7 +42,7 @@ impl InlineTableScan { impl AnalyzerRule for InlineTableScan { fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { - plan.transform_up(&analyze_internal) + plan.transform_up_old(&analyze_internal) } fn name(&self) -> &str { @@ -74,7 +74,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { Transformed::Yes(plan) } LogicalPlan::Filter(filter) => { - let new_expr = filter.predicate.transform_up(&rewrite_subquery)?; + let new_expr = filter.predicate.transform_up_old(&rewrite_subquery)?; Transformed::Yes(LogicalPlan::Filter(Filter::try_new( new_expr, filter.input, @@ -88,7 +88,7 @@ fn rewrite_subquery(expr: Expr) -> Result> { match expr { Expr::Exists(Exists { subquery, negated }) => { let plan = subquery.subquery.as_ref().clone(); - let new_plan = plan.transform_up(&analyze_internal)?; + let new_plan = plan.transform_up_old(&analyze_internal)?; let subquery = subquery.with_plan(Arc::new(new_plan)); Ok(Transformed::Yes(Expr::Exists(Exists { subquery, negated }))) } @@ -98,7 +98,7 @@ fn rewrite_subquery(expr: Expr) -> Result> { negated, }) => { let plan = subquery.subquery.as_ref().clone(); - let new_plan = plan.transform_up(&analyze_internal)?; + let new_plan = plan.transform_up_old(&analyze_internal)?; let subquery = subquery.with_plan(Arc::new(new_plan)); Ok(Transformed::Yes(Expr::InSubquery(InSubquery::new( expr, subquery, negated, @@ -106,7 +106,7 @@ fn rewrite_subquery(expr: Expr) -> Result> { } Expr::ScalarSubquery(subquery) => { let plan = subquery.subquery.as_ref().clone(); - let new_plan = plan.transform_up(&analyze_internal)?; + let new_plan = plan.transform_up_old(&analyze_internal)?; let subquery = subquery.with_plan(Arc::new(new_plan)); Ok(Transformed::Yes(Expr::ScalarSubquery(subquery))) } diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index a68f374d9fe6..3df604a62c8a 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -370,7 +370,7 @@ fn agg_exprs_evaluation_result_on_empty_batch( expr_result_map_for_count_bug: &mut ExprResultMap, ) -> Result<()> { for e in agg_expr.iter() { - let mut result_expr = e.clone().transform_up(&|expr| { + let mut result_expr = e.clone().transform_up_old(&|expr| { let new_expr = match expr { Expr::AggregateFunction(expr::AggregateFunction { func_def, .. }) => { match func_def { @@ -415,7 +415,7 @@ fn proj_exprs_evaluation_result_on_empty_batch( expr_result_map_for_count_bug: &mut ExprResultMap, ) -> Result<()> { for expr in proj_expr.iter() { - let result_expr = expr.clone().transform_up(&|expr| { + let result_expr = expr.clone().transform_up_old(&|expr| { if let Expr::Column(Column { name, .. }) = &expr { if let Some(result_expr) = input_expr_result_map_for_count_bug.get(name) { Ok(Transformed::Yes(result_expr.clone())) @@ -448,7 +448,7 @@ fn filter_exprs_evaluation_result_on_empty_batch( input_expr_result_map_for_count_bug: &ExprResultMap, expr_result_map_for_count_bug: &mut ExprResultMap, ) -> Result> { - let result_expr = filter_expr.clone().transform_up(&|expr| { + let result_expr = filter_expr.clone().transform_up_old(&|expr| { if let Expr::Column(Column { name, .. }) = &expr { if let Some(result_expr) = input_expr_result_map_for_count_bug.get(name) { Ok(Transformed::Yes(result_expr.clone())) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index cb9c06154bad..7a58944e5ac9 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -961,7 +961,7 @@ pub fn replace_cols_by_name( e: Expr, replace_map: &HashMap, ) -> Result { - e.transform_up(&|expr| { + e.transform_up_old(&|expr| { Ok(if let Expr::Column(c) = &expr { match replace_map.get(&c.flat_name()) { Some(new_c) => Transformed::Yes(new_c.clone()), diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 34ed4a9475cb..378e01f4bfa9 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -87,7 +87,7 @@ impl OptimizerRule for ScalarSubqueryToJoin { { if !expr_check_map.is_empty() { rewrite_expr = - rewrite_expr.clone().transform_up(&|expr| { + rewrite_expr.clone().transform_up_old(&|expr| { if let Expr::Column(col) = &expr { if let Some(map_expr) = expr_check_map.get(&col.name) @@ -141,8 +141,9 @@ impl OptimizerRule for ScalarSubqueryToJoin { if let Some(rewrite_expr) = expr_to_rewrite_expr_map.get(expr) { - let new_expr = - rewrite_expr.clone().transform_up(&|expr| { + let new_expr = rewrite_expr + .clone() + .transform_up_old(&|expr| { if let Expr::Column(col) = &expr { if let Some(map_expr) = expr_check_map.get(&col.name) diff --git a/datafusion/physical-expr/src/equivalence.rs b/datafusion/physical-expr/src/equivalence.rs index 4899d69bad58..78e232a6bb59 100644 --- a/datafusion/physical-expr/src/equivalence.rs +++ b/datafusion/physical-expr/src/equivalence.rs @@ -30,7 +30,7 @@ use crate::{ use arrow::datatypes::SchemaRef; use arrow_schema::SortOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_common::{JoinSide, JoinType, Result}; use indexmap::IndexSet; @@ -169,10 +169,10 @@ impl ProjectionMapping { .enumerate() .map(|(expr_idx, (expression, name))| { let target_expr = Arc::new(Column::new(name, expr_idx)) as _; - expression - .clone() - .transform_down(&|e| match e.as_any().downcast_ref::() { - Some(col) => { + let mut source_expr = expression.clone(); + source_expr + .transform_down(&mut |e| { + if let Some(col) = e.as_any().downcast_ref::() { // Sometimes, an expression and its name in the input_schema // doesn't match. This can cause problems, so we make sure // that the expression name matches with the name in `input_schema`. @@ -181,11 +181,11 @@ impl ProjectionMapping { let matching_input_field = input_schema.field(idx); let matching_input_column = Column::new(matching_input_field.name(), idx); - Ok(Transformed::Yes(Arc::new(matching_input_column))) + *e = Arc::new(matching_input_column) } - None => Ok(Transformed::No(e)), + Ok(TreeNodeRecursion::Continue) }) - .map(|source_expr| (source_expr, target_expr)) + .map(|_| (source_expr, target_expr)) }) .collect::>>() .map(|map| Self { map }) @@ -352,7 +352,7 @@ impl EquivalenceGroup { /// class it matches with (if any). pub fn normalize_expr(&self, expr: Arc) -> Arc { expr.clone() - .transform_up(&|expr| { + .transform_up_old(&|expr| { for cls in self.iter() { if cls.contains(&expr) { return Ok(Transformed::Yes(cls.canonical_expr().unwrap())); @@ -752,7 +752,7 @@ pub fn add_offset_to_expr( expr: Arc, offset: usize, ) -> Arc { - expr.transform_down(&|e| match e.as_any().downcast_ref::() { + expr.transform_down_old(&|e| match e.as_any().downcast_ref::() { Some(col) => Ok(Transformed::Yes(Arc::new(Column::new( col.name(), offset + col.index(), @@ -1517,7 +1517,7 @@ impl EquivalenceProperties { /// the given expression. pub fn get_expr_ordering(&self, expr: Arc) -> ExprOrdering { ExprOrdering::new(expr.clone()) - .transform_up(&|expr| Ok(update_ordering(expr, self))) + .transform_up_old(&|expr| Ok(update_ordering(expr, self))) // Guaranteed to always return `Ok`. .unwrap() } diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index d637cf1e54e6..8958c71c585e 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -944,7 +944,7 @@ mod tests { let expr2 = expr .clone() - .transform_up(&|e| { + .transform_up_old(&|e| { let transformed = match e.as_any().downcast_ref::() { Some(lit_value) => match lit_value.value() { @@ -965,7 +965,7 @@ mod tests { let expr3 = expr .clone() - .transform_down(&|e| { + .transform_down_old(&|e| { let transformed = match e.as_any().downcast_ref::() { Some(lit_value) => match lit_value.value() { diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index cf4d0a077e6f..d07a960b0f71 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -282,7 +282,7 @@ pub fn reassign_predicate_columns( schema: &SchemaRef, ignore_not_found: bool, ) -> Result> { - pred.transform_down(&|expr| { + pred.transform_down_old(&|expr| { let expr_any = expr.as_any(); if let Some(column) = expr_any.downcast_ref::() { diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index 64a976a1e39f..5be90a6e3bed 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -281,7 +281,7 @@ pub fn convert_sort_expr_with_filter_schema( if all_columns_are_included { // Since we are sure that one to one column mapping includes all columns, we convert // the sort expression into a filter expression. - let converted_filter_expr = expr.transform_up(&|p| { + let converted_filter_expr = expr.transform_up_old(&|p| { convert_filter_columns(p.as_ref(), &column_map).map(|transformed| { match transformed { Some(transformed) => Transformed::Yes(transformed), diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 616a2fc74932..76c270ffe463 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -33,7 +33,7 @@ use std::collections::HashMap; /// Make a best-effort attempt at resolving all columns in the expression tree pub(crate) fn resolve_columns(expr: &Expr, plan: &LogicalPlan) -> Result { - expr.clone().transform_up(&|nested_expr| { + expr.clone().transform_up_old(&|nested_expr| { match nested_expr { Expr::Column(col) => { let field = plan.schema().field_from_column(&col)?; @@ -66,7 +66,7 @@ pub(crate) fn rebase_expr( base_exprs: &[Expr], plan: &LogicalPlan, ) -> Result { - expr.clone().transform_up(&|nested_expr| { + expr.clone().transform_up_old(&|nested_expr| { if base_exprs.contains(&nested_expr) { Ok(Transformed::Yes(expr_as_column_expr(&nested_expr, plan)?)) } else { @@ -170,16 +170,17 @@ pub(crate) fn resolve_aliases_to_exprs( expr: &Expr, aliases: &HashMap, ) -> Result { - expr.clone().transform_up(&|nested_expr| match nested_expr { - Expr::Column(c) if c.relation.is_none() => { - if let Some(aliased_expr) = aliases.get(&c.name) { - Ok(Transformed::Yes(aliased_expr.clone())) - } else { - Ok(Transformed::No(Expr::Column(c))) + expr.clone() + .transform_up_old(&|nested_expr| match nested_expr { + Expr::Column(c) if c.relation.is_none() => { + if let Some(aliased_expr) = aliases.get(&c.name) { + Ok(Transformed::Yes(aliased_expr.clone())) + } else { + Ok(Transformed::No(Expr::Column(c))) + } } - } - _ => Ok(Transformed::No(nested_expr)), - }) + _ => Ok(Transformed::No(nested_expr)), + }) } /// given a slice of window expressions sharing the same sort key, find their common partition From 8fa80e70d3cfe22335f3e5a9c4a5b91aab6d14ea Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 19 Dec 2023 16:40:34 +0100 Subject: [PATCH 3/9] - refactor `EnforceDistribution` using `transform_down_with_payload()` --- .../enforce_distribution.rs | 251 ++++++------------ 1 file changed, 88 insertions(+), 163 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index b54ec2d6a7f0..1d345bbfaeaf 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -191,17 +191,15 @@ impl EnforceDistribution { impl PhysicalOptimizerRule for EnforceDistribution { fn optimize( &self, - plan: Arc, + mut plan: Arc, config: &ConfigOptions, ) -> Result> { let top_down_join_key_reordering = config.optimizer.top_down_join_key_reordering; let adjusted = if top_down_join_key_reordering { // Run a top-down process to adjust input key ordering recursively - let plan_requirements = PlanWithKeyRequirements::new(plan); - let adjusted = - plan_requirements.transform_down_old(&adjust_input_keys_ordering)?; - adjusted.plan + plan.transform_down_with_payload(&mut adjust_input_keys_ordering, None)?; + plan } else { // Run a bottom-up process plan.transform_up_old(&|plan| { @@ -269,12 +267,15 @@ impl PhysicalOptimizerRule for EnforceDistribution { /// 4) If the current plan is Projection, transform the requirements to the columns before the Projection and push down requirements /// 5) For other types of operators, by default, pushdown the parent requirements to children. /// +type RequiredKeyOrdering = Option>>; + fn adjust_input_keys_ordering( - requirements: PlanWithKeyRequirements, -) -> Result> { - let parent_required = requirements.required_key_ordering.clone(); - let plan_any = requirements.plan.as_any(); - let transformed = if let Some(HashJoinExec { + plan: &mut Arc, + required_key_ordering: RequiredKeyOrdering, +) -> Result<(TreeNodeRecursion, Vec)> { + let parent_required = required_key_ordering.unwrap_or_default().clone(); + let plan_any = plan.as_any(); + if let Some(HashJoinExec { left, right, on, @@ -299,13 +300,15 @@ fn adjust_input_keys_ordering( *null_equals_null, )?) as Arc) }; - Some(reorder_partitioned_join_keys( - requirements.plan.clone(), + let (new_plan, request_key_ordering) = reorder_partitioned_join_keys( + plan.clone(), &parent_required, on, vec![], &join_constructor, - )?) + )?; + *plan = new_plan; + Ok((TreeNodeRecursion::Continue, request_key_ordering)) } PartitionMode::CollectLeft => { let new_right_request = match join_type { @@ -323,15 +326,14 @@ fn adjust_input_keys_ordering( }; // Push down requirements to the right side - Some(PlanWithKeyRequirements { - plan: requirements.plan.clone(), - required_key_ordering: vec![], - request_key_ordering: vec![None, new_right_request], - }) + Ok((TreeNodeRecursion::Continue, vec![None, new_right_request])) } PartitionMode::Auto => { // Can not satisfy, clear the current requirements and generate new empty requirements - Some(PlanWithKeyRequirements::new(requirements.plan.clone())) + Ok(( + TreeNodeRecursion::Continue, + vec![None; plan.children().len()], + )) } } } else if let Some(CrossJoinExec { left, .. }) = @@ -339,14 +341,13 @@ fn adjust_input_keys_ordering( { let left_columns_len = left.schema().fields().len(); // Push down requirements to the right side - Some(PlanWithKeyRequirements { - plan: requirements.plan.clone(), - required_key_ordering: vec![], - request_key_ordering: vec![ + Ok(( + TreeNodeRecursion::Continue, + vec![ None, shift_right_required(&parent_required, left_columns_len), ], - }) + )) } else if let Some(SortMergeJoinExec { left, right, @@ -368,26 +369,38 @@ fn adjust_input_keys_ordering( *null_equals_null, )?) as Arc) }; - Some(reorder_partitioned_join_keys( - requirements.plan.clone(), + let (new_plan, request_key_ordering) = reorder_partitioned_join_keys( + plan.clone(), &parent_required, on, sort_options.clone(), &join_constructor, - )?) + )?; + *plan = new_plan; + Ok((TreeNodeRecursion::Continue, request_key_ordering)) } else if let Some(aggregate_exec) = plan_any.downcast_ref::() { if !parent_required.is_empty() { match aggregate_exec.mode() { - AggregateMode::FinalPartitioned => Some(reorder_aggregate_keys( - requirements.plan.clone(), - &parent_required, - aggregate_exec, - )?), - _ => Some(PlanWithKeyRequirements::new(requirements.plan.clone())), + AggregateMode::FinalPartitioned => { + let (new_plan, request_key_ordering) = reorder_aggregate_keys( + plan.clone(), + &parent_required, + aggregate_exec, + )?; + *plan = new_plan; + Ok((TreeNodeRecursion::Continue, request_key_ordering)) + } + _ => Ok(( + TreeNodeRecursion::Continue, + vec![None; plan.children().len()], + )), } } else { // Keep everything unchanged - None + Ok(( + TreeNodeRecursion::Continue, + vec![None; plan.children().len()], + )) } } else if let Some(proj) = plan_any.downcast_ref::() { let expr = proj.expr(); @@ -396,34 +409,33 @@ fn adjust_input_keys_ordering( // Construct a mapping from new name to the the orginal Column let new_required = map_columns_before_projection(&parent_required, expr); if new_required.len() == parent_required.len() { - Some(PlanWithKeyRequirements { - plan: requirements.plan.clone(), - required_key_ordering: vec![], - request_key_ordering: vec![Some(new_required.clone())], - }) + Ok(( + TreeNodeRecursion::Continue, + vec![Some(new_required.clone())], + )) } else { // Can not satisfy, clear the current requirements and generate new empty requirements - Some(PlanWithKeyRequirements::new(requirements.plan.clone())) + Ok(( + TreeNodeRecursion::Continue, + vec![None; plan.children().len()], + )) } } else if plan_any.downcast_ref::().is_some() || plan_any.downcast_ref::().is_some() || plan_any.downcast_ref::().is_some() { - Some(PlanWithKeyRequirements::new(requirements.plan.clone())) + Ok(( + TreeNodeRecursion::Continue, + vec![None; plan.children().len()], + )) } else { // By default, push down the parent requirements to children - let children_len = requirements.plan.children().len(); - Some(PlanWithKeyRequirements { - plan: requirements.plan.clone(), - required_key_ordering: vec![], - request_key_ordering: vec![Some(parent_required.clone()); children_len], - }) - }; - Ok(if let Some(transformed) = transformed { - Transformed::Yes(transformed) - } else { - Transformed::No(requirements) - }) + let children_len = plan.children().len(); + Ok(( + TreeNodeRecursion::Continue, + vec![Some(parent_required.clone()); children_len], + )) + } } fn reorder_partitioned_join_keys( @@ -432,7 +444,7 @@ fn reorder_partitioned_join_keys( on: &[(Column, Column)], sort_options: Vec, join_constructor: &F, -) -> Result +) -> Result<(Arc, Vec)> where F: Fn((Vec<(Column, Column)>, Vec)) -> Result>, { @@ -455,27 +467,21 @@ where new_sort_options.push(sort_options[new_positions[idx]]) } - Ok(PlanWithKeyRequirements { - plan: join_constructor((new_join_on, new_sort_options))?, - required_key_ordering: vec![], - request_key_ordering: vec![Some(left_keys), Some(right_keys)], - }) + Ok(( + join_constructor((new_join_on, new_sort_options))?, + vec![Some(left_keys), Some(right_keys)], + )) } else { - Ok(PlanWithKeyRequirements { - plan: join_plan, - required_key_ordering: vec![], - request_key_ordering: vec![Some(left_keys), Some(right_keys)], - }) + Ok((join_plan, vec![Some(left_keys), Some(right_keys)])) } } else { - Ok(PlanWithKeyRequirements { - plan: join_plan, - required_key_ordering: vec![], - request_key_ordering: vec![ + Ok(( + join_plan, + vec![ Some(join_key_pairs.left_keys), Some(join_key_pairs.right_keys), ], - }) + )) } } @@ -483,7 +489,7 @@ fn reorder_aggregate_keys( agg_plan: Arc, parent_required: &[Arc], agg_exec: &AggregateExec, -) -> Result { +) -> Result<(Arc, Vec)> { let output_columns = agg_exec .group_by() .expr() @@ -501,11 +507,15 @@ fn reorder_aggregate_keys( || !agg_exec.group_by().null_expr().is_empty() || physical_exprs_equal(&output_exprs, parent_required) { - Ok(PlanWithKeyRequirements::new(agg_plan)) + let request_key_ordering = vec![None; agg_plan.children().len()]; + Ok((agg_plan, request_key_ordering)) } else { let new_positions = expected_expr_positions(&output_exprs, parent_required); match new_positions { - None => Ok(PlanWithKeyRequirements::new(agg_plan)), + None => { + let request_key_ordering = vec![None; agg_plan.children().len()]; + Ok((agg_plan, request_key_ordering)) + } Some(positions) => { let new_partial_agg = if let Some(agg_exec) = agg_exec.input().as_any().downcast_ref::() @@ -577,11 +587,13 @@ fn reorder_aggregate_keys( .push((Arc::new(Column::new(name, idx)) as _, name.clone())) } // TODO merge adjacent Projections if there are - Ok(PlanWithKeyRequirements::new(Arc::new( - ProjectionExec::try_new(proj_exprs, new_final_agg)?, - ))) + let new_plan = + Arc::new(ProjectionExec::try_new(proj_exprs, new_final_agg)?); + let request_key_ordering = vec![None; new_plan.children().len()]; + Ok((new_plan, request_key_ordering)) } else { - Ok(PlanWithKeyRequirements::new(agg_plan)) + let request_key_ordering = vec![None; agg_plan.children().len()]; + Ok((agg_plan, request_key_ordering)) } } } @@ -1539,93 +1551,6 @@ struct JoinKeyPairs { right_keys: Vec>, } -#[derive(Debug, Clone)] -struct PlanWithKeyRequirements { - plan: Arc, - /// Parent required key ordering - required_key_ordering: Vec>, - /// The request key ordering to children - request_key_ordering: Vec>>>, -} - -impl PlanWithKeyRequirements { - fn new(plan: Arc) -> Self { - let children_len = plan.children().len(); - PlanWithKeyRequirements { - plan, - required_key_ordering: vec![], - request_key_ordering: vec![None; children_len], - } - } - - fn children(&self) -> Vec { - let plan_children = self.plan.children(); - assert_eq!(plan_children.len(), self.request_key_ordering.len()); - plan_children - .into_iter() - .zip(self.request_key_ordering.clone()) - .map(|(child, required)| { - let from_parent = required.unwrap_or_default(); - let length = child.children().len(); - PlanWithKeyRequirements { - plan: child, - required_key_ordering: from_parent, - request_key_ordering: vec![None; length], - } - }) - .collect() - } -} - -impl TreeNode for PlanWithKeyRequirements { - fn apply_children(&self, f: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - self.children().iter().for_each_till_continue(f) - } - - fn map_children(self, transform: F) -> Result - where - F: FnMut(Self) -> Result, - { - let children = self.children(); - if !children.is_empty() { - let new_children: Result> = - children.into_iter().map(transform).collect(); - - let children_plans = new_children? - .into_iter() - .map(|child| child.plan) - .collect::>(); - let new_plan = with_new_children_if_necessary(self.plan, children_plans)?; - Ok(PlanWithKeyRequirements { - plan: new_plan.into(), - required_key_ordering: self.required_key_ordering, - request_key_ordering: self.request_key_ordering, - }) - } else { - Ok(self) - } - } - - fn transform_children(&mut self, f: &mut F) -> Result - where - F: FnMut(&mut Self) -> Result, - { - let mut children = self.children(); - if !children.is_empty() { - let tnr = children.iter_mut().for_each_till_continue(f)?; - let children_plans = children.into_iter().map(|c| c.plan).collect(); - self.plan = - with_new_children_if_necessary(self.plan.clone(), children_plans)?.into(); - Ok(tnr) - } else { - Ok(TreeNodeRecursion::Continue) - } - } -} - /// Since almost all of these tests explicitly use `ParquetExec` they only run with the parquet feature flag on #[cfg(feature = "parquet")] #[cfg(test)] From f4d28e0d7a340c037e3e9729cd788a28149cfb1e Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 20 Dec 2023 15:47:22 +0100 Subject: [PATCH 4/9] make `TreeNode` methods naming consitent, after this change we have `visit()`, `visit_down()`, `transform()`, `transform_down()`, `transform_up()`, `transform_with_payload()`, `transform_down_with_payload()` and `transform_up_with_payload()` functions on `TreeNode`, others can be deprecated and removed once no longer used --- datafusion/common/src/tree_node.rs | 20 +++++++++---------- .../enforce_distribution.rs | 2 +- .../src/physical_optimizer/enforce_sorting.rs | 4 ++-- .../physical_optimizer/pipeline_checker.rs | 2 +- .../replace_with_order_preserving_variants.rs | 2 +- datafusion/expr/src/logical_plan/plan.rs | 12 +++++------ datafusion/expr/src/tree_node/expr.rs | 2 +- datafusion/expr/src/tree_node/plan.rs | 6 +++--- datafusion/optimizer/src/analyzer/mod.rs | 2 +- datafusion/optimizer/src/analyzer/subquery.rs | 10 +++++----- .../physical-expr/src/sort_properties.rs | 2 +- datafusion/physical-expr/src/utils/mod.rs | 2 +- 12 files changed, 33 insertions(+), 33 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index e15c43543a22..1cc73bdd1e0d 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -23,8 +23,8 @@ use std::sync::Arc; use crate::Result; /// Defines a tree node that can have children of the same type as the parent node. The -/// implementations must provide [`TreeNode::apply_children()`] and -/// [`TreeNode::map_children()`] for visiting and changing the structure of the tree. +/// implementations must provide [`TreeNode::visit_children()`] and +/// [`TreeNode::transform_children()`] for visiting and changing the structure of the tree. /// /// [`TreeNode`] is implemented for plans ([`ExecutionPlan`] and [`LogicalPlan`]) as well /// as expression trees ([`PhysicalExpr`], [`Expr`]) in DataFusion. @@ -38,7 +38,7 @@ use crate::Result; /// nodes of these subquery plans are the inner children of the containing query plan /// node. /// -/// Tree node implementations can provide [`TreeNode::apply_inner_children()`] for +/// Tree node implementations can provide [`TreeNode::visit_inner_children()`] for /// visiting the structure of the inner tree. /// /// @@ -64,10 +64,10 @@ pub trait TreeNode: Sized { // Run the recursive `apply` on each inner children, but as they are // unrelated root nodes of inner trees if any returns stop then continue // with the next one. - self.apply_inner_children(&mut |c| c.visit_down(f)?.continue_on_stop())? + self.visit_inner_children(&mut |c| c.visit_down(f)?.continue_on_stop())? // Run the recursive `apply` on each children. .and_then_on_continue(|| { - self.apply_children(&mut |c| c.visit_down(f)) + self.visit_children(&mut |c| c.visit_down(f)) }) })? // Applying `f` on self might have returned prune, but we need to propagate @@ -114,10 +114,10 @@ pub trait TreeNode: Sized { // Run the recursive `visit` on each inner children, but as they are // unrelated subquery plans if any returns stop then continue with the // next one. - self.apply_inner_children(&mut |c| c.visit(visitor)?.continue_on_stop())? + self.visit_inner_children(&mut |c| c.visit(visitor)?.continue_on_stop())? // Run the recursive `visit` on each children. .and_then_on_continue(|| { - self.apply_children(&mut |c| c.visit(visitor)) + self.visit_children(&mut |c| c.visit(visitor)) })? // Apply `post_visit` on self. .and_then_on_continue(|| visitor.post_visit(self)) @@ -357,12 +357,12 @@ pub trait TreeNode: Sized { } /// Apply `f` to the node's children. - fn apply_children(&self, f: &mut F) -> Result + fn visit_children(&self, f: &mut F) -> Result where F: FnMut(&Self) -> Result; /// Apply `f` to the node's inner children. - fn apply_inner_children(&self, _f: &mut F) -> Result + fn visit_inner_children(&self, _f: &mut F) -> Result where F: FnMut(&Self) -> Result, { @@ -610,7 +610,7 @@ pub trait DynTreeNode { /// Blanket implementation for Arc for any tye that implements /// [`DynTreeNode`] (such as [`Arc`]) impl TreeNode for Arc { - fn apply_children(&self, f: &mut F) -> Result + fn visit_children(&self, f: &mut F) -> Result where F: FnMut(&Self) -> Result, { diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index 1d345bbfaeaf..766727496abb 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -1490,7 +1490,7 @@ impl DistributionContext { } impl TreeNode for DistributionContext { - fn apply_children(&self, f: &mut F) -> Result + fn visit_children(&self, f: &mut F) -> Result where F: FnMut(&Self) -> Result, { diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index 7512f6e8aa2c..376ba3bc9f23 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -159,7 +159,7 @@ impl PlanWithCorrespondingSort { } impl TreeNode for PlanWithCorrespondingSort { - fn apply_children(&self, f: &mut F) -> Result + fn visit_children(&self, f: &mut F) -> Result where F: FnMut(&Self) -> Result, { @@ -283,7 +283,7 @@ impl PlanWithCorrespondingCoalescePartitions { } impl TreeNode for PlanWithCorrespondingCoalescePartitions { - fn apply_children(&self, f: &mut F) -> Result + fn visit_children(&self, f: &mut F) -> Result where F: FnMut(&Self) -> Result, { diff --git a/datafusion/core/src/physical_optimizer/pipeline_checker.rs b/datafusion/core/src/physical_optimizer/pipeline_checker.rs index 9176bab57656..3f07025d9eb0 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_checker.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_checker.rs @@ -96,7 +96,7 @@ impl PipelineStatePropagator { } impl TreeNode for PipelineStatePropagator { - fn apply_children(&self, f: &mut F) -> Result + fn visit_children(&self, f: &mut F) -> Result where F: FnMut(&Self) -> Result, { diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index 4b001d67aca9..ed6c09346d7d 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -120,7 +120,7 @@ impl OrderPreservationContext { } impl TreeNode for OrderPreservationContext { - fn apply_children(&self, f: &mut F) -> Result + fn visit_children(&self, f: &mut F) -> Result where F: FnMut(&Self) -> Result, { diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 013a7b673265..471e0a12d63d 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -275,7 +275,7 @@ impl LogicalPlan { /// children pub fn expressions(self: &LogicalPlan) -> Vec { let mut exprs = vec![]; - self.apply_expressions(&mut |e| { + self.visit_expressions(&mut |e| { exprs.push(e.clone()); Ok(TreeNodeRecursion::Continue) }) @@ -288,7 +288,7 @@ impl LogicalPlan { /// logical plan nodes and all its descendant nodes. pub fn all_out_ref_exprs(self: &LogicalPlan) -> Vec { let mut exprs = vec![]; - self.apply_expressions(&mut |e| { + self.visit_expressions(&mut |e| { find_out_reference_exprs(e).into_iter().for_each(|e| { if !exprs.contains(&e) { exprs.push(e) @@ -311,7 +311,7 @@ impl LogicalPlan { /// Apply `f` on expressions of the plan node. /// `f` is not allowed to return [`TreeNodeRecursion::Prune`]. - pub fn apply_expressions(&self, f: &mut F) -> Result + pub fn visit_expressions(&self, f: &mut F) -> Result where F: FnMut(&Expr) -> Result, { @@ -1140,11 +1140,11 @@ impl LogicalPlan { /// Apply `f` on the root nodes of subquery plans of the plan node. /// `f` is not allowed to return [`TreeNodeRecursion::Prune`]. - pub fn apply_subqueries(&self, f: &mut F) -> Result + pub fn visit_subqueries(&self, f: &mut F) -> Result where F: FnMut(&Self) -> Result, { - self.apply_expressions(&mut |e| { + self.visit_expressions(&mut |e| { e.visit_down(&mut |e| match e { Expr::Exists(Exists { subquery, .. }) | Expr::InSubquery(InSubquery { subquery, .. }) @@ -1194,7 +1194,7 @@ impl LogicalPlan { let mut param_types: HashMap> = HashMap::new(); self.visit_down(&mut |plan| { - plan.apply_expressions(&mut |expr| { + plan.visit_expressions(&mut |expr| { expr.visit_down(&mut |expr| { if let Expr::Placeholder(Placeholder { id, data_type }) = expr { let prev = param_types.get(id); diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index de407063d78f..e80a4cae7e58 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -28,7 +28,7 @@ use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, VisitRecursionIt use datafusion_common::{internal_err, DataFusionError, Result}; impl TreeNode for Expr { - fn apply_children(&self, f: &mut F) -> Result + fn visit_children(&self, f: &mut F) -> Result where F: FnMut(&Self) -> Result, { diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/tree_node/plan.rs index e85294ea5f73..91d336cac498 100644 --- a/datafusion/expr/src/tree_node/plan.rs +++ b/datafusion/expr/src/tree_node/plan.rs @@ -22,18 +22,18 @@ use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, VisitRecursionIt use datafusion_common::Result; impl TreeNode for LogicalPlan { - fn apply_children(&self, f: &mut F) -> Result + fn visit_children(&self, f: &mut F) -> Result where F: FnMut(&Self) -> Result, { self.inputs().into_iter().for_each_till_continue(f) } - fn apply_inner_children(&self, f: &mut F) -> Result + fn visit_inner_children(&self, f: &mut F) -> Result where F: FnMut(&Self) -> Result, { - self.apply_subqueries(f) + self.visit_subqueries(f) } fn map_children(self, transform: F) -> Result diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index 0b2c20db3957..25b2b1246062 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -117,7 +117,7 @@ impl Analyzer { /// Do necessary check and fail the invalid plan fn check_plan(plan: &LogicalPlan) -> Result<()> { plan.visit_down(&mut |plan: &LogicalPlan| { - plan.apply_expressions(&mut |e| { + plan.visit_expressions(&mut |e| { // recursively look for subqueries e.visit_down(&mut |e| { match e { diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index 78c630982d9b..d676be9a1087 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -144,7 +144,7 @@ fn check_inner_plan( // We want to support as many operators as possible inside the correlated subquery match inner_plan { LogicalPlan::Aggregate(_) => { - inner_plan.apply_children(&mut |plan| { + inner_plan.visit_children(&mut |plan| { check_inner_plan(plan, is_scalar, true, can_contain_outer_ref)?; Ok(TreeNodeRecursion::Continue) })?; @@ -169,7 +169,7 @@ fn check_inner_plan( } LogicalPlan::Window(window) => { check_mixed_out_refer_in_window(window)?; - inner_plan.apply_children(&mut |plan| { + inner_plan.visit_children(&mut |plan| { check_inner_plan(plan, is_scalar, is_aggregate, can_contain_outer_ref)?; Ok(TreeNodeRecursion::Continue) })?; @@ -186,7 +186,7 @@ fn check_inner_plan( | LogicalPlan::Values(_) | LogicalPlan::Subquery(_) | LogicalPlan::SubqueryAlias(_) => { - inner_plan.apply_children(&mut |plan| { + inner_plan.visit_children(&mut |plan| { check_inner_plan(plan, is_scalar, is_aggregate, can_contain_outer_ref)?; Ok(TreeNodeRecursion::Continue) })?; @@ -199,7 +199,7 @@ fn check_inner_plan( .. }) => match join_type { JoinType::Inner => { - inner_plan.apply_children(&mut |plan| { + inner_plan.visit_children(&mut |plan| { check_inner_plan( plan, is_scalar, @@ -219,7 +219,7 @@ fn check_inner_plan( check_inner_plan(right, is_scalar, is_aggregate, can_contain_outer_ref) } JoinType::Full => { - inner_plan.apply_children(&mut |plan| { + inner_plan.visit_children(&mut |plan| { check_inner_plan(plan, is_scalar, is_aggregate, false)?; Ok(TreeNodeRecursion::Continue) })?; diff --git a/datafusion/physical-expr/src/sort_properties.rs b/datafusion/physical-expr/src/sort_properties.rs index 7c61e14e345a..99e99e8bb657 100644 --- a/datafusion/physical-expr/src/sort_properties.rs +++ b/datafusion/physical-expr/src/sort_properties.rs @@ -174,7 +174,7 @@ impl ExprOrdering { } impl TreeNode for ExprOrdering { - fn apply_children(&self, f: &mut F) -> Result + fn visit_children(&self, f: &mut F) -> Result where F: FnMut(&Self) -> Result, { diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index d07a960b0f71..bfe23dcce74c 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -154,7 +154,7 @@ impl ExprTreeNode { } impl TreeNode for ExprTreeNode { - fn apply_children(&self, f: &mut F) -> Result + fn visit_children(&self, f: &mut F) -> Result where F: FnMut(&Self) -> Result, { From 9279c6a9daca35035a257cdcd0a8e82e3bbe777e Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 20 Dec 2023 15:49:35 +0100 Subject: [PATCH 5/9] fix `transform_with_payload()` to behave like `transform_down_with_payload()` in its pre-order transform (`f_down`) function --- datafusion/common/src/tree_node.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 1cc73bdd1e0d..6bfbd6dac7b5 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -150,11 +150,11 @@ pub trait TreeNode: Sized { fn transform_with_payload( &mut self, f_down: &mut FD, - payload_down: Option, + payload_down: PD, f_up: &mut FU, ) -> Result<(TreeNodeRecursion, Option)> where - FD: FnMut(&mut Self, Option) -> Result<(TreeNodeRecursion, Vec)>, + FD: FnMut(&mut Self, PD) -> Result<(TreeNodeRecursion, Vec)>, FU: FnMut(&mut Self, Vec) -> Result<(TreeNodeRecursion, PU)>, { // Apply `f_down` on self. @@ -168,7 +168,7 @@ pub trait TreeNode: Sized { let mut payload_up = vec![]; let tnr = self.transform_children(&mut |c| { let (tnr, p) = - c.transform_with_payload(f_down, new_payload_down_iter.next(), f_up)?; + c.transform_with_payload(f_down, new_payload_down_iter.next().unwrap(), f_up)?; p.into_iter().for_each(|p| payload_up.push(p)); Ok(tnr) })?; From 888228584d9dd95beae6644e648ee0bb2b34ba6e Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 20 Dec 2023 15:50:07 +0100 Subject: [PATCH 6/9] add docs --- datafusion/common/src/tree_node.rs | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 6bfbd6dac7b5..9898760eab36 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -428,7 +428,7 @@ pub trait TreeNode: Sized { } /// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for -/// recursively walking [`TreeNode`]s. +/// recursively visiting [`TreeNode`]s. /// /// [`TreeNodeVisitor`] allows keeping the algorithms separate from the code to traverse /// the structure of the [`TreeNode`] tree and makes it easier to add new types of tree @@ -450,6 +450,14 @@ pub trait TreeNodeVisitor: Sized { fn post_visit(&mut self, _node: &Self::Node) -> Result; } +/// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for +/// recursively transforming [`TreeNode`]s. +/// +/// When passed to [`TreeNode::transform()`], [`TreeNodeVisitor::pre_transform()`] and +/// [`TreeNodeVisitor::post_transform()`] are invoked recursively on an node tree. +/// See [`TreeNodeRecursion`] for more details on how the traversal can be controlled. +/// +/// If an [`Err`] result is returned, recursion is stopped immediately. pub trait TreeNodeTransformer: Sized { /// The node type which is visitable. type Node: TreeNode; @@ -492,8 +500,9 @@ pub enum RewriteRecursion { Skip, } -/// Controls how the [`TreeNode`] recursion should proceed for [`TreeNode::visit_down()`] and -/// [`TreeNode::visit()`]. +/// Controls how a [`TreeNode`] recursion should proceed for [`TreeNode::visit_down()`], +/// [`TreeNode::visit()`], [`TreeNode::transform_down()`], [`TreeNode::transform_up()`] +/// and [`TreeNode::transform()`]. #[derive(Debug)] pub enum TreeNodeRecursion { /// Continue the visit to the next node. @@ -515,6 +524,8 @@ pub enum TreeNodeRecursion { } impl TreeNodeRecursion { + /// Helper function to define behavior of a [`TreeNode`] recursion to continue with a + /// closure if the recursion so far resulted [`TreeNodeRecursion::Continue]`. pub fn and_then_on_continue(self, f: F) -> Result where F: FnOnce() -> Result, @@ -525,7 +536,7 @@ impl TreeNodeRecursion { } } - pub fn continue_on_prune(self) -> Result { + fn continue_on_prune(self) -> Result { Ok(match self { TreeNodeRecursion::Prune => TreeNodeRecursion::Continue, o => o, @@ -539,7 +550,7 @@ impl TreeNodeRecursion { }) } - pub fn continue_on_stop(self) -> Result { + fn continue_on_stop(self) -> Result { Ok(match self { TreeNodeRecursion::Stop => TreeNodeRecursion::Continue, o => o, From 6cd5d39bfd4eb17adc139103f996fe834b26fbf2 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 20 Dec 2023 17:41:28 +0100 Subject: [PATCH 7/9] remove `Expr::Nop`, define `Expr::Literal(ScalarValue::Null)` as default for `Expr` --- datafusion/core/src/datasource/listing/helpers.rs | 3 +-- datafusion/core/src/physical_planner.rs | 3 --- datafusion/expr/src/expr.rs | 13 +++++++------ datafusion/expr/src/expr_schema.rs | 4 +--- datafusion/expr/src/tree_node/expr.rs | 3 --- datafusion/expr/src/utils.rs | 3 +-- datafusion/optimizer/src/push_down_filter.rs | 1 - .../src/simplify_expressions/expr_simplifier.rs | 3 +-- datafusion/proto/src/logical_plan/to_proto.rs | 4 ---- 9 files changed, 11 insertions(+), 26 deletions(-) diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index 870bddbaaaa5..3852e79e5bb3 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -62,8 +62,7 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { Ok(TreeNodeRecursion::Stop) } } - Expr::Nop - | Expr::Literal(_) + Expr::Literal(_) | Expr::Alias(_) | Expr::OuterReferenceColumn(_, _) | Expr::ScalarVariable(_, _) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index fd9e81c1b752..e5816eb49ebb 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -381,9 +381,6 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { Expr::OuterReferenceColumn(_, _) => { internal_err!("Create physical name does not support OuterReferenceColumn") } - Expr::Nop => { - internal_err!("Create physical name does not support Nop expression") - } } } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 93ec2f369b41..5618e11fc595 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -81,10 +81,8 @@ use std::{fmt, mem}; /// assert_eq!(binary_expr.op, Operator::Eq); /// } /// ``` -#[derive(Clone, PartialEq, Eq, Hash, Debug, Default)] +#[derive(Clone, PartialEq, Eq, Hash, Debug)] pub enum Expr { - #[default] - Nop, /// An expression with a specific name. Alias(Alias), /// A named reference to a qualified filed in a schema. @@ -181,6 +179,12 @@ pub enum Expr { OuterReferenceColumn(DataType, Column), } +impl Default for Expr { + fn default() -> Self { + Expr::Literal(ScalarValue::Null) + } +} + /// Alias expression #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct Alias { @@ -786,7 +790,6 @@ impl Expr { /// Useful for non-rust based bindings pub fn variant_name(&self) -> &str { match self { - Expr::Nop { .. } => "Nop", Expr::AggregateFunction { .. } => "AggregateFunction", Expr::Alias(..) => "Alias", Expr::Between { .. } => "Between", @@ -1207,7 +1210,6 @@ macro_rules! expr_vec_fmt { impl fmt::Display for Expr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - Expr::Nop => write!(f, "NOP"), Expr::Alias(Alias { expr, name, .. }) => write!(f, "{expr} AS {name}"), Expr::Column(c) => write!(f, "{c}"), Expr::OuterReferenceColumn(_, c) => write!(f, "outer_ref({c})"), @@ -1450,7 +1452,6 @@ fn create_function_name(fun: &str, distinct: bool, args: &[Expr]) -> Result 2)". fn create_name(e: &Expr) -> Result { match e { - Expr::Nop => Ok("NOP".to_string()), Expr::Alias(Alias { name, .. }) => Ok(name.clone()), Expr::Column(c) => Ok(c.flat_name()), Expr::OuterReferenceColumn(_, c) => Ok(format!("outer_ref({})", c.flat_name())), diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 71987667be4a..e5b0185d90e0 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -67,7 +67,6 @@ impl ExprSchemable for Expr { /// (e.g. `[utf8] + [bool]`). fn get_type(&self, schema: &S) -> Result { match self { - Expr::Nop => Ok(DataType::Null), Expr::Alias(Alias { expr, name, .. }) => match &**expr { Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type { None => schema.data_type(&Column::from_name(name)).cloned(), @@ -252,8 +251,7 @@ impl ExprSchemable for Expr { | Expr::WindowFunction { .. } | Expr::AggregateFunction { .. } | Expr::Placeholder(_) => Ok(true), - Expr::Nop - | Expr::IsNull(_) + Expr::IsNull(_) | Expr::IsNotNull(_) | Expr::IsTrue(_) | Expr::IsFalse(_) diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index e80a4cae7e58..bd451fb8d55a 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -65,7 +65,6 @@ impl TreeNode for Expr { Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { lists_of_exprs.iter().flatten().for_each_till_continue(f) } - Expr::Nop | Expr::Column(_) // Treat OuterReferenceColumn as a leaf expression | Expr::OuterReferenceColumn(_, _) @@ -120,7 +119,6 @@ impl TreeNode for Expr { let mut transform = transform; Ok(match self { - Expr::Nop => self, Expr::Alias(Alias { expr, relation, @@ -381,7 +379,6 @@ impl TreeNode for Expr { Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { lists_of_exprs.iter_mut().flatten().for_each_till_continue(f) } - Expr::Nop | Expr::Column(_) // Treat OuterReferenceColumn as a leaf expression | Expr::OuterReferenceColumn(_, _) diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 285119749f1a..522bb3ec8bca 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -270,8 +270,7 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { // Use explicit pattern match instead of a default // implementation, so that in the future if someone adds // new Expr types, they will check here as well - Expr::Nop - | Expr::ScalarVariable(_, _) + Expr::ScalarVariable(_, _) | Expr::Alias(_) | Expr::Literal(_) | Expr::BinaryExpr { .. } diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 7a58944e5ac9..7e5b5b784e14 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -251,7 +251,6 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::ScalarFunction(..) | Expr::InList { .. } => Ok(TreeNodeRecursion::Continue), Expr::Sort(_) - | Expr::Nop | Expr::AggregateFunction(_) | Expr::WindowFunction(_) | Expr::Wildcard { .. } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index fe2e1345290b..fdad693f3463 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -331,8 +331,7 @@ impl<'a> ConstEvaluator<'a> { // at plan time match expr { // Has no runtime cost, but needed during planning - Expr::Nop - | Expr::Alias(..) + Expr::Alias(..) | Expr::AggregateFunction { .. } | Expr::ScalarVariable(_, _) | Expr::Column(_) diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index f2d2779f99a8..2997d147424d 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -484,10 +484,6 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { use protobuf::logical_expr_node::ExprType; let expr_node = match expr { - Expr::Nop => Err(Error::NotImplemented( - "Proto serialization error: Trying to serialize a Nop expression" - .to_string(), - ))?, Expr::Column(c) => Self { expr_type: Some(ExprType::Column(c.into())), }, From 25b75bb1f90fb4e1f03df14a3a49bf761da28433 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 20 Dec 2023 18:02:13 +0100 Subject: [PATCH 8/9] fix docs --- datafusion/common/src/tree_node.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 682c85033b0b..acb00fde2fac 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -456,8 +456,8 @@ pub trait TreeNodeVisitor: Sized { /// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for /// recursively transforming [`TreeNode`]s. /// -/// When passed to [`TreeNode::transform()`], [`TreeNodeVisitor::pre_transform()`] and -/// [`TreeNodeVisitor::post_transform()`] are invoked recursively on an node tree. +/// When passed to [`TreeNode::transform()`], [`TreeNodeTransformer::pre_transform()`] and +/// [`TreeNodeTransformer::post_transform()`] are invoked recursively on an node tree. /// See [`TreeNodeRecursion`] for more details on how the traversal can be controlled. /// /// If an [`Err`] result is returned, recursion is stopped immediately. From 9e13bea7afaa3357496108b067acf440227891b0 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 2 Jan 2024 11:28:59 +0100 Subject: [PATCH 9/9] revert `transform_with_payload`, `transform_down_with_payload` and `transform_up_with_payload` related changes --- datafusion/common/src/tree_node.rs | 89 ------- .../enforce_distribution.rs | 251 ++++++++++++------ .../src/physical_optimizer/enforce_sorting.rs | 65 +---- .../src/physical_optimizer/sort_pushdown.rs | 165 +++++++++++- 4 files changed, 330 insertions(+), 240 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index acb00fde2fac..e3187c041ea6 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -147,47 +147,6 @@ pub trait TreeNode: Sized { .continue_on_prune() } - fn transform_with_payload( - &mut self, - f_down: &mut FD, - payload_down: PD, - f_up: &mut FU, - ) -> Result<(TreeNodeRecursion, Option)> - where - FD: FnMut(&mut Self, PD) -> Result<(TreeNodeRecursion, Vec)>, - FU: FnMut(&mut Self, Vec) -> Result<(TreeNodeRecursion, PU)>, - { - // Apply `f_down` on self. - let (tnr, new_payload_down) = f_down(self, payload_down)?; - let mut new_payload_down_iter = new_payload_down.into_iter(); - // If it returns continue (not prune or stop or stop all) then continue traversal - // on inner children and children. - let mut new_payload_up = None; - tnr.and_then_on_continue(|| { - // Run the recursive `transform` on each children. - let mut payload_up = vec![]; - let tnr = self.transform_children(&mut |c| { - let (tnr, p) = c.transform_with_payload( - f_down, - new_payload_down_iter.next().unwrap(), - f_up, - )?; - p.into_iter().for_each(|p| payload_up.push(p)); - Ok(tnr) - })?; - // Apply `f_up` on self. - tnr.and_then_on_continue(|| { - let (tnr, np) = f_up(self, payload_up)?; - new_payload_up = Some(np); - Ok(tnr) - }) - })? - // Applying `f_down` or `f_up` on self might have returned prune, but we need to propagate - // continue. - .continue_on_prune() - .map(|tnr| (tnr, new_payload_up)) - } - fn transform_down(&mut self, f: &mut F) -> Result where F: FnMut(&mut Self) -> Result, @@ -204,27 +163,6 @@ pub trait TreeNode: Sized { .continue_on_prune() } - fn transform_down_with_payload( - &mut self, - f: &mut F, - payload: P, - ) -> Result - where - F: FnMut(&mut Self, P) -> Result<(TreeNodeRecursion, Vec

)>, - { - // Apply `f` on self. - let (tnr, new_payload) = f(self, payload)?; - let mut new_payload_iter = new_payload.into_iter(); - // If it returns continue (not prune or stop or stop all) then continue - // traversal on inner children and children. - tnr.and_then_on_continue(|| - // Run the recursive `transform` on each children. - self.transform_children(&mut |c| c.transform_down_with_payload(f, new_payload_iter.next().unwrap())))? - // Applying `f` on self might have returned prune, but we need to propagate - // continue. - .continue_on_prune() - } - fn transform_up(&mut self, f: &mut F) -> Result where F: FnMut(&mut Self) -> Result, @@ -238,33 +176,6 @@ pub trait TreeNode: Sized { .continue_on_prune() } - fn transform_up_with_payload( - &mut self, - f: &mut F, - ) -> Result<(TreeNodeRecursion, Option

)> - where - F: FnMut(&mut Self, Vec

) -> Result<(TreeNodeRecursion, P)>, - { - // Run the recursive `transform` on each children. - let mut payload = vec![]; - let tnr = self.transform_children(&mut |c| { - let (tnr, p) = c.transform_up_with_payload(f)?; - p.into_iter().for_each(|p| payload.push(p)); - Ok(tnr) - })?; - let mut new_payload = None; - // Apply `f` on self. - tnr.and_then_on_continue(|| { - let (tnr, np) = f(self, payload)?; - new_payload = Some(np); - Ok(tnr) - })? - // Applying `f` on self might have returned prune, but we need to propagate - // continue. - .continue_on_prune() - .map(|tnr| (tnr, new_payload)) - } - /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its /// children(Preorder Traversal). /// When the `op` does not apply to a given node, it is left unchanged. diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index 766727496abb..7e0230fe4bf7 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -191,15 +191,17 @@ impl EnforceDistribution { impl PhysicalOptimizerRule for EnforceDistribution { fn optimize( &self, - mut plan: Arc, + plan: Arc, config: &ConfigOptions, ) -> Result> { let top_down_join_key_reordering = config.optimizer.top_down_join_key_reordering; let adjusted = if top_down_join_key_reordering { // Run a top-down process to adjust input key ordering recursively - plan.transform_down_with_payload(&mut adjust_input_keys_ordering, None)?; - plan + let plan_requirements = PlanWithKeyRequirements::new(plan); + let adjusted = + plan_requirements.transform_down_old(&adjust_input_keys_ordering)?; + adjusted.plan } else { // Run a bottom-up process plan.transform_up_old(&|plan| { @@ -267,15 +269,12 @@ impl PhysicalOptimizerRule for EnforceDistribution { /// 4) If the current plan is Projection, transform the requirements to the columns before the Projection and push down requirements /// 5) For other types of operators, by default, pushdown the parent requirements to children. /// -type RequiredKeyOrdering = Option>>; - fn adjust_input_keys_ordering( - plan: &mut Arc, - required_key_ordering: RequiredKeyOrdering, -) -> Result<(TreeNodeRecursion, Vec)> { - let parent_required = required_key_ordering.unwrap_or_default().clone(); - let plan_any = plan.as_any(); - if let Some(HashJoinExec { + requirements: PlanWithKeyRequirements, +) -> Result> { + let parent_required = requirements.required_key_ordering.clone(); + let plan_any = requirements.plan.as_any(); + let transformed = if let Some(HashJoinExec { left, right, on, @@ -300,15 +299,13 @@ fn adjust_input_keys_ordering( *null_equals_null, )?) as Arc) }; - let (new_plan, request_key_ordering) = reorder_partitioned_join_keys( - plan.clone(), + Some(reorder_partitioned_join_keys( + requirements.plan.clone(), &parent_required, on, vec![], &join_constructor, - )?; - *plan = new_plan; - Ok((TreeNodeRecursion::Continue, request_key_ordering)) + )?) } PartitionMode::CollectLeft => { let new_right_request = match join_type { @@ -326,14 +323,15 @@ fn adjust_input_keys_ordering( }; // Push down requirements to the right side - Ok((TreeNodeRecursion::Continue, vec![None, new_right_request])) + Some(PlanWithKeyRequirements { + plan: requirements.plan.clone(), + required_key_ordering: vec![], + request_key_ordering: vec![None, new_right_request], + }) } PartitionMode::Auto => { // Can not satisfy, clear the current requirements and generate new empty requirements - Ok(( - TreeNodeRecursion::Continue, - vec![None; plan.children().len()], - )) + Some(PlanWithKeyRequirements::new(requirements.plan.clone())) } } } else if let Some(CrossJoinExec { left, .. }) = @@ -341,13 +339,14 @@ fn adjust_input_keys_ordering( { let left_columns_len = left.schema().fields().len(); // Push down requirements to the right side - Ok(( - TreeNodeRecursion::Continue, - vec![ + Some(PlanWithKeyRequirements { + plan: requirements.plan.clone(), + required_key_ordering: vec![], + request_key_ordering: vec![ None, shift_right_required(&parent_required, left_columns_len), ], - )) + }) } else if let Some(SortMergeJoinExec { left, right, @@ -369,38 +368,26 @@ fn adjust_input_keys_ordering( *null_equals_null, )?) as Arc) }; - let (new_plan, request_key_ordering) = reorder_partitioned_join_keys( - plan.clone(), + Some(reorder_partitioned_join_keys( + requirements.plan.clone(), &parent_required, on, sort_options.clone(), &join_constructor, - )?; - *plan = new_plan; - Ok((TreeNodeRecursion::Continue, request_key_ordering)) + )?) } else if let Some(aggregate_exec) = plan_any.downcast_ref::() { if !parent_required.is_empty() { match aggregate_exec.mode() { - AggregateMode::FinalPartitioned => { - let (new_plan, request_key_ordering) = reorder_aggregate_keys( - plan.clone(), - &parent_required, - aggregate_exec, - )?; - *plan = new_plan; - Ok((TreeNodeRecursion::Continue, request_key_ordering)) - } - _ => Ok(( - TreeNodeRecursion::Continue, - vec![None; plan.children().len()], - )), + AggregateMode::FinalPartitioned => Some(reorder_aggregate_keys( + requirements.plan.clone(), + &parent_required, + aggregate_exec, + )?), + _ => Some(PlanWithKeyRequirements::new(requirements.plan.clone())), } } else { // Keep everything unchanged - Ok(( - TreeNodeRecursion::Continue, - vec![None; plan.children().len()], - )) + None } } else if let Some(proj) = plan_any.downcast_ref::() { let expr = proj.expr(); @@ -409,33 +396,34 @@ fn adjust_input_keys_ordering( // Construct a mapping from new name to the the orginal Column let new_required = map_columns_before_projection(&parent_required, expr); if new_required.len() == parent_required.len() { - Ok(( - TreeNodeRecursion::Continue, - vec![Some(new_required.clone())], - )) + Some(PlanWithKeyRequirements { + plan: requirements.plan.clone(), + required_key_ordering: vec![], + request_key_ordering: vec![Some(new_required.clone())], + }) } else { // Can not satisfy, clear the current requirements and generate new empty requirements - Ok(( - TreeNodeRecursion::Continue, - vec![None; plan.children().len()], - )) + Some(PlanWithKeyRequirements::new(requirements.plan.clone())) } } else if plan_any.downcast_ref::().is_some() || plan_any.downcast_ref::().is_some() || plan_any.downcast_ref::().is_some() { - Ok(( - TreeNodeRecursion::Continue, - vec![None; plan.children().len()], - )) + Some(PlanWithKeyRequirements::new(requirements.plan.clone())) } else { // By default, push down the parent requirements to children - let children_len = plan.children().len(); - Ok(( - TreeNodeRecursion::Continue, - vec![Some(parent_required.clone()); children_len], - )) - } + let children_len = requirements.plan.children().len(); + Some(PlanWithKeyRequirements { + plan: requirements.plan.clone(), + required_key_ordering: vec![], + request_key_ordering: vec![Some(parent_required.clone()); children_len], + }) + }; + Ok(if let Some(transformed) = transformed { + Transformed::Yes(transformed) + } else { + Transformed::No(requirements) + }) } fn reorder_partitioned_join_keys( @@ -444,7 +432,7 @@ fn reorder_partitioned_join_keys( on: &[(Column, Column)], sort_options: Vec, join_constructor: &F, -) -> Result<(Arc, Vec)> +) -> Result where F: Fn((Vec<(Column, Column)>, Vec)) -> Result>, { @@ -467,21 +455,27 @@ where new_sort_options.push(sort_options[new_positions[idx]]) } - Ok(( - join_constructor((new_join_on, new_sort_options))?, - vec![Some(left_keys), Some(right_keys)], - )) + Ok(PlanWithKeyRequirements { + plan: join_constructor((new_join_on, new_sort_options))?, + required_key_ordering: vec![], + request_key_ordering: vec![Some(left_keys), Some(right_keys)], + }) } else { - Ok((join_plan, vec![Some(left_keys), Some(right_keys)])) + Ok(PlanWithKeyRequirements { + plan: join_plan, + required_key_ordering: vec![], + request_key_ordering: vec![Some(left_keys), Some(right_keys)], + }) } } else { - Ok(( - join_plan, - vec![ + Ok(PlanWithKeyRequirements { + plan: join_plan, + required_key_ordering: vec![], + request_key_ordering: vec![ Some(join_key_pairs.left_keys), Some(join_key_pairs.right_keys), ], - )) + }) } } @@ -489,7 +483,7 @@ fn reorder_aggregate_keys( agg_plan: Arc, parent_required: &[Arc], agg_exec: &AggregateExec, -) -> Result<(Arc, Vec)> { +) -> Result { let output_columns = agg_exec .group_by() .expr() @@ -507,15 +501,11 @@ fn reorder_aggregate_keys( || !agg_exec.group_by().null_expr().is_empty() || physical_exprs_equal(&output_exprs, parent_required) { - let request_key_ordering = vec![None; agg_plan.children().len()]; - Ok((agg_plan, request_key_ordering)) + Ok(PlanWithKeyRequirements::new(agg_plan)) } else { let new_positions = expected_expr_positions(&output_exprs, parent_required); match new_positions { - None => { - let request_key_ordering = vec![None; agg_plan.children().len()]; - Ok((agg_plan, request_key_ordering)) - } + None => Ok(PlanWithKeyRequirements::new(agg_plan)), Some(positions) => { let new_partial_agg = if let Some(agg_exec) = agg_exec.input().as_any().downcast_ref::() @@ -587,13 +577,11 @@ fn reorder_aggregate_keys( .push((Arc::new(Column::new(name, idx)) as _, name.clone())) } // TODO merge adjacent Projections if there are - let new_plan = - Arc::new(ProjectionExec::try_new(proj_exprs, new_final_agg)?); - let request_key_ordering = vec![None; new_plan.children().len()]; - Ok((new_plan, request_key_ordering)) + Ok(PlanWithKeyRequirements::new(Arc::new( + ProjectionExec::try_new(proj_exprs, new_final_agg)?, + ))) } else { - let request_key_ordering = vec![None; agg_plan.children().len()]; - Ok((agg_plan, request_key_ordering)) + Ok(PlanWithKeyRequirements::new(agg_plan)) } } } @@ -1551,6 +1539,93 @@ struct JoinKeyPairs { right_keys: Vec>, } +#[derive(Debug, Clone)] +struct PlanWithKeyRequirements { + plan: Arc, + /// Parent required key ordering + required_key_ordering: Vec>, + /// The request key ordering to children + request_key_ordering: Vec>>>, +} + +impl PlanWithKeyRequirements { + fn new(plan: Arc) -> Self { + let children_len = plan.children().len(); + PlanWithKeyRequirements { + plan, + required_key_ordering: vec![], + request_key_ordering: vec![None; children_len], + } + } + + fn children(&self) -> Vec { + let plan_children = self.plan.children(); + assert_eq!(plan_children.len(), self.request_key_ordering.len()); + plan_children + .into_iter() + .zip(self.request_key_ordering.clone()) + .map(|(child, required)| { + let from_parent = required.unwrap_or_default(); + let length = child.children().len(); + PlanWithKeyRequirements { + plan: child, + required_key_ordering: from_parent, + request_key_ordering: vec![None; length], + } + }) + .collect() + } +} + +impl TreeNode for PlanWithKeyRequirements { + fn visit_children(&self, f: &mut F) -> Result + where + F: FnMut(&Self) -> Result, + { + self.children().iter().for_each_till_continue(f) + } + + fn map_children(self, transform: F) -> Result + where + F: FnMut(Self) -> Result, + { + let children = self.children(); + if !children.is_empty() { + let new_children: Result> = + children.into_iter().map(transform).collect(); + + let children_plans = new_children? + .into_iter() + .map(|child| child.plan) + .collect::>(); + let new_plan = with_new_children_if_necessary(self.plan, children_plans)?; + Ok(PlanWithKeyRequirements { + plan: new_plan.into(), + required_key_ordering: self.required_key_ordering, + request_key_ordering: self.request_key_ordering, + }) + } else { + Ok(self) + } + } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + let mut children = self.children(); + if !children.is_empty() { + let tnr = children.iter_mut().for_each_till_continue(f)?; + let children_plans = children.into_iter().map(|c| c.plan).collect(); + self.plan = + with_new_children_if_necessary(self.plan.clone(), children_plans)?.into(); + Ok(tnr) + } else { + Ok(TreeNodeRecursion::Continue) + } + } +} + /// Since almost all of these tests explicitly use `ParquetExec` they only run with the parquet feature flag on #[cfg(feature = "parquet")] #[cfg(test)] diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index 44d53531d090..6d44ab85237d 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -41,7 +41,7 @@ use crate::error::Result; use crate::physical_optimizer::replace_with_order_preserving_variants::{ replace_with_order_preserving_variants, OrderPreservationContext, }; -use crate::physical_optimizer::sort_pushdown::pushdown_requirement_to_children; +use crate::physical_optimizer::sort_pushdown::{pushdown_sorts, SortPushDown}; use crate::physical_optimizer::utils::{ add_sort_above, is_coalesce_partitions, is_limit, is_repartition, is_sort, is_sort_preserving_merge, is_union, is_window, ExecTree, @@ -351,7 +351,7 @@ impl PhysicalOptimizerRule for EnforceSorting { adjusted.plan }; let plan_with_pipeline_fixer = OrderPreservationContext::new(new_plan); - let mut updated_plan = + let updated_plan = plan_with_pipeline_fixer.transform_up_old(&|plan_with_pipeline_fixer| { replace_with_order_preserving_variants( plan_with_pipeline_fixer, @@ -363,64 +363,9 @@ impl PhysicalOptimizerRule for EnforceSorting { // Execute a top-down traversal to exploit sort push-down opportunities // missed by the bottom-up traversal: - updated_plan.plan.transform_down_with_payload( - &mut |plan, required_ordering: Option>| { - let parent_required = required_ordering.as_deref().unwrap_or(&[]); - if let Some(sort_exec) = plan.as_any().downcast_ref::() { - let new_plan = if !plan - .equivalence_properties() - .ordering_satisfy_requirement(parent_required) - { - // If the current plan is a SortExec, modify it to satisfy parent requirements: - let mut new_plan = sort_exec.input().clone(); - add_sort_above(&mut new_plan, parent_required, sort_exec.fetch()); - new_plan - } else { - plan.clone() - }; - let required_ordering = new_plan - .output_ordering() - .map(PhysicalSortRequirement::from_sort_exprs) - .unwrap_or_default(); - // Since new_plan is a SortExec, we can safely get the 0th index. - let child = new_plan.children().swap_remove(0); - if let Some(adjusted) = - pushdown_requirement_to_children(&child, &required_ordering)? - { - *plan = child; - Ok((TreeNodeRecursion::Continue, adjusted)) - } else { - *plan = new_plan; - // Can not push down requirements - Ok((TreeNodeRecursion::Continue, plan.required_input_ordering())) - } - } else { - // Executors other than SortExec - if plan - .equivalence_properties() - .ordering_satisfy_requirement(parent_required) - { - // Satisfies parent requirements, immediately return. - return Ok(( - TreeNodeRecursion::Continue, - plan.required_input_ordering(), - )); - } - // Can not satisfy the parent requirements, check whether the requirements can be pushed down: - if let Some(adjusted) = - pushdown_requirement_to_children(plan, parent_required)? - { - Ok((TreeNodeRecursion::Continue, adjusted)) - } else { - // Can not push down requirements, add new SortExec: - add_sort_above(plan, parent_required, None); - Ok((TreeNodeRecursion::Continue, plan.required_input_ordering())) - } - } - }, - None, - )?; - Ok(updated_plan.plan) + let sort_pushdown = SortPushDown::init(updated_plan.plan); + let adjusted = sort_pushdown.transform_down_old(&pushdown_sorts)?; + Ok(adjusted.plan) } fn name(&self) -> &str { diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs b/datafusion/core/src/physical_optimizer/sort_pushdown.rs index d06adb82c83e..62fc2bc77fde 100644 --- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs @@ -18,15 +18,19 @@ use std::sync::Arc; use crate::physical_optimizer::utils::{ - is_limit, is_sort_preserving_merge, is_union, is_window, + add_sort_above, is_limit, is_sort_preserving_merge, is_union, is_window, }; use crate::physical_plan::filter::FilterExec; use crate::physical_plan::joins::utils::calculate_join_output_ordering; use crate::physical_plan::joins::{HashJoinExec, SortMergeJoinExec}; use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::repartition::RepartitionExec; -use crate::physical_plan::ExecutionPlan; +use crate::physical_plan::sorts::sort::SortExec; +use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeRecursion, VisitRecursionIterator, +}; use datafusion_common::{plan_err, DataFusionError, JoinSide, Result}; use datafusion_expr::JoinType; use datafusion_physical_expr::expressions::Column; @@ -34,7 +38,162 @@ use datafusion_physical_expr::{ LexRequirementRef, PhysicalSortExpr, PhysicalSortRequirement, }; -pub fn pushdown_requirement_to_children( +use itertools::izip; + +/// This is a "data class" we use within the [`EnforceSorting`] rule to push +/// down [`SortExec`] in the plan. In some cases, we can reduce the total +/// computational cost by pushing down `SortExec`s through some executors. +/// +/// [`EnforceSorting`]: crate::physical_optimizer::enforce_sorting::EnforceSorting +#[derive(Debug, Clone)] +pub(crate) struct SortPushDown { + /// Current plan + pub plan: Arc, + /// Parent required sort ordering + required_ordering: Option>, + /// The adjusted request sort ordering to children. + /// By default they are the same as the plan's required input ordering, but can be adjusted based on parent required sort ordering properties. + adjusted_request_ordering: Vec>>, +} + +impl SortPushDown { + pub fn init(plan: Arc) -> Self { + let request_ordering = plan.required_input_ordering(); + SortPushDown { + plan, + required_ordering: None, + adjusted_request_ordering: request_ordering, + } + } + + pub fn children(&self) -> Vec { + izip!( + self.plan.children().into_iter(), + self.adjusted_request_ordering.clone().into_iter(), + ) + .map(|(child, from_parent)| { + let child_request_ordering = child.required_input_ordering(); + SortPushDown { + plan: child, + required_ordering: from_parent, + adjusted_request_ordering: child_request_ordering, + } + }) + .collect() + } +} + +impl TreeNode for SortPushDown { + fn visit_children(&self, f: &mut F) -> Result + where + F: FnMut(&Self) -> Result, + { + self.children().iter().for_each_till_continue(f) + } + + fn map_children(mut self, transform: F) -> Result + where + F: FnMut(Self) -> Result, + { + let children = self.children(); + if !children.is_empty() { + let children_plans = children + .into_iter() + .map(transform) + .map(|r| r.map(|s| s.plan)) + .collect::>>()?; + + match with_new_children_if_necessary(self.plan, children_plans)? { + Transformed::Yes(plan) | Transformed::No(plan) => { + self.plan = plan; + } + } + }; + Ok(self) + } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + let mut children = self.children(); + if !children.is_empty() { + let tnr = children.iter_mut().for_each_till_continue(f)?; + let children_plans = children.into_iter().map(|c| c.plan).collect(); + self.plan = + with_new_children_if_necessary(self.plan.clone(), children_plans)?.into(); + Ok(tnr) + } else { + Ok(TreeNodeRecursion::Continue) + } + } +} + +pub(crate) fn pushdown_sorts( + requirements: SortPushDown, +) -> Result> { + let plan = &requirements.plan; + let parent_required = requirements.required_ordering.as_deref().unwrap_or(&[]); + if let Some(sort_exec) = plan.as_any().downcast_ref::() { + let new_plan = if !plan + .equivalence_properties() + .ordering_satisfy_requirement(parent_required) + { + // If the current plan is a SortExec, modify it to satisfy parent requirements: + let mut new_plan = sort_exec.input().clone(); + add_sort_above(&mut new_plan, parent_required, sort_exec.fetch()); + new_plan + } else { + requirements.plan + }; + let required_ordering = new_plan + .output_ordering() + .map(PhysicalSortRequirement::from_sort_exprs) + .unwrap_or_default(); + // Since new_plan is a SortExec, we can safely get the 0th index. + let child = new_plan.children().swap_remove(0); + if let Some(adjusted) = + pushdown_requirement_to_children(&child, &required_ordering)? + { + // Can push down requirements + Ok(Transformed::Yes(SortPushDown { + plan: child, + required_ordering: None, + adjusted_request_ordering: adjusted, + })) + } else { + // Can not push down requirements + Ok(Transformed::Yes(SortPushDown::init(new_plan))) + } + } else { + // Executors other than SortExec + if plan + .equivalence_properties() + .ordering_satisfy_requirement(parent_required) + { + // Satisfies parent requirements, immediately return. + return Ok(Transformed::Yes(SortPushDown { + required_ordering: None, + ..requirements + })); + } + // Can not satisfy the parent requirements, check whether the requirements can be pushed down: + if let Some(adjusted) = pushdown_requirement_to_children(plan, parent_required)? { + Ok(Transformed::Yes(SortPushDown { + plan: requirements.plan, + required_ordering: None, + adjusted_request_ordering: adjusted, + })) + } else { + // Can not push down requirements, add new SortExec: + let mut new_plan = requirements.plan; + add_sort_above(&mut new_plan, parent_required, None); + Ok(Transformed::Yes(SortPushDown::init(new_plan))) + } + } +} + +fn pushdown_requirement_to_children( plan: &Arc, parent_required: LexRequirementRef, ) -> Result>>>> {