diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 62a27b0a025ad..3b868acffb16e 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -41,17 +41,59 @@ use crate::{ Aggregate, Analyze, CreateMemoryTable, CreateView, DdlStatement, Distinct, DistinctOn, DmlStatement, Execute, Explain, Expr, Extension, Filter, Join, Limit, LogicalPlan, Partitioning, Prepare, Projection, RecursiveQuery, Repartition, Sort, - Statement, Subquery, SubqueryAlias, TableScan, Union, Unnest, UserDefinedLogicalNode, - Values, Window, dml::CopyTo, + Statement, Subquery, SubqueryAlias, TableScan, TableSource, Union, Unnest, + UserDefinedLogicalNode, Values, Window, dml::CopyTo, }; use datafusion_common::tree_node::TreeNodeRefContainer; use crate::expr::{Exists, InSubquery}; +use arrow::datatypes::SchemaRef; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeContainer, TreeNodeIterator, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, }; use datafusion_common::{Result, internal_err}; +use std::{any::Any, borrow::Cow, sync::Arc}; + +/// Wrapper around a TableSource that replaces its logical plan +/// without requiring the TableSource API to be modified +struct TableSourceWithPlan { + inner: Arc, + logical_plan: LogicalPlan, +} + +impl TableSource for TableSourceWithPlan { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.inner.schema() + } + + fn constraints(&self) -> Option<&datafusion_common::Constraints> { + self.inner.constraints() + } + + fn table_type(&self) -> crate::TableType { + self.inner.table_type() + } + + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> Result> { + self.inner.supports_filters_pushdown(filters) + } + + fn get_logical_plan(&'_ self) -> Option> { + Some(Cow::Borrowed(&self.logical_plan)) + } + + fn get_column_default(&self, column: &str) -> Option<&Expr> { + self.inner.get_column_default(column) + } +} impl TreeNode for LogicalPlan { fn apply_children<'n, F: FnMut(&'n Self) -> Result>( @@ -346,8 +388,30 @@ impl TreeNode for LogicalPlan { } .update_data(LogicalPlan::Statement), // plans without inputs - LogicalPlan::TableScan { .. } - | LogicalPlan::EmptyRelation { .. } + LogicalPlan::TableScan(scan) => { + if let Some(inner_cow) = scan.source.get_logical_plan() { + let inner_plan_owned = inner_cow.into_owned(); + + inner_plan_owned.map_elements(f)?.update_data(|new_inner| { + let new_source = Arc::new(TableSourceWithPlan { + inner: Arc::clone(&scan.source), + logical_plan: new_inner, + }) + as Arc; + LogicalPlan::TableScan(TableScan { + table_name: scan.table_name, + source: new_source, + projection: scan.projection, + projected_schema: scan.projected_schema, + filters: scan.filters, + fetch: scan.fetch, + }) + }) + } else { + Transformed::no(LogicalPlan::TableScan(scan)) + } + } + LogicalPlan::EmptyRelation { .. } | LogicalPlan::Values { .. } | LogicalPlan::DescribeTable(_) => Transformed::no(self), }) @@ -868,3 +932,113 @@ impl LogicalPlan { }) } } + +#[cfg(test)] +mod tests { + use crate::{EmptyRelation, table_source::TableSource}; + + use super::*; + use std::any::Any; + use std::borrow::Cow; + use std::sync::Arc; + + use arrow::datatypes::{Schema, SchemaRef}; // arrow crate types + use datafusion_common::tree_node::Transformed; + use datafusion_common::{DFSchema, DFSchemaRef, Result}; + + #[derive(Clone)] + struct TestProvider { + plan: Option, + schema: SchemaRef, + } + + impl TestProvider { + fn with_plan(plan: LogicalPlan) -> Self { + Self { + plan: Some(plan), + schema: Arc::new(Schema::empty()), + } + } + + fn without_plan() -> Self { + Self { + plan: None, + schema: Arc::new(Schema::empty()), + } + } + } + + impl TableSource for TestProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn get_logical_plan(&'_ self) -> Option> { + // return an owned LogicalPlan so tests don't need lifetime juggling + self.plan.as_ref().map(|p| Cow::Owned(p.clone())) + } + } + + #[test] + fn test_table_scan_with_inner_plan_is_visited() -> Result<()> { + let df_schema_ref: DFSchemaRef = Arc::new(DFSchema::empty()); + + let inner_empty = EmptyRelation { + produce_one_row: false, + schema: Arc::clone(&df_schema_ref), + }; + + let inner_plan = LogicalPlan::EmptyRelation(inner_empty); + + let provider = + Arc::new(TestProvider::with_plan(inner_plan.clone())) as Arc; + + let scan = TableScan::try_new("t", provider, None, vec![], None)?; + + let plan = LogicalPlan::TableScan(scan); + + let visited = Arc::new(std::sync::Mutex::new(false)); + let visited_clone = Arc::clone(&visited); + + let _ = plan.map_children(|child_plan: LogicalPlan| { + if matches!(&child_plan, LogicalPlan::EmptyRelation(_)) { + let mut flag = visited_clone.lock().unwrap(); + *flag = true; + } + Ok(Transformed::no(child_plan)) + })?; + + assert!( + *visited.lock().unwrap(), + "expected inner logical plan to be visited" + ); + Ok(()) + } + + #[test] + fn test_table_scan_without_inner_plan_is_not_visited() -> Result<()> { + let provider = Arc::new(TestProvider::without_plan()) as Arc; + let scan = TableScan::try_new("t", provider, None, vec![], None)?; + let plan = LogicalPlan::TableScan(scan); + + let visited = Arc::new(std::sync::Mutex::new(false)); + let visited_clone = Arc::clone(&visited); + + let _ = plan.map_children(|child_plan: LogicalPlan| { + // If this is called for any child, mark visited + let mut flag = visited_clone.lock().unwrap(); + *flag = true; + Ok(Transformed::no(child_plan)) + })?; + + assert!( + !*visited.lock().unwrap(), + "did not expect inner visit when provider had no plan" + ); + Ok(()) + } +}