diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index 2996152fb924..a9201f435ad8 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -29,7 +29,7 @@ use super::{ }; use crate::{ metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, - Column, DisplayFormatType, ExecutionPlan, + DisplayFormatType, ExecutionPlan, }; use arrow::compute::filter_record_batch; @@ -192,9 +192,7 @@ impl FilterExec { let mut eq_properties = input.equivalence_properties().clone(); let (equal_pairs, _) = collect_columns_from_predicate(predicate); for (lhs, rhs) in equal_pairs { - let lhs_expr = Arc::new(lhs.clone()) as _; - let rhs_expr = Arc::new(rhs.clone()) as _; - eq_properties.add_equal_conditions(&lhs_expr, &rhs_expr) + eq_properties.add_equal_conditions(lhs, rhs) } // Add the columns that have only one viable value (singleton) after // filtering to constants. @@ -405,34 +403,33 @@ impl RecordBatchStream for FilterExecStream { /// Return the equals Column-Pairs and Non-equals Column-Pairs fn collect_columns_from_predicate(predicate: &Arc) -> EqualAndNonEqual { - let mut eq_predicate_columns = Vec::<(&Column, &Column)>::new(); - let mut ne_predicate_columns = Vec::<(&Column, &Column)>::new(); + let mut eq_predicate_columns = Vec::::new(); + let mut ne_predicate_columns = Vec::::new(); let predicates = split_conjunction(predicate); predicates.into_iter().for_each(|p| { if let Some(binary) = p.as_any().downcast_ref::() { - if let (Some(left_column), Some(right_column)) = ( - binary.left().as_any().downcast_ref::(), - binary.right().as_any().downcast_ref::(), - ) { - match binary.op() { - Operator::Eq => { - eq_predicate_columns.push((left_column, right_column)) - } - Operator::NotEq => { - ne_predicate_columns.push((left_column, right_column)) - } - _ => {} + match binary.op() { + Operator::Eq => { + eq_predicate_columns.push((binary.left(), binary.right())) + } + Operator::NotEq => { + ne_predicate_columns.push((binary.left(), binary.right())) } + _ => {} } } }); (eq_predicate_columns, ne_predicate_columns) } + +/// Pair of `Arc`s +pub type PhysicalExprPairRef<'a> = (&'a Arc, &'a Arc); + /// The equals Column-Pairs and Non-equals Column-Pairs in the Predicates pub type EqualAndNonEqual<'a> = - (Vec<(&'a Column, &'a Column)>, Vec<(&'a Column, &'a Column)>); + (Vec>, Vec>); #[cfg(test)] mod tests { @@ -482,14 +479,16 @@ mod tests { )?; let (equal_pairs, ne_pairs) = collect_columns_from_predicate(&predicate); + assert_eq!(2, equal_pairs.len()); + assert!(equal_pairs[0].0.eq(&col("c2", &schema)?)); + assert!(equal_pairs[0].1.eq(&lit(4u32))); - assert_eq!(1, equal_pairs.len()); - assert_eq!(equal_pairs[0].0.name(), "c2"); - assert_eq!(equal_pairs[0].1.name(), "c9"); + assert!(equal_pairs[1].0.eq(&col("c2", &schema)?)); + assert!(equal_pairs[1].1.eq(&col("c9", &schema)?)); assert_eq!(1, ne_pairs.len()); - assert_eq!(ne_pairs[0].0.name(), "c1"); - assert_eq!(ne_pairs[0].1.name(), "c13"); + assert!(ne_pairs[0].0.eq(&col("c1", &schema)?)); + assert!(ne_pairs[0].1.eq(&col("c13", &schema)?)); Ok(()) } diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 4b4b37f8b51b..3e8e439c9a38 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -33,7 +33,6 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::utils::DataPtr; use datafusion_common::Result; use datafusion_execution::TaskContext; -use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{ EquivalenceProperties, LexOrdering, PhysicalSortExpr, PhysicalSortRequirement, }; diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index 3a5c6497ebd4..ad4b0df1a546 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -1386,6 +1386,27 @@ AggregateExec: mode=FinalPartitioned, gby=[c2@0 as c2], aggr=[COUNT(*)] --------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2], has_header=true +# FilterExec can track equality of non-column expressions. +# plan below shouldn't have a SortExec because given column 'a' is ordered. +# 'CAST(ROUND(b) as INT)' is also ordered. After filter is applied. +query TT +EXPLAIN SELECT * +FROM annotated_data_finite2 +WHERE CAST(ROUND(b) as INT) = a +ORDER BY CAST(ROUND(b) as INT); +---- +logical_plan +Sort: CAST(round(CAST(annotated_data_finite2.b AS Float64)) AS Int32) ASC NULLS LAST +--Filter: CAST(round(CAST(annotated_data_finite2.b AS Float64)) AS Int32) = annotated_data_finite2.a +----TableScan: annotated_data_finite2 projection=[a0, a, b, c, d], partial_filters=[CAST(round(CAST(annotated_data_finite2.b AS Float64)) AS Int32) = annotated_data_finite2.a] +physical_plan +SortPreservingMergeExec: [CAST(round(CAST(b@2 AS Float64)) AS Int32) ASC NULLS LAST] +--CoalesceBatchesExec: target_batch_size=8192 +----FilterExec: CAST(round(CAST(b@2 AS Float64)) AS Int32) = a@1 +------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true + + statement ok drop table annotated_data_finite2;