Skip to content

Commit

Permalink
Minor: Use ExprVisitor to find columns referenced by expr (#2471)
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb authored May 6, 2022
1 parent 9c1b462 commit de7f15b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 72 deletions.
25 changes: 24 additions & 1 deletion datafusion/core/src/sql/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ use crate::{
error::{DataFusionError, Result},
logical_plan::{Column, ExpressionVisitor, Recursion},
};
use datafusion_expr::expr::find_columns_referenced_by_expr;
use std::collections::HashMap;

/// Collect all deeply nested `Expr::AggregateFunction` and
Expand Down Expand Up @@ -86,6 +85,30 @@ where
})
}

/// Recursively find all columns referenced by an expression
#[derive(Debug, Default)]
struct ColumnCollector {
exprs: Vec<Column>,
}

impl ExpressionVisitor for ColumnCollector {
fn pre_visit(mut self, expr: &Expr) -> Result<Recursion<Self>> {
if let Expr::Column(c) = expr {
self.exprs.push(c.clone())
}
Ok(Recursion::Continue(self))
}
}

fn find_columns_referenced_by_expr(e: &Expr) -> Vec<Column> {
// As the `ExpressionVisitor` impl above always returns Ok, this
// "can't" error
let ColumnCollector { exprs } = e
.accept(ColumnCollector::default())
.expect("Unexpected error");
exprs
}

// Visitor that find expressions that match a particular predicate
struct Finder<'a, F>
where
Expand Down
71 changes: 0 additions & 71 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,77 +251,6 @@ pub enum Expr {
QualifiedWildcard { qualifier: String },
}

/// Recursively find all columns referenced by an expression
pub fn find_columns_referenced_by_expr(e: &Expr) -> Vec<Column> {
match e {
Expr::Alias(expr, _)
| Expr::Negative(expr)
| Expr::Cast { expr, .. }
| Expr::TryCast { expr, .. }
| Expr::Sort { expr, .. }
| Expr::InList { expr, .. }
| Expr::InSubquery { expr, .. }
| Expr::GetIndexedField { expr, .. }
| Expr::Not(expr)
| Expr::IsNotNull(expr)
| Expr::IsNull(expr) => find_columns_referenced_by_expr(expr),
Expr::Column(c) => vec![c.clone()],
Expr::BinaryExpr { left, right, .. } => {
let mut cols = vec![];
cols.extend(find_columns_referenced_by_expr(left.as_ref()));
cols.extend(find_columns_referenced_by_expr(right.as_ref()));
cols
}
Expr::Case {
expr,
when_then_expr,
else_expr,
} => {
let mut cols = vec![];
if let Some(expr) = expr {
cols.extend(find_columns_referenced_by_expr(expr.as_ref()));
}
for (w, t) in when_then_expr {
cols.extend(find_columns_referenced_by_expr(w.as_ref()));
cols.extend(find_columns_referenced_by_expr(t.as_ref()));
}
if let Some(else_expr) = else_expr {
cols.extend(find_columns_referenced_by_expr(else_expr.as_ref()));
}
cols
}
Expr::ScalarFunction { args, .. } => args
.iter()
.flat_map(find_columns_referenced_by_expr)
.collect(),
Expr::AggregateFunction { args, .. } => args
.iter()
.flat_map(find_columns_referenced_by_expr)
.collect(),
Expr::ScalarVariable(_, _)
| Expr::Exists { .. }
| Expr::Wildcard
| Expr::QualifiedWildcard { .. }
| Expr::ScalarSubquery(_)
| Expr::Literal(_) => vec![],
Expr::Between {
expr, low, high, ..
} => {
let mut cols = vec![];
cols.extend(find_columns_referenced_by_expr(expr.as_ref()));
cols.extend(find_columns_referenced_by_expr(low.as_ref()));
cols.extend(find_columns_referenced_by_expr(high.as_ref()));
cols
}
Expr::ScalarUDF { args, .. }
| Expr::WindowFunction { args, .. }
| Expr::AggregateUDF { args, .. } => args
.iter()
.flat_map(find_columns_referenced_by_expr)
.collect(),
}
}

/// Fixed seed for the hashing so that Ords are consistent across runs
const SEED: ahash::RandomState = ahash::RandomState::with_seeds(0, 0, 0, 0);

Expand Down

0 comments on commit de7f15b

Please sign in to comment.