diff --git a/rust/datafusion/README.md b/rust/datafusion/README.md index 307443b02b4ac..114cc41b5186a 100644 --- a/rust/datafusion/README.md +++ b/rust/datafusion/README.md @@ -171,3 +171,24 @@ Below is a checklist of what you need to do to add a new aggregate function to D * a new line in `create_aggregate_expr` mapping the built-in to the implementation * tests to the function. * In [tests/sql.rs](tests/sql.rs), add a new test where the function is called through SQL against well known data and returns the expected result. + +## How to display plans graphically + +The query plans represented by `LogicalPlan` nodes can be graphically +rendered using [Graphviz](http://www.graphviz.org/). + +To do so, save the output of the `display_graphviz` function to a file.: + +```rust +// Create plan somehow... +let mut output = File::create("/tmp/plan.dot")?; +write!(output, "{}", plan.display_graphviz()); +``` + +Then, use the `dot` command line tool to render it into a file that +can be displayed. For example, the following command creates a +`/tmp/plan.pdf` file: + +```bash +dot -Tpdf < /tmp/plan.dot > /tmp/plan.pdf +``` diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index 19ee56d76ba3a..05633f04e4d5d 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -21,8 +21,8 @@ //! Logical query plans can then be optimized and executed directly, or translated into //! physical query plans and executed. -use fmt::Debug; -use std::{any::Any, collections::HashMap, collections::HashSet, fmt, sync::Arc}; +use std::fmt::{self, Debug, Display}; +use std::{any::Any, collections::HashMap, collections::HashSet, sync::Arc}; use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction}; use arrow::{ @@ -956,117 +956,567 @@ impl LogicalPlan { } } +/// Trait that implements the [Visitor +/// pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for a +/// depth first walk of `LogicalPlan` nodes. `pre_visit` is called +/// before any children are visited, and then `post_visit` is called +/// after all children have been visited. +//// +/// To use, define a struct that implements this trait and then invoke +/// "LogicalPlan::accept". +/// +/// For example, for a logical plan like: +/// +/// Projection: #id +/// Filter: #state Eq Utf8(\"CO\")\ +/// CsvScan: employee.csv projection=Some([0, 3])"; +/// +/// The sequence of visit operations would be: +/// ```text +/// visitor.pre_visit(Projection) +/// visitor.pre_visit(Filter) +/// visitor.pre_visit(CsvScan) +/// visitor.post_visit(CsvScan) +/// visitor.post_visit(Filter) +/// visitor.post_visit(Projection) +/// ``` +pub trait PlanVisitor { + /// The type of error returned by this visitor + type Error; + + /// Invoked on a logical plan before any of its child inputs have been + /// visited. If Ok(true) is returned, the recursion continues. If + /// Err(..) or Ok(false) are returned, the recursion stops + /// immedately and the error, if any, is returned to `accept` + fn pre_visit(&mut self, plan: &LogicalPlan) + -> std::result::Result; + + /// Invoked on a logical plan after all of its child inputs have + /// been visited. The return value is handled the same as the + /// return value of `pre_visit`. The provided default implementation + /// returns `Ok(true)`. + fn post_visit( + &mut self, + _plan: &LogicalPlan, + ) -> std::result::Result { + Ok(true) + } +} + impl LogicalPlan { - fn fmt_with_indent(&self, f: &mut fmt::Formatter, indent: usize) -> fmt::Result { - if indent > 0 { - writeln!(f)?; - for _ in 0..indent { - write!(f, " ")?; - } + /// returns all inputs in the logical plan. Returns Ok(true) if + /// all nodes were visited, and Ok(false) if any call to + /// `pre_visit` or `post_visit` returned Ok(false) and may have + /// cut short the recursion + pub fn accept(&self, visitor: &mut V) -> std::result::Result + where + V: PlanVisitor, + { + if !visitor.pre_visit(self)? { + return Ok(false); } - match *self { - LogicalPlan::EmptyRelation { .. } => write!(f, "EmptyRelation"), - LogicalPlan::TableScan { - ref source, - ref projection, - .. - } => match source { - TableSource::FromContext(table_name) => { - write!(f, "TableScan: {} projection={:?}", table_name, projection) - } - TableSource::FromProvider(_) => { - write!(f, "TableScan: projection={:?}", projection) - } - }, - LogicalPlan::InMemoryScan { ref projection, .. } => { - write!(f, "InMemoryScan: projection={:?}", projection) - } - LogicalPlan::CsvScan { - ref path, - ref projection, - .. - } => write!(f, "CsvScan: {} projection={:?}", path, projection), - LogicalPlan::ParquetScan { - ref path, - ref projection, - .. - } => write!(f, "ParquetScan: {} projection={:?}", path, projection), - LogicalPlan::Projection { - ref expr, - ref input, - .. - } => { - write!(f, "Projection: ")?; - for i in 0..expr.len() { - if i > 0 { - write!(f, ", ")?; + + let recurse = match self { + LogicalPlan::Projection { input, .. } => input.accept(visitor)?, + LogicalPlan::Filter { input, .. } => input.accept(visitor)?, + LogicalPlan::Aggregate { input, .. } => input.accept(visitor)?, + LogicalPlan::Sort { input, .. } => input.accept(visitor)?, + LogicalPlan::Limit { input, .. } => input.accept(visitor)?, + LogicalPlan::Extension { node } => { + for input in node.inputs() { + if !input.accept(visitor)? { + return Ok(false); } - write!(f, "{:?}", expr[i])?; } - input.fmt_with_indent(f, indent + 1) + true } - LogicalPlan::Filter { - predicate: ref expr, - ref input, - .. - } => { - write!(f, "Filter: {:?}", expr)?; - input.fmt_with_indent(f, indent + 1) - } - LogicalPlan::Aggregate { - ref input, - ref group_expr, - ref aggr_expr, - .. - } => { + // plans without inputs + LogicalPlan::TableScan { .. } + | LogicalPlan::InMemoryScan { .. } + | LogicalPlan::ParquetScan { .. } + | LogicalPlan::CsvScan { .. } + | LogicalPlan::EmptyRelation { .. } + | LogicalPlan::CreateExternalTable { .. } + | LogicalPlan::Explain { .. } => true, + }; + if !recurse { + return Ok(false); + } + + if !visitor.post_visit(self)? { + return Ok(false); + } + + Ok(true) + } +} + +/// Formats plans with a single line per node. For example: +/// +/// Projection: #id +/// Filter: #state Eq Utf8(\"CO\")\ +/// CsvScan: employee.csv projection=Some([0, 3])"; +struct IndentVisitor<'a, 'b> { + f: &'a mut fmt::Formatter<'b>, + /// If true, includes summarized schema information + with_schema: bool, + indent: u32, +} + +impl<'a, 'b> IndentVisitor<'a, 'b> { + fn write_indent(&mut self) -> fmt::Result { + for _ in 0..self.indent { + write!(self.f, " ")?; + } + Ok(()) + } +} + +impl<'a, 'b> PlanVisitor for IndentVisitor<'a, 'b> { + type Error = fmt::Error; + + fn pre_visit(&mut self, plan: &LogicalPlan) -> std::result::Result { + if self.indent > 0 { + writeln!(self.f)?; + } + self.write_indent()?; + + write!(self.f, "{}", plan.display())?; + if self.with_schema { + write!(self.f, " {}", display_schema(plan.schema()))?; + } + + self.indent += 1; + Ok(true) + } + + fn post_visit( + &mut self, + _plan: &LogicalPlan, + ) -> std::result::Result { + self.indent -= 1; + Ok(true) + } +} + +/// Print the schema in a compact representation to `buf` +/// +/// For example: `foo:Utf8` if `foo` can not be null, and +/// `foo:Utf8;N` if `foo` is nullable. +/// +/// ``` +/// use arrow::datatypes::{Field, Schema, DataType}; +/// # use datafusion::logical_plan::display_schema; +/// let schema = Schema::new(vec![ +/// Field::new("id", DataType::Int32, false), +/// Field::new("first_name", DataType::Utf8, true), +/// ]); +/// +/// assert_eq!( +/// "[id:Int32, first_name:Utf8;N]", +/// format!("{}", display_schema(&schema)) +/// ); +/// ``` +pub fn display_schema<'a>(schema: &'a Schema) -> impl fmt::Display + 'a { + struct Wrapper<'a>(&'a Schema); + + impl<'a> fmt::Display for Wrapper<'a> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "[")?; + for (idx, field) in self.0.fields().iter().enumerate() { + if idx > 0 { + write!(f, ", ")?; + } + let nullable_str = if field.is_nullable() { ";N" } else { "" }; write!( f, - "Aggregate: groupBy=[{:?}], aggr=[{:?}]", - group_expr, aggr_expr + "{}:{:?}{}", + field.name(), + field.data_type(), + nullable_str )?; - input.fmt_with_indent(f, indent + 1) - } - LogicalPlan::Sort { - ref input, - ref expr, - .. - } => { - write!(f, "Sort: ")?; - for i in 0..expr.len() { - if i > 0 { - write!(f, ", ")?; - } - write!(f, "{:?}", expr[i])?; - } - input.fmt_with_indent(f, indent + 1) } - LogicalPlan::Limit { - ref input, ref n, .. - } => { - write!(f, "Limit: {}", n)?; - input.fmt_with_indent(f, indent + 1) + write!(f, "]") + } + } + Wrapper(schema) +} + +/// Logic related to creating DOT language graphs. +#[derive(Default)] +struct GraphvizBuilder { + id_gen: usize, +} + +impl GraphvizBuilder { + fn next_id(&mut self) -> usize { + self.id_gen += 1; + self.id_gen + } + + // write out the start of the subgraph cluster + fn start_cluster(&mut self, f: &mut fmt::Formatter, title: &str) -> fmt::Result { + writeln!(f, " subgraph cluster_{}", self.next_id())?; + writeln!(f, " {{")?; + writeln!(f, " graph[label={}]", Self::quoted(title)) + } + + // write out the end of the subgraph cluster + fn end_cluster(&mut self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!(f, " }}") + } + + /// makes a quoted string suitable for inclusion in a graphviz chart + fn quoted(label: &str) -> String { + let label = label.replace('"', "_"); + format!("\"{}\"", label) + } +} + +/// Formats plans for graphical display using the `DOT` language. This +/// format can be visualized using software from +/// [`graphviz`](https://graphviz.org/) +struct GraphvizVisitor<'a, 'b> { + f: &'a mut fmt::Formatter<'b>, + graphviz_builder: GraphvizBuilder, + /// If true, includes summarized schema information + with_schema: bool, + + /// Holds the ids (as generated from `graphviz_builder` of all + /// parent nodes + parent_ids: Vec, +} + +impl<'a, 'b> GraphvizVisitor<'a, 'b> { + fn new(f: &'a mut fmt::Formatter<'b>) -> Self { + Self { + f, + graphviz_builder: GraphvizBuilder::default(), + with_schema: false, + parent_ids: Vec::new(), + } + } + + /// Sets a flag which controls if the output schema is displayed + fn set_with_schema(&mut self, with_schema: bool) { + self.with_schema = with_schema; + } + + fn pre_visit_plan(&mut self, label: &str) -> fmt::Result { + self.graphviz_builder.start_cluster(self.f, label) + } + + fn post_visit_plan(&mut self) -> fmt::Result { + self.graphviz_builder.end_cluster(self.f) + } +} + +impl<'a, 'b> PlanVisitor for GraphvizVisitor<'a, 'b> { + type Error = fmt::Error; + + fn pre_visit(&mut self, plan: &LogicalPlan) -> std::result::Result { + let id = self.graphviz_builder.next_id(); + + // Create a new graph node for `plan` such as + // id [label="foo"] + let label = if self.with_schema { + format!( + "{}\\nSchema: {}", + plan.display(), + display_schema(plan.schema()) + ) + } else { + format!("{}", plan.display()) + }; + + writeln!( + self.f, + " {}[shape=box label={}]", + id, + GraphvizBuilder::quoted(&label) + )?; + + // Create an edge to our parent node, if any + // parent_id -> id + if let Some(parent_id) = self.parent_ids.last() { + writeln!( + self.f, + " {} -> {} [arrowhead=none, arrowtail=normal, dir=back]", + parent_id, id + )?; + } + + self.parent_ids.push(id); + Ok(true) + } + + fn post_visit( + &mut self, + _plan: &LogicalPlan, + ) -> std::result::Result { + // always be non-empty as pre_visit always pushes + self.parent_ids.pop().unwrap(); + Ok(true) + } +} + +// Various implementations for printing out LogicalPlans +impl LogicalPlan { + /// Return a `format`able structure that produces a single line + /// per node. For example: + /// + /// ```text + /// Projection: #id + /// Filter: #state Eq Utf8(\"CO\")\ + /// CsvScan: employee.csv projection=Some([0, 3]) + /// ``` + /// + /// ``` + /// use arrow::datatypes::{Field, Schema, DataType}; + /// use datafusion::logical_plan::{lit, col, LogicalPlanBuilder}; + /// let schema = Schema::new(vec![ + /// Field::new("id", DataType::Int32, false), + /// ]); + /// let plan = LogicalPlanBuilder::scan("default", "foo.csv", &schema, None).unwrap() + /// .filter(col("id").eq(lit(5))).unwrap() + /// .build().unwrap(); + /// + /// // Format using display_indent + /// let display_string = format!("{}", plan.display_indent()); + /// + /// assert_eq!("Filter: #id Eq Int32(5)\ + /// \n TableScan: foo.csv projection=None", + /// display_string); + /// ``` + pub fn display_indent<'a>(&'a self) -> impl fmt::Display + 'a { + // Boilerplate structure to wrap LogicalPlan with something + // that that can be formatted + struct Wrapper<'a>(&'a LogicalPlan); + impl<'a> fmt::Display for Wrapper<'a> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let with_schema = false; + let mut visitor = IndentVisitor { + f, + with_schema, + indent: 0, + }; + self.0.accept(&mut visitor).unwrap(); + Ok(()) } - LogicalPlan::CreateExternalTable { ref name, .. } => { - write!(f, "CreateExternalTable: {:?}", name) + } + Wrapper(self) + } + + /// Return a `format`able structure that produces a single line + /// per node that includes the output schema. For example: + /// + /// ```text + /// Projection: #id [id:Int32]\ + /// Filter: #state Eq Utf8(\"CO\") [id:Int32, state:Utf8]\ + /// TableScan: employee.csv projection=Some([0, 3]) [id:Int32, state:Utf8]"; + /// ``` + /// + /// ``` + /// use arrow::datatypes::{Field, Schema, DataType}; + /// use datafusion::logical_plan::{lit, col, LogicalPlanBuilder}; + /// let schema = Schema::new(vec![ + /// Field::new("id", DataType::Int32, false), + /// ]); + /// let plan = LogicalPlanBuilder::scan("default", "foo.csv", &schema, None).unwrap() + /// .filter(col("id").eq(lit(5))).unwrap() + /// .build().unwrap(); + /// + /// // Format using display_indent_schema + /// let display_string = format!("{}", plan.display_indent_schema()); + /// + /// assert_eq!("Filter: #id Eq Int32(5) [id:Int32]\ + /// \n TableScan: foo.csv projection=None [id:Int32]", + /// display_string); + /// ``` + pub fn display_indent_schema<'a>(&'a self) -> impl fmt::Display + 'a { + // Boilerplate structure to wrap LogicalPlan with something + // that that can be formatted + struct Wrapper<'a>(&'a LogicalPlan); + impl<'a> fmt::Display for Wrapper<'a> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let with_schema = true; + let mut visitor = IndentVisitor { + f, + with_schema, + indent: 0, + }; + self.0.accept(&mut visitor).unwrap(); + Ok(()) } - LogicalPlan::Explain { ref plan, .. } => { - write!(f, "Explain")?; - plan.fmt_with_indent(f, indent + 1) + } + Wrapper(self) + } + + /// Return a `format`able structure that produces lines meant for + /// graphical display using the `DOT` language. This format can be + /// visualized using software from + /// [`graphviz`](https://graphviz.org/) + /// + /// This currently produces two graphs -- one with the basic + /// structure, and one with additional details such as schema. + /// + /// ``` + /// use arrow::datatypes::{Field, Schema, DataType}; + /// use datafusion::logical_plan::{lit, col, LogicalPlanBuilder}; + /// let schema = Schema::new(vec![ + /// Field::new("id", DataType::Int32, false), + /// ]); + /// let plan = LogicalPlanBuilder::scan("default", "foo.csv", &schema, None).unwrap() + /// .filter(col("id").eq(lit(5))).unwrap() + /// .build().unwrap(); + /// + /// // Format using display_graphviz + /// let graphviz_string = format!("{}", plan.display_graphviz()); + /// ``` + /// + /// If graphviz string is saved to a file such as `/tmp/example.dot`, the following + /// commands can be used to render it as a pdf: + /// + /// ```bash + /// dot -Tpdf < /tmp/example.dot > /tmp/example.pdf + /// ``` + /// + pub fn display_graphviz<'a>(&'a self) -> impl fmt::Display + 'a { + // Boilerplate structure to wrap LogicalPlan with something + // that that can be formatted + struct Wrapper<'a>(&'a LogicalPlan); + impl<'a> fmt::Display for Wrapper<'a> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!( + f, + "// Begin DataFusion GraphViz Plan (see https://graphviz.org)" + )?; + writeln!(f, "digraph {{")?; + + let mut visitor = GraphvizVisitor::new(f); + + visitor.pre_visit_plan("LogicalPlan")?; + self.0.accept(&mut visitor).unwrap(); + visitor.post_visit_plan()?; + + visitor.set_with_schema(true); + visitor.pre_visit_plan("Detailed LogicalPlan")?; + self.0.accept(&mut visitor).unwrap(); + visitor.post_visit_plan()?; + + writeln!(f, "}}")?; + writeln!(f, "// End DataFusion GraphViz Plan")?; + Ok(()) } - LogicalPlan::Extension { ref node } => { - node.fmt_for_explain(f)?; - node.inputs() - .iter() - .map(|input| input.fmt_with_indent(f, indent + 1)) - .collect() + } + Wrapper(self) + } + + /// Return a `format`able structure with the a human readable + /// description of this LogicalPlan node per node, not including + /// children. For example: + /// + /// ```text + /// Projection: #id + /// ``` + /// ``` + /// use arrow::datatypes::{Field, Schema, DataType}; + /// use datafusion::logical_plan::{lit, col, LogicalPlanBuilder}; + /// let schema = Schema::new(vec![ + /// Field::new("id", DataType::Int32, false), + /// ]); + /// let plan = LogicalPlanBuilder::scan("default", "foo.csv", &schema, None).unwrap() + /// .build().unwrap(); + /// + /// // Format using display + /// let display_string = format!("{}", plan.display()); + /// + /// assert_eq!("TableScan: foo.csv projection=None", display_string); + /// ``` + pub fn display<'a>(&'a self) -> impl fmt::Display + 'a { + // Boilerplate structure to wrap LogicalPlan with something + // that that can be formatted + struct Wrapper<'a>(&'a LogicalPlan); + impl<'a> fmt::Display for Wrapper<'a> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self.0 { + LogicalPlan::EmptyRelation { .. } => write!(f, "EmptyRelation"), + LogicalPlan::TableScan { + ref source, + ref projection, + .. + } => match source { + TableSource::FromContext(table_name) => write!( + f, + "TableScan: {} projection={:?}", + table_name, projection + ), + TableSource::FromProvider(_) => { + write!(f, "TableScan: projection={:?}", projection) + } + }, + LogicalPlan::InMemoryScan { ref projection, .. } => { + write!(f, "InMemoryScan: projection={:?}", projection) + } + LogicalPlan::CsvScan { + ref path, + ref projection, + .. + } => write!(f, "CsvScan: {} projection={:?}", path, projection), + LogicalPlan::ParquetScan { + ref path, + ref projection, + .. + } => write!(f, "ParquetScan: {} projection={:?}", path, projection), + LogicalPlan::Projection { ref expr, .. } => { + write!(f, "Projection: ")?; + for i in 0..expr.len() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{:?}", expr[i])?; + } + Ok(()) + } + LogicalPlan::Filter { + predicate: ref expr, + .. + } => write!(f, "Filter: {:?}", expr), + LogicalPlan::Aggregate { + ref group_expr, + ref aggr_expr, + .. + } => write!( + f, + "Aggregate: groupBy=[{:?}], aggr=[{:?}]", + group_expr, aggr_expr + ), + LogicalPlan::Sort { ref expr, .. } => { + write!(f, "Sort: ")?; + for i in 0..expr.len() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{:?}", expr[i])?; + } + Ok(()) + } + LogicalPlan::Limit { ref n, .. } => write!(f, "Limit: {}", n), + LogicalPlan::CreateExternalTable { ref name, .. } => { + write!(f, "CreateExternalTable: {:?}", name) + } + LogicalPlan::Explain { .. } => write!(f, "Explain"), + LogicalPlan::Extension { ref node } => node.fmt_for_explain(f), + } } } + Wrapper(self) } } impl fmt::Debug for LogicalPlan { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - self.fmt_with_indent(f, 0) + self.display_indent().fmt(f) } } @@ -1529,4 +1979,342 @@ mod tests { Ok(()) } + + #[test] + fn test_visitor() { + let schema = Schema::new(vec![]); + assert_eq!("[]", format!("{}", display_schema(&schema))); + } + + #[test] + fn test_display_empty_schema() { + let schema = Schema::new(vec![]); + assert_eq!("[]", format!("{}", display_schema(&schema))); + } + + #[test] + fn test_display_schema() { + let schema = Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("first_name", DataType::Utf8, true), + ]); + + assert_eq!( + "[id:Int32, first_name:Utf8;N]", + format!("{}", display_schema(&schema)) + ); + } + + fn display_plan() -> LogicalPlan { + LogicalPlanBuilder::scan( + "default", + "employee.csv", + &employee_schema(), + Some(vec![0, 3]), + ) + .unwrap() + .filter(col("state").eq(lit("CO"))) + .unwrap() + .project(vec![col("id")]) + .unwrap() + .build() + .unwrap() + } + + #[test] + fn test_display_indent() { + let plan = display_plan(); + + let expected = "Projection: #id\ + \n Filter: #state Eq Utf8(\"CO\")\ + \n TableScan: employee.csv projection=Some([0, 3])"; + + assert_eq!(expected, format!("{}", plan.display_indent())); + } + + #[test] + fn test_display_indent_schema() { + let plan = display_plan(); + + let expected = "Projection: #id [id:Int32]\ + \n Filter: #state Eq Utf8(\"CO\") [id:Int32, state:Utf8]\ + \n TableScan: employee.csv projection=Some([0, 3]) [id:Int32, state:Utf8]"; + + assert_eq!(expected, format!("{}", plan.display_indent_schema())); + } + + #[test] + fn test_display_graphviz() { + let plan = display_plan(); + + // just test for a few key lines in the output rather than the + // whole thing to make test mainteance easier. + let graphviz = format!("{}", plan.display_graphviz()); + + assert!( + graphviz.contains( + r#"// Begin DataFusion GraphViz Plan (see https://graphviz.org)"# + ), + "\n{}", + plan.display_graphviz() + ); + assert!( + graphviz.contains( + r#"[shape=box label="TableScan: employee.csv projection=Some([0, 3])"]"# + ), + "\n{}", + plan.display_graphviz() + ); + assert!(graphviz.contains(r#"[shape=box label="TableScan: employee.csv projection=Some([0, 3])\nSchema: [id:Int32, state:Utf8]"]"#), + "\n{}", plan.display_graphviz()); + assert!( + graphviz.contains(r#"// End DataFusion GraphViz Plan"#), + "\n{}", + plan.display_graphviz() + ); + } +} + +#[cfg(test)] +/// Tests for the Visitor trait and walking logical plan nodes +mod test_visitor { + use super::*; + + #[derive(Debug, Default)] + struct OkVisitor { + strings: Vec, + } + impl PlanVisitor for OkVisitor { + type Error = String; + + fn pre_visit( + &mut self, + plan: &LogicalPlan, + ) -> std::result::Result { + let s = match plan { + LogicalPlan::Projection { .. } => "pre_visit Projection", + LogicalPlan::Filter { .. } => "pre_visit Filter", + LogicalPlan::TableScan { .. } => "pre_visit TableScan", + _ => unimplemented!("unknown plan type"), + }; + + self.strings.push(s.into()); + Ok(true) + } + + fn post_visit( + &mut self, + plan: &LogicalPlan, + ) -> std::result::Result { + let s = match plan { + LogicalPlan::Projection { .. } => "post_visit Projection", + LogicalPlan::Filter { .. } => "post_visit Filter", + LogicalPlan::TableScan { .. } => "post_visit TableScan", + _ => unimplemented!("unknown plan type"), + }; + + self.strings.push(s.into()); + Ok(true) + } + } + + #[test] + fn visit_order() { + let mut visitor = OkVisitor::default(); + let plan = test_plan(); + let res = plan.accept(&mut visitor); + assert!(res.is_ok()); + + assert_eq!( + visitor.strings, + vec![ + "pre_visit Projection", + "pre_visit Filter", + "pre_visit TableScan", + "post_visit TableScan", + "post_visit Filter", + "post_visit Projection" + ] + ); + } + + #[derive(Debug, Default)] + /// Counter than counts to zero and returns true when it gets there + struct OptionalCounter { + val: Option, + } + impl OptionalCounter { + fn new(val: usize) -> Self { + Self { val: Some(val) } + } + // Decrements the counter by 1, if any, returning true if it hits zero + fn dec(&mut self) -> bool { + if Some(0) == self.val { + true + } else { + self.val = self.val.take().map(|i| i - 1); + false + } + } + } + + #[derive(Debug, Default)] + /// Visitor that returns false after some number of visits + struct StoppingVisitor { + inner: OkVisitor, + /// When Some(0) returns false from pre_visit + return_false_from_pre_in: OptionalCounter, + /// When Some(0) returns false from post_visit + return_false_from_post_in: OptionalCounter, + } + + impl PlanVisitor for StoppingVisitor { + type Error = String; + + fn pre_visit( + &mut self, + plan: &LogicalPlan, + ) -> std::result::Result { + if self.return_false_from_pre_in.dec() { + return Ok(false); + } + self.inner.pre_visit(plan) + } + + fn post_visit( + &mut self, + plan: &LogicalPlan, + ) -> std::result::Result { + if self.return_false_from_post_in.dec() { + return Ok(false); + } + + self.inner.post_visit(plan) + } + } + + /// test earliy stopping in pre-visit + #[test] + fn early_stoping_pre_visit() { + let mut visitor = StoppingVisitor::default(); + visitor.return_false_from_pre_in = OptionalCounter::new(2); + let plan = test_plan(); + let res = plan.accept(&mut visitor); + assert!(res.is_ok()); + + assert_eq!( + visitor.inner.strings, + vec!["pre_visit Projection", "pre_visit Filter",] + ); + } + + #[test] + fn early_stoping_post_visit() { + let mut visitor = StoppingVisitor::default(); + visitor.return_false_from_post_in = OptionalCounter::new(1); + let plan = test_plan(); + let res = plan.accept(&mut visitor); + assert!(res.is_ok()); + + assert_eq!( + visitor.inner.strings, + vec![ + "pre_visit Projection", + "pre_visit Filter", + "pre_visit TableScan", + "post_visit TableScan", + ] + ); + } + + #[derive(Debug, Default)] + /// Visitor that returns an error after some number of visits + struct ErrorVisitor { + inner: OkVisitor, + /// When Some(0) returns false from pre_visit + return_error_from_pre_in: OptionalCounter, + /// When Some(0) returns false from post_visit + return_error_from_post_in: OptionalCounter, + } + + impl PlanVisitor for ErrorVisitor { + type Error = String; + + fn pre_visit( + &mut self, + plan: &LogicalPlan, + ) -> std::result::Result { + if self.return_error_from_pre_in.dec() { + return Err("Error in pre_visit".into()); + } + + self.inner.pre_visit(plan) + } + + fn post_visit( + &mut self, + plan: &LogicalPlan, + ) -> std::result::Result { + if self.return_error_from_post_in.dec() { + return Err("Error in post_visit".into()); + } + + self.inner.post_visit(plan) + } + } + + #[test] + fn error_pre_visit() { + let mut visitor = ErrorVisitor::default(); + visitor.return_error_from_pre_in = OptionalCounter::new(2); + let plan = test_plan(); + let res = plan.accept(&mut visitor); + + if let Err(e) = res { + assert_eq!("Error in pre_visit", e); + } else { + panic!("Expected an error"); + } + + assert_eq!( + visitor.inner.strings, + vec!["pre_visit Projection", "pre_visit Filter",] + ); + } + + #[test] + fn error_post_visit() { + let mut visitor = ErrorVisitor::default(); + visitor.return_error_from_post_in = OptionalCounter::new(1); + let plan = test_plan(); + let res = plan.accept(&mut visitor); + if let Err(e) = res { + assert_eq!("Error in post_visit", e); + } else { + panic!("Expected an error"); + } + + assert_eq!( + visitor.inner.strings, + vec![ + "pre_visit Projection", + "pre_visit Filter", + "pre_visit TableScan", + "post_visit TableScan", + ] + ); + } + + fn test_plan() -> LogicalPlan { + let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); + + LogicalPlanBuilder::scan("default", "employee.csv", &schema, Some(vec![0])) + .unwrap() + .filter(col("state").eq(lit("CO"))) + .unwrap() + .project(vec![col("id")]) + .unwrap() + .build() + .unwrap() + } }