Skip to content
70 changes: 34 additions & 36 deletions datafusion/physical-plan/src/joins/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ use crate::projection::{
use crate::spill::get_record_batch_memory_size;
use crate::ExecutionPlanProperties;
use crate::{
coalesce_partitions::CoalescePartitionsExec,
common::can_project,
handle_state,
hash_utils::create_hashes,
Expand Down Expand Up @@ -791,34 +790,44 @@ impl ExecutionPlan for HashJoinExec {
);
}

if self.mode == PartitionMode::CollectLeft && left_partitions != 1 {
return internal_err!(
"Invalid HashJoinExec,the output partition count of the left child must be 1 in CollectLeft mode,\
consider using CoalescePartitionsExec"
);
}

let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics);
let left_fut = match self.mode {
PartitionMode::CollectLeft => self.left_fut.once(|| {
let reservation =
MemoryConsumer::new("HashJoinInput").register(context.memory_pool());
collect_left_input(
None,
self.random_state.clone(),
Arc::clone(&self.left),
on_left.clone(),
Arc::clone(&context),
join_metrics.clone(),
reservation,
need_produce_result_in_final(self.join_type),
self.right().output_partitioning().partition_count(),
)
}),
PartitionMode::CollectLeft => {
let left_stream = self.left.execute(0, Arc::clone(&context))?;

self.left_fut.once(|| {
let reservation = MemoryConsumer::new("HashJoinInput")
.register(context.memory_pool());

collect_left_input(
self.random_state.clone(),
left_stream,
on_left.clone(),
join_metrics.clone(),
reservation,
need_produce_result_in_final(self.join_type),
self.right().output_partitioning().partition_count(),
)
})
}
PartitionMode::Partitioned => {
let left_stream = self.left.execute(partition, Arc::clone(&context))?;

let reservation =
MemoryConsumer::new(format!("HashJoinInput[{partition}]"))
.register(context.memory_pool());

OnceFut::new(collect_left_input(
Some(partition),
self.random_state.clone(),
Arc::clone(&self.left),
left_stream,
on_left.clone(),
Arc::clone(&context),
join_metrics.clone(),
reservation,
need_produce_result_in_final(self.join_type),
Expand Down Expand Up @@ -929,36 +938,22 @@ impl ExecutionPlan for HashJoinExec {

/// Reads the left (build) side of the input, buffering it in memory, to build a
/// hash table (`LeftJoinData`)
#[allow(clippy::too_many_arguments)]
async fn collect_left_input(
partition: Option<usize>,
random_state: RandomState,
left: Arc<dyn ExecutionPlan>,
left_stream: SendableRecordBatchStream,
on_left: Vec<PhysicalExprRef>,
context: Arc<TaskContext>,
metrics: BuildProbeJoinMetrics,
reservation: MemoryReservation,
with_visited_indices_bitmap: bool,
probe_threads_count: usize,
) -> Result<JoinLeftData> {
let schema = left.schema();

let (left_input, left_input_partition) = if let Some(partition) = partition {
(left, partition)
} else if left.output_partitioning().partition_count() != 1 {
(Arc::new(CoalescePartitionsExec::new(left)) as _, 0)
} else {
(left, 0)
};

// Depending on partition argument load single partition or whole left side in memory
let stream = left_input.execute(left_input_partition, Arc::clone(&context))?;
let schema = left_stream.schema();

// This operation performs 2 steps at once:
// 1. creates a [JoinHashMap] of all batches from the stream
// 2. stores the batches in a vector.
let initial = (Vec::new(), 0, metrics, reservation);
let (batches, num_rows, metrics, mut reservation) = stream
let (batches, num_rows, metrics, mut reservation) = left_stream
.try_fold(initial, |mut acc, batch| async {
let batch_size = get_record_batch_memory_size(&batch);
// Reserve memory for incoming batch
Expand Down Expand Up @@ -1654,6 +1649,7 @@ impl EmbeddedProjection for HashJoinExec {
#[cfg(test)]
mod tests {
use super::*;
use crate::coalesce_partitions::CoalescePartitionsExec;
use crate::test::TestMemoryExec;
use crate::{
common, expressions::Column, repartition::RepartitionExec, test::build_table_i32,
Expand Down Expand Up @@ -2101,6 +2097,7 @@ mod tests {
let left =
TestMemoryExec::try_new_exec(&[vec![batch1], vec![batch2]], schema, None)
.unwrap();
let left = Arc::new(CoalescePartitionsExec::new(left));

let right = build_table(
("a1", &vec![1, 2, 3]),
Expand Down Expand Up @@ -2173,6 +2170,7 @@ mod tests {
let left =
TestMemoryExec::try_new_exec(&[vec![batch1], vec![batch2]], schema, None)
.unwrap();
let left = Arc::new(CoalescePartitionsExec::new(left));
let right = build_table(
("a2", &vec![20, 30, 10]),
("b2", &vec![5, 6, 4]),
Expand Down