From fb0a15340bf2f62bbc431998939bb7822bbce1d2 Mon Sep 17 00:00:00 2001 From: Rohan Krishnaswamy Date: Fri, 5 Sep 2025 20:59:58 -0700 Subject: [PATCH 1/6] print check to observe indeterminate predicate propagation to file opener --- .../core/tests/physical_optimizer/filter_pushdown/util.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs b/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs index 7d0020b2e937..6d2a84b991e2 100644 --- a/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs +++ b/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs @@ -53,6 +53,7 @@ pub struct TestOpener { batch_size: Option, schema: Option, projection: Option>, + predicate: Option>, } impl FileOpener for TestOpener { @@ -61,6 +62,12 @@ impl FileOpener for TestOpener { _file_meta: FileMeta, _file: PartitionedFile, ) -> Result { + if let Some(predicate) = &self.predicate { + println!( + "Predicate when calling open: {}", + fmt_sql(predicate.as_ref()) + ); + } let mut batches = self.batches.clone(); if let Some(batch_size) = self.batch_size { let batch = concat_batches(&batches[0].schema(), &batches)?; @@ -133,6 +140,7 @@ impl FileSource for TestSource { batch_size: self.batch_size, schema: self.schema.clone(), projection: self.projection.clone(), + predicate: self.predicate.clone(), }) } From 02053aa524343dca9e426dece11f24b5b5c59770 Mon Sep 17 00:00:00 2001 From: Rohan Krishnaswamy Date: Fri, 5 Sep 2025 21:01:40 -0700 Subject: [PATCH 2/6] wait until partition bounds are reported by all partitions before polling right side --- .../src/joins/hash_join/shared_bounds.rs | 129 ++++++++++++++---- .../src/joins/hash_join/stream.rs | 31 ++++- 2 files changed, 131 insertions(+), 29 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs index 73e65be68683..36bcdbc26c81 100644 --- a/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs +++ b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs @@ -19,6 +19,8 @@ // TODO: include the link to the Dynamic Filter blog post. use std::fmt; +use std::future::Future; +use std::sync::atomic::AtomicUsize; use std::sync::Arc; use crate::joins::PartitionMode; @@ -30,6 +32,7 @@ use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{lit, BinaryExpr, DynamicFilterPhysicalExpr}; use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef}; +use futures::task::AtomicWaker; use itertools::Itertools; use parking_lot::Mutex; @@ -119,7 +122,9 @@ struct SharedBoundsState { /// Each element represents the column bounds computed by one partition. bounds: Vec, /// Number of partitions that have reported completion. - completed_partitions: usize, + completed_partitions: Arc, + /// Cached wakers to wake when all partitions are complete + wakers: Vec>, } impl SharedBoundsAccumulator { @@ -170,7 +175,8 @@ impl SharedBoundsAccumulator { Self { inner: Mutex::new(SharedBoundsState { bounds: Vec::with_capacity(expected_calls), - completed_partitions: 0, + completed_partitions: Arc::new(AtomicUsize::new(0)), + wakers: Vec::new(), }), total_partitions: expected_calls, dynamic_filter, @@ -253,39 +259,69 @@ impl SharedBoundsAccumulator { /// bounds from the current partition, increments the completion counter, and when all /// partitions have reported, creates an OR'd filter from individual partition bounds. /// + /// It returns a [`BoundsWaiter`] future that can be awaited to ensure the filter has been + /// updated before proceeding. This is important to delay probe-side scans until the filter + /// is ready. + /// /// # Arguments + /// * `partition` - The partition identifier reporting its bounds /// * `partition_bounds` - The bounds computed by this partition (if any) /// /// # Returns - /// * `Result<()>` - Ok if successful, Err if filter update failed + /// * `Result>` - Ok if successful, Err if filter update failed pub(crate) fn report_partition_bounds( &self, partition: usize, partition_bounds: Option>, - ) -> Result<()> { - let mut inner = self.inner.lock(); + ) -> Result> { + // Scope for lock to avoid holding it across await points + let maybe_waiter = { + let mut inner = self.inner.lock(); - // Store bounds in the accumulator - this runs once per partition - if let Some(bounds) = partition_bounds { - // Only push actual bounds if they exist - inner.bounds.push(PartitionBounds::new(partition, bounds)); - } + // Store bounds in the accumulator - this runs once per partition + if let Some(bounds) = partition_bounds { + // Only push actual bounds if they exist + inner.bounds.push(PartitionBounds::new(partition, bounds)); + } - // Increment the completion counter - // Even empty partitions must report to ensure proper termination - inner.completed_partitions += 1; - let completed = inner.completed_partitions; - let total_partitions = self.total_partitions; - - // Critical synchronization point: Only update the filter when ALL partitions are complete - // Troubleshooting: If you see "completed > total_partitions", check partition - // count calculation in new_from_partition_mode() - it may not match actual execution calls - if completed == total_partitions && !inner.bounds.is_empty() { - let filter_expr = self.create_filter_from_partition_bounds(&inner.bounds)?; - self.dynamic_filter.update(filter_expr)?; - } + // Increment the completion counter + // Even empty partitions must report to ensure proper termination + inner + .completed_partitions + .fetch_add(1, std::sync::atomic::Ordering::SeqCst); + let completed = inner + .completed_partitions + .load(std::sync::atomic::Ordering::SeqCst); + let total_partitions = self.total_partitions; + + // Critical synchronization point: Only update the filter when ALL partitions are complete + // Troubleshooting: If you see "completed > total_partitions", check partition + // count calculation in new_from_partition_mode() - it may not match actual execution calls + if completed == total_partitions { + if !inner.bounds.is_empty() { + let filter_expr = + self.create_filter_from_partition_bounds(&inner.bounds)?; + self.dynamic_filter.update(filter_expr)?; + } + + // Notify any waiters that the filter is ready + for waker in inner.wakers.drain(..) { + waker.wake(); + } + + None + } else { + let waker = Arc::new(AtomicWaker::new()); + inner.wakers.push(Arc::clone(&waker)); + Some(BoundsWaiter::new( + total_partitions, + Arc::clone(&inner.completed_partitions), + waker, + )) + } + }; - Ok(()) + Ok(maybe_waiter) } } @@ -294,3 +330,48 @@ impl fmt::Debug for SharedBoundsAccumulator { write!(f, "SharedBoundsAccumulator") } } + +/// Utility future to wait until all partitions have reported completion +/// and the dynamic filter has been updated. +#[derive(Clone)] +pub(crate) struct BoundsWaiter { + waker: Arc, + total: usize, + completed: Arc, +} + +impl BoundsWaiter { + pub fn new( + total: usize, + completed: Arc, + waker: Arc, + ) -> Self { + Self { + waker, + total, + completed, + } + } +} + +impl Future for BoundsWaiter { + type Output = (); + + fn poll( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + // Quick check to avoid registration if already complete + if self.completed.load(std::sync::atomic::Ordering::Relaxed) >= self.total { + return std::task::Poll::Ready(()); + } + + self.waker.register(cx.waker()); + + if self.completed.load(std::sync::atomic::Ordering::Relaxed) >= self.total { + std::task::Poll::Ready(()) + } else { + std::task::Poll::Pending + } + } +} diff --git a/datafusion/physical-plan/src/joins/hash_join/stream.rs b/datafusion/physical-plan/src/joins/hash_join/stream.rs index d368a9cf8ee2..4c6759ed9b71 100644 --- a/datafusion/physical-plan/src/joins/hash_join/stream.rs +++ b/datafusion/physical-plan/src/joins/hash_join/stream.rs @@ -24,7 +24,7 @@ use std::sync::Arc; use std::task::Poll; use crate::joins::hash_join::exec::JoinLeftData; -use crate::joins::hash_join::shared_bounds::SharedBoundsAccumulator; +use crate::joins::hash_join::shared_bounds::{BoundsWaiter, SharedBoundsAccumulator}; use crate::joins::utils::{ equal_rows_arr, get_final_indices_from_shared_bitmap, OnceFut, }; @@ -50,7 +50,7 @@ use datafusion_common::{ use datafusion_physical_expr::PhysicalExprRef; use ahash::RandomState; -use futures::{ready, Stream, StreamExt}; +use futures::{ready, FutureExt, Stream, StreamExt}; /// Represents build-side of hash join. pub(super) enum BuildSide { @@ -120,6 +120,8 @@ impl BuildSide { pub(super) enum HashJoinStreamState { /// Initial state for HashJoinStream indicating that build-side data not collected yet WaitBuildSide, + /// Waiting for bounds to be reported by all partitions + WaitPartitionBoundsReport, /// Indicates that build-side has been collected, and stream is ready for fetching probe-side FetchProbeBatch, /// Indicates that non-empty batch has been fetched from probe-side, and is ready to be processed @@ -205,6 +207,9 @@ pub(super) struct HashJoinStream { right_side_ordered: bool, /// Shared bounds accumulator for coordinating dynamic filter updates (optional) bounds_accumulator: Option>, + /// Optional future to signal when bounds have been reported by all partitions + /// and the dynamic filter has been updated + bounds_waiter: Option, } impl RecordBatchStream for HashJoinStream { @@ -325,6 +330,7 @@ impl HashJoinStream { hashes_buffer, right_side_ordered, bounds_accumulator, + bounds_waiter: None, } } @@ -339,6 +345,9 @@ impl HashJoinStream { HashJoinStreamState::WaitBuildSide => { handle_state!(ready!(self.collect_build_side(cx))) } + HashJoinStreamState::WaitPartitionBoundsReport => { + handle_state!(ready!(self.wait_for_partition_bounds_report(cx))) + } HashJoinStreamState::FetchProbeBatch => { handle_state!(ready!(self.fetch_probe_batch(cx))) } @@ -355,6 +364,17 @@ impl HashJoinStream { } } + fn wait_for_partition_bounds_report( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { + if let Some(ref mut fut) = self.bounds_waiter { + ready!(fut.poll_unpin(cx)); + } + self.state = HashJoinStreamState::FetchProbeBatch; + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + /// Collects build-side data by polling `OnceFut` future from initialized build-side /// /// Updates build-side to `Ready`, and state to `FetchProbeSide` @@ -376,13 +396,14 @@ impl HashJoinStream { // Dynamic filter coordination between partitions: // Report bounds to the accumulator which will handle synchronization and filter updates if let Some(ref bounds_accumulator) = self.bounds_accumulator { - bounds_accumulator + self.bounds_waiter = bounds_accumulator .report_partition_bounds(self.partition, left_data.bounds.clone())?; + self.state = HashJoinStreamState::WaitPartitionBoundsReport; + } else { + self.state = HashJoinStreamState::FetchProbeBatch; } - self.state = HashJoinStreamState::FetchProbeBatch; self.build_side = BuildSide::Ready(BuildSideReadyState { left_data }); - Poll::Ready(Ok(StatefulStreamResult::Continue)) } From ef2c7732e373bdd9024b05f2ed645f47e7e294ba Mon Sep 17 00:00:00 2001 From: Rohan Krishnaswamy Date: Sat, 6 Sep 2025 17:54:55 -0700 Subject: [PATCH 3/6] cleanup + clarifying docs --- .../physical-plan/src/joins/hash_join/shared_bounds.rs | 8 +++----- datafusion/physical-plan/src/joins/hash_join/stream.rs | 9 +++++++++ 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs index 36bcdbc26c81..5c3108f8f9bd 100644 --- a/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs +++ b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs @@ -361,11 +361,9 @@ impl Future for BoundsWaiter { self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll { - // Quick check to avoid registration if already complete - if self.completed.load(std::sync::atomic::Ordering::Relaxed) >= self.total { - return std::task::Poll::Ready(()); - } - + // Ensure we register the waker first so we do not fall victim + // to lost wakeups. This is a no-op if our current waker and our + // stored waker are the same. self.waker.register(cx.waker()); if self.completed.load(std::sync::atomic::Ordering::Relaxed) >= self.total { diff --git a/datafusion/physical-plan/src/joins/hash_join/stream.rs b/datafusion/physical-plan/src/joins/hash_join/stream.rs index 4c6759ed9b71..4bf0c53a927a 100644 --- a/datafusion/physical-plan/src/joins/hash_join/stream.rs +++ b/datafusion/physical-plan/src/joins/hash_join/stream.rs @@ -364,6 +364,15 @@ impl HashJoinStream { } } + /// Optional step to wait until bounds have been reported by all partitions. + /// This state is only entered if a bounds accumulator is present. + /// + /// ## Why wait? + /// + /// The dynamic filter is only built once all partitions have reported their bounds. + /// If we do not wait here, the probe-side scan may start before the filter is ready. + /// This can lead to the probe-side scan missing the opportunity to apply the filter + /// and skip reading unnecessary data. fn wait_for_partition_bounds_report( &mut self, cx: &mut std::task::Context<'_>, From 35da826de3ce5d241ef1d74fed029ad04039a61c Mon Sep 17 00:00:00 2001 From: Rohan Krishnaswamy Date: Sun, 7 Sep 2025 19:56:11 -0700 Subject: [PATCH 4/6] add check to verify predicate is always present when initiating probe side scan --- .../tests/physical_optimizer/filter_pushdown/mod.rs | 9 ++++++++- .../physical_optimizer/filter_pushdown/util.rs | 13 +++++++------ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs b/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs index eaf3be2b86ed..691216020db6 100644 --- a/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs +++ b/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs @@ -1065,7 +1065,7 @@ async fn test_hashjoin_dynamic_filter_pushdown_partitioned() { ]; let probe_repartition = Arc::new( RepartitionExec::try_new( - probe_scan, + Arc::clone(&probe_scan), Partitioning::Hash(probe_hash_exprs, partition_count), ) .unwrap(), @@ -1199,6 +1199,13 @@ async fn test_hashjoin_dynamic_filter_pushdown_partitioned() { let result = format!("{}", pretty_format_batches(&batches).unwrap()); + let probe_scan_metrics = probe_scan.metrics().unwrap(); + + // The probe side had 4 rows, but after applying the dynamic filter only 2 rows should remain. + // The number of output rows from the probe side scan should stay consistent across executions. + // Issue: https://github.com/apache/datafusion/issues/17451 + assert_eq!(probe_scan_metrics.output_rows().unwrap(), 2); + insta::assert_snapshot!( result, @r" diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs b/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs index 6d2a84b991e2..2fe705b14921 100644 --- a/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs +++ b/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs @@ -27,6 +27,7 @@ use datafusion_datasource::{ }; use datafusion_physical_expr_common::physical_expr::fmt_sql; use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_plan::filter::batch_filter; use datafusion_physical_plan::filter_pushdown::{FilterPushdownPhase, PushedDown}; use datafusion_physical_plan::{ displayable, @@ -62,12 +63,6 @@ impl FileOpener for TestOpener { _file_meta: FileMeta, _file: PartitionedFile, ) -> Result { - if let Some(predicate) = &self.predicate { - println!( - "Predicate when calling open: {}", - fmt_sql(predicate.as_ref()) - ); - } let mut batches = self.batches.clone(); if let Some(batch_size) = self.batch_size { let batch = concat_batches(&batches[0].schema(), &batches)?; @@ -84,6 +79,12 @@ impl FileOpener for TestOpener { let (mapper, projection) = factory.map_schema(&batches[0].schema()).unwrap(); let mut new_batches = Vec::new(); for batch in batches { + let batch = if let Some(predicate) = &self.predicate { + batch_filter(&batch, predicate)? + } else { + batch + }; + let batch = batch.project(&projection).unwrap(); let batch = mapper.map_batch(batch).unwrap(); new_batches.push(batch); From c24ef7967aed41f96f441382b669b633d9a99a22 Mon Sep 17 00:00:00 2001 From: Rohan Krishnaswamy Date: Sun, 7 Sep 2025 20:38:37 -0700 Subject: [PATCH 5/6] update snap --- datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs b/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs index 691216020db6..118b860c5b18 100644 --- a/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs +++ b/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs @@ -1362,7 +1362,7 @@ async fn test_nested_hashjoin_dynamic_filter_pushdown() { - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, x], file_type=test, pushdown_supported=true - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, d@0)] - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[b, c, y], file_type=test, pushdown_supported=true, predicate=DynamicFilterPhysicalExpr [ b@0 >= aa AND b@0 <= ab ] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, z], file_type=test, pushdown_supported=true, predicate=DynamicFilterPhysicalExpr [ d@0 >= ca AND d@0 <= ce ] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, z], file_type=test, pushdown_supported=true, predicate=DynamicFilterPhysicalExpr [ d@0 >= ca AND d@0 <= cb ] " ); } From 2a41cede553b0e3e9de34b1d8ac5faabe6649af4 Mon Sep 17 00:00:00 2001 From: Rohan Krishnaswamy Date: Tue, 9 Sep 2025 11:18:27 -0700 Subject: [PATCH 6/6] prefer barrier --- .../src/joins/hash_join/shared_bounds.rs | 146 +++++------------- .../src/joins/hash_join/stream.rs | 18 ++- 2 files changed, 47 insertions(+), 117 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs index 5c3108f8f9bd..40dc4ac2e5d1 100644 --- a/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs +++ b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs @@ -19,8 +19,6 @@ // TODO: include the link to the Dynamic Filter blog post. use std::fmt; -use std::future::Future; -use std::sync::atomic::AtomicUsize; use std::sync::Arc; use crate::joins::PartitionMode; @@ -32,9 +30,9 @@ use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{lit, BinaryExpr, DynamicFilterPhysicalExpr}; use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef}; -use futures::task::AtomicWaker; use itertools::Itertools; use parking_lot::Mutex; +use tokio::sync::Barrier; /// Represents the minimum and maximum values for a specific column. /// Used in dynamic filter pushdown to establish value boundaries. @@ -89,9 +87,9 @@ impl PartitionBounds { /// ## Synchronization Strategy /// /// 1. Each partition computes bounds from its build-side data -/// 2. Bounds are stored in the shared HashMap (indexed by partition_id) -/// 3. A counter tracks how many partitions have reported their bounds -/// 4. When the last partition reports (completed == total), bounds are merged and filter is updated +/// 2. Bounds are stored in the shared vector +/// 3. A barrier tracks how many partitions have reported their bounds +/// 4. When the last partition reports, bounds are merged and the filter is updated exactly once /// /// ## Partition Counting /// @@ -106,10 +104,7 @@ impl PartitionBounds { pub(crate) struct SharedBoundsAccumulator { /// Shared state protected by a single mutex to avoid ordering concerns inner: Mutex, - /// Total number of partitions. - /// Need to know this so that we can update the dynamic filter once we are done - /// building *all* of the hash tables. - total_partitions: usize, + barrier: Barrier, /// Dynamic filter for pushdown to probe side dynamic_filter: Arc, /// Right side join expressions needed for creating filter bounds @@ -121,10 +116,6 @@ struct SharedBoundsState { /// Bounds from completed partitions. /// Each element represents the column bounds computed by one partition. bounds: Vec, - /// Number of partitions that have reported completion. - completed_partitions: Arc, - /// Cached wakers to wake when all partitions are complete - wakers: Vec>, } impl SharedBoundsAccumulator { @@ -175,10 +166,8 @@ impl SharedBoundsAccumulator { Self { inner: Mutex::new(SharedBoundsState { bounds: Vec::with_capacity(expected_calls), - completed_partitions: Arc::new(AtomicUsize::new(0)), - wakers: Vec::new(), }), - total_partitions: expected_calls, + barrier: Barrier::new(expected_calls), dynamic_filter, on_right, } @@ -259,69 +248,47 @@ impl SharedBoundsAccumulator { /// bounds from the current partition, increments the completion counter, and when all /// partitions have reported, creates an OR'd filter from individual partition bounds. /// - /// It returns a [`BoundsWaiter`] future that can be awaited to ensure the filter has been - /// updated before proceeding. This is important to delay probe-side scans until the filter - /// is ready. + /// This method is async and uses a [`tokio::sync::Barrier`] to wait for all partitions + /// to report their bounds. Once that occurs, the method will resolve for all callers and the + /// dynamic filter will be updated exactly once. + /// + /// # Note + /// + /// As barriers are reusable, it is likely an error to call this method more times than the + /// total number of partitions - as it can lead to pending futures that never resolve. We rely + /// on correct usage from the caller rather than imposing additional checks here. If this is a concern, + /// consider making the resulting future shared so the ready result can be reused. /// /// # Arguments /// * `partition` - The partition identifier reporting its bounds /// * `partition_bounds` - The bounds computed by this partition (if any) /// /// # Returns - /// * `Result>` - Ok if successful, Err if filter update failed - pub(crate) fn report_partition_bounds( + /// * `Result<()>` - Ok if successful, Err if filter update failed + pub(crate) async fn report_partition_bounds( &self, partition: usize, partition_bounds: Option>, - ) -> Result> { - // Scope for lock to avoid holding it across await points - let maybe_waiter = { - let mut inner = self.inner.lock(); - - // Store bounds in the accumulator - this runs once per partition - if let Some(bounds) = partition_bounds { - // Only push actual bounds if they exist - inner.bounds.push(PartitionBounds::new(partition, bounds)); - } - - // Increment the completion counter - // Even empty partitions must report to ensure proper termination - inner - .completed_partitions - .fetch_add(1, std::sync::atomic::Ordering::SeqCst); - let completed = inner - .completed_partitions - .load(std::sync::atomic::Ordering::SeqCst); - let total_partitions = self.total_partitions; - - // Critical synchronization point: Only update the filter when ALL partitions are complete - // Troubleshooting: If you see "completed > total_partitions", check partition - // count calculation in new_from_partition_mode() - it may not match actual execution calls - if completed == total_partitions { - if !inner.bounds.is_empty() { - let filter_expr = - self.create_filter_from_partition_bounds(&inner.bounds)?; - self.dynamic_filter.update(filter_expr)?; - } - - // Notify any waiters that the filter is ready - for waker in inner.wakers.drain(..) { - waker.wake(); - } + ) -> Result<()> { + // Store bounds in the accumulator - this runs once per partition + if let Some(bounds) = partition_bounds { + self.inner + .lock() + .bounds + .push(PartitionBounds::new(partition, bounds)); + } - None - } else { - let waker = Arc::new(AtomicWaker::new()); - inner.wakers.push(Arc::clone(&waker)); - Some(BoundsWaiter::new( - total_partitions, - Arc::clone(&inner.completed_partitions), - waker, - )) + if self.barrier.wait().await.is_leader() { + // All partitions have reported, so we can update the filter + let inner = self.inner.lock(); + if !inner.bounds.is_empty() { + let filter_expr = + self.create_filter_from_partition_bounds(&inner.bounds)?; + self.dynamic_filter.update(filter_expr)?; } - }; + } - Ok(maybe_waiter) + Ok(()) } } @@ -330,46 +297,3 @@ impl fmt::Debug for SharedBoundsAccumulator { write!(f, "SharedBoundsAccumulator") } } - -/// Utility future to wait until all partitions have reported completion -/// and the dynamic filter has been updated. -#[derive(Clone)] -pub(crate) struct BoundsWaiter { - waker: Arc, - total: usize, - completed: Arc, -} - -impl BoundsWaiter { - pub fn new( - total: usize, - completed: Arc, - waker: Arc, - ) -> Self { - Self { - waker, - total, - completed, - } - } -} - -impl Future for BoundsWaiter { - type Output = (); - - fn poll( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll { - // Ensure we register the waker first so we do not fall victim - // to lost wakeups. This is a no-op if our current waker and our - // stored waker are the same. - self.waker.register(cx.waker()); - - if self.completed.load(std::sync::atomic::Ordering::Relaxed) >= self.total { - std::task::Poll::Ready(()) - } else { - std::task::Poll::Pending - } - } -} diff --git a/datafusion/physical-plan/src/joins/hash_join/stream.rs b/datafusion/physical-plan/src/joins/hash_join/stream.rs index 4bf0c53a927a..4484eeabd326 100644 --- a/datafusion/physical-plan/src/joins/hash_join/stream.rs +++ b/datafusion/physical-plan/src/joins/hash_join/stream.rs @@ -24,7 +24,7 @@ use std::sync::Arc; use std::task::Poll; use crate::joins::hash_join::exec::JoinLeftData; -use crate::joins::hash_join::shared_bounds::{BoundsWaiter, SharedBoundsAccumulator}; +use crate::joins::hash_join::shared_bounds::SharedBoundsAccumulator; use crate::joins::utils::{ equal_rows_arr, get_final_indices_from_shared_bitmap, OnceFut, }; @@ -50,7 +50,7 @@ use datafusion_common::{ use datafusion_physical_expr::PhysicalExprRef; use ahash::RandomState; -use futures::{ready, FutureExt, Stream, StreamExt}; +use futures::{ready, Stream, StreamExt}; /// Represents build-side of hash join. pub(super) enum BuildSide { @@ -209,7 +209,7 @@ pub(super) struct HashJoinStream { bounds_accumulator: Option>, /// Optional future to signal when bounds have been reported by all partitions /// and the dynamic filter has been updated - bounds_waiter: Option, + bounds_waiter: Option>, } impl RecordBatchStream for HashJoinStream { @@ -378,7 +378,7 @@ impl HashJoinStream { cx: &mut std::task::Context<'_>, ) -> Poll>>> { if let Some(ref mut fut) = self.bounds_waiter { - ready!(fut.poll_unpin(cx)); + ready!(fut.get_shared(cx))?; } self.state = HashJoinStreamState::FetchProbeBatch; Poll::Ready(Ok(StatefulStreamResult::Continue)) @@ -405,8 +405,14 @@ impl HashJoinStream { // Dynamic filter coordination between partitions: // Report bounds to the accumulator which will handle synchronization and filter updates if let Some(ref bounds_accumulator) = self.bounds_accumulator { - self.bounds_waiter = bounds_accumulator - .report_partition_bounds(self.partition, left_data.bounds.clone())?; + let bounds_accumulator = Arc::clone(bounds_accumulator); + let partition = self.partition; + let left_data_bounds = left_data.bounds.clone(); + self.bounds_waiter = Some(OnceFut::new(async move { + bounds_accumulator + .report_partition_bounds(partition, left_data_bounds) + .await + })); self.state = HashJoinStreamState::WaitPartitionBoundsReport; } else { self.state = HashJoinStreamState::FetchProbeBatch;