diff --git a/datafusion/common/src/column.rs b/datafusion/common/src/column.rs index c7f0b5a4f4881..f97276e3c3761 100644 --- a/datafusion/common/src/column.rs +++ b/datafusion/common/src/column.rs @@ -30,7 +30,7 @@ use std::fmt; pub struct Column { /// relation/table reference. pub relation: Option, - /// field/column name. + /// Field/column name. pub name: String, /// Original source code location, if known pub spans: Spans, diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs b/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs index 5f1971f649d2c..c8912ab639d8e 100644 --- a/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs +++ b/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs @@ -63,15 +63,13 @@ use datafusion_physical_plan::{ sorts::sort::SortExec, }; +use super::pushdown_utils::{ + OptimizationTest, TestNode, TestScanBuilder, TestSource, format_plan_for_test, +}; use datafusion_physical_plan::union::UnionExec; use futures::StreamExt; use object_store::{ObjectStore, memory::InMemory}; use regex::Regex; -use util::{OptimizationTest, TestNode, TestScanBuilder, format_plan_for_test}; - -use crate::physical_optimizer::filter_pushdown::util::TestSource; - -mod util; #[test] fn test_pushdown_into_scan() { diff --git a/datafusion/core/tests/physical_optimizer/mod.rs b/datafusion/core/tests/physical_optimizer/mod.rs index d11322cd26be9..cf179cb727cf1 100644 --- a/datafusion/core/tests/physical_optimizer/mod.rs +++ b/datafusion/core/tests/physical_optimizer/mod.rs @@ -24,7 +24,6 @@ mod combine_partial_final_agg; mod enforce_distribution; mod enforce_sorting; mod enforce_sorting_monotonicity; -#[expect(clippy::needless_pass_by_value)] mod filter_pushdown; mod join_selection; #[expect(clippy::needless_pass_by_value)] @@ -38,3 +37,5 @@ mod sanity_checker; #[expect(clippy::needless_pass_by_value)] mod test_utils; mod window_optimize; + +mod pushdown_utils; diff --git a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs index d9b36dc4b87ce..902e0f785199f 100644 --- a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs @@ -18,12 +18,15 @@ use std::any::Any; use std::sync::Arc; +use arrow::array::{Int32Array, RecordBatch, StructArray}; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow_schema::Fields; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::physical_plan::CsvSource; use datafusion::datasource::source::DataSourceExec; +use datafusion::prelude::get_field; use datafusion_common::config::{ConfigOptions, CsvOptions}; use datafusion_common::{JoinSide, JoinType, NullEquality, Result, ScalarValue}; use datafusion_datasource::TableSchema; @@ -31,12 +34,13 @@ use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::{ - Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, + Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, lit, }; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_physical_expr::expressions::{ BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr, binary, cast, col, }; +use datafusion_physical_expr::planner::logical2physical; use datafusion_physical_expr::{Distribution, Partitioning, ScalarFunctionExpr}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{ @@ -64,6 +68,8 @@ use datafusion_physical_plan::{ExecutionPlan, displayable}; use insta::assert_snapshot; use itertools::Itertools; +use crate::physical_optimizer::pushdown_utils::TestScanBuilder; + /// Mocked UDF #[derive(Debug, PartialEq, Eq, Hash)] struct DummyUDF { @@ -1723,3 +1729,87 @@ fn test_cooperative_exec_after_projection() -> Result<()> { Ok(()) } + +#[test] +fn test_pushdown_projection_through_repartition_filter() { + let struct_fields = Fields::from(vec![Field::new("a", DataType::Int32, false)]); + let array = StructArray::new( + struct_fields.clone(), + vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))], + None, + ); + let batches = vec![ + RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new( + "struct", + DataType::Struct(struct_fields.clone()), + true, + )])), + vec![Arc::new(array)], + ) + .unwrap(), + ]; + let build_side_schema = Arc::new(Schema::new(vec![Field::new( + "struct", + DataType::Struct(struct_fields), + true, + )])); + + let scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .with_batches(batches) + .build(); + let scan_schema = scan.schema(); + let struct_access = get_field(datafusion_expr::col("struct"), "a"); + let filter = struct_access.clone().gt(lit(2)); + let repartition = + RepartitionExec::try_new(scan, Partitioning::RoundRobinBatch(32)).unwrap(); + let filter_exec = FilterExec::try_new( + logical2physical(&filter, &scan_schema), + Arc::new(repartition), + ) + .unwrap(); + let projection: Arc = Arc::new( + ProjectionExec::try_new( + vec![ProjectionExpr::new( + logical2physical(&struct_access, &scan_schema), + "a", + )], + Arc::new(filter_exec), + ) + .unwrap(), + ) as _; + + let initial = displayable(projection.as_ref()).indent(true).to_string(); + let actual = initial.trim(); + + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[get_field(struct@0, a) as a] + FilterExec: get_field(struct@0, a) > 2 + RepartitionExec: partitioning=RoundRobinBatch(32), input_partitions=1 + DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[struct], file_type=test, pushdown_supported=true + " + ); + + let after_optimize = ProjectionPushdown::new() + .optimize(projection, &ConfigOptions::new()) + .unwrap(); + + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + + // Projection should be pushed all the way down to the DataSource, and + // filter predicate should be rewritten to reference projection's output column + assert_snapshot!( + actual, + @r" + FilterExec: a@0 > 2 + RepartitionExec: partitioning=RoundRobinBatch(32), input_partitions=1 + DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[get_field(struct@0, a) as a], file_type=test, pushdown_supported=true + " + ); +} diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs b/datafusion/core/tests/physical_optimizer/pushdown_utils.rs similarity index 92% rename from datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs rename to datafusion/core/tests/physical_optimizer/pushdown_utils.rs index 1afdc4823f0a4..524d33ae6edb6 100644 --- a/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs +++ b/datafusion/core/tests/physical_optimizer/pushdown_utils.rs @@ -24,6 +24,7 @@ use datafusion_datasource::{ file_scan_config::FileScanConfigBuilder, file_stream::FileOpenFuture, file_stream::FileOpener, source::DataSourceExec, }; +use datafusion_physical_expr::projection::ProjectionExprs; use datafusion_physical_expr_common::physical_expr::fmt_sql; use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::filter::batch_filter; @@ -50,7 +51,7 @@ use std::{ pub struct TestOpener { batches: Vec, batch_size: Option, - projection: Option>, + projection: Option, predicate: Option>, } @@ -60,6 +61,7 @@ impl FileOpener for TestOpener { if self.batches.is_empty() { return Ok((async { Ok(TestStream::new(vec![]).boxed()) }).boxed()); } + let schema = self.batches[0].schema(); if let Some(batch_size) = self.batch_size { let batch = concat_batches(&batches[0].schema(), &batches)?; let mut new_batches = Vec::new(); @@ -83,9 +85,10 @@ impl FileOpener for TestOpener { batches = new_batches; if let Some(projection) = &self.projection { + let projector = projection.make_projector(&schema)?; batches = batches .into_iter() - .map(|batch| batch.project(projection).unwrap()) + .map(|batch| projector.project_batch(&batch).unwrap()) .collect(); } @@ -103,14 +106,13 @@ pub struct TestSource { batch_size: Option, batches: Vec, metrics: ExecutionPlanMetricsSet, - projection: Option>, + projection: Option, table_schema: datafusion_datasource::TableSchema, } impl TestSource { pub fn new(schema: SchemaRef, support: bool, batches: Vec) -> Self { - let table_schema = - datafusion_datasource::TableSchema::new(Arc::clone(&schema), vec![]); + let table_schema = datafusion_datasource::TableSchema::new(schema, vec![]); Self { support, metrics: ExecutionPlanMetricsSet::new(), @@ -210,6 +212,30 @@ impl FileSource for TestSource { } } + fn try_pushdown_projection( + &self, + projection: &ProjectionExprs, + ) -> Result>> { + if let Some(existing_projection) = &self.projection { + // Combine existing projection with new projection + let combined_projection = existing_projection.try_merge(projection)?; + Ok(Some(Arc::new(TestSource { + projection: Some(combined_projection), + table_schema: self.table_schema.clone(), + ..self.clone() + }))) + } else { + Ok(Some(Arc::new(TestSource { + projection: Some(projection.clone()), + ..self.clone() + }))) + } + } + + fn projection(&self) -> Option<&ProjectionExprs> { + self.projection.as_ref() + } + fn table_schema(&self) -> &datafusion_datasource::TableSchema { &self.table_schema } @@ -332,6 +358,7 @@ pub struct OptimizationTest { } impl OptimizationTest { + #[expect(clippy::needless_pass_by_value)] pub fn new( input_plan: Arc, opt: O, diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index fa248c448683b..026b25170ce2e 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -996,10 +996,9 @@ async fn parquet_recursive_projection_pushdown() -> Result<()> { SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] RecursiveQueryExec: name=number_series, is_distinct=false CoalescePartitionsExec - ProjectionExec: expr=[id@0 as id, 1 as level] - FilterExec: id@0 = 1 - RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES), input_partitions=1 - DataSourceExec: file_groups={1 group: [[TMP_DIR/hierarchy.parquet]]}, projection=[id], file_type=parquet, predicate=id@0 = 1, pruning_predicate=id_null_count@2 != row_count@3 AND id_min@0 <= 1 AND 1 <= id_max@1, required_guarantees=[id in (1)] + FilterExec: id@0 = level@1 + RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES), input_partitions=1 + DataSourceExec: file_groups={1 group: [[TMP_DIR/hierarchy.parquet]]}, projection=[id, 1 as level], file_type=parquet, predicate=id@0 = 1, pruning_predicate=id_null_count@2 != row_count@3 AND id_min@0 <= 1 AND 1 <= id_max@1, required_guarantees=[id in (1)] CoalescePartitionsExec ProjectionExec: expr=[id@0 + 1 as ns.id + Int64(1), level@1 + 1 as ns.level + Int64(1)] FilterExec: id@0 < 10 diff --git a/datafusion/expr-common/src/lib.rs b/datafusion/expr-common/src/lib.rs index 2be066beaad24..cd91f44b5a7e5 100644 --- a/datafusion/expr-common/src/lib.rs +++ b/datafusion/expr-common/src/lib.rs @@ -44,4 +44,7 @@ pub mod operator; pub mod signature; pub mod sort_properties; pub mod statistics; +pub mod triviality; pub mod type_coercion; + +pub use triviality::ArgTriviality; diff --git a/datafusion/expr-common/src/triviality.rs b/datafusion/expr-common/src/triviality.rs new file mode 100644 index 0000000000000..8944f56d0e8ec --- /dev/null +++ b/datafusion/expr-common/src/triviality.rs @@ -0,0 +1,57 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Triviality classification for expressions and function arguments. + +/// Classification of argument triviality for scalar functions. +/// +/// This enum is used by [`ScalarUDFImpl::triviality_with_args`] to allow +/// functions to make context-dependent decisions about whether they are +/// trivial based on the nature of their arguments. +/// +/// For example, `get_field(struct_col, 'field_name')` is trivial (static field +/// lookup), but `get_field(struct_col, key_col)` is not (dynamic per-row lookup). +/// +/// [`ScalarUDFImpl::triviality_with_args`]: crate::ScalarUDFImpl::triviality_with_args +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ArgTriviality { + /// Argument is a literal constant value or an expression that can be + /// evaluated to a constant at planning time. + Literal, + /// Argument is a simple column reference. + Column, + /// Argument is a complex expressions that declares itself trivial. + /// For example, if `get_field(struct_col, 'field_name')` is implemented as a + /// trivial expression, then it would return this variant. + /// Then `other_trivial_function(get_field(...), 42)` could also be classified as + /// a trivial expression using the knowledge that `get_field(...)` is trivial. + TrivialExpr, + /// Argument is a complex expression that declares itself non-trivial. + /// For example, `min(col1 + col2)` is non-trivial because it requires per-row computation. + NonTrivial, +} + +impl ArgTriviality { + /// Returns true if this triviality classification indicates a trivial + /// (cheap to evaluate) expression. + /// + /// Trivial expressions include literals, column references, and trivial + /// composite expressions like nested field accessors. + pub fn is_trivial(&self) -> bool { + !matches!(self, ArgTriviality::NonTrivial) + } +} diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 454839fdb75ac..ad2b69680542e 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -27,7 +27,7 @@ use std::sync::Arc; use crate::expr_fn::binary_expr; use crate::function::WindowFunctionSimplification; use crate::logical_plan::Subquery; -use crate::{AggregateUDF, Volatility}; +use crate::{AggregateUDF, ArgTriviality, Volatility}; use crate::{ExprSchemable, Operator, Signature, WindowFrame, WindowUDF}; use arrow::datatypes::{DataType, Field, FieldRef}; @@ -1933,6 +1933,32 @@ impl Expr { } } + /// Returns the triviality classification of this expression. + /// + /// Trivial expressions include column references, literals, and nested + /// field access via `get_field`. + /// + /// # Example + /// ``` + /// # use datafusion_expr::{col, ArgTriviality}; + /// let expr = col("foo"); + /// assert!(expr.triviality().is_trivial()); + /// ``` + pub fn triviality(&self) -> ArgTriviality { + match self { + Expr::Column(_) => ArgTriviality::Column, + Expr::Literal(_, _) => ArgTriviality::Literal, + Expr::ScalarFunction(func) => { + // Classify each argument's triviality for context-aware decision making + let arg_trivialities: Vec = + func.args.iter().map(|arg| arg.triviality()).collect(); + + func.func.triviality_with_args(&arg_trivialities) + } + _ => ArgTriviality::NonTrivial, + } + } + /// Return all references to columns in this expression. /// /// # Example diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 4fb78933d7a5c..725fc07a47ca8 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -91,6 +91,7 @@ pub use datafusion_doc::{ DocSection, Documentation, DocumentationBuilder, aggregate_doc_sections, scalar_doc_sections, window_doc_sections, }; +pub use datafusion_expr_common::ArgTriviality; pub use datafusion_expr_common::accumulator::Accumulator; pub use datafusion_expr_common::columnar_value::ColumnarValue; pub use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 0654370ac7ebf..25d387783f591 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -30,6 +30,7 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::{ExprSchema, Result, ScalarValue, not_impl_err}; use datafusion_expr_common::dyn_eq::{DynEq, DynHash}; use datafusion_expr_common::interval_arithmetic::Interval; +use datafusion_expr_common::triviality::ArgTriviality; use std::any::Any; use std::cmp::Ordering; use std::fmt::Debug; @@ -122,6 +123,17 @@ impl ScalarUDF { Self { inner: fun } } + /// Returns the triviality classification of this function given its arguments' triviality. + /// + /// This allows functions to make context-dependent decisions about triviality. + /// For example, `get_field(struct_col, 'field_name')` is trivial (static field + /// lookup), but `get_field(struct_col, key_col)` is not (dynamic per-row lookup). + /// + /// See [`ScalarUDFImpl::triviality_with_args`] for more details. + pub fn triviality_with_args(&self, args: &[ArgTriviality]) -> ArgTriviality { + self.inner.triviality(args) + } + /// Return the underlying [`ScalarUDFImpl`] trait object for this function pub fn inner(&self) -> &Arc { &self.inner @@ -846,6 +858,32 @@ pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync { fn documentation(&self) -> Option<&Documentation> { None } + + /// Returns the triviality classification of this function given its arguments' triviality. + /// + /// This method allows functions to make context-dependent decisions about + /// whether they are trivial. The default implementation returns `NonTrivial` + /// (conservative default). + /// + /// Trivial functions are lightweight accessor functions like `get_field` + /// (struct field access) that simply access nested data within a column + /// without significant computation. + /// + /// This is used to identify expressions that are cheap to duplicate or + /// don't benefit from caching/partitioning optimizations. + /// + /// # Example + /// + /// `get_field(struct_col, 'field_name')` with a literal key is trivial (static + /// field lookup), but `get_field(struct_col, key_col)` with a column key is + /// not trivial (dynamic per-row lookup). + /// + /// # Arguments + /// + /// * `args` - Classification of each argument's triviality + fn triviality(&self, _args: &[ArgTriviality]) -> ArgTriviality { + ArgTriviality::NonTrivial + } } /// ScalarUDF that adds an alias to the underlying function. It is better to @@ -964,6 +1002,10 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { fn documentation(&self) -> Option<&Documentation> { self.inner.documentation() } + + fn triviality(&self, args: &[ArgTriviality]) -> ArgTriviality { + self.inner.triviality(args) + } } #[cfg(test)] diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 47a903639dde5..30d68fa071b6b 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -33,8 +33,8 @@ use datafusion_common::{ use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::ExprSimplifyResult; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, - ScalarUDFImpl, Signature, Volatility, + ArgTriviality, ColumnarValue, Documentation, Expr, ReturnFieldArgs, + ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; @@ -499,6 +499,27 @@ impl ScalarUDFImpl for GetFieldFunc { fn documentation(&self) -> Option<&Documentation> { self.doc() } + + fn triviality(&self, args: &[ArgTriviality]) -> ArgTriviality { + // get_field is only trivial if: + // 1. The struct/map argument is trivial (column, literal, or trivial expression) + // 2. All key arguments are literals (static field access, not dynamic per-row lookup) + if args.is_empty() { + return ArgTriviality::NonTrivial; + } + + // Check if the base (struct/map) argument is trivial + let base_trivial = args[0].is_trivial(); + + // All key arguments (after the first) must be literals for static field access + let keys_literal = args.iter().skip(1).all(|a| *a == ArgTriviality::Literal); + + if base_trivial && keys_literal { + ArgTriviality::TrivialExpr + } else { + ArgTriviality::NonTrivial + } + } } #[cfg(test)] @@ -542,4 +563,71 @@ mod tests { Ok(()) } + + #[test] + fn test_triviality_with_args_literal_key() { + let func = GetFieldFunc::new(); + + // get_field(col, 'literal') -> trivial (static field access) + let args = vec![ArgTriviality::Column, ArgTriviality::Literal]; + assert_eq!(func.triviality(&args), ArgTriviality::TrivialExpr); + + // get_field(col, 'a', 'b') -> trivial (nested static field access) + let args = vec![ + ArgTriviality::Column, + ArgTriviality::Literal, + ArgTriviality::Literal, + ]; + assert_eq!(func.triviality(&args), ArgTriviality::TrivialExpr); + + // get_field(get_field(col, 'a'), 'b') represented as TrivialExpr for base + let args = vec![ArgTriviality::TrivialExpr, ArgTriviality::Literal]; + assert_eq!(func.triviality(&args), ArgTriviality::TrivialExpr); + } + + #[test] + fn test_triviality_with_args_column_key() { + let func = GetFieldFunc::new(); + + // get_field(col, other_col) -> NOT trivial (dynamic per-row lookup) + let args = vec![ArgTriviality::Column, ArgTriviality::Column]; + assert_eq!(func.triviality(&args), ArgTriviality::NonTrivial); + + // get_field(col, 'a', other_col) -> NOT trivial (dynamic nested lookup) + let args = vec![ + ArgTriviality::Column, + ArgTriviality::Literal, + ArgTriviality::Column, + ]; + assert_eq!(func.triviality(&args), ArgTriviality::NonTrivial); + } + + #[test] + fn test_triviality_with_args_non_trivial() { + let func = GetFieldFunc::new(); + + // get_field(non_trivial_expr, 'literal') -> NOT trivial + let args = vec![ArgTriviality::NonTrivial, ArgTriviality::Literal]; + assert_eq!(func.triviality(&args), ArgTriviality::NonTrivial); + + // get_field(col, non_trivial_expr) -> NOT trivial + let args = vec![ArgTriviality::Column, ArgTriviality::NonTrivial]; + assert_eq!(func.triviality(&args), ArgTriviality::NonTrivial); + } + + #[test] + fn test_triviality_with_args_edge_cases() { + let func = GetFieldFunc::new(); + + // Empty args -> NOT trivial + assert_eq!(func.triviality(&[]), ArgTriviality::NonTrivial); + + // Just base, no key -> TrivialExpr (not a valid call but should handle gracefully) + let args = vec![ArgTriviality::Column]; + assert_eq!(func.triviality(&args), ArgTriviality::TrivialExpr); + + // Literal base with literal key -> trivial + let args = vec![ArgTriviality::Literal, ArgTriviality::Literal]; + assert_eq!(func.triviality(&args), ArgTriviality::TrivialExpr); + } } diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index f97b05ea68fbd..f01ca247dd906 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -530,10 +530,9 @@ fn merge_consecutive_projections(proj: Projection) -> Result 1 - && !is_expr_trivial( - &prev_projection.expr - [prev_projection.schema.index_of_column(col).unwrap()], - ) + && !prev_projection.expr[prev_projection.schema.index_of_column(col).unwrap()] + .triviality() + .is_trivial() }) { // no change return Projection::try_new_with_schema(expr, input, schema).map(Transformed::no); @@ -586,11 +585,6 @@ fn merge_consecutive_projections(proj: Projection) -> Result bool { - matches!(expr, Expr::Column(_) | Expr::Literal(_, _)) -} - /// Rewrites a projection expression using the projection before it (i.e. its input) /// This is a subroutine to the `merge_consecutive_projections` function. /// diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index 2358a21940912..5a6e8d0c2ac09 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -37,6 +37,7 @@ use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; use datafusion_expr_common::sort_properties::ExprProperties; use datafusion_expr_common::statistics::Distribution; +use datafusion_expr_common::triviality::ArgTriviality; use itertools::izip; @@ -430,6 +431,20 @@ pub trait PhysicalExpr: Any + Send + Sync + Display + Debug + DynEq + DynHash { fn is_volatile_node(&self) -> bool { false } + + /// Returns the triviality classification of this expression. + /// + /// Trivial expressions include: + /// - Column references (`ArgTriviality::Column`) + /// - Literal values (`ArgTriviality::Literal`) + /// - Struct field access via `get_field` (`ArgTriviality::TrivialExpr`) + /// - Nested combinations of field accessors (e.g., `col['a']['b']`) + /// + /// This is used to identify expressions that are cheap to duplicate or + /// don't benefit from caching/partitioning optimizations. + fn triviality(&self) -> ArgTriviality { + ArgTriviality::NonTrivial + } } #[deprecated( diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index 8c7e8c319fff4..c34cc31ad3d4e 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -29,7 +29,7 @@ use arrow::{ }; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{Result, internal_err, plan_err}; -use datafusion_expr::ColumnarValue; +use datafusion_expr::{ArgTriviality, ColumnarValue}; /// Represents the column at a given index in a RecordBatch /// @@ -146,6 +146,10 @@ impl PhysicalExpr for Column { fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.name) } + + fn triviality(&self) -> ArgTriviality { + ArgTriviality::Column + } } impl Column { diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index 1f3fefc60b7ad..9cb5acff550aa 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -30,7 +30,7 @@ use arrow::{ }; use datafusion_common::metadata::FieldMetadata; use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::Expr; +use datafusion_expr::{ArgTriviality, Expr}; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; use datafusion_expr_common::sort_properties::{ExprProperties, SortProperties}; @@ -134,6 +134,10 @@ impl PhysicalExpr for Literal { fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { std::fmt::Display::fmt(self, f) } + + fn triviality(&self) -> ArgTriviality { + ArgTriviality::Literal + } } /// Create a literal expression diff --git a/datafusion/physical-expr/src/projection.rs b/datafusion/physical-expr/src/projection.rs index 540fd620c92ce..65430ae90791b 100644 --- a/datafusion/physical-expr/src/projection.rs +++ b/datafusion/physical-expr/src/projection.rs @@ -37,6 +37,7 @@ use datafusion_physical_expr_common::metrics::ExpressionEvaluatorMetrics; use datafusion_physical_expr_common::physical_expr::fmt_sql; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays_with_metrics; +use hashbrown::HashSet; use indexmap::IndexMap; use itertools::Itertools; @@ -128,6 +129,49 @@ pub struct ProjectionExprs { exprs: Vec, } +/// Classification of how beneficial a projection expression is for pushdown. +/// +/// This is used to determine whether an expression should be pushed down +/// below other operators in the query plan. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PushdownBenefit { + /// Field accessors - reduce data size, should be pushed down. + /// Examples: `struct_col['field']`, nested field access. + Beneficial, + /// Column references - neutral, can be pushed if needed. + /// Examples: `col_a`, `col_b@1` + Neutral, + /// Literals and computed expressions - add data, should NOT be pushed down. + /// Examples: `42`, `'hello'`, `a + b` + NonBeneficial, +} + +/// Result of splitting a projection for optimized pushdown. +/// +/// When a projection contains a mix of beneficial and non-beneficial expressions, +/// we can split it into two projections: +/// - `inner`: Pushed down below operators (contains beneficial exprs + needed columns) +/// - `outer`: Stays above operators (references inner outputs + adds non-beneficial exprs) +/// +/// Example: +/// ```text +/// -- Original: +/// Projection: struct['a'] AS f1, 42 AS const, col_b +/// FilterExec: predicate +/// +/// -- After splitting: +/// Projection: f1, 42 AS const, col_b <- outer (keeps literal, refs inner) +/// Projection: struct['a'] AS f1, col_b <- inner (pushed down) +/// FilterExec: predicate +/// ``` +#[derive(Debug, Clone)] +pub struct ProjectionSplit { + /// The inner projection to be pushed down (beneficial exprs + columns needed by outer) + pub inner: ProjectionExprs, + /// The outer projection to keep above (refs to inner outputs + non-beneficial exprs) + pub outer: ProjectionExprs, +} + impl std::fmt::Display for ProjectionExprs { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let exprs: Vec = self.exprs.iter().map(|e| e.to_string()).collect(); @@ -233,6 +277,277 @@ impl ProjectionExprs { self.exprs.iter() } + /// Checks if all of the projection expressions are trivial. + pub fn is_trivial(&self) -> bool { + self.exprs.iter().all(|p| p.expr.triviality().is_trivial()) + } + + /// Classifies a single expression for pushdown benefit. + /// + /// - Literals are `NonBeneficial` (they add data) + /// - Non-trivial expressions are `NonBeneficial` (they add computation) + /// - Column references are `Neutral` (no cost/benefit) + /// - Trivial non-column expressions (e.g., field accessors) are `Beneficial` + fn classify_expr(expr: &Arc) -> PushdownBenefit { + // Literals add data size downstream - don't push + if expr.as_any().is::() { + return PushdownBenefit::NonBeneficial; + } + // Column references are neutral + if expr.as_any().is::() { + return PushdownBenefit::Neutral; + } + + if expr.triviality().is_trivial() { + // Trivial non-column expressions (field accessors) reduce data - push + PushdownBenefit::Beneficial + } else { + // Any other expression is considered non-beneficial + PushdownBenefit::NonBeneficial + } + } + + /// Check if this projection benefits from being pushed down lower in the plan. + /// + /// The goal is to push down projections that reduce the amount of data processed + /// by subsequent operations and are "compute neutral" (i.e., do not add computation overhead). + /// + /// Some "compute neutral" projections include: + /// - Dropping unneeded columns (e.g., `SELECT a, b` from `a, b, c`). Subsequent filters or joins + /// now process fewer columns. + /// - Struct field access (e.g., `SELECT struct_col.field1 AS f1`). This reduces the data size + /// processed downstream and is a cheap reference / metadata only clone of the inner array. + /// + /// Examples of projections that do NOT benefit from pushdown: + /// - Literal projections (e.g., `SELECT 42 AS const_col`). These add constant columns + /// that increase data size downstream. + /// - Computed expressions (e.g., `SELECT a + b AS sum_ab`). These add computation overhead + /// and may increase data size downstream, so they should be applied after filters or joins. + pub fn benefits_from_pushdown(&self) -> bool { + // All expressions must be trivial for pushdown to be beneficial + if !self.is_trivial() { + // Contains computed expressions, function calls, etc. - do not push down + return false; + } + + // Check if all expressions are just columns or literals (no field accessors) + // If so, there's no benefit to pushing down because: + // - Columns are just references (no data reduction) + // - Literals add data (definitely don't want to push down) + let all_columns_or_literals = self + .exprs + .iter() + .all(|p| p.expr.as_any().is::() || p.expr.as_any().is::()); + + // Only benefit from pushdown if we have field accessors (or other beneficial trivial exprs) + !all_columns_or_literals + } + + /// Determines whether this projection should be pushed through an operator. + /// + /// A projection should be pushed through when it is: + /// 1. Trivial (no expensive computations to duplicate) + /// 2. AND provides some benefit: + /// - Either narrows the schema (fewer output columns than input columns) + /// - Or has beneficial expressions like field accessors that reduce data size + /// - Or has literal expressions that can be absorbed by the datasource + /// + /// Column-only projections that just rename without narrowing the schema are NOT + /// pushed through, as they provide no benefit. + /// + /// # Arguments + /// * `input_field_count` - Number of fields in the input schema + pub fn should_push_through_operator(&self, input_field_count: usize) -> bool { + // Must be trivial (no expensive computations) + if !self.is_trivial() { + return false; + } + + // Must provide some benefit: + // - Either narrows schema (fewer output columns than input columns) + // - Or has field accessors that reduce data size + // Note: literals are NOT pushed through because they expand into full arrays + // downstream, increasing data size. However, literals as arguments to trivial + // functions (like get_field keys) are fine since they're just accessors. + let narrows_schema = self.exprs.len() < input_field_count; + let has_beneficial_exprs = self.benefits_from_pushdown(); + + narrows_schema || has_beneficial_exprs + } + + /// Attempts to split this projection into beneficial and non-beneficial parts. + /// + /// When a projection contains both beneficial expressions (field accessors) and + /// non-beneficial expressions (literals), this method splits it so that the + /// beneficial parts can be pushed down while non-beneficial parts stay above. + /// + /// # Returns + /// - `Ok(Some(split))` - The projection was split successfully + /// - `Ok(None)` - No split needed (all expressions are the same category) + /// + /// # Arguments + /// * `input_schema` - The schema of the input to this projection + pub fn split_for_pushdown( + &self, + input_schema: &Schema, + ) -> Result> { + // Classify all expressions + let classifications: Vec<_> = self + .exprs + .iter() + .map(|p| (p, Self::classify_expr(&p.expr))) + .collect(); + + let has_beneficial = classifications + .iter() + .any(|(_, c)| *c == PushdownBenefit::Beneficial); + let has_non_beneficial = classifications + .iter() + .any(|(_, c)| *c == PushdownBenefit::NonBeneficial); + + // If no beneficial expressions, nothing to push down + if !has_beneficial { + return Ok(None); + } + + // If no non-beneficial expressions, push the entire projection (no split needed) + if !has_non_beneficial { + return Ok(None); + } + + // We need to split: beneficial + columns needed by non-beneficial go to inner, + // references to inner + non-beneficial expressions go to outer + + // Collect columns needed by non-beneficial expressions + let mut columns_needed_by_outer: HashSet = HashSet::new(); + for (proj, class) in &classifications { + if *class == PushdownBenefit::NonBeneficial { + for col in collect_columns(&proj.expr) { + columns_needed_by_outer.insert(col.index()); + } + } + } + + // Build inner projection: beneficial exprs + columns needed by outer + let mut inner_exprs: Vec = Vec::new(); + // Track where each original expression ends up in inner projection + // Maps original index -> inner index (if it goes to inner) + let mut original_to_inner: IndexMap = IndexMap::new(); + + // First add beneficial expressions + for (orig_idx, (proj, class)) in classifications.iter().enumerate() { + if *class == PushdownBenefit::Beneficial || *class == PushdownBenefit::Neutral + { + original_to_inner.insert(orig_idx, inner_exprs.len()); + inner_exprs.push((*proj).clone()); + } + } + + // Add columns needed by non-beneficial expressions (if not already present) + // Build mapping from input column index -> inner projection index + let mut col_index_to_inner: IndexMap = IndexMap::new(); + for (proj_idx, (proj, _)) in classifications.iter().enumerate() { + if let Some(col) = proj.expr.as_any().downcast_ref::() + && let Some(&inner_idx) = original_to_inner.get(&proj_idx) + { + col_index_to_inner.insert(col.index(), inner_idx); + } + } + + // Track columns we need to add to inner for outer's non-beneficial exprs + for col_idx in &columns_needed_by_outer { + if !col_index_to_inner.contains_key(col_idx) { + // Add this column to inner + let field = input_schema.field(*col_idx); + let col_expr = ProjectionExpr::new( + Arc::new(Column::new(field.name(), *col_idx)), + field.name().clone(), + ); + col_index_to_inner.insert(*col_idx, inner_exprs.len()); + inner_exprs.push(col_expr); + } + } + + // Build inner schema (for rewriting outer expressions) + let inner_schema = self.build_schema_for_exprs(&inner_exprs, input_schema)?; + + // Build outer projection: references to inner outputs + non-beneficial exprs + let mut outer_exprs: Vec = Vec::new(); + + for (orig_idx, (proj, class)) in classifications.iter().enumerate() { + match class { + PushdownBenefit::Beneficial | PushdownBenefit::Neutral => { + // Reference the inner projection output + let inner_idx = original_to_inner[&orig_idx]; + let col_expr = ProjectionExpr::new( + Arc::new(Column::new(&proj.alias, inner_idx)), + proj.alias.clone(), + ); + outer_exprs.push(col_expr); + } + PushdownBenefit::NonBeneficial => { + // Keep the expression but rewrite column references to point to inner + let rewritten_expr = self.rewrite_columns_for_inner( + &proj.expr, + &col_index_to_inner, + &inner_schema, + )?; + outer_exprs + .push(ProjectionExpr::new(rewritten_expr, proj.alias.clone())); + } + } + } + + Ok(Some(ProjectionSplit { + inner: ProjectionExprs::new(inner_exprs), + outer: ProjectionExprs::new(outer_exprs), + })) + } + + /// Helper to build a schema from projection expressions + fn build_schema_for_exprs( + &self, + exprs: &[ProjectionExpr], + input_schema: &Schema, + ) -> Result { + let fields: Result> = exprs + .iter() + .map(|p| { + let field = p.expr.return_field(input_schema)?; + Ok(Field::new( + &p.alias, + field.data_type().clone(), + field.is_nullable(), + )) + }) + .collect(); + Ok(Schema::new(fields?)) + } + + /// Rewrite column references in an expression to point to inner projection outputs + fn rewrite_columns_for_inner( + &self, + expr: &Arc, + col_index_to_inner: &IndexMap, + inner_schema: &Schema, + ) -> Result> { + Arc::clone(expr) + .transform(|e| { + if let Some(col) = e.as_any().downcast_ref::() + && let Some(&inner_idx) = col_index_to_inner.get(&col.index()) + { + let inner_field = inner_schema.field(inner_idx); + return Ok(Transformed::yes(Arc::new(Column::new( + inner_field.name(), + inner_idx, + )) + as Arc)); + } + Ok(Transformed::no(e)) + }) + .data() + } + /// Creates a ProjectionMapping from this projection pub fn projection_mapping( &self, @@ -795,7 +1110,7 @@ pub fn update_expr( projected_exprs: &[ProjectionExpr], sync_with_child: bool, ) -> Result>> { - #[derive(Debug, PartialEq)] + #[derive(PartialEq)] enum RewriteState { /// The expression is unchanged. Unchanged, @@ -806,10 +1121,48 @@ pub fn update_expr( RewrittenInvalid, } + // Track Arc pointers of columns created by pass 1. + // These should not be modified by pass 2. + // We use Arc pointer addresses (not name/index) to distinguish pass-1-created columns + // from original columns that happen to have the same name and index. + let mut pass1_created: HashSet = HashSet::new(); + + // First pass: try to rewrite the expression in terms of the projected expressions. + // For example, if the expression is `a + b > 5` and the projection is `a + b AS sum_ab`, + // we can rewrite the expression to `sum_ab > 5` directly. + // + // This optimization only applies when sync_with_child=false, meaning we want the + // expression to use OUTPUT references (e.g., when pushing projection down and the + // expression will be above the projection). Pass 1 creates OUTPUT column references. + // + // When sync_with_child=true, we want INPUT references (expanding OUTPUT to INPUT), + // so pass 1 doesn't apply. + let new_expr = if !sync_with_child { + Arc::clone(expr) + .transform_down(&mut |expr: Arc| { + // If expr is equal to one of the projected expressions, we can short-circuit the rewrite: + for (idx, projected_expr) in projected_exprs.iter().enumerate() { + if expr.eq(&projected_expr.expr) { + // Create new column and track its Arc pointer so pass 2 doesn't modify it + let new_col = Arc::new(Column::new(&projected_expr.alias, idx)) + as Arc; + // Use data pointer for trait object (ignores vtable) + pass1_created.insert(Arc::as_ptr(&new_col) as *const () as usize); + return Ok(Transformed::yes(new_col)); + } + } + Ok(Transformed::no(expr)) + })? + .data + } else { + Arc::clone(expr) + }; + + // Second pass: rewrite remaining column references based on the projection. + // Skip columns that were introduced by pass 1. let mut state = RewriteState::Unchanged; - - let new_expr = Arc::clone(expr) - .transform_up(|expr| { + let new_expr = new_expr + .transform_up(&mut |expr: Arc| { if state == RewriteState::RewrittenInvalid { return Ok(Transformed::no(expr)); } @@ -817,6 +1170,16 @@ pub fn update_expr( let Some(column) = expr.as_any().downcast_ref::() else { return Ok(Transformed::no(expr)); }; + + // Skip columns introduced by pass 1 - they're already valid OUTPUT references. + // Mark state as valid since pass 1 successfully handled this column. + // We check the Arc pointer address to distinguish pass-1-created columns from + // original columns that might have the same name and index. + if pass1_created.contains(&(Arc::as_ptr(&expr) as *const () as usize)) { + state = RewriteState::RewrittenValid; + return Ok(Transformed::no(expr)); + } + if sync_with_child { state = RewriteState::RewrittenValid; // Update the index of `column`: @@ -2425,6 +2788,291 @@ pub(crate) mod tests { Ok(()) } + #[test] + fn test_update_expr_matches_projected_expr() -> Result<()> { + // Test that when filter expression exactly matches a projected expression, + // update_expr short-circuits and rewrites to use the projected column. + // e.g., projection: a * 2 AS a_times_2, filter: a * 2 > 4 + // should become: a_times_2 > 4 + + // Create the computed expression: a@0 * 2 + let computed_expr: Arc = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Multiply, + Arc::new(Literal::new(ScalarValue::Int32(Some(2)))), + )); + + // Create projection with the computed expression aliased as "a_times_2" + let projection = vec![ProjectionExpr { + expr: Arc::clone(&computed_expr), + alias: "a_times_2".to_string(), + }]; + + // Create filter predicate: a * 2 > 4 (same expression as projection) + let filter_predicate: Arc = Arc::new(BinaryExpr::new( + Arc::clone(&computed_expr), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(4)))), + )); + + // Update the expression - should rewrite a * 2 to a_times_2@0 + // sync_with_child=false because we want OUTPUT references (filter will be above projection) + let result = update_expr(&filter_predicate, &projection, false)?; + assert!(result.is_some(), "Filter predicate should be valid"); + + let result_expr = result.unwrap(); + let binary = result_expr + .as_any() + .downcast_ref::() + .expect("Should be a BinaryExpr"); + // Left side should now be a column reference to a_times_2@0 + let left_col = binary + .left() + .as_any() + .downcast_ref::() + .expect("Left should be rewritten to a Column"); + assert_eq!(left_col.name(), "a_times_2"); + assert_eq!(left_col.index(), 0); + + // Right side should still be the literal 4 + let right_lit = binary + .right() + .as_any() + .downcast_ref::() + .expect("Right should be a Literal"); + assert_eq!(right_lit.value(), &ScalarValue::Int32(Some(4))); + + Ok(()) + } + + #[test] + fn test_update_expr_partial_match() -> Result<()> { + // Test that when only part of an expression matches, we still handle + // the rest correctly. e.g., `a + b > 2 AND c > 3` with projection + // `a + b AS sum_ab, c AS c_out` should become `sum_ab > 2 AND c_out > 3` + + // Create computed expression: a@0 + b@1 + let sum_expr: Arc = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Plus, + Arc::new(Column::new("b", 1)), + )); + + // Projection: [a + b AS sum_ab, c AS c_out] + let projection = vec![ + ProjectionExpr { + expr: Arc::clone(&sum_expr), + alias: "sum_ab".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("c", 2)), + alias: "c_out".to_string(), + }, + ]; + + // Filter: (a + b > 2) AND (c > 3) + let filter_predicate: Arc = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::clone(&sum_expr), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(2)))), + )), + Operator::And, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(3)))), + )), + )); + + // With sync_with_child=false: columns reference input schema, need to map to output + let result = update_expr(&filter_predicate, &projection, false)?; + assert!(result.is_some(), "Filter predicate should be valid"); + + let result_expr = result.unwrap(); + // Should be: sum_ab@0 > 2 AND c_out@1 > 3 + assert_eq!(result_expr.to_string(), "sum_ab@0 > 2 AND c_out@1 > 3"); + + Ok(()) + } + + #[test] + fn test_update_expr_partial_match_with_unresolved_column() -> Result<()> { + // Test that when part of an expression matches but other columns can't be + // resolved, we return None. e.g., `a + b > 2 AND c > 3` with projection + // `a + b AS sum_ab` (note: no 'c' column!) should return None. + + // Create computed expression: a@0 + b@1 + let sum_expr: Arc = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Plus, + Arc::new(Column::new("b", 1)), + )); + + // Projection: [a + b AS sum_ab] - note: NO 'c' column! + let projection = vec![ProjectionExpr { + expr: Arc::clone(&sum_expr), + alias: "sum_ab".to_string(), + }]; + + // Filter: (a + b > 2) AND (c > 3) - 'c' is not in projection! + let filter_predicate: Arc = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::clone(&sum_expr), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(2)))), + )), + Operator::And, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(3)))), + )), + )); + + // With sync_with_child=false: should return None because 'c' can't be mapped + let result = update_expr(&filter_predicate, &projection, false)?; + assert!( + result.is_none(), + "Should return None when some columns can't be resolved" + ); + + // On the other hand if the projection is `c AS c_out, a + b AS sum_ab` we should succeed + let projection = vec![ + ProjectionExpr { + expr: Arc::new(Column::new("c", 2)), + alias: "c_out".to_string(), + }, + ProjectionExpr { + expr: Arc::clone(&sum_expr), + alias: "sum_ab".to_string(), + }, + ]; + let result = update_expr(&filter_predicate, &projection, false)?; + assert!(result.is_some(), "Filter predicate should be valid now"); + + Ok(()) + } + + #[test] + fn test_update_expr_nested_match() -> Result<()> { + // Test matching a sub-expression within a larger expression. + // e.g., `(a + b) * 2 > 10` with projection `a + b AS sum_ab` + // should become `sum_ab * 2 > 10` + + // Create computed expression: a@0 + b@1 + let sum_expr: Arc = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Plus, + Arc::new(Column::new("b", 1)), + )); + + // Projection: [a + b AS sum_ab] + let projection = vec![ProjectionExpr { + expr: Arc::clone(&sum_expr), + alias: "sum_ab".to_string(), + }]; + + // Filter: (a + b) * 2 > 10 + let filter_predicate: Arc = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::clone(&sum_expr), + Operator::Multiply, + Arc::new(Literal::new(ScalarValue::Int32(Some(2)))), + )), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )); + + // With sync_with_child=false: should rewrite a+b to sum_ab + let result = update_expr(&filter_predicate, &projection, false)?; + assert!(result.is_some(), "Filter predicate should be valid"); + + let result_expr = result.unwrap(); + // Should be: sum_ab@0 * 2 > 10 + assert_eq!(result_expr.to_string(), "sum_ab@0 * 2 > 10"); + + Ok(()) + } + + #[test] + fn test_update_expr_no_match_returns_none() -> Result<()> { + // Test that when columns can't be resolved, we return None (with sync_with_child=false) + + // Projection: [a AS a_out] + let projection = vec![ProjectionExpr { + expr: Arc::new(Column::new("a", 0)), + alias: "a_out".to_string(), + }]; + + // Filter references column 'd' which is not in projection + let filter_predicate: Arc = Arc::new(BinaryExpr::new( + Arc::new(Column::new("d", 3)), // Not in projection + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )); + + // With sync_with_child=false: should return None because 'd' can't be mapped + let result = update_expr(&filter_predicate, &projection, false)?; + assert!( + result.is_none(), + "Should return None when column can't be resolved" + ); + + Ok(()) + } + + #[test] + fn test_update_expr_column_name_collision() -> Result<()> { + // Regression test for a bug where an original column with the same (name, index) + // as a pass-1-rewritten column would incorrectly be considered "already handled". + // + // Example from SQLite tests: + // - Input schema: [col0, col1, col2] + // - Projection: [col2 AS col0] (col2@2 becomes col0@0) + // - Filter: col0 - col2 <= col2 / col2 + // + // The bug: when pass 1 rewrites col2@2 to col0@0, it added ("col0", 0) to + // valid_columns. Then in pass 2, the ORIGINAL col0@0 in the filter would + // match ("col0", 0) and be incorrectly skipped, resulting in: + // col0 - col0 <= col0 / col0 = 0 - 0 <= 0 / 0 = always true (or NaN) + // instead of flagging the expression as invalid. + + // Projection: [col2 AS col0] - note the alias matches another input column's name! + let projection = vec![ProjectionExpr { + expr: Arc::new(Column::new("col2", 2)), + alias: "col0".to_string(), // Alias collides with original col0! + }]; + + // Filter: col0@0 - col2@2 <= col2@2 / col2@2 + // After correct rewrite, col2@2 becomes col0@0 (via pass 1 match) + // But col0@0 (the original) can't be resolved since the projection + // doesn't include it - should return None + let filter_predicate: Arc = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("col0", 0)), // Original col0 - NOT in projection! + Operator::Minus, + Arc::new(Column::new("col2", 2)), // This will be rewritten to col0@0 + )), + Operator::LtEq, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("col2", 2)), + Operator::Divide, + Arc::new(Column::new("col2", 2)), + )), + )); + + // With sync_with_child=false: should return None because original col0@0 + // can't be resolved (only col2 is in projection, aliased as col0) + let result = update_expr(&filter_predicate, &projection, false)?; + assert!( + result.is_none(), + "Should return None when original column collides with rewritten alias but isn't in projection" + ); + + Ok(()) + } + #[test] fn test_project_schema_simple_columns() -> Result<()> { // Input schema: [col0: Int64, col1: Utf8, col2: Float32] diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index aa090743ad441..ee5feae50b53c 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -45,8 +45,8 @@ use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::type_coercion::functions::fields_with_udf; use datafusion_expr::{ - ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, Volatility, - expr_vec_fmt, + ArgTriviality, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, + Volatility, expr_vec_fmt, }; /// Physical expression of a scalar function @@ -362,6 +362,14 @@ impl PhysicalExpr for ScalarFunctionExpr { fn is_volatile_node(&self) -> bool { self.fun.signature().volatility == Volatility::Volatile } + + fn triviality(&self) -> ArgTriviality { + // Classify each argument's triviality for context-aware decision making + let arg_trivialities: Vec = + self.args.iter().map(|arg| arg.triviality()).collect(); + + self.fun.triviality_with_args(&arg_trivialities) + } } #[cfg(test)] diff --git a/datafusion/physical-optimizer/src/output_requirements.rs b/datafusion/physical-optimizer/src/output_requirements.rs index 0dc6a25fbc0b7..f2b4164bc50d9 100644 --- a/datafusion/physical-optimizer/src/output_requirements.rs +++ b/datafusion/physical-optimizer/src/output_requirements.rs @@ -256,9 +256,12 @@ impl ExecutionPlan for OutputRequirementExec { &self, projection: &ProjectionExec, ) -> Result>> { - // If the projection does not narrow the schema, we should not try to push it down: - let proj_exprs = projection.expr(); - if proj_exprs.len() >= projection.input().schema().fields().len() { + // Only push down projections that are trivial AND provide benefit (narrow schema or have field accessors) + let input_field_count = projection.input().schema().fields().len(); + if !projection + .projection_expr() + .should_push_through_operator(input_field_count) + { return Ok(None); } @@ -267,7 +270,8 @@ impl ExecutionPlan for OutputRequirementExec { let mut updated_reqs = vec![]; let (lexes, soft) = reqs.into_alternatives(); for lex in lexes.into_iter() { - let Some(updated_lex) = update_ordering_requirement(lex, proj_exprs)? + let Some(updated_lex) = + update_ordering_requirement(lex, projection.expr())? else { return Ok(None); }; diff --git a/datafusion/physical-optimizer/src/projection_pushdown.rs b/datafusion/physical-optimizer/src/projection_pushdown.rs index 281d61aecf538..df7fec5b53d22 100644 --- a/datafusion/physical-optimizer/src/projection_pushdown.rs +++ b/datafusion/physical-optimizer/src/projection_pushdown.rs @@ -76,6 +76,11 @@ impl PhysicalOptimizerRule for ProjectionPushdown { }) .map(|t| t.data)?; + // First, try to split mixed projections (beneficial + non-beneficial expressions) + // This allows the beneficial parts to be pushed down while keeping non-beneficial parts above. + let plan = plan.transform_down(try_split_projection).map(|t| t.data)?; + + // Then apply the normal projection pushdown logic plan.transform_down(remove_unnecessary_projections).data() } @@ -88,6 +93,44 @@ impl PhysicalOptimizerRule for ProjectionPushdown { } } +/// Tries to split a projection that contains a mix of beneficial and non-beneficial expressions. +/// +/// Beneficial expressions (like field accessors) should be pushed down, while non-beneficial +/// expressions (like literals) should stay above. This function splits the projection into +/// two parts: +/// - Inner projection: Contains beneficial expressions + columns needed by outer +/// - Outer projection: Contains references to inner + non-beneficial expressions +/// +/// This enables the inner projection to be pushed down while the outer stays in place. +fn try_split_projection( + plan: Arc, +) -> Result>> { + let Some(projection) = plan.as_any().downcast_ref::() else { + return Ok(Transformed::no(plan)); + }; + + let input_schema = projection.input().schema(); + let split = projection + .projection_expr() + .split_for_pushdown(input_schema.as_ref())?; + + let Some(split) = split else { + // No split needed - either all beneficial (push whole thing) or all non-beneficial (don't push) + return Ok(Transformed::no(plan)); + }; + + // Create the inner projection (to be pushed down) + let inner = ProjectionExec::try_new( + split.inner.as_ref().to_vec(), + Arc::clone(projection.input()), + )?; + + // Create the outer projection (stays above) + let outer = ProjectionExec::try_new(split.outer.as_ref().to_vec(), Arc::new(inner))?; + + Ok(Transformed::yes(Arc::new(outer))) +} + /// Tries to push down parts of the filter. /// /// See [JoinFilterRewriter] for details. diff --git a/datafusion/physical-plan/src/coalesce_partitions.rs b/datafusion/physical-plan/src/coalesce_partitions.rs index 22dcc85d6ea3a..f6840d16002b6 100644 --- a/datafusion/physical-plan/src/coalesce_partitions.rs +++ b/datafusion/physical-plan/src/coalesce_partitions.rs @@ -249,8 +249,12 @@ impl ExecutionPlan for CoalescePartitionsExec { &self, projection: &ProjectionExec, ) -> Result>> { - // If the projection does not narrow the schema, we should not try to push it down: - if projection.expr().len() >= projection.input().schema().fields().len() { + // Only push down projections that are trivial AND provide benefit (narrow schema or have field accessors) + let input_field_count = projection.input().schema().fields().len(); + if !projection + .projection_expr() + .should_push_through_operator(input_field_count) + { return Ok(None); } // CoalescePartitionsExec always has a single child, so zero indexing is safe. diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index 1edf96fe0c794..9333f235fb40e 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -562,8 +562,12 @@ impl ExecutionPlan for FilterExec { &self, projection: &ProjectionExec, ) -> Result>> { - // If the projection does not narrow the schema, we should not try to push it down: - if projection.expr().len() < projection.input().schema().fields().len() { + // Only push down projections that are trivial AND provide benefit (narrow schema or have field accessors) + let input_field_count = projection.input().schema().fields().len(); + if projection + .projection_expr() + .should_push_through_operator(input_field_count) + { // Each column in the predicate expression must exist after the projection. if let Some(new_predicate) = update_expr(self.predicate(), projection.expr(), false)? diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index 8f2f2219f4338..53d18de825f70 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -20,7 +20,7 @@ //! of a projection on table `t1` where the expressions `a`, `b`, and `a+b` are the //! projection expressions. `SELECT` without `FROM` will only evaluate expressions. -use super::expressions::{Column, Literal}; +use super::expressions::Column; use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use super::{ DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream, @@ -255,18 +255,10 @@ impl ExecutionPlan for ProjectionExec { } fn benefits_from_input_partitioning(&self) -> Vec { - let all_simple_exprs = - self.projector - .projection() - .as_ref() - .iter() - .all(|proj_expr| { - proj_expr.expr.as_any().is::() - || proj_expr.expr.as_any().is::() - }); - // If expressions are all either column_expr or Literal, then all computations in this projection are reorder or rename, - // and projection would not benefit from the repartition, benefits_from_input_partitioning will return false. - vec![!all_simple_exprs] + // If expressions are all trivial (columns, literals, or field accessors), + // then all computations in this projection are reorder or rename, + // and projection would not benefit from the repartition. + vec![!self.projection_expr().is_trivial()] } fn children(&self) -> Vec<&Arc> { @@ -713,13 +705,6 @@ pub fn make_with_child( .map(|e| Arc::new(e) as _) } -/// Returns `true` if all the expressions in the argument are `Column`s. -pub fn all_columns(exprs: &[ProjectionExpr]) -> bool { - exprs - .iter() - .all(|proj_expr| proj_expr.expr.as_any().is::()) -} - /// Updates the given lexicographic ordering according to given projected /// expressions using the [`update_expr`] function. pub fn update_ordering( @@ -962,7 +947,7 @@ fn try_unifying_projections( // beneficial as caching mechanism for non-trivial computations. // See discussion in: https://github.com/apache/datafusion/issues/8296 if column_ref_map.iter().any(|(column, count)| { - *count > 1 && !is_expr_trivial(&Arc::clone(&child.expr()[column.index()].expr)) + *count > 1 && !&child.expr()[column.index()].expr.triviality().is_trivial() }) { return Ok(None); } @@ -1072,13 +1057,6 @@ fn new_columns_for_join_on( (new_columns.len() == hash_join_on.len()).then_some(new_columns) } -/// Checks if the given expression is trivial. -/// An expression is considered trivial if it is either a `Column` or a `Literal`. -fn is_expr_trivial(expr: &Arc) -> bool { - expr.as_any().downcast_ref::().is_some() - || expr.as_any().downcast_ref::().is_some() -} - #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 612c7bb27ddf4..547ceabddf6d2 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -34,7 +34,7 @@ use crate::coalesce::LimitedBatchCoalescer; use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType}; use crate::hash_utils::create_hashes; use crate::metrics::{BaselineMetrics, SpillMetrics}; -use crate::projection::{ProjectionExec, all_columns, make_with_child, update_expr}; +use crate::projection::{ProjectionExec, make_with_child, update_expr}; use crate::sorts::streaming_merge::StreamingMergeBuilder; use crate::spill::spill_manager::SpillManager; use crate::spill::spill_pool::{self, SpillPoolWriter}; @@ -1123,14 +1123,12 @@ impl ExecutionPlan for RepartitionExec { &self, projection: &ProjectionExec, ) -> Result>> { - // If the projection does not narrow the schema, we should not try to push it down. - if projection.expr().len() >= projection.input().schema().fields().len() { - return Ok(None); - } - // If pushdown is not beneficial or applicable, break it. + let input_field_count = projection.input().schema().fields().len(); if projection.benefits_from_input_partitioning()[0] - || !all_columns(projection.expr()) + || !projection + .projection_expr() + .should_push_through_operator(input_field_count) { return Ok(None); } diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 3e8fdf1f3ed7e..acab85f606992 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -1391,8 +1391,12 @@ impl ExecutionPlan for SortExec { &self, projection: &ProjectionExec, ) -> Result>> { - // If the projection does not narrow the schema, we should not try to push it down. - if projection.expr().len() >= projection.input().schema().fields().len() { + // Only push down projections that are trivial AND provide benefit (narrow schema or have field accessors) + let input_field_count = projection.input().schema().fields().len(); + if !projection + .projection_expr() + .should_push_through_operator(input_field_count) + { return Ok(None); } diff --git a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs index 68c457a0d8a3c..d63d351ad53ee 100644 --- a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs +++ b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs @@ -391,8 +391,12 @@ impl ExecutionPlan for SortPreservingMergeExec { &self, projection: &ProjectionExec, ) -> Result>> { - // If the projection does not narrow the schema, we should not try to push it down. - if projection.expr().len() >= projection.input().schema().fields().len() { + // Only push down projections that are trivial AND provide benefit (narrow schema or have field accessors) + let input_field_count = projection.input().schema().fields().len(); + if !projection + .projection_expr() + .should_push_through_operator(input_field_count) + { return Ok(None); } diff --git a/datafusion/sqllogictest/test_files/unnest.slt b/datafusion/sqllogictest/test_files/unnest.slt index f939cd0154a82..1a6b82020c667 100644 --- a/datafusion/sqllogictest/test_files/unnest.slt +++ b/datafusion/sqllogictest/test_files/unnest.slt @@ -673,8 +673,8 @@ logical_plan physical_plan 01)ProjectionExec: expr=[__unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1],depth=2)@0 as UNNEST(UNNEST(UNNEST(recursive_unnest_table.column3)[c1])), column3@1 as column3] 02)--UnnestExec -03)----ProjectionExec: expr=[get_field(__unnest_placeholder(recursive_unnest_table.column3,depth=1)@0, c1) as __unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1]), column3@1 as column3] -04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +04)------ProjectionExec: expr=[get_field(__unnest_placeholder(recursive_unnest_table.column3,depth=1)@0, c1) as __unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1]), column3@1 as column3] 05)--------UnnestExec 06)----------ProjectionExec: expr=[column3@0 as __unnest_placeholder(recursive_unnest_table.column3), column3@0 as column3] 07)------------DataSourceExec: partitions=1, partition_sizes=[1]