-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix SMJ Left Anti Join when the join filter is set #10724
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -487,7 +487,6 @@ struct StreamedBatch { | |
/// The join key arrays of streamed batch which are used to compare with buffered batches | ||
/// and to produce output. They are produced by evaluating `on` expressions. | ||
pub join_arrays: Vec<ArrayRef>, | ||
|
||
/// Chunks of indices from buffered side (may be nulls) joined to streamed | ||
pub output_indices: Vec<StreamedJoinedChunk>, | ||
/// Index of currently scanned batch from buffered data | ||
|
@@ -1021,6 +1020,15 @@ impl SMJStream { | |
join_streamed = true; | ||
join_buffered = true; | ||
}; | ||
|
||
if matches!(self.join_type, JoinType::LeftAnti) && self.filter.is_some() { | ||
join_streamed = !self | ||
.streamed_batch | ||
.join_filter_matched_idxs | ||
.contains(&(self.streamed_batch.idx as u64)) | ||
&& !self.streamed_joined; | ||
join_buffered = join_streamed; | ||
} | ||
} | ||
Ordering::Greater => { | ||
if matches!(self.join_type, JoinType::Full) { | ||
|
@@ -1181,7 +1189,10 @@ impl SMJStream { | |
let filter_columns = if chunk.buffered_batch_idx.is_some() { | ||
if matches!(self.join_type, JoinType::Right) { | ||
get_filter_column(&self.filter, &buffered_columns, &streamed_columns) | ||
} else if matches!(self.join_type, JoinType::LeftSemi) { | ||
} else if matches!( | ||
self.join_type, | ||
JoinType::LeftSemi | JoinType::LeftAnti | ||
) { | ||
// unwrap is safe here as we check is_some on top of if statement | ||
let buffered_columns = get_buffered_columns( | ||
&self.buffered_data, | ||
|
@@ -1228,7 +1239,15 @@ impl SMJStream { | |
datafusion_common::cast::as_boolean_array(&filter_result)?; | ||
|
||
let maybe_filtered_join_mask: Option<(BooleanArray, Vec<u64>)> = | ||
get_filtered_join_mask(self.join_type, streamed_indices, mask); | ||
get_filtered_join_mask( | ||
self.join_type, | ||
streamed_indices, | ||
mask, | ||
&self.streamed_batch.join_filter_matched_idxs, | ||
&self.buffered_data.scanning_batch_idx, | ||
&self.buffered_data.batches.len(), | ||
); | ||
|
||
if let Some(ref filtered_join_mask) = maybe_filtered_join_mask { | ||
mask = &filtered_join_mask.0; | ||
self.streamed_batch | ||
|
@@ -1419,51 +1438,87 @@ fn get_buffered_columns( | |
.collect::<Result<Vec<_>, ArrowError>>() | ||
} | ||
|
||
// Calculate join filter bit mask considering join type specifics | ||
// `streamed_indices` - array of streamed datasource JOINED row indices | ||
// `mask` - array booleans representing computed join filter expression eval result: | ||
// true = the row index matches the join filter | ||
// false = the row index doesn't match the join filter | ||
// `streamed_indices` have the same length as `mask` | ||
/// Calculate join filter bit mask considering join type specifics | ||
/// `streamed_indices` - array of streamed datasource JOINED row indices | ||
/// `mask` - array booleans representing computed join filter expression eval result: | ||
/// true = the row index matches the join filter | ||
/// false = the row index doesn't match the join filter | ||
/// `streamed_indices` have the same length as `mask` | ||
/// `matched_indices` array of streaming indices that already has a join filter match | ||
/// `scanning_batch_idx` current buffered batch | ||
/// `buffered_batches_len` how many batches are in buffered data | ||
fn get_filtered_join_mask( | ||
join_type: JoinType, | ||
streamed_indices: UInt64Array, | ||
mask: &BooleanArray, | ||
matched_indices: &HashSet<u64>, | ||
scanning_buffered_batch_idx: &usize, | ||
buffered_batches_len: &usize, | ||
) -> Option<(BooleanArray, Vec<u64>)> { | ||
// for LeftSemi Join the filter mask should be calculated in its own way: | ||
// if we find at least one matching row for specific streaming index | ||
// we don't need to check any others for the same index | ||
if matches!(join_type, JoinType::LeftSemi) { | ||
// have we seen a filter match for a streaming index before | ||
let mut seen_as_true: bool = false; | ||
let streamed_indices_length = streamed_indices.len(); | ||
let mut corrected_mask: BooleanBuilder = | ||
BooleanBuilder::with_capacity(streamed_indices_length); | ||
|
||
let mut filter_matched_indices: Vec<u64> = vec![]; | ||
|
||
#[allow(clippy::needless_range_loop)] | ||
for i in 0..streamed_indices_length { | ||
// LeftSemi respects only first true values for specific streaming index, | ||
// others true values for the same index must be false | ||
if mask.value(i) && !seen_as_true { | ||
seen_as_true = true; | ||
corrected_mask.append_value(true); | ||
filter_matched_indices.push(streamed_indices.value(i)); | ||
} else { | ||
corrected_mask.append_value(false); | ||
let mut seen_as_true: bool = false; | ||
let streamed_indices_length = streamed_indices.len(); | ||
let mut corrected_mask: BooleanBuilder = | ||
BooleanBuilder::with_capacity(streamed_indices_length); | ||
|
||
let mut filter_matched_indices: Vec<u64> = vec![]; | ||
|
||
#[allow(clippy::needless_range_loop)] | ||
match join_type { | ||
// for LeftSemi Join the filter mask should be calculated in its own way: | ||
// if we find at least one matching row for specific streaming index | ||
// we don't need to check any others for the same index | ||
JoinType::LeftSemi => { | ||
// have we seen a filter match for a streaming index before | ||
for i in 0..streamed_indices_length { | ||
// LeftSemi respects only first true values for specific streaming index, | ||
// others true values for the same index must be false | ||
if mask.value(i) && !seen_as_true { | ||
seen_as_true = true; | ||
corrected_mask.append_value(true); | ||
filter_matched_indices.push(streamed_indices.value(i)); | ||
} else { | ||
corrected_mask.append_value(false); | ||
} | ||
|
||
// if switched to next streaming index(e.g. from 0 to 1, or from 1 to 2), we reset seen_as_true flag | ||
if i < streamed_indices_length - 1 | ||
&& streamed_indices.value(i) != streamed_indices.value(i + 1) | ||
{ | ||
seen_as_true = false; | ||
} | ||
} | ||
Some((corrected_mask.finish(), filter_matched_indices)) | ||
} | ||
// LeftAnti semantics: return true if for every x in the collection, p(x) is false. | ||
// the true(if any) flag needs to be set only once per streaming index | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "the true(if any) flag" -> "The |
||
// to prevent duplicates in the output | ||
Comment on lines
+1492
to
+1494
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is unclear what |
||
JoinType::LeftAnti => { | ||
// have we seen a filter match for a streaming index before | ||
for i in 0..streamed_indices_length { | ||
if mask.value(i) && !seen_as_true { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I vaguely remember there are subtle semantics with respect to NULLs and Anti joins. It looks like this doe the right thing (looks for true values in the mask). However i wonder if you also should check if This may not be necessary if mask is always non nullable. I am not sure There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a confusion between 0 and nulls, good point, but not sure if its applicanble for booleans, I'll check it UPD: there is a collision between nulls and default values. so for boolean null value will be fetched as false which is okay for this logic as we seeking for the true, and null index cannot be fetched as true |
||
seen_as_true = true; | ||
filter_matched_indices.push(streamed_indices.value(i)); | ||
} | ||
|
||
// if switched to next streaming index(e.g. from 0 to 1, or from 1 to 2), we reset seen_as_true flag | ||
if i < streamed_indices_length - 1 | ||
&& streamed_indices.value(i) != streamed_indices.value(i + 1) | ||
{ | ||
seen_as_true = false; | ||
// if switched to next streaming index(e.g. from 0 to 1, or from 1 to 2), we reset seen_as_true flag | ||
if (i < streamed_indices_length - 1 | ||
&& streamed_indices.value(i) != streamed_indices.value(i + 1)) | ||
|| (i == streamed_indices_length - 1 | ||
&& *scanning_buffered_batch_idx == buffered_batches_len - 1) | ||
Comment on lines
+1506
to
+1507
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The second condition is not in the above comment. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, the second condition is for the last index in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Good point, I'll check that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
{ | ||
corrected_mask.append_value( | ||
!matched_indices.contains(&streamed_indices.value(i)) | ||
&& !seen_as_true, | ||
); | ||
seen_as_true = false; | ||
} else { | ||
corrected_mask.append_value(false); | ||
} | ||
} | ||
|
||
Some((corrected_mask.finish(), filter_matched_indices)) | ||
} | ||
Some((corrected_mask.finish(), filter_matched_indices)) | ||
} else { | ||
None | ||
_ => None, | ||
} | ||
} | ||
|
||
|
@@ -1711,8 +1766,9 @@ mod tests { | |
use arrow::datatypes::{DataType, Field, Schema}; | ||
use arrow::record_batch::RecordBatch; | ||
use arrow_array::{BooleanArray, UInt64Array}; | ||
use hashbrown::HashSet; | ||
|
||
use datafusion_common::JoinType::LeftSemi; | ||
use datafusion_common::JoinType::{LeftAnti, LeftSemi}; | ||
use datafusion_common::{ | ||
assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType, Result, | ||
}; | ||
|
@@ -2754,7 +2810,10 @@ mod tests { | |
get_filtered_join_mask( | ||
LeftSemi, | ||
UInt64Array::from(vec![0, 0, 1, 1]), | ||
&BooleanArray::from(vec![true, true, false, false]) | ||
&BooleanArray::from(vec![true, true, false, false]), | ||
&HashSet::new(), | ||
&0, | ||
&0 | ||
), | ||
Some((BooleanArray::from(vec![true, false, false, false]), vec![0])) | ||
); | ||
|
@@ -2763,7 +2822,10 @@ mod tests { | |
get_filtered_join_mask( | ||
LeftSemi, | ||
UInt64Array::from(vec![0, 1]), | ||
&BooleanArray::from(vec![true, true]) | ||
&BooleanArray::from(vec![true, true]), | ||
&HashSet::new(), | ||
&0, | ||
&0 | ||
), | ||
Some((BooleanArray::from(vec![true, true]), vec![0, 1])) | ||
); | ||
|
@@ -2772,7 +2834,10 @@ mod tests { | |
get_filtered_join_mask( | ||
LeftSemi, | ||
UInt64Array::from(vec![0, 1]), | ||
&BooleanArray::from(vec![false, true]) | ||
&BooleanArray::from(vec![false, true]), | ||
&HashSet::new(), | ||
&0, | ||
&0 | ||
), | ||
Some((BooleanArray::from(vec![false, true]), vec![1])) | ||
); | ||
|
@@ -2781,7 +2846,10 @@ mod tests { | |
get_filtered_join_mask( | ||
LeftSemi, | ||
UInt64Array::from(vec![0, 1]), | ||
&BooleanArray::from(vec![true, false]) | ||
&BooleanArray::from(vec![true, false]), | ||
&HashSet::new(), | ||
&0, | ||
&0 | ||
), | ||
Some((BooleanArray::from(vec![true, false]), vec![0])) | ||
); | ||
|
@@ -2790,7 +2858,10 @@ mod tests { | |
get_filtered_join_mask( | ||
LeftSemi, | ||
UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), | ||
&BooleanArray::from(vec![false, true, true, true, true, true]) | ||
&BooleanArray::from(vec![false, true, true, true, true, true]), | ||
&HashSet::new(), | ||
&0, | ||
&0 | ||
), | ||
Some(( | ||
BooleanArray::from(vec![false, true, false, true, false, false]), | ||
|
@@ -2802,7 +2873,10 @@ mod tests { | |
get_filtered_join_mask( | ||
LeftSemi, | ||
UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), | ||
&BooleanArray::from(vec![false, false, false, false, false, true]) | ||
&BooleanArray::from(vec![false, false, false, false, false, true]), | ||
&HashSet::new(), | ||
&0, | ||
&0 | ||
), | ||
Some(( | ||
BooleanArray::from(vec![false, false, false, false, false, true]), | ||
|
@@ -2813,6 +2887,89 @@ mod tests { | |
Ok(()) | ||
} | ||
|
||
#[tokio::test] | ||
async fn left_anti_join_filtered_mask() -> Result<()> { | ||
assert_eq!( | ||
get_filtered_join_mask( | ||
LeftAnti, | ||
UInt64Array::from(vec![0, 0, 1, 1]), | ||
&BooleanArray::from(vec![true, true, false, false]), | ||
&HashSet::new(), | ||
&0, | ||
&1 | ||
), | ||
Some((BooleanArray::from(vec![false, false, false, true]), vec![0])) | ||
); | ||
|
||
assert_eq!( | ||
get_filtered_join_mask( | ||
LeftAnti, | ||
UInt64Array::from(vec![0, 1]), | ||
&BooleanArray::from(vec![true, true]), | ||
&HashSet::new(), | ||
&0, | ||
&1 | ||
), | ||
Some((BooleanArray::from(vec![false, false]), vec![0, 1])) | ||
); | ||
|
||
assert_eq!( | ||
get_filtered_join_mask( | ||
LeftAnti, | ||
UInt64Array::from(vec![0, 1]), | ||
&BooleanArray::from(vec![false, true]), | ||
&HashSet::new(), | ||
&0, | ||
&1 | ||
), | ||
Some((BooleanArray::from(vec![true, false]), vec![1])) | ||
); | ||
|
||
assert_eq!( | ||
get_filtered_join_mask( | ||
LeftAnti, | ||
UInt64Array::from(vec![0, 1]), | ||
&BooleanArray::from(vec![true, false]), | ||
&HashSet::new(), | ||
&0, | ||
&1 | ||
), | ||
Some((BooleanArray::from(vec![false, true]), vec![0])) | ||
); | ||
|
||
assert_eq!( | ||
get_filtered_join_mask( | ||
LeftAnti, | ||
UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), | ||
&BooleanArray::from(vec![false, true, true, true, true, true]), | ||
&HashSet::new(), | ||
&0, | ||
&1 | ||
), | ||
Some(( | ||
BooleanArray::from(vec![false, false, false, false, false, false]), | ||
vec![0, 1] | ||
)) | ||
); | ||
|
||
assert_eq!( | ||
get_filtered_join_mask( | ||
LeftAnti, | ||
UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), | ||
&BooleanArray::from(vec![false, false, false, false, false, true]), | ||
&HashSet::new(), | ||
&0, | ||
&1 | ||
), | ||
Some(( | ||
BooleanArray::from(vec![false, false, true, false, false, false]), | ||
vec![1] | ||
)) | ||
); | ||
|
||
Ok(()) | ||
} | ||
|
||
/// Returns the column names on the schema | ||
fn columns(schema: &Schema) -> Vec<String> { | ||
schema.fields().iter().map(|f| f.name().clone()).collect() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we have to check (in the lines above) for RightAnti and RightSemi? 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a separate ticket for Right* joins, however I'm not sure how to build a right join on the sql level tbh. I will check it once we have Left* stabilized