2020use std:: collections:: { HashMap , HashSet } ;
2121use std:: fmt:: { self , Debug , Display , Formatter } ;
2222use std:: hash:: { Hash , Hasher } ;
23- use std:: sync:: Arc ;
23+ use std:: sync:: { Arc , OnceLock } ;
2424
2525use super :: dml:: CopyTo ;
2626use super :: DdlStatement ;
@@ -45,7 +45,8 @@ use crate::{
4545
4646use arrow:: datatypes:: { DataType , Field , Schema , SchemaRef } ;
4747use datafusion_common:: tree_node:: {
48- Transformed , TransformedResult , TreeNode , TreeNodeRecursion , TreeNodeVisitor ,
48+ Transformed , TransformedIterator , TransformedResult , TreeNode , TreeNodeRecursion ,
49+ TreeNodeVisitor ,
4950} ;
5051use datafusion_common:: {
5152 aggregate_functional_dependencies, internal_err, plan_err, Column , Constraints ,
@@ -1131,6 +1132,202 @@ impl LogicalPlan {
11311132 } ) ?;
11321133 Ok ( ( ) )
11331134 }
1135+ }
1136+
1137+ // TODO put this somewhere better than here
1138+
1139+ /// A temporary node that is left in place while rewriting the children of a
1140+ /// [`LogicalPlan`]. This is necessary to ensure that the `LogicalPlan` is
1141+ /// always in a valid state (from the Rust perspective)
1142+ static PLACEHOLDER : OnceLock < Arc < LogicalPlan > > = OnceLock :: new ( ) ;
1143+
1144+ /// its inputs, so this code would not be needed. However, for now we try and
1145+ /// unwrap the `Arc` which avoids `clone`ing in most cases.
1146+ ///
1147+ /// On error, node be left with a placeholder logical plan
1148+ fn rewrite_arc < F > (
1149+ node : & mut Arc < LogicalPlan > ,
1150+ mut f : F ,
1151+ ) -> Result < Transformed < & mut Arc < LogicalPlan > > >
1152+ where
1153+ F : FnMut ( LogicalPlan ) -> Result < Transformed < LogicalPlan > > ,
1154+ {
1155+ // We need to leave a valid node in the Arc, while we rewrite the existing
1156+ // one, so use a single global static placeholder node
1157+ let mut new_node = PLACEHOLDER
1158+ . get_or_init ( || {
1159+ Arc :: new ( LogicalPlan :: EmptyRelation ( EmptyRelation {
1160+ produce_one_row : false ,
1161+ schema : DFSchemaRef :: new ( DFSchema :: empty ( ) ) ,
1162+ } ) )
1163+ } )
1164+ . clone ( ) ;
1165+
1166+ // take the old value out of the Arc
1167+ std:: mem:: swap ( node, & mut new_node) ;
1168+
1169+ // try to update existing node, if it isn't shared with others
1170+ let new_node = Arc :: try_unwrap ( new_node)
1171+ // if None is returned, there is another reference to this
1172+ // LogicalPlan, so we must clone instead
1173+ . unwrap_or_else ( |node| node. as_ref ( ) . clone ( ) ) ;
1174+
1175+ // apply the actual transform
1176+ let result = f ( new_node) ?;
1177+
1178+ // put the new value back into the Arc
1179+ let mut new_node = Arc :: new ( result. data ) ;
1180+ std:: mem:: swap ( node, & mut new_node) ;
1181+
1182+ // return the `node` back
1183+ Ok ( Transformed :: new ( node, result. transformed , result. tnr ) )
1184+ }
1185+
1186+ /// Rewrite the arc and discard the contents of Transformed
1187+ fn rewrite_arc_no_data < F > ( node : & mut Arc < LogicalPlan > , f : F ) -> Result < Transformed < ( ) > >
1188+ where
1189+ F : FnMut ( LogicalPlan ) -> Result < Transformed < LogicalPlan > > ,
1190+ {
1191+ rewrite_arc ( node, f) . map ( |res| res. discard_data ( ) )
1192+ }
1193+
1194+ /// Rewrites all inputs for an Extension node "in place"
1195+ /// (it currently has to copy values because there are no APIs for in place modification)
1196+ ///
1197+ /// Should be removed when we have an API for in place modifications of the
1198+ /// extension to avoid these copies
1199+ fn rewrite_extension_inputs < F > (
1200+ node : & mut Arc < dyn UserDefinedLogicalNode > ,
1201+ f : F ,
1202+ ) -> Result < Transformed < ( ) > >
1203+ where
1204+ F : FnMut ( LogicalPlan ) -> Result < Transformed < LogicalPlan > > ,
1205+ {
1206+ let Transformed {
1207+ data : new_inputs,
1208+ transformed,
1209+ tnr,
1210+ } = node
1211+ . inputs ( )
1212+ . into_iter ( )
1213+ . cloned ( )
1214+ . map_until_stop_and_collect ( f) ?;
1215+
1216+ let exprs = node. expressions ( ) ;
1217+ let mut new_node = node. from_template ( & exprs, & new_inputs) ;
1218+ std:: mem:: swap ( node, & mut new_node) ;
1219+ Ok ( Transformed {
1220+ data : ( ) ,
1221+ transformed,
1222+ tnr,
1223+ } )
1224+ }
1225+
1226+ impl LogicalPlan {
1227+ /// applies `f` to each input of this plan node, rewriting them *in place.*
1228+ ///
1229+ /// # Notes
1230+ /// Inputs include both direct children as well as any embedded subquery
1231+ /// `LogicalPlan`s, for example such as are in [`Expr::Exists`].
1232+ ///
1233+ /// If `f` returns an `Err`, that Err is returned, and the inputs are left
1234+ /// in a partially modified state
1235+ pub ( crate ) fn rewrite_children < F > ( & mut self , mut f : F ) -> Result < Transformed < ( ) > >
1236+ where
1237+ F : FnMut ( Self ) -> Result < Transformed < Self > > ,
1238+ {
1239+ let children_result = match self {
1240+ LogicalPlan :: Projection ( Projection { input, .. } ) => {
1241+ rewrite_arc_no_data ( input, & mut f)
1242+ }
1243+ LogicalPlan :: Filter ( Filter { input, .. } ) => {
1244+ rewrite_arc_no_data ( input, & mut f)
1245+ }
1246+ LogicalPlan :: Repartition ( Repartition { input, .. } ) => {
1247+ rewrite_arc_no_data ( input, & mut f)
1248+ }
1249+ LogicalPlan :: Window ( Window { input, .. } ) => {
1250+ rewrite_arc_no_data ( input, & mut f)
1251+ }
1252+ LogicalPlan :: Aggregate ( Aggregate { input, .. } ) => {
1253+ rewrite_arc_no_data ( input, & mut f)
1254+ }
1255+ LogicalPlan :: Sort ( Sort { input, .. } ) => rewrite_arc_no_data ( input, & mut f) ,
1256+ LogicalPlan :: Join ( Join { left, right, .. } ) => {
1257+ let results = [ left, right]
1258+ . into_iter ( )
1259+ . map_until_stop_and_collect ( |input| rewrite_arc ( input, & mut f) ) ?;
1260+ Ok ( results. discard_data ( ) )
1261+ }
1262+ LogicalPlan :: CrossJoin ( CrossJoin { left, right, .. } ) => {
1263+ let results = [ left, right]
1264+ . into_iter ( )
1265+ . map_until_stop_and_collect ( |input| rewrite_arc ( input, & mut f) ) ?;
1266+ Ok ( results. discard_data ( ) )
1267+ }
1268+ LogicalPlan :: Limit ( Limit { input, .. } ) => rewrite_arc_no_data ( input, & mut f) ,
1269+ LogicalPlan :: Subquery ( Subquery { subquery, .. } ) => {
1270+ rewrite_arc_no_data ( subquery, & mut f)
1271+ }
1272+ LogicalPlan :: SubqueryAlias ( SubqueryAlias { input, .. } ) => {
1273+ rewrite_arc_no_data ( input, & mut f)
1274+ }
1275+ LogicalPlan :: Extension ( extension) => {
1276+ rewrite_extension_inputs ( & mut extension. node , & mut f)
1277+ }
1278+ LogicalPlan :: Union ( Union { inputs, .. } ) => {
1279+ let results = inputs
1280+ . iter_mut ( )
1281+ . map_until_stop_and_collect ( |input| rewrite_arc ( input, & mut f) ) ?;
1282+ Ok ( results. discard_data ( ) )
1283+ }
1284+ LogicalPlan :: Distinct (
1285+ Distinct :: All ( input) | Distinct :: On ( DistinctOn { input, .. } ) ,
1286+ ) => rewrite_arc_no_data ( input, & mut f) ,
1287+ LogicalPlan :: Explain ( explain) => {
1288+ rewrite_arc_no_data ( & mut explain. plan , & mut f)
1289+ }
1290+ LogicalPlan :: Analyze ( analyze) => {
1291+ rewrite_arc_no_data ( & mut analyze. input , & mut f)
1292+ }
1293+ LogicalPlan :: Dml ( write) => rewrite_arc_no_data ( & mut write. input , & mut f) ,
1294+ LogicalPlan :: Copy ( copy) => rewrite_arc_no_data ( & mut copy. input , & mut f) ,
1295+ LogicalPlan :: Ddl ( ddl) => {
1296+ if let Some ( input) = ddl. input_mut ( ) {
1297+ rewrite_arc_no_data ( input, & mut f)
1298+ } else {
1299+ Ok ( Transformed :: no ( ( ) ) )
1300+ }
1301+ }
1302+ LogicalPlan :: Unnest ( Unnest { input, .. } ) => {
1303+ rewrite_arc_no_data ( input, & mut f)
1304+ }
1305+ LogicalPlan :: Prepare ( Prepare { input, .. } ) => {
1306+ rewrite_arc_no_data ( input, & mut f)
1307+ }
1308+ LogicalPlan :: RecursiveQuery ( RecursiveQuery {
1309+ static_term,
1310+ recursive_term,
1311+ ..
1312+ } ) => {
1313+ let results = [ static_term, recursive_term]
1314+ . into_iter ( )
1315+ . map_until_stop_and_collect ( |input| rewrite_arc ( input, & mut f) ) ?;
1316+ Ok ( results. discard_data ( ) )
1317+ }
1318+ // plans without inputs
1319+ LogicalPlan :: TableScan { .. }
1320+ | LogicalPlan :: Statement { .. }
1321+ | LogicalPlan :: EmptyRelation { .. }
1322+ | LogicalPlan :: Values { .. }
1323+ | LogicalPlan :: DescribeTable ( _) => Ok ( Transformed :: no ( ( ) ) ) ,
1324+ } ?;
1325+
1326+ // after visiting the actual children we we need to visit any subqueries
1327+ // that are inside the expressions
1328+ // children_result.and_then(|| self.rewrite_subqueries(&mut f))
1329+ Ok ( children_result)
1330+ }
11341331
11351332 /// Return a `LogicalPlan` with all placeholders (e.g $1 $2,
11361333 /// ...) replaced with corresponding values provided in
0 commit comments