Skip to content

Commit df512fa

Browse files
committed
Implement TreeNode::map_children in place
1 parent 4bd7c13 commit df512fa

File tree

3 files changed

+179
-17
lines changed

3 files changed

+179
-17
lines changed

datafusion/expr/src/logical_plan/ddl.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,24 @@ impl DdlStatement {
112112
}
113113
}
114114

115+
/// Return a mutable reference to the input `LogicalPlan`, if any
116+
pub fn input_mut(&mut self) -> Option<&mut Arc<LogicalPlan>> {
117+
match self {
118+
DdlStatement::CreateMemoryTable(CreateMemoryTable { input, .. }) => {
119+
Some(input)
120+
}
121+
DdlStatement::CreateExternalTable(_) => None,
122+
DdlStatement::CreateView(CreateView { input, .. }) => Some(input),
123+
DdlStatement::CreateCatalogSchema(_) => None,
124+
DdlStatement::CreateCatalog(_) => None,
125+
DdlStatement::DropTable(_) => None,
126+
DdlStatement::DropView(_) => None,
127+
DdlStatement::DropCatalogSchema(_) => None,
128+
DdlStatement::CreateFunction(_) => None,
129+
DdlStatement::DropFunction(_) => None,
130+
}
131+
}
132+
115133
/// Return a `format`able structure with the a human readable
116134
/// description of this LogicalPlan node per node, not including
117135
/// children.

datafusion/expr/src/logical_plan/plan.rs

Lines changed: 154 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
use std::collections::{HashMap, HashSet};
2121
use std::fmt::{self, Debug, Display, Formatter};
2222
use std::hash::{Hash, Hasher};
23-
use std::sync::Arc;
23+
use std::sync::{Arc, OnceLock};
2424

2525
use super::dml::CopyTo;
2626
use super::DdlStatement;
@@ -1131,6 +1131,159 @@ impl LogicalPlan {
11311131
})?;
11321132
Ok(())
11331133
}
1134+
}
1135+
1136+
// TODO put this somewhere better than here
1137+
1138+
/// A temporary node that is left in place while rewriting the children of a
1139+
/// [`LogicalPlan`]. This is necessary to ensure that the `LogicalPlan` is
1140+
/// always in a valid state (from the Rust perspective)
1141+
static PLACEHOLDER: OnceLock<Arc<LogicalPlan>> = OnceLock::new();
1142+
1143+
/// its inputs, so this code would not be needed. However, for now we try and
1144+
/// unwrap the `Arc` which avoids `clone`ing in most cases.
1145+
///
1146+
/// On error, node be left with a placeholder logical plan
1147+
fn rewrite_arc<F>(node: &mut Arc<LogicalPlan>, mut f: F) -> Result<Transformed<()>>
1148+
where
1149+
F: FnMut(LogicalPlan) -> Result<Transformed<LogicalPlan>>,
1150+
{
1151+
// We need to leave a valid node in the Arc, while we rewrite the existing
1152+
// one, so use a single global static placeholder node
1153+
let mut new_node = PLACEHOLDER
1154+
.get_or_init(|| {
1155+
Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1156+
produce_one_row: false,
1157+
schema: DFSchemaRef::new(DFSchema::empty()),
1158+
}))
1159+
})
1160+
.clone();
1161+
1162+
// take the old value out of the Arc
1163+
std::mem::swap(node, &mut new_node);
1164+
1165+
// try to update existing node, if it isn't shared with others
1166+
let mut new_node = Arc::try_unwrap(new_node)
1167+
// if None is returned, there is another reference to this
1168+
// LogicalPlan, so we must clone instead
1169+
.unwrap_or_else(|node| node.as_ref().clone());
1170+
1171+
// apply the actual transform
1172+
let result = f(new_node)?;
1173+
1174+
// put the new value back into the Arc
1175+
let mut new_node = Arc::new(result.data);
1176+
std::mem::swap(node, &mut new_node);
1177+
1178+
// return the `()` back
1179+
Ok(Transformed::new((), result.transformed, result.tnr))
1180+
}
1181+
1182+
/*
1183+
/// Rewrties all inputs for an Extension node "in place"
1184+
/// (it currently has to copy values because there are no APIs for in place modification)
1185+
///
1186+
/// Should be removed when we have an API for in place modifications of the
1187+
/// extension to avoid these copies
1188+
fn rewrite_extension_inputs<F>(
1189+
node: &mut Arc<dyn UserDefinedLogicalNode>,
1190+
mut f: F,
1191+
) -> Result<Transformed<()>>
1192+
where
1193+
F: FnMut(&mut LogicalPlan) -> Result<Transformed<()>>,
1194+
{
1195+
let mut inputs: Vec<_> = node.inputs().into_iter().cloned().collect();
1196+
1197+
let result = inputs
1198+
.iter_mut()
1199+
.try_fold(Transformed::no(()), |acc, input| acc.and_then(|| f(input)))?;
1200+
let exprs = node.expressions();
1201+
let mut new_node = node.from_template(&exprs, &inputs);
1202+
std::mem::swap(node, &mut new_node);
1203+
Ok(result)
1204+
}
1205+
*/
1206+
1207+
impl LogicalPlan {
1208+
/// applies `f` to each input of this plan node, rewriting them in place.
1209+
///
1210+
/// # Notes
1211+
/// Inputs include both direct children as well as any embedded subquery
1212+
/// `LogicalPlan`s, for example such as are in [`Expr::Exists`].
1213+
///
1214+
/// If `f` returns an `Err`, that Err is returned, and the inputs are left
1215+
/// in a partially modified state
1216+
pub(crate) fn rewrite_children<F>(&mut self, mut f: F) -> Result<Transformed<()>>
1217+
where
1218+
F: FnMut(Self) -> Result<Transformed<Self>>,
1219+
{
1220+
let children_result = match self {
1221+
LogicalPlan::Projection(Projection { input, .. }) => {
1222+
rewrite_arc(input, &mut f)
1223+
}
1224+
LogicalPlan::Filter(Filter { input, .. }) => rewrite_arc(input, &mut f),
1225+
LogicalPlan::Repartition(Repartition { input, .. }) => {
1226+
rewrite_arc(input, &mut f)
1227+
}
1228+
LogicalPlan::Window(Window { input, .. }) => rewrite_arc(input, &mut f),
1229+
LogicalPlan::Aggregate(Aggregate { input, .. }) => rewrite_arc(input, &mut f),
1230+
LogicalPlan::Sort(Sort { input, .. }) => rewrite_arc(input, &mut f),
1231+
LogicalPlan::Join(Join { left, right, .. }) => {
1232+
rewrite_arc(left, &mut f)?.and_then(|| rewrite_arc(right, &mut f))
1233+
}
1234+
LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => {
1235+
rewrite_arc(left, &mut f)?.and_then(|| rewrite_arc(right, &mut f))
1236+
}
1237+
LogicalPlan::Limit(Limit { input, .. }) => rewrite_arc(input, &mut f),
1238+
LogicalPlan::Subquery(Subquery { subquery, .. }) => {
1239+
rewrite_arc(subquery, &mut f)
1240+
}
1241+
LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => {
1242+
rewrite_arc(input, &mut f)
1243+
}
1244+
LogicalPlan::Extension(extension) => {
1245+
todo!();
1246+
//rewrite_extension_inputs(&mut extension.node, &mut f)
1247+
}
1248+
LogicalPlan::Union(Union { inputs, .. }) => inputs
1249+
.iter_mut()
1250+
.try_fold(Transformed::no(()), |acc, input| {
1251+
acc.and_then(|| rewrite_arc(input, &mut f))
1252+
}),
1253+
LogicalPlan::Distinct(
1254+
Distinct::All(input) | Distinct::On(DistinctOn { input, .. }),
1255+
) => rewrite_arc(input, &mut f),
1256+
LogicalPlan::Explain(explain) => rewrite_arc(&mut explain.plan, &mut f),
1257+
LogicalPlan::Analyze(analyze) => rewrite_arc(&mut analyze.input, &mut f),
1258+
LogicalPlan::Dml(write) => rewrite_arc(&mut write.input, &mut f),
1259+
LogicalPlan::Copy(copy) => rewrite_arc(&mut copy.input, &mut f),
1260+
LogicalPlan::Ddl(ddl) => {
1261+
if let Some(input) = ddl.input_mut() {
1262+
rewrite_arc(input, &mut f)
1263+
} else {
1264+
Ok(Transformed::no(()))
1265+
}
1266+
}
1267+
LogicalPlan::Unnest(Unnest { input, .. }) => rewrite_arc(input, &mut f),
1268+
LogicalPlan::Prepare(Prepare { input, .. }) => rewrite_arc(input, &mut f),
1269+
LogicalPlan::RecursiveQuery(RecursiveQuery {
1270+
static_term,
1271+
recursive_term,
1272+
..
1273+
}) => rewrite_arc(static_term, &mut f)?
1274+
.and_then(|| rewrite_arc(recursive_term, &mut f)),
1275+
// plans without inputs
1276+
LogicalPlan::TableScan { .. }
1277+
| LogicalPlan::Statement { .. }
1278+
| LogicalPlan::EmptyRelation { .. }
1279+
| LogicalPlan::Values { .. }
1280+
| LogicalPlan::DescribeTable(_) => Ok(Transformed::no(())),
1281+
}?;
1282+
1283+
// after visiting the actual children we we need to visit any subqueries
1284+
// that are inside the expressions
1285+
children_result.and_then(|| self.rewrite_subqueries(&mut f))
1286+
}
11341287

11351288
/// Return a `LogicalPlan` with all placeholders (e.g $1 $2,
11361289
/// ...) replaced with corresponding values provided in

datafusion/expr/src/tree_node/plan.rs

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
use crate::LogicalPlan;
2121

2222
use datafusion_common::tree_node::{
23-
Transformed, TransformedIterator, TreeNode, TreeNodeRecursion, TreeNodeVisitor,
23+
Transformed, TreeNode, TreeNodeRecursion, TreeNodeVisitor,
2424
};
2525
use datafusion_common::{handle_visit_recursion, Result};
2626

@@ -91,23 +91,14 @@ impl TreeNode for LogicalPlan {
9191
Ok(tnr)
9292
}
9393

94-
fn map_children<F>(self, f: F) -> Result<Transformed<Self>>
94+
fn map_children<F>(mut self, f: F) -> Result<Transformed<Self>>
9595
where
9696
F: FnMut(Self) -> Result<Transformed<Self>>,
9797
{
98-
let new_children = self
99-
.inputs()
100-
.iter()
101-
.map(|&c| c.clone())
102-
.map_until_stop_and_collect(f)?;
103-
// Propagate up `new_children.transformed` and `new_children.tnr`
104-
// along with the node containing transformed children.
105-
if new_children.transformed {
106-
new_children.map_data(|new_children| {
107-
self.with_new_exprs(self.expressions(), new_children)
108-
})
109-
} else {
110-
Ok(new_children.update_data(|_| self))
111-
}
98+
// Apply the rewrites in place for each child
99+
let result = self.rewrite_children(f)?;
100+
101+
// return a reference to ourself
102+
Ok(result.update_data(|_| self))
112103
}
113104
}

0 commit comments

Comments
 (0)