Skip to content

Commit 11a25cb

Browse files
committed
Fix column picks for LeftMark and RightMark
1 parent cfaeee5 commit 11a25cb

File tree

2 files changed

+148
-26
lines changed

2 files changed

+148
-26
lines changed

datafusion/physical-plan/src/joins/sort_merge_join/stream.rs

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1559,29 +1559,32 @@ impl SortMergeJoinStream {
15591559
null_joined_batch.num_rows(),
15601560
);
15611561

1562-
let columns = if !matches!(self.join_type, JoinType::Right) {
1563-
// For left joins, the first columns are the left columns.
1564-
// Critical: There is a bug here still because the match directions.
1565-
let mut left_columns = null_joined_batch
1566-
.columns()
1567-
.iter()
1568-
.take(left_columns_length)
1569-
.cloned()
1570-
.collect::<Vec<_>>();
1571-
1572-
left_columns.extend(right_columns);
1573-
left_columns
1574-
} else {
1575-
// For right joins, the first columns are the right columns.
1576-
let left_columns = null_joined_batch
1577-
.columns()
1578-
.iter()
1579-
.skip(right_columns_length)
1580-
.cloned()
1581-
.collect::<Vec<_>>();
1582-
1583-
right_columns.extend(left_columns);
1584-
right_columns
1562+
let columns = match self.join_type {
1563+
JoinType::Right => {
1564+
// The first columns are the right columns.
1565+
let left_columns = null_joined_batch
1566+
.columns()
1567+
.iter()
1568+
.skip(right_columns_length)
1569+
.cloned()
1570+
.collect::<Vec<_>>();
1571+
1572+
right_columns.extend(left_columns);
1573+
right_columns
1574+
}
1575+
JoinType::Left | JoinType::LeftMark | JoinType::RightMark => {
1576+
// The first columns are the left columns.
1577+
let mut left_columns = null_joined_batch
1578+
.columns()
1579+
.iter()
1580+
.take(left_columns_length)
1581+
.cloned()
1582+
.collect::<Vec<_>>();
1583+
1584+
left_columns.extend(right_columns);
1585+
left_columns
1586+
}
1587+
_ => exec_err!("Did not expect join type {}", self.join_type)?,
15851588
};
15861589

15871590
// Push the streamed/buffered batch joined nulls to the output

datafusion/physical-plan/src/joins/sort_merge_join/tests.rs

Lines changed: 122 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -644,10 +644,8 @@ async fn join_right_one() -> Result<()> {
644644
Ok(())
645645
}
646646

647-
648647
#[tokio::test]
649648
async fn join_right_different_columns_count_with_filter() -> Result<()> {
650-
651649
// select *
652650
// from t1
653651
// right join t2 on t1.b1 = t2.b1 and t1.a1 > t2.a2
@@ -707,7 +705,6 @@ async fn join_right_different_columns_count_with_filter() -> Result<()> {
707705

708706
#[tokio::test]
709707
async fn join_left_different_columns_count_with_filter() -> Result<()> {
710-
711708
// select *
712709
// from t2
713710
// left join t1 on t2.b1 = t1.b1 and t2.a2 > t1.a1
@@ -765,6 +762,128 @@ async fn join_left_different_columns_count_with_filter() -> Result<()> {
765762
Ok(())
766763
}
767764

765+
#[tokio::test]
766+
async fn join_left_mark_different_columns_count_with_filter() -> Result<()> {
767+
// select *
768+
// from t2
769+
// left mark join t1 on t2.b1 = t1.b1 and t2.a2 > t1.a1
770+
771+
let left = build_table_two_cols(
772+
("a2", &vec![10, 20, 30]),
773+
("b1", &vec![4, 5, 6]), // 6 does not exist on the right
774+
);
775+
776+
let right = build_table(
777+
("a1", &vec![1, 21, 3]), // 20(t2.a2) > 1(t1.a1)
778+
("b1", &vec![4, 5, 7]),
779+
("c1", &vec![7, 8, 9]),
780+
);
781+
782+
let on = vec![(
783+
Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
784+
Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
785+
)];
786+
787+
let filter = JoinFilter::new(
788+
Arc::new(BinaryExpr::new(
789+
Arc::new(Column::new("a2", 0)),
790+
Operator::Gt,
791+
Arc::new(Column::new("a1", 1)),
792+
)),
793+
vec![
794+
ColumnIndex {
795+
index: 0,
796+
side: JoinSide::Left,
797+
},
798+
ColumnIndex {
799+
index: 0,
800+
side: JoinSide::Right,
801+
},
802+
],
803+
Arc::new(Schema::new(vec![
804+
Field::new("a2", DataType::Int32, true),
805+
Field::new("a1", DataType::Int32, true),
806+
])),
807+
);
808+
809+
let (_, batches) =
810+
join_collect_with_filter(left, right, on, filter, LeftMark).await?;
811+
812+
// The output order is important as SMJ preserves sortedness
813+
// LeftMark returns all left rows with a boolean mark column
814+
assert_snapshot!(batches_to_string(&batches), @r#"
815+
+----+----+-------+
816+
| a2 | b1 | mark |
817+
+----+----+-------+
818+
| 10 | 4 | true |
819+
| 20 | 5 | false |
820+
| 30 | 6 | false |
821+
+----+----+-------+
822+
"#);
823+
Ok(())
824+
}
825+
826+
#[tokio::test]
827+
async fn join_right_mark_different_columns_count_with_filter() -> Result<()> {
828+
// select *
829+
// from t1
830+
// right mark join t2 on t1.b1 = t2.b1 and t1.a1 > t2.a2
831+
832+
let left = build_table(
833+
("a1", &vec![1, 21, 3]), // 21(t1.a1) > 20(t2.a2)
834+
("b1", &vec![4, 5, 7]),
835+
("c1", &vec![7, 8, 9]),
836+
);
837+
838+
let right = build_table_two_cols(
839+
("a2", &vec![10, 20, 30]),
840+
("b1", &vec![4, 5, 6]), // 6 does not exist on the left
841+
);
842+
843+
let on = vec![(
844+
Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
845+
Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
846+
)];
847+
848+
let filter = JoinFilter::new(
849+
Arc::new(BinaryExpr::new(
850+
Arc::new(Column::new("a1", 0)),
851+
Operator::Gt,
852+
Arc::new(Column::new("a2", 1)),
853+
)),
854+
vec![
855+
ColumnIndex {
856+
index: 0,
857+
side: JoinSide::Left,
858+
},
859+
ColumnIndex {
860+
index: 0,
861+
side: JoinSide::Right,
862+
},
863+
],
864+
Arc::new(Schema::new(vec![
865+
Field::new("a1", DataType::Int32, true),
866+
Field::new("a2", DataType::Int32, true),
867+
])),
868+
);
869+
870+
let (_, batches) =
871+
join_collect_with_filter(left, right, on, filter, RightMark).await?;
872+
873+
// The output order is important as SMJ preserves sortedness
874+
// RightMark returns all right rows with a boolean mark column
875+
assert_snapshot!(batches_to_string(&batches), @r#"
876+
+----+----+-------+
877+
| a2 | b1 | mark |
878+
+----+----+-------+
879+
| 10 | 4 | false |
880+
| 20 | 5 | true |
881+
| 30 | 6 | false |
882+
+----+----+-------+
883+
"#);
884+
Ok(())
885+
}
886+
768887
#[tokio::test]
769888
async fn join_full_one() -> Result<()> {
770889
let left = build_table(

0 commit comments

Comments
 (0)