Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Sort Merge Join LeftSemi issues when JoinFilter is set #10304

Merged
merged 20 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 199 additions & 29 deletions datafusion/physical-plan/src/joins/sort_merge_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,13 @@ use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use crate::expressions::PhysicalSortExpr;
use crate::joins::utils::{
build_join_schema, check_join_is_valid, estimate_join_statistics,
partitioned_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef,
};
use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
use crate::{
execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, Distribution,
ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties,
RecordBatchStream, SendableRecordBatchStream, Statistics,
};

use arrow::array::*;
use arrow::compute::{self, concat_batches, take, SortOptions};
use arrow::datatypes::{DataType, SchemaRef, TimeUnit};
use arrow::error::ArrowError;
use futures::{Stream, StreamExt};
use hashbrown::HashSet;

use datafusion_common::{
internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, Result,
};
Expand All @@ -54,7 +45,17 @@ use datafusion_execution::TaskContext;
use datafusion_physical_expr::equivalence::join_equivalence_properties;
use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement};

use futures::{Stream, StreamExt};
use crate::expressions::PhysicalSortExpr;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change is by formatter

use crate::joins::utils::{
build_join_schema, check_join_is_valid, estimate_join_statistics,
partitioned_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef,
};
use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
use crate::{
execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, Distribution,
ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties,
RecordBatchStream, SendableRecordBatchStream, Statistics,
};

/// join execution plan executes partitions in parallel and combines them into a set of
/// partitions.
Expand Down Expand Up @@ -491,6 +492,10 @@ struct StreamedBatch {
pub output_indices: Vec<StreamedJoinedChunk>,
/// Index of currently scanned batch from buffered data
pub buffered_batch_idx: Option<usize>,
/// Indices that found a match for the given join filter
/// Used for semi joins to keep track the streaming index which got a join filter match
/// and already emitted to the output.
pub join_filter_matched_idxs: HashSet<u64>,
}

impl StreamedBatch {
Expand All @@ -502,6 +507,7 @@ impl StreamedBatch {
join_arrays,
output_indices: vec![],
buffered_batch_idx: None,
join_filter_matched_idxs: HashSet::new(),
}
}

Expand All @@ -512,6 +518,7 @@ impl StreamedBatch {
join_arrays: vec![],
output_indices: vec![],
buffered_batch_idx: None,
join_filter_matched_idxs: HashSet::new(),
}
}

Expand Down Expand Up @@ -989,8 +996,21 @@ impl SMJStream {
}
}
Ordering::Equal => {
if matches!(self.join_type, JoinType::LeftSemi) {
if matches!(self.join_type, JoinType::LeftSemi) && self.filter.is_some() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can combine this and below which both are for JoinType::LeftSemi under single if block.

join_streamed = !self
.streamed_batch
.join_filter_matched_idxs
.contains(&(self.streamed_batch.idx as u64))
&& !self.streamed_joined;
Copy link
Member

@viirya viirya May 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If self.streamed_joined is false, join_filter_matched_idxs always doesn't contain self.streamed_batch.idx, so the two conditions are duplicated as they are both true.

If self.streamed_joined is true, this and check is failed, the another condition doesn't matter.

I'm not sure what this is added to check.

I don't see it addresses the issue in https://github.com/apache/datafusion/pull/10304/files#r1601943239.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I got what you want to do here. It is better to add a comment for later readers.

// if the join filter specified there can be references to buffered columns
// so buffered columns are needed to access them
join_buffered = join_streamed;
}
if matches!(self.join_type, JoinType::LeftSemi) && self.filter.is_none() {
join_streamed = !self.streamed_joined;
// if the join filter specified there can be references to buffered columns
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block knows self.filter.is_none(), why you still do join_buffered = self.filter.is_some();?

// so buffered columns are needed to access them
join_buffered = self.filter.is_some();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LeftSemi doesn't join buffered side, why we want to do this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. As the filter uses buffered columns, we need to access to it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the comment doesn't look correct as we don't actually join buffered columns.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, but if there are more rows at buffered side are matched on keys, won't it add additional joined pairs with nulls and buffered rows?

Copy link
Contributor Author

@comphead comphead May 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope I got the concern right and wrapped it into SQL query

query II
select * from (
with
t1 as (
    select 11 a, 12 b union all
    select 11 a, 13 b),
t2 as (
    select 11 a, 12 b union all
    select 11 a, 12 b union all
    select 11 a, 14 b
    )
select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b)
) order by 1, 2;
----
11 12
11 13

it passes, I can add it to slt file as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not able to run such test just because of other SMJ issue not related to this PR:
If the join filter is set and for the same streaming index there are matching rows more or equal to a batch size then the query just stuck. Likely the problem is in polling state and it can be easily reproduced on main branch.

  #[tokio::test]
    async fn test_11() -> Result<()> {
        let ctx: SessionContext = SessionContext::new();

        let sql = "set datafusion.optimizer.prefer_hash_join = false;";
        let _ = ctx.sql(sql).await?.collect().await?;

        let sql = "set datafusion.execution.batch_size = 1";
        let _ = ctx.sql(sql).await?.collect().await?;

        let sql = "
        select * from (
        with
        t1 as (
            select 12 a, 12 b
            ),
        t2 as (
            select 12 a, 12 b
            )
            select t1.* from t1 join t2 on t1.a = t2.b where t1.a > t2.b
        ) order by 1, 2;
        ";

        let actual = ctx.sql(sql).await?.collect().await?;


        Ok(())
    }

I'll file a separate issue for this one, but perhaps we can go with this PR because potential problem you talking about cannot ever be hit because of the issue above

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for creating the ticket.

I think that current fix to the LeftSemi isn't correct due to the additional joined pairs of nulls and buffered rows. I don't think we should move forward with it because of that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm putting the PR to draft until we can check if PR requires modifications to avoid addition join pairs of nulls and buffered rows. It depends on #10491

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have checked the case, it doesn't fail, however it produces more rows than expected, looking into this

Copy link
Member

@viirya viirya May 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, what I meant is actually it will produce some rows that are not correct results (due to the additional joined pairs of nulls and buffered rows).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// if the join filter specified there can be references to buffered columns
// so buffered columns are needed to access them
join_buffered = self.filter.is_some();

}
if matches!(
self.join_type,
Expand Down Expand Up @@ -1134,17 +1154,15 @@ impl SMJStream {
.collect::<Result<Vec<_>, ArrowError>>()?;

let buffered_indices: UInt64Array = chunk.buffered_indices.finish();

comphead marked this conversation as resolved.
Show resolved Hide resolved
let mut buffered_columns =
if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) {
vec![]
} else if let Some(buffered_idx) = chunk.buffered_batch_idx {
self.buffered_data.batches[buffered_idx]
.batch
.columns()
.iter()
.map(|column| take(column, &buffered_indices, None))
.collect::<Result<Vec<_>, ArrowError>>()?
get_buffered_columns(
&self.buffered_data,
buffered_idx,
&buffered_indices,
)?
} else {
self.buffered_schema
.fields()
Expand All @@ -1161,6 +1179,15 @@ impl SMJStream {
let filter_columns = if chunk.buffered_batch_idx.is_some() {
if matches!(self.join_type, JoinType::Right) {
get_filter_column(&self.filter, &buffered_columns, &streamed_columns)
} else if matches!(self.join_type, JoinType::LeftSemi) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this should also check for JoinType::Left (and the clause above also check for JoinType::RightSemi 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have more to go, LeftAnti is first as it prevents TPCH to run and then double check RightSemi as well, good point

// unwrap is safe here as we check is_some on top of if statement
let buffered_columns = get_buffered_columns(
&self.buffered_data,
chunk.buffered_batch_idx.unwrap(),
&buffered_indices,
)?;

get_filter_column(&self.filter, &streamed_columns, &buffered_columns)
} else {
get_filter_column(&self.filter, &streamed_columns, &buffered_columns)
}
Expand Down Expand Up @@ -1195,7 +1222,17 @@ impl SMJStream {
.into_array(filter_batch.num_rows())?;

// The selection mask of the filter
let mask = datafusion_common::cast::as_boolean_array(&filter_result)?;
let mut mask =
datafusion_common::cast::as_boolean_array(&filter_result)?;

let maybe_filtered_join_mask: Option<(BooleanArray, Vec<u64>)> =
get_filtered_join_mask(self.join_type, streamed_indices, mask);
if let Some(ref filtered_join_mask) = maybe_filtered_join_mask {
mask = &filtered_join_mask.0;
self.streamed_batch
.join_filter_matched_idxs
.extend(&filtered_join_mask.1);
}

// Push the filtered batch to the output
let filtered_batch =
Expand Down Expand Up @@ -1365,6 +1402,69 @@ fn get_filter_column(
filter_columns
}

/// Get `buffered_indices` rows for `buffered_data[buffered_batch_idx]`
#[inline(always)]
fn get_buffered_columns(
buffered_data: &BufferedData,
buffered_batch_idx: usize,
buffered_indices: &UInt64Array,
) -> Result<Vec<ArrayRef>, ArrowError> {
buffered_data.batches[buffered_batch_idx]
.batch
.columns()
.iter()
.map(|column| take(column, &buffered_indices, None))
.collect::<Result<Vec<_>, ArrowError>>()
}

// Calculate join filter bit mask considering join type specifics
comphead marked this conversation as resolved.
Show resolved Hide resolved
// `streamed_indices` - array of streamed datasource JOINED row indices
// `mask` - array booleans representing computed join filter expression eval result:
// true = the row index matches the join filter
// false = the row index doesn't match the join filter
// `streamed_indices` have the same length as `mask`
fn get_filtered_join_mask(
join_type: JoinType,
streamed_indices: UInt64Array,
mask: &BooleanArray,
) -> Option<(BooleanArray, Vec<u64>)> {
// for LeftSemi Join the filter mask should be calculated in its own way:
// if we find at least one matching row for specific streaming index
// we don't need to check any others for the same index
if matches!(join_type, JoinType::LeftSemi) {
comphead marked this conversation as resolved.
Show resolved Hide resolved
// have we seen a filter match for a streaming index before
let mut seen_as_true: bool = false;
let streamed_indices_length = streamed_indices.len();
let mut corrected_mask: BooleanBuilder =
BooleanBuilder::with_capacity(streamed_indices_length);

let mut filter_matched_indices: Vec<u64> = vec![];

#[allow(clippy::needless_range_loop)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder why ignore clippy here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clippy doesn't like for loops anymore ....

for i in 0..streamed_indices_length {
// LeftSemi respects only first true values for specific streaming index,
// others true values for the same index must be false
if mask.value(i) && !seen_as_true {
seen_as_true = true;
corrected_mask.append_value(true);
filter_matched_indices.push(streamed_indices.value(i));
viirya marked this conversation as resolved.
Show resolved Hide resolved
} else {
corrected_mask.append_value(false);
}

// if switched to next streaming index(e.g. from 0 to 1, or from 1 to 2), we reset seen_as_true flag
if i < streamed_indices_length - 1
&& streamed_indices.value(i) != streamed_indices.value(i + 1)
{
seen_as_true = false;
}
}
Some((corrected_mask.finish(), filter_matched_indices))
} else {
None
}
}

/// Buffered data contains all buffered batches with one unique join key
#[derive(Debug, Default)]
struct BufferedData {
Expand Down Expand Up @@ -1604,24 +1704,28 @@ fn is_join_arrays_equal(
mod tests {
use std::sync::Arc;

use crate::expressions::Column;
use crate::joins::utils::JoinOn;
use crate::joins::SortMergeJoinExec;
use crate::memory::MemoryExec;
use crate::test::build_table_i32;
use crate::{common, ExecutionPlan};

use arrow::array::{Date32Array, Date64Array, Int32Array};
use arrow::compute::SortOptions;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use arrow_array::{BooleanArray, UInt64Array};

use datafusion_common::JoinType::LeftSemi;
use datafusion_common::{
assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType, Result,
};
use datafusion_execution::config::SessionConfig;
use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use datafusion_execution::TaskContext;

use crate::expressions::Column;
use crate::joins::sort_merge_join::get_filtered_join_mask;
use crate::joins::utils::JoinOn;
use crate::joins::SortMergeJoinExec;
use crate::memory::MemoryExec;
use crate::test::build_table_i32;
use crate::{common, ExecutionPlan};

fn build_table(
a: (&str, &Vec<i32>),
b: (&str, &Vec<i32>),
Expand Down Expand Up @@ -2641,6 +2745,72 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn left_semi_join_filtered_mask() -> Result<()> {
assert_eq!(
get_filtered_join_mask(
LeftSemi,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should test a type other than LeftSemi as negative test coverage 🤔

UInt64Array::from(vec![0, 0, 1, 1]),
&BooleanArray::from(vec![true, true, false, false])
),
Some((BooleanArray::from(vec![true, false, false, false]), vec![0]))
);

assert_eq!(
get_filtered_join_mask(
LeftSemi,
UInt64Array::from(vec![0, 1]),
&BooleanArray::from(vec![true, true])
),
Some((BooleanArray::from(vec![true, true]), vec![0, 1]))
);

assert_eq!(
get_filtered_join_mask(
LeftSemi,
UInt64Array::from(vec![0, 1]),
&BooleanArray::from(vec![false, true])
),
Some((BooleanArray::from(vec![false, true]), vec![1]))
);

assert_eq!(
get_filtered_join_mask(
LeftSemi,
UInt64Array::from(vec![0, 1]),
&BooleanArray::from(vec![true, false])
),
Some((BooleanArray::from(vec![true, false]), vec![0]))
);

assert_eq!(
get_filtered_join_mask(
LeftSemi,
UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
&BooleanArray::from(vec![false, true, true, true, true, true])
),
Some((
BooleanArray::from(vec![false, true, false, true, false, false]),
vec![0, 1]
))
);

assert_eq!(
get_filtered_join_mask(
LeftSemi,
UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
&BooleanArray::from(vec![false, false, false, false, false, true])
),
Some((
BooleanArray::from(vec![false, false, false, false, false, true]),
vec![1]
))
);

Ok(())
}

/// Returns the column names on the schema
fn columns(schema: &Schema) -> Vec<String> {
schema.fields().iter().map(|f| f.name().clone()).collect()
Expand Down
Loading