diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 48d648c89a35..ec1bf8fd4903 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -17,12 +17,6 @@ //! [`HashJoinExec`] Partitioned Hash Join Operator -use std::fmt; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; -use std::task::Poll; -use std::{any::Any, vec}; - use super::utils::asymmetric_join_output_partitioning; use super::{ utils::{OnceAsync, OnceFut}, @@ -47,6 +41,12 @@ use crate::{ Partitioning, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, }; +use std::fmt; +use std::ops::{Deref, DerefMut}; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::{any::Any, vec}; use arrow::array::{ Array, ArrayRef, BooleanArray, BooleanBufferBuilder, UInt32Array, UInt64Array, @@ -71,11 +71,68 @@ use datafusion_physical_expr::equivalence::{ use datafusion_physical_expr::PhysicalExprRef; use ahash::RandomState; +use arrow_buffer::BooleanBuffer; use datafusion_expr::Operator; use datafusion_physical_expr_common::datum::compare_op_for_nested; use futures::{ready, Stream, StreamExt, TryStreamExt}; use parking_lot::Mutex; +/// `SharedJoinState` provides an extension point allowing +/// `HashJoinStream` to share the `visited_indices_bitmap` of the build side of a join +/// across probe tasks without shared memory. +/// +/// This can be used to, for example, implement a left outer join efficiently as a broadcast join +/// if the left side is small +pub struct SharedJoinState { + /// `SharedJoinState` is just a wrapper around a user-defined implementation of `SharedJoinStateImpl` + /// so it can be passed to the operator as an extension in the `TaskContext` + state_impl: Arc, +} + +impl SharedJoinState { + pub fn new(state_impl: Arc) -> Self { + Self { state_impl } + } + + /// Number of probe threads in the current task + fn num_task_partitions(&self) -> usize { + self.state_impl.num_task_partitions() + } + + /// When all probe threads in the current task are completed, this method will probe the shared state + fn poll_probe_completed( + &self, + visited_indices_bitmap: &BooleanBufferBuilder, + cx: &mut Context<'_>, + ) -> Poll> { + self.state_impl + .poll_probe_completed(visited_indices_bitmap, cx) + } +} + +/// Represents the result of polling the shared join state +#[derive(Clone)] +pub enum SharedProbeState { + /// Indicates that one or more probe tasks are still running so the current task can + /// complete without emitting rows required for outer join + Continue, + /// All probe tasks are completed. Contains their combined `visited_indices_bitmap` + Ready(BooleanBuffer), +} + +/// Trait which provides the user-defined implementation of the shared join state +pub trait SharedJoinStateImpl: Send + Sync + 'static { + /// Number of probe threads in the current task + fn num_task_partitions(&self) -> usize; + + /// When all probe threads in the current task are completed, this method will probe the shared state + fn poll_probe_completed( + &self, + visited_indices_bitmap: &BooleanBufferBuilder, + cx: &mut Context<'_>, + ) -> Poll>; +} + type SharedBitmapBuilder = Mutex; /// HashTable and input data for the left (build side) of a join @@ -89,6 +146,8 @@ struct JoinLeftData { /// Counter of running probe-threads, potentially /// able to update `visited_indices_bitmap` probe_threads_counter: AtomicUsize, + /// Shared join state if it is provided in `TaskContext` + shared_state: Option>, /// Memory reservation that tracks memory used by `hash_map` hash table /// `batch`. Cleared on drop. #[allow(dead_code)] @@ -102,6 +161,7 @@ impl JoinLeftData { batch: RecordBatch, visited_indices_bitmap: SharedBitmapBuilder, probe_threads_counter: AtomicUsize, + shared_state: Option>, reservation: MemoryReservation, ) -> Self { Self { @@ -109,6 +169,7 @@ impl JoinLeftData { batch, visited_indices_bitmap, probe_threads_counter, + shared_state, reservation, } } @@ -131,8 +192,29 @@ impl JoinLeftData { /// Decrements the counter of running threads, and returns `true` /// if caller is the last running thread fn report_probe_completed(&self) -> bool { - self.probe_threads_counter.fetch_sub(1, Ordering::Relaxed) == 1 + self.probe_threads_counter.load(Ordering::Relaxed) == 0 + || self.probe_threads_counter.fetch_sub(1, Ordering::Relaxed) == 1 + } +} + +fn merge_bitmap(m1: &mut BooleanBufferBuilder, m2: BooleanBuffer) -> Result<()> { + if m1.len() != m2.len() { + return Err(DataFusionError::Execution(format!( + "local and shared indices bitmaps have different lengths: {} and {}", + m1.len(), + m2.len() + ))); + } + + for (b1, b2) in m1 + .as_slice_mut() + .iter_mut() + .zip(m2.inner().as_slice().iter().copied()) + { + *b1 |= b2; } + + Ok(()) } /// Join execution plan: Evaluates eqijoin predicates in parallel on multiple @@ -688,11 +770,21 @@ impl ExecutionPlan for HashJoinExec { ); } + let shared_state = context.session_config().get_extension::(); + let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics); let left_fut = match self.mode { PartitionMode::CollectLeft => self.left_fut.once(|| { let reservation = MemoryConsumer::new("HashJoinInput").register(context.memory_pool()); + + let probe_threads = shared_state + .as_ref() + .map(|s| s.num_task_partitions()) + .unwrap_or_else(|| { + self.right().output_partitioning().partition_count() + }); + collect_left_input( None, self.random_state.clone(), @@ -702,7 +794,8 @@ impl ExecutionPlan for HashJoinExec { join_metrics.clone(), reservation, need_produce_result_in_final(self.join_type), - self.right().output_partitioning().partition_count(), + probe_threads, + shared_state, ) }), PartitionMode::Partitioned => { @@ -720,6 +813,7 @@ impl ExecutionPlan for HashJoinExec { reservation, need_produce_result_in_final(self.join_type), 1, + None, )) } PartitionMode::Auto => { @@ -806,6 +900,7 @@ async fn collect_left_input( reservation: MemoryReservation, with_visited_indices_bitmap: bool, probe_threads_count: usize, + shared_state: Option>, ) -> Result { let schema = left.schema(); @@ -892,6 +987,7 @@ async fn collect_left_input( single_batch, Mutex::new(visited_indices_bitmap), AtomicUsize::new(probe_threads_count), + shared_state, reservation, ); @@ -1266,7 +1362,7 @@ impl HashJoinStream { /// that partial borrows work correctly fn poll_next_impl( &mut self, - cx: &mut std::task::Context<'_>, + cx: &mut Context<'_>, ) -> Poll>> { loop { return match self.state { @@ -1280,7 +1376,7 @@ impl HashJoinStream { handle_state!(self.process_probe_batch()) } HashJoinStreamState::ExhaustedProbeSide => { - handle_state!(self.process_unmatched_build_batch()) + handle_state!(ready!(self.process_unmatched_build_batch(cx))) } HashJoinStreamState::Completed => Poll::Ready(None), }; @@ -1292,7 +1388,7 @@ impl HashJoinStream { /// Updates build-side to `Ready`, and state to `FetchProbeSide` fn collect_build_side( &mut self, - cx: &mut std::task::Context<'_>, + cx: &mut Context<'_>, ) -> Poll>>> { let build_timer = self.join_metrics.build_time.timer(); // build hash table from left (build) side, if not yet done @@ -1466,18 +1562,35 @@ impl HashJoinStream { /// Updates state to `Completed` fn process_unmatched_build_batch( &mut self, - ) -> Result>> { + cx: &mut Context<'_>, + ) -> Poll>>> { let timer = self.join_metrics.join_time.timer(); if !need_produce_result_in_final(self.join_type) { self.state = HashJoinStreamState::Completed; - return Ok(StatefulStreamResult::Continue); + return Poll::Ready(Ok(StatefulStreamResult::Continue)); } let build_side = self.build_side.try_as_ready()?; if !build_side.left_data.report_probe_completed() { self.state = HashJoinStreamState::Completed; - return Ok(StatefulStreamResult::Continue); + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + + if let Some(shared_state) = build_side.left_data.shared_state.as_ref() { + let mut guard = build_side.left_data.visited_indices_bitmap().lock(); + match ready!(shared_state.poll_probe_completed(guard.deref(), cx)) { + Ok(SharedProbeState::Continue) => { + self.state = HashJoinStreamState::Completed; + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + Ok(SharedProbeState::Ready(shared_mask)) => { + if let Err(e) = merge_bitmap(guard.deref_mut(), shared_mask) { + return Poll::Ready(Err(e)); + } + } + Err(err) => return Poll::Ready(Err(err)), + } } // use the global left bitmap to produce the left indices and right indices @@ -1508,7 +1621,7 @@ impl HashJoinStream { self.state = HashJoinStreamState::Completed; - Ok(StatefulStreamResult::Ready(Some(result?))) + Poll::Ready(Ok(StatefulStreamResult::Ready(Some(result?)))) } } @@ -1534,7 +1647,7 @@ mod tests { use arrow::array::{Date32Array, Int32Array}; use arrow::datatypes::{DataType, Field}; use arrow_array::StructArray; - use arrow_buffer::NullBuffer; + use arrow_buffer::{MutableBuffer, NullBuffer}; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, ScalarValue, @@ -1545,6 +1658,8 @@ mod tests { use datafusion_physical_expr::expressions::{BinaryExpr, Literal}; use datafusion_physical_expr::PhysicalExpr; + use crate::stream::RecordBatchReceiverStreamBuilder; + use datafusion_common_runtime::SpawnedTask; use hashbrown::raw::RawTable; use rstest::*; use rstest_reuse::*; @@ -1562,6 +1677,16 @@ mod tests { Arc::new(TaskContext::default().with_session_config(session_config)) } + fn prepare_task_ctx_with_shared_state( + batch_size: usize, + shared_state: Arc, + ) -> Arc { + let session_config = SessionConfig::default() + .with_batch_size(batch_size) + .with_extension(shared_state); + Arc::new(TaskContext::default().with_session_config(session_config)) + } + fn build_table( a: (&str, &Vec), b: (&str, &Vec), @@ -1572,6 +1697,11 @@ mod tests { Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) } + fn build_partitioned_table(partitions: Vec) -> Arc { + let schema = partitions.first().expect("no partitions").schema(); + Arc::new(MemoryExec::try_new(&[partitions], schema, None).unwrap()) + } + fn join( left: Arc, right: Arc, @@ -2224,6 +2354,179 @@ mod tests { assert_batches_sorted_eq!(expected, &batches); } + struct Coordinator { + probe_tx: tokio::sync::mpsc::Sender<( + BooleanBuffer, + tokio::sync::oneshot::Sender, + )>, + request: Mutex>>, + } + + impl SharedJoinStateImpl for Coordinator { + fn num_task_partitions(&self) -> usize { + 1 + } + + fn poll_probe_completed( + &self, + visited_indices_bitmap: &BooleanBufferBuilder, + cx: &mut Context<'_>, + ) -> Poll> { + let mut guard = self.request.lock(); + + loop { + match guard.deref_mut() { + Some(fut) => { + return match ready!(fut.get(cx)) { + Ok(state) => Poll::Ready(Ok(state.clone())), + Err(e) => Poll::Ready(Err(e)), + } + } + None => { + let (result_tx, result_rx) = tokio::sync::oneshot::channel(); + self.probe_tx + .try_send((visited_indices_bitmap.finish_cloned(), result_tx)) + .expect("request channel full"); + + let result_fut = async move { + result_rx.await.map_err(|_| { + DataFusionError::Internal("sender dropped".to_string()) + }) + }; + + *guard = Some(OnceFut::new(result_fut)); + } + } + } + } + } + + #[apply(batch_sizes)] + #[tokio::test] + async fn join_left_distributed_multi_batch(batch_size: usize) { + let left = build_table( + ("a1", &vec![1, 2, 3, 4, 5, 6]), + ("b1", &vec![4, 5, 7, 9, 11, 13]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9, 10, 11, 12]), + ); + + let right = build_partitioned_table(vec![ + build_table_i32( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ), + build_table_i32( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![7, 11, 12]), + ("c2", &vec![70, 80, 90]), + ), + build_table_i32( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![14, 15, 16]), + ("c2", &vec![70, 80, 90]), + ), + ]); + + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("b1", &right.schema()).unwrap()) as _, + )]; + + let join = Arc::new(join(left, right, on, &JoinType::Left, false).unwrap()); + + let partitions = join.properties().output_partitioning().partition_count(); + + let (probe_tx, mut probe_rx) = tokio::sync::mpsc::channel(partitions); + + let coordinator = Coordinator { + probe_tx, + request: Default::default(), + }; + + let _coordinator_task = SpawnedTask::spawn(async move { + let mut shared_mask: Option = None; + let mut partitions = partitions; + while let Some((mask, response_rx)) = probe_rx.recv().await { + partitions -= 1; + let state = if partitions == 0 { + SharedProbeState::Ready( + shared_mask + .as_ref() + .map(|b| b.finish_cloned()) + .unwrap_or(mask), + ) + } else { + match shared_mask.as_mut() { + Some(shared_mask) => { + for (b1, b2) in shared_mask + .as_slice_mut() + .iter_mut() + .zip(mask.inner().as_slice()) + { + *b1 |= *b2; + } + } + None => { + let inner = MutableBuffer::from_iter( + mask.inner().as_slice().iter().copied(), + ); + let shared = + BooleanBufferBuilder::new_from_buffer(inner, mask.len()); + shared_mask = Some(shared); + } + } + SharedProbeState::Continue + }; + let _ = response_rx.send(state); + } + }); + + let task_ctx = prepare_task_ctx_with_shared_state( + batch_size, + Arc::new(SharedJoinState::new(Arc::new(coordinator))), + ); + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); + + let mut builder = RecordBatchReceiverStreamBuilder::new(join.schema(), 8); + + for i in 0..partitions { + let join = Arc::clone(&join); + let tx = builder.tx(); + let task_ctx = Arc::clone(&task_ctx); + builder.spawn(async move { + let mut stream = join.execute(i, task_ctx).unwrap(); + + while let Some(batch) = stream.next().await { + let _ = tx.send(batch).await; + } + + Ok(()) + }); + } + + let stream = builder.build(); + + let batches = common::collect(stream).await.unwrap(); + + let expected = [ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | 10 | 7 | 70 |", + "| 4 | 9 | 10 | | | |", + "| 5 | 11 | 11 | 20 | 11 | 80 |", + "| 6 | 13 | 12 | | | |", + "+----+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + } + #[apply(batch_sizes)] #[tokio::test] async fn join_full_multi_batch(batch_size: usize) { diff --git a/datafusion/physical-plan/src/joins/mod.rs b/datafusion/physical-plan/src/joins/mod.rs index 6ddf19c51193..221f664f0e34 100644 --- a/datafusion/physical-plan/src/joins/mod.rs +++ b/datafusion/physical-plan/src/joins/mod.rs @@ -18,7 +18,9 @@ //! DataFusion Join implementations pub use cross_join::CrossJoinExec; -pub use hash_join::HashJoinExec; +pub use hash_join::{ + HashJoinExec, SharedJoinState, SharedJoinStateImpl, SharedProbeState, +}; pub use nested_loop_join::NestedLoopJoinExec; // Note: SortMergeJoin is not used in plans yet pub use sort_merge_join::SortMergeJoinExec;