Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,90 @@ class DatasetSuite extends QueryTest
1 -> "a", 2 -> "bc", 3 -> "d")
}

test("cogroup with complex key types") {
// Test cogroup with nested structure as key using existing ClassData
val ds1 = Seq(
(ClassData("x", 1), "left1"),
(ClassData("x", 1), "left2"),
(ClassData("y", 2), "left3")
).toDS()

val ds2 = Seq(
(ClassData("x", 1), 100),
(ClassData("z", 3), 200)
).toDS()

val cogrouped = ds1.groupByKey(_._1).cogroup(ds2.groupByKey(_._1)) {
case (key, left, right) =>
Iterator((key.a, key.b, left.size, right.size))
}

checkDatasetUnorderly(
cogrouped,
("x", 1, 2, 1), // ClassData("x", 1): 2 left, 1 right
("y", 2, 1, 0), // ClassData("y", 2): 1 left, 0 right
("z", 3, 0, 1) // ClassData("z", 3): 0 left, 1 right
)
}

test("cogroup with null keys") {
// Test that null keys are handled correctly - rows with null keys should be grouped together.
val ds1 = Seq(
(Some(1), "a"),
(Some(1), "b"),
(None, "c"),
(None, "d"),
(Some(2), "e")
).toDS()
val ds2 = Seq(
(Some(1), 10),
(None, 20),
(Some(3), 30)
).toDS()

val cogrouped = ds1.groupByKey(_._1).cogroup(ds2.groupByKey(_._1)) {
case (key, left, right) =>
Iterator((key, left.size, right.size))
}

checkDatasetUnorderly(
cogrouped,
(Some(1), 2, 1), // key=1: 2 left ("a","b"), 1 right (10)
(None, 2, 1), // key=null: 2 left ("c","d"), 1 right (20)
(Some(2), 1, 0), // key=2: 1 left ("e"), 0 right
(Some(3), 0, 1) // key=3: 0 left, 1 right (30)
)
}

test("cogroup with empty datasets") {
val ds1 = Seq(1 -> "a", 2 -> "b").toDS()
val ds2 = Seq(2 -> 100, 3 -> 200).toDS()
val emptyDs = spark.emptyDataset[(Int, String)]
val emptyDs2 = spark.emptyDataset[(Int, Long)]

// Helper function to count elements from each side
def countElements[L, R](left: Iterator[L], right: Iterator[R]): (Int, Int) =
(left.size, right.size)

// Empty left: all keys come from right, left iterator is always empty
val emptyLeftResult = emptyDs.groupByKey(_._1).cogroup(ds2.groupByKey(_._1)) {
case (key, left, right) => Iterator((key, countElements(left, right)))
}.collect().sortBy(_._1)
assert(emptyLeftResult === Array((2, (0, 1)), (3, (0, 1))))

// Empty right: all keys come from left, right iterator is always empty
val emptyRightResult = ds1.groupByKey(_._1).cogroup(emptyDs.groupByKey(_._1)) {
case (key, left, right) => Iterator((key, countElements(left, right)))
}.collect().sortBy(_._1)
assert(emptyRightResult === Array((1, (1, 0)), (2, (1, 0))))

// Both empty: result should be empty
val bothEmptyResult = emptyDs.groupByKey(_._1).cogroup(emptyDs2.groupByKey(_._1)) {
case (key, left, right) => Iterator((key, countElements(left, right)))
}.collect()
assert(bothEmptyResult.isEmpty)
}

test("cogroup with groupBy and sorted") {
val left = Seq(1 -> "a", 3 -> "xyz", 5 -> "hello", 3 -> "abc", 3 -> "ijk").toDS()
val right = Seq(2 -> "q", 3 -> "w", 5 -> "x", 5 -> "z", 3 -> "a", 5 -> "y").toDS()
Expand Down