Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SMJ: Add more tests and improve comments #10784

Merged
merged 4 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 9 additions & 22 deletions datafusion/physical-plan/src/joins/sort_merge_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1244,8 +1244,7 @@ impl SMJStream {
streamed_indices,
mask,
&self.streamed_batch.join_filter_matched_idxs,
&self.buffered_data.scanning_batch_idx,
&self.buffered_data.batches.len(),
&self.buffered_data.scanning_offset,
);

if let Some(ref filtered_join_mask) = maybe_filtered_join_mask {
Expand Down Expand Up @@ -1445,15 +1444,13 @@ fn get_buffered_columns(
/// false = the row index doesn't match the join filter
/// `streamed_indices` have the same length as `mask`
/// `matched_indices` array of streaming indices that already has a join filter match
/// `scanning_batch_idx` current buffered batch
/// `buffered_batches_len` how many batches are in buffered data
/// `scanning_buffered_offset` current buffered offset across batches
fn get_filtered_join_mask(
join_type: JoinType,
streamed_indices: UInt64Array,
mask: &BooleanArray,
matched_indices: &HashSet<u64>,
scanning_buffered_batch_idx: &usize,
buffered_batches_len: &usize,
scanning_buffered_offset: &usize,
) -> Option<(BooleanArray, Vec<u64>)> {
let mut seen_as_true: bool = false;
let streamed_indices_length = streamed_indices.len();
Expand Down Expand Up @@ -1489,8 +1486,8 @@ fn get_filtered_join_mask(
}
Some((corrected_mask.finish(), filter_matched_indices))
}
// LeftAnti semantics: return true if for every x in the collection, p(x) is false.
// the true(if any) flag needs to be set only once per streaming index
// LeftAnti semantics: return true if for every x in the collection the join matching filter is false.
// `filter_matched_indices` needs to be set once per streaming index
// to prevent duplicates in the output
JoinType::LeftAnti => {
// have we seen a filter match for a streaming index before
Expand All @@ -1500,11 +1497,13 @@ fn get_filtered_join_mask(
filter_matched_indices.push(streamed_indices.value(i));
}

// if switched to next streaming index(e.g. from 0 to 1, or from 1 to 2), we reset seen_as_true flag
// Reset `seen_as_true` flag and calculate mask for the current streaming index
// - if within the batch it switched to next streaming index(e.g. from 0 to 1, or from 1 to 2)
// - if it is at the end of the all buffered batches for the given streaming index, 0 index comes last
if (i < streamed_indices_length - 1
&& streamed_indices.value(i) != streamed_indices.value(i + 1))
|| (i == streamed_indices_length - 1
&& *scanning_buffered_batch_idx == buffered_batches_len - 1)
&& *scanning_buffered_offset == 0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, why you use scanning_buffered_offset == 0 as condition? It is set to zero when moving to next buffered batch or all buffered batches are scanned for current streamed batch. It doesn't mean anything to the position of streamed indices for the streamed batch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @viirya for the review. We exactly need to check the condition where all buffered batches are scanned for current streamed batch. This is because LeftAnti doesnt know if it matches or not until the very last buffered row comes in. This scenario already tested in slt file

query II
select * from (
with
t1 as (
    select 11 a, 12 b),
t2 as (
    select 11 a, 12 c union all
    select 11 a, 11 c union all
    select 11 a, 15 c
    )
select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c)
) order by 1, 2
----

it works for small and large batches.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We exactly need to check the condition where all buffered batches are scanned for current streamed batch.

This is correct, but what I meant is that scanning_buffered_offset == 0 condition could also be true when it moves to next buffered batch and output size is equal to batch size. When it happens, SMJ also goes to output batches, but obviously not all buffered batches are scanned for current row in the streamed batch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll check that

{
corrected_mask.append_value(
!matched_indices.contains(&streamed_indices.value(i))
Expand Down Expand Up @@ -2813,7 +2812,6 @@ mod tests {
&BooleanArray::from(vec![true, true, false, false]),
&HashSet::new(),
&0,
&0
),
Some((BooleanArray::from(vec![true, false, false, false]), vec![0]))
);
Expand All @@ -2825,7 +2823,6 @@ mod tests {
&BooleanArray::from(vec![true, true]),
&HashSet::new(),
&0,
&0
),
Some((BooleanArray::from(vec![true, true]), vec![0, 1]))
);
Expand All @@ -2837,7 +2834,6 @@ mod tests {
&BooleanArray::from(vec![false, true]),
&HashSet::new(),
&0,
&0
),
Some((BooleanArray::from(vec![false, true]), vec![1]))
);
Expand All @@ -2849,7 +2845,6 @@ mod tests {
&BooleanArray::from(vec![true, false]),
&HashSet::new(),
&0,
&0
),
Some((BooleanArray::from(vec![true, false]), vec![0]))
);
Expand All @@ -2861,7 +2856,6 @@ mod tests {
&BooleanArray::from(vec![false, true, true, true, true, true]),
&HashSet::new(),
&0,
&0
),
Some((
BooleanArray::from(vec![false, true, false, true, false, false]),
Expand All @@ -2876,7 +2870,6 @@ mod tests {
&BooleanArray::from(vec![false, false, false, false, false, true]),
&HashSet::new(),
&0,
&0
),
Some((
BooleanArray::from(vec![false, false, false, false, false, true]),
Expand All @@ -2896,7 +2889,6 @@ mod tests {
&BooleanArray::from(vec![true, true, false, false]),
&HashSet::new(),
&0,
&1
),
Some((BooleanArray::from(vec![false, false, false, true]), vec![0]))
);
Expand All @@ -2908,7 +2900,6 @@ mod tests {
&BooleanArray::from(vec![true, true]),
&HashSet::new(),
&0,
&1
),
Some((BooleanArray::from(vec![false, false]), vec![0, 1]))
);
Expand All @@ -2920,7 +2911,6 @@ mod tests {
&BooleanArray::from(vec![false, true]),
&HashSet::new(),
&0,
&1
),
Some((BooleanArray::from(vec![true, false]), vec![1]))
);
Expand All @@ -2932,7 +2922,6 @@ mod tests {
&BooleanArray::from(vec![true, false]),
&HashSet::new(),
&0,
&1
),
Some((BooleanArray::from(vec![false, true]), vec![0]))
);
Expand All @@ -2944,7 +2933,6 @@ mod tests {
&BooleanArray::from(vec![false, true, true, true, true, true]),
&HashSet::new(),
&0,
&1
),
Some((
BooleanArray::from(vec![false, false, false, false, false, false]),
Expand All @@ -2959,7 +2947,6 @@ mod tests {
&BooleanArray::from(vec![false, false, false, false, false, true]),
&HashSet::new(),
&0,
&1
),
Some((
BooleanArray::from(vec![false, false, true, false, false, false]),
Expand Down
101 changes: 101 additions & 0 deletions datafusion/sqllogictest/test_files/sort_merge_join.slt
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,64 @@ select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.
11 12 1
11 13 2

query II
select * from (
with
t1 as (
select 11 a, 12 b),
t2 as (
select 11 a, 13 c union all
select 11 a, 14 c union all
select 11 a, 15 c
)
select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c)
) order by 1, 2
----
11 12

query II
select * from (
with
t1 as (
select 11 a, 12 b),
t2 as (
select 11 a, 11 c union all
select 11 a, 14 c union all
select 11 a, 15 c
)
select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c)
) order by 1, 2
----

query II
select * from (
with
t1 as (
select 11 a, 12 b),
t2 as (
select 11 a, 12 c union all
select 11 a, 11 c union all
select 11 a, 15 c
)
select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c)
) order by 1, 2
----

query II
select * from (
with
t1 as (
select 11 a, 12 b),
t2 as (
select 11 a, 12 c union all
select 11 a, 14 c union all
select 11 a, 11 c
)
select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c)
) order by 1, 2
----


# Test LEFT ANTI with cross batch data distribution
statement ok
set datafusion.execution.batch_size = 1;
Expand Down Expand Up @@ -512,6 +570,49 @@ select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.
11 12 1
11 13 2

query II
select * from (
with
t1 as (
select 11 a, 12 b),
t2 as (
select 11 a, 13 c union all
select 11 a, 14 c union all
select 11 a, 15 c
)
select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c)
) order by 1, 2
----
11 12

query II
select * from (
with
t1 as (
select 11 a, 12 b),
t2 as (
select 11 a, 12 c union all
select 11 a, 11 c union all
select 11 a, 15 c
)
select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c)
) order by 1, 2
----

query II
select * from (
with
t1 as (
select 11 a, 12 b),
t2 as (
select 11 a, 12 c union all
select 11 a, 14 c union all
select 11 a, 11 c
)
select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c)
) order by 1, 2
----

# return sql params back to default values
statement ok
set datafusion.optimizer.prefer_hash_join = true;
Expand Down