Skip to content

Commit 5db8778

Browse files
Hisoka-Xcloud-fan
authored andcommitted
[SPARK-43781][SQL] Fix IllegalStateException when cogrouping two datasets derived from the same source
### What changes were proposed in this pull request? When cogroup two datasets derived from same source, eg: ```scala val inputType = StructType(Array(StructField("id", LongType, false), StructField("type", StringType, false))) val keyType = StructType(Array(StructField("id", LongType, false))) val inputRows = new java.util.ArrayList[Row]() inputRows.add(Row(1L, "foo")) inputRows.add(Row(1L, "bar")) inputRows.add(Row(2L, "foo")) val input = spark.createDataFrame(inputRows, inputType) val fooGroups = input.filter("type = 'foo'").groupBy("id").as(RowEncoder(keyType), RowEncoder(inputType)) val barGroups = input.filter("type = 'bar'").groupBy("id").as(RowEncoder(keyType), RowEncoder(inputType)) val result = fooGroups.cogroup(barGroups) { case (row, iterator, iterator1) => iterator.toSeq ++ iterator1.toSeq }(RowEncoder(inputType)).collect() ``` The error will be reported: ``` 21:03:27.651 ERROR org.apache.spark.executor.Executor: Exception in task 1.0 in stage 0.0 (TID 1) java.lang.IllegalStateException: Couldn't find id#19L in [id#0L,type#1] at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:80) at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:73) ... ``` The reason are `DeduplicateRelations` rewrite `LocalRelation` but can't rewrite `left(right)Group` and `left(right)Attr` in `CoGroup`. In fact, the `Join` will face same situation. But `Join` regenerate plan when invoke itself to avoid this situation. Please refer https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala#L1089 This PR let `DeduplicateRelations` handle with `CoGroup` case ### Why are the changes needed? Fix IllegalStateException when cogrouping two datasets derived from the same source ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Add new test Closes #41554 from Hisoka-X/SPARK-43781_cogrouping_two_datasets. Authored-by: Jia Fan <fanjiaeminem@qq.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 3164ff5 commit 5db8778

File tree

2 files changed

+63
-2
lines changed

2 files changed

+63
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis
1919

2020
import scala.collection.mutable
2121

22-
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, AttributeSet, NamedExpression, OuterReference, SubqueryExpression}
22+
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, AttributeReference, AttributeSet, Expression, NamedExpression, OuterReference, SubqueryExpression}
2323
import org.apache.spark.sql.catalyst.plans.logical._
2424
import org.apache.spark.sql.catalyst.rules.Rule
2525
import org.apache.spark.sql.catalyst.trees.TreePattern._
@@ -228,7 +228,42 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
228228
if (attrMap.isEmpty) {
229229
planWithNewChildren
230230
} else {
231-
planWithNewChildren.rewriteAttrs(attrMap)
231+
def rewriteAttrs[T <: Expression](
232+
exprs: Seq[T],
233+
attrMap: Map[Attribute, Attribute]): Seq[T] = {
234+
exprs.map { expr =>
235+
expr.transformWithPruning(_.containsPattern(ATTRIBUTE_REFERENCE)) {
236+
case a: AttributeReference => attrMap.getOrElse(a, a)
237+
}.asInstanceOf[T]
238+
}
239+
}
240+
241+
planWithNewChildren match {
242+
// TODO (SPARK-44754): we should handle all special cases here.
243+
case c: CoGroup =>
244+
// SPARK-43781: CoGroup is a special case, `rewriteAttrs` will incorrectly update
245+
// some fields that do not need to be updated. We need to update the output
246+
// attributes of CoGroup manually.
247+
val leftAttrMap = attrMap.filter(a => c.left.output.contains(a._2))
248+
val rightAttrMap = attrMap.filter(a => c.right.output.contains(a._2))
249+
val newLeftAttr = rewriteAttrs(c.leftAttr, leftAttrMap)
250+
val newRightAttr = rewriteAttrs(c.rightAttr, rightAttrMap)
251+
val newLeftGroup = rewriteAttrs(c.leftGroup, leftAttrMap)
252+
val newRightGroup = rewriteAttrs(c.rightGroup, rightAttrMap)
253+
val newLeftOrder = rewriteAttrs(c.leftOrder, leftAttrMap)
254+
val newRightOrder = rewriteAttrs(c.rightOrder, rightAttrMap)
255+
val newKeyDes = c.keyDeserializer.asInstanceOf[UnresolvedDeserializer]
256+
.copy(inputAttributes = newLeftGroup)
257+
val newLeftDes = c.leftDeserializer.asInstanceOf[UnresolvedDeserializer]
258+
.copy(inputAttributes = newLeftAttr)
259+
val newRightDes = c.rightDeserializer.asInstanceOf[UnresolvedDeserializer]
260+
.copy(inputAttributes = newRightAttr)
261+
c.copy(keyDeserializer = newKeyDes, leftDeserializer = newLeftDes,
262+
rightDeserializer = newRightDes, leftGroup = newLeftGroup,
263+
rightGroup = newRightGroup, leftAttr = newLeftAttr, rightAttr = newRightAttr,
264+
leftOrder = newLeftOrder, rightOrder = newRightOrder)
265+
case _ => planWithNewChildren.rewriteAttrs(attrMap)
266+
}
232267
}
233268
} else {
234269
planWithNewSubquery.withNewChildren(newChildren.toSeq)

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -916,6 +916,32 @@ class DatasetSuite extends QueryTest
916916
}
917917
}
918918

919+
test("SPARK-43781: cogroup two datasets derived from the same source") {
920+
val inputType = StructType(Array(StructField("id", LongType, false),
921+
StructField("type", StringType, false)))
922+
val keyType = StructType(Array(StructField("id", LongType, false)))
923+
924+
val inputRows = new java.util.ArrayList[Row]()
925+
inputRows.add(Row(1L, "foo"))
926+
inputRows.add(Row(1L, "bar"))
927+
inputRows.add(Row(2L, "foo"))
928+
val input = spark.createDataFrame(inputRows, inputType)
929+
val fooGroups = input.filter("type = 'foo'").groupBy("id").as(ExpressionEncoder(keyType),
930+
ExpressionEncoder(inputType))
931+
val barGroups = input.filter("type = 'bar'").groupBy("id").as(ExpressionEncoder(keyType),
932+
ExpressionEncoder(inputType))
933+
934+
val result = fooGroups.cogroup(barGroups) { case (row, iterator, iterator1) =>
935+
iterator.toSeq ++ iterator1.toSeq
936+
}(ExpressionEncoder(inputType)).collect()
937+
assert(result.length == 3)
938+
939+
val result2 = fooGroups.cogroupSorted(barGroups)($"id")($"id") {
940+
case (row, iterator, iterator1) => iterator.toSeq ++ iterator1.toSeq
941+
}(ExpressionEncoder(inputType)).collect()
942+
assert(result2.length == 3)
943+
}
944+
919945
test("SPARK-34806: observation on datasets") {
920946
val namedObservation = Observation("named")
921947
val unnamedObservation = Observation()

0 commit comments

Comments
 (0)