diff --git a/Cargo.lock b/Cargo.lock index 80425a87afd1..03934698dedf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2270,6 +2270,7 @@ dependencies = [ name = "datafusion-functions" version = "50.0.0" dependencies = [ + "ahash 0.8.12", "arrow", "arrow-buffer", "base64 0.22.1", @@ -2512,6 +2513,7 @@ dependencies = [ "datafusion-common-runtime", "datafusion-execution", "datafusion-expr", + "datafusion-functions", "datafusion-functions-aggregate", "datafusion-functions-aggregate-common", "datafusion-functions-window", diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs b/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs index 9f588519ecac..b2ead0e0a7d1 100644 --- a/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs +++ b/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs @@ -24,7 +24,7 @@ use arrow::{ }; use arrow_schema::SortOptions; use datafusion::{ - assert_batches_eq, + assert_batches_eq, assert_batches_sorted_eq, logical_expr::Operator, physical_plan::{ expressions::{BinaryExpr, Column, Literal}, @@ -60,6 +60,8 @@ use futures::StreamExt; use object_store::{memory::InMemory, ObjectStore}; use util::{format_plan_for_test, OptimizationTest, TestNode, TestScanBuilder}; +use crate::physical_optimizer::filter_pushdown::util::SlowPartitionNode; + mod util; #[test] @@ -1199,13 +1201,6 @@ async fn test_hashjoin_dynamic_filter_pushdown_partitioned() { let result = format!("{}", pretty_format_batches(&batches).unwrap()); - let probe_scan_metrics = probe_scan.metrics().unwrap(); - - // The probe side had 4 rows, but after applying the dynamic filter only 2 rows should remain. - // The number of output rows from the probe side scan should stay consistent across executions. - // Issue: https://github.com/apache/datafusion/issues/17451 - assert_eq!(probe_scan_metrics.output_rows().unwrap(), 2); - insta::assert_snapshot!( result, @r" @@ -1219,6 +1214,272 @@ async fn test_hashjoin_dynamic_filter_pushdown_partitioned() { ); } +/// Test demonstrating progressive dynamic filter evolution in partitioned hash joins +/// +/// This test validates that instead of waiting for all +/// build-side partitions to complete before applying any filters, we apply partial filters +/// immediately as each partition completes. +/// To be able to evaluate partial filters we need to know which partition each probe row belongs to, +/// so we push down a hash function to the probe side that computes the same hash as repartitioning will later on. +/// +/// ## Test Scenario Setup +/// +/// - **Build side**: Values [1, 2, 3, 4] distributed across 3 hash partitions +/// - **Probe side**: Values [2, 3] that need to be filtered +/// - **Partition 1 is artificially slowed**: Simulates real-world partition skew +/// +/// ## Progressive Filter Evolution Demonstration +/// +/// The test shows how the dynamic filter evolves through three distinct phases: +/// +/// ### Phase 1: Initial State (All Partitions Building) +/// ```sql +/// -- No filter applied yet +/// predicate=DynamicFilterPhysicalExpr [ true ] +/// ``` +/// → All probe-side data passes through unfiltered +/// +/// ### Phase 2: Progressive Filtering (Some Partitions Complete) +/// ```sql +/// -- Hash-based progressive filter after partition 0 completes +/// predicate=DynamicFilterPhysicalExpr [ +/// CASE repartition_hash(id@0) % 3 +/// WHEN 0 THEN id@0 >= 3 AND id@0 <= 3 -- Only partition 0 bounds known +/// ELSE true -- Pass through partitions 1,2 data +/// END +/// ] +/// ``` +/// → Filters probe data for partition 0, passes through everything else safely +/// +/// ### Phase 3: Final Optimization (All Partitions Complete) +/// ```sql +/// -- Optimized bounds-only filter +/// predicate=DynamicFilterPhysicalExpr [ +/// id@0 >= 3 AND id@0 <= 3 OR -- Partition 0 bounds +/// id@0 >= 2 AND id@0 <= 2 OR -- Partition 1 bounds +/// id@0 >= 1 AND id@0 <= 4 -- Partition 2 bounds +/// ] +/// ``` +/// → Bounds filter with no hash computation overhead +/// +/// ## Correctness Validation +/// +/// The test verifies: +/// 1. **No False Negatives**: All valid join results [2,3] are preserved throughout +/// 2. **Progressive Improvement**: Filter selectivity increases as partitions complete +/// 3. **Final Optimization**: Hash-based expressions are removed when all partitions finish +/// 4. **Partition Isolation**: Each partition's filter only affects its own hash bucket +/// +/// ## Real-World Impact +/// +/// This optimization addresses common production scenarios where: +/// - Some partitions finish much faster than others (data skew) +/// - Waiting for large build sides before starting the probe sides increases latency +#[tokio::test] +#[cfg(not(feature = "force_hash_collisions"))] // this test relies on hash partitioning to separate rows +async fn test_hashjoin_progressive_filter_reporting() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + use crate::physical_optimizer::filter_pushdown::util::TestRepartitionHash; + + // Create build side with limited values + let build_batches_1 = vec![record_batch!(("id", UInt64, [1, 2, 3, 4])).unwrap()]; + let build_side_schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::UInt64, false)])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + // Add two batches -> creates 2 partitions + .with_batches(build_batches_1) + .build(); + + // Create probe side with more values + let probe_batches_1 = vec![record_batch!(("id", UInt64, [2, 3])).unwrap()]; + let probe_side_schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::UInt64, false)])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .with_batches(probe_batches_1) + .build(); + + // Create RepartitionExec nodes for both sides with hash partitioning on join keys + let partition_count = 3; + + // Build side: DataSource -> RepartitionExec (Hash) -> CoalesceBatchesExec + let build_hash_exprs = vec![ + col("id", &build_side_schema).unwrap(), + col("id", &build_side_schema).unwrap(), + ]; + let build_side = Arc::new( + RepartitionExec::try_new( + build_scan, + Partitioning::Hash(build_hash_exprs, partition_count), + ) + .unwrap() + .with_hash_function(ScalarUDF::new_from_impl(TestRepartitionHash::new())), + ); + let build_side = Arc::new(SlowPartitionNode::new(build_side, vec![1])); + + // Probe side: DataSource -> RepartitionExec (Hash) -> CoalesceBatchesExec + let probe_hash_exprs = vec![ + col("id", &probe_side_schema).unwrap(), + col("id", &probe_side_schema).unwrap(), + ]; + let probe_side = Arc::new( + RepartitionExec::try_new( + Arc::clone(&probe_scan), + Partitioning::Hash(probe_hash_exprs, partition_count), + ) + .unwrap() + .with_hash_function(ScalarUDF::new_from_impl(TestRepartitionHash::new())), + ); + + // Create HashJoinExec with partitioned inputs + let on = vec![( + col("id", &build_side_schema).unwrap(), + col("id", &probe_side_schema).unwrap(), + )]; + let plan = Arc::new( + HashJoinExec::try_new( + Arc::clone(&build_side) as Arc, + probe_side, + on, + None, + &JoinType::Inner, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + ) + .unwrap(), + ) as Arc; + + // Verify the initial optimization - should show DynamicFilterPhysicalExpr is set up + // but not yet populated with any bounds (shows as "true" initially) + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new_post_optimization(), true), + @r" + OptimizationTest: + input: + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(id@0, id@0)] + - SlowPartitionNode + - RepartitionExec: partitioning=Hash([id@0, id@0], 3), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id], file_type=test, pushdown_supported=true + - RepartitionExec: partitioning=Hash([id@0, id@0], 3), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id], file_type=test, pushdown_supported=true + output: + Ok: + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(id@0, id@0)] + - SlowPartitionNode + - RepartitionExec: partitioning=Hash([id@0, id@0], 3), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id], file_type=test, pushdown_supported=true + - RepartitionExec: partitioning=Hash([id@0, id@0], 3), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id], file_type=test, pushdown_supported=true, predicate=DynamicFilterPhysicalExpr [ true ] + " + ); + + // Actually apply the optimization to the plan and execute to see the filter in action + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + config.optimizer.enable_dynamic_filter_pushdown = true; + let plan = FilterPushdown::new_post_optimization() + .optimize(plan, &config) + .unwrap(); + let config = SessionConfig::new().with_batch_size(10); + let session_ctx = SessionContext::new_with_config(config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + let mut batches = Vec::new(); + + // Execute partition 0 directly to test progressive behavior + // This partition should complete while partition 1 is blocked + let mut stream_0 = plan.execute(0, Arc::clone(&task_ctx)).unwrap(); + + // Pull batches from partition 0 (should work even while partition 1 is blocked) + while let Some(batch_result) = stream_0.next().await { + let batch = batch_result.unwrap(); + if batch.num_rows() > 0 { + batches.push(batch); + } + } + + // CRITICAL VALIDATION: This snapshot shows the progressive filter in action! + // After partition 0 completes (but partition 1 is still blocked), we see: + // - CASE repartition_hash(id@0) % 3 WHEN 0 THEN id@0 >= 3 AND id@0 <= 3 ELSE true END + // This means: + // - For rows that hash to partition 0: Apply bounds check (id >= 3 AND id <= 3) + // - For rows that hash to partitions 1,2: Pass everything through (ELSE true) + // This is the core of progressive filtering - partial filtering without false negatives! + #[cfg(not(feature = "force_hash_collisions"))] + insta::assert_snapshot!( + format!("{}", format_plan_for_test(&plan)), + @r" + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(id@0, id@0)] + - SlowPartitionNode + - RepartitionExec: partitioning=Hash([id@0, id@0], 3), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id], file_type=test, pushdown_supported=true + - RepartitionExec: partitioning=Hash([id@0, id@0], 3), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id], file_type=test, pushdown_supported=true, predicate=DynamicFilterPhysicalExpr [ CASE repartition_hash(id@0) % 3 WHEN 0 THEN id@0 >= 3 AND id@0 <= 3 ELSE true END ] + " + ); + + #[rustfmt::skip] + let expected = [ + "+----+----+", + "| id | id |", + "+----+----+", + "| 3 | 3 |", + "+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + + // Wake the slow build side partition 0 to allow it to complete + build_side.unblock(); + + // Pull remaining batches + let mut stream_1 = plan.execute(1, Arc::clone(&task_ctx)).unwrap(); + while let Some(batch) = stream_1.next().await { + batches.push(batch.unwrap()); + } + let mut stream_2 = plan.execute(2, Arc::clone(&task_ctx)).unwrap(); + while let Some(batch) = stream_2.next().await { + batches.push(batch.unwrap()); + } + + // FINAL OPTIMIZATION VALIDATION: All partitions complete - filter is now optimized! + // The hash-based CASE expression has been replaced with a simple OR of bounds: + // - id@0 >= 3 AND id@0 <= 3 OR id@0 >= 2 AND id@0 <= 2 OR id@0 >= 1 AND id@0 <= 4 + // This is much more efficient - no hash computation needed, just bounds checks. + // Each OR clause represents one partition's bounds: [3,3], [2,2], [1,4] + insta::assert_snapshot!( + format!("{}", format_plan_for_test(&plan)), + @r" + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(id@0, id@0)] + - SlowPartitionNode + - RepartitionExec: partitioning=Hash([id@0, id@0], 3), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id], file_type=test, pushdown_supported=true + - RepartitionExec: partitioning=Hash([id@0, id@0], 3), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id], file_type=test, pushdown_supported=true, predicate=DynamicFilterPhysicalExpr [ id@0 >= 3 AND id@0 <= 3 OR id@0 >= 2 AND id@0 <= 2 OR id@0 >= 1 AND id@0 <= 4 ] + " + ); + + // Look at the final results + #[rustfmt::skip] + let expected = [ + "+----+----+", + "| id | id |", + "+----+----+", + "| 2 | 2 |", + "| 3 | 3 |", + "+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); +} + #[tokio::test] async fn test_hashjoin_dynamic_filter_pushdown_collect_left() { use datafusion_common::JoinType; diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs b/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs index 2fe705b14921..ffd428c9c2d0 100644 --- a/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs +++ b/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs @@ -15,8 +15,10 @@ // specific language governing permissions and limitations // under the License. +use arrow::array::UInt64Array; use arrow::datatypes::SchemaRef; use arrow::{array::RecordBatch, compute::concat_batches}; +use arrow_schema::DataType; use datafusion::{datasource::object_store::ObjectStoreUrl, physical_plan::PhysicalExpr}; use datafusion_common::{config::ConfigOptions, internal_err, Result, Statistics}; use datafusion_datasource::{ @@ -25,6 +27,8 @@ use datafusion_datasource::{ file_stream::FileOpener, schema_adapter::DefaultSchemaAdapterFactory, schema_adapter::SchemaAdapterFactory, source::DataSourceExec, PartitionedFile, }; +use datafusion_execution::RecordBatchStream; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; use datafusion_physical_expr_common::physical_expr::fmt_sql; use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::filter::batch_filter; @@ -42,6 +46,7 @@ use datafusion_physical_plan::{ use futures::StreamExt; use futures::{FutureExt, Stream}; use object_store::ObjectStore; +use std::future::Future; use std::{ any::Any, fmt::{Display, Formatter}, @@ -576,3 +581,263 @@ impl ExecutionPlan for TestNode { } } } + +pub struct Flag { + sender: tokio::sync::watch::Sender, +} + +impl Flag { + /// Creates a new flag object. + pub fn new() -> Self { + Self { + sender: tokio::sync::watch::channel(false).0, + } + } + + /// Enables the flag. + pub fn enable(&self) { + self.sender.send_if_modified(|value| { + if *value { + false + } else { + *value = true; + + true + } + }); + } + + /// Waits the flag to become enabled. + pub async fn wait_enabled(&self) { + if !*self.sender.borrow() { + let mut receiver = self.sender.subscribe(); + + if !*receiver.borrow() { + receiver.changed().await.ok(); + } + } + } +} + +/// An execution plan node that waits for a notification before yielding any data from designated partitions. +pub struct SlowPartitionNode { + input: Arc, + flag: Arc, + slow_partitions: Vec, +} + +impl SlowPartitionNode { + pub fn new(input: Arc, slow_partitions: Vec) -> Self { + Self { + input, + flag: Arc::new(Flag::new()), + slow_partitions, + } + } + + pub fn unblock(&self) { + self.flag.enable(); + } +} + +impl std::fmt::Debug for SlowPartitionNode { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "SlowPartitionNode") + } +} + +impl DisplayAs for SlowPartitionNode { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + write!(f, "SlowPartitionNode") + } +} + +impl ExecutionPlan for SlowPartitionNode { + fn name(&self) -> &str { + "SlowPartitionNode" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + self.input.properties() + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + assert!(children.len() == 1); + Ok(Arc::new(SlowPartitionNode { + input: children[0].clone(), + flag: Arc::clone(&self.flag), + slow_partitions: self.slow_partitions.clone(), + })) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + if self.slow_partitions.contains(&partition) { + let stream = self.input.execute(partition, context)?; + let waiter_stream = WaiterStream { + inner: stream, + flag: Arc::clone(&self.flag), + flag_checked: false, + }; + Ok(Box::pin(waiter_stream) + as datafusion_execution::SendableRecordBatchStream) + } else { + self.input.execute(partition, context) + } + } +} + +/// Stream that waits for a notification before yielding the first batch +struct WaiterStream { + inner: datafusion_execution::SendableRecordBatchStream, + flag: Arc, + flag_checked: bool, +} + +impl Stream for WaiterStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + // If we haven't checked the flag yet, wait for it to be enabled + if !this.flag_checked { + let flag = Arc::clone(&this.flag); + let wait_future = flag.wait_enabled(); + futures::pin_mut!(wait_future); + + match wait_future.poll(cx) { + Poll::Ready(()) => { + // Flag is now enabled, mark as checked and continue + this.flag_checked = true; + } + Poll::Pending => { + // Still waiting for flag to be enabled + return Poll::Pending; + } + } + } + + // Flag has been checked and is enabled, delegate to inner stream + Pin::new(&mut this.inner).poll_next(cx) + } +} + +impl RecordBatchStream for WaiterStream { + fn schema(&self) -> SchemaRef { + self.inner.schema() + } +} + +/// A hash repartition implementation that only accepts integers and hashes them to themselves. +#[derive(Debug)] +pub struct TestRepartitionHash { + signature: datafusion_expr::Signature, +} + +impl TestRepartitionHash { + pub fn new() -> Self { + Self { + signature: datafusion_expr::Signature::one_of( + vec![datafusion_expr::TypeSignature::VariadicAny], + datafusion_expr::Volatility::Immutable, + ), + } + } +} + +impl PartialEq for TestRepartitionHash { + fn eq(&self, other: &Self) -> bool { + // RandomState doesn't implement PartialEq, so we just compare signatures + self.signature == other.signature + } +} + +impl Eq for TestRepartitionHash {} + +impl std::hash::Hash for TestRepartitionHash { + fn hash(&self, state: &mut H) { + // Only hash the signature since RandomState doesn't implement Hash + self.signature.hash(state); + } +} + +impl ScalarUDFImpl for TestRepartitionHash { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "test_repartition_hash" + } + + fn signature(&self) -> &datafusion_expr::Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + // Always return UInt64Array regardless of input types + Ok(DataType::UInt64) + } + + fn invoke_with_args( + &self, + args: datafusion_expr::ScalarFunctionArgs, + ) -> Result { + // All inputs must be arrays of UInt64 + let arrays: Vec = args + .args + .iter() + .map(|cv| { + let ColumnarValue::Array(array) = cv else { + panic!("Expected array input"); + }; + let Some(array) = array.as_any().downcast_ref::() else { + panic!("Expected UInt64Array input"); + }; + array.clone() + }) + .collect(); + // We accept only 1 array + if arrays.is_empty() { + return Err(datafusion_common::DataFusionError::Internal( + "Expected at least one argument".to_string(), + )); + } + + let num_rows = arrays[0].len(); + let mut result_values = Vec::with_capacity(num_rows); + + // Add together all the integer values from all input arrays + for row_idx in 0..num_rows { + let mut sum = 0u64; + for array in &arrays { + let value = array.value(row_idx); + sum = sum.wrapping_add(value); + } + result_values.push(sum); + } + // Return the summed values as a UInt64Array + Ok(ColumnarValue::Array(Arc::new(UInt64Array::from( + result_values, + )))) + } + + fn documentation(&self) -> Option<&datafusion_expr::Documentation> { + None + } +} diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 90331fbccaf0..93c600c76650 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -62,6 +62,7 @@ unicode_expressions = ["unicode-segmentation"] name = "datafusion_functions" [dependencies] +ahash = { workspace = true } arrow = { workspace = true } arrow-buffer = { workspace = true } base64 = { version = "0.22", optional = true } diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index 607224782fc4..c8e9efac5e53 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -53,6 +53,7 @@ datafusion-common = { workspace = true } datafusion-common-runtime = { workspace = true, default-features = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } +datafusion-functions = { workspace = true } datafusion-functions-aggregate-common = { workspace = true } datafusion-functions-window-common = { workspace = true } datafusion-physical-expr = { workspace = true, default-features = true } diff --git a/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs index 25f7a0de31ac..ecef9cca6e50 100644 --- a/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs +++ b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs @@ -18,10 +18,12 @@ //! Utilities for shared bounds. Used in dynamic filter pushdown in Hash Joins. // TODO: include the link to the Dynamic Filter blog post. +use std::collections::HashMap; use std::fmt; use std::sync::Arc; use crate::joins::PartitionMode; +use crate::repartition::hash::repartition_hash; use crate::ExecutionPlan; use crate::ExecutionPlanProperties; @@ -29,10 +31,8 @@ use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{lit, BinaryExpr, DynamicFilterPhysicalExpr}; use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef}; - use itertools::Itertools; use parking_lot::Mutex; -use tokio::sync::Barrier; /// Represents the minimum and maximum values for a specific column. /// Used in dynamic filter pushdown to establish value boundaries. @@ -78,24 +78,72 @@ impl PartitionBounds { } } -/// Coordinates dynamic filter bounds collection across multiple partitions +/// Coordinates dynamic filter bounds collection across multiple partitions with progressive filtering +/// +/// This structure applies dynamic filters progressively as each partition completes, rather than +/// waiting for all partitions. This reduces probe-side scan overhead and improves performance. +/// +/// ## Progressive Filtering Strategy +/// +/// The key insight is that we can apply partial filters immediately without waiting for all +/// partitions to complete, while maintaining correctness through hash-based expressions. +/// +/// 1. **Immediate Filter Injection**: Each partition computes bounds from its build-side data +/// and immediately injects a progressive filter +/// 2. **Hash-Based Correctness**: Filters use hash expressions to ensure no false negatives: +/// `CASE hash(cols) % num_partitions WHEN partition_id THEN (col >= min AND col <= max) ELSE true END` +/// 3. **Incremental Improvement**: As partitions complete, filter selectivity increases +/// 4. **Final Optimization**: When all partitions complete, hash checks are removed +/// +/// ## Concrete Example /// -/// This structure ensures that dynamic filters are built with complete information from all -/// relevant partitions before being applied to probe-side scans. Incomplete filters would -/// incorrectly eliminate valid join results. +/// Consider a 3-partition hash join on column `id` with build-side values: +/// - Partition 0: id ∈ [10, 20] (completes first) +/// - Partition 1: id ∈ [30, 40] (completes second) +/// - Partition 2: id ∈ [50, 60] (completes last) /// -/// ## Synchronization Strategy +/// ### Progressive Phase Filters: /// -/// 1. Each partition computes bounds from its build-side data -/// 2. Bounds are stored in the shared vector -/// 3. A barrier tracks how many partitions have reported their bounds -/// 4. When the last partition reports, bounds are merged and the filter is updated exactly once +/// **After Partition 0 completes:** +/// ```sql +/// CASE hash(id) % 3 +/// WHEN 0 THEN id >= 10 AND id <= 20 +/// ELSE true +/// END +/// ``` +/// → Filters partition-0 data immediately, passes through all other data /// -/// ## Partition Counting +/// **After Partition 1 completes:** +/// ```sql +/// CASE hash(id) % 3 +/// WHEN 0 THEN id >= 10 AND id <= 20 +/// WHEN 1 THEN id >= 30 AND id <= 40 +/// ELSE true +/// END +/// ``` +/// → Now filters both partition-0 and partition-1 data /// -/// The `total_partitions` count represents how many times `collect_build_side` will be called: -/// - **CollectLeft**: Number of output partitions (each accesses shared build data) -/// - **Partitioned**: Number of input partitions (each builds independently) +/// ### Final Phase Filter: +/// **After all partitions complete:** +/// ```sql +/// (id >= 10 AND id <= 20) OR (id >= 30 AND id <= 40) OR (id >= 50 AND id <= 60) +/// ``` +/// → Optimized bounds-only filter, no hash computation needed +/// +/// ## Correctness Guarantee +/// +/// The hash-based approach ensures **no false negatives**: +/// - For rows belonging to completed partitions: bounds check filters correctly +/// - For rows belonging to incomplete partitions: `ELSE true` passes everything through +/// - Hash function matches the partitioning scheme, ensuring correct partition assignment +/// +/// ## Performance Benefits +/// +/// - **Early Filtering**: Probe-side scans start filtering immediately, not after barrier +/// - **Progressive Improvement**: Filter selectivity increases with each completed partition +/// - **Reduced I/O**: Less data read from probe-side sources as partitions complete +/// - **No Coordination Overhead**: Eliminates barrier synchronization between partitions +/// - **Final Optimization**: Removes hash computation cost when all partitions are done /// /// ## Thread Safety /// @@ -104,7 +152,8 @@ impl PartitionBounds { pub(crate) struct SharedBoundsAccumulator { /// Shared state protected by a single mutex to avoid ordering concerns inner: Mutex, - barrier: Barrier, + /// Total number of partitions expected to report + total_partitions: usize, /// Dynamic filter for pushdown to probe side dynamic_filter: Arc, /// Right side join expressions needed for creating filter bounds @@ -114,8 +163,12 @@ pub(crate) struct SharedBoundsAccumulator { /// State protected by SharedBoundsAccumulator's mutex struct SharedBoundsState { /// Bounds from completed partitions. - /// Each element represents the column bounds computed by one partition. - bounds: Vec, + /// Each value represents the column bounds computed by one partition and the keys are the partition ids. + bounds: HashMap, + /// Whether we've optimized the filter to remove hash checks + filter_optimized: bool, + /// Number of partitions that have reported completion (for tracking when all are done) + completed_count: usize, } impl SharedBoundsAccumulator { @@ -153,7 +206,7 @@ impl SharedBoundsAccumulator { ) -> Self { // Troubleshooting: If partition counts are incorrect, verify this logic matches // the actual execution pattern in collect_build_side() - let expected_calls = match partition_mode { + let total_partitions = match partition_mode { // Each output partition accesses shared build data PartitionMode::CollectLeft => { right_child.output_partitioning().partition_count() @@ -167,99 +220,248 @@ impl SharedBoundsAccumulator { }; Self { inner: Mutex::new(SharedBoundsState { - bounds: Vec::with_capacity(expected_calls), + bounds: HashMap::with_capacity(total_partitions), + filter_optimized: false, + completed_count: 0, }), - barrier: Barrier::new(expected_calls), + total_partitions, dynamic_filter, on_right, } } - /// Create a filter expression from individual partition bounds using OR logic. + /// Create a bounds predicate for a single partition: (col >= min AND col <= max) for all columns. + /// This is used in both progressive and final filter creation. + /// Returns None if no bounds are available for this partition. + fn create_partition_bounds_predicate( + &self, + partition_bounds: &PartitionBounds, + ) -> Result>> { + let mut column_predicates = Vec::with_capacity(partition_bounds.len()); + + for (col_idx, right_expr) in self.on_right.iter().enumerate() { + if let Some(column_bounds) = partition_bounds.get_column_bounds(col_idx) { + // Create predicate: col >= min AND col <= max + let min_expr = Arc::new(BinaryExpr::new( + Arc::clone(right_expr), + Operator::GtEq, + lit(column_bounds.min.clone()), + )) as Arc; + let max_expr = Arc::new(BinaryExpr::new( + Arc::clone(right_expr), + Operator::LtEq, + lit(column_bounds.max.clone()), + )) as Arc; + let range_expr = + Arc::new(BinaryExpr::new(min_expr, Operator::And, max_expr)) + as Arc; + column_predicates.push(range_expr); + } else { + // Missing bounds for this column, the created predicate will have lower selectivity but will still be correct + continue; + } + } + + // Combine all column predicates for this partition with AND + Ok(column_predicates.into_iter().reduce(|acc, pred| { + Arc::new(BinaryExpr::new(acc, Operator::And, pred)) as Arc + })) + } + + /// Create progressive filter using hash-based expressions to avoid false negatives. + /// + /// This is the heart of progressive filtering. It creates a CASE expression that applies + /// bounds filtering only to rows belonging to completed partitions, while safely passing + /// through all data from incomplete partitions. /// - /// This creates a filter where each partition's bounds form a conjunction (AND) - /// of column range predicates, and all partitions are combined with OR. + /// ## Generated Expression Structure: + /// ```sql + /// CASE hash(cols) % num_partitions + /// WHEN 0 THEN (col1 >= min1 AND col1 <= max1 AND col2 >= min2 AND col2 <= max2) + /// WHEN 1 THEN (col1 >= min3 AND col1 <= max3 AND col2 >= min4 AND col2 <= max4) + /// ... + /// ELSE true -- Critical: ensures no false negatives for incomplete partitions + /// END + /// ``` /// - /// For example, with 2 partitions and 2 columns: - /// ((col0 >= p0_min0 AND col0 <= p0_max0 AND col1 >= p0_min1 AND col1 <= p0_max1) - /// OR - /// (col0 >= p1_min0 AND col0 <= p1_max0 AND col1 >= p1_min1 AND col1 <= p1_max1)) - pub(crate) fn create_filter_from_partition_bounds( + /// ## Correctness Key Points: + /// - **Hash Function**: Uses the same hash as the join's partitioning scheme + /// - **Modulo Operation**: Maps hash values to partition IDs (0 to num_partitions-1) + /// - **WHEN Clauses**: Only created for partitions that have completed and reported bounds + /// - **ELSE true**: Ensures rows from incomplete partitions are never filtered out + /// - **Single Hash**: Hash is computed once per row, regardless of how many partitions completed + pub(crate) fn create_progressive_filter_from_partition_bounds( &self, - bounds: &[PartitionBounds], + bounds: &HashMap, ) -> Result> { - if bounds.is_empty() { - return Ok(lit(true)); - } + // Step 1: Create the partition assignment expression: hash(join_cols) % num_partitions + // This must match the hash function used by RepartitionExec for correctness + let hash_expr = repartition_hash(self.on_right.clone())?; + let total_partitions_expr = + lit(ScalarValue::UInt64(Some(self.total_partitions as u64))); + let modulo_expr = Arc::new(BinaryExpr::new( + hash_expr, + Operator::Modulo, + total_partitions_expr, + )) as Arc; - // Create a predicate for each partition - let mut partition_predicates = Vec::with_capacity(bounds.len()); - - for partition_bounds in bounds.iter().sorted_by_key(|b| b.partition) { - // Create range predicates for each join key in this partition - let mut column_predicates = Vec::with_capacity(partition_bounds.len()); - - for (col_idx, right_expr) in self.on_right.iter().enumerate() { - if let Some(column_bounds) = partition_bounds.get_column_bounds(col_idx) { - // Create predicate: col >= min AND col <= max - let min_expr = Arc::new(BinaryExpr::new( - Arc::clone(right_expr), - Operator::GtEq, - lit(column_bounds.min.clone()), - )) as Arc; - let max_expr = Arc::new(BinaryExpr::new( - Arc::clone(right_expr), - Operator::LtEq, - lit(column_bounds.max.clone()), - )) as Arc; - let range_expr = - Arc::new(BinaryExpr::new(min_expr, Operator::And, max_expr)) - as Arc; - column_predicates.push(range_expr); + // Step 2: Build WHEN clauses for each completed partition + // Format: WHEN partition_id THEN (bounds_predicate) + let when_thens = bounds.values().sorted_by_key(|b| b.partition).try_fold( + Vec::new(), + |mut acc, partition_bounds| { + // Create literal for partition ID (e.g., WHEN 0, WHEN 1, etc.) + let when_value = + lit(ScalarValue::UInt64(Some(partition_bounds.partition as u64))); + + // Create bounds predicate for this partition (e.g., col >= min AND col <= max) + if let Some(then_predicate) = + self.create_partition_bounds_predicate(partition_bounds)? + { + acc.push((when_value, then_predicate)); } - } + Ok::<_, datafusion_common::DataFusionError>(acc) + }, + )?; - // Combine all column predicates for this partition with AND - if !column_predicates.is_empty() { - let partition_predicate = column_predicates - .into_iter() - .reduce(|acc, pred| { - Arc::new(BinaryExpr::new(acc, Operator::And, pred)) - as Arc - }) - .unwrap(); - partition_predicates.push(partition_predicate); + // Step 3: Build the complete CASE expression + use datafusion_physical_expr::expressions::case; + let expr = if when_thens.is_empty() { + // Edge case: No partitions have completed yet - pass everything through + lit(ScalarValue::Boolean(Some(true))) + } else { + // Create CASE expression with critical ELSE true clause + // The ELSE true ensures we never filter out rows from incomplete partitions + case( + Some(modulo_expr), // CASE hash(cols) % num_partitions + when_thens, // WHEN clauses for completed partitions + Some(lit(ScalarValue::Boolean(Some(true)))), // ELSE true - no false negatives! + )? + }; + + Ok(expr) + } + + /// Create final optimized filter when all partitions have completed + /// + /// This method represents the performance optimization phase of progressive filtering. + /// Once all partitions have reported their bounds, we can eliminate the hash-based + /// CASE expression and use a simpler, more efficient bounds-only filter. + /// + /// ## Optimization Benefits: + /// 1. **No Hash Computation**: Eliminates expensive hash calculations per row + /// 2. **Simpler Expression**: OR-based bounds are faster to evaluate than CASE expressions + /// 3. **Better Vectorization**: Simple bounds comparisons optimize better in Arrow + /// 4. **Reduced CPU Overhead**: Significant performance improvement for large datasets + /// + /// ## Generated Expression Structure: + /// ```sql + /// (col1 >= min1 AND col1 <= max1) OR -- Partition 0 bounds + /// (col1 >= min2 AND col1 <= max2) OR -- Partition 1 bounds + /// ... + /// (col1 >= minN AND col1 <= maxN) -- Partition N bounds + /// ``` + /// + /// ## Correctness Maintained: + /// - Each OR clause represents the exact bounds from one partition + /// - Union of all partition bounds = complete build-side value range + /// - No false negatives: if a value exists in build side, it passes this filter + /// - Same filtering effect as progressive filter, but much more efficient + /// + /// This transformation is only applied when ALL partitions have completed to ensure + /// we have complete bounds information. + pub(crate) fn create_optimized_filter_from_partition_bounds( + &self, + bounds: &HashMap, + ) -> Result>> { + // Build individual partition predicates - each becomes one OR clause + let mut partition_filters = Vec::with_capacity(bounds.len()); + + for partition_bounds in bounds.values().sorted_by_key(|b| b.partition) { + if let Some(filter) = + self.create_partition_bounds_predicate(partition_bounds)? + { + // This partition contributed bounds - include in optimized filter + partition_filters.push(filter); } + // Skip empty partitions gracefully - they don't contribute bounds but + // shouldn't prevent the optimization from proceeding } - // Combine all partition predicates with OR - let combined_predicate = partition_predicates - .into_iter() - .reduce(|acc, pred| { - Arc::new(BinaryExpr::new(acc, Operator::Or, pred)) - as Arc - }) - .unwrap_or_else(|| lit(true)); - - Ok(combined_predicate) + // Create the final OR expression: bounds_0 OR bounds_1 OR ... OR bounds_N + // This replaces the hash-based CASE expression with a much faster bounds-only check + Ok(partition_filters.into_iter().reduce(|acc, filter| { + Arc::new(BinaryExpr::new(acc, Operator::Or, filter)) as Arc + })) } - /// Report bounds from a completed partition and update dynamic filter if all partitions are done + /// Report bounds from a completed partition and immediately update the dynamic filter + /// + /// This is the core method that implements progressive filtering. Unlike traditional approaches + /// that wait for all partitions to complete, this method immediately applies a partial filter + /// as soon as each partition finishes building its hash table. + /// + /// ## Progressive Filter Logic + /// + /// The method maintains correctness through careful filter design: + /// + /// **Key Insight**: We can safely filter rows that belong to completed partitions while + /// letting all other rows pass through, because the hash function determines partition + /// membership deterministically. + /// + /// ## Filter Evolution Example + /// + /// Consider a 2-partition join on column `price`: + /// + /// **Initial state**: No filter applied + /// ```sql + /// -- All probe-side rows pass through + /// SELECT * FROM probe_table -- No filtering + /// ``` + /// + /// **After Partition 0 completes** (found price range [100, 200]): + /// ```sql + /// -- Progressive filter applied + /// SELECT * FROM probe_table + /// WHERE CASE hash(price) % 2 + /// WHEN 0 THEN price >= 100 AND price <= 200 -- Filter partition-0 data + /// ELSE true -- Pass through partition-1 data + /// END + /// ``` + /// → Filters out probe rows with price ∉ [100, 200] that hash to partition 0 + /// + /// **After Partition 1 completes** (found price range [500, 600]): + /// ```sql + /// -- Final optimized filter + /// SELECT * FROM probe_table + /// WHERE (price >= 100 AND price <= 200) OR (price >= 500 AND price <= 600) + /// ``` + /// → Clean bounds-only filter, no hash computation needed + /// + /// ## Correctness Guarantee /// - /// This method coordinates the dynamic filter updates across all partitions. It stores the - /// bounds from the current partition, increments the completion counter, and when all - /// partitions have reported, creates an OR'd filter from individual partition bounds. + /// This approach ensures **zero false negatives** (never incorrectly excludes valid joins): /// - /// This method is async and uses a [`tokio::sync::Barrier`] to wait for all partitions - /// to report their bounds. Once that occurs, the method will resolve for all callers and the - /// dynamic filter will be updated exactly once. + /// 1. **Completed Partitions**: Rows are filtered by actual build-side bounds + /// 2. **Incomplete Partitions**: All rows pass through (`ELSE true`) + /// 3. **Partition Assignment**: Hash function matches the join's partitioning scheme exactly + /// 4. **Bounds Accuracy**: Min/max values computed from actual build-side data /// - /// # Note + /// The filter may have **false positives** (includes rows that won't join) during the + /// progressive phase, but these are eliminated during the actual join operation. /// - /// As barriers are reusable, it is likely an error to call this method more times than the - /// total number of partitions - as it can lead to pending futures that never resolve. We rely - /// on correct usage from the caller rather than imposing additional checks here. If this is a concern, - /// consider making the resulting future shared so the ready result can be reused. + /// ## Concurrency Handling + /// + /// - **Thread Safety**: Uses mutex to coordinate between concurrent partition executions + /// - **Deduplication**: Handles multiple reports from same partition (CollectLeft mode) + /// - **Atomic Updates**: Filter updates are applied atomically to avoid inconsistent states + /// + /// ## Performance Impact + /// + /// - **Immediate Benefit**: Probe-side filtering starts as soon as first partition completes + /// - **I/O Reduction**: Less data read from storage/network as build partitions complete + /// - **CPU Optimization**: Final filter removes hash computation overhead + /// - **Scalability**: No barrier synchronization delays between partitions /// /// # Arguments /// * `left_side_partition_id` - The identifier for the **left-side** partition reporting its bounds @@ -267,39 +469,52 @@ impl SharedBoundsAccumulator { /// /// # Returns /// * `Result<()>` - Ok if successful, Err if filter update failed - pub(crate) async fn report_partition_bounds( + pub(crate) fn report_partition_bounds( &self, left_side_partition_id: usize, partition_bounds: Option>, ) -> Result<()> { - // Store bounds in the accumulator - this runs once per partition - if let Some(bounds) = partition_bounds { - let mut guard = self.inner.lock(); - - let should_push = if let Some(last_bound) = guard.bounds.last() { - // In `PartitionMode::CollectLeft`, all streams on the left side share the same partition id (0). - // Since this function can be called multiple times for that same partition, we must deduplicate - // by checking against the last recorded bound. - last_bound.partition != left_side_partition_id - } else { - true - }; + let mut inner = self.inner.lock(); - if should_push { - guard - .bounds - .push(PartitionBounds::new(left_side_partition_id, bounds)); - } - } + // Always increment completion counter - every partition reports exactly once + inner.completed_count += 1; - if self.barrier.wait().await.is_leader() { - // All partitions have reported, so we can update the filter - let inner = self.inner.lock(); - if !inner.bounds.is_empty() { - let filter_expr = - self.create_filter_from_partition_bounds(&inner.bounds)?; - self.dynamic_filter.update(filter_expr)?; - } + // Store bounds from this partition (avoid duplicates) + // In CollectLeft mode, multiple streams may report the same partition_id, + // but we only want to store bounds once + inner + .bounds + .entry(left_side_partition_id) + .or_insert_with(|| { + if let Some(bounds) = partition_bounds { + PartitionBounds::new(left_side_partition_id, bounds) + } else { + // Insert an empty bounds entry to track this partition + PartitionBounds::new(left_side_partition_id, vec![]) + } + }); + + let completed = inner.completed_count; + let total = self.total_partitions; + + let all_partitions_complete = completed == total; + + // Create the appropriate filter based on completion status + let filter_expr = if all_partitions_complete && !inner.filter_optimized { + // All partitions complete - use optimized filter without hash checks + inner.filter_optimized = true; + self.create_optimized_filter_from_partition_bounds(&inner.bounds)? + } else { + // Progressive phase - use hash-based filter + Some(self.create_progressive_filter_from_partition_bounds(&inner.bounds)?) + }; + + // Release lock before updating filter to avoid holding it during the update + drop(inner); + + // Update the dynamic filter + if let Some(filter_expr) = filter_expr { + self.dynamic_filter.update(filter_expr)?; } Ok(()) @@ -311,3 +526,140 @@ impl fmt::Debug for SharedBoundsAccumulator { write!(f, "SharedBoundsAccumulator") } } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::Array; + use arrow::array::Int32Array; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use datafusion_common::cast::as_boolean_array; + use std::sync::Arc; + + #[test] + fn test_create_optimized_filter_case_expr() -> Result<()> { + // Create a simple test setup + let schema = Schema::new(vec![ + Field::new("col1", DataType::Int32, false), + Field::new("col2", DataType::Int32, false), + ]); + + // Create test partition bounds + let bounds1 = PartitionBounds::new( + 0, + vec![ + ColumnBounds::new( + ScalarValue::Int32(Some(1)), + ScalarValue::Int32(Some(10)), + ), + ColumnBounds::new( + ScalarValue::Int32(Some(5)), + ScalarValue::Int32(Some(15)), + ), + ], + ); + let bounds2 = PartitionBounds::new( + 1, + vec![ + ColumnBounds::new( + ScalarValue::Int32(Some(20)), + ScalarValue::Int32(Some(30)), + ), + ColumnBounds::new( + ScalarValue::Int32(Some(25)), + ScalarValue::Int32(Some(35)), + ), + ], + ); + + // Create mock accumulator (we only need the parts that create_optimized_filter uses) + let on_right = vec![ + Arc::new(datafusion_physical_expr::expressions::Column::new( + "col1", 0, + )) as Arc, + Arc::new(datafusion_physical_expr::expressions::Column::new( + "col2", 1, + )) as Arc, + ]; + + let accumulator = SharedBoundsAccumulator { + inner: Mutex::new(SharedBoundsState { + bounds: HashMap::new(), + filter_optimized: false, + completed_count: 0, + }), + total_partitions: 2, + dynamic_filter: Arc::new(DynamicFilterPhysicalExpr::new( + on_right.clone(), + Arc::new(datafusion_physical_expr::expressions::Literal::new( + ScalarValue::Boolean(Some(true)), + )), + )), + on_right, + }; + + // Test the optimized filter creation + let bounds = HashMap::from([(0, bounds1.clone()), (1, bounds2.clone())]); + let filter = accumulator + .create_optimized_filter_from_partition_bounds(&bounds)? + .unwrap(); + + // Verify the filter is a CaseExpr (indirectly by checking it doesn't panic and has reasonable behavior) + let test_batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Int32Array::from(vec![5, 25, 100])), // col1 values + Arc::new(Int32Array::from(vec![10, 30, 200])), // col2 values + ], + )?; + + let result = filter.evaluate(&test_batch)?; + let result_array = result.into_array(test_batch.num_rows())?; + let result_array = as_boolean_array(&result_array)?; + + // Should have 3 results + assert_eq!(result_array.len(), 3); + + // The exact results depend on hash values, but we should get boolean results + for i in 0..3 { + // Just verify we get boolean values (true/false, not null for this simple case) + assert!( + !result_array.is_null(i), + "Result should not be null at index {i}" + ); + } + + Ok(()) + } + + #[test] + fn test_empty_bounds() -> Result<()> { + let on_right = vec![Arc::new(datafusion_physical_expr::expressions::Column::new( + "col1", 0, + )) as Arc]; + + let accumulator = SharedBoundsAccumulator { + inner: Mutex::new(SharedBoundsState { + bounds: HashMap::new(), + filter_optimized: false, + completed_count: 0, + }), + total_partitions: 2, + dynamic_filter: Arc::new(DynamicFilterPhysicalExpr::new( + on_right.clone(), + Arc::new(datafusion_physical_expr::expressions::Literal::new( + ScalarValue::Boolean(Some(true)), + )), + )), + on_right, + }; + + // Test with empty bounds + let res = + accumulator.create_optimized_filter_from_partition_bounds(&HashMap::new())?; + assert!(res.is_none(), "Expected None for empty bounds"); + + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/joins/hash_join/stream.rs b/datafusion/physical-plan/src/joins/hash_join/stream.rs index adc00d9fe75e..14481704ac54 100644 --- a/datafusion/physical-plan/src/joins/hash_join/stream.rs +++ b/datafusion/physical-plan/src/joins/hash_join/stream.rs @@ -121,8 +121,6 @@ impl BuildSide { pub(super) enum HashJoinStreamState { /// Initial state for HashJoinStream indicating that build-side data not collected yet WaitBuildSide, - /// Waiting for bounds to be reported by all partitions - WaitPartitionBoundsReport, /// Indicates that build-side has been collected, and stream is ready for fetching probe-side FetchProbeBatch, /// Indicates that non-empty batch has been fetched from probe-side, and is ready to be processed @@ -208,9 +206,6 @@ pub(super) struct HashJoinStream { right_side_ordered: bool, /// Shared bounds accumulator for coordinating dynamic filter updates (optional) bounds_accumulator: Option>, - /// Optional future to signal when bounds have been reported by all partitions - /// and the dynamic filter has been updated - bounds_waiter: Option>, /// Partitioning mode to use mode: PartitionMode, @@ -335,7 +330,6 @@ impl HashJoinStream { hashes_buffer, right_side_ordered, bounds_accumulator, - bounds_waiter: None, mode, } } @@ -351,9 +345,6 @@ impl HashJoinStream { HashJoinStreamState::WaitBuildSide => { handle_state!(ready!(self.collect_build_side(cx))) } - HashJoinStreamState::WaitPartitionBoundsReport => { - handle_state!(ready!(self.wait_for_partition_bounds_report(cx))) - } HashJoinStreamState::FetchProbeBatch => { handle_state!(ready!(self.fetch_probe_batch(cx))) } @@ -370,26 +361,6 @@ impl HashJoinStream { } } - /// Optional step to wait until bounds have been reported by all partitions. - /// This state is only entered if a bounds accumulator is present. - /// - /// ## Why wait? - /// - /// The dynamic filter is only built once all partitions have reported their bounds. - /// If we do not wait here, the probe-side scan may start before the filter is ready. - /// This can lead to the probe-side scan missing the opportunity to apply the filter - /// and skip reading unnecessary data. - fn wait_for_partition_bounds_report( - &mut self, - cx: &mut std::task::Context<'_>, - ) -> Poll>>> { - if let Some(ref mut fut) = self.bounds_waiter { - ready!(fut.get_shared(cx))?; - } - self.state = HashJoinStreamState::FetchProbeBatch; - Poll::Ready(Ok(StatefulStreamResult::Continue)) - } - /// Collects build-side data by polling `OnceFut` future from initialized build-side /// /// Updates build-side to `Ready`, and state to `FetchProbeSide` @@ -408,28 +379,22 @@ impl HashJoinStream { // Handle dynamic filter bounds accumulation // - // Dynamic filter coordination between partitions: - // Report bounds to the accumulator which will handle synchronization and filter updates + // Progressive dynamic filter coordination: + // Report bounds immediately to inject progressive filters if let Some(ref bounds_accumulator) = self.bounds_accumulator { - let bounds_accumulator = Arc::clone(bounds_accumulator); - let left_side_partition_id = match self.mode { PartitionMode::Partitioned => self.partition, PartitionMode::CollectLeft => 0, PartitionMode::Auto => unreachable!("PartitionMode::Auto should not be present at execution time. This is a bug in DataFusion, please report it!"), }; - let left_data_bounds = left_data.bounds.clone(); - self.bounds_waiter = Some(OnceFut::new(async move { - bounds_accumulator - .report_partition_bounds(left_side_partition_id, left_data_bounds) - .await - })); - self.state = HashJoinStreamState::WaitPartitionBoundsReport; - } else { - self.state = HashJoinStreamState::FetchProbeBatch; + bounds_accumulator.report_partition_bounds( + left_side_partition_id, + left_data.bounds.clone(), + )?; } + self.state = HashJoinStreamState::FetchProbeBatch; self.build_side = BuildSide::Ready(BuildSideReadyState { left_data }); Poll::Ready(Ok(StatefulStreamResult::Continue)) } @@ -533,7 +498,7 @@ impl HashJoinStream { if need_produce_result_in_final(self.join_type) { let mut bitmap = build_side.left_data.visited_indices_bitmap().lock(); left_indices.iter().flatten().for_each(|x| { - bitmap.set_bit(x as usize, true); + bitmap.set_bit(usize::try_from(x).expect("fits in a usize"), true); }); } diff --git a/datafusion/physical-plan/src/repartition/hash.rs b/datafusion/physical-plan/src/repartition/hash.rs new file mode 100644 index 000000000000..81299f999db2 --- /dev/null +++ b/datafusion/physical-plan/src/repartition/hash.rs @@ -0,0 +1,154 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Hash utilities for repartitioning data + +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field}; +use datafusion_common::{config::ConfigOptions, Result}; +use datafusion_expr::{ColumnarValue, ScalarUDF, ScalarUDFImpl}; + +use ahash::RandomState; +use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef, ScalarFunctionExpr}; + +/// Internal hash function used for repartitioning inputs. +/// This is used for partitioned HashJoinExec and partitioned GroupByExec. +/// Currently we use AHash with fixed seeds, but this is subject to change. +/// We make no promises about stability of this function across versions. +/// Currently this is *not* stable across machines since AHash is not stable across platforms, +/// thus this should only be used in a single node context. +#[derive(Debug)] +pub(crate) struct RepartitionHash { + signature: datafusion_expr::Signature, + /// RandomState for consistent hashing - using the same seed as hash joins + random_state: RandomState, +} + +impl PartialEq for RepartitionHash { + fn eq(&self, other: &Self) -> bool { + // RandomState doesn't implement PartialEq, so we just compare signatures + self.signature == other.signature + } +} + +impl Eq for RepartitionHash {} + +impl std::hash::Hash for RepartitionHash { + fn hash(&self, state: &mut H) { + // Only hash the signature since RandomState doesn't implement Hash + self.signature.hash(state); + } +} + +impl RepartitionHash { + /// Create a new RepartitionHash + pub(crate) fn new() -> Self { + Self { + signature: datafusion_expr::Signature::one_of( + vec![datafusion_expr::TypeSignature::VariadicAny], + datafusion_expr::Volatility::Immutable, + ), + random_state: REPARTITION_RANDOM_STATE, + } + } +} + +pub(crate) fn repartition_hash( + args: Vec, +) -> Result> { + let hash = ScalarUDF::new_from_impl(RepartitionHash::new()); + let name = hash.name().to_string(); + Ok(Arc::new(ScalarFunctionExpr::new( + &name, + Arc::new(hash), + args, + Arc::new(Field::new(&name, DataType::UInt64, false)), + Arc::new(ConfigOptions::default()), + ))) +} + +impl ScalarUDFImpl for RepartitionHash { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "repartition_hash" + } + + fn signature(&self) -> &datafusion_expr::Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + // Always return UInt64Array regardless of input types + Ok(DataType::UInt64) + } + + fn invoke_with_args( + &self, + args: datafusion_expr::ScalarFunctionArgs, + ) -> Result { + use arrow::array::{Array, UInt64Array}; + use datafusion_common::hash_utils::create_hashes; + use std::sync::Arc; + + if args.args.is_empty() { + return datafusion_common::plan_err!("hash requires at least one argument"); + } + + // Convert all arguments to arrays + let arrays = ColumnarValue::values_to_arrays(&args.args)?; + + // Check that all arrays have the same length + let array_len = arrays[0].len(); + for (i, array) in arrays.iter().enumerate() { + if array.len() != array_len { + return datafusion_common::plan_err!( + "All input arrays must have the same length. Array 0 has length {}, but array {} has length {}", + array_len, i, array.len() + ); + } + } + + // If no rows, return an empty UInt64Array + if array_len == 0 { + return Ok(ColumnarValue::Array(Arc::new(UInt64Array::from( + Vec::::new(), + )))); + } + + // Create hash buffer and compute hashes using DataFusion's internal algorithm + let mut hashes_buffer = vec![0u64; array_len]; + create_hashes(&arrays, &self.random_state, &mut hashes_buffer)?; + + // Return the hash values as a UInt64Array + Ok(ColumnarValue::Array(Arc::new(UInt64Array::from( + hashes_buffer, + )))) + } + + fn documentation(&self) -> Option<&datafusion_expr::Documentation> { + None + } +} + +/// RandomState used by RepartitionExec for consistent hash partitioning +/// This must match the seeds used in RepartitionExec to ensure our hash-based +/// filter expressions compute the same partition assignments as the actual partitioning +const REPARTITION_RANDOM_STATE: RandomState = RandomState::with_seeds(0, 0, 0, 0); diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 22bc1b5cf924..6b49989e3063 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -31,19 +31,20 @@ use super::{ DisplayAs, ExecutionPlanProperties, RecordBatchStream, SendableRecordBatchStream, }; use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType}; -use crate::hash_utils::create_hashes; use crate::metrics::BaselineMetrics; use crate::projection::{all_columns, make_with_child, update_expr, ProjectionExec}; use crate::repartition::distributor_channels::{ channels, partition_aware_channels, DistributionReceiver, DistributionSender, }; +use crate::repartition::hash::RepartitionHash; use crate::sorts::streaming_merge::StreamingMergeBuilder; use crate::stream::RecordBatchStreamAdapter; use crate::{DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, Statistics}; -use arrow::array::{PrimitiveArray, RecordBatch, RecordBatchOptions}; +use arrow::array::{PrimitiveArray, RecordBatch, RecordBatchOptions, UInt64Array}; use arrow::compute::take_arrays; use arrow::datatypes::{SchemaRef, UInt32Type}; +use arrow_schema::{DataType, Field}; use datafusion_common::config::ConfigOptions; use datafusion_common::stats::Precision; use datafusion_common::utils::transpose; @@ -52,7 +53,8 @@ use datafusion_common::{not_impl_err, DataFusionError, Result}; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_execution::TaskContext; -use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr}; +use datafusion_expr::{ColumnarValue, ScalarUDF}; +use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr, ScalarFunctionExpr}; use datafusion_physical_expr_common::sort_expr::LexOrdering; use crate::filter_pushdown::{ @@ -65,6 +67,7 @@ use log::trace; use parking_lot::Mutex; mod distributor_channels; +pub mod hash; type MaybeBatch = Option>; type InputPartitionsToCurrentPartitionSender = Vec>; @@ -149,6 +152,7 @@ impl RepartitionExecState { Ok(()) } + #[expect(clippy::too_many_arguments)] fn consume_input_streams( &mut self, input: Arc, @@ -157,6 +161,7 @@ impl RepartitionExecState { preserve_order: bool, name: String, context: Arc, + hash: ScalarUDF, ) -> Result<&mut ConsumingInputStreamsState> { let streams_and_metrics = match self { RepartitionExecState::NotInitialized => { @@ -227,6 +232,7 @@ impl RepartitionExecState { txs.clone(), partitioning.clone(), metrics, + hash.clone(), )); // In a separate task, wait for each input to be done @@ -258,10 +264,9 @@ pub struct BatchPartitioner { enum BatchPartitionerState { Hash { - random_state: ahash::RandomState, - exprs: Vec>, + hash: Arc, num_partitions: usize, - hash_buffer: Vec, + exprs: Vec>, }, RoundRobin { num_partitions: usize, @@ -281,19 +286,41 @@ impl BatchPartitioner { next_idx: 0, } } - Partitioning::Hash(exprs, num_partitions) => BatchPartitionerState::Hash { - exprs, - num_partitions, - // Use fixed random hash - random_state: ahash::RandomState::with_seeds(0, 0, 0, 0), - hash_buffer: vec![], - }, + Partitioning::Hash(exprs, num_partitions) => { + let hash = ScalarUDF::new_from_impl(RepartitionHash::new()); + let name = hash.name().to_string(); + BatchPartitionerState::Hash { + hash: Arc::new(ScalarFunctionExpr::new( + &name, + Arc::new(hash), + exprs.clone(), + Arc::new(Field::new(&name, DataType::UInt64, false)), + Arc::new(ConfigOptions::default()), + )), + num_partitions, + exprs, + } + } other => return not_impl_err!("Unsupported repartitioning scheme {other:?}"), }; - Ok(Self { state, timer }) } + /// Set the hash function to use for hash partitioning. + pub(crate) fn with_hash_function(mut self, hash: ScalarUDF) -> Self { + if let BatchPartitionerState::Hash { hash: h, exprs, .. } = &mut self.state { + let name = hash.name().to_string(); + *h = Arc::new(ScalarFunctionExpr::new( + &name, + Arc::new(hash), + exprs.clone(), + Arc::new(Field::new(&name, DataType::UInt64, false)), + Arc::new(ConfigOptions::default()), + )); + } + self + } + /// Partition the provided [`RecordBatch`] into one or more partitioned [`RecordBatch`] /// based on the [`Partitioning`] specified on construction /// @@ -333,30 +360,36 @@ impl BatchPartitioner { Box::new(std::iter::once(Ok((idx, batch)))) } BatchPartitionerState::Hash { - random_state, - exprs, + hash, num_partitions: partitions, - hash_buffer, + .. } => { // Tracking time required for distributing indexes across output partitions let timer = self.timer.timer(); - let arrays = exprs - .iter() - .map(|expr| expr.evaluate(&batch)?.into_array(batch.num_rows())) - .collect::>>()?; - - hash_buffer.clear(); - hash_buffer.resize(batch.num_rows(), 0); - - create_hashes(&arrays, random_state, hash_buffer)?; + let ColumnarValue::Array(hashes) = hash.evaluate(&batch)? else { + return internal_err!( + "Hash partitioning expression did not return an array" + ); + }; + let Some(hashes) = hashes.as_any().downcast_ref::() + else { + return internal_err!( + "Hash partitioning expression did not return a UInt64Array" + ); + }; let mut indices: Vec<_> = (0..*partitions) .map(|_| Vec::with_capacity(batch.num_rows())) .collect(); - for (index, hash) in hash_buffer.iter().enumerate() { - indices[(*hash % *partitions as u64) as usize].push(index as u32); + for (index, hash) in hashes.iter().enumerate() { + let Some(hash) = hash else { + return internal_err!( + "Hash partitioning expression returned null value" + ); + }; + indices[(hash % *partitions as u64) as usize].push(index as u32); } // Finished building index-arrays for output partitions @@ -486,6 +519,8 @@ pub struct RepartitionExec { preserve_order: bool, /// Cache holding plan properties like equivalences, output partitioning etc. cache: PlanProperties, + /// Hash function used for hash partitioning + hash: ScalarUDF, } #[derive(Debug, Clone)] @@ -624,7 +659,8 @@ impl ExecutionPlan for RepartitionExec { let mut repartition = RepartitionExec::try_new( children.swap_remove(0), self.partitioning().clone(), - )?; + )? + .with_hash_function(self.hash.clone()); if self.preserve_order { repartition = repartition.with_preserve_order(); } @@ -657,6 +693,7 @@ impl ExecutionPlan for RepartitionExec { let name = self.name().to_owned(); let schema = self.schema(); let schema_captured = Arc::clone(&schema); + let hash = self.hash.clone(); // Get existing ordering to use for merging let sort_exprs = self.sort_exprs().cloned(); @@ -685,6 +722,7 @@ impl ExecutionPlan for RepartitionExec { preserve_order, name.clone(), Arc::clone(&context), + hash.clone(), )?; // now return stream for the specified *output* partition which will @@ -877,9 +915,16 @@ impl RepartitionExec { metrics: ExecutionPlanMetricsSet::new(), preserve_order, cache, + hash: ScalarUDF::new_from_impl(RepartitionHash::new()), }) } + /// Set a custom hash function to use for hash partitioning. + pub fn with_hash_function(mut self, hash: ScalarUDF) -> Self { + self.hash = hash; + self + } + fn maintains_input_order_helper( input: &Arc, preserve_order: bool, @@ -962,9 +1007,11 @@ impl RepartitionExec { >, partitioning: Partitioning, metrics: RepartitionMetrics, + hash: ScalarUDF, ) -> Result<()> { let mut partitioner = - BatchPartitioner::try_new(partitioning, metrics.repartition_time.clone())?; + BatchPartitioner::try_new(partitioning, metrics.repartition_time.clone())? + .with_hash_function(hash); // While there are still outputs to send to, keep pulling inputs let mut batches_until_yield = partitioner.num_partitions();