diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index ec83fe3f2af8..143a726d31b1 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -487,7 +487,6 @@ struct StreamedBatch { /// The join key arrays of streamed batch which are used to compare with buffered batches /// and to produce output. They are produced by evaluating `on` expressions. pub join_arrays: Vec, - /// Chunks of indices from buffered side (may be nulls) joined to streamed pub output_indices: Vec, /// Index of currently scanned batch from buffered data @@ -1021,6 +1020,15 @@ impl SMJStream { join_streamed = true; join_buffered = true; }; + + if matches!(self.join_type, JoinType::LeftAnti) && self.filter.is_some() { + join_streamed = !self + .streamed_batch + .join_filter_matched_idxs + .contains(&(self.streamed_batch.idx as u64)) + && !self.streamed_joined; + join_buffered = join_streamed; + } } Ordering::Greater => { if matches!(self.join_type, JoinType::Full) { @@ -1181,7 +1189,10 @@ impl SMJStream { let filter_columns = if chunk.buffered_batch_idx.is_some() { if matches!(self.join_type, JoinType::Right) { get_filter_column(&self.filter, &buffered_columns, &streamed_columns) - } else if matches!(self.join_type, JoinType::LeftSemi) { + } else if matches!( + self.join_type, + JoinType::LeftSemi | JoinType::LeftAnti + ) { // unwrap is safe here as we check is_some on top of if statement let buffered_columns = get_buffered_columns( &self.buffered_data, @@ -1228,7 +1239,15 @@ impl SMJStream { datafusion_common::cast::as_boolean_array(&filter_result)?; let maybe_filtered_join_mask: Option<(BooleanArray, Vec)> = - get_filtered_join_mask(self.join_type, streamed_indices, mask); + get_filtered_join_mask( + self.join_type, + streamed_indices, + mask, + &self.streamed_batch.join_filter_matched_idxs, + &self.buffered_data.scanning_batch_idx, + &self.buffered_data.batches.len(), + ); + if let Some(ref filtered_join_mask) = maybe_filtered_join_mask { mask = &filtered_join_mask.0; self.streamed_batch @@ -1419,51 +1438,87 @@ fn get_buffered_columns( .collect::, ArrowError>>() } -// Calculate join filter bit mask considering join type specifics -// `streamed_indices` - array of streamed datasource JOINED row indices -// `mask` - array booleans representing computed join filter expression eval result: -// true = the row index matches the join filter -// false = the row index doesn't match the join filter -// `streamed_indices` have the same length as `mask` +/// Calculate join filter bit mask considering join type specifics +/// `streamed_indices` - array of streamed datasource JOINED row indices +/// `mask` - array booleans representing computed join filter expression eval result: +/// true = the row index matches the join filter +/// false = the row index doesn't match the join filter +/// `streamed_indices` have the same length as `mask` +/// `matched_indices` array of streaming indices that already has a join filter match +/// `scanning_batch_idx` current buffered batch +/// `buffered_batches_len` how many batches are in buffered data fn get_filtered_join_mask( join_type: JoinType, streamed_indices: UInt64Array, mask: &BooleanArray, + matched_indices: &HashSet, + scanning_buffered_batch_idx: &usize, + buffered_batches_len: &usize, ) -> Option<(BooleanArray, Vec)> { - // for LeftSemi Join the filter mask should be calculated in its own way: - // if we find at least one matching row for specific streaming index - // we don't need to check any others for the same index - if matches!(join_type, JoinType::LeftSemi) { - // have we seen a filter match for a streaming index before - let mut seen_as_true: bool = false; - let streamed_indices_length = streamed_indices.len(); - let mut corrected_mask: BooleanBuilder = - BooleanBuilder::with_capacity(streamed_indices_length); - - let mut filter_matched_indices: Vec = vec![]; - - #[allow(clippy::needless_range_loop)] - for i in 0..streamed_indices_length { - // LeftSemi respects only first true values for specific streaming index, - // others true values for the same index must be false - if mask.value(i) && !seen_as_true { - seen_as_true = true; - corrected_mask.append_value(true); - filter_matched_indices.push(streamed_indices.value(i)); - } else { - corrected_mask.append_value(false); + let mut seen_as_true: bool = false; + let streamed_indices_length = streamed_indices.len(); + let mut corrected_mask: BooleanBuilder = + BooleanBuilder::with_capacity(streamed_indices_length); + + let mut filter_matched_indices: Vec = vec![]; + + #[allow(clippy::needless_range_loop)] + match join_type { + // for LeftSemi Join the filter mask should be calculated in its own way: + // if we find at least one matching row for specific streaming index + // we don't need to check any others for the same index + JoinType::LeftSemi => { + // have we seen a filter match for a streaming index before + for i in 0..streamed_indices_length { + // LeftSemi respects only first true values for specific streaming index, + // others true values for the same index must be false + if mask.value(i) && !seen_as_true { + seen_as_true = true; + corrected_mask.append_value(true); + filter_matched_indices.push(streamed_indices.value(i)); + } else { + corrected_mask.append_value(false); + } + + // if switched to next streaming index(e.g. from 0 to 1, or from 1 to 2), we reset seen_as_true flag + if i < streamed_indices_length - 1 + && streamed_indices.value(i) != streamed_indices.value(i + 1) + { + seen_as_true = false; + } } + Some((corrected_mask.finish(), filter_matched_indices)) + } + // LeftAnti semantics: return true if for every x in the collection, p(x) is false. + // the true(if any) flag needs to be set only once per streaming index + // to prevent duplicates in the output + JoinType::LeftAnti => { + // have we seen a filter match for a streaming index before + for i in 0..streamed_indices_length { + if mask.value(i) && !seen_as_true { + seen_as_true = true; + filter_matched_indices.push(streamed_indices.value(i)); + } - // if switched to next streaming index(e.g. from 0 to 1, or from 1 to 2), we reset seen_as_true flag - if i < streamed_indices_length - 1 - && streamed_indices.value(i) != streamed_indices.value(i + 1) - { - seen_as_true = false; + // if switched to next streaming index(e.g. from 0 to 1, or from 1 to 2), we reset seen_as_true flag + if (i < streamed_indices_length - 1 + && streamed_indices.value(i) != streamed_indices.value(i + 1)) + || (i == streamed_indices_length - 1 + && *scanning_buffered_batch_idx == buffered_batches_len - 1) + { + corrected_mask.append_value( + !matched_indices.contains(&streamed_indices.value(i)) + && !seen_as_true, + ); + seen_as_true = false; + } else { + corrected_mask.append_value(false); + } } + + Some((corrected_mask.finish(), filter_matched_indices)) } - Some((corrected_mask.finish(), filter_matched_indices)) - } else { - None + _ => None, } } @@ -1711,8 +1766,9 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; use arrow_array::{BooleanArray, UInt64Array}; + use hashbrown::HashSet; - use datafusion_common::JoinType::LeftSemi; + use datafusion_common::JoinType::{LeftAnti, LeftSemi}; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType, Result, }; @@ -2754,7 +2810,10 @@ mod tests { get_filtered_join_mask( LeftSemi, UInt64Array::from(vec![0, 0, 1, 1]), - &BooleanArray::from(vec![true, true, false, false]) + &BooleanArray::from(vec![true, true, false, false]), + &HashSet::new(), + &0, + &0 ), Some((BooleanArray::from(vec![true, false, false, false]), vec![0])) ); @@ -2763,7 +2822,10 @@ mod tests { get_filtered_join_mask( LeftSemi, UInt64Array::from(vec![0, 1]), - &BooleanArray::from(vec![true, true]) + &BooleanArray::from(vec![true, true]), + &HashSet::new(), + &0, + &0 ), Some((BooleanArray::from(vec![true, true]), vec![0, 1])) ); @@ -2772,7 +2834,10 @@ mod tests { get_filtered_join_mask( LeftSemi, UInt64Array::from(vec![0, 1]), - &BooleanArray::from(vec![false, true]) + &BooleanArray::from(vec![false, true]), + &HashSet::new(), + &0, + &0 ), Some((BooleanArray::from(vec![false, true]), vec![1])) ); @@ -2781,7 +2846,10 @@ mod tests { get_filtered_join_mask( LeftSemi, UInt64Array::from(vec![0, 1]), - &BooleanArray::from(vec![true, false]) + &BooleanArray::from(vec![true, false]), + &HashSet::new(), + &0, + &0 ), Some((BooleanArray::from(vec![true, false]), vec![0])) ); @@ -2790,7 +2858,10 @@ mod tests { get_filtered_join_mask( LeftSemi, UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), - &BooleanArray::from(vec![false, true, true, true, true, true]) + &BooleanArray::from(vec![false, true, true, true, true, true]), + &HashSet::new(), + &0, + &0 ), Some(( BooleanArray::from(vec![false, true, false, true, false, false]), @@ -2802,7 +2873,10 @@ mod tests { get_filtered_join_mask( LeftSemi, UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), - &BooleanArray::from(vec![false, false, false, false, false, true]) + &BooleanArray::from(vec![false, false, false, false, false, true]), + &HashSet::new(), + &0, + &0 ), Some(( BooleanArray::from(vec![false, false, false, false, false, true]), @@ -2813,6 +2887,89 @@ mod tests { Ok(()) } + #[tokio::test] + async fn left_anti_join_filtered_mask() -> Result<()> { + assert_eq!( + get_filtered_join_mask( + LeftAnti, + UInt64Array::from(vec![0, 0, 1, 1]), + &BooleanArray::from(vec![true, true, false, false]), + &HashSet::new(), + &0, + &1 + ), + Some((BooleanArray::from(vec![false, false, false, true]), vec![0])) + ); + + assert_eq!( + get_filtered_join_mask( + LeftAnti, + UInt64Array::from(vec![0, 1]), + &BooleanArray::from(vec![true, true]), + &HashSet::new(), + &0, + &1 + ), + Some((BooleanArray::from(vec![false, false]), vec![0, 1])) + ); + + assert_eq!( + get_filtered_join_mask( + LeftAnti, + UInt64Array::from(vec![0, 1]), + &BooleanArray::from(vec![false, true]), + &HashSet::new(), + &0, + &1 + ), + Some((BooleanArray::from(vec![true, false]), vec![1])) + ); + + assert_eq!( + get_filtered_join_mask( + LeftAnti, + UInt64Array::from(vec![0, 1]), + &BooleanArray::from(vec![true, false]), + &HashSet::new(), + &0, + &1 + ), + Some((BooleanArray::from(vec![false, true]), vec![0])) + ); + + assert_eq!( + get_filtered_join_mask( + LeftAnti, + UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), + &BooleanArray::from(vec![false, true, true, true, true, true]), + &HashSet::new(), + &0, + &1 + ), + Some(( + BooleanArray::from(vec![false, false, false, false, false, false]), + vec![0, 1] + )) + ); + + assert_eq!( + get_filtered_join_mask( + LeftAnti, + UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), + &BooleanArray::from(vec![false, false, false, false, false, true]), + &HashSet::new(), + &0, + &1 + ), + Some(( + BooleanArray::from(vec![false, false, true, false, false, false]), + vec![1] + )) + ); + + Ok(()) + } + /// Returns the column names on the schema fn columns(schema: &Schema) -> Vec { schema.fields().iter().map(|f| f.name().clone()).collect() diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index 3a27d9693d00..babb7dc8fd6b 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -378,24 +378,6 @@ select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b != 11 12 11 13 -#LEFTANTI tests -# returns no rows instead of correct result -#query III -#select * from ( -#with -#t1 as ( -# select 11 a, 12 b, 1 c union all -# select 11 a, 13 b, 2 c), -#t2 as ( -# select 11 a, 12 b, 3 c union all -# select 11 a, 14 b, 4 c -# ) -#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) -#) order by 1, 2; -#---- -#11 12 1 -#11 13 2 - # Set batch size to 1 for sort merge join to test scenario when data spread across multiple batches statement ok set datafusion.execution.batch_size = 1; @@ -431,5 +413,108 @@ select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b != 11 12 11 13 +#LEFTANTI tests +statement ok +set datafusion.execution.batch_size = 10; + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b), +t2 as ( + select 11 a, 13 c union all + select 11 a, 14 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +) order by 1, 2 +---- +11 12 + +query III +select * from ( +with +t1 as ( + select 11 a, 12 b, 1 c union all + select 11 a, 13 b, 2 c), +t2 as ( + select 11 a, 12 b, 3 c union all + select 11 a, 14 b, 4 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) +) order by 1, 2; +---- +11 12 1 +11 13 2 + +query III +select * from ( +with +t1 as ( + select 11 a, 12 b, 1 c union all + select 11 a, 13 b, 2 c), +t2 as ( + select 11 a, 12 b, 3 c where false + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) +) order by 1, 2; +---- +11 12 1 +11 13 2 + +# Test LEFT ANTI with cross batch data distribution +statement ok +set datafusion.execution.batch_size = 1; + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b), +t2 as ( + select 11 a, 13 c union all + select 11 a, 14 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +) order by 1, 2 +---- +11 12 + +query III +select * from ( +with +t1 as ( + select 11 a, 12 b, 1 c union all + select 11 a, 13 b, 2 c), +t2 as ( + select 11 a, 12 b, 3 c union all + select 11 a, 14 b, 4 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) +) order by 1, 2; +---- +11 12 1 +11 13 2 + +query III +select * from ( +with +t1 as ( + select 11 a, 12 b, 1 c union all + select 11 a, 13 b, 2 c), +t2 as ( + select 11 a, 12 b, 3 c where false + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) +) order by 1, 2; +---- +11 12 1 +11 13 2 + +# return sql params back to default values statement ok set datafusion.optimizer.prefer_hash_join = true; + +statement ok +set datafusion.execution.batch_size = 8192; +