Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 178 additions & 4 deletions datafusion/expr/src/logical_plan/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,59 @@ use crate::{
Aggregate, Analyze, CreateMemoryTable, CreateView, DdlStatement, Distinct,
DistinctOn, DmlStatement, Execute, Explain, Expr, Extension, Filter, Join, Limit,
LogicalPlan, Partitioning, Prepare, Projection, RecursiveQuery, Repartition, Sort,
Statement, Subquery, SubqueryAlias, TableScan, Union, Unnest, UserDefinedLogicalNode,
Values, Window, dml::CopyTo,
Statement, Subquery, SubqueryAlias, TableScan, TableSource, Union, Unnest,
UserDefinedLogicalNode, Values, Window, dml::CopyTo,
};
use datafusion_common::tree_node::TreeNodeRefContainer;

use crate::expr::{Exists, InSubquery};
use arrow::datatypes::SchemaRef;
use datafusion_common::tree_node::{
Transformed, TreeNode, TreeNodeContainer, TreeNodeIterator, TreeNodeRecursion,
TreeNodeRewriter, TreeNodeVisitor,
};
use datafusion_common::{Result, internal_err};
use std::{any::Any, borrow::Cow, sync::Arc};

/// Wrapper around a TableSource that replaces its logical plan
/// without requiring the TableSource API to be modified
struct TableSourceWithPlan {
inner: Arc<dyn TableSource>,
logical_plan: LogicalPlan,
}

impl TableSource for TableSourceWithPlan {
fn as_any(&self) -> &dyn Any {
self
}

fn schema(&self) -> SchemaRef {
self.inner.schema()
}

fn constraints(&self) -> Option<&datafusion_common::Constraints> {
self.inner.constraints()
}

fn table_type(&self) -> crate::TableType {
self.inner.table_type()
}

fn supports_filters_pushdown(
&self,
filters: &[&Expr],
) -> Result<Vec<crate::TableProviderFilterPushDown>> {
self.inner.supports_filters_pushdown(filters)
}

fn get_logical_plan(&'_ self) -> Option<Cow<'_, LogicalPlan>> {
Some(Cow::Borrowed(&self.logical_plan))
}

fn get_column_default(&self, column: &str) -> Option<&Expr> {
self.inner.get_column_default(column)
}
}

impl TreeNode for LogicalPlan {
fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
Expand Down Expand Up @@ -346,8 +388,30 @@ impl TreeNode for LogicalPlan {
}
.update_data(LogicalPlan::Statement),
// plans without inputs
LogicalPlan::TableScan { .. }
| LogicalPlan::EmptyRelation { .. }
LogicalPlan::TableScan(scan) => {
if let Some(inner_cow) = scan.source.get_logical_plan() {
let inner_plan_owned = inner_cow.into_owned();

inner_plan_owned.map_elements(f)?.update_data(|new_inner| {
let new_source = Arc::new(TableSourceWithPlan {
inner: Arc::clone(&scan.source),
logical_plan: new_inner,
})
as Arc<dyn TableSource>;
LogicalPlan::TableScan(TableScan {
table_name: scan.table_name,
source: new_source,
projection: scan.projection,
projected_schema: scan.projected_schema,
filters: scan.filters,
fetch: scan.fetch,
})
})
} else {
Transformed::no(LogicalPlan::TableScan(scan))
}
}
LogicalPlan::EmptyRelation { .. }
| LogicalPlan::Values { .. }
| LogicalPlan::DescribeTable(_) => Transformed::no(self),
})
Expand Down Expand Up @@ -868,3 +932,113 @@ impl LogicalPlan {
})
}
}

#[cfg(test)]
mod tests {
use crate::{EmptyRelation, table_source::TableSource};

use super::*;
use std::any::Any;
use std::borrow::Cow;
use std::sync::Arc;

use arrow::datatypes::{Schema, SchemaRef}; // arrow crate types
use datafusion_common::tree_node::Transformed;
use datafusion_common::{DFSchema, DFSchemaRef, Result};

#[derive(Clone)]
struct TestProvider {
plan: Option<LogicalPlan>,
schema: SchemaRef,
}

impl TestProvider {
fn with_plan(plan: LogicalPlan) -> Self {
Self {
plan: Some(plan),
schema: Arc::new(Schema::empty()),
}
}

fn without_plan() -> Self {
Self {
plan: None,
schema: Arc::new(Schema::empty()),
}
}
}

impl TableSource for TestProvider {
fn as_any(&self) -> &dyn Any {
self
}

fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}

fn get_logical_plan(&'_ self) -> Option<Cow<'_, LogicalPlan>> {
// return an owned LogicalPlan so tests don't need lifetime juggling
self.plan.as_ref().map(|p| Cow::Owned(p.clone()))
}
}

#[test]
fn test_table_scan_with_inner_plan_is_visited() -> Result<()> {
let df_schema_ref: DFSchemaRef = Arc::new(DFSchema::empty());

let inner_empty = EmptyRelation {
produce_one_row: false,
schema: Arc::clone(&df_schema_ref),
};

let inner_plan = LogicalPlan::EmptyRelation(inner_empty);

let provider =
Arc::new(TestProvider::with_plan(inner_plan.clone())) as Arc<dyn TableSource>;

let scan = TableScan::try_new("t", provider, None, vec![], None)?;

let plan = LogicalPlan::TableScan(scan);

let visited = Arc::new(std::sync::Mutex::new(false));
let visited_clone = Arc::clone(&visited);

let _ = plan.map_children(|child_plan: LogicalPlan| {
if matches!(&child_plan, LogicalPlan::EmptyRelation(_)) {
let mut flag = visited_clone.lock().unwrap();
*flag = true;
}
Ok(Transformed::no(child_plan))
})?;

assert!(
*visited.lock().unwrap(),
"expected inner logical plan to be visited"
);
Ok(())
}

#[test]
fn test_table_scan_without_inner_plan_is_not_visited() -> Result<()> {
let provider = Arc::new(TestProvider::without_plan()) as Arc<dyn TableSource>;
let scan = TableScan::try_new("t", provider, None, vec![], None)?;
let plan = LogicalPlan::TableScan(scan);

let visited = Arc::new(std::sync::Mutex::new(false));
let visited_clone = Arc::clone(&visited);

let _ = plan.map_children(|child_plan: LogicalPlan| {
// If this is called for any child, mark visited
let mut flag = visited_clone.lock().unwrap();
*flag = true;
Ok(Transformed::no(child_plan))
})?;

assert!(
!*visited.lock().unwrap(),
"did not expect inner visit when provider had no plan"
);
Ok(())
}
}