Skip to content

Commit bd95a6b

Browse files
feat: Support swap for RightMark Join (#17651)
* feat: Support swap for `RightMark` Join * add flag * fmt * add comment + fix test * Update datafusion/physical-plan/src/joins/sort_merge_join/tests.rs Co-authored-by: Oleks V <comphead@users.noreply.github.com> --------- Co-authored-by: Oleks V <comphead@users.noreply.github.com>
1 parent 6a61304 commit bd95a6b

File tree

14 files changed

+249
-47
lines changed

14 files changed

+249
-47
lines changed

datafusion/common/src/join_type.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ impl JoinType {
109109
| JoinType::RightSemi
110110
| JoinType::LeftAnti
111111
| JoinType::RightAnti
112+
| JoinType::LeftMark
113+
| JoinType::RightMark
112114
)
113115
}
114116
}

datafusion/core/tests/fuzz_cases/join_fuzz.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ async fn test_right_mark_join_1k() {
314314
JoinType::RightMark,
315315
None,
316316
)
317-
.run_test(&[NljHj], false)
317+
.run_test(&[HjSmj, NljHj], false)
318318
.await
319319
}
320320

@@ -326,7 +326,7 @@ async fn test_right_mark_join_1k_filtered() {
326326
JoinType::RightMark,
327327
Some(Box::new(col_lt_col_filter)),
328328
)
329-
.run_test(&[NljHj], false)
329+
.run_test(&[HjSmj, NljHj], false)
330330
.await
331331
}
332332

@@ -555,7 +555,7 @@ async fn test_right_mark_join_1k_binary() {
555555
JoinType::RightMark,
556556
None,
557557
)
558-
.run_test(&[NljHj], false)
558+
.run_test(&[HjSmj, NljHj], false)
559559
.await
560560
}
561561

@@ -567,7 +567,7 @@ async fn test_right_mark_join_1k_binary_filtered() {
567567
JoinType::RightMark,
568568
Some(Box::new(col_lt_col_filter)),
569569
)
570-
.run_test(&[NljHj], false)
570+
.run_test(&[HjSmj, NljHj], false)
571571
.await
572572
}
573573

datafusion/core/tests/physical_optimizer/join_selection.rs

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,61 @@ async fn test_join_with_swap_semi() {
369369
}
370370
}
371371

372+
#[tokio::test]
373+
async fn test_join_with_swap_mark() {
374+
let join_types = [JoinType::LeftMark, JoinType::RightMark];
375+
for join_type in join_types {
376+
let (big, small) = create_big_and_small();
377+
378+
let join = HashJoinExec::try_new(
379+
Arc::clone(&big),
380+
Arc::clone(&small),
381+
vec![(
382+
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()),
383+
Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()),
384+
)],
385+
None,
386+
&join_type,
387+
None,
388+
PartitionMode::Partitioned,
389+
NullEquality::NullEqualsNothing,
390+
)
391+
.unwrap();
392+
393+
let original_schema = join.schema();
394+
395+
let optimized_join = JoinSelection::new()
396+
.optimize(Arc::new(join), &ConfigOptions::new())
397+
.unwrap();
398+
399+
let swapped_join = optimized_join
400+
.as_any()
401+
.downcast_ref::<HashJoinExec>()
402+
.expect(
403+
"A proj is not required to swap columns back to their original order",
404+
);
405+
406+
assert_eq!(swapped_join.schema().fields().len(), 2);
407+
assert_eq!(
408+
swapped_join
409+
.left()
410+
.partition_statistics(None)
411+
.unwrap()
412+
.total_byte_size,
413+
Precision::Inexact(8192)
414+
);
415+
assert_eq!(
416+
swapped_join
417+
.right()
418+
.partition_statistics(None)
419+
.unwrap()
420+
.total_byte_size,
421+
Precision::Inexact(2097152)
422+
);
423+
assert_eq!(original_schema, swapped_join.schema());
424+
}
425+
}
426+
372427
/// Compare the input plan with the plan after running the probe order optimizer.
373428
macro_rules! assert_optimized {
374429
($EXPECTED_LINES: expr, $PLAN: expr) => {
@@ -576,7 +631,8 @@ async fn test_nl_join_with_swap(join_type: JoinType) {
576631
case::left_semi(JoinType::LeftSemi),
577632
case::left_anti(JoinType::LeftAnti),
578633
case::right_semi(JoinType::RightSemi),
579-
case::right_anti(JoinType::RightAnti)
634+
case::right_anti(JoinType::RightAnti),
635+
case::right_mark(JoinType::RightMark)
580636
)]
581637
#[tokio::test]
582638
async fn test_nl_join_with_swap_no_proj(join_type: JoinType) {

datafusion/expr/src/logical_plan/builder.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1696,7 +1696,10 @@ pub fn build_join_schema(
16961696
);
16971697

16981698
let (schema1, schema2) = match join_type {
1699-
JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => (left, right),
1699+
JoinType::Right
1700+
| JoinType::RightSemi
1701+
| JoinType::RightAnti
1702+
| JoinType::RightMark => (left, right),
17001703
_ => (right, left),
17011704
};
17021705

datafusion/optimizer/src/decorrelate_predicate_subquery.rs

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ use datafusion_common::{internal_err, plan_err, Column, Result};
3131
use datafusion_expr::expr::{Exists, InSubquery};
3232
use datafusion_expr::expr_rewriter::create_col_from_scalar_expr;
3333
use datafusion_expr::logical_plan::{JoinType, Subquery};
34-
use datafusion_expr::utils::{conjunction, split_conjunction_owned};
34+
use datafusion_expr::utils::{conjunction, expr_to_columns, split_conjunction_owned};
3535
use datafusion_expr::{
3636
exists, in_subquery, lit, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter,
3737
LogicalPlan, LogicalPlanBuilder, Operator,
@@ -342,7 +342,7 @@ fn build_join(
342342
replace_qualified_name(filter, &all_correlated_cols, &alias).map(Some)
343343
})?;
344344

345-
let join_filter = match (join_filter_opt, in_predicate_opt) {
345+
let join_filter = match (join_filter_opt, in_predicate_opt.clone()) {
346346
(
347347
Some(join_filter),
348348
Some(Expr::BinaryExpr(BinaryExpr {
@@ -371,6 +371,51 @@ fn build_join(
371371
(None, None) => lit(true),
372372
_ => return Ok(None),
373373
};
374+
375+
if matches!(join_type, JoinType::LeftMark | JoinType::RightMark) {
376+
let right_schema = sub_query_alias.schema();
377+
378+
// Gather all columns needed for the join filter + predicates
379+
let mut needed = std::collections::HashSet::new();
380+
expr_to_columns(&join_filter, &mut needed)?;
381+
if let Some(ref in_pred) = in_predicate_opt {
382+
expr_to_columns(in_pred, &mut needed)?;
383+
}
384+
385+
// Keep only columns that actually belong to the RIGHT child, and sort by their
386+
// position in the right schema for deterministic order.
387+
let mut right_cols_idx_and_col: Vec<(usize, Column)> = needed
388+
.into_iter()
389+
.filter_map(|c| right_schema.index_of_column(&c).ok().map(|idx| (idx, c)))
390+
.collect();
391+
392+
right_cols_idx_and_col.sort_by_key(|(idx, _)| *idx);
393+
394+
let right_proj_exprs: Vec<Expr> = right_cols_idx_and_col
395+
.into_iter()
396+
.map(|(_, c)| Expr::Column(c))
397+
.collect();
398+
399+
let right_projected = if !right_proj_exprs.is_empty() {
400+
LogicalPlanBuilder::from(sub_query_alias.clone())
401+
.project(right_proj_exprs)?
402+
.build()?
403+
} else {
404+
// Degenerate case: no right columns referenced by the predicate(s)
405+
sub_query_alias.clone()
406+
};
407+
let new_plan = LogicalPlanBuilder::from(left.clone())
408+
.join_on(right_projected, join_type, Some(join_filter))?
409+
.build()?;
410+
411+
debug!(
412+
"predicate subquery optimized:\n{}",
413+
new_plan.display_indent()
414+
);
415+
416+
return Ok(Some(new_plan));
417+
}
418+
374419
// join our sub query into the main plan
375420
let new_plan = LogicalPlanBuilder::from(left.clone())
376421
.join_on(sub_query_alias, join_type, Some(join_filter))?

datafusion/physical-optimizer/src/join_selection.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,11 @@ pub(crate) fn swap_join_according_to_unboundedness(
514514
match (*partition_mode, *join_type) {
515515
(
516516
_,
517-
JoinType::Right | JoinType::RightSemi | JoinType::RightAnti | JoinType::Full,
517+
JoinType::Right
518+
| JoinType::RightSemi
519+
| JoinType::RightAnti
520+
| JoinType::RightMark
521+
| JoinType::Full,
518522
) => internal_err!("{join_type} join cannot be swapped for unbounded input."),
519523
(PartitionMode::Partitioned, _) => {
520524
hash_join.swap_inputs(PartitionMode::Partitioned)

datafusion/physical-plan/src/joins/hash_join/exec.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,8 @@ impl HashJoinExec {
690690
| JoinType::RightSemi
691691
| JoinType::LeftAnti
692692
| JoinType::RightAnti
693+
| JoinType::LeftMark
694+
| JoinType::RightMark
693695
) || self.projection.is_some()
694696
{
695697
Ok(Arc::new(new_join))

datafusion/physical-plan/src/joins/nested_loop_join.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,8 @@ impl NestedLoopJoinExec {
379379
| JoinType::RightSemi
380380
| JoinType::LeftAnti
381381
| JoinType::RightAnti
382+
| JoinType::LeftMark
383+
| JoinType::RightMark
382384
) || self.projection.is_some()
383385
{
384386
Arc::new(new_join)

datafusion/physical-plan/src/joins/sort_merge_join/stream.rs

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ pub(super) fn get_corrected_filter_mask(
430430
corrected_mask.append_n(expected_size - corrected_mask.len(), false);
431431
Some(corrected_mask.finish())
432432
}
433-
JoinType::LeftMark => {
433+
JoinType::LeftMark | JoinType::RightMark => {
434434
for i in 0..row_indices_length {
435435
let last_index =
436436
last_index_for_row(i, row_indices, batch_ids, row_indices_length);
@@ -582,6 +582,7 @@ impl Stream for SortMergeJoinStream {
582582
| JoinType::LeftMark
583583
| JoinType::Right
584584
| JoinType::RightSemi
585+
| JoinType::RightMark
585586
| JoinType::LeftAnti
586587
| JoinType::RightAnti
587588
| JoinType::Full
@@ -691,6 +692,7 @@ impl Stream for SortMergeJoinStream {
691692
| JoinType::LeftAnti
692693
| JoinType::RightAnti
693694
| JoinType::LeftMark
695+
| JoinType::RightMark
694696
| JoinType::Full
695697
)
696698
{
@@ -718,6 +720,7 @@ impl Stream for SortMergeJoinStream {
718720
| JoinType::RightAnti
719721
| JoinType::Full
720722
| JoinType::LeftMark
723+
| JoinType::RightMark
721724
)
722725
{
723726
let record_batch = self.filter_joined_batch()?;
@@ -1042,16 +1045,23 @@ impl SortMergeJoinStream {
10421045
| JoinType::LeftAnti
10431046
| JoinType::RightAnti
10441047
| JoinType::LeftMark
1048+
| JoinType::RightMark
10451049
) {
10461050
join_streamed = !self.streamed_joined;
10471051
}
10481052
}
10491053
Ordering::Equal => {
10501054
if matches!(
10511055
self.join_type,
1052-
JoinType::LeftSemi | JoinType::LeftMark | JoinType::RightSemi
1056+
JoinType::LeftSemi
1057+
| JoinType::LeftMark
1058+
| JoinType::RightSemi
1059+
| JoinType::RightMark
10531060
) {
1054-
mark_row_as_match = matches!(self.join_type, JoinType::LeftMark);
1061+
mark_row_as_match = matches!(
1062+
self.join_type,
1063+
JoinType::LeftMark | JoinType::RightMark
1064+
);
10551065
// if the join filter is specified then its needed to output the streamed index
10561066
// only if it has not been emitted before
10571067
// the `join_filter_matched_idxs` keeps track on if streamed index has a successful
@@ -1266,31 +1276,32 @@ impl SortMergeJoinStream {
12661276

12671277
// The row indices of joined buffered batch
12681278
let right_indices: UInt64Array = chunk.buffered_indices.finish();
1269-
let mut right_columns = if matches!(self.join_type, JoinType::LeftMark) {
1270-
vec![Arc::new(is_not_null(&right_indices)?) as ArrayRef]
1271-
} else if matches!(
1272-
self.join_type,
1273-
JoinType::LeftSemi
1274-
| JoinType::LeftAnti
1275-
| JoinType::RightAnti
1276-
| JoinType::RightSemi
1277-
) {
1278-
vec![]
1279-
} else if let Some(buffered_idx) = chunk.buffered_batch_idx {
1280-
fetch_right_columns_by_idxs(
1281-
&self.buffered_data,
1282-
buffered_idx,
1283-
&right_indices,
1284-
)?
1285-
} else {
1286-
// If buffered batch none, meaning it is null joined batch.
1287-
// We need to create null arrays for buffered columns to join with streamed rows.
1288-
create_unmatched_columns(
1279+
let mut right_columns =
1280+
if matches!(self.join_type, JoinType::LeftMark | JoinType::RightMark) {
1281+
vec![Arc::new(is_not_null(&right_indices)?) as ArrayRef]
1282+
} else if matches!(
12891283
self.join_type,
1290-
&self.buffered_schema,
1291-
right_indices.len(),
1292-
)
1293-
};
1284+
JoinType::LeftSemi
1285+
| JoinType::LeftAnti
1286+
| JoinType::RightAnti
1287+
| JoinType::RightSemi
1288+
) {
1289+
vec![]
1290+
} else if let Some(buffered_idx) = chunk.buffered_batch_idx {
1291+
fetch_right_columns_by_idxs(
1292+
&self.buffered_data,
1293+
buffered_idx,
1294+
&right_indices,
1295+
)?
1296+
} else {
1297+
// If buffered batch none, meaning it is null joined batch.
1298+
// We need to create null arrays for buffered columns to join with streamed rows.
1299+
create_unmatched_columns(
1300+
self.join_type,
1301+
&self.buffered_schema,
1302+
right_indices.len(),
1303+
)
1304+
};
12941305

12951306
// Prepare the columns we apply join filter on later.
12961307
// Only for joined rows between streamed and buffered.
@@ -1309,7 +1320,7 @@ impl SortMergeJoinStream {
13091320
get_filter_column(&self.filter, &left_columns, &right_cols)
13101321
} else if matches!(
13111322
self.join_type,
1312-
JoinType::RightAnti | JoinType::RightSemi
1323+
JoinType::RightAnti | JoinType::RightSemi | JoinType::RightMark
13131324
) {
13141325
let right_cols = fetch_right_columns_by_idxs(
13151326
&self.buffered_data,
@@ -1375,6 +1386,7 @@ impl SortMergeJoinStream {
13751386
| JoinType::LeftAnti
13761387
| JoinType::RightAnti
13771388
| JoinType::LeftMark
1389+
| JoinType::RightMark
13781390
| JoinType::Full
13791391
) {
13801392
self.staging_output_record_batches
@@ -1475,6 +1487,7 @@ impl SortMergeJoinStream {
14751487
| JoinType::LeftAnti
14761488
| JoinType::RightAnti
14771489
| JoinType::LeftMark
1490+
| JoinType::RightMark
14781491
| JoinType::Full
14791492
))
14801493
{
@@ -1537,7 +1550,7 @@ impl SortMergeJoinStream {
15371550

15381551
if matches!(
15391552
self.join_type,
1540-
JoinType::Left | JoinType::LeftMark | JoinType::Right
1553+
JoinType::Left | JoinType::LeftMark | JoinType::Right | JoinType::RightMark
15411554
) {
15421555
let null_mask = compute::not(corrected_mask)?;
15431556
let null_joined_batch = filter_record_batch(&record_batch, &null_mask)?;
@@ -1658,7 +1671,7 @@ fn create_unmatched_columns(
16581671
schema: &SchemaRef,
16591672
size: usize,
16601673
) -> Vec<ArrayRef> {
1661-
if matches!(join_type, JoinType::LeftMark) {
1674+
if matches!(join_type, JoinType::LeftMark | JoinType::RightMark) {
16621675
vec![Arc::new(BooleanArray::from(vec![false; size])) as ArrayRef]
16631676
} else {
16641677
schema

0 commit comments

Comments
 (0)