Skip to content

Commit 3d68a39

Browse files
committed
implement for CollectLeft
1 parent e66969e commit 3d68a39

File tree

3 files changed

+104
-116
lines changed

3 files changed

+104
-116
lines changed

datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,7 +1078,7 @@ async fn test_hashjoin_dynamic_filter_pushdown() {
10781078
@r"
10791079
- HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)]
10801080
- DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true
1081-
- DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= ba AND b@1 <= bb ]
1081+
- DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= ba AND b@1 <= bb AND hash_lookup ]
10821082
"
10831083
);
10841084
}
@@ -1503,7 +1503,7 @@ async fn test_hashjoin_dynamic_filter_pushdown_collect_left() {
15031503
- DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true
15041504
- CoalesceBatchesExec: target_batch_size=8192
15051505
- RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1
1506-
- DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= ba AND b@1 <= bb ]
1506+
- DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= ba AND b@1 <= bb AND hash_lookup ]
15071507
"
15081508
);
15091509

datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs

Lines changed: 100 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ use datafusion_physical_expr::expressions::{
3535
};
3636
use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef};
3737

38-
use itertools::Itertools;
3938
use parking_lot::Mutex;
4039
use tokio::sync::Barrier;
4140

@@ -78,10 +77,6 @@ impl PartitionBounds {
7877
}
7978
}
8079

81-
pub(crate) fn len(&self) -> usize {
82-
self.column_bounds.len()
83-
}
84-
8580
pub(crate) fn get_column_bounds(&self, index: usize) -> Option<&ColumnBounds> {
8681
self.column_bounds.get(index)
8782
}
@@ -138,6 +133,8 @@ struct SharedBuildState {
138133
/// Hash maps from completed partitions (used in Partitioned mode)
139134
/// Index corresponds to partition number
140135
hash_maps: Vec<Option<Arc<dyn JoinHashMapType>>>,
136+
/// Single hash map for CollectLeft mode (shared across all partitions)
137+
single_hash_map: Option<Arc<dyn JoinHashMapType>>,
141138
}
142139

143140
impl SharedBuildAccumulator {
@@ -200,6 +197,7 @@ impl SharedBuildAccumulator {
200197
inner: Mutex::new(SharedBuildState {
201198
bounds: Vec::with_capacity(expected_calls),
202199
hash_maps: vec![None; num_partitions],
200+
single_hash_map: None,
203201
}),
204202
barrier: Barrier::new(expected_calls),
205203
dynamic_filter,
@@ -208,129 +206,121 @@ impl SharedBuildAccumulator {
208206
}
209207
}
210208

211-
/// Create a filter expression from individual partition bounds using OR logic.
212-
///
213-
/// This creates a filter where each partition's bounds form a conjunction (AND)
214-
/// of column range predicates, and all partitions are combined with OR.
215-
///
216-
/// For example, with 2 partitions and 2 columns:
217-
/// ((col0 >= p0_min0 AND col0 <= p0_max0 AND col1 >= p0_min1 AND col1 <= p0_max1)
218-
/// OR
219-
/// (col0 >= p1_min0 AND col0 <= p1_max0 AND col1 >= p1_min1 AND col1 <= p1_max1))
220-
pub(crate) fn create_filter_from_partition_bounds(
221-
&self,
222-
bounds: &[PartitionBounds],
223-
) -> Result<Arc<dyn PhysicalExpr>> {
224-
if bounds.is_empty() {
225-
return Ok(lit(true));
226-
}
227-
228-
// Create a predicate for each partition
229-
let mut partition_predicates = Vec::with_capacity(bounds.len());
230-
231-
for partition_bounds in bounds.iter().sorted_by_key(|b| b.partition) {
232-
// Create range predicates for each join key in this partition
233-
let mut column_predicates = Vec::with_capacity(partition_bounds.len());
234-
235-
for (col_idx, right_expr) in self.on_right.iter().enumerate() {
236-
if let Some(column_bounds) = partition_bounds.get_column_bounds(col_idx) {
237-
// Create predicate: col >= min AND col <= max
238-
let min_expr = Arc::new(BinaryExpr::new(
239-
Arc::clone(right_expr),
240-
Operator::GtEq,
241-
lit(column_bounds.min.clone()),
242-
)) as Arc<dyn PhysicalExpr>;
243-
let max_expr = Arc::new(BinaryExpr::new(
244-
Arc::clone(right_expr),
245-
Operator::LtEq,
246-
lit(column_bounds.max.clone()),
247-
)) as Arc<dyn PhysicalExpr>;
248-
let range_expr =
249-
Arc::new(BinaryExpr::new(min_expr, Operator::And, max_expr))
250-
as Arc<dyn PhysicalExpr>;
251-
column_predicates.push(range_expr);
252-
}
253-
}
254-
255-
// Combine all column predicates for this partition with AND
256-
if !column_predicates.is_empty() {
257-
let partition_predicate = column_predicates
258-
.into_iter()
259-
.reduce(|acc, pred| {
260-
Arc::new(BinaryExpr::new(acc, Operator::And, pred))
261-
as Arc<dyn PhysicalExpr>
262-
})
263-
.unwrap();
264-
partition_predicates.push(partition_predicate);
265-
}
266-
}
267-
268-
// Combine all partition predicates with OR
269-
let combined_predicate = partition_predicates
270-
.into_iter()
271-
.reduce(|acc, pred| {
272-
Arc::new(BinaryExpr::new(acc, Operator::Or, pred))
273-
as Arc<dyn PhysicalExpr>
274-
})
275-
.unwrap_or_else(|| lit(true));
276-
277-
Ok(combined_predicate)
278-
}
279-
280-
/// Report bounds from a completed partition and update dynamic filter if all partitions are done
281-
///
282-
/// This method coordinates the dynamic filter updates across all partitions. It stores the
283-
/// bounds from the current partition, increments the completion counter, and when all
284-
/// partitions have reported, creates an OR'd filter from individual partition bounds.
285-
///
286-
/// This method is async and uses a [`tokio::sync::Barrier`] to wait for all partitions
287-
/// to report their bounds. Once that occurs, the method will resolve for all callers and the
288-
/// dynamic filter will be updated exactly once.
209+
/// Report hash map and bounds from CollectLeft mode (single hash table shared by all partitions)
289210
///
290-
/// # Note
211+
/// This method is used for `PartitionMode::CollectLeft` to collect the single shared hash map
212+
/// and bounds. When all partitions have reported (waited at barrier), it creates a simple filter
213+
/// expression that combines min/max range checks with hash table lookups.
291214
///
292-
/// As barriers are reusable, it is likely an error to call this method more times than the
293-
/// total number of partitions - as it can lead to pending futures that never resolve. We rely
294-
/// on correct usage from the caller rather than imposing additional checks here. If this is a concern,
295-
/// consider making the resulting future shared so the ready result can be reused.
215+
/// Unlike Partitioned mode, this creates a simpler filter without CASE expression or partition routing:
216+
/// `(col >= min AND col <= max AND ...) AND hash_lookup(hash_table, hash_join(join_keys))`
296217
///
297218
/// # Arguments
298-
/// * `left_side_partition_id` - The identifier for the **left-side** partition reporting its bounds
299-
/// * `partition_bounds` - The bounds computed by this partition (if any)
219+
/// * `hash_map` - Arc reference to the single shared hash table
220+
/// * `partition_bounds` - Min/max bounds for the build side
300221
///
301222
/// # Returns
302223
/// * `Result<()>` - Ok if successful, Err if filter update failed
303-
pub(crate) async fn report_partition_bounds(
224+
pub(crate) async fn report_single_hash_map_and_bounds(
304225
&self,
305-
left_side_partition_id: usize,
226+
hash_map: Arc<dyn JoinHashMapType>,
306227
partition_bounds: Option<Vec<ColumnBounds>>,
307228
) -> Result<()> {
308-
// Store bounds in the accumulator - this runs once per partition
309-
if let Some(bounds) = partition_bounds {
229+
// Store hash map and bounds in the accumulator
230+
{
310231
let mut guard = self.inner.lock();
311232

312-
let should_push = if let Some(last_bound) = guard.bounds.last() {
313-
// In `PartitionMode::CollectLeft`, all streams on the left side share the same partition id (0).
314-
// Since this function can be called multiple times for that same partition, we must deduplicate
315-
// by checking against the last recorded bound.
316-
last_bound.partition != left_side_partition_id
317-
} else {
318-
true
319-
};
233+
// Store the single hash map (only once, even though multiple partitions call this)
234+
if guard.single_hash_map.is_none() {
235+
guard.single_hash_map = Some(hash_map);
236+
}
320237

321-
if should_push {
322-
guard
323-
.bounds
324-
.push(PartitionBounds::new(left_side_partition_id, bounds));
238+
if let Some(bounds) = partition_bounds {
239+
// Use partition 0 for the single hash table
240+
let should_push = if let Some(last_bound) = guard.bounds.last() {
241+
// Deduplicate - all partitions report the same data in CollectLeft
242+
last_bound.partition != 0
243+
} else {
244+
true
245+
};
246+
247+
if should_push {
248+
guard.bounds.push(PartitionBounds::new(0, bounds));
249+
}
325250
}
326251
}
327252

253+
// Wait for all partitions to report
328254
if self.barrier.wait().await.is_leader() {
329-
// All partitions have reported, so we can update the filter
255+
// All partitions have reported, so we can create and update the filter
330256
let inner = self.inner.lock();
331-
if !inner.bounds.is_empty() {
332-
let filter_expr =
333-
self.create_filter_from_partition_bounds(&inner.bounds)?;
257+
258+
if let Some(ref hash_map) = inner.single_hash_map {
259+
// Create hash lookup expression
260+
let lookup_hash_expr = Arc::new(HashExpr::new(
261+
self.on_right.clone(),
262+
HASH_JOIN_SEED,
263+
"hash_join".to_string(),
264+
)) as Arc<dyn PhysicalExpr>;
265+
266+
let hash_lookup_expr = Arc::new(HashTableLookupExpr::new(
267+
lookup_hash_expr,
268+
Arc::clone(hash_map),
269+
"hash_lookup".to_string(),
270+
)) as Arc<dyn PhysicalExpr>;
271+
272+
// Create bounds check expression (if bounds available)
273+
let mut filter_expr = hash_lookup_expr;
274+
275+
if let Some(partition_bounds) = inner.bounds.first() {
276+
let mut column_predicates = Vec::new();
277+
278+
for (col_idx, right_expr) in self.on_right.iter().enumerate() {
279+
if let Some(column_bounds) =
280+
partition_bounds.get_column_bounds(col_idx)
281+
{
282+
// Create predicate: col >= min AND col <= max
283+
let min_expr = Arc::new(BinaryExpr::new(
284+
Arc::clone(right_expr),
285+
Operator::GtEq,
286+
lit(column_bounds.min.clone()),
287+
))
288+
as Arc<dyn PhysicalExpr>;
289+
let max_expr = Arc::new(BinaryExpr::new(
290+
Arc::clone(right_expr),
291+
Operator::LtEq,
292+
lit(column_bounds.max.clone()),
293+
))
294+
as Arc<dyn PhysicalExpr>;
295+
let range_expr = Arc::new(BinaryExpr::new(
296+
min_expr,
297+
Operator::And,
298+
max_expr,
299+
))
300+
as Arc<dyn PhysicalExpr>;
301+
column_predicates.push(range_expr);
302+
}
303+
}
304+
305+
// Combine all column range predicates with AND
306+
if !column_predicates.is_empty() {
307+
let bounds_expr = column_predicates
308+
.into_iter()
309+
.reduce(|acc, pred| {
310+
Arc::new(BinaryExpr::new(acc, Operator::And, pred))
311+
as Arc<dyn PhysicalExpr>
312+
})
313+
.unwrap();
314+
315+
// Combine bounds_expr AND hash_lookup_expr
316+
filter_expr = Arc::new(BinaryExpr::new(
317+
bounds_expr,
318+
Operator::And,
319+
filter_expr,
320+
)) as Arc<dyn PhysicalExpr>;
321+
}
322+
}
323+
334324
self.dynamic_filter.update(filter_expr)?;
335325
}
336326
}

datafusion/physical-plan/src/joins/hash_join/stream.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -437,13 +437,11 @@ impl HashJoinStream {
437437
}));
438438
}
439439
PartitionMode::CollectLeft => {
440+
let hash_map = left_data.hash_map_arc();
440441
let left_data_bounds = left_data.bounds.clone();
441442
self.build_waiter = Some(OnceFut::new(async move {
442443
build_accumulator
443-
.report_partition_bounds(
444-
left_side_partition_id,
445-
left_data_bounds,
446-
)
444+
.report_single_hash_map_and_bounds(hash_map, left_data_bounds)
447445
.await
448446
}));
449447
}

0 commit comments

Comments
 (0)