Skip to content

Commit

Permalink
fix: fix CometNativeExec.doCanonicalize for ReusedExchangeExec
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed May 18, 2024
1 parent ec8da30 commit 32fd8be
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 1 deletion.
20 changes: 19 additions & 1 deletion spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -363,14 +363,32 @@ abstract class CometNativeExec extends CometExec {
}

override protected def doCanonicalize(): SparkPlan = {
val canonicalizedPlan = super.doCanonicalize().asInstanceOf[CometNativeExec]
val canonicalizedPlan = super
.doCanonicalize()
.asInstanceOf[CometNativeExec]
.canonicalizePlans()

if (serializedPlanOpt.isDefined) {
// If the plan is a boundary node, we should remove the serialized plan.
canonicalizedPlan.cleanBlock()
} else {
canonicalizedPlan
}
}

/**
* Canonicalizes the plans of Product parameters in Comet native operators.
*/
protected def canonicalizePlans(): CometNativeExec = {
def transform(arg: Any): AnyRef = arg match {
case sparkPlan: SparkPlan => sparkPlan.canonicalized
case other: AnyRef => other
case null => null
}

val newArgs = mapProductIterator(transform)
makeCopy(newArgs).asInstanceOf[CometNativeExec]
}
}

abstract class CometUnaryExec extends CometNativeExec with UnaryExecNode
Expand Down
30 changes: 30 additions & 0 deletions spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,36 @@ class CometExecSuite extends CometTestBase {
}
}

test("fix CometNativeExec.doCanonicalize for ReusedExchangeExec") {
assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark 3.4+")
withSQLConf(
CometConf.COMET_EXEC_BROADCAST_FORCE_ENABLED.key -> "true",
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
withTable("td") {
testData
.withColumn("bucket", $"key" % 3)
.write
.mode(SaveMode.Overwrite)
.bucketBy(2, "bucket")
.format("parquet")
.saveAsTable("td")
val df = sql("""
|SELECT t1.key, t2.key, t3.key
|FROM td AS t1
|JOIN td AS t2 ON t2.key = t1.key
|JOIN td AS t3 ON t3.key = t2.key
|WHERE t1.bucket = 1 AND t2.bucket = 1 AND t3.bucket = 1
|""".stripMargin)
val reusedPlan = ReuseExchangeAndSubquery.apply(df.queryExecution.executedPlan)
val reusedExchanges = collect(reusedPlan) { case r: ReusedExchangeExec =>
r
}
assert(reusedExchanges.size == 1)
assert(reusedExchanges.head.child.isInstanceOf[CometBroadcastExchangeExec])
}
}
}

test("ReusedExchangeExec should work on CometBroadcastExchangeExec") {
assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark 3.4+")
withSQLConf(
Expand Down

0 comments on commit 32fd8be

Please sign in to comment.