Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1065,7 +1065,7 @@ async fn test_hashjoin_dynamic_filter_pushdown_partitioned() {
];
let probe_repartition = Arc::new(
RepartitionExec::try_new(
probe_scan,
Arc::clone(&probe_scan),
Partitioning::Hash(probe_hash_exprs, partition_count),
)
.unwrap(),
Expand Down Expand Up @@ -1199,6 +1199,13 @@ async fn test_hashjoin_dynamic_filter_pushdown_partitioned() {

let result = format!("{}", pretty_format_batches(&batches).unwrap());

let probe_scan_metrics = probe_scan.metrics().unwrap();

// The probe side had 4 rows, but after applying the dynamic filter only 2 rows should remain.
// The number of output rows from the probe side scan should stay consistent across executions.
// Issue: https://github.com/apache/datafusion/issues/17451
assert_eq!(probe_scan_metrics.output_rows().unwrap(), 2);

insta::assert_snapshot!(
result,
@r"
Expand Down Expand Up @@ -1355,7 +1362,7 @@ async fn test_nested_hashjoin_dynamic_filter_pushdown() {
- DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, x], file_type=test, pushdown_supported=true
- HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, d@0)]
- DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[b, c, y], file_type=test, pushdown_supported=true, predicate=DynamicFilterPhysicalExpr [ b@0 >= aa AND b@0 <= ab ]
- DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, z], file_type=test, pushdown_supported=true, predicate=DynamicFilterPhysicalExpr [ d@0 >= ca AND d@0 <= ce ]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@adriangb This update seems correct to me

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm the filter is more selective (generally a good thing) but I'm curious why it changed, it's not immediately coming to me why waiting would change the value of the filter. Could you help me understand why that is?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I believe this is simply due to the fact that the filter is now applied in TestOpener. The bounds on the left side have changed due to rows being filtered out so we're seeing the correct filter now.

- DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, z], file_type=test, pushdown_supported=true, predicate=DynamicFilterPhysicalExpr [ d@0 >= ca AND d@0 <= cb ]
"
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use datafusion_datasource::{
};
use datafusion_physical_expr_common::physical_expr::fmt_sql;
use datafusion_physical_optimizer::PhysicalOptimizerRule;
use datafusion_physical_plan::filter::batch_filter;
use datafusion_physical_plan::filter_pushdown::{FilterPushdownPhase, PushedDown};
use datafusion_physical_plan::{
displayable,
Expand All @@ -53,6 +54,7 @@ pub struct TestOpener {
batch_size: Option<usize>,
schema: Option<SchemaRef>,
projection: Option<Vec<usize>>,
predicate: Option<Arc<dyn PhysicalExpr>>,
}

impl FileOpener for TestOpener {
Expand All @@ -77,6 +79,12 @@ impl FileOpener for TestOpener {
let (mapper, projection) = factory.map_schema(&batches[0].schema()).unwrap();
let mut new_batches = Vec::new();
for batch in batches {
let batch = if let Some(predicate) = &self.predicate {
batch_filter(&batch, predicate)?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice I didn't know this little function existed!

} else {
batch
};

let batch = batch.project(&projection).unwrap();
let batch = mapper.map_batch(batch).unwrap();
new_batches.push(batch);
Expand Down Expand Up @@ -133,6 +141,7 @@ impl FileSource for TestSource {
batch_size: self.batch_size,
schema: self.schema.clone(),
projection: self.projection.clone(),
predicate: self.predicate.clone(),
})
}

Expand Down
59 changes: 31 additions & 28 deletions datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef};

use itertools::Itertools;
use parking_lot::Mutex;
use tokio::sync::Barrier;

/// Represents the minimum and maximum values for a specific column.
/// Used in dynamic filter pushdown to establish value boundaries.
Expand Down Expand Up @@ -86,9 +87,9 @@ impl PartitionBounds {
/// ## Synchronization Strategy
///
/// 1. Each partition computes bounds from its build-side data
/// 2. Bounds are stored in the shared HashMap (indexed by partition_id)
/// 3. A counter tracks how many partitions have reported their bounds
/// 4. When the last partition reports (completed == total), bounds are merged and filter is updated
/// 2. Bounds are stored in the shared vector
/// 3. A barrier tracks how many partitions have reported their bounds
/// 4. When the last partition reports, bounds are merged and the filter is updated exactly once
///
/// ## Partition Counting
///
Expand All @@ -103,10 +104,7 @@ impl PartitionBounds {
pub(crate) struct SharedBoundsAccumulator {
/// Shared state protected by a single mutex to avoid ordering concerns
inner: Mutex<SharedBoundsState>,
/// Total number of partitions.
/// Need to know this so that we can update the dynamic filter once we are done
/// building *all* of the hash tables.
total_partitions: usize,
barrier: Barrier,
/// Dynamic filter for pushdown to probe side
dynamic_filter: Arc<DynamicFilterPhysicalExpr>,
/// Right side join expressions needed for creating filter bounds
Expand All @@ -118,8 +116,6 @@ struct SharedBoundsState {
/// Bounds from completed partitions.
/// Each element represents the column bounds computed by one partition.
bounds: Vec<PartitionBounds>,
/// Number of partitions that have reported completion.
completed_partitions: usize,
}

impl SharedBoundsAccumulator {
Expand Down Expand Up @@ -170,9 +166,8 @@ impl SharedBoundsAccumulator {
Self {
inner: Mutex::new(SharedBoundsState {
bounds: Vec::with_capacity(expected_calls),
completed_partitions: 0,
}),
total_partitions: expected_calls,
barrier: Barrier::new(expected_calls),
dynamic_filter,
on_right,
}
Expand Down Expand Up @@ -253,36 +248,44 @@ impl SharedBoundsAccumulator {
/// bounds from the current partition, increments the completion counter, and when all
/// partitions have reported, creates an OR'd filter from individual partition bounds.
///
/// This method is async and uses a [`tokio::sync::Barrier`] to wait for all partitions
/// to report their bounds. Once that occurs, the method will resolve for all callers and the
/// dynamic filter will be updated exactly once.
///
/// # Note
///
/// As barriers are reusable, it is likely an error to call this method more times than the
/// total number of partitions - as it can lead to pending futures that never resolve. We rely
/// on correct usage from the caller rather than imposing additional checks here. If this is a concern,
/// consider making the resulting future shared so the ready result can be reused.
///
/// # Arguments
/// * `partition` - The partition identifier reporting its bounds
/// * `partition_bounds` - The bounds computed by this partition (if any)
///
/// # Returns
/// * `Result<()>` - Ok if successful, Err if filter update failed
pub(crate) fn report_partition_bounds(
pub(crate) async fn report_partition_bounds(
&self,
partition: usize,
partition_bounds: Option<Vec<ColumnBounds>>,
) -> Result<()> {
let mut inner = self.inner.lock();

// Store bounds in the accumulator - this runs once per partition
if let Some(bounds) = partition_bounds {
// Only push actual bounds if they exist
inner.bounds.push(PartitionBounds::new(partition, bounds));
self.inner
.lock()
.bounds
.push(PartitionBounds::new(partition, bounds));
}

// Increment the completion counter
// Even empty partitions must report to ensure proper termination
inner.completed_partitions += 1;
let completed = inner.completed_partitions;
let total_partitions = self.total_partitions;

// Critical synchronization point: Only update the filter when ALL partitions are complete
// Troubleshooting: If you see "completed > total_partitions", check partition
// count calculation in new_from_partition_mode() - it may not match actual execution calls
if completed == total_partitions && !inner.bounds.is_empty() {
let filter_expr = self.create_filter_from_partition_bounds(&inner.bounds)?;
self.dynamic_filter.update(filter_expr)?;
if self.barrier.wait().await.is_leader() {
// All partitions have reported, so we can update the filter
let inner = self.inner.lock();
if !inner.bounds.is_empty() {
let filter_expr =
self.create_filter_from_partition_bounds(&inner.bounds)?;
self.dynamic_filter.update(filter_expr)?;
}
}

Ok(())
Expand Down
44 changes: 40 additions & 4 deletions datafusion/physical-plan/src/joins/hash_join/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ impl BuildSide {
pub(super) enum HashJoinStreamState {
/// Initial state for HashJoinStream indicating that build-side data not collected yet
WaitBuildSide,
/// Waiting for bounds to be reported by all partitions
WaitPartitionBoundsReport,
/// Indicates that build-side has been collected, and stream is ready for fetching probe-side
FetchProbeBatch,
/// Indicates that non-empty batch has been fetched from probe-side, and is ready to be processed
Expand Down Expand Up @@ -205,6 +207,9 @@ pub(super) struct HashJoinStream {
right_side_ordered: bool,
/// Shared bounds accumulator for coordinating dynamic filter updates (optional)
bounds_accumulator: Option<Arc<SharedBoundsAccumulator>>,
/// Optional future to signal when bounds have been reported by all partitions
/// and the dynamic filter has been updated
bounds_waiter: Option<OnceFut<()>>,
}

impl RecordBatchStream for HashJoinStream {
Expand Down Expand Up @@ -325,6 +330,7 @@ impl HashJoinStream {
hashes_buffer,
right_side_ordered,
bounds_accumulator,
bounds_waiter: None,
}
}

Expand All @@ -339,6 +345,9 @@ impl HashJoinStream {
HashJoinStreamState::WaitBuildSide => {
handle_state!(ready!(self.collect_build_side(cx)))
}
HashJoinStreamState::WaitPartitionBoundsReport => {
handle_state!(ready!(self.wait_for_partition_bounds_report(cx)))
}
HashJoinStreamState::FetchProbeBatch => {
handle_state!(ready!(self.fetch_probe_batch(cx)))
}
Expand All @@ -355,6 +364,26 @@ impl HashJoinStream {
}
}

/// Optional step to wait until bounds have been reported by all partitions.
/// This state is only entered if a bounds accumulator is present.
///
/// ## Why wait?
///
/// The dynamic filter is only built once all partitions have reported their bounds.
/// If we do not wait here, the probe-side scan may start before the filter is ready.
/// This can lead to the probe-side scan missing the opportunity to apply the filter
/// and skip reading unnecessary data.
fn wait_for_partition_bounds_report(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
if let Some(ref mut fut) = self.bounds_waiter {
ready!(fut.get_shared(cx))?;
}
self.state = HashJoinStreamState::FetchProbeBatch;
Poll::Ready(Ok(StatefulStreamResult::Continue))
}

/// Collects build-side data by polling `OnceFut` future from initialized build-side
///
/// Updates build-side to `Ready`, and state to `FetchProbeSide`
Expand All @@ -376,13 +405,20 @@ impl HashJoinStream {
// Dynamic filter coordination between partitions:
// Report bounds to the accumulator which will handle synchronization and filter updates
if let Some(ref bounds_accumulator) = self.bounds_accumulator {
bounds_accumulator
.report_partition_bounds(self.partition, left_data.bounds.clone())?;
let bounds_accumulator = Arc::clone(bounds_accumulator);
let partition = self.partition;
let left_data_bounds = left_data.bounds.clone();
self.bounds_waiter = Some(OnceFut::new(async move {
bounds_accumulator
.report_partition_bounds(partition, left_data_bounds)
.await
}));
self.state = HashJoinStreamState::WaitPartitionBoundsReport;
} else {
self.state = HashJoinStreamState::FetchProbeBatch;
}

self.state = HashJoinStreamState::FetchProbeBatch;
self.build_side = BuildSide::Ready(BuildSideReadyState { left_data });

Poll::Ready(Ok(StatefulStreamResult::Continue))
}

Expand Down