diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs index dafd0bfd4940..c6214e8986b9 100644 --- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs @@ -41,6 +41,7 @@ use arrow::array::{ use arrow::compute::concat_batches; use arrow::datatypes::{ArrowNativeType, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use futures::stream::{select, BoxStream}; use futures::{Stream, StreamExt}; use hashbrown::{raw::RawTable, HashSet}; use parking_lot::Mutex; @@ -569,12 +570,24 @@ impl ExecutionPlan for SymmetricHashJoinExec { let right_side_joiner = OneSideHashJoiner::new(JoinSide::Right, on_right, self.right.schema()); - let left_stream = self.left.execute(partition, context.clone())?; - let right_stream = self.right.execute(partition, context)?; + let left_stream = self + .left + .execute(partition, context.clone())? + .map(|val| (JoinSide::Left, val)); + + let right_stream = self + .right + .execute(partition, context)? + .map(|val| (JoinSide::Right, val)); + // This function will attempt to pull items from both streams. + // Each stream will be polled in a round-robin fashion, and whenever a stream is + // ready to yield an item that item is yielded. + // After one of the two input streams completes, the remaining one will be polled exclusively. + // The returned stream completes when both input streams have completed. + let input_stream = select(left_stream, right_stream).boxed(); Ok(Box::pin(SymmetricHashJoinStream { - left_stream, - right_stream, + input_stream, schema: self.schema(), filter: self.filter.clone(), join_type: self.join_type, @@ -588,17 +601,14 @@ impl ExecutionPlan for SymmetricHashJoinExec { right_sorted_filter_expr, null_equals_null: self.null_equals_null, final_result: false, - probe_side: JoinSide::Left, })) } } /// A stream that issues [RecordBatch]es as they arrive from the right of the join. struct SymmetricHashJoinStream { - /// Left stream - left_stream: SendableRecordBatchStream, - /// right stream - right_stream: SendableRecordBatchStream, + /// Input stream + input_stream: BoxStream<'static, (JoinSide, Result)>, /// Input schema schema: Arc, /// join filter @@ -625,8 +635,6 @@ struct SymmetricHashJoinStream { metrics: SymmetricHashJoinMetrics, /// Flag indicating whether there is nothing to process anymore final_result: bool, - /// The current probe side. We choose build and probe side according to this attribute. - probe_side: JoinSide, } impl RecordBatchStream for SymmetricHashJoinStream { @@ -641,7 +649,7 @@ impl Stream for SymmetricHashJoinStream { fn poll_next( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + ) -> Poll> { self.poll_next_impl(cx) } } @@ -1026,8 +1034,6 @@ struct OneSideHashJoiner { offset: usize, /// Deleted offset deleted_offset: usize, - /// Side is exhausted - exhausted: bool, } impl OneSideHashJoiner { @@ -1042,7 +1048,6 @@ impl OneSideHashJoiner { visited_rows: HashSet::new(), offset: 0, deleted_offset: 0, - exhausted: false, } } @@ -1341,77 +1346,35 @@ impl SymmetricHashJoinStream { cx: &mut std::task::Context<'_>, ) -> Poll>> { loop { - // If the final result has already been obtained, return `Poll::Ready(None)`: - if self.final_result { - return Poll::Ready(None); - } - // If both streams have been exhausted, return the final result: - if self.right.exhausted && self.left.exhausted { - // Get left side results: - let left_result = self.left.build_side_determined_results( - &self.schema, - self.left.input_buffer.num_rows(), - self.right.input_buffer.schema(), - self.join_type, - &self.column_indices, - )?; - // Get right side results: - let right_result = self.right.build_side_determined_results( - &self.schema, - self.right.input_buffer.num_rows(), - self.left.input_buffer.schema(), - self.join_type, - &self.column_indices, - )?; - self.final_result = true; - // Combine results: - let result = - combine_two_batches(&self.schema, left_result, right_result)?; - // Update the metrics if we have a batch; otherwise, continue the loop. - if let Some(batch) = &result { - self.metrics.output_batches.add(1); - self.metrics.output_rows.add(batch.num_rows()); - return Poll::Ready(Ok(result).transpose()); - } else { - continue; - } - } - - // Determine which stream should be polled next. The side the - // RecordBatch comes from becomes the probe side. - let ( - input_stream, - probe_hash_joiner, - build_hash_joiner, - probe_side_sorted_filter_expr, - build_side_sorted_filter_expr, - build_join_side, - probe_side_metrics, - ) = if self.probe_side.eq(&JoinSide::Left) { - ( - &mut self.left_stream, - &mut self.left, - &mut self.right, - &mut self.left_sorted_filter_expr, - &mut self.right_sorted_filter_expr, - JoinSide::Right, - &mut self.metrics.left, - ) - } else { - ( - &mut self.right_stream, - &mut self.right, - &mut self.left, - &mut self.right_sorted_filter_expr, - &mut self.left_sorted_filter_expr, - JoinSide::Left, - &mut self.metrics.right, - ) - }; // Poll the next batch from `input_stream`: - match input_stream.poll_next_unpin(cx) { + match self.input_stream.poll_next_unpin(cx) { // Batch is available - Poll::Ready(Some(Ok(probe_batch))) => { + Poll::Ready(Some((side, Ok(probe_batch)))) => { + // Determine which stream should be polled next. The side the + // RecordBatch comes from becomes the probe side. + let ( + probe_hash_joiner, + build_hash_joiner, + probe_side_sorted_filter_expr, + build_side_sorted_filter_expr, + probe_side_metrics, + ) = if side.eq(&JoinSide::Left) { + ( + &mut self.left, + &mut self.right, + &mut self.left_sorted_filter_expr, + &mut self.right_sorted_filter_expr, + &mut self.metrics.left, + ) + } else { + ( + &mut self.right, + &mut self.left, + &mut self.right_sorted_filter_expr, + &mut self.left_sorted_filter_expr, + &mut self.metrics.right, + ) + }; // Update the metrics for the stream that was polled: probe_side_metrics.input_batches.add(1); probe_side_metrics.input_rows.add(probe_batch.num_rows()); @@ -1475,11 +1438,6 @@ impl SymmetricHashJoinStream { // Combine results: let result = combine_two_batches(&self.schema, equal_result, anti_result)?; - // Choose next poll side. If the other side is not exhausted, - // switch the probe side before returning the result. - if !build_hash_joiner.exhausted { - self.probe_side = build_join_side; - } // Update the metrics if we have a batch; otherwise, continue the loop. if let Some(batch) = &result { self.metrics.output_batches.add(1); @@ -1487,20 +1445,44 @@ impl SymmetricHashJoinStream { return Poll::Ready(Ok(result).transpose()); } } - Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), + Poll::Ready(Some((_, Err(e)))) => return Poll::Ready(Some(Err(e))), Poll::Ready(None) => { - // Mark the probe side exhausted: - probe_hash_joiner.exhausted = true; - // Change the probe side: - self.probe_side = build_join_side; - } - Poll::Pending => { - if !build_hash_joiner.exhausted { - self.probe_side = build_join_side; - } else { - return Poll::Pending; + // If the final result has already been obtained, return `Poll::Ready(None)`: + if self.final_result { + return Poll::Ready(None); + } + self.final_result = true; + // Get the left side results: + let left_result = self.left.build_side_determined_results( + &self.schema, + self.left.input_buffer.num_rows(), + self.right.input_buffer.schema(), + self.join_type, + &self.column_indices, + )?; + // Get the right side results: + let right_result = self.right.build_side_determined_results( + &self.schema, + self.right.input_buffer.num_rows(), + self.left.input_buffer.schema(), + self.join_type, + &self.column_indices, + )?; + + // Combine the left and right results: + let result = + combine_two_batches(&self.schema, left_result, right_result)?; + + // Update the metrics and return the result: + if let Some(batch) = &result { + // Update the metrics: + self.metrics.output_batches.add(1); + self.metrics.output_rows.add(batch.num_rows()); + + return Poll::Ready(Ok(result).transpose()); } } + Poll::Pending => return Poll::Pending, } } } @@ -1530,7 +1512,7 @@ mod tests { collect, common, memory::MemoryExec, repartition::RepartitionExec, }; use crate::prelude::{CsvReadOptions, SessionConfig, SessionContext}; - use crate::test_util; + use crate::test_util::register_unbounded_file_with_ordering; use super::*; @@ -2207,18 +2189,40 @@ mod tests { let tmp_dir = TempDir::new().unwrap(); let left_file_path = tmp_dir.path().join("left.csv"); File::create(left_file_path.clone()).unwrap(); - test_util::test_create_unbounded_sorted_file( + // Create schema + let schema = Arc::new(Schema::new(vec![ + Field::new("a1", DataType::UInt32, false), + Field::new("a2", DataType::UInt32, false), + ])); + // Specify the ordering: + let file_sort_order = Some( + [datafusion_expr::col("a1")] + .into_iter() + .map(|e| { + let ascending = true; + let nulls_first = false; + e.sort(ascending, nulls_first) + }) + .collect::>(), + ); + register_unbounded_file_with_ordering( &ctx, - left_file_path.clone(), + schema.clone(), + &left_file_path, "left", + file_sort_order.clone(), + true, ) .await?; let right_file_path = tmp_dir.path().join("right.csv"); File::create(right_file_path.clone()).unwrap(); - test_util::test_create_unbounded_sorted_file( + register_unbounded_file_with_ordering( &ctx, - right_file_path.clone(), + schema, + &right_file_path, "right", + file_sort_order, + true, ) .await?; let df = ctx.sql("EXPLAIN SELECT t1.a1, t1.a2, t2.a1, t2.a2 FROM left as t1 FULL JOIN right as t2 ON t1.a2 = t2.a2 AND t1.a1 > t2.a1 + 3 AND t1.a1 < t2.a1 + 10").await?; diff --git a/datafusion/core/src/test_util/mod.rs b/datafusion/core/src/test_util/mod.rs index 982a7a83002e..d42379b82a35 100644 --- a/datafusion/core/src/test_util/mod.rs +++ b/datafusion/core/src/test_util/mod.rs @@ -21,6 +21,7 @@ pub mod parquet; use std::any::Any; use std::collections::HashMap; +use std::path::Path; use std::pin::Pin; use std::task::{Context, Poll}; use std::{env, error::Error, path::PathBuf, sync::Arc}; @@ -512,34 +513,22 @@ mod tests { } /// This function creates an unbounded sorted file for testing purposes. -pub async fn test_create_unbounded_sorted_file( +pub async fn register_unbounded_file_with_ordering( ctx: &SessionContext, - file_path: PathBuf, + schema: SchemaRef, + file_path: &Path, table_name: &str, + file_sort_order: Option>, + with_unbounded_execution: bool, ) -> Result<()> { - // Create schema: - let schema = Arc::new(Schema::new(vec![ - Field::new("a1", DataType::UInt32, false), - Field::new("a2", DataType::UInt32, false), - ])); - // Specify the ordering: - let file_sort_order = [datafusion_expr::col("a1")] - .into_iter() - .map(|e| { - let ascending = true; - let nulls_first = false; - e.sort(ascending, nulls_first) - }) - .collect::>(); // Mark infinite and provide schema: let fifo_options = CsvReadOptions::new() .schema(schema.as_ref()) - .has_header(false) - .mark_infinite(true); + .mark_infinite(with_unbounded_execution); // Get listing options: let options_sort = fifo_options .to_listing_options(&ctx.copied_config()) - .with_file_sort_order(Some(file_sort_order)); + .with_file_sort_order(file_sort_order); // Register table: ctx.register_listing_table( table_name, diff --git a/datafusion/core/tests/fifo.rs b/datafusion/core/tests/fifo.rs index a3dfe42e21a5..5c12045c4398 100644 --- a/datafusion/core/tests/fifo.rs +++ b/datafusion/core/tests/fifo.rs @@ -22,13 +22,12 @@ mod unix_test { use arrow::array::Array; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion::execution::options::ReadOptions; + use datafusion::test_util::register_unbounded_file_with_ordering; use datafusion::{ prelude::{CsvReadOptions, SessionConfig, SessionContext}, test_util::{aggr_test_schema, arrow_test_data}, }; use datafusion_common::{DataFusionError, Result}; - use datafusion_expr::Expr; use futures::StreamExt; use itertools::enumerate; use nix::sys::stat; @@ -36,7 +35,6 @@ mod unix_test { use rstest::*; use std::fs::{File, OpenOptions}; use std::io::Write; - use std::path::Path; use std::path::PathBuf; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; @@ -199,35 +197,6 @@ mod unix_test { }) } - /// This function creates an unbounded sorted file for testing purposes. - pub async fn register_unbounded_file_with_ordering( - ctx: &SessionContext, - schema: arrow::datatypes::SchemaRef, - file_path: &Path, - table_name: &str, - file_sort_order: Option>, - with_unbounded_execution: bool, - ) -> Result<()> { - // Mark infinite and provide schema: - let fifo_options = CsvReadOptions::new() - .schema(schema.as_ref()) - .mark_infinite(with_unbounded_execution); - // Get listing options: - let options_sort = fifo_options - .to_listing_options(&ctx.copied_config()) - .with_file_sort_order(file_sort_order); - // Register table: - ctx.register_listing_table( - table_name, - file_path.as_os_str().to_str().unwrap(), - options_sort, - Some(schema), - None, - ) - .await?; - Ok(()) - } - // This test provides a relatively realistic end-to-end scenario where // we change the join into a [SymmetricHashJoin] to accommodate two // unbounded (FIFO) sources.