diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala index d2124a38c9d4e..334dcdbcb4287 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.execution import scala.collection.JavaConverters._ +import scala.concurrent.duration.Duration import scala.util.{Failure, Success} import com.google.protobuf.ByteString @@ -153,23 +154,23 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) () } - val future = spark.sparkContext.submitJob( - rdd = batches, - processPartition = (iter: Iterator[Batch]) => iter.toArray, - partitions = Seq.range(0, numPartitions), - resultHandler = resultHandler, - resultFunc = () => ()) - - // Collect errors and propagate them to the main thread. - future.onComplete { - case Success(_) => - executePlan.eventsManager.postFinished() - case Failure(throwable) => - signal.synchronized { - error = Some(throwable) - signal.notify() - } - }(ThreadUtils.sameThread) + val future = spark.sparkContext + .submitJob( + rdd = batches, + processPartition = (iter: Iterator[Batch]) => iter.toArray, + partitions = Seq.range(0, numPartitions), + resultHandler = resultHandler, + resultFunc = () => ()) + // Collect errors and propagate them to the main thread. + .andThen { + case Success(_) => + executePlan.eventsManager.postFinished() + case Failure(throwable) => + signal.synchronized { + error = Some(throwable) + signal.notify() + } + }(ThreadUtils.sameThread) // The main thread will wait until 0-th partition is available, // then send it to client and wait for the next partition. @@ -199,6 +200,9 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) currentPartitionId += 1 } + ThreadUtils.awaitReady(future, Duration.Inf) + } else { + executePlan.eventsManager.postFinished() } } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index cfa37b86cd41a..498084efb8f3f 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -564,8 +564,7 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with } } - // TODO(SPARK-44474): Reenable Test observe response at SparkConnectServiceSuite - ignore("Test observe response") { + test("Test observe response") { // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21 assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) withTable("test") {