diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala index 0f698d8aa6..ccf218cf6c 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -129,11 +129,16 @@ case class CometBroadcastExchangeExec( case AQEShuffleReadExec(s: ShuffleQueryStageExec, _) if s.plan.isInstanceOf[CometPlan] => CometExec.getByteArrayRdd(s.plan.asInstanceOf[CometPlan]).collect() + case s: ShuffleQueryStageExec if s.plan.isInstanceOf[CometPlan] => + CometExec.getByteArrayRdd(s.plan.asInstanceOf[CometPlan]).collect() case ReusedExchangeExec(_, plan) if plan.isInstanceOf[CometPlan] => CometExec.getByteArrayRdd(plan.asInstanceOf[CometPlan]).collect() case AQEShuffleReadExec(ShuffleQueryStageExec(_, ReusedExchangeExec(_, plan), _), _) if plan.isInstanceOf[CometPlan] => CometExec.getByteArrayRdd(plan.asInstanceOf[CometPlan]).collect() + case ShuffleQueryStageExec(_, ReusedExchangeExec(_, plan), _) + if plan.isInstanceOf[CometPlan] => + CometExec.getByteArrayRdd(plan.asInstanceOf[CometPlan]).collect() case AQEShuffleReadExec(s: ShuffleQueryStageExec, _) => throw new CometRuntimeException( "Child of CometBroadcastExchangeExec should be CometExec, " + diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index bfddd74d8e..f7749fc339 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -64,6 +64,22 @@ class CometExecSuite extends CometTestBase { } } + test("ShuffleQueryStageExec could be direct child node of CometBroadcastExchangeExec") { + val table = "src" + withTable(table) { + withView("lv_noalias") { + sql(s"CREATE TABLE $table (key INT, value STRING) USING PARQUET") + sql(s"insert into $table values(238, 'val_238')") + + sql( + "create view lv_noalias as SELECT myTab.* from src " + + "LATERAL VIEW explode(map('key1', 100, 'key2', 200)) myTab limit 2") + val df = sql("select * from lv_noalias a join lv_noalias b on a.key=b.key"); + checkSparkAnswer(df) + } + } + } + test("Sort on single struct should fallback to Spark") { withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",