Skip to content

Commit

Permalink
anti joins now respect join filters (#2843)
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove authored Jul 6, 2022
1 parent 94646ac commit ae6dab0
Showing 1 changed file with 69 additions and 12 deletions.
81 changes: 69 additions & 12 deletions datafusion/core/src/physical_plan/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -643,13 +643,6 @@ fn build_batch(
)
.unwrap();

if matches!(join_type, JoinType::Semi | JoinType::Anti) {
return Ok((
RecordBatch::new_empty(Arc::new(schema.clone())),
left_indices,
));
}

let (left_filtered_indices, right_filtered_indices) = if let Some(filter) = filter {
apply_join_filter(
&left_data.1,
Expand All @@ -664,6 +657,13 @@ fn build_batch(
(left_indices, right_indices)
};

if matches!(join_type, JoinType::Semi | JoinType::Anti) {
return Ok((
RecordBatch::new_empty(Arc::new(schema.clone())),
left_filtered_indices,
));
}

build_batch_from_indices(
schema,
&left_data.1,
Expand Down Expand Up @@ -857,7 +857,7 @@ fn apply_join_filter(
)?;

match join_type {
JoinType::Inner | JoinType::Left => {
JoinType::Inner | JoinType::Left | JoinType::Anti | JoinType::Semi => {
// For both INNER and LEFT joins, input arrays contains only indices for matched data.
// Due to this fact it's correct to simply apply filter to intermediate batch and return
// indices for left/right rows satisfying filter predicate
Expand Down Expand Up @@ -931,10 +931,6 @@ fn apply_join_filter(

Ok((left_rebuilt.finish(), right_rebuilt.finish()))
}
_ => Err(DataFusionError::NotImplemented(format!(
"Unexpected filter in {} join",
join_type
))),
}
}

Expand Down Expand Up @@ -2164,6 +2160,67 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn join_anti_with_filter() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let left = build_table(
("col1", &vec![1, 3]),
("col2", &vec![2, 4]),
("col3", &vec![3, 5]),
);
let right = left.clone();

// join on col1
let on = vec![(
Column::new_with_schema("col1", &left.schema())?,
Column::new_with_schema("col1", &right.schema())?,
)];

// build filter b.col2 <> a.col2
let column_indices = vec![
ColumnIndex {
index: 1,
side: JoinSide::Left,
},
ColumnIndex {
index: 1,
side: JoinSide::Right,
},
];
let intermediate_schema = Schema::new(vec![
Field::new("x", DataType::Int32, true),
Field::new("x", DataType::Int32, true),
]);
let filter_expression = Arc::new(BinaryExpr::new(
Arc::new(Column::new("x", 0)),
Operator::NotEq,
Arc::new(Column::new("x", 1)),
)) as Arc<dyn PhysicalExpr>;

let filter =
JoinFilter::new(filter_expression, column_indices, intermediate_schema);

let join = join_with_filter(left, right, on, filter, &JoinType::Anti, false)?;

let columns = columns(&join.schema());
assert_eq!(columns, vec!["col1", "col2", "col3"]);

let stream = join.execute(0, task_ctx)?;
let batches = common::collect(stream).await?;

let expected = vec![
"+------+------+------+",
"| col1 | col2 | col3 |",
"+------+------+------+",
"| 1 | 2 | 3 |",
"| 3 | 4 | 5 |",
"+------+------+------+",
];
assert_batches_sorted_eq!(expected, &batches);
Ok(())
}

#[tokio::test]
async fn join_right_one() -> Result<()> {
let session_ctx = SessionContext::new();
Expand Down

0 comments on commit ae6dab0

Please sign in to comment.