-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix: Sort Merge Join LeftSemi issues when JoinFilter is set #10304
Changes from 19 commits
ca564ce
fd21ccf
64d7e5c
8c6010e
f9e1133
ed0035b
4c2c8f3
4052b0d
9da2c45
9c71eef
fe0bb60
c0fd73e
f993b3c
1354f83
c129846
22c61fc
30f28fe
823f396
f0e60da
a06acaa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -30,22 +30,13 @@ use std::pin::Pin; | |||||||
use std::sync::Arc; | ||||||||
use std::task::{Context, Poll}; | ||||||||
|
||||||||
use crate::expressions::PhysicalSortExpr; | ||||||||
use crate::joins::utils::{ | ||||||||
build_join_schema, check_join_is_valid, estimate_join_statistics, | ||||||||
partitioned_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef, | ||||||||
}; | ||||||||
use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; | ||||||||
use crate::{ | ||||||||
execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, Distribution, | ||||||||
ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties, | ||||||||
RecordBatchStream, SendableRecordBatchStream, Statistics, | ||||||||
}; | ||||||||
|
||||||||
use arrow::array::*; | ||||||||
use arrow::compute::{self, concat_batches, take, SortOptions}; | ||||||||
use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; | ||||||||
use arrow::error::ArrowError; | ||||||||
use futures::{Stream, StreamExt}; | ||||||||
use hashbrown::HashSet; | ||||||||
|
||||||||
use datafusion_common::{ | ||||||||
internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, Result, | ||||||||
}; | ||||||||
|
@@ -54,7 +45,17 @@ use datafusion_execution::TaskContext; | |||||||
use datafusion_physical_expr::equivalence::join_equivalence_properties; | ||||||||
use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement}; | ||||||||
|
||||||||
use futures::{Stream, StreamExt}; | ||||||||
use crate::expressions::PhysicalSortExpr; | ||||||||
use crate::joins::utils::{ | ||||||||
build_join_schema, check_join_is_valid, estimate_join_statistics, | ||||||||
partitioned_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef, | ||||||||
}; | ||||||||
use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; | ||||||||
use crate::{ | ||||||||
execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, Distribution, | ||||||||
ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties, | ||||||||
RecordBatchStream, SendableRecordBatchStream, Statistics, | ||||||||
}; | ||||||||
|
||||||||
/// join execution plan executes partitions in parallel and combines them into a set of | ||||||||
/// partitions. | ||||||||
|
@@ -491,6 +492,10 @@ struct StreamedBatch { | |||||||
pub output_indices: Vec<StreamedJoinedChunk>, | ||||||||
/// Index of currently scanned batch from buffered data | ||||||||
pub buffered_batch_idx: Option<usize>, | ||||||||
/// Indices that found a match for the given join filter | ||||||||
/// Used for semi joins to keep track the streaming index which got a join filter match | ||||||||
/// and already emitted to the output. | ||||||||
pub join_filter_matched_idxs: HashSet<u64>, | ||||||||
} | ||||||||
|
||||||||
impl StreamedBatch { | ||||||||
|
@@ -502,6 +507,7 @@ impl StreamedBatch { | |||||||
join_arrays, | ||||||||
output_indices: vec![], | ||||||||
buffered_batch_idx: None, | ||||||||
join_filter_matched_idxs: HashSet::new(), | ||||||||
} | ||||||||
} | ||||||||
|
||||||||
|
@@ -512,6 +518,7 @@ impl StreamedBatch { | |||||||
join_arrays: vec![], | ||||||||
output_indices: vec![], | ||||||||
buffered_batch_idx: None, | ||||||||
join_filter_matched_idxs: HashSet::new(), | ||||||||
} | ||||||||
} | ||||||||
|
||||||||
|
@@ -989,8 +996,21 @@ impl SMJStream { | |||||||
} | ||||||||
} | ||||||||
Ordering::Equal => { | ||||||||
if matches!(self.join_type, JoinType::LeftSemi) { | ||||||||
if matches!(self.join_type, JoinType::LeftSemi) && self.filter.is_some() { | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can combine this and below which both are for |
||||||||
join_streamed = !self | ||||||||
.streamed_batch | ||||||||
.join_filter_matched_idxs | ||||||||
.contains(&(self.streamed_batch.idx as u64)) | ||||||||
&& !self.streamed_joined; | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If If I'm not sure what this is added to check. I don't see it addresses the issue in https://github.com/apache/datafusion/pull/10304/files#r1601943239. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, I got what you want to do here. It is better to add a comment for later readers. |
||||||||
// if the join filter specified there can be references to buffered columns | ||||||||
// so buffered columns are needed to access them | ||||||||
join_buffered = join_streamed; | ||||||||
} | ||||||||
if matches!(self.join_type, JoinType::LeftSemi) && self.filter.is_none() { | ||||||||
join_streamed = !self.streamed_joined; | ||||||||
// if the join filter specified there can be references to buffered columns | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This block knows |
||||||||
// so buffered columns are needed to access them | ||||||||
join_buffered = self.filter.is_some(); | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LeftSemi doesn't join buffered side, why we want to do this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, I see. As the filter uses buffered columns, we need to access to it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But the comment doesn't look correct as we don't actually join buffered columns. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, but if there are more rows at buffered side are matched on keys, won't it add additional joined pairs with nulls and buffered rows? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I hope I got the concern right and wrapped it into SQL query
it passes, I can add it to slt file as well There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not able to run such test just because of other SMJ issue not related to this PR:
I'll file a separate issue for this one, but perhaps we can go with this PR because potential problem you talking about cannot ever be hit because of the issue above There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for creating the ticket. I think that current fix to the LeftSemi isn't correct due to the additional joined pairs of nulls and buffered rows. I don't think we should move forward with it because of that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm putting the PR to draft until we can check if PR requires modifications to avoid addition join pairs of nulls and buffered rows. It depends on #10491 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have checked the case, it doesn't fail, however it produces more rows than expected, looking into this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea, what I meant is actually it will produce some rows that are not correct results (due to the additional joined pairs of nulls and buffered rows). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
} | ||||||||
if matches!( | ||||||||
self.join_type, | ||||||||
|
@@ -1134,17 +1154,15 @@ impl SMJStream { | |||||||
.collect::<Result<Vec<_>, ArrowError>>()?; | ||||||||
|
||||||||
let buffered_indices: UInt64Array = chunk.buffered_indices.finish(); | ||||||||
|
||||||||
comphead marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
let mut buffered_columns = | ||||||||
if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { | ||||||||
vec![] | ||||||||
} else if let Some(buffered_idx) = chunk.buffered_batch_idx { | ||||||||
self.buffered_data.batches[buffered_idx] | ||||||||
.batch | ||||||||
.columns() | ||||||||
.iter() | ||||||||
.map(|column| take(column, &buffered_indices, None)) | ||||||||
.collect::<Result<Vec<_>, ArrowError>>()? | ||||||||
get_buffered_columns( | ||||||||
&self.buffered_data, | ||||||||
buffered_idx, | ||||||||
&buffered_indices, | ||||||||
)? | ||||||||
} else { | ||||||||
self.buffered_schema | ||||||||
.fields() | ||||||||
|
@@ -1161,6 +1179,15 @@ impl SMJStream { | |||||||
let filter_columns = if chunk.buffered_batch_idx.is_some() { | ||||||||
if matches!(self.join_type, JoinType::Right) { | ||||||||
get_filter_column(&self.filter, &buffered_columns, &streamed_columns) | ||||||||
} else if matches!(self.join_type, JoinType::LeftSemi) { | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if this should also check for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have more to go, LeftAnti is first as it prevents TPCH to run and then double check RightSemi as well, good point |
||||||||
// unwrap is safe here as we check is_some on top of if statement | ||||||||
let buffered_columns = get_buffered_columns( | ||||||||
&self.buffered_data, | ||||||||
chunk.buffered_batch_idx.unwrap(), | ||||||||
&buffered_indices, | ||||||||
)?; | ||||||||
|
||||||||
get_filter_column(&self.filter, &streamed_columns, &buffered_columns) | ||||||||
} else { | ||||||||
get_filter_column(&self.filter, &streamed_columns, &buffered_columns) | ||||||||
} | ||||||||
|
@@ -1195,7 +1222,17 @@ impl SMJStream { | |||||||
.into_array(filter_batch.num_rows())?; | ||||||||
|
||||||||
// The selection mask of the filter | ||||||||
let mask = datafusion_common::cast::as_boolean_array(&filter_result)?; | ||||||||
let mut mask = | ||||||||
datafusion_common::cast::as_boolean_array(&filter_result)?; | ||||||||
|
||||||||
let maybe_filtered_join_mask: Option<(BooleanArray, Vec<u64>)> = | ||||||||
get_filtered_join_mask(self.join_type, streamed_indices, mask); | ||||||||
if let Some(ref filtered_join_mask) = maybe_filtered_join_mask { | ||||||||
mask = &filtered_join_mask.0; | ||||||||
self.streamed_batch | ||||||||
.join_filter_matched_idxs | ||||||||
.extend(&filtered_join_mask.1); | ||||||||
} | ||||||||
|
||||||||
// Push the filtered batch to the output | ||||||||
let filtered_batch = | ||||||||
|
@@ -1365,6 +1402,69 @@ fn get_filter_column( | |||||||
filter_columns | ||||||||
} | ||||||||
|
||||||||
/// Get `buffered_indices` rows for `buffered_data[buffered_batch_idx]` | ||||||||
#[inline(always)] | ||||||||
fn get_buffered_columns( | ||||||||
buffered_data: &BufferedData, | ||||||||
buffered_batch_idx: usize, | ||||||||
buffered_indices: &UInt64Array, | ||||||||
) -> Result<Vec<ArrayRef>, ArrowError> { | ||||||||
buffered_data.batches[buffered_batch_idx] | ||||||||
.batch | ||||||||
.columns() | ||||||||
.iter() | ||||||||
.map(|column| take(column, &buffered_indices, None)) | ||||||||
.collect::<Result<Vec<_>, ArrowError>>() | ||||||||
} | ||||||||
|
||||||||
// Calculate join filter bit mask considering join type specifics | ||||||||
comphead marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
// `streamed_indices` - array of streamed datasource JOINED row indices | ||||||||
// `mask` - array booleans representing computed join filter expression eval result: | ||||||||
// true = the row index matches the join filter | ||||||||
// false = the row index doesn't match the join filter | ||||||||
// `streamed_indices` have the same length as `mask` | ||||||||
fn get_filtered_join_mask( | ||||||||
join_type: JoinType, | ||||||||
streamed_indices: UInt64Array, | ||||||||
mask: &BooleanArray, | ||||||||
) -> Option<(BooleanArray, Vec<u64>)> { | ||||||||
// for LeftSemi Join the filter mask should be calculated in its own way: | ||||||||
// if we find at least one matching row for specific streaming index | ||||||||
// we don't need to check any others for the same index | ||||||||
if matches!(join_type, JoinType::LeftSemi) { | ||||||||
comphead marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
// have we seen a filter match for a streaming index before | ||||||||
let mut seen_as_true: bool = false; | ||||||||
let streamed_indices_length = streamed_indices.len(); | ||||||||
let mut corrected_mask: BooleanBuilder = | ||||||||
BooleanBuilder::with_capacity(streamed_indices_length); | ||||||||
|
||||||||
let mut filter_matched_indices: Vec<u64> = vec![]; | ||||||||
|
||||||||
#[allow(clippy::needless_range_loop)] | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder why ignore clippy here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Clippy doesn't like for loops anymore .... |
||||||||
for i in 0..streamed_indices_length { | ||||||||
// LeftSemi respects only first true values for specific streaming index, | ||||||||
// others true values for the same index must be false | ||||||||
if mask.value(i) && !seen_as_true { | ||||||||
seen_as_true = true; | ||||||||
corrected_mask.append_value(true); | ||||||||
filter_matched_indices.push(streamed_indices.value(i)); | ||||||||
viirya marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
} else { | ||||||||
corrected_mask.append_value(false); | ||||||||
} | ||||||||
|
||||||||
// if switched to next streaming index(e.g. from 0 to 1, or from 1 to 2), we reset seen_as_true flag | ||||||||
if i < streamed_indices_length - 1 | ||||||||
&& streamed_indices.value(i) != streamed_indices.value(i + 1) | ||||||||
{ | ||||||||
seen_as_true = false; | ||||||||
} | ||||||||
} | ||||||||
Some((corrected_mask.finish(), filter_matched_indices)) | ||||||||
} else { | ||||||||
None | ||||||||
} | ||||||||
} | ||||||||
|
||||||||
/// Buffered data contains all buffered batches with one unique join key | ||||||||
#[derive(Debug, Default)] | ||||||||
struct BufferedData { | ||||||||
|
@@ -1604,24 +1704,28 @@ fn is_join_arrays_equal( | |||||||
mod tests { | ||||||||
use std::sync::Arc; | ||||||||
|
||||||||
use crate::expressions::Column; | ||||||||
use crate::joins::utils::JoinOn; | ||||||||
use crate::joins::SortMergeJoinExec; | ||||||||
use crate::memory::MemoryExec; | ||||||||
use crate::test::build_table_i32; | ||||||||
use crate::{common, ExecutionPlan}; | ||||||||
|
||||||||
use arrow::array::{Date32Array, Date64Array, Int32Array}; | ||||||||
use arrow::compute::SortOptions; | ||||||||
use arrow::datatypes::{DataType, Field, Schema}; | ||||||||
use arrow::record_batch::RecordBatch; | ||||||||
use arrow_array::{BooleanArray, UInt64Array}; | ||||||||
|
||||||||
use datafusion_common::JoinType::LeftSemi; | ||||||||
use datafusion_common::{ | ||||||||
assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType, Result, | ||||||||
}; | ||||||||
use datafusion_execution::config::SessionConfig; | ||||||||
use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; | ||||||||
use datafusion_execution::TaskContext; | ||||||||
|
||||||||
use crate::expressions::Column; | ||||||||
use crate::joins::sort_merge_join::get_filtered_join_mask; | ||||||||
use crate::joins::utils::JoinOn; | ||||||||
use crate::joins::SortMergeJoinExec; | ||||||||
use crate::memory::MemoryExec; | ||||||||
use crate::test::build_table_i32; | ||||||||
use crate::{common, ExecutionPlan}; | ||||||||
|
||||||||
fn build_table( | ||||||||
a: (&str, &Vec<i32>), | ||||||||
b: (&str, &Vec<i32>), | ||||||||
|
@@ -2641,6 +2745,72 @@ mod tests { | |||||||
|
||||||||
Ok(()) | ||||||||
} | ||||||||
|
||||||||
#[tokio::test] | ||||||||
async fn left_semi_join_filtered_mask() -> Result<()> { | ||||||||
assert_eq!( | ||||||||
get_filtered_join_mask( | ||||||||
LeftSemi, | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe we should test a type other than LeftSemi as negative test coverage 🤔 |
||||||||
UInt64Array::from(vec![0, 0, 1, 1]), | ||||||||
&BooleanArray::from(vec![true, true, false, false]) | ||||||||
), | ||||||||
Some((BooleanArray::from(vec![true, false, false, false]), vec![0])) | ||||||||
); | ||||||||
|
||||||||
assert_eq!( | ||||||||
get_filtered_join_mask( | ||||||||
LeftSemi, | ||||||||
UInt64Array::from(vec![0, 1]), | ||||||||
&BooleanArray::from(vec![true, true]) | ||||||||
), | ||||||||
Some((BooleanArray::from(vec![true, true]), vec![0, 1])) | ||||||||
); | ||||||||
|
||||||||
assert_eq!( | ||||||||
get_filtered_join_mask( | ||||||||
LeftSemi, | ||||||||
UInt64Array::from(vec![0, 1]), | ||||||||
&BooleanArray::from(vec![false, true]) | ||||||||
), | ||||||||
Some((BooleanArray::from(vec![false, true]), vec![1])) | ||||||||
); | ||||||||
|
||||||||
assert_eq!( | ||||||||
get_filtered_join_mask( | ||||||||
LeftSemi, | ||||||||
UInt64Array::from(vec![0, 1]), | ||||||||
&BooleanArray::from(vec![true, false]) | ||||||||
), | ||||||||
Some((BooleanArray::from(vec![true, false]), vec![0])) | ||||||||
); | ||||||||
|
||||||||
assert_eq!( | ||||||||
get_filtered_join_mask( | ||||||||
LeftSemi, | ||||||||
UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), | ||||||||
&BooleanArray::from(vec![false, true, true, true, true, true]) | ||||||||
), | ||||||||
Some(( | ||||||||
BooleanArray::from(vec![false, true, false, true, false, false]), | ||||||||
vec![0, 1] | ||||||||
)) | ||||||||
); | ||||||||
|
||||||||
assert_eq!( | ||||||||
get_filtered_join_mask( | ||||||||
LeftSemi, | ||||||||
UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), | ||||||||
&BooleanArray::from(vec![false, false, false, false, false, true]) | ||||||||
), | ||||||||
Some(( | ||||||||
BooleanArray::from(vec![false, false, false, false, false, true]), | ||||||||
vec![1] | ||||||||
)) | ||||||||
); | ||||||||
|
||||||||
Ok(()) | ||||||||
} | ||||||||
|
||||||||
/// Returns the column names on the schema | ||||||||
fn columns(schema: &Schema) -> Vec<String> { | ||||||||
schema.fields().iter().map(|f| f.name().clone()).collect() | ||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this change is by formatter