Skip to content

Commit

Permalink
Add Expr::column_refs to find column references without copying (#10948)
Browse files Browse the repository at this point in the history
* Add Expr::column_refs to find column references without copying

migrate some uses of to_column

* Simplify condition
  • Loading branch information
alamb authored Jun 22, 2024
1 parent ea46e82 commit 98373ab
Show file tree
Hide file tree
Showing 9 changed files with 81 additions and 39 deletions.
48 changes: 45 additions & 3 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ use crate::{
use crate::{window_frame, Volatility};

use arrow::datatypes::{DataType, FieldRef};
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
};
use datafusion_common::{
internal_err, plan_err, Column, DFSchema, Result, ScalarValue, TableReference,
};
Expand Down Expand Up @@ -1333,6 +1335,46 @@ impl Expr {
Ok(using_columns)
}

/// Return all references to columns in this expression.
///
/// # Example
/// ```
/// # use std::collections::HashSet;
/// # use datafusion_common::Column;
/// # use datafusion_expr::col;
/// // For an expression `a + (b * a)`
/// let expr = col("a") + (col("b") * col("a"));
/// let refs = expr.column_refs();
/// // refs contains "a" and "b"
/// assert_eq!(refs.len(), 2);
/// assert!(refs.contains(&Column::new_unqualified("a")));
/// assert!(refs.contains(&Column::new_unqualified("b")));
/// ```
pub fn column_refs(&self) -> HashSet<&Column> {
let mut using_columns = HashSet::new();
self.add_column_refs(&mut using_columns);
using_columns
}

/// Adds references to all columns in this expression to the set
///
/// See [`Self::column_refs`] for details
pub fn add_column_refs<'a>(&'a self, set: &mut HashSet<&'a Column>) {
self.apply(|expr| {
if let Expr::Column(col) = expr {
set.insert(col);
}
Ok(TreeNodeRecursion::Continue)
})
.expect("traversal is infallable");
}

/// Returns true if there are any column references in this Expr
pub fn any_column_refs(&self) -> bool {
self.exists(|expr| Ok(matches!(expr, Expr::Column(_))))
.unwrap()
}

/// Return true when the expression contains out reference(correlated) expressions.
pub fn contains_outer(&self) -> bool {
self.exists(|expr| Ok(matches!(expr, Expr::OuterReferenceColumn { .. })))
Expand Down Expand Up @@ -2038,15 +2080,15 @@ mod test {
// single column
{
let expr = &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64));
let columns = expr.to_columns()?;
let columns = expr.column_refs();
assert_eq!(1, columns.len());
assert!(columns.contains(&Column::from_name("a")));
}

// multiple columns
{
let expr = col("a") + col("b") + lit(1);
let columns = expr.to_columns()?;
let columns = expr.column_refs();
assert_eq!(2, columns.len());
assert!(columns.contains(&Column::from_name("a")));
assert!(columns.contains(&Column::from_name("b")));
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ pub const COUNT_STAR_EXPANSION: ScalarValue = ScalarValue::Int64(Some(1));

/// Recursively walk a list of expression trees, collecting the unique set of columns
/// referenced in the expression
#[deprecated(since = "40.0.0", note = "Expr::add_column_refs instead")]
pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet<Column>) -> Result<()> {
for e in expr {
expr_to_columns(e, accum)?;
Expand Down
21 changes: 10 additions & 11 deletions datafusion/optimizer/src/analyzer/subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,19 +300,17 @@ fn can_pullup_over_aggregation(expr: &Expr) -> bool {
}) = expr
{
match (left.deref(), right.deref()) {
(Expr::Column(_), right) if right.to_columns().unwrap().is_empty() => true,
(left, Expr::Column(_)) if left.to_columns().unwrap().is_empty() => true,
(Expr::Column(_), right) => !right.any_column_refs(),
(left, Expr::Column(_)) => !left.any_column_refs(),
(Expr::Cast(Cast { expr, .. }), right)
if matches!(expr.deref(), Expr::Column(_))
&& right.to_columns().unwrap().is_empty() =>
if matches!(expr.deref(), Expr::Column(_)) =>
{
true
!right.any_column_refs()
}
(left, Expr::Cast(Cast { expr, .. }))
if matches!(expr.deref(), Expr::Column(_))
&& left.to_columns().unwrap().is_empty() =>
if matches!(expr.deref(), Expr::Column(_)) =>
{
true
!left.any_column_refs()
}
(_, _) => false,
}
Expand All @@ -323,9 +321,10 @@ fn can_pullup_over_aggregation(expr: &Expr) -> bool {

/// Check whether the window expressions contain a mixture of out reference columns and inner columns
fn check_mixed_out_refer_in_window(window: &Window) -> Result<()> {
let mixed = window.window_expr.iter().any(|win_expr| {
win_expr.contains_outer() && !win_expr.to_columns().unwrap().is_empty()
});
let mixed = window
.window_expr
.iter()
.any(|win_expr| win_expr.contains_outer() && win_expr.any_column_refs());
if mixed {
plan_err!(
"Window expressions should not contain a mixed of outer references and inner columns"
Expand Down
11 changes: 7 additions & 4 deletions datafusion/optimizer/src/decorrelate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,11 +370,14 @@ impl PullUpCorrelatedExpr {
}
}
if let Some(pull_up_having) = &self.pull_up_having_expr {
let filter_apply_columns = pull_up_having.to_columns()?;
let filter_apply_columns = pull_up_having.column_refs();
for col in filter_apply_columns {
let col_expr = Expr::Column(col);
if !missing_exprs.contains(&col_expr) {
missing_exprs.push(col_expr)
// add to missing_exprs if not already there
let contains = missing_exprs
.iter()
.any(|expr| matches!(expr, Expr::Column(c) if c == col));
if !contains {
missing_exprs.push(Expr::Column(col.clone()))
}
}
}
Expand Down
16 changes: 8 additions & 8 deletions datafusion/optimizer/src/optimize_projections/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -479,10 +479,10 @@ fn merge_consecutive_projections(proj: Projection) -> Result<Transformed<Project
};

// Count usages (referrals) of each projection expression in its input fields:
let mut column_referral_map = HashMap::<Column, usize>::new();
for columns in expr.iter().flat_map(|expr| expr.to_columns()) {
let mut column_referral_map = HashMap::<&Column, usize>::new();
for columns in expr.iter().map(|expr| expr.column_refs()) {
for col in columns.into_iter() {
*column_referral_map.entry(col.clone()).or_default() += 1;
*column_referral_map.entry(col).or_default() += 1;
}
}

Expand All @@ -493,7 +493,7 @@ fn merge_consecutive_projections(proj: Projection) -> Result<Transformed<Project
usage > 1
&& !is_expr_trivial(
&prev_projection.expr
[prev_projection.schema.index_of_column(&col).unwrap()],
[prev_projection.schema.index_of_column(col).unwrap()],
)
}) {
// no change
Expand Down Expand Up @@ -625,12 +625,12 @@ fn rewrite_expr(expr: Expr, input: &Projection) -> Result<Transformed<Expr>> {
/// * `expr` - The expression to analyze for outer-referenced columns.
/// * `columns` - A mutable reference to a `HashSet<Column>` where detected
/// columns are collected.
fn outer_columns(expr: &Expr, columns: &mut HashSet<Column>) {
fn outer_columns<'a>(expr: &'a Expr, columns: &mut HashSet<&'a Column>) {
// inspect_expr_pre doesn't handle subquery references, so find them explicitly
expr.apply(|expr| {
match expr {
Expr::OuterReferenceColumn(_, col) => {
columns.insert(col.clone());
columns.insert(col);
}
Expr::ScalarSubquery(subquery) => {
outer_columns_helper_multi(&subquery.outer_ref_columns, columns);
Expand Down Expand Up @@ -660,9 +660,9 @@ fn outer_columns(expr: &Expr, columns: &mut HashSet<Column>) {
/// * `exprs` - The expressions to analyze for outer-referenced columns.
/// * `columns` - A mutable reference to a `HashSet<Column>` where detected
/// columns are collected.
fn outer_columns_helper_multi<'a>(
fn outer_columns_helper_multi<'a, 'b>(
exprs: impl IntoIterator<Item = &'a Expr>,
columns: &mut HashSet<Column>,
columns: &'b mut HashSet<&'a Column>,
) {
exprs.into_iter().for_each(|e| outer_columns(e, columns));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,12 @@ impl RequiredIndicies {
/// * `expr`: An expression for which we want to find necessary field indices.
fn add_expr(&mut self, input_schema: &DFSchemaRef, expr: &Expr) -> Result<()> {
// TODO could remove these clones (and visit the expression directly)
let mut cols = expr.to_columns()?;
let mut cols = expr.column_refs();
// Get outer-referenced (subquery) columns:
outer_columns(expr, &mut cols);
self.indices.reserve(cols.len());
for col in cols {
if let Some(idx) = input_schema.maybe_index_of_column(&col) {
if let Some(idx) = input_schema.maybe_index_of_column(col) {
self.indices.push(idx);
}
}
Expand Down
11 changes: 4 additions & 7 deletions datafusion/optimizer/src/push_down_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -561,12 +561,9 @@ fn infer_join_predicates(
.filter_map(|predicate| {
let mut join_cols_to_replace = HashMap::new();

let columns = match predicate.to_columns() {
Ok(columns) => columns,
Err(e) => return Some(Err(e)),
};
let columns = predicate.column_refs();

for col in columns.iter() {
for &col in columns.iter() {
for (l, r) in join_col_keys.iter() {
if col == *l {
join_cols_to_replace.insert(col, *r);
Expand Down Expand Up @@ -798,7 +795,7 @@ impl OptimizerRule for PushDownFilter {
let mut keep_predicates = vec![];
let mut push_predicates = vec![];
for expr in predicates {
let cols = expr.to_columns()?;
let cols = expr.column_refs();
if cols.iter().all(|c| group_expr_columns.contains(c)) {
push_predicates.push(expr);
} else {
Expand Down Expand Up @@ -899,7 +896,7 @@ impl OptimizerRule for PushDownFilter {
let predicate_push_or_keep = split_conjunction(&filter.predicate)
.iter()
.map(|expr| {
let cols = expr.to_columns()?;
let cols = expr.column_refs();
if cols.iter().any(|c| prevent_cols.contains(&c.name)) {
Ok(false) // No push (keep)
} else {
Expand Down
6 changes: 3 additions & 3 deletions datafusion/optimizer/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ pub(crate) fn collect_subquery_cols(
) -> Result<BTreeSet<Column>> {
exprs.iter().try_fold(BTreeSet::new(), |mut cols, expr| {
let mut using_cols: Vec<Column> = vec![];
for col in expr.to_columns()?.into_iter() {
if subquery_schema.has_column(&col) {
using_cols.push(col);
for col in expr.column_refs().into_iter() {
if subquery_schema.has_column(col) {
using_cols.push(col.clone());
}
}

Expand Down
2 changes: 1 addition & 1 deletion datafusion/sql/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -964,7 +964,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
self.order_by_to_sort_expr(&expr, schema, planner_context, true, None)?;
// Verify that columns of all SortExprs exist in the schema:
for expr in expr_vec.iter() {
for column in expr.to_columns()?.iter() {
for column in expr.column_refs().iter() {
if !schema.has_column(column) {
// Return an error if any column is not in the schema:
return plan_err!("Column {column} is not in schema");
Expand Down

0 comments on commit 98373ab

Please sign in to comment.