diff --git a/datafusion/core/src/logical_plan/plan.rs b/datafusion/core/src/logical_plan/plan.rs index 0ef1ecd81caf..6529695ad33f 100644 --- a/datafusion/core/src/logical_plan/plan.rs +++ b/datafusion/core/src/logical_plan/plan.rs @@ -97,337 +97,3 @@ pub fn source_as_provider( )), } } - -#[cfg(test)] -mod tests { - use super::super::{col, lit}; - use super::*; - use crate::test_util::scan_empty; - use arrow::datatypes::{DataType, Field, Schema}; - - fn employee_schema() -> Schema { - Schema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("first_name", DataType::Utf8, false), - Field::new("last_name", DataType::Utf8, false), - Field::new("state", DataType::Utf8, false), - Field::new("salary", DataType::Int32, false), - ]) - } - - fn display_plan() -> LogicalPlan { - scan_empty(Some("employee_csv"), &employee_schema(), Some(vec![0, 3])) - .unwrap() - .filter(col("state").eq(lit("CO"))) - .unwrap() - .project(vec![col("id")]) - .unwrap() - .build() - .unwrap() - } - - #[test] - fn test_display_indent() { - let plan = display_plan(); - - let expected = "Projection: #employee_csv.id\ - \n Filter: #employee_csv.state = Utf8(\"CO\")\ - \n TableScan: employee_csv projection=Some([id, state])"; - - assert_eq!(expected, format!("{}", plan.display_indent())); - } - - #[test] - fn test_display_indent_schema() { - let plan = display_plan(); - - let expected = "Projection: #employee_csv.id [id:Int32]\ - \n Filter: #employee_csv.state = Utf8(\"CO\") [id:Int32, state:Utf8]\ - \n TableScan: employee_csv projection=Some([id, state]) [id:Int32, state:Utf8]"; - - assert_eq!(expected, format!("{}", plan.display_indent_schema())); - } - - #[test] - fn test_display_graphviz() { - let plan = display_plan(); - - // just test for a few key lines in the output rather than the - // whole thing to make test mainteance easier. - let graphviz = format!("{}", plan.display_graphviz()); - - assert!( - graphviz.contains( - r#"// Begin DataFusion GraphViz Plan (see https://graphviz.org)"# - ), - "\n{}", - plan.display_graphviz() - ); - assert!( - graphviz.contains( - r#"[shape=box label="TableScan: employee_csv projection=Some([id, state])"]"# - ), - "\n{}", - plan.display_graphviz() - ); - assert!(graphviz.contains(r#"[shape=box label="TableScan: employee_csv projection=Some([id, state])\nSchema: [id:Int32, state:Utf8]"]"#), - "\n{}", plan.display_graphviz()); - assert!( - graphviz.contains(r#"// End DataFusion GraphViz Plan"#), - "\n{}", - plan.display_graphviz() - ); - } - - /// Tests for the Visitor trait and walking logical plan nodes - #[derive(Debug, Default)] - struct OkVisitor { - strings: Vec, - } - - impl PlanVisitor for OkVisitor { - type Error = String; - - fn pre_visit( - &mut self, - plan: &LogicalPlan, - ) -> std::result::Result { - let s = match plan { - LogicalPlan::Projection { .. } => "pre_visit Projection", - LogicalPlan::Filter { .. } => "pre_visit Filter", - LogicalPlan::TableScan { .. } => "pre_visit TableScan", - _ => unimplemented!("unknown plan type"), - }; - - self.strings.push(s.into()); - Ok(true) - } - - fn post_visit( - &mut self, - plan: &LogicalPlan, - ) -> std::result::Result { - let s = match plan { - LogicalPlan::Projection { .. } => "post_visit Projection", - LogicalPlan::Filter { .. } => "post_visit Filter", - LogicalPlan::TableScan { .. } => "post_visit TableScan", - _ => unimplemented!("unknown plan type"), - }; - - self.strings.push(s.into()); - Ok(true) - } - } - - #[test] - fn visit_order() { - let mut visitor = OkVisitor::default(); - let plan = test_plan(); - let res = plan.accept(&mut visitor); - assert!(res.is_ok()); - - assert_eq!( - visitor.strings, - vec![ - "pre_visit Projection", - "pre_visit Filter", - "pre_visit TableScan", - "post_visit TableScan", - "post_visit Filter", - "post_visit Projection", - ] - ); - } - - #[derive(Debug, Default)] - /// Counter than counts to zero and returns true when it gets there - struct OptionalCounter { - val: Option, - } - - impl OptionalCounter { - fn new(val: usize) -> Self { - Self { val: Some(val) } - } - // Decrements the counter by 1, if any, returning true if it hits zero - fn dec(&mut self) -> bool { - if Some(0) == self.val { - true - } else { - self.val = self.val.take().map(|i| i - 1); - false - } - } - } - - #[derive(Debug, Default)] - /// Visitor that returns false after some number of visits - struct StoppingVisitor { - inner: OkVisitor, - /// When Some(0) returns false from pre_visit - return_false_from_pre_in: OptionalCounter, - /// When Some(0) returns false from post_visit - return_false_from_post_in: OptionalCounter, - } - - impl PlanVisitor for StoppingVisitor { - type Error = String; - - fn pre_visit( - &mut self, - plan: &LogicalPlan, - ) -> std::result::Result { - if self.return_false_from_pre_in.dec() { - return Ok(false); - } - self.inner.pre_visit(plan) - } - - fn post_visit( - &mut self, - plan: &LogicalPlan, - ) -> std::result::Result { - if self.return_false_from_post_in.dec() { - return Ok(false); - } - - self.inner.post_visit(plan) - } - } - - /// test early stopping in pre-visit - #[test] - fn early_stopping_pre_visit() { - let mut visitor = StoppingVisitor { - return_false_from_pre_in: OptionalCounter::new(2), - ..Default::default() - }; - let plan = test_plan(); - let res = plan.accept(&mut visitor); - assert!(res.is_ok()); - - assert_eq!( - visitor.inner.strings, - vec!["pre_visit Projection", "pre_visit Filter"] - ); - } - - #[test] - fn early_stopping_post_visit() { - let mut visitor = StoppingVisitor { - return_false_from_post_in: OptionalCounter::new(1), - ..Default::default() - }; - let plan = test_plan(); - let res = plan.accept(&mut visitor); - assert!(res.is_ok()); - - assert_eq!( - visitor.inner.strings, - vec![ - "pre_visit Projection", - "pre_visit Filter", - "pre_visit TableScan", - "post_visit TableScan", - ] - ); - } - - #[derive(Debug, Default)] - /// Visitor that returns an error after some number of visits - struct ErrorVisitor { - inner: OkVisitor, - /// When Some(0) returns false from pre_visit - return_error_from_pre_in: OptionalCounter, - /// When Some(0) returns false from post_visit - return_error_from_post_in: OptionalCounter, - } - - impl PlanVisitor for ErrorVisitor { - type Error = String; - - fn pre_visit( - &mut self, - plan: &LogicalPlan, - ) -> std::result::Result { - if self.return_error_from_pre_in.dec() { - return Err("Error in pre_visit".into()); - } - - self.inner.pre_visit(plan) - } - - fn post_visit( - &mut self, - plan: &LogicalPlan, - ) -> std::result::Result { - if self.return_error_from_post_in.dec() { - return Err("Error in post_visit".into()); - } - - self.inner.post_visit(plan) - } - } - - #[test] - fn error_pre_visit() { - let mut visitor = ErrorVisitor { - return_error_from_pre_in: OptionalCounter::new(2), - ..Default::default() - }; - let plan = test_plan(); - let res = plan.accept(&mut visitor); - - if let Err(e) = res { - assert_eq!("Error in pre_visit", e); - } else { - panic!("Expected an error"); - } - - assert_eq!( - visitor.inner.strings, - vec!["pre_visit Projection", "pre_visit Filter"] - ); - } - - #[test] - fn error_post_visit() { - let mut visitor = ErrorVisitor { - return_error_from_post_in: OptionalCounter::new(1), - ..Default::default() - }; - let plan = test_plan(); - let res = plan.accept(&mut visitor); - if let Err(e) = res { - assert_eq!("Error in post_visit", e); - } else { - panic!("Expected an error"); - } - - assert_eq!( - visitor.inner.strings, - vec![ - "pre_visit Projection", - "pre_visit Filter", - "pre_visit TableScan", - "post_visit TableScan", - ] - ); - } - - fn test_plan() -> LogicalPlan { - let schema = Schema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("state", DataType::Utf8, false), - ]); - - scan_empty(None, &schema, Some(vec![0, 1])) - .unwrap() - .filter(col("state").eq(lit("CO"))) - .unwrap() - .project(vec![col("id")]) - .unwrap() - .build() - .unwrap() - } -} diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 7174a4c7788a..022f50ea320c 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -1344,3 +1344,337 @@ pub trait ToStringifiedPlan { /// Create a stringified plan with the specified type fn to_stringified(&self, plan_type: PlanType) -> StringifiedPlan; } + +#[cfg(test)] +mod tests { + use super::*; + use crate::logical_plan::table_scan; + use crate::{col, lit}; + use arrow::datatypes::{DataType, Field, Schema}; + + fn employee_schema() -> Schema { + Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("first_name", DataType::Utf8, false), + Field::new("last_name", DataType::Utf8, false), + Field::new("state", DataType::Utf8, false), + Field::new("salary", DataType::Int32, false), + ]) + } + + fn display_plan() -> LogicalPlan { + table_scan(Some("employee_csv"), &employee_schema(), Some(vec![0, 3])) + .unwrap() + .filter(col("state").eq(lit("CO"))) + .unwrap() + .project(vec![col("id")]) + .unwrap() + .build() + .unwrap() + } + + #[test] + fn test_display_indent() { + let plan = display_plan(); + + let expected = "Projection: #employee_csv.id\ + \n Filter: #employee_csv.state = Utf8(\"CO\")\ + \n TableScan: employee_csv projection=Some([id, state])"; + + assert_eq!(expected, format!("{}", plan.display_indent())); + } + + #[test] + fn test_display_indent_schema() { + let plan = display_plan(); + + let expected = "Projection: #employee_csv.id [id:Int32]\ + \n Filter: #employee_csv.state = Utf8(\"CO\") [id:Int32, state:Utf8]\ + \n TableScan: employee_csv projection=Some([id, state]) [id:Int32, state:Utf8]"; + + assert_eq!(expected, format!("{}", plan.display_indent_schema())); + } + + #[test] + fn test_display_graphviz() { + let plan = display_plan(); + + // just test for a few key lines in the output rather than the + // whole thing to make test mainteance easier. + let graphviz = format!("{}", plan.display_graphviz()); + + assert!( + graphviz.contains( + r#"// Begin DataFusion GraphViz Plan (see https://graphviz.org)"# + ), + "\n{}", + plan.display_graphviz() + ); + assert!( + graphviz.contains( + r#"[shape=box label="TableScan: employee_csv projection=Some([id, state])"]"# + ), + "\n{}", + plan.display_graphviz() + ); + assert!(graphviz.contains(r#"[shape=box label="TableScan: employee_csv projection=Some([id, state])\nSchema: [id:Int32, state:Utf8]"]"#), + "\n{}", plan.display_graphviz()); + assert!( + graphviz.contains(r#"// End DataFusion GraphViz Plan"#), + "\n{}", + plan.display_graphviz() + ); + } + + /// Tests for the Visitor trait and walking logical plan nodes + #[derive(Debug, Default)] + struct OkVisitor { + strings: Vec, + } + + impl PlanVisitor for OkVisitor { + type Error = String; + + fn pre_visit( + &mut self, + plan: &LogicalPlan, + ) -> std::result::Result { + let s = match plan { + LogicalPlan::Projection { .. } => "pre_visit Projection", + LogicalPlan::Filter { .. } => "pre_visit Filter", + LogicalPlan::TableScan { .. } => "pre_visit TableScan", + _ => unimplemented!("unknown plan type"), + }; + + self.strings.push(s.into()); + Ok(true) + } + + fn post_visit( + &mut self, + plan: &LogicalPlan, + ) -> std::result::Result { + let s = match plan { + LogicalPlan::Projection { .. } => "post_visit Projection", + LogicalPlan::Filter { .. } => "post_visit Filter", + LogicalPlan::TableScan { .. } => "post_visit TableScan", + _ => unimplemented!("unknown plan type"), + }; + + self.strings.push(s.into()); + Ok(true) + } + } + + #[test] + fn visit_order() { + let mut visitor = OkVisitor::default(); + let plan = test_plan(); + let res = plan.accept(&mut visitor); + assert!(res.is_ok()); + + assert_eq!( + visitor.strings, + vec![ + "pre_visit Projection", + "pre_visit Filter", + "pre_visit TableScan", + "post_visit TableScan", + "post_visit Filter", + "post_visit Projection", + ] + ); + } + + #[derive(Debug, Default)] + /// Counter than counts to zero and returns true when it gets there + struct OptionalCounter { + val: Option, + } + + impl OptionalCounter { + fn new(val: usize) -> Self { + Self { val: Some(val) } + } + // Decrements the counter by 1, if any, returning true if it hits zero + fn dec(&mut self) -> bool { + if Some(0) == self.val { + true + } else { + self.val = self.val.take().map(|i| i - 1); + false + } + } + } + + #[derive(Debug, Default)] + /// Visitor that returns false after some number of visits + struct StoppingVisitor { + inner: OkVisitor, + /// When Some(0) returns false from pre_visit + return_false_from_pre_in: OptionalCounter, + /// When Some(0) returns false from post_visit + return_false_from_post_in: OptionalCounter, + } + + impl PlanVisitor for StoppingVisitor { + type Error = String; + + fn pre_visit( + &mut self, + plan: &LogicalPlan, + ) -> std::result::Result { + if self.return_false_from_pre_in.dec() { + return Ok(false); + } + self.inner.pre_visit(plan) + } + + fn post_visit( + &mut self, + plan: &LogicalPlan, + ) -> std::result::Result { + if self.return_false_from_post_in.dec() { + return Ok(false); + } + + self.inner.post_visit(plan) + } + } + + /// test early stopping in pre-visit + #[test] + fn early_stopping_pre_visit() { + let mut visitor = StoppingVisitor { + return_false_from_pre_in: OptionalCounter::new(2), + ..Default::default() + }; + let plan = test_plan(); + let res = plan.accept(&mut visitor); + assert!(res.is_ok()); + + assert_eq!( + visitor.inner.strings, + vec!["pre_visit Projection", "pre_visit Filter"] + ); + } + + #[test] + fn early_stopping_post_visit() { + let mut visitor = StoppingVisitor { + return_false_from_post_in: OptionalCounter::new(1), + ..Default::default() + }; + let plan = test_plan(); + let res = plan.accept(&mut visitor); + assert!(res.is_ok()); + + assert_eq!( + visitor.inner.strings, + vec![ + "pre_visit Projection", + "pre_visit Filter", + "pre_visit TableScan", + "post_visit TableScan", + ] + ); + } + + #[derive(Debug, Default)] + /// Visitor that returns an error after some number of visits + struct ErrorVisitor { + inner: OkVisitor, + /// When Some(0) returns false from pre_visit + return_error_from_pre_in: OptionalCounter, + /// When Some(0) returns false from post_visit + return_error_from_post_in: OptionalCounter, + } + + impl PlanVisitor for ErrorVisitor { + type Error = String; + + fn pre_visit( + &mut self, + plan: &LogicalPlan, + ) -> std::result::Result { + if self.return_error_from_pre_in.dec() { + return Err("Error in pre_visit".into()); + } + + self.inner.pre_visit(plan) + } + + fn post_visit( + &mut self, + plan: &LogicalPlan, + ) -> std::result::Result { + if self.return_error_from_post_in.dec() { + return Err("Error in post_visit".into()); + } + + self.inner.post_visit(plan) + } + } + + #[test] + fn error_pre_visit() { + let mut visitor = ErrorVisitor { + return_error_from_pre_in: OptionalCounter::new(2), + ..Default::default() + }; + let plan = test_plan(); + let res = plan.accept(&mut visitor); + + if let Err(e) = res { + assert_eq!("Error in pre_visit", e); + } else { + panic!("Expected an error"); + } + + assert_eq!( + visitor.inner.strings, + vec!["pre_visit Projection", "pre_visit Filter"] + ); + } + + #[test] + fn error_post_visit() { + let mut visitor = ErrorVisitor { + return_error_from_post_in: OptionalCounter::new(1), + ..Default::default() + }; + let plan = test_plan(); + let res = plan.accept(&mut visitor); + if let Err(e) = res { + assert_eq!("Error in post_visit", e); + } else { + panic!("Expected an error"); + } + + assert_eq!( + visitor.inner.strings, + vec![ + "pre_visit Projection", + "pre_visit Filter", + "pre_visit TableScan", + "post_visit TableScan", + ] + ); + } + + fn test_plan() -> LogicalPlan { + let schema = Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("state", DataType::Utf8, false), + ]); + + table_scan(None, &schema, Some(vec![0, 1])) + .unwrap() + .filter(col("state").eq(lit("CO"))) + .unwrap() + .project(vec![col("id")]) + .unwrap() + .build() + .unwrap() + } +}