Skip to content
Open
2 changes: 1 addition & 1 deletion datafusion/common/src/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use std::fmt;
pub struct Column {
/// relation/table reference.
pub relation: Option<TableReference>,
/// field/column name.
/// Field/column name.
pub name: String,
/// Original source code location, if known
pub spans: Spans,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,13 @@ use datafusion_physical_plan::{
sorts::sort::SortExec,
};

use super::pushdown_utils::{
OptimizationTest, TestNode, TestScanBuilder, TestSource, format_plan_for_test,
};
use datafusion_physical_plan::union::UnionExec;
use futures::StreamExt;
use object_store::{ObjectStore, memory::InMemory};
use regex::Regex;
use util::{OptimizationTest, TestNode, TestScanBuilder, format_plan_for_test};

use crate::physical_optimizer::filter_pushdown::util::TestSource;

mod util;

#[test]
fn test_pushdown_into_scan() {
Expand Down
3 changes: 2 additions & 1 deletion datafusion/core/tests/physical_optimizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ mod combine_partial_final_agg;
mod enforce_distribution;
mod enforce_sorting;
mod enforce_sorting_monotonicity;
#[expect(clippy::needless_pass_by_value)]
mod filter_pushdown;
mod join_selection;
#[expect(clippy::needless_pass_by_value)]
Expand All @@ -38,3 +37,5 @@ mod sanity_checker;
#[expect(clippy::needless_pass_by_value)]
mod test_utils;
mod window_optimize;

mod pushdown_utils;
92 changes: 91 additions & 1 deletion datafusion/core/tests/physical_optimizer/projection_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,29 @@
use std::any::Any;
use std::sync::Arc;

use arrow::array::{Int32Array, RecordBatch, StructArray};
use arrow::compute::SortOptions;
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use arrow_schema::Fields;
use datafusion::datasource::listing::PartitionedFile;
use datafusion::datasource::memory::MemorySourceConfig;
use datafusion::datasource::physical_plan::CsvSource;
use datafusion::datasource::source::DataSourceExec;
use datafusion::prelude::get_field;
use datafusion_common::config::{ConfigOptions, CsvOptions};
use datafusion_common::{JoinSide, JoinType, NullEquality, Result, ScalarValue};
use datafusion_datasource::TableSchema;
use datafusion_datasource::file_scan_config::FileScanConfigBuilder;
use datafusion_execution::object_store::ObjectStoreUrl;
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
use datafusion_expr::{
Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, lit,
};
use datafusion_expr_common::columnar_value::ColumnarValue;
use datafusion_physical_expr::expressions::{
BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr, binary, cast, col,
};
use datafusion_physical_expr::planner::logical2physical;
use datafusion_physical_expr::{Distribution, Partitioning, ScalarFunctionExpr};
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use datafusion_physical_expr_common::sort_expr::{
Expand Down Expand Up @@ -64,6 +68,8 @@ use datafusion_physical_plan::{ExecutionPlan, displayable};
use insta::assert_snapshot;
use itertools::Itertools;

use crate::physical_optimizer::pushdown_utils::TestScanBuilder;

/// Mocked UDF
#[derive(Debug, PartialEq, Eq, Hash)]
struct DummyUDF {
Expand Down Expand Up @@ -1723,3 +1729,87 @@ fn test_cooperative_exec_after_projection() -> Result<()> {

Ok(())
}

#[test]
fn test_pushdown_projection_through_repartition_filter() {
let struct_fields = Fields::from(vec![Field::new("a", DataType::Int32, false)]);
let array = StructArray::new(
struct_fields.clone(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))],
None,
);
let batches = vec![
RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new(
"struct",
DataType::Struct(struct_fields.clone()),
true,
)])),
vec![Arc::new(array)],
)
.unwrap(),
];
let build_side_schema = Arc::new(Schema::new(vec![Field::new(
"struct",
DataType::Struct(struct_fields),
true,
)]));

let scan = TestScanBuilder::new(Arc::clone(&build_side_schema))
.with_support(true)
.with_batches(batches)
.build();
let scan_schema = scan.schema();
let struct_access = get_field(datafusion_expr::col("struct"), "a");
let filter = struct_access.clone().gt(lit(2));
let repartition =
RepartitionExec::try_new(scan, Partitioning::RoundRobinBatch(32)).unwrap();
let filter_exec = FilterExec::try_new(
logical2physical(&filter, &scan_schema),
Arc::new(repartition),
)
.unwrap();
let projection: Arc<dyn ExecutionPlan> = Arc::new(
ProjectionExec::try_new(
vec![ProjectionExpr::new(
logical2physical(&struct_access, &scan_schema),
"a",
)],
Arc::new(filter_exec),
)
.unwrap(),
) as _;

let initial = displayable(projection.as_ref()).indent(true).to_string();
let actual = initial.trim();

assert_snapshot!(
actual,
@r"
ProjectionExec: expr=[get_field(struct@0, a) as a]
FilterExec: get_field(struct@0, a) > 2
RepartitionExec: partitioning=RoundRobinBatch(32), input_partitions=1
DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[struct], file_type=test, pushdown_supported=true
"
);

let after_optimize = ProjectionPushdown::new()
.optimize(projection, &ConfigOptions::new())
.unwrap();

let after_optimize_string = displayable(after_optimize.as_ref())
.indent(true)
.to_string();
let actual = after_optimize_string.trim();

// Projection should be pushed all the way down to the DataSource, and
// filter predicate should be rewritten to reference projection's output column
assert_snapshot!(
actual,
@r"
FilterExec: a@0 > 2
RepartitionExec: partitioning=RoundRobinBatch(32), input_partitions=1
DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[get_field(struct@0, a) as a], file_type=test, pushdown_supported=true
"
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use datafusion_datasource::{
file_scan_config::FileScanConfigBuilder, file_stream::FileOpenFuture,
file_stream::FileOpener, source::DataSourceExec,
};
use datafusion_physical_expr::projection::ProjectionExprs;
use datafusion_physical_expr_common::physical_expr::fmt_sql;
use datafusion_physical_optimizer::PhysicalOptimizerRule;
use datafusion_physical_plan::filter::batch_filter;
Expand All @@ -50,7 +51,7 @@ use std::{
pub struct TestOpener {
batches: Vec<RecordBatch>,
batch_size: Option<usize>,
projection: Option<Vec<usize>>,
projection: Option<ProjectionExprs>,
predicate: Option<Arc<dyn PhysicalExpr>>,
}

Expand All @@ -60,6 +61,7 @@ impl FileOpener for TestOpener {
if self.batches.is_empty() {
return Ok((async { Ok(TestStream::new(vec![]).boxed()) }).boxed());
}
let schema = self.batches[0].schema();
if let Some(batch_size) = self.batch_size {
let batch = concat_batches(&batches[0].schema(), &batches)?;
let mut new_batches = Vec::new();
Expand All @@ -83,9 +85,10 @@ impl FileOpener for TestOpener {
batches = new_batches;

if let Some(projection) = &self.projection {
let projector = projection.make_projector(&schema)?;
batches = batches
.into_iter()
.map(|batch| batch.project(projection).unwrap())
.map(|batch| projector.project_batch(&batch).unwrap())
.collect();
}

Expand All @@ -103,14 +106,13 @@ pub struct TestSource {
batch_size: Option<usize>,
batches: Vec<RecordBatch>,
metrics: ExecutionPlanMetricsSet,
projection: Option<Vec<usize>>,
projection: Option<ProjectionExprs>,
table_schema: datafusion_datasource::TableSchema,
}

impl TestSource {
pub fn new(schema: SchemaRef, support: bool, batches: Vec<RecordBatch>) -> Self {
let table_schema =
datafusion_datasource::TableSchema::new(Arc::clone(&schema), vec![]);
let table_schema = datafusion_datasource::TableSchema::new(schema, vec![]);
Self {
support,
metrics: ExecutionPlanMetricsSet::new(),
Expand Down Expand Up @@ -210,6 +212,30 @@ impl FileSource for TestSource {
}
}

fn try_pushdown_projection(
&self,
projection: &ProjectionExprs,
) -> Result<Option<Arc<dyn FileSource>>> {
if let Some(existing_projection) = &self.projection {
// Combine existing projection with new projection
let combined_projection = existing_projection.try_merge(projection)?;
Ok(Some(Arc::new(TestSource {
projection: Some(combined_projection),
table_schema: self.table_schema.clone(),
..self.clone()
})))
} else {
Ok(Some(Arc::new(TestSource {
projection: Some(projection.clone()),
..self.clone()
})))
}
}

fn projection(&self) -> Option<&ProjectionExprs> {
self.projection.as_ref()
}

fn table_schema(&self) -> &datafusion_datasource::TableSchema {
&self.table_schema
}
Expand Down Expand Up @@ -332,6 +358,7 @@ pub struct OptimizationTest {
}

impl OptimizationTest {
#[expect(clippy::needless_pass_by_value)]
pub fn new<O>(
input_plan: Arc<dyn ExecutionPlan>,
opt: O,
Expand Down
7 changes: 3 additions & 4 deletions datafusion/core/tests/sql/explain_analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -996,10 +996,9 @@ async fn parquet_recursive_projection_pushdown() -> Result<()> {
SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false]
RecursiveQueryExec: name=number_series, is_distinct=false
CoalescePartitionsExec
ProjectionExec: expr=[id@0 as id, 1 as level]
FilterExec: id@0 = 1
RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES), input_partitions=1
DataSourceExec: file_groups={1 group: [[TMP_DIR/hierarchy.parquet]]}, projection=[id], file_type=parquet, predicate=id@0 = 1, pruning_predicate=id_null_count@2 != row_count@3 AND id_min@0 <= 1 AND 1 <= id_max@1, required_guarantees=[id in (1)]
FilterExec: id@0 = level@1
RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES), input_partitions=1
DataSourceExec: file_groups={1 group: [[TMP_DIR/hierarchy.parquet]]}, projection=[id, 1 as level], file_type=parquet, predicate=id@0 = 1, pruning_predicate=id_null_count@2 != row_count@3 AND id_min@0 <= 1 AND 1 <= id_max@1, required_guarantees=[id in (1)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is somewhat interesting that it materializes the constant in the scan. This is probably ok, but it does mean that constant may now get carried as a constant record batch up through the plan many 🤔

CoalescePartitionsExec
ProjectionExec: expr=[id@0 + 1 as ns.id + Int64(1), level@1 + 1 as ns.level + Int64(1)]
FilterExec: id@0 < 10
Expand Down
3 changes: 3 additions & 0 deletions datafusion/expr-common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,7 @@ pub mod operator;
pub mod signature;
pub mod sort_properties;
pub mod statistics;
pub mod triviality;
pub mod type_coercion;

pub use triviality::ArgTriviality;
57 changes: 57 additions & 0 deletions datafusion/expr-common/src/triviality.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Triviality classification for expressions and function arguments.

/// Classification of argument triviality for scalar functions.
///
/// This enum is used by [`ScalarUDFImpl::triviality_with_args`] to allow
/// functions to make context-dependent decisions about whether they are
/// trivial based on the nature of their arguments.
///
/// For example, `get_field(struct_col, 'field_name')` is trivial (static field
/// lookup), but `get_field(struct_col, key_col)` is not (dynamic per-row lookup).
///
/// [`ScalarUDFImpl::triviality_with_args`]: crate::ScalarUDFImpl::triviality_with_args
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ArgTriviality {
/// Argument is a literal constant value or an expression that can be
/// evaluated to a constant at planning time.
Literal,
/// Argument is a simple column reference.
Column,
/// Argument is a complex expressions that declares itself trivial.
/// For example, if `get_field(struct_col, 'field_name')` is implemented as a
/// trivial expression, then it would return this variant.
/// Then `other_trivial_function(get_field(...), 42)` could also be classified as
/// a trivial expression using the knowledge that `get_field(...)` is trivial.
TrivialExpr,
/// Argument is a complex expression that declares itself non-trivial.
/// For example, `min(col1 + col2)` is non-trivial because it requires per-row computation.
NonTrivial,
}

impl ArgTriviality {
/// Returns true if this triviality classification indicates a trivial
/// (cheap to evaluate) expression.
///
/// Trivial expressions include literals, column references, and trivial
/// composite expressions like nested field accessors.
pub fn is_trivial(&self) -> bool {
!matches!(self, ArgTriviality::NonTrivial)
}
}
28 changes: 27 additions & 1 deletion datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use std::sync::Arc;
use crate::expr_fn::binary_expr;
use crate::function::WindowFunctionSimplification;
use crate::logical_plan::Subquery;
use crate::{AggregateUDF, Volatility};
use crate::{AggregateUDF, ArgTriviality, Volatility};
use crate::{ExprSchemable, Operator, Signature, WindowFrame, WindowUDF};

use arrow::datatypes::{DataType, Field, FieldRef};
Expand Down Expand Up @@ -1933,6 +1933,32 @@ impl Expr {
}
}

/// Returns the triviality classification of this expression.
///
/// Trivial expressions include column references, literals, and nested
/// field access via `get_field`.
///
/// # Example
/// ```
/// # use datafusion_expr::{col, ArgTriviality};
/// let expr = col("foo");
/// assert!(expr.triviality().is_trivial());
/// ```
pub fn triviality(&self) -> ArgTriviality {
match self {
Expr::Column(_) => ArgTriviality::Column,
Expr::Literal(_, _) => ArgTriviality::Literal,
Expr::ScalarFunction(func) => {
// Classify each argument's triviality for context-aware decision making
let arg_trivialities: Vec<ArgTriviality> =
func.args.iter().map(|arg| arg.triviality()).collect();

func.func.triviality_with_args(&arg_trivialities)
}
_ => ArgTriviality::NonTrivial,
}
}

/// Return all references to columns in this expression.
///
/// # Example
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ pub use datafusion_doc::{
DocSection, Documentation, DocumentationBuilder, aggregate_doc_sections,
scalar_doc_sections, window_doc_sections,
};
pub use datafusion_expr_common::ArgTriviality;
pub use datafusion_expr_common::accumulator::Accumulator;
pub use datafusion_expr_common::columnar_value::ColumnarValue;
pub use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator};
Expand Down
Loading
Loading