@@ -430,7 +430,7 @@ pub(super) fn get_corrected_filter_mask(
430430 corrected_mask. append_n ( expected_size - corrected_mask. len ( ) , false ) ;
431431 Some ( corrected_mask. finish ( ) )
432432 }
433- JoinType :: LeftMark => {
433+ JoinType :: LeftMark | JoinType :: RightMark => {
434434 for i in 0 ..row_indices_length {
435435 let last_index =
436436 last_index_for_row ( i, row_indices, batch_ids, row_indices_length) ;
@@ -582,6 +582,7 @@ impl Stream for SortMergeJoinStream {
582582 | JoinType :: LeftMark
583583 | JoinType :: Right
584584 | JoinType :: RightSemi
585+ | JoinType :: RightMark
585586 | JoinType :: LeftAnti
586587 | JoinType :: RightAnti
587588 | JoinType :: Full
@@ -691,6 +692,7 @@ impl Stream for SortMergeJoinStream {
691692 | JoinType :: LeftAnti
692693 | JoinType :: RightAnti
693694 | JoinType :: LeftMark
695+ | JoinType :: RightMark
694696 | JoinType :: Full
695697 )
696698 {
@@ -718,6 +720,7 @@ impl Stream for SortMergeJoinStream {
718720 | JoinType :: RightAnti
719721 | JoinType :: Full
720722 | JoinType :: LeftMark
723+ | JoinType :: RightMark
721724 )
722725 {
723726 let record_batch = self . filter_joined_batch ( ) ?;
@@ -1042,16 +1045,23 @@ impl SortMergeJoinStream {
10421045 | JoinType :: LeftAnti
10431046 | JoinType :: RightAnti
10441047 | JoinType :: LeftMark
1048+ | JoinType :: RightMark
10451049 ) {
10461050 join_streamed = !self . streamed_joined ;
10471051 }
10481052 }
10491053 Ordering :: Equal => {
10501054 if matches ! (
10511055 self . join_type,
1052- JoinType :: LeftSemi | JoinType :: LeftMark | JoinType :: RightSemi
1056+ JoinType :: LeftSemi
1057+ | JoinType :: LeftMark
1058+ | JoinType :: RightSemi
1059+ | JoinType :: RightMark
10531060 ) {
1054- mark_row_as_match = matches ! ( self . join_type, JoinType :: LeftMark ) ;
1061+ mark_row_as_match = matches ! (
1062+ self . join_type,
1063+ JoinType :: LeftMark | JoinType :: RightMark
1064+ ) ;
10551065 // if the join filter is specified then its needed to output the streamed index
10561066 // only if it has not been emitted before
10571067 // the `join_filter_matched_idxs` keeps track on if streamed index has a successful
@@ -1266,31 +1276,32 @@ impl SortMergeJoinStream {
12661276
12671277 // The row indices of joined buffered batch
12681278 let right_indices: UInt64Array = chunk. buffered_indices . finish ( ) ;
1269- let mut right_columns = if matches ! ( self . join_type, JoinType :: LeftMark ) {
1270- vec ! [ Arc :: new( is_not_null( & right_indices) ?) as ArrayRef ]
1271- } else if matches ! (
1272- self . join_type,
1273- JoinType :: LeftSemi
1274- | JoinType :: LeftAnti
1275- | JoinType :: RightAnti
1276- | JoinType :: RightSemi
1277- ) {
1278- vec ! [ ]
1279- } else if let Some ( buffered_idx) = chunk. buffered_batch_idx {
1280- fetch_right_columns_by_idxs (
1281- & self . buffered_data ,
1282- buffered_idx,
1283- & right_indices,
1284- ) ?
1285- } else {
1286- // If buffered batch none, meaning it is null joined batch.
1287- // We need to create null arrays for buffered columns to join with streamed rows.
1288- create_unmatched_columns (
1279+ let mut right_columns =
1280+ if matches ! ( self . join_type, JoinType :: LeftMark | JoinType :: RightMark ) {
1281+ vec ! [ Arc :: new( is_not_null( & right_indices) ?) as ArrayRef ]
1282+ } else if matches ! (
12891283 self . join_type,
1290- & self . buffered_schema ,
1291- right_indices. len ( ) ,
1292- )
1293- } ;
1284+ JoinType :: LeftSemi
1285+ | JoinType :: LeftAnti
1286+ | JoinType :: RightAnti
1287+ | JoinType :: RightSemi
1288+ ) {
1289+ vec ! [ ]
1290+ } else if let Some ( buffered_idx) = chunk. buffered_batch_idx {
1291+ fetch_right_columns_by_idxs (
1292+ & self . buffered_data ,
1293+ buffered_idx,
1294+ & right_indices,
1295+ ) ?
1296+ } else {
1297+ // If buffered batch none, meaning it is null joined batch.
1298+ // We need to create null arrays for buffered columns to join with streamed rows.
1299+ create_unmatched_columns (
1300+ self . join_type ,
1301+ & self . buffered_schema ,
1302+ right_indices. len ( ) ,
1303+ )
1304+ } ;
12941305
12951306 // Prepare the columns we apply join filter on later.
12961307 // Only for joined rows between streamed and buffered.
@@ -1309,7 +1320,7 @@ impl SortMergeJoinStream {
13091320 get_filter_column ( & self . filter , & left_columns, & right_cols)
13101321 } else if matches ! (
13111322 self . join_type,
1312- JoinType :: RightAnti | JoinType :: RightSemi
1323+ JoinType :: RightAnti | JoinType :: RightSemi | JoinType :: RightMark
13131324 ) {
13141325 let right_cols = fetch_right_columns_by_idxs (
13151326 & self . buffered_data ,
@@ -1375,6 +1386,7 @@ impl SortMergeJoinStream {
13751386 | JoinType :: LeftAnti
13761387 | JoinType :: RightAnti
13771388 | JoinType :: LeftMark
1389+ | JoinType :: RightMark
13781390 | JoinType :: Full
13791391 ) {
13801392 self . staging_output_record_batches
@@ -1475,6 +1487,7 @@ impl SortMergeJoinStream {
14751487 | JoinType :: LeftAnti
14761488 | JoinType :: RightAnti
14771489 | JoinType :: LeftMark
1490+ | JoinType :: RightMark
14781491 | JoinType :: Full
14791492 ) )
14801493 {
@@ -1537,7 +1550,7 @@ impl SortMergeJoinStream {
15371550
15381551 if matches ! (
15391552 self . join_type,
1540- JoinType :: Left | JoinType :: LeftMark | JoinType :: Right
1553+ JoinType :: Left | JoinType :: LeftMark | JoinType :: Right | JoinType :: RightMark
15411554 ) {
15421555 let null_mask = compute:: not ( corrected_mask) ?;
15431556 let null_joined_batch = filter_record_batch ( & record_batch, & null_mask) ?;
@@ -1658,7 +1671,7 @@ fn create_unmatched_columns(
16581671 schema : & SchemaRef ,
16591672 size : usize ,
16601673) -> Vec < ArrayRef > {
1661- if matches ! ( join_type, JoinType :: LeftMark ) {
1674+ if matches ! ( join_type, JoinType :: LeftMark | JoinType :: RightMark ) {
16621675 vec ! [ Arc :: new( BooleanArray :: from( vec![ false ; size] ) ) as ArrayRef ]
16631676 } else {
16641677 schema
0 commit comments