diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 20bdbf391c502..2b082578eb256 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -22,7 +22,6 @@ import scala.collection.mutable import com.google.common.collect.{Lists, Maps} import com.google.protobuf.{Any => ProtoAny, ByteString} -import io.grpc.stub.StreamObserver import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.spark.{Partition, SparkEnv, TaskContext} @@ -2083,7 +2082,7 @@ class SparkConnectPlanner(val session: SparkSession) { command: proto.Command, userId: String, sessionId: String, - responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { + responses: mutable.ArrayBuffer[ExecutePlanResponse] = mutable.ArrayBuffer.empty): Unit = { command.getCommandTypeCase match { case proto.Command.CommandTypeCase.REGISTER_FUNCTION => handleRegisterUserDefinedFunction(command.getRegisterFunction) @@ -2096,22 +2095,22 @@ class SparkConnectPlanner(val session: SparkSession) { case proto.Command.CommandTypeCase.EXTENSION => handleCommandPlugin(command.getExtension) case proto.Command.CommandTypeCase.SQL_COMMAND => - handleSqlCommand(command.getSqlCommand, sessionId, responseObserver) + handleSqlCommand(command.getSqlCommand, sessionId, responses) case proto.Command.CommandTypeCase.WRITE_STREAM_OPERATION_START => handleWriteStreamOperationStart( command.getWriteStreamOperationStart, userId, sessionId, - responseObserver) + responses) case proto.Command.CommandTypeCase.STREAMING_QUERY_COMMAND => - handleStreamingQueryCommand(command.getStreamingQueryCommand, sessionId, responseObserver) + handleStreamingQueryCommand(command.getStreamingQueryCommand, sessionId, responses) case proto.Command.CommandTypeCase.STREAMING_QUERY_MANAGER_COMMAND => handleStreamingQueryManagerCommand( command.getStreamingQueryManagerCommand, sessionId, - responseObserver) + responses) case proto.Command.CommandTypeCase.GET_RESOURCES_COMMAND => - handleGetResourcesCommand(sessionId, responseObserver) + handleGetResourcesCommand(sessionId, responses) case _ => throw new UnsupportedOperationException(s"$command not supported.") } } @@ -2119,7 +2118,7 @@ class SparkConnectPlanner(val session: SparkSession) { def handleSqlCommand( getSqlCommand: SqlCommand, sessionId: String, - responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { + responses: mutable.ArrayBuffer[ExecutePlanResponse]): Unit = { // Eagerly execute commands of the provided SQL string. val df = session.sql( getSqlCommand.getSql, @@ -2180,15 +2179,15 @@ class SparkConnectPlanner(val session: SparkSession) { .putAllArgs(getSqlCommand.getArgsMap))) } // Exactly one SQL Command Result Batch - responseObserver.onNext( + responses += ExecutePlanResponse .newBuilder() .setSessionId(sessionId) .setSqlCommandResult(result) - .build()) + .build() // Send Metrics - responseObserver.onNext(SparkConnectStreamHandler.createMetricsResponse(sessionId, df)) + responses += SparkConnectStreamHandler.createMetricsResponse(sessionId, df) } private def handleRegisterUserDefinedFunction( @@ -2408,7 +2407,7 @@ class SparkConnectPlanner(val session: SparkSession) { writeOp: WriteStreamOperationStart, userId: String, sessionId: String, - responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { + responses: mutable.ArrayBuffer[ExecutePlanResponse]): Unit = { val plan = transformRelation(writeOp.getInput) val dataset = Dataset.ofRows(session, logicalPlan = plan) @@ -2473,18 +2472,18 @@ class SparkConnectPlanner(val session: SparkSession) { .setName(Option(query.name).getOrElse("")) .build() - responseObserver.onNext( + responses += ExecutePlanResponse .newBuilder() .setSessionId(sessionId) .setWriteStreamOperationStartResult(result) - .build()) + .build() } def handleStreamingQueryCommand( command: StreamingQueryCommand, sessionId: String, - responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { + responses: mutable.ArrayBuffer[ExecutePlanResponse]): Unit = { val id = command.getQueryId.getId val runId = command.getQueryId.getRunId @@ -2589,12 +2588,12 @@ class SparkConnectPlanner(val session: SparkSession) { throw new IllegalArgumentException("Missing command in StreamingQueryCommand") } - responseObserver.onNext( + responses += ExecutePlanResponse .newBuilder() .setSessionId(sessionId) .setStreamingQueryCommandResult(respBuilder.build()) - .build()) + .build() } private def buildStreamingQueryInstance(query: StreamingQuery): StreamingQueryInstance = { @@ -2615,7 +2614,7 @@ class SparkConnectPlanner(val session: SparkSession) { def handleStreamingQueryManagerCommand( command: StreamingQueryManagerCommand, sessionId: String, - responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { + responses: mutable.ArrayBuffer[ExecutePlanResponse]): Unit = { val respBuilder = StreamingQueryManagerCommandResult.newBuilder() @@ -2650,18 +2649,18 @@ class SparkConnectPlanner(val session: SparkSession) { throw new IllegalArgumentException("Missing command in StreamingQueryManagerCommand") } - responseObserver.onNext( + responses += ExecutePlanResponse .newBuilder() .setSessionId(sessionId) .setStreamingQueryManagerCommandResult(respBuilder.build()) - .build()) + .build() } def handleGetResourcesCommand( sessionId: String, - responseObserver: StreamObserver[proto.ExecutePlanResponse]): Unit = { - responseObserver.onNext( + responses: mutable.ArrayBuffer[ExecutePlanResponse]): Unit = { + responses += proto.ExecutePlanResponse .newBuilder() .setSessionId(sessionId) @@ -2679,7 +2678,7 @@ class SparkConnectPlanner(val session: SparkSession) { .toMap .asJava) .build()) - .build()) + .build() } private val emptyLocalRelation = LocalRelation( diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index 4958fd69b9de2..82709ec370534 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.service import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.util.control.NonFatal import com.google.protobuf.ByteString @@ -113,11 +114,13 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp private def handleCommand(session: SparkSession, request: ExecutePlanRequest): Unit = { val command = request.getPlan.getCommand val planner = new SparkConnectPlanner(session) + val responses = mutable.ArrayBuffer.empty[ExecutePlanResponse] planner.process( command = command, userId = request.getUserContext.getUserId, sessionId = request.getSessionId, - responseObserver = responseObserver) + responses = responses) + responses.foreach(responseObserver.onNext(_)) responseObserver.onCompleted() } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 37d4bec9c8746..30991d8fe4163 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -20,11 +20,9 @@ package org.apache.spark.sql.connect.planner import scala.collection.JavaConverters._ import com.google.protobuf.ByteString -import io.grpc.stub.StreamObserver import org.apache.spark.SparkFunSuite import org.apache.spark.connect.proto -import org.apache.spark.connect.proto.ExecutePlanResponse import org.apache.spark.connect.proto.Expression.{Alias, ExpressionString, UnresolvedStar} import org.apache.spark.sql.{AnalysisException, Dataset, Row} import org.apache.spark.sql.catalyst.InternalRow @@ -44,18 +42,12 @@ import org.apache.spark.unsafe.types.UTF8String */ trait SparkConnectPlanTest extends SharedSparkSession { - class MockObserver extends StreamObserver[proto.ExecutePlanResponse] { - override def onNext(value: ExecutePlanResponse): Unit = {} - override def onError(t: Throwable): Unit = {} - override def onCompleted(): Unit = {} - } - def transform(rel: proto.Relation): logical.LogicalPlan = { new SparkConnectPlanner(spark).transformRelation(rel) } def transform(cmd: proto.Command): Unit = { - new SparkConnectPlanner(spark).process(cmd, "clientId", "sessionId", new MockObserver()) + new SparkConnectPlanner(spark).process(cmd, "clientId", "sessionId") } def readRel: proto.Relation = diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala index d61b54c67c259..b6a9daa8239b3 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala @@ -195,7 +195,7 @@ class SparkConnectPluginRegistrySuite extends SharedSparkSession with SparkConne .build())) .build() - new SparkConnectPlanner(spark).process(plan, "clientId", "sessionId", new MockObserver()) + new SparkConnectPlanner(spark).process(plan, "clientId", "sessionId") assert(spark.sparkContext.getLocalProperty("testingProperty").equals("Martin")) } }