diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index 639fae7615af0..1c83ac47cd395 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -22,8 +22,8 @@ use std::{any::Any, sync::Arc, task::Poll}; use super::utils::{ adjust_right_output_partitioning, reorder_output_after_swap, BatchSplitter, - BatchTransformer, BuildProbeJoinMetrics, NoopBatchTransformer, OnceAsync, OnceFut, - StatefulStreamResult, + BatchTransformer, BuildProbeJoinMetrics, NoopBatchTransformer, SharedResultOnceCell, + SharedResultPending, StatefulStreamResult, }; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::execution_plan::{boundedness_from_children, EmissionType}; @@ -48,7 +48,7 @@ use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::join_equivalence_properties; use async_trait::async_trait; -use futures::{ready, Stream, StreamExt, TryStreamExt}; +use futures::{ready, FutureExt, Stream, StreamExt, TryStreamExt}; /// Data of the left side that is buffered into memory #[derive(Debug)] @@ -72,9 +72,10 @@ struct JoinLeftData { /// /// # Clone / Shared State /// -/// Note this structure includes a [`OnceAsync`] that is used to coordinate the -/// loading of the left side with the processing in each output stream. -/// Therefore it can not be [`Clone`] +/// Note this structure includes a shared [`SharedResultOnceCell`] that is +/// used to coordinate the loading of the left side with the processing in +/// each output stream. +/// Therefore it does not implement [`Clone`]. #[derive(Debug)] pub struct CrossJoinExec { /// left (build) side which gets loaded in memory @@ -83,13 +84,11 @@ pub struct CrossJoinExec { pub right: Arc, /// The schema once the join is applied schema: SchemaRef, - /// Buffered copy of left (build) side in memory. + /// Shared thread-safe OnceCell coordinating the loading of the left (build) side. /// - /// This structure is *shared* across all output streams. - /// - /// Each output stream waits on the `OnceAsync` to signal the completion of - /// the left side loading. - left_fut: OnceAsync, + /// All output streams share this cell to ensure the left side data is loaded exactly once, + /// with the first accessing stream handling the initialization. + left_once_cell: Arc>, /// Execution plan metrics metrics: ExecutionPlanMetricsSet, /// Properties such as schema, equivalence properties, ordering, partitioning, etc. @@ -122,7 +121,7 @@ impl CrossJoinExec { left, right, schema, - left_fut: Default::default(), + left_once_cell: Default::default(), metrics: ExecutionPlanMetricsSet::default(), cache, } @@ -303,14 +302,14 @@ impl ExecutionPlan for CrossJoinExec { let enforce_batch_size_in_joins = context.session_config().enforce_batch_size_in_joins(); - let left_fut = self.left_fut.once(|| { - load_left_input( + let left_fut = Arc::clone(&self.left_once_cell) + .get_or_init(load_left_input( Arc::clone(&self.left), context, join_metrics.clone(), reservation, - ) - }); + )) + .boxed(); if enforce_batch_size_in_joins { Ok(Box::pin(CrossJoinStream { @@ -459,7 +458,7 @@ struct CrossJoinStream { /// Input schema schema: Arc, /// Future for data from left side - left_fut: OnceFut, + left_fut: SharedResultPending, /// Right side stream right: SendableRecordBatchStream, /// Current value on the left @@ -568,10 +567,7 @@ impl CrossJoinStream { cx: &mut std::task::Context<'_>, ) -> Poll>>> { let build_timer = self.join_metrics.build_time.timer(); - let left_data = match ready!(self.left_fut.get(cx)) { - Ok(left_data) => left_data, - Err(e) => return Poll::Ready(Err(e)), - }; + let left_data = ready!(self.left_fut.poll_unpin(cx))?; build_timer.done(); let left_data = left_data.merged_batch.clone(); diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 376c3590b88f0..f5fee0ba1c4d5 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -26,12 +26,10 @@ use std::{any::Any, vec}; use super::utils::{ asymmetric_join_output_partitioning, get_final_indices_from_shared_bitmap, - reorder_output_after_swap, swap_join_projection, -}; -use super::{ - utils::{OnceAsync, OnceFut}, - PartitionMode, SharedBitmapBuilder, + reorder_output_after_swap, swap_join_projection, SharedResultOnceCell, + SharedResultPending, }; +use super::{PartitionMode, SharedBitmapBuilder}; use crate::execution_plan::{boundedness_from_children, EmissionType}; use crate::projection::{ try_embed_projection, try_pushdown_through_join, EmbeddedProjection, JoinData, @@ -83,7 +81,7 @@ use datafusion_physical_expr_common::datum::compare_op_for_nested; use ahash::RandomState; use datafusion_physical_expr_common::physical_expr::fmt_sql; -use futures::{ready, Stream, StreamExt, TryStreamExt}; +use futures::{ready, FutureExt, Stream, StreamExt, TryStreamExt}; use parking_lot::Mutex; /// HashTable and input data for the left (build side) of a join @@ -314,9 +312,10 @@ impl JoinLeftData { /// /// # Clone / Shared State /// -/// Note this structure includes a [`OnceAsync`] that is used to coordinate the -/// loading of the left side with the processing in each output stream. -/// Therefore it can not be [`Clone`] +/// Note this structure includes a shared [`SharedResultOnceCell`] that is +/// used to coordinate the loading of the left side with the processing in +/// each output stream. +/// Therefore it does not implement [`Clone`]. #[derive(Debug)] pub struct HashJoinExec { /// left (build) side which gets hashed @@ -332,13 +331,12 @@ pub struct HashJoinExec { /// The schema after join. Please be careful when using this schema, /// if there is a projection, the schema isn't the same as the output schema. join_schema: SchemaRef, - /// Future that consumes left input and builds the hash table - /// - /// For CollectLeft partition mode, this structure is *shared* across all output streams. + /// Shared OnceCell coordinating the loading of the left (build) side. /// - /// Each output stream waits on the `OnceAsync` to signal the completion of - /// the hash table creation. - left_fut: OnceAsync, + /// For CollectLeft partition mode, all output streams share this cell to + /// ensure the build side is computed exactly once, with the first accessing + /// stream handling the initialization. + left_once_cell: Arc>, /// Shared the `RandomState` for the hashing algorithm random_state: RandomState, /// Partitioning mode to use @@ -409,7 +407,7 @@ impl HashJoinExec { filter, join_type: *join_type, join_schema, - left_fut: Default::default(), + left_once_cell: Default::default(), random_state, mode: partition_mode, metrics: ExecutionPlanMetricsSet::new(), @@ -793,27 +791,41 @@ impl ExecutionPlan for HashJoinExec { let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics); let left_fut = match self.mode { - PartitionMode::CollectLeft => self.left_fut.once(|| { - let reservation = - MemoryConsumer::new("HashJoinInput").register(context.memory_pool()); - collect_left_input( - None, - self.random_state.clone(), - Arc::clone(&self.left), - on_left.clone(), - Arc::clone(&context), - join_metrics.clone(), - reservation, - need_produce_result_in_final(self.join_type), - self.right().output_partitioning().partition_count(), - ) - }), + PartitionMode::CollectLeft => { + let memory_pool = Arc::clone(context.memory_pool()); + let random_state = self.random_state.clone(); + let left = Arc::clone(&self.left); + let context = Arc::clone(&context); + let join_metrics = join_metrics.clone(); + let join_type = self.join_type; + + let fut = Arc::clone(&self.left_once_cell).get_or_init(async move { + // We create the reservation inside this closure to ensure it is only created once + let reservation = + MemoryConsumer::new("HashJoinInput").register(&memory_pool); + + collect_left_input( + None, + random_state, + left, + on_left, + context, + join_metrics, + reservation, + need_produce_result_in_final(join_type), + right_partitions, + ) + .await + }); + + fut.boxed() + } PartitionMode::Partitioned => { let reservation = MemoryConsumer::new(format!("HashJoinInput[{partition}]")) .register(context.memory_pool()); - OnceFut::new(collect_left_input( + collect_left_input( Some(partition), self.random_state.clone(), Arc::clone(&self.left), @@ -823,7 +835,9 @@ impl ExecutionPlan for HashJoinExec { reservation, need_produce_result_in_final(self.join_type), 1, - )) + ) + .map(|res| res.map(Arc::new)) + .boxed() } PartitionMode::Auto => { return plan_err!( @@ -1099,7 +1113,7 @@ enum BuildSide { /// Container for BuildSide::Initial related data struct BuildSideInitialState { /// Future for building hash table from build-side input - left_fut: OnceFut, + left_fut: SharedResultPending, } /// Container for BuildSide::Ready related data @@ -1423,7 +1437,8 @@ impl HashJoinStream { .build_side .try_as_initial_mut()? .left_fut - .get_shared(cx))?; + .poll_unpin(cx))?; + build_timer.done(); self.state = HashJoinStreamState::FetchProbeBatch; diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index cdd2eaeca8997..458c004d4a2c7 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -26,7 +26,8 @@ use std::task::Poll; use super::utils::{ asymmetric_join_output_partitioning, get_final_indices_from_shared_bitmap, need_produce_result_in_final, reorder_output_after_swap, swap_join_projection, - BatchSplitter, BatchTransformer, NoopBatchTransformer, StatefulStreamResult, + BatchSplitter, BatchTransformer, NoopBatchTransformer, SharedResultOnceCell, + SharedResultPending, StatefulStreamResult, }; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::common::can_project; @@ -34,7 +35,7 @@ use crate::execution_plan::{boundedness_from_children, EmissionType}; use crate::joins::utils::{ adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices, build_join_schema, check_join_is_valid, estimate_join_statistics, - BuildProbeJoinMetrics, ColumnIndex, JoinFilter, OnceAsync, OnceFut, + BuildProbeJoinMetrics, ColumnIndex, JoinFilter, }; use crate::joins::SharedBitmapBuilder; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; @@ -62,7 +63,7 @@ use datafusion_physical_expr::equivalence::{ join_equivalence_properties, ProjectionMapping, }; -use futures::{ready, Stream, StreamExt, TryStreamExt}; +use futures::{ready, FutureExt, Stream, StreamExt, TryStreamExt}; use parking_lot::Mutex; /// Left (build-side) data @@ -148,9 +149,10 @@ impl JoinLeftData { /// /// # Clone / Shared State /// -/// Note this structure includes a [`OnceAsync`] that is used to coordinate the -/// loading of the left side with the processing in each output stream. -/// Therefore it can not be [`Clone`] +/// Note this structure includes a shared [`SharedResultOnceCell`] that is +/// used to coordinate the loading of the inner table with the processing in +/// each output stream. +/// Therefore it does not implement [`Clone`]. #[derive(Debug)] pub struct NestedLoopJoinExec { /// left side @@ -163,13 +165,11 @@ pub struct NestedLoopJoinExec { pub(crate) join_type: JoinType, /// The schema once the join is applied join_schema: SchemaRef, - /// Future that consumes left input and buffers it in memory + /// Shared OnceCell coordinating the loading of the inner table /// - /// This structure is *shared* across all output streams. - /// - /// Each output stream waits on the `OnceAsync` to signal the completion of - /// the hash table creation. - inner_table: OnceAsync, + /// All output streams share this cell to ensure the inner table is loaded exactly once, + /// with the first accessing stream handling the initialization. + inner_table: Arc>, /// Information of index and left / right placement of columns column_indices: Vec, /// Projection to apply to the output of the join @@ -490,16 +490,16 @@ impl ExecutionPlan for NestedLoopJoinExec { MemoryConsumer::new(format!("NestedLoopJoinLoad[{partition}]")) .register(context.memory_pool()); - let inner_table = self.inner_table.once(|| { - collect_left_input( + let inner_table = Arc::clone(&self.inner_table) + .get_or_init(collect_left_input( Arc::clone(&self.left), Arc::clone(&context), join_metrics.clone(), load_reservation, need_produce_result_in_final(self.join_type), self.right().output_partitioning().partition_count(), - ) - }); + )) + .boxed(); let batch_size = context.session_config().batch_size(); let enforce_batch_size_in_joins = @@ -708,7 +708,7 @@ struct NestedLoopJoinStream { /// the outer table data of the nested loop join outer_table: SendableRecordBatchStream, /// the inner table data of the nested loop join - inner_table: OnceFut, + inner_table: SharedResultPending, /// Information of index and left / right placement of columns column_indices: Vec, // TODO: support null aware equal @@ -833,7 +833,7 @@ impl NestedLoopJoinStream { ) -> Poll>>> { let build_timer = self.join_metrics.build_time.timer(); // build hash table from left (build) side, if not yet done - self.left_data = Some(ready!(self.inner_table.get_shared(cx))?); + self.left_data = Some(ready!(self.inner_table.poll_unpin(cx))?); build_timer.done(); self.state = NestedLoopJoinStreamState::FetchProbeBatch; diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index cffc4b4bff8ee..829415d6fa1a2 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -23,7 +23,6 @@ use std::future::Future; use std::iter::once; use std::ops::{IndexMut, Range}; use std::sync::Arc; -use std::task::{Context, Poll}; use crate::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder}; use crate::{ @@ -56,12 +55,11 @@ use datafusion_physical_expr::{ }; use hashbrown::hash_table::Entry::{Occupied, Vacant}; use hashbrown::HashTable; +use tokio::sync::OnceCell; use crate::joins::SharedBitmapBuilder; use crate::projection::ProjectionExec; -use futures::future::{BoxFuture, Shared}; -use futures::{ready, FutureExt}; -use parking_lot::Mutex; +use futures::future::BoxFuture; /// Maps a `u64` hash value based on the build side ["on" values] to a list of indices with this key's value. /// @@ -649,72 +647,54 @@ pub fn build_join_schema( (fields.finish().with_metadata(metadata), column_indices) } -/// A [`OnceAsync`] runs an `async` closure once, where multiple calls to -/// [`OnceAsync::once`] return a [`OnceFut`] that resolves to the result of the -/// same computation. +/// A [`SharedResultOnceCell`] is specialized a [`OnceCell`] that stores the result of +/// the initialization as a [`SharedResult`] wrapped in [`Arc`]. /// -/// This is useful for joins where the results of one child are needed to proceed -/// with multiple output stream +/// It ensures the computation happens exactly once, caching the result for all callers. +/// This type is primarily used for lazy initialization of fallible operations +/// where multiple consumers need access to the same result. /// -/// -/// For example, in a hash join, one input is buffered and shared across -/// potentially multiple output partitions. Each output partition must wait for -/// the hash table to be built before proceeding. -/// -/// Each output partition waits on the same `OnceAsync` before proceeding. -pub(crate) struct OnceAsync { - fut: Mutex>>, -} +/// See [`OnceCell`] for more details. +pub(crate) struct SharedResultOnceCell(OnceCell>>); -impl Default for OnceAsync { +impl Default for SharedResultOnceCell { fn default() -> Self { - Self { - fut: Mutex::new(None), - } + Self(OnceCell::new()) } } -impl Debug for OnceAsync { +impl Debug for SharedResultOnceCell { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "OnceAsync") + write!(f, "SharedResultOnceCell") } } -impl OnceAsync { - /// If this is the first call to this function on this object, will invoke - /// `f` to obtain a future and return a [`OnceFut`] referring to this +impl SharedResultOnceCell { + /// Gets a reference to the contained value, initializing it if necessary. + /// + /// See [`OnceCell::get_or_init`] for more details. /// - /// If this is not the first call, will return a [`OnceFut`] referring - /// to the same future as was returned by the first call - pub(crate) fn once(&self, f: F) -> OnceFut + /// # Parameters + /// - `self`: Must be wrapped in an `Arc` to avoid lifetime issues + /// - `f`: Async function that returns `Result` to be cached + /// + /// # Returns + /// - `Ok(Arc)`: Success result wrapped in `Arc` + /// - `Err(DataFusionError::Shared)`: Original error from `f`, appropriately wrapped + pub(crate) async fn get_or_init(self: Arc, f: Fut) -> Result> where - F: FnOnce() -> Fut, Fut: Future> + Send + 'static, { - self.fut - .lock() - .get_or_insert_with(|| OnceFut::new(f())) + self.0 + .get_or_init(|| async { f.await.map(Arc::new).map_err(Arc::new) }) + .await .clone() + .map_err(DataFusionError::Shared) } } -/// The shared future type used internally within [`OnceAsync`] -type OnceFutPending = Shared>>>; - -/// A [`OnceFut`] represents a shared asynchronous computation, that will be evaluated -/// once for all [`Clone`]'s, with [`OnceFut::get`] providing a non-consuming interface -/// to drive the underlying [`Future`] to completion -pub(crate) struct OnceFut { - state: OnceFutState, -} - -impl Clone for OnceFut { - fn clone(&self) -> Self { - Self { - state: self.state.clone(), - } - } -} +/// Type alias for a boxed future that resolves to a shared result. +pub(crate) type SharedResultPending = BoxFuture<'static, Result>>; /// A shared state between statistic aggregators for a join /// operation. @@ -1038,69 +1018,6 @@ fn max_distinct_count( } } -enum OnceFutState { - Pending(OnceFutPending), - Ready(SharedResult>), -} - -impl Clone for OnceFutState { - fn clone(&self) -> Self { - match self { - Self::Pending(p) => Self::Pending(p.clone()), - Self::Ready(r) => Self::Ready(r.clone()), - } - } -} - -impl OnceFut { - /// Create a new [`OnceFut`] from a [`Future`] - pub(crate) fn new(fut: Fut) -> Self - where - Fut: Future> + Send + 'static, - { - Self { - state: OnceFutState::Pending( - fut.map(|res| res.map(Arc::new).map_err(Arc::new)) - .boxed() - .shared(), - ), - } - } - - /// Get the result of the computation if it is ready, without consuming it - pub(crate) fn get(&mut self, cx: &mut Context<'_>) -> Poll> { - if let OnceFutState::Pending(fut) = &mut self.state { - let r = ready!(fut.poll_unpin(cx)); - self.state = OnceFutState::Ready(r); - } - - // Cannot use loop as this would trip up the borrow checker - match &self.state { - OnceFutState::Pending(_) => unreachable!(), - OnceFutState::Ready(r) => Poll::Ready( - r.as_ref() - .map(|r| r.as_ref()) - .map_err(DataFusionError::from), - ), - } - } - - /// Get shared reference to the result of the computation if it is ready, without consuming it - pub(crate) fn get_shared(&mut self, cx: &mut Context<'_>) -> Poll>> { - if let OnceFutState::Pending(fut) = &mut self.state { - let r = ready!(fut.poll_unpin(cx)); - self.state = OnceFutState::Ready(r); - } - - match &self.state { - OnceFutState::Pending(_) => unreachable!(), - OnceFutState::Ready(r) => { - Poll::Ready(r.clone().map_err(DataFusionError::Shared)) - } - } - } -} - /// Some type `join_type` of join need to maintain the matched indices bit map for the left side, and /// use the bit map to generate the part of result of the join. /// @@ -1824,14 +1741,13 @@ pub(super) fn swap_join_projection( #[cfg(test)] mod tests { use super::*; - use std::pin::Pin; use arrow::array::Int32Array; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Fields}; - use arrow::error::{ArrowError, Result as ArrowResult}; + use datafusion_common::stats::Precision::{Absent, Exact, Inexact}; - use datafusion_common::{arrow_datafusion_err, arrow_err, ScalarValue}; + use datafusion_common::ScalarValue; use rstest::rstest; @@ -1876,39 +1792,6 @@ mod tests { assert!(check(&left, &right, on).is_err()); } - #[tokio::test] - async fn check_error_nesting() { - let once_fut = OnceFut::<()>::new(async { - arrow_err!(ArrowError::CsvError("some error".to_string())) - }); - - struct TestFut(OnceFut<()>); - impl Future for TestFut { - type Output = ArrowResult<()>; - - fn poll( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll { - match ready!(self.0.get(cx)) { - Ok(()) => Poll::Ready(Ok(())), - Err(e) => Poll::Ready(Err(e.into())), - } - } - } - - let res = TestFut(once_fut).await; - let arrow_err_from_fut = res.expect_err("once_fut always return error"); - - let wrapped_err = DataFusionError::from(arrow_err_from_fut); - let root_err = wrapped_err.find_root(); - - let _expected = - arrow_datafusion_err!(ArrowError::CsvError("some error".to_owned())); - - assert!(matches!(root_err, _expected)) - } - #[test] fn check_not_in_left() { let left = vec![Column::new("b", 0)]; @@ -2822,4 +2705,23 @@ mod tests { assert_eq!(col.name(), name); assert_eq!(col.index(), index); } + + #[tokio::test] + async fn test_shared_result_once_cell() { + let cell1 = Arc::new(SharedResultOnceCell::::default()); + let cell2 = Arc::clone(&cell1); + + let fut1 = cell1.get_or_init(async { Ok(1) }); + let fut2 = cell2.get_or_init(async { Ok(2) }); + + let val1 = fut1.await.unwrap(); + assert_eq!(val1, Arc::new(1)); + + let val2 = fut2.await.unwrap(); + assert_eq!( + val2, + Arc::new(1), + "SharedResultOnceCell should not be re-initialized" + ); + } }