Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance Asynchronous Performance of SHJ Implementation #5864

Merged
merged 3 commits into from
Apr 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 109 additions & 105 deletions datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ use arrow::array::{
use arrow::compute::concat_batches;
use arrow::datatypes::{ArrowNativeType, Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use futures::stream::{select, BoxStream};
use futures::{Stream, StreamExt};
use hashbrown::{raw::RawTable, HashSet};
use parking_lot::Mutex;
Expand Down Expand Up @@ -569,12 +570,24 @@ impl ExecutionPlan for SymmetricHashJoinExec {
let right_side_joiner =
OneSideHashJoiner::new(JoinSide::Right, on_right, self.right.schema());

let left_stream = self.left.execute(partition, context.clone())?;
let right_stream = self.right.execute(partition, context)?;
let left_stream = self
.left
.execute(partition, context.clone())?
.map(|val| (JoinSide::Left, val));

let right_stream = self
.right
.execute(partition, context)?
.map(|val| (JoinSide::Right, val));
// This function will attempt to pull items from both streams.
// Each stream will be polled in a round-robin fashion, and whenever a stream is
// ready to yield an item that item is yielded.
// After one of the two input streams completes, the remaining one will be polled exclusively.
// The returned stream completes when both input streams have completed.
let input_stream = select(left_stream, right_stream).boxed();

Ok(Box::pin(SymmetricHashJoinStream {
left_stream,
right_stream,
input_stream,
schema: self.schema(),
filter: self.filter.clone(),
join_type: self.join_type,
Expand All @@ -588,17 +601,14 @@ impl ExecutionPlan for SymmetricHashJoinExec {
right_sorted_filter_expr,
null_equals_null: self.null_equals_null,
final_result: false,
probe_side: JoinSide::Left,
}))
}
}

/// A stream that issues [RecordBatch]es as they arrive from the right of the join.
struct SymmetricHashJoinStream {
/// Left stream
left_stream: SendableRecordBatchStream,
/// right stream
right_stream: SendableRecordBatchStream,
/// Input stream
input_stream: BoxStream<'static, (JoinSide, Result<RecordBatch>)>,
/// Input schema
schema: Arc<Schema>,
/// join filter
Expand All @@ -625,8 +635,6 @@ struct SymmetricHashJoinStream {
metrics: SymmetricHashJoinMetrics,
/// Flag indicating whether there is nothing to process anymore
final_result: bool,
/// The current probe side. We choose build and probe side according to this attribute.
probe_side: JoinSide,
}

impl RecordBatchStream for SymmetricHashJoinStream {
Expand All @@ -641,7 +649,7 @@ impl Stream for SymmetricHashJoinStream {
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
) -> Poll<Option<Self::Item>> {
self.poll_next_impl(cx)
}
}
Expand Down Expand Up @@ -1026,8 +1034,6 @@ struct OneSideHashJoiner {
offset: usize,
/// Deleted offset
deleted_offset: usize,
/// Side is exhausted
exhausted: bool,
}

impl OneSideHashJoiner {
Expand All @@ -1042,7 +1048,6 @@ impl OneSideHashJoiner {
visited_rows: HashSet::new(),
offset: 0,
deleted_offset: 0,
exhausted: false,
}
}

Expand Down Expand Up @@ -1341,77 +1346,35 @@ impl SymmetricHashJoinStream {
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Result<RecordBatch>>> {
loop {
// If the final result has already been obtained, return `Poll::Ready(None)`:
if self.final_result {
return Poll::Ready(None);
}
// If both streams have been exhausted, return the final result:
if self.right.exhausted && self.left.exhausted {
// Get left side results:
let left_result = self.left.build_side_determined_results(
&self.schema,
self.left.input_buffer.num_rows(),
self.right.input_buffer.schema(),
self.join_type,
&self.column_indices,
)?;
// Get right side results:
let right_result = self.right.build_side_determined_results(
&self.schema,
self.right.input_buffer.num_rows(),
self.left.input_buffer.schema(),
self.join_type,
&self.column_indices,
)?;
self.final_result = true;
// Combine results:
let result =
combine_two_batches(&self.schema, left_result, right_result)?;
// Update the metrics if we have a batch; otherwise, continue the loop.
if let Some(batch) = &result {
self.metrics.output_batches.add(1);
self.metrics.output_rows.add(batch.num_rows());
return Poll::Ready(Ok(result).transpose());
} else {
continue;
}
}

// Determine which stream should be polled next. The side the
// RecordBatch comes from becomes the probe side.
let (
input_stream,
probe_hash_joiner,
build_hash_joiner,
probe_side_sorted_filter_expr,
build_side_sorted_filter_expr,
build_join_side,
probe_side_metrics,
) = if self.probe_side.eq(&JoinSide::Left) {
(
&mut self.left_stream,
&mut self.left,
&mut self.right,
&mut self.left_sorted_filter_expr,
&mut self.right_sorted_filter_expr,
JoinSide::Right,
&mut self.metrics.left,
)
} else {
(
&mut self.right_stream,
&mut self.right,
&mut self.left,
&mut self.right_sorted_filter_expr,
&mut self.left_sorted_filter_expr,
JoinSide::Left,
&mut self.metrics.right,
)
};
// Poll the next batch from `input_stream`:
match input_stream.poll_next_unpin(cx) {
match self.input_stream.poll_next_unpin(cx) {
// Batch is available
Poll::Ready(Some(Ok(probe_batch))) => {
Poll::Ready(Some((side, Ok(probe_batch)))) => {
// Determine which stream should be polled next. The side the
// RecordBatch comes from becomes the probe side.
let (
probe_hash_joiner,
build_hash_joiner,
probe_side_sorted_filter_expr,
build_side_sorted_filter_expr,
probe_side_metrics,
) = if side.eq(&JoinSide::Left) {
(
&mut self.left,
&mut self.right,
&mut self.left_sorted_filter_expr,
&mut self.right_sorted_filter_expr,
&mut self.metrics.left,
)
} else {
(
&mut self.right,
&mut self.left,
&mut self.right_sorted_filter_expr,
&mut self.left_sorted_filter_expr,
&mut self.metrics.right,
)
};
// Update the metrics for the stream that was polled:
probe_side_metrics.input_batches.add(1);
probe_side_metrics.input_rows.add(probe_batch.num_rows());
Expand Down Expand Up @@ -1475,32 +1438,51 @@ impl SymmetricHashJoinStream {
// Combine results:
let result =
combine_two_batches(&self.schema, equal_result, anti_result)?;
// Choose next poll side. If the other side is not exhausted,
// switch the probe side before returning the result.
if !build_hash_joiner.exhausted {
self.probe_side = build_join_side;
}
// Update the metrics if we have a batch; otherwise, continue the loop.
if let Some(batch) = &result {
self.metrics.output_batches.add(1);
self.metrics.output_rows.add(batch.num_rows());
return Poll::Ready(Ok(result).transpose());
}
}
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
Poll::Ready(Some((_, Err(e)))) => return Poll::Ready(Some(Err(e))),
Poll::Ready(None) => {
// Mark the probe side exhausted:
probe_hash_joiner.exhausted = true;
// Change the probe side:
self.probe_side = build_join_side;
}
Poll::Pending => {
if !build_hash_joiner.exhausted {
self.probe_side = build_join_side;
} else {
return Poll::Pending;
// If the final result has already been obtained, return `Poll::Ready(None)`:
if self.final_result {
return Poll::Ready(None);
}
self.final_result = true;
// Get the left side results:
let left_result = self.left.build_side_determined_results(
&self.schema,
self.left.input_buffer.num_rows(),
self.right.input_buffer.schema(),
self.join_type,
&self.column_indices,
)?;
// Get the right side results:
let right_result = self.right.build_side_determined_results(
&self.schema,
self.right.input_buffer.num_rows(),
self.left.input_buffer.schema(),
self.join_type,
&self.column_indices,
)?;

// Combine the left and right results:
let result =
combine_two_batches(&self.schema, left_result, right_result)?;

// Update the metrics and return the result:
if let Some(batch) = &result {
// Update the metrics:
self.metrics.output_batches.add(1);
self.metrics.output_rows.add(batch.num_rows());

return Poll::Ready(Ok(result).transpose());
}
}
Poll::Pending => return Poll::Pending,
}
}
}
Expand Down Expand Up @@ -1530,7 +1512,7 @@ mod tests {
collect, common, memory::MemoryExec, repartition::RepartitionExec,
};
use crate::prelude::{CsvReadOptions, SessionConfig, SessionContext};
use crate::test_util;
use crate::test_util::register_unbounded_file_with_ordering;

use super::*;

Expand Down Expand Up @@ -2207,18 +2189,40 @@ mod tests {
let tmp_dir = TempDir::new().unwrap();
let left_file_path = tmp_dir.path().join("left.csv");
File::create(left_file_path.clone()).unwrap();
test_util::test_create_unbounded_sorted_file(
// Create schema
let schema = Arc::new(Schema::new(vec![
Field::new("a1", DataType::UInt32, false),
Field::new("a2", DataType::UInt32, false),
]));
// Specify the ordering:
let file_sort_order = Some(
[datafusion_expr::col("a1")]
.into_iter()
.map(|e| {
let ascending = true;
let nulls_first = false;
e.sort(ascending, nulls_first)
})
.collect::<Vec<_>>(),
);
register_unbounded_file_with_ordering(
&ctx,
left_file_path.clone(),
schema.clone(),
&left_file_path,
"left",
file_sort_order.clone(),
true,
)
.await?;
let right_file_path = tmp_dir.path().join("right.csv");
File::create(right_file_path.clone()).unwrap();
test_util::test_create_unbounded_sorted_file(
register_unbounded_file_with_ordering(
&ctx,
right_file_path.clone(),
schema,
&right_file_path,
"right",
file_sort_order,
true,
)
.await?;
let df = ctx.sql("EXPLAIN SELECT t1.a1, t1.a2, t2.a1, t2.a2 FROM left as t1 FULL JOIN right as t2 ON t1.a2 = t2.a2 AND t1.a1 > t2.a1 + 3 AND t1.a1 < t2.a1 + 10").await?;
Expand Down
27 changes: 8 additions & 19 deletions datafusion/core/src/test_util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub mod parquet;

use std::any::Any;
use std::collections::HashMap;
use std::path::Path;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{env, error::Error, path::PathBuf, sync::Arc};
Expand Down Expand Up @@ -512,34 +513,22 @@ mod tests {
}

/// This function creates an unbounded sorted file for testing purposes.
pub async fn test_create_unbounded_sorted_file(
pub async fn register_unbounded_file_with_ordering(
ctx: &SessionContext,
file_path: PathBuf,
schema: SchemaRef,
file_path: &Path,
table_name: &str,
file_sort_order: Option<Vec<Expr>>,
with_unbounded_execution: bool,
) -> Result<()> {
// Create schema:
let schema = Arc::new(Schema::new(vec![
Field::new("a1", DataType::UInt32, false),
Field::new("a2", DataType::UInt32, false),
]));
// Specify the ordering:
let file_sort_order = [datafusion_expr::col("a1")]
.into_iter()
.map(|e| {
let ascending = true;
let nulls_first = false;
e.sort(ascending, nulls_first)
})
.collect::<Vec<_>>();
// Mark infinite and provide schema:
let fifo_options = CsvReadOptions::new()
.schema(schema.as_ref())
.has_header(false)
.mark_infinite(true);
.mark_infinite(with_unbounded_execution);
// Get listing options:
let options_sort = fifo_options
.to_listing_options(&ctx.copied_config())
.with_file_sort_order(Some(file_sort_order));
.with_file_sort_order(file_sort_order);
// Register table:
ctx.register_listing_table(
table_name,
Expand Down
Loading