|
20 | 20 | use std::collections::{HashMap, HashSet}; |
21 | 21 | use std::fmt::{self, Debug, Display, Formatter}; |
22 | 22 | use std::hash::{Hash, Hasher}; |
23 | | -use std::sync::Arc; |
| 23 | +use std::sync::{Arc, OnceLock}; |
24 | 24 |
|
25 | 25 | use super::dml::CopyTo; |
26 | 26 | use super::DdlStatement; |
@@ -1131,6 +1131,159 @@ impl LogicalPlan { |
1131 | 1131 | })?; |
1132 | 1132 | Ok(()) |
1133 | 1133 | } |
| 1134 | +} |
| 1135 | + |
| 1136 | +// TODO put this somewhere better than here |
| 1137 | + |
| 1138 | +/// A temporary node that is left in place while rewriting the children of a |
| 1139 | +/// [`LogicalPlan`]. This is necessary to ensure that the `LogicalPlan` is |
| 1140 | +/// always in a valid state (from the Rust perspective) |
| 1141 | +static PLACEHOLDER: OnceLock<Arc<LogicalPlan>> = OnceLock::new(); |
| 1142 | + |
| 1143 | +/// its inputs, so this code would not be needed. However, for now we try and |
| 1144 | +/// unwrap the `Arc` which avoids `clone`ing in most cases. |
| 1145 | +/// |
| 1146 | +/// On error, node be left with a placeholder logical plan |
| 1147 | +fn rewrite_arc<F>(node: &mut Arc<LogicalPlan>, mut f: F) -> Result<Transformed<()>> |
| 1148 | +where |
| 1149 | + F: FnMut(LogicalPlan) -> Result<Transformed<LogicalPlan>>, |
| 1150 | +{ |
| 1151 | + // We need to leave a valid node in the Arc, while we rewrite the existing |
| 1152 | + // one, so use a single global static placeholder node |
| 1153 | + let mut new_node = PLACEHOLDER |
| 1154 | + .get_or_init(|| { |
| 1155 | + Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { |
| 1156 | + produce_one_row: false, |
| 1157 | + schema: DFSchemaRef::new(DFSchema::empty()), |
| 1158 | + })) |
| 1159 | + }) |
| 1160 | + .clone(); |
| 1161 | + |
| 1162 | + // take the old value out of the Arc |
| 1163 | + std::mem::swap(node, &mut new_node); |
| 1164 | + |
| 1165 | + // try to update existing node, if it isn't shared with others |
| 1166 | + let mut new_node = Arc::try_unwrap(new_node) |
| 1167 | + // if None is returned, there is another reference to this |
| 1168 | + // LogicalPlan, so we must clone instead |
| 1169 | + .unwrap_or_else(|node| node.as_ref().clone()); |
| 1170 | + |
| 1171 | + // apply the actual transform |
| 1172 | + let result = f(new_node)?; |
| 1173 | + |
| 1174 | + // put the new value back into the Arc |
| 1175 | + let mut new_node = Arc::new(result.data); |
| 1176 | + std::mem::swap(node, &mut new_node); |
| 1177 | + |
| 1178 | + // return the `()` back |
| 1179 | + Ok(Transformed::new((), result.transformed, result.tnr)) |
| 1180 | +} |
| 1181 | + |
| 1182 | +/* |
| 1183 | +/// Rewrties all inputs for an Extension node "in place" |
| 1184 | +/// (it currently has to copy values because there are no APIs for in place modification) |
| 1185 | +/// |
| 1186 | +/// Should be removed when we have an API for in place modifications of the |
| 1187 | +/// extension to avoid these copies |
| 1188 | +fn rewrite_extension_inputs<F>( |
| 1189 | + node: &mut Arc<dyn UserDefinedLogicalNode>, |
| 1190 | + mut f: F, |
| 1191 | +) -> Result<Transformed<()>> |
| 1192 | +where |
| 1193 | + F: FnMut(&mut LogicalPlan) -> Result<Transformed<()>>, |
| 1194 | +{ |
| 1195 | + let mut inputs: Vec<_> = node.inputs().into_iter().cloned().collect(); |
| 1196 | +
|
| 1197 | + let result = inputs |
| 1198 | + .iter_mut() |
| 1199 | + .try_fold(Transformed::no(()), |acc, input| acc.and_then(|| f(input)))?; |
| 1200 | + let exprs = node.expressions(); |
| 1201 | + let mut new_node = node.from_template(&exprs, &inputs); |
| 1202 | + std::mem::swap(node, &mut new_node); |
| 1203 | + Ok(result) |
| 1204 | +} |
| 1205 | +*/ |
| 1206 | + |
| 1207 | +impl LogicalPlan { |
| 1208 | + /// applies `f` to each input of this plan node, rewriting them in place. |
| 1209 | + /// |
| 1210 | + /// # Notes |
| 1211 | + /// Inputs include both direct children as well as any embedded subquery |
| 1212 | + /// `LogicalPlan`s, for example such as are in [`Expr::Exists`]. |
| 1213 | + /// |
| 1214 | + /// If `f` returns an `Err`, that Err is returned, and the inputs are left |
| 1215 | + /// in a partially modified state |
| 1216 | + pub(crate) fn rewrite_children<F>(&mut self, mut f: F) -> Result<Transformed<()>> |
| 1217 | + where |
| 1218 | + F: FnMut(Self) -> Result<Transformed<Self>>, |
| 1219 | + { |
| 1220 | + let children_result = match self { |
| 1221 | + LogicalPlan::Projection(Projection { input, .. }) => { |
| 1222 | + rewrite_arc(input, &mut f) |
| 1223 | + } |
| 1224 | + LogicalPlan::Filter(Filter { input, .. }) => rewrite_arc(input, &mut f), |
| 1225 | + LogicalPlan::Repartition(Repartition { input, .. }) => { |
| 1226 | + rewrite_arc(input, &mut f) |
| 1227 | + } |
| 1228 | + LogicalPlan::Window(Window { input, .. }) => rewrite_arc(input, &mut f), |
| 1229 | + LogicalPlan::Aggregate(Aggregate { input, .. }) => rewrite_arc(input, &mut f), |
| 1230 | + LogicalPlan::Sort(Sort { input, .. }) => rewrite_arc(input, &mut f), |
| 1231 | + LogicalPlan::Join(Join { left, right, .. }) => { |
| 1232 | + rewrite_arc(left, &mut f)?.and_then(|| rewrite_arc(right, &mut f)) |
| 1233 | + } |
| 1234 | + LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => { |
| 1235 | + rewrite_arc(left, &mut f)?.and_then(|| rewrite_arc(right, &mut f)) |
| 1236 | + } |
| 1237 | + LogicalPlan::Limit(Limit { input, .. }) => rewrite_arc(input, &mut f), |
| 1238 | + LogicalPlan::Subquery(Subquery { subquery, .. }) => { |
| 1239 | + rewrite_arc(subquery, &mut f) |
| 1240 | + } |
| 1241 | + LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => { |
| 1242 | + rewrite_arc(input, &mut f) |
| 1243 | + } |
| 1244 | + LogicalPlan::Extension(extension) => { |
| 1245 | + todo!(); |
| 1246 | + //rewrite_extension_inputs(&mut extension.node, &mut f) |
| 1247 | + } |
| 1248 | + LogicalPlan::Union(Union { inputs, .. }) => inputs |
| 1249 | + .iter_mut() |
| 1250 | + .try_fold(Transformed::no(()), |acc, input| { |
| 1251 | + acc.and_then(|| rewrite_arc(input, &mut f)) |
| 1252 | + }), |
| 1253 | + LogicalPlan::Distinct( |
| 1254 | + Distinct::All(input) | Distinct::On(DistinctOn { input, .. }), |
| 1255 | + ) => rewrite_arc(input, &mut f), |
| 1256 | + LogicalPlan::Explain(explain) => rewrite_arc(&mut explain.plan, &mut f), |
| 1257 | + LogicalPlan::Analyze(analyze) => rewrite_arc(&mut analyze.input, &mut f), |
| 1258 | + LogicalPlan::Dml(write) => rewrite_arc(&mut write.input, &mut f), |
| 1259 | + LogicalPlan::Copy(copy) => rewrite_arc(&mut copy.input, &mut f), |
| 1260 | + LogicalPlan::Ddl(ddl) => { |
| 1261 | + if let Some(input) = ddl.input_mut() { |
| 1262 | + rewrite_arc(input, &mut f) |
| 1263 | + } else { |
| 1264 | + Ok(Transformed::no(())) |
| 1265 | + } |
| 1266 | + } |
| 1267 | + LogicalPlan::Unnest(Unnest { input, .. }) => rewrite_arc(input, &mut f), |
| 1268 | + LogicalPlan::Prepare(Prepare { input, .. }) => rewrite_arc(input, &mut f), |
| 1269 | + LogicalPlan::RecursiveQuery(RecursiveQuery { |
| 1270 | + static_term, |
| 1271 | + recursive_term, |
| 1272 | + .. |
| 1273 | + }) => rewrite_arc(static_term, &mut f)? |
| 1274 | + .and_then(|| rewrite_arc(recursive_term, &mut f)), |
| 1275 | + // plans without inputs |
| 1276 | + LogicalPlan::TableScan { .. } |
| 1277 | + | LogicalPlan::Statement { .. } |
| 1278 | + | LogicalPlan::EmptyRelation { .. } |
| 1279 | + | LogicalPlan::Values { .. } |
| 1280 | + | LogicalPlan::DescribeTable(_) => Ok(Transformed::no(())), |
| 1281 | + }?; |
| 1282 | + |
| 1283 | + // after visiting the actual children we we need to visit any subqueries |
| 1284 | + // that are inside the expressions |
| 1285 | + children_result.and_then(|| self.rewrite_subqueries(&mut f)) |
| 1286 | + } |
1134 | 1287 |
|
1135 | 1288 | /// Return a `LogicalPlan` with all placeholders (e.g $1 $2, |
1136 | 1289 | /// ...) replaced with corresponding values provided in |
|
0 commit comments