diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 482fc96b519b..ce26cac7970b 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -17,12 +17,19 @@ //! Tree node implementation for logical plan -use crate::LogicalPlan; +use crate::{ + Aggregate, Analyze, CreateMemoryTable, CreateView, CrossJoin, DdlStatement, Distinct, + DistinctOn, DmlStatement, Explain, Extension, Filter, Join, Limit, LogicalPlan, + Prepare, Projection, RecursiveQuery, Repartition, Sort, Subquery, SubqueryAlias, + Union, Unnest, Window, +}; +use std::sync::Arc; +use crate::dml::CopyTo; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, }; -use datafusion_common::Result; +use datafusion_common::{map_until_stop_and_collect, Result}; impl TreeNode for LogicalPlan { fn apply_children Result>( @@ -32,23 +39,359 @@ impl TreeNode for LogicalPlan { self.inputs().into_iter().apply_until_stop(f) } - fn map_children Result>>( - self, - f: F, - ) -> Result> { - let new_children = self - .inputs() - .into_iter() - .cloned() - .map_until_stop_and_collect(f)?; - // Propagate up `new_children.transformed` and `new_children.tnr` - // along with the node containing transformed children. - if new_children.transformed { - new_children.map_data(|new_children| { - self.with_new_exprs(self.expressions(), new_children) - }) - } else { - Ok(new_children.update_data(|_| self)) - } + /// Applies `f` to each child (input) of this plan node, rewriting them *in place.* + /// + /// # Notes + /// + /// Inputs include ONLY direct children, not embedded `LogicalPlan`s for + /// subqueries, for example such as are in [`Expr::Exists`]. + /// + /// [`Expr::Exists`]: crate::Expr::Exists + fn map_children(self, mut f: F) -> Result> + where + F: FnMut(Self) -> Result>, + { + Ok(match self { + LogicalPlan::Projection(Projection { + expr, + input, + schema, + }) => rewrite_arc(input, f)?.update_data(|input| { + LogicalPlan::Projection(Projection { + expr, + input, + schema, + }) + }), + LogicalPlan::Filter(Filter { predicate, input }) => rewrite_arc(input, f)? + .update_data(|input| LogicalPlan::Filter(Filter { predicate, input })), + LogicalPlan::Repartition(Repartition { + input, + partitioning_scheme, + }) => rewrite_arc(input, f)?.update_data(|input| { + LogicalPlan::Repartition(Repartition { + input, + partitioning_scheme, + }) + }), + LogicalPlan::Window(Window { + input, + window_expr, + schema, + }) => rewrite_arc(input, f)?.update_data(|input| { + LogicalPlan::Window(Window { + input, + window_expr, + schema, + }) + }), + LogicalPlan::Aggregate(Aggregate { + input, + group_expr, + aggr_expr, + schema, + }) => rewrite_arc(input, f)?.update_data(|input| { + LogicalPlan::Aggregate(Aggregate { + input, + group_expr, + aggr_expr, + schema, + }) + }), + LogicalPlan::Sort(Sort { expr, input, fetch }) => rewrite_arc(input, f)? + .update_data(|input| LogicalPlan::Sort(Sort { expr, input, fetch })), + LogicalPlan::Join(Join { + left, + right, + on, + filter, + join_type, + join_constraint, + schema, + null_equals_null, + }) => map_until_stop_and_collect!( + rewrite_arc(left, &mut f), + right, + rewrite_arc(right, &mut f) + )? + .update_data(|(left, right)| { + LogicalPlan::Join(Join { + left, + right, + on, + filter, + join_type, + join_constraint, + schema, + null_equals_null, + }) + }), + LogicalPlan::CrossJoin(CrossJoin { + left, + right, + schema, + }) => map_until_stop_and_collect!( + rewrite_arc(left, &mut f), + right, + rewrite_arc(right, &mut f) + )? + .update_data(|(left, right)| { + LogicalPlan::CrossJoin(CrossJoin { + left, + right, + schema, + }) + }), + LogicalPlan::Limit(Limit { skip, fetch, input }) => rewrite_arc(input, f)? + .update_data(|input| LogicalPlan::Limit(Limit { skip, fetch, input })), + LogicalPlan::Subquery(Subquery { + subquery, + outer_ref_columns, + }) => rewrite_arc(subquery, f)?.update_data(|subquery| { + LogicalPlan::Subquery(Subquery { + subquery, + outer_ref_columns, + }) + }), + LogicalPlan::SubqueryAlias(SubqueryAlias { + input, + alias, + schema, + }) => rewrite_arc(input, f)?.update_data(|input| { + LogicalPlan::SubqueryAlias(SubqueryAlias { + input, + alias, + schema, + }) + }), + LogicalPlan::Extension(extension) => rewrite_extension_inputs(extension, f)? + .update_data(LogicalPlan::Extension), + LogicalPlan::Union(Union { inputs, schema }) => rewrite_arcs(inputs, f)? + .update_data(|inputs| LogicalPlan::Union(Union { inputs, schema })), + LogicalPlan::Distinct(distinct) => match distinct { + Distinct::All(input) => rewrite_arc(input, f)?.update_data(Distinct::All), + Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + input, + schema, + }) => rewrite_arc(input, f)?.update_data(|input| { + Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + input, + schema, + }) + }), + } + .update_data(LogicalPlan::Distinct), + LogicalPlan::Explain(Explain { + verbose, + plan, + stringified_plans, + schema, + logical_optimization_succeeded, + }) => rewrite_arc(plan, f)?.update_data(|plan| { + LogicalPlan::Explain(Explain { + verbose, + plan, + stringified_plans, + schema, + logical_optimization_succeeded, + }) + }), + LogicalPlan::Analyze(Analyze { + verbose, + input, + schema, + }) => rewrite_arc(input, f)?.update_data(|input| { + LogicalPlan::Analyze(Analyze { + verbose, + input, + schema, + }) + }), + LogicalPlan::Dml(DmlStatement { + table_name, + table_schema, + op, + input, + }) => rewrite_arc(input, f)?.update_data(|input| { + LogicalPlan::Dml(DmlStatement { + table_name, + table_schema, + op, + input, + }) + }), + LogicalPlan::Copy(CopyTo { + input, + output_url, + partition_by, + format_options, + options, + }) => rewrite_arc(input, f)?.update_data(|input| { + LogicalPlan::Copy(CopyTo { + input, + output_url, + partition_by, + format_options, + options, + }) + }), + LogicalPlan::Ddl(ddl) => { + match ddl { + DdlStatement::CreateMemoryTable(CreateMemoryTable { + name, + constraints, + input, + if_not_exists, + or_replace, + column_defaults, + }) => rewrite_arc(input, f)?.update_data(|input| { + DdlStatement::CreateMemoryTable(CreateMemoryTable { + name, + constraints, + input, + if_not_exists, + or_replace, + column_defaults, + }) + }), + DdlStatement::CreateView(CreateView { + name, + input, + or_replace, + definition, + }) => rewrite_arc(input, f)?.update_data(|input| { + DdlStatement::CreateView(CreateView { + name, + input, + or_replace, + definition, + }) + }), + // no inputs in these statements + DdlStatement::CreateExternalTable(_) + | DdlStatement::CreateCatalogSchema(_) + | DdlStatement::CreateCatalog(_) + | DdlStatement::DropTable(_) + | DdlStatement::DropView(_) + | DdlStatement::DropCatalogSchema(_) + | DdlStatement::CreateFunction(_) + | DdlStatement::DropFunction(_) => Transformed::no(ddl), + } + .update_data(LogicalPlan::Ddl) + } + LogicalPlan::Unnest(Unnest { + input, + column, + schema, + options, + }) => rewrite_arc(input, f)?.update_data(|input| { + LogicalPlan::Unnest(Unnest { + input, + column, + schema, + options, + }) + }), + LogicalPlan::Prepare(Prepare { + name, + data_types, + input, + }) => rewrite_arc(input, f)?.update_data(|input| { + LogicalPlan::Prepare(Prepare { + name, + data_types, + input, + }) + }), + LogicalPlan::RecursiveQuery(RecursiveQuery { + name, + static_term, + recursive_term, + is_distinct, + }) => map_until_stop_and_collect!( + rewrite_arc(static_term, &mut f), + recursive_term, + rewrite_arc(recursive_term, &mut f) + )? + .update_data(|(static_term, recursive_term)| { + LogicalPlan::RecursiveQuery(RecursiveQuery { + name, + static_term, + recursive_term, + is_distinct, + }) + }), + // plans without inputs + LogicalPlan::TableScan { .. } + | LogicalPlan::Statement { .. } + | LogicalPlan::EmptyRelation { .. } + | LogicalPlan::Values { .. } + | LogicalPlan::DescribeTable(_) => Transformed::no(self), + }) } } + +/// Converts a `Arc` without copying, if possible. Copies the plan +/// if there is a shared reference +fn unwrap_arc(plan: Arc) -> LogicalPlan { + Arc::try_unwrap(plan) + // if None is returned, there is another reference to this + // LogicalPlan, so we can not own it, and must clone instead + .unwrap_or_else(|node| node.as_ref().clone()) +} + +/// Applies `f` to rewrite a `Arc` without copying, if possible +fn rewrite_arc( + plan: Arc, + mut f: F, +) -> Result>> +where + F: FnMut(LogicalPlan) -> Result>, +{ + f(unwrap_arc(plan))?.map_data(|new_plan| Ok(Arc::new(new_plan))) +} + +/// rewrite a `Vec` of `Arc` without copying, if possible +fn rewrite_arcs( + input_plans: Vec>, + mut f: F, +) -> Result>>> +where + F: FnMut(LogicalPlan) -> Result>, +{ + input_plans + .into_iter() + .map_until_stop_and_collect(|plan| rewrite_arc(plan, &mut f)) +} + +/// Rewrites all inputs for an Extension node "in place" +/// (it currently has to copy values because there are no APIs for in place modification) +/// +/// Should be removed when we have an API for in place modifications of the +/// extension to avoid these copies +fn rewrite_extension_inputs( + extension: Extension, + f: F, +) -> Result> +where + F: FnMut(LogicalPlan) -> Result>, +{ + let Extension { node } = extension; + + node.inputs() + .into_iter() + .cloned() + .map_until_stop_and_collect(f)? + .map_data(|new_inputs| { + let exprs = node.expressions(); + Ok(Extension { + node: node.from_template(&exprs, &new_inputs), + }) + }) +}