diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala index 83f84f45b317..e0c7d267c604 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala @@ -27,15 +27,15 @@ private[connect] object ProtoUtils { private val MAX_BYTES_SIZE = 8 private val MAX_STRING_SIZE = 1024 - def abbreviate(message: Message): Message = { + def abbreviate(message: Message, maxStringSize: Int = MAX_STRING_SIZE): Message = { val builder = message.toBuilder message.getAllFields.asScala.iterator.foreach { case (field: FieldDescriptor, string: String) if field.getJavaType == FieldDescriptor.JavaType.STRING && string != null => val size = string.size - if (size > MAX_STRING_SIZE) { - builder.setField(field, createString(string.take(MAX_STRING_SIZE), size)) + if (size > maxStringSize) { + builder.setField(field, createString(string.take(maxStringSize), size)) } else { builder.setField(field, string) } @@ -43,8 +43,8 @@ private[connect] object ProtoUtils { case (field: FieldDescriptor, byteString: ByteString) if field.getJavaType == FieldDescriptor.JavaType.BYTE_STRING && byteString != null => val size = byteString.size - if (size > MAX_BYTES_SIZE) { - val prefix = Array.tabulate(MAX_BYTES_SIZE)(byteString.byteAt) + if (size > maxStringSize) { + val prefix = Array.tabulate(maxStringSize)(byteString.byteAt) builder.setField(field, createByteString(prefix, size)) } else { builder.setField(field, byteString) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala index b7b3d2adf9f7..6c2ffa465474 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala @@ -78,7 +78,6 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends // and different exceptions like InterruptedException, ClosedByInterruptException etc. // could be thrown. if (interrupted) { - // Turn the interrupt into OPERATION_CANCELED error. throw new SparkSQLException("OPERATION_CANCELED", Map.empty) } else { // Rethrown the original error. @@ -92,7 +91,9 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends "execute", executeHolder.responseObserver, executeHolder.sessionHolder.userId, - executeHolder.sessionHolder.sessionId) + executeHolder.sessionHolder.sessionId, + Some(executeHolder.eventsManager), + interrupted) } } @@ -148,9 +149,8 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends val planner = new SparkConnectPlanner(executeHolder.sessionHolder) planner.process( command = command, - userId = request.getUserContext.getUserId, - sessionId = request.getSessionId, - responseObserver = responseObserver) + responseObserver = responseObserver, + executeHolder = executeHolder) responseObserver.onCompleted() } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala index 74b4a5f65974..d2124a38c9d4 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.execution import scala.collection.JavaConverters._ +import scala.util.{Failure, Success} import com.google.protobuf.ByteString import io.grpc.stub.StreamObserver @@ -54,13 +55,15 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) throw new IllegalStateException( s"Illegal operation type ${request.getPlan.getOpTypeCase} to be handled here.") } - - // Extract the plan from the request and convert it to a logical plan val planner = new SparkConnectPlanner(sessionHolder) + val tracker = executeHolder.eventsManager.createQueryPlanningTracker val dataframe = - Dataset.ofRows(sessionHolder.session, planner.transformRelation(request.getPlan.getRoot)) + Dataset.ofRows( + sessionHolder.session, + planner.transformRelation(request.getPlan.getRoot), + tracker) responseObserver.onNext(createSchemaResponse(request.getSessionId, dataframe.schema)) - processAsArrowBatches(request.getSessionId, dataframe, responseObserver) + processAsArrowBatches(dataframe, responseObserver, executeHolder) responseObserver.onNext( MetricGenerator.createMetricsResponse(request.getSessionId, dataframe)) if (dataframe.queryExecution.observedMetrics.nonEmpty) { @@ -87,10 +90,11 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) batches.map(b => b -> batches.rowCountInLastBatch) } - private def processAsArrowBatches( - sessionId: String, + def processAsArrowBatches( dataframe: DataFrame, - responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { + responseObserver: StreamObserver[ExecutePlanResponse], + executePlan: ExecuteHolder): Unit = { + val sessionId = executePlan.sessionHolder.sessionId val spark = dataframe.sparkSession val schema = dataframe.schema val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch @@ -120,6 +124,7 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) dataframe.queryExecution.executedPlan match { case LocalTableScanExec(_, rows) => + executePlan.eventsManager.postFinished() converter(rows.iterator).foreach { case (bytes, count) => sendBatch(bytes, count) } @@ -156,13 +161,14 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) resultFunc = () => ()) // Collect errors and propagate them to the main thread. - future.onComplete { result => - result.failed.foreach { throwable => + future.onComplete { + case Success(_) => + executePlan.eventsManager.postFinished() + case Failure(throwable) => signal.synchronized { error = Some(throwable) signal.notify() } - } }(ThreadUtils.sameThread) // The main thread will wait until 0-th partition is available, diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index f82cb6760404..39cb4c1b972b 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -57,8 +57,7 @@ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils} import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket, InvalidPlanInput, LiteralValueProtoConverter, StorageLevelProtoConverter, StreamingListenerPacket, UdfPacket} import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry -import org.apache.spark.sql.connect.service.SessionHolder -import org.apache.spark.sql.connect.service.SparkConnectService +import org.apache.spark.sql.connect.service.{ExecuteHolder, SessionHolder, SparkConnectService} import org.apache.spark.sql.connect.utils.MetricGenerator import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.QueryExecution @@ -86,7 +85,11 @@ final case class InvalidCommandInput( class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { - def session: SparkSession = sessionHolder.session + private[connect] def session: SparkSession = sessionHolder.session + + private[connect] def userId: String = sessionHolder.userId + + private[connect] def sessionId: String = sessionHolder.sessionId private lazy val pythonExec = sys.env.getOrElse("PYSPARK_PYTHON", sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python3")) @@ -2333,56 +2336,58 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { def process( command: proto.Command, - userId: String, - sessionId: String, - responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { + responseObserver: StreamObserver[ExecutePlanResponse], + executeHolder: ExecuteHolder): Unit = { command.getCommandTypeCase match { case proto.Command.CommandTypeCase.REGISTER_FUNCTION => - handleRegisterUserDefinedFunction(command.getRegisterFunction) + handleRegisterUserDefinedFunction(command.getRegisterFunction, executeHolder) case proto.Command.CommandTypeCase.REGISTER_TABLE_FUNCTION => - handleRegisterUserDefinedTableFunction(command.getRegisterTableFunction) + handleRegisterUserDefinedTableFunction(command.getRegisterTableFunction, executeHolder) case proto.Command.CommandTypeCase.WRITE_OPERATION => - handleWriteOperation(command.getWriteOperation) + handleWriteOperation(command.getWriteOperation, executeHolder) case proto.Command.CommandTypeCase.CREATE_DATAFRAME_VIEW => - handleCreateViewCommand(command.getCreateDataframeView) + handleCreateViewCommand(command.getCreateDataframeView, executeHolder) case proto.Command.CommandTypeCase.WRITE_OPERATION_V2 => - handleWriteOperationV2(command.getWriteOperationV2) + handleWriteOperationV2(command.getWriteOperationV2, executeHolder) case proto.Command.CommandTypeCase.EXTENSION => - handleCommandPlugin(command.getExtension) + handleCommandPlugin(command.getExtension, executeHolder) case proto.Command.CommandTypeCase.SQL_COMMAND => - handleSqlCommand(command.getSqlCommand, sessionId, responseObserver) + handleSqlCommand(command.getSqlCommand, responseObserver, executeHolder) case proto.Command.CommandTypeCase.WRITE_STREAM_OPERATION_START => handleWriteStreamOperationStart( command.getWriteStreamOperationStart, - userId, - sessionId, - responseObserver) + responseObserver, + executeHolder) case proto.Command.CommandTypeCase.STREAMING_QUERY_COMMAND => - handleStreamingQueryCommand(command.getStreamingQueryCommand, sessionId, responseObserver) + handleStreamingQueryCommand( + command.getStreamingQueryCommand, + responseObserver, + executeHolder) case proto.Command.CommandTypeCase.STREAMING_QUERY_MANAGER_COMMAND => handleStreamingQueryManagerCommand( command.getStreamingQueryManagerCommand, - sessionId, - responseObserver) + responseObserver, + executeHolder) case proto.Command.CommandTypeCase.GET_RESOURCES_COMMAND => - handleGetResourcesCommand(sessionId, responseObserver) + handleGetResourcesCommand(responseObserver, executeHolder) case _ => throw new UnsupportedOperationException(s"$command not supported.") } } def handleSqlCommand( getSqlCommand: SqlCommand, - sessionId: String, - responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { + responseObserver: StreamObserver[ExecutePlanResponse], + executeHolder: ExecuteHolder): Unit = { // Eagerly execute commands of the provided SQL string. val args = getSqlCommand.getArgsMap val posArgs = getSqlCommand.getPosArgsList + val tracker = executeHolder.eventsManager.createQueryPlanningTracker val df = if (!args.isEmpty) { - session.sql(getSqlCommand.getSql, args.asScala.mapValues(transformLiteral).toMap) + session.sql(getSqlCommand.getSql, args.asScala.mapValues(transformLiteral).toMap, tracker) } else if (!posArgs.isEmpty) { - session.sql(getSqlCommand.getSql, posArgs.asScala.map(transformLiteral).toArray) + session.sql(getSqlCommand.getSql, posArgs.asScala.map(transformLiteral).toArray, tracker) } else { - session.sql(getSqlCommand.getSql) + session.sql(getSqlCommand.getSql, Map.empty[String, Any], tracker) } // Check if commands have been executed. val isCommand = df.queryExecution.commandExecuted.isInstanceOf[CommandResult] @@ -2430,6 +2435,9 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { .newBuilder() .setData(ByteString.copyFrom(bytes)))) } else { + // Trigger assertExecutedPlanPrepared to ensure post ReadyForExecution before finished + // executedPlan is currently called by createMetricsResponse below + df.queryExecution.assertExecutedPlanPrepared() result.setRelation( proto.Relation .newBuilder() @@ -2440,6 +2448,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { .putAllArgs(getSqlCommand.getArgsMap) .addAllPosArgs(getSqlCommand.getPosArgsList))) } + executeHolder.eventsManager.postFinished() // Exactly one SQL Command Result Batch responseObserver.onNext( ExecutePlanResponse @@ -2453,7 +2462,8 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { } private def handleRegisterUserDefinedFunction( - fun: proto.CommonInlineUserDefinedFunction): Unit = { + fun: proto.CommonInlineUserDefinedFunction, + executeHolder: ExecuteHolder): Unit = { fun.getFunctionCase match { case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF => handleRegisterPythonUDF(fun) @@ -2465,10 +2475,12 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { throw InvalidPlanInput( s"Function with ID: ${fun.getFunctionCase.getNumber} is not supported") } + executeHolder.eventsManager.postFinished() } private def handleRegisterUserDefinedTableFunction( - fun: proto.CommonInlineUserDefinedTableFunction): Unit = { + fun: proto.CommonInlineUserDefinedTableFunction, + executeHolder: ExecuteHolder): Unit = { fun.getFunctionCase match { case proto.CommonInlineUserDefinedTableFunction.FunctionCase.PYTHON_UDTF => val function = createPythonUserDefinedTableFunction(fun) @@ -2477,6 +2489,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { throw InvalidPlanInput( s"Function with ID: ${fun.getFunctionCase.getNumber} is not supported") } + executeHolder.eventsManager.postFinished() } private def createPythonUserDefinedTableFunction( @@ -2532,7 +2545,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { session.udf.register(fun.getFunctionName, udf) } - private def handleCommandPlugin(extension: ProtoAny): Unit = { + private def handleCommandPlugin(extension: ProtoAny, executeHolder: ExecuteHolder): Unit = { SparkConnectPluginRegistry.commandRegistry // Lazily traverse the collection. .view @@ -2542,9 +2555,12 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { .find(_.nonEmpty) .flatten .getOrElse(throw InvalidPlanInput("No handler found for extension")) + executeHolder.eventsManager.postFinished() } - private def handleCreateViewCommand(createView: proto.CreateDataFrameViewCommand): Unit = { + private def handleCreateViewCommand( + createView: proto.CreateDataFrameViewCommand, + executeHolder: ExecuteHolder): Unit = { val viewType = if (createView.getIsGlobal) GlobalTempView else LocalTempView val tableIdentifier = @@ -2566,7 +2582,9 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { replace = createView.getReplace, viewType = viewType) - Dataset.ofRows(session, plan).queryExecution.commandExecuted + val tracker = executeHolder.eventsManager.createQueryPlanningTracker + Dataset.ofRows(session, plan, tracker).queryExecution.commandExecuted + executeHolder.eventsManager.postFinished() } /** @@ -2578,11 +2596,14 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { * * @param writeOperation */ - private def handleWriteOperation(writeOperation: proto.WriteOperation): Unit = { + private def handleWriteOperation( + writeOperation: proto.WriteOperation, + executeHolder: ExecuteHolder): Unit = { // Transform the input plan into the logical plan. val plan = transformRelation(writeOperation.getInput) // And create a Dataset from the plan. - val dataset = Dataset.ofRows(session, logicalPlan = plan) + val tracker = executeHolder.eventsManager.createQueryPlanningTracker + val dataset = Dataset.ofRows(session, plan, tracker) val w = dataset.write if (writeOperation.getMode != proto.WriteOperation.SaveMode.SAVE_MODE_UNSPECIFIED) { @@ -2637,6 +2658,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { "WriteOperation:SaveTypeCase not supported " + s"${writeOperation.getSaveTypeCase.getNumber}") } + executeHolder.eventsManager.postFinished() } /** @@ -2648,11 +2670,14 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { * * @param writeOperation */ - def handleWriteOperationV2(writeOperation: proto.WriteOperationV2): Unit = { + def handleWriteOperationV2( + writeOperation: proto.WriteOperationV2, + executeHolder: ExecuteHolder): Unit = { // Transform the input plan into the logical plan. val plan = transformRelation(writeOperation.getInput) // And create a Dataset from the plan. - val dataset = Dataset.ofRows(session, logicalPlan = plan) + val tracker = executeHolder.eventsManager.createQueryPlanningTracker + val dataset = Dataset.ofRows(session, plan, tracker) val w = dataset.writeTo(table = writeOperation.getTableName) @@ -2703,15 +2728,18 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { throw new UnsupportedOperationException( s"WriteOperationV2:ModeValue not supported ${writeOperation.getModeValue}") } + executeHolder.eventsManager.postFinished() } def handleWriteStreamOperationStart( writeOp: WriteStreamOperationStart, - userId: String, - sessionId: String, - responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { + responseObserver: StreamObserver[ExecutePlanResponse], + executeHolder: ExecuteHolder): Unit = { val plan = transformRelation(writeOp.getInput) - val dataset = Dataset.ofRows(session, logicalPlan = plan) + val tracker = executeHolder.eventsManager.createQueryPlanningTracker + val dataset = Dataset.ofRows(session, plan, tracker) + // Call manually as writeStream does not trigger ReadyForExecution + tracker.setReadyForExecution() val writer = dataset.writeStream @@ -2789,6 +2817,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { SparkConnectService.streamingSessionManager.registerNewStreamingQuery( sessionHolder = SessionHolder(userId = userId, sessionId = sessionId, session), query = query) + executeHolder.eventsManager.postFinished() val result = WriteStreamOperationStartResult .newBuilder() @@ -2811,8 +2840,8 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { def handleStreamingQueryCommand( command: StreamingQueryCommand, - sessionId: String, - responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { + responseObserver: StreamObserver[ExecutePlanResponse], + executeHolder: ExecuteHolder): Unit = { val id = command.getQueryId.getId val runId = command.getQueryId.getRunId @@ -2915,6 +2944,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { throw new IllegalArgumentException("Missing command in StreamingQueryCommand") } + executeHolder.eventsManager.postFinished() responseObserver.onNext( ExecutePlanResponse .newBuilder() @@ -2982,9 +3012,8 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { def handleStreamingQueryManagerCommand( command: StreamingQueryManagerCommand, - sessionId: String, - responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { - + responseObserver: StreamObserver[ExecutePlanResponse], + executeHolder: ExecuteHolder): Unit = { val respBuilder = StreamingQueryManagerCommandResult.newBuilder() command.getCommandCase match { @@ -3045,6 +3074,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { throw new IllegalArgumentException("Missing command in StreamingQueryManagerCommand") } + executeHolder.eventsManager.postFinished() responseObserver.onNext( ExecutePlanResponse .newBuilder() @@ -3054,8 +3084,9 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { } def handleGetResourcesCommand( - sessionId: String, - responseObserver: StreamObserver[proto.ExecutePlanResponse]): Unit = { + responseObserver: StreamObserver[proto.ExecutePlanResponse], + executeHolder: ExecuteHolder): Unit = { + executeHolder.eventsManager.postFinished() responseObserver.onNext( proto.ExecutePlanResponse .newBuilder() diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala new file mode 100644 index 000000000000..0af54f034a25 --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala @@ -0,0 +1,420 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.service + +import com.fasterxml.jackson.annotation.JsonIgnore +import com.google.protobuf.Message + +import org.apache.spark.connect.proto +import org.apache.spark.scheduler.SparkListenerEvent +import org.apache.spark.sql.catalyst.{QueryPlanningTracker, QueryPlanningTrackerCallback} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connect.common.ProtoUtils +import org.apache.spark.util.{Clock, Utils} + +object ExecuteEventsManager { + // TODO: Make this configurable + val MAX_STATEMENT_TEXT_SIZE = 65535 +} + +sealed abstract class ExecuteStatus(value: Int) + +object ExecuteStatus { + case object Pending extends ExecuteStatus(0) + case object Started extends ExecuteStatus(1) + case object Analyzed extends ExecuteStatus(2) + case object ReadyForExecution extends ExecuteStatus(3) + case object Finished extends ExecuteStatus(4) + case object Failed extends ExecuteStatus(5) + case object Canceled extends ExecuteStatus(6) + case object Closed extends ExecuteStatus(7) +} + +/** + * Post request Connect events to @link org.apache.spark.scheduler.LiveListenerBus. + * + * @param executeHolder: + * Request for which the events are generated. + * @param clock: + * Source of time for unit tests. + */ +case class ExecuteEventsManager(executeHolder: ExecuteHolder, clock: Clock) { + + private def operationId = executeHolder.operationId + + private def jobTag = executeHolder.jobTag + + private def listenerBus = sessionHolder.session.sparkContext.listenerBus + + private def sessionHolder = executeHolder.sessionHolder + + private def sessionId = executeHolder.request.getSessionId + + private def sessionStatus = sessionHolder.eventManager.status + + private var _status: ExecuteStatus = ExecuteStatus.Pending + + private var error = Option.empty[Boolean] + + private var canceled = Option.empty[Boolean] + + /** + * @return + * Last event posted by the Connect request + */ + private[connect] def status: ExecuteStatus = _status + + /** + * @return + * True when the Connect request has posted @link + * org.apache.spark.sql.connect.service.SparkListenerConnectOperationCanceled + */ + private[connect] def hasCanceled: Option[Boolean] = canceled + + /** + * @return + * True when the Connect request has posted @link + * org.apache.spark.sql.connect.service.SparkListenerConnectOperationFailed + */ + private[connect] def hasError: Option[Boolean] = error + + /** + * Post @link org.apache.spark.sql.connect.service.SparkListenerConnectOperationStarted. + */ + def postStarted(): Unit = { + assertStatus(List(ExecuteStatus.Pending), ExecuteStatus.Started) + val request = executeHolder.request + val plan: Message = + request.getPlan.getOpTypeCase match { + case proto.Plan.OpTypeCase.COMMAND => request.getPlan.getCommand + case proto.Plan.OpTypeCase.ROOT => request.getPlan.getRoot + case _ => + throw new UnsupportedOperationException( + s"${request.getPlan.getOpTypeCase} not supported.") + } + + listenerBus.post( + SparkListenerConnectOperationStarted( + jobTag, + operationId, + clock.getTimeMillis(), + sessionId, + request.getUserContext.getUserId, + request.getUserContext.getUserName, + Utils.redact( + sessionHolder.session.sessionState.conf.stringRedactionPattern, + ProtoUtils.abbreviate(plan, ExecuteEventsManager.MAX_STATEMENT_TEXT_SIZE).toString), + Some(request))) + } + + /** + * Post @link org.apache.spark.sql.connect.service.SparkListenerConnectOperationAnalyzed. + * + * @param analyzedPlan + * The analyzed plan generated by the Connect request plan. None when the request does not + * generate a plan. + */ + def postAnalyzed(analyzedPlan: Option[LogicalPlan] = None): Unit = { + assertStatus(List(ExecuteStatus.Started, ExecuteStatus.Analyzed), ExecuteStatus.Analyzed) + val event = + SparkListenerConnectOperationAnalyzed(jobTag, operationId, clock.getTimeMillis()) + event.analyzedPlan = analyzedPlan + listenerBus.post(event) + } + + /** + * Post @link + * org.apache.spark.sql.connect.service.SparkListenerConnectOperationReadyForExecution. + */ + def postReadyForExecution(): Unit = { + assertStatus(List(ExecuteStatus.Analyzed), ExecuteStatus.ReadyForExecution) + listenerBus.post( + SparkListenerConnectOperationReadyForExecution(jobTag, operationId, clock.getTimeMillis())) + } + + /** + * Post @link org.apache.spark.sql.connect.service.SparkListenerConnectOperationCanceled. + */ + def postCanceled(): Unit = { + assertStatus( + List( + ExecuteStatus.Started, + ExecuteStatus.Analyzed, + ExecuteStatus.ReadyForExecution, + ExecuteStatus.Finished, + ExecuteStatus.Failed), + ExecuteStatus.Canceled) + canceled = Some(true) + listenerBus + .post(SparkListenerConnectOperationCanceled(jobTag, operationId, clock.getTimeMillis())) + } + + /** + * Post @link org.apache.spark.sql.connect.service.SparkListenerConnectOperationFailed. + * + * @param errorMessage + * The message of the error thrown during the request. + */ + def postFailed(errorMessage: String): Unit = { + assertStatus( + List( + ExecuteStatus.Started, + ExecuteStatus.Analyzed, + ExecuteStatus.ReadyForExecution, + ExecuteStatus.Finished), + ExecuteStatus.Failed) + error = Some(true) + listenerBus.post( + SparkListenerConnectOperationFailed( + jobTag, + operationId, + clock.getTimeMillis(), + errorMessage)) + } + + /** + * Post @link org.apache.spark.sql.connect.service.SparkListenerConnectOperationFinished. + */ + def postFinished(): Unit = { + assertStatus( + List(ExecuteStatus.Started, ExecuteStatus.ReadyForExecution), + ExecuteStatus.Finished) + listenerBus + .post(SparkListenerConnectOperationFinished(jobTag, operationId, clock.getTimeMillis())) + } + + /** + * Post @link org.apache.spark.sql.connect.service.SparkListenerConnectOperationClosed. + */ + def postClosed(): Unit = { + assertStatus( + List(ExecuteStatus.Finished, ExecuteStatus.Failed, ExecuteStatus.Canceled), + ExecuteStatus.Closed) + listenerBus + .post(SparkListenerConnectOperationClosed(jobTag, operationId, clock.getTimeMillis())) + } + + /** + * @return + * \@link A org.apache.spark.sql.catalyst.QueryPlanningTracker that calls postAnalyzed & + * postReadyForExecution after analysis & prior execution. + */ + def createQueryPlanningTracker(): QueryPlanningTracker = { + new QueryPlanningTracker(Some(new QueryPlanningTrackerCallback { + def analyzed(tracker: QueryPlanningTracker, analyzedPlan: LogicalPlan): Unit = { + postAnalyzed(Some(analyzedPlan)) + } + + def readyForExecution(tracker: QueryPlanningTracker): Unit = postReadyForExecution + })) + } + + private[connect] def status_(executeStatus: ExecuteStatus): Unit = { + _status = executeStatus + } + + private def assertStatus( + validStatuses: List[ExecuteStatus], + eventStatus: ExecuteStatus): Unit = { + if (!validStatuses + .find(s => s == status) + .isDefined) { + throw new IllegalStateException(s""" + operationId: $operationId with status ${status} + is not within statuses $validStatuses for event $eventStatus + """) + } + if (sessionHolder.eventManager.status != SessionStatus.Started) { + throw new IllegalStateException(s""" + sessionId: $sessionId with status $sessionStatus + is not Started for event $eventStatus + """) + } + _status = eventStatus + } +} + +/** + * Event sent after reception of a Connect request (i.e. not queued), but prior any analysis or + * execution. + * + * @param jobTag: + * Opaque Spark jobTag (@link org.apache.spark.SparkContext.setJobGroup) assigned by Connect + * during a request. Designed to be unique across sessions and requests. + * @param operationId: + * 36 characters UUID assigned by Connect during a request. + * @param eventTime: + * The time in ms when the event was generated. + * @param sessionId: + * ID assigned by the client or Connect the operation was executed on. + * @param userId: + * Opaque userId set in the Connect request. + * @param userName: + * Opaque userName set in the Connect request. + * @param statementText: + * The connect request plan converted to text. + * @param planRequest: + * The Connect request. None if the operation is not of type @link proto.ExecutePlanRequest + * @param extraTags: + * Additional metadata during the request. + */ +case class SparkListenerConnectOperationStarted( + jobTag: String, + operationId: String, + eventTime: Long, + sessionId: String, + userId: String, + userName: String, + statementText: String, + planRequest: Option[proto.ExecutePlanRequest], + extraTags: Map[String, String] = Map.empty) + extends SparkListenerEvent + +/** + * The event is sent after a Connect request has been analyzed (@link + * org.apache.spark.sql.catalyst.QueryPlanningTracker.ANALYSIS). + * + * @param jobTag: + * Opaque Spark jobTag (@link org.apache.spark.SparkContext.addJobTag) assigned by Connect + * during a request. Designed to be unique across sessions and requests. + * @param operationId: + * 36 characters UUID assigned by Connect during a request. + * @param eventTime: + * The time in ms when the event was generated. + * @param extraTags: + * Additional metadata during the request. + */ +case class SparkListenerConnectOperationAnalyzed( + jobTag: String, + operationId: String, + eventTime: Long, + extraTags: Map[String, String] = Map.empty) + extends SparkListenerEvent { + + /** + * Analyzed Spark plan generated by the Connect request. None when the Connect request does not + * generate a Spark plan. + */ + @JsonIgnore var analyzedPlan: Option[LogicalPlan] = None +} + +/** + * The event is sent after a Connect request is ready for execution. For eager commands this is + * after @link org.apache.spark.sql.catalyst.QueryPlanningTracker.ANALYSIS. For other requests it + * is after \@link org.apache.spark.sql.catalyst.QueryPlanningTracker.PLANNING + * + * @param jobTag: + * Opaque Spark jobTag (@link org.apache.spark.SparkContext.addJobTag) assigned by Connect + * during a request. Designed to be unique across sessions and requests. + * @param operationId: + * 36 characters UUID assigned by Connect during a request. + * @param eventTime: + * The time in ms when the event was generated. + * @param extraTags: + * Additional metadata during the request. + */ +case class SparkListenerConnectOperationReadyForExecution( + jobTag: String, + operationId: String, + eventTime: Long, + extraTags: Map[String, String] = Map.empty) + extends SparkListenerEvent + +/** + * Event sent after a Connect request has been canceled. + * + * @param jobTag: + * Opaque Spark jobTag (@link org.apache.spark.SparkContext.addJobTag) assigned by Connect + * during a request. Designed to be unique across sessions and requests. + * @param operationId: + * 36 characters UUID assigned by Connect during a request. + * @param eventTime: + * The time in ms when the event was generated. + * @param extraTags: + * Additional metadata during the request. + */ +case class SparkListenerConnectOperationCanceled( + jobTag: String, + operationId: String, + eventTime: Long, + extraTags: Map[String, String] = Map.empty) + extends SparkListenerEvent + +/** + * Event sent after a Connect request has failed. + * + * @param jobTag: + * Opaque Spark jobTag (@link org.apache.spark.SparkContext.addJobTag) assigned by Connect + * during a request. Designed to be unique across sessions and requests. + * @param operationId: + * 36 characters UUID assigned by Connect during a request. + * @param eventTime: + * The time in ms when the event was generated. + * @param errorMessage: + * The message of the error thrown during the request. + * @param extraTags: + * Additional metadata during the request. + */ +case class SparkListenerConnectOperationFailed( + jobTag: String, + operationId: String, + eventTime: Long, + errorMessage: String, + extraTags: Map[String, String] = Map.empty) + extends SparkListenerEvent + +/** + * Event sent after a Connect request has finished executing, but prior results have been sent to + * client. + * + * @param jobTag: + * Opaque Spark jobTag (@link org.apache.spark.SparkContext.addJobTag) assigned by Connect + * during a request. Designed to be unique across sessions and requests. + * @param operationId: + * 36 characters UUID assigned by Connect during a request. + * @param eventTime: + * The time in ms when the event was generated. + * @param extraTags: + * Additional metadata during the request. + */ +case class SparkListenerConnectOperationFinished( + jobTag: String, + operationId: String, + eventTime: Long, + extraTags: Map[String, String] = Map.empty) + extends SparkListenerEvent + +/** + * Event sent after a Connect request has finished executing and results have been sent to client. + * + * @param jobTag: + * Opaque Spark jobTag (@link org.apache.spark.SparkContext.addJobTag) assigned by Connect + * during a request. Designed to be unique across sessions and requests. + * @param operationId: + * 36 characters UUID assigned by Connect during a request. + * @param eventTime: + * The time in ms when the event was generated. + * @param extraTags: + * Additional metadata during the request. + */ +case class SparkListenerConnectOperationClosed( + jobTag: String, + operationId: String, + eventTime: Long, + extraTags: Map[String, String] = Map.empty) + extends SparkListenerEvent diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala index 89aceaee1e4a..1f70973b60e0 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.connect.service import org.apache.spark.connect.proto import org.apache.spark.internal.Logging import org.apache.spark.sql.connect.execution.{ExecuteGrpcResponseSender, ExecuteResponseObserver, ExecuteThreadRunner} +import org.apache.spark.util.SystemClock /** * Object used to hold the Spark Connect execution state. @@ -41,6 +42,8 @@ private[connect] class ExecuteHolder( val responseObserver: ExecuteResponseObserver[proto.ExecutePlanResponse] = new ExecuteResponseObserver[proto.ExecutePlanResponse]() + val eventsManager: ExecuteEventsManager = ExecuteEventsManager(this, new SystemClock()) + private val runner: ExecuteThreadRunner = new ExecuteThreadRunner(this) /** diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionEventsManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionEventsManager.scala new file mode 100644 index 000000000000..f275fab56bf5 --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionEventsManager.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.service + +import org.apache.spark.scheduler.SparkListenerEvent +import org.apache.spark.util.{Clock} + +sealed abstract class SessionStatus(value: Int) + +object SessionStatus { + case object Pending extends SessionStatus(0) + case object Started extends SessionStatus(1) + case object Closed extends SessionStatus(2) +} + +/** + * Post session Connect events to @link org.apache.spark.scheduler.LiveListenerBus. + * + * @param sessionHolder: + * Session for which the events are generated. + * @param clock: + * Source of time for unit tests. + */ +case class SessionEventsManager(sessionHolder: SessionHolder, clock: Clock) { + + private def sessionId = sessionHolder.sessionId + + private var _status: SessionStatus = SessionStatus.Pending + + private[connect] def status_(sessionStatus: SessionStatus): Unit = { + _status = sessionStatus + } + + /** + * @return + * Last event posted by the Connect session + */ + def status: SessionStatus = _status + + /** + * Post @link org.apache.spark.sql.connect.service.SparkListenerConnectSessionStarted. + */ + def postStarted(): Unit = { + assertStatus(List(SessionStatus.Pending), SessionStatus.Started) + sessionHolder.session.sparkContext.listenerBus + .post( + SparkListenerConnectSessionStarted( + sessionHolder.sessionId, + sessionHolder.userId, + clock.getTimeMillis())) + } + + /** + * Post @link org.apache.spark.sql.connect.service.SparkListenerConnectSessionClosed. + */ + def postClosed(): Unit = { + assertStatus(List(SessionStatus.Started), SessionStatus.Closed) + sessionHolder.session.sparkContext.listenerBus + .post( + SparkListenerConnectSessionClosed( + sessionHolder.sessionId, + sessionHolder.userId, + clock.getTimeMillis())) + } + + private def assertStatus( + validStatuses: List[SessionStatus], + eventStatus: SessionStatus): Unit = { + if (!validStatuses + .find(s => s == status) + .isDefined) { + throw new IllegalStateException(s""" + sessionId: $sessionId with status ${status} + is not within statuses $validStatuses for event $eventStatus + """) + } + _status = eventStatus + } +} + +/** + * Event sent after a Connect session has been started. + * + * @param sessionId: + * ID assigned by the client or Connect the operation was executed on. + * @param eventTime: + * The time in ms when the event was generated. + * @param extraTags: + * Additional metadata + */ +case class SparkListenerConnectSessionStarted( + sessionId: String, + userId: String, + eventTime: Long, + extraTags: Map[String, String] = Map.empty) + extends SparkListenerEvent + +/** + * Event sent after a Connect session has been closed. + * + * @param sessionId: + * ID assigned by the client or Connect the operation was executed on. + * @param eventTime: + * The time in ms when the event was generated. + * @param extraTags: + * Additional metadata + */ +case class SparkListenerConnectSessionClosed( + sessionId: String, + userId: String, + eventTime: Long, + extraTags: Map[String, String] = Map.empty) + extends SparkListenerEvent diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index 2f3bd1badcec..5ac4f6db82aa 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager import org.apache.spark.sql.connect.common.InvalidPlanInput import org.apache.spark.sql.streaming.StreamingQueryListener +import org.apache.spark.util.{SystemClock} import org.apache.spark.util.Utils /** @@ -44,6 +45,8 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio val executions: ConcurrentMap[String, ExecuteHolder] = new ConcurrentHashMap[String, ExecuteHolder]() + val eventManager: SessionEventsManager = SessionEventsManager(this, new SystemClock()) + // Mapping from relation ID (passed to client) to runtime dataframe. Used for callbacks like // foreachBatch() in Streaming. Lazy since most sessions don't need it. private lazy val dataFrameCache: ConcurrentMap[String, DataFrame] = new ConcurrentHashMap() @@ -60,6 +63,10 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio executePlanHolder } + private[connect] def executeHolder(operationId: String): Option[ExecuteHolder] = { + Option(executions.get(operationId)) + } + private[connect] def removeExecuteHolder(operationId: String): Unit = { executions.remove(operationId) } @@ -98,12 +105,17 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio */ def classloader: ClassLoader = artifactManager.classloader + private[connect] def initializeSession(): Unit = { + eventManager.postStarted() + } + /** * Expire this session and trigger state cleanup mechanisms. */ private[connect] def expireSession(): Unit = { logDebug(s"Expiring session with userId: $userId and sessionId: $sessionId") artifactManager.cleanUpResources() + eventManager.postClosed() } /** diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala index 50ca733b4391..b4e91c438359 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala @@ -32,6 +32,7 @@ class SparkConnectExecutePlanHandler(responseObserver: StreamObserver[proto.Exec val executeHolder = sessionHolder.createExecuteHolder(v) try { + executeHolder.eventsManager.postStarted() executeHolder.start() val responseSender = new ExecuteGrpcResponseSender[proto.ExecutePlanResponse](responseObserver) @@ -44,6 +45,7 @@ class SparkConnectExecutePlanHandler(responseObserver: StreamObserver[proto.Exec } finally { // TODO this will change with detachable execution. executeHolder.join() + executeHolder.eventsManager.postClosed() sessionHolder.removeExecuteHolder(executeHolder.operationId) } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index c38fbbdfcf97..ad40c94d5498 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -223,10 +223,19 @@ object SparkConnectService { userSessionMapping.get( (userId, sessionId), () => { - SessionHolder(userId, sessionId, newIsolatedSession()) + val holder = SessionHolder(userId, sessionId, newIsolatedSession()) + holder.initializeSession() + holder }) } + /** + * Used for testing + */ + private[connect] def invalidateAllSessions(): Unit = { + userSessionMapping.invalidateAll() + } + private def newIsolatedSession(): SparkSession = { SparkSession.active.newSession() } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala index d0f754827dad..326bdd0052c6 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala @@ -35,6 +35,7 @@ import org.apache.spark.{SparkEnv, SparkException, SparkThrowable} import org.apache.spark.api.python.PythonException import org.apache.spark.internal.Logging import org.apache.spark.sql.connect.config.Connect +import org.apache.spark.sql.connect.service.ExecuteEventsManager import org.apache.spark.sql.connect.service.SparkConnectService import org.apache.spark.sql.internal.SQLConf @@ -103,32 +104,42 @@ private[connect] object ErrorUtils extends Logging { opType: String, observer: StreamObserver[V], userId: String, - sessionId: String): PartialFunction[Throwable, Unit] = { + sessionId: String, + events: Option[ExecuteEventsManager] = None, + isInterrupted: Boolean = false): PartialFunction[Throwable, Unit] = { val session = SparkConnectService .getOrCreateIsolatedSession(userId, sessionId) .session val stackTraceEnabled = session.conf.get(SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED) - { + val partial: PartialFunction[Throwable, (Throwable, Throwable)] = { case se: SparkException if isPythonExecutionException(se) => - logError(s"Error during: $opType. UserId: $userId. SessionId: $sessionId.", se) - observer.onError( + ( + se, StatusProto.toStatusRuntimeException( buildStatusFromThrowable(se.getCause, stackTraceEnabled))) case e: Throwable if e.isInstanceOf[SparkThrowable] || NonFatal.apply(e) => - logError(s"Error during: $opType. UserId: $userId. SessionId: $sessionId.", e) - observer.onError( - StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e, stackTraceEnabled))) + (e, StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e, stackTraceEnabled))) case e: Throwable => - logError(s"Error during: $opType. UserId: $userId. SessionId: $sessionId.", e) - observer.onError( + ( + e, Status.UNKNOWN .withCause(e) .withDescription(StringUtils.abbreviate(e.getMessage, 2048)) .asRuntimeException()) } + partial + .andThen { case (original, wrapped) => + logError(s"Error during: $opType. UserId: $userId. SessionId: $sessionId.", original) + if (isInterrupted) { + events.foreach(_.postCanceled) + } else { + events.foreach(_.postFailed(wrapped.getMessage)) + } + observer.onError(wrapped) + } } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index a10540676b04..595f9d65c269 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.connect.common.InvalidPlanInput import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto -import org.apache.spark.sql.connect.service.SessionHolder +import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteStatus, SessionHolder, SessionStatus} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} @@ -58,8 +58,9 @@ trait SparkConnectPlanTest extends SharedSparkSession { } def transform(cmd: proto.Command): Unit = { - new SparkConnectPlanner(SessionHolder.forTesting(spark)) - .process(cmd, "clientId", "sessionId", new MockObserver()) + val executeHolder = buildExecutePlanHolder(cmd) + new SparkConnectPlanner(executeHolder.sessionHolder) + .process(cmd, new MockObserver(), executeHolder) } def readRel: proto.Relation = @@ -114,6 +115,28 @@ trait SparkConnectPlanTest extends SharedSparkSession { localRelationBuilder.setData(ByteString.copyFrom(bytes)) proto.Relation.newBuilder().setLocalRelation(localRelationBuilder.build()).build() } + + def buildExecutePlanHolder(command: proto.Command): ExecuteHolder = { + val sessionHolder = SessionHolder.forTesting(spark) + sessionHolder.eventManager.status_(SessionStatus.Started) + + val context = proto.UserContext + .newBuilder() + .setUserId(sessionHolder.userId) + .build() + val plan = proto.Plan + .newBuilder() + .setCommand(command) + .build() + val request = proto.ExecutePlanRequest + .newBuilder() + .setPlan(plan) + .setUserContext(context) + .build() + val executeHolder = sessionHolder.createExecuteHolder(request) + executeHolder.eventsManager.status_(ExecuteStatus.Started) + executeHolder + } } /** diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index bceaada9051e..498084efb8f3 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -16,30 +16,49 @@ */ package org.apache.spark.sql.connect.planner +import java.util.UUID +import java.util.concurrent.Semaphore + import scala.collection.JavaConverters._ import scala.collection.mutable +import com.google.protobuf +import com.google.protobuf.ByteString import io.grpc.StatusRuntimeException import io.grpc.stub.StreamObserver import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.{BigIntVector, Float8Vector} import org.apache.arrow.vector.ipc.ArrowStreamReader import org.apache.commons.lang3.{JavaVersion, SystemUtils} +import org.mockito.Mockito.when +import org.scalatest.Tag +import org.scalatestplus.mockito.MockitoSugar +import org.apache.spark.{SparkContext, SparkEnv} import org.apache.spark.connect.proto +import org.apache.spark.connect.proto.CreateDataFrameViewCommand +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent} +import org.apache.spark.sql.connect.common.DataTypeProtoConverter +import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.dsl.MockRemoteSession import org.apache.spark.sql.connect.dsl.expressions._ import org.apache.spark.sql.connect.dsl.plans._ -import org.apache.spark.sql.connect.service.{SparkConnectAnalyzeHandler, SparkConnectService} -import org.apache.spark.sql.connect.service.SessionHolder +import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry +import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteStatus, SessionHolder, SessionStatus, SparkConnectAnalyzeHandler, SparkConnectService, SparkListenerConnectOperationStarted} +import org.apache.spark.sql.connector.catalog.InMemoryPartitionTableCatalog +import org.apache.spark.sql.streaming.StreamingQuery import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.util.Utils /** * Testing Connect Service implementation. */ -class SparkConnectServiceSuite extends SharedSparkSession { +class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with Logging { private def sparkSessionHolder = SessionHolder.forTesting(spark) + private def DEFAULT_UUID = UUID.fromString("89ea6117-1f45-4c03-ae27-f47c6aded093") test("Test schema in analyze response") { withTable("test") { @@ -131,126 +150,365 @@ class SparkConnectServiceSuite extends SharedSparkSession { } test("SPARK-41224: collect data using arrow") { - // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21 - assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) - val instance = new SparkConnectService(false) - val connect = new MockRemoteSession() - val context = proto.UserContext - .newBuilder() - .setUserId("c1") - .build() - val plan = proto.Plan - .newBuilder() - .setRoot(connect.sql("select id, exp(id) as eid from range(0, 100, 1, 4)")) - .build() - val request = proto.ExecutePlanRequest - .newBuilder() - .setPlan(plan) - .setUserContext(context) - .build() - - // Execute plan. - @volatile var done = false - val responses = mutable.Buffer.empty[proto.ExecutePlanResponse] - instance.executePlan( - request, - new StreamObserver[proto.ExecutePlanResponse] { - override def onNext(v: proto.ExecutePlanResponse): Unit = responses += v - - override def onError(throwable: Throwable): Unit = throw throwable - - override def onCompleted(): Unit = done = true - }) - - // The current implementation is expected to be blocking. This is here to make sure it is. - assert(done) - - // 4 Partitions + Metrics - assert(responses.size == 6) - - // Make sure the first response is schema only - val head = responses.head - assert(head.hasSchema && !head.hasArrowBatch && !head.hasMetrics) - - // Make sure the last response is metrics only - val last = responses.last - assert(last.hasMetrics && !last.hasSchema && !last.hasArrowBatch) - - val allocator = new RootAllocator() - - // Check the 'data' batches - var expectedId = 0L - var previousEId = 0.0d - responses.tail.dropRight(1).foreach { response => - assert(response.hasArrowBatch) - val batch = response.getArrowBatch - assert(batch.getData != null) - assert(batch.getRowCount == 25) - - val reader = new ArrowStreamReader(batch.getData.newInput(), allocator) - while (reader.loadNextBatch()) { - val root = reader.getVectorSchemaRoot - val idVector = root.getVector(0).asInstanceOf[BigIntVector] - val eidVector = root.getVector(1).asInstanceOf[Float8Vector] - val numRows = root.getRowCount - var i = 0 - while (i < numRows) { - assert(idVector.get(i) == expectedId) - expectedId += 1 - val eid = eidVector.get(i) - assert(eid > previousEId) - previousEId = eid - i += 1 + withEvents { verifyEvents => + // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21 + assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) + val instance = new SparkConnectService(false) + val connect = new MockRemoteSession() + val context = proto.UserContext + .newBuilder() + .setUserId("c1") + .build() + val plan = proto.Plan + .newBuilder() + .setRoot(connect.sql("select id, exp(id) as eid from range(0, 100, 1, 4)")) + .build() + val request = proto.ExecutePlanRequest + .newBuilder() + .setPlan(plan) + .setUserContext(context) + .build() + + // Execute plan. + @volatile var done = false + val responses = mutable.Buffer.empty[proto.ExecutePlanResponse] + instance.executePlan( + request, + new StreamObserver[proto.ExecutePlanResponse] { + override def onNext(v: proto.ExecutePlanResponse): Unit = { + responses += v + verifyEvents.onNext(v) + } + + override def onError(throwable: Throwable): Unit = { + verifyEvents.onError(throwable) + throw throwable + } + + override def onCompleted(): Unit = { + done = true + } + }) + verifyEvents.onCompleted() + // The current implementation is expected to be blocking. This is here to make sure it is. + assert(done) + + // 4 Partitions + Metrics + assert(responses.size == 6) + + // Make sure the first response is schema only + val head = responses.head + assert(head.hasSchema && !head.hasArrowBatch && !head.hasMetrics) + + // Make sure the last response is metrics only + val last = responses.last + assert(last.hasMetrics && !last.hasSchema && !last.hasArrowBatch) + + val allocator = new RootAllocator() + + // Check the 'data' batches + var expectedId = 0L + var previousEId = 0.0d + responses.tail.dropRight(1).foreach { response => + assert(response.hasArrowBatch) + val batch = response.getArrowBatch + assert(batch.getData != null) + assert(batch.getRowCount == 25) + + val reader = new ArrowStreamReader(batch.getData.newInput(), allocator) + while (reader.loadNextBatch()) { + val root = reader.getVectorSchemaRoot + val idVector = root.getVector(0).asInstanceOf[BigIntVector] + val eidVector = root.getVector(1).asInstanceOf[Float8Vector] + val numRows = root.getRowCount + var i = 0 + while (i < numRows) { + assert(idVector.get(i) == expectedId) + expectedId += 1 + val eid = eidVector.get(i) + assert(eid > previousEId) + previousEId = eid + i += 1 + } } + reader.close() } - reader.close() + allocator.close() } - allocator.close() } - test("SPARK-41165: failures in the arrow collect path should not cause hangs") { - val instance = new SparkConnectService(false) + gridTest("SPARK-43923: commands send events")( + Seq( + proto.Command + .newBuilder() + .setSqlCommand(proto.SqlCommand.newBuilder().setSql("select 1").build()), + proto.Command + .newBuilder() + .setSqlCommand(proto.SqlCommand.newBuilder().setSql("show tables").build()), + proto.Command + .newBuilder() + .setWriteOperation( + proto.WriteOperation + .newBuilder() + .setInput( + proto.Relation.newBuilder().setSql(proto.SQL.newBuilder().setQuery("select 1"))) + .setPath("my/test/path") + .setMode(proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE)), + proto.Command + .newBuilder() + .setWriteOperationV2( + proto.WriteOperationV2 + .newBuilder() + .setInput(proto.Relation.newBuilder.setRange( + proto.Range.newBuilder().setStart(0).setEnd(2).setStep(1L))) + .setTableName("testcat.testtable") + .setMode(proto.WriteOperationV2.Mode.MODE_CREATE)), + proto.Command + .newBuilder() + .setCreateDataframeView( + CreateDataFrameViewCommand + .newBuilder() + .setName("testview") + .setInput( + proto.Relation.newBuilder().setSql(proto.SQL.newBuilder().setQuery("select 1")))), + proto.Command + .newBuilder() + .setGetResourcesCommand(proto.GetResourcesCommand.newBuilder()), + proto.Command + .newBuilder() + .setExtension( + protobuf.Any.pack( + proto.ExamplePluginCommand + .newBuilder() + .setCustomField("SPARK-43923") + .build())), + proto.Command + .newBuilder() + .setWriteStreamOperationStart( + proto.WriteStreamOperationStart + .newBuilder() + .setInput( + proto.Relation + .newBuilder() + .setRead(proto.Read + .newBuilder() + .setIsStreaming(true) + .setDataSource(proto.Read.DataSource.newBuilder().setFormat("rate").build()) + .build()) + .build()) + .setOutputMode("Append") + .setAvailableNow(true) + .setQueryName("test") + .setFormat("memory") + .putOptions("checkpointLocation", s"${UUID.randomUUID}") + .setPath("test-path") + .build()), + proto.Command + .newBuilder() + .setStreamingQueryCommand( + proto.StreamingQueryCommand + .newBuilder() + .setQueryId( + proto.StreamingQueryInstanceId + .newBuilder() + .setId(DEFAULT_UUID.toString) + .setRunId(DEFAULT_UUID.toString) + .build()) + .setStop(true)), + proto.Command + .newBuilder() + .setStreamingQueryManagerCommand(proto.StreamingQueryManagerCommand + .newBuilder() + .setListListeners(true)), + proto.Command + .newBuilder() + .setRegisterFunction( + proto.CommonInlineUserDefinedFunction + .newBuilder() + .setFunctionName("function") + .setPythonUdf( + proto.PythonUDF + .newBuilder() + .setEvalType(100) + .setOutputType(DataTypeProtoConverter.toConnectProtoType(IntegerType)) + .setCommand(ByteString.copyFrom("command".getBytes())) + .setPythonVer("3.10") + .build())))) { command => + withCommandTest { verifyEvents => + val instance = new SparkConnectService(false) + val context = proto.UserContext + .newBuilder() + .setUserId("c1") + .build() + val plan = proto.Plan + .newBuilder() + .setCommand(command) + .build() + + val request = proto.ExecutePlanRequest + .newBuilder() + .setPlan(plan) + .setSessionId("s1") + .setUserContext(context) + .build() + + // Execute plan. + @volatile var done = false + val responses = mutable.Buffer.empty[proto.ExecutePlanResponse] + instance.executePlan( + request, + new StreamObserver[proto.ExecutePlanResponse] { + override def onNext(v: proto.ExecutePlanResponse): Unit = { + responses += v + verifyEvents.onNext(v) + } + + override def onError(throwable: Throwable): Unit = { + verifyEvents.onError(throwable) + throw throwable + } + + override def onCompleted(): Unit = { + done = true + } + }) + verifyEvents.onCompleted() + // The current implementation is expected to be blocking. + // This is here to make sure it is. + assert(done) - // Add an always crashing UDF - val session = SparkConnectService.getOrCreateIsolatedSession("c1", "session").session - val instaKill: Long => Long = { _ => - throw new Exception("Kaboom") + // Result + Metrics + if (responses.size > 1) { + assert(responses.size == 2) + + // Make sure the first response result only + val head = responses.head + assert(head.hasSqlCommandResult && !head.hasMetrics) + + // Make sure the last response is metrics only + val last = responses.last + assert(last.hasMetrics && !last.hasSqlCommandResult) + } } - session.udf.register("insta_kill", instaKill) - - val connect = new MockRemoteSession() - val context = proto.UserContext - .newBuilder() - .setUserId("c1") - .build() - val plan = proto.Plan - .newBuilder() - .setRoot(connect.sql("select insta_kill(id) from range(10)")) - .build() - val request = proto.ExecutePlanRequest - .newBuilder() - .setPlan(plan) - .setUserContext(context) - .setSessionId("session") - .build() - - // The observer is executed inside this thread. So - // we can perform the checks inside the observer. - instance.executePlan( - request, - new StreamObserver[proto.ExecutePlanResponse] { - override def onNext(v: proto.ExecutePlanResponse): Unit = { - fail("this should not receive responses") - } + } - override def onError(throwable: Throwable): Unit = { - assert(throwable.isInstanceOf[StatusRuntimeException]) - } + test("SPARK-43923: canceled request send events") { + withEvents { verifyEvents => + val instance = new SparkConnectService(false) + + // Add an always crashing UDF + val session = SparkConnectService.getOrCreateIsolatedSession("c1", "session").session + val sleep: Long => Long = { time => + Thread.sleep(time) + time + } + session.udf.register("sleep", sleep) - override def onCompleted(): Unit = { - fail("this should not complete") + val connect = new MockRemoteSession() + val context = proto.UserContext + .newBuilder() + .setUserId("c1") + .build() + val plan = proto.Plan + .newBuilder() + .setRoot(connect.sql("select sleep(10000)")) + .build() + val request = proto.ExecutePlanRequest + .newBuilder() + .setPlan(plan) + .setUserContext(context) + .setSessionId("session") + .build() + + val thread = new Thread { + override def run: Unit = { + verifyEvents.listener.semaphoreStarted.acquire() + instance.interrupt( + proto.InterruptRequest + .newBuilder() + .setSessionId("session") + .setUserContext(context) + .setInterruptType(proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_ALL) + .build(), + new StreamObserver[proto.InterruptResponse] { + override def onNext(v: proto.InterruptResponse): Unit = {} + + override def onError(throwable: Throwable): Unit = {} + + override def onCompleted(): Unit = {} + }) } - }) + } + thread.start() + // The observer is executed inside this thread. So + // we can perform the checks inside the observer. + instance.executePlan( + request, + new StreamObserver[proto.ExecutePlanResponse] { + override def onNext(v: proto.ExecutePlanResponse): Unit = { + logInfo(s"$v") + } + + override def onError(throwable: Throwable): Unit = { + verifyEvents.onCanceled + } + + override def onCompleted(): Unit = { + fail("this should not complete") + } + }) + thread.join() + verifyEvents.onCompleted() + } + } + + test("SPARK-41165: failures in the arrow collect path should not cause hangs") { + withEvents { verifyEvents => + val instance = new SparkConnectService(false) + + // Add an always crashing UDF + val session = SparkConnectService.getOrCreateIsolatedSession("c1", "session").session + val instaKill: Long => Long = { _ => + throw new Exception("Kaboom") + } + session.udf.register("insta_kill", instaKill) + + val connect = new MockRemoteSession() + val context = proto.UserContext + .newBuilder() + .setUserId("c1") + .build() + val plan = proto.Plan + .newBuilder() + .setRoot(connect.sql("select insta_kill(id) from range(10)")) + .build() + val request = proto.ExecutePlanRequest + .newBuilder() + .setPlan(plan) + .setUserContext(context) + .setSessionId("session") + .build() + + // The observer is executed inside this thread. So + // we can perform the checks inside the observer. + instance.executePlan( + request, + new StreamObserver[proto.ExecutePlanResponse] { + override def onNext(v: proto.ExecutePlanResponse): Unit = { + fail("this should not receive responses") + } + + override def onError(throwable: Throwable): Unit = { + assert(throwable.isInstanceOf[StatusRuntimeException]) + verifyEvents.onError(throwable) + } + + override def onCompleted(): Unit = { + fail("this should not complete") + } + }) + verifyEvents.onCompleted() + } } test("Test explain mode in analyze response") { @@ -378,4 +636,108 @@ class SparkConnectServiceSuite extends SharedSparkSession { assert(valuesList.last.hasLong && valuesList.last.getLong == 99) } } + + protected def withCommandTest(f: VerifyEvents => Unit): Unit = { + withView("testview") { + withTable("testcat.testtable") { + withSparkConf( + "spark.sql.catalog.testcat" -> classOf[InMemoryPartitionTableCatalog].getName, + Connect.CONNECT_EXTENSIONS_COMMAND_CLASSES.key -> + "org.apache.spark.sql.connect.plugin.ExampleCommandPlugin") { + withEvents { verifyEvents => + val restartedQuery = mock[StreamingQuery] + when(restartedQuery.id).thenReturn(DEFAULT_UUID) + when(restartedQuery.runId).thenReturn(DEFAULT_UUID) + SparkConnectService.streamingSessionManager.registerNewStreamingQuery( + SparkConnectService.getOrCreateIsolatedSession("c1", "s1"), + restartedQuery) + f(verifyEvents) + } + } + } + } + } + + protected def withSparkConf(pairs: (String, String)*)(f: => Unit): Unit = { + val conf = SparkEnv.get.conf + pairs.foreach { kv => conf.set(kv._1, kv._2) } + try f + finally { + pairs.foreach { kv => conf.remove(kv._1) } + } + } + + protected def withEvents(f: VerifyEvents => Unit): Unit = { + val verifyEvents = new VerifyEvents(spark.sparkContext) + spark.sparkContext.addSparkListener(verifyEvents.listener) + Utils.tryWithSafeFinally({ + f(verifyEvents) + SparkConnectService.invalidateAllSessions() + verifyEvents.onSessionClosed() + }) { + verifyEvents.waitUntilEmpty() + spark.sparkContext.removeSparkListener(verifyEvents.listener) + SparkConnectService.invalidateAllSessions() + SparkConnectPluginRegistry.reset() + } + } + + protected def gridTest[A](testNamePrefix: String, testTags: Tag*)(params: Seq[A])( + testFun: A => Unit): Unit = { + for (param <- params) { + test(testNamePrefix + s" ($param)", testTags: _*)(testFun(param)) + } + } + + class VerifyEvents(val sparkContext: SparkContext) { + val listener: MockSparkListener = new MockSparkListener() + val listenerBus = sparkContext.listenerBus + val LISTENER_BUS_TIMEOUT = 30000 + def executeHolder: ExecuteHolder = { + assert(listener.executeHolder.isDefined) + listener.executeHolder.get + } + def onNext(v: proto.ExecutePlanResponse): Unit = { + if (v.hasSchema) { + assert(executeHolder.eventsManager.status == ExecuteStatus.Analyzed) + } + if (v.hasMetrics) { + assert(executeHolder.eventsManager.status == ExecuteStatus.Finished) + } + } + def onError(throwable: Throwable): Unit = { + assert(executeHolder.eventsManager.hasCanceled.isEmpty) + assert(executeHolder.eventsManager.hasError.isDefined) + } + def onCompleted(): Unit = { + assert(executeHolder.eventsManager.status == ExecuteStatus.Closed) + } + def onCanceled(): Unit = { + assert(executeHolder.eventsManager.hasCanceled.contains(true)) + assert(executeHolder.eventsManager.hasError.isEmpty) + } + def onSessionClosed(): Unit = { + assert(executeHolder.sessionHolder.eventManager.status == SessionStatus.Closed) + } + def onSessionStarted(): Unit = { + assert(executeHolder.sessionHolder.eventManager.status == SessionStatus.Started) + } + def waitUntilEmpty(): Unit = { + listenerBus.waitUntilEmpty(LISTENER_BUS_TIMEOUT) + } + } + class MockSparkListener() extends SparkListener { + val semaphoreStarted = new Semaphore(0) + var executeHolder = Option.empty[ExecuteHolder] + override def onOtherEvent(event: SparkListenerEvent): Unit = { + event match { + case e: SparkListenerConnectOperationStarted => + semaphoreStarted.release() + val sessionHolder = + SparkConnectService.getOrCreateIsolatedSession(e.userId, e.sessionId) + executeHolder = sessionHolder.executeHolder(e.operationId) + case _ => + } + } + } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala index 2bdabc7ccc21..fdb903237941 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connect.common.InvalidPlanInput import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.planner.{SparkConnectPlanner, SparkConnectPlanTest} -import org.apache.spark.sql.connect.service.SessionHolder import org.apache.spark.sql.test.SharedSparkSession class DummyPlugin extends RelationPlugin { @@ -196,8 +195,9 @@ class SparkConnectPluginRegistrySuite extends SharedSparkSession with SparkConne .build())) .build() - new SparkConnectPlanner(SessionHolder.forTesting(spark)) - .process(plan, "clientId", "sessionId", new MockObserver()) + val executeHolder = buildExecutePlanHolder(plan) + new SparkConnectPlanner(executeHolder.sessionHolder) + .process(plan, new MockObserver(), executeHolder) assert(spark.sparkContext.getLocalProperty("testingProperty").equals("Martin")) } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala new file mode 100644 index 000000000000..365b17632a74 --- /dev/null +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala @@ -0,0 +1,318 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.service + +import scala.util.matching.Regex + +import org.mockito.Mockito._ +import org.scalatestplus.mockito.MockitoSugar + +import org.apache.spark.{SparkContext, SparkFunSuite} +import org.apache.spark.connect.proto +import org.apache.spark.connect.proto.{ExecutePlanRequest, Plan, UserContext} +import org.apache.spark.scheduler.LiveListenerBus +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connect.planner.SparkConnectPlanTest +import org.apache.spark.sql.internal.{SessionState, SQLConf} +import org.apache.spark.util.ManualClock + +class ExecuteEventsManagerSuite + extends SparkFunSuite + with MockitoSugar + with SparkConnectPlanTest { + + val DEFAULT_ERROR = "error" + val DEFAULT_CLOCK = new ManualClock() + val DEFAULT_NODE_NAME = "nodeName" + val DEFAULT_TEXT = """limit { + limit: 10 +} +""" + val DEFAULT_USER_ID = "1" + val DEFAULT_USER_NAME = "userName" + val DEFAULT_SESSION_ID = "2" + val DEFAULT_QUERY_ID = "3" + val DEFAULT_CLIENT_TYPE = "clientType" + + test("SPARK-43923: post started") { + val events = setupEvents(ExecuteStatus.Pending) + events.postStarted() + + verify(events.executeHolder.sessionHolder.session.sparkContext.listenerBus, times(1)) + .post(SparkListenerConnectOperationStarted( + events.executeHolder.jobTag, + DEFAULT_QUERY_ID, + DEFAULT_CLOCK.getTimeMillis(), + DEFAULT_SESSION_ID, + DEFAULT_USER_ID, + DEFAULT_USER_NAME, + DEFAULT_TEXT, + Some(events.executeHolder.request), + Map.empty)) + } + + test("SPARK-43923: post analyzed with plan") { + val events = setupEvents(ExecuteStatus.Started) + + val mockPlan = mock[LogicalPlan] + events.postAnalyzed(Some(mockPlan)) + val event = SparkListenerConnectOperationAnalyzed( + events.executeHolder.jobTag, + DEFAULT_QUERY_ID, + DEFAULT_CLOCK.getTimeMillis()) + event.analyzedPlan = Some(mockPlan) + verify(events.executeHolder.sessionHolder.session.sparkContext.listenerBus, times(1)) + .post(event) + } + + test("SPARK-43923: post analyzed with empty plan") { + val events = setupEvents(ExecuteStatus.Started) + events.postAnalyzed() + verify(events.executeHolder.sessionHolder.session.sparkContext.listenerBus, times(1)) + .post( + SparkListenerConnectOperationAnalyzed( + events.executeHolder.jobTag, + DEFAULT_QUERY_ID, + DEFAULT_CLOCK.getTimeMillis())) + } + + test("SPARK-43923: post readyForExecution") { + val events = setupEvents(ExecuteStatus.Analyzed) + events.postReadyForExecution() + val event = SparkListenerConnectOperationReadyForExecution( + events.executeHolder.jobTag, + DEFAULT_QUERY_ID, + DEFAULT_CLOCK.getTimeMillis()) + verify(events.executeHolder.sessionHolder.session.sparkContext.listenerBus, times(1)) + .post(event) + } + + test("SPARK-43923: post canceled") { + val events = setupEvents(ExecuteStatus.Started) + events.postCanceled() + verify(events.executeHolder.sessionHolder.session.sparkContext.listenerBus, times(1)) + .post( + SparkListenerConnectOperationCanceled( + events.executeHolder.jobTag, + DEFAULT_QUERY_ID, + DEFAULT_CLOCK.getTimeMillis())) + } + + test("SPARK-43923: post failed") { + val events = setupEvents(ExecuteStatus.Started) + events.postFailed(DEFAULT_ERROR) + verify(events.executeHolder.sessionHolder.session.sparkContext.listenerBus, times(1)) + .post( + SparkListenerConnectOperationFailed( + events.executeHolder.jobTag, + DEFAULT_QUERY_ID, + DEFAULT_CLOCK.getTimeMillis(), + DEFAULT_ERROR, + Map.empty[String, String])) + } + + test("SPARK-43923: post finished") { + val events = setupEvents(ExecuteStatus.Started) + events.postFinished() + verify(events.executeHolder.sessionHolder.session.sparkContext.listenerBus, times(1)) + .post( + SparkListenerConnectOperationFinished( + events.executeHolder.jobTag, + DEFAULT_QUERY_ID, + DEFAULT_CLOCK.getTimeMillis())) + } + + test("SPARK-43923: post closed") { + val events = setupEvents(ExecuteStatus.Finished) + events.postClosed() + verify(events.executeHolder.sessionHolder.session.sparkContext.listenerBus, times(1)) + .post( + SparkListenerConnectOperationClosed( + events.executeHolder.jobTag, + DEFAULT_QUERY_ID, + DEFAULT_CLOCK.getTimeMillis())) + } + + test("SPARK-43923: Closed wrong order throws exception") { + val events = setupEvents(ExecuteStatus.Closed) + assertThrows[IllegalStateException] { + events.postStarted() + } + assertThrows[IllegalStateException] { + events.postAnalyzed() + } + assertThrows[IllegalStateException] { + events.postReadyForExecution() + } + assertThrows[IllegalStateException] { + events.postFinished() + } + assertThrows[IllegalStateException] { + events.postCanceled() + } + assertThrows[IllegalStateException] { + events.postClosed() + } + } + + test("SPARK-43923: Finished wrong order throws exception") { + val events = setupEvents(ExecuteStatus.Finished) + assertThrows[IllegalStateException] { + events.postStarted() + } + assertThrows[IllegalStateException] { + events.postAnalyzed() + } + assertThrows[IllegalStateException] { + events.postReadyForExecution() + } + assertThrows[IllegalStateException] { + events.postFinished() + } + } + + test("SPARK-43923: Failed wrong order throws exception") { + val events = setupEvents(ExecuteStatus.Finished) + assertThrows[IllegalStateException] { + events.postStarted() + } + assertThrows[IllegalStateException] { + events.postAnalyzed() + } + assertThrows[IllegalStateException] { + events.postReadyForExecution() + } + assertThrows[IllegalStateException] { + events.postFinished() + } + assertThrows[IllegalStateException] { + events.postFinished() + } + } + + test("SPARK-43923: Canceled wrong order throws exception") { + val events = setupEvents(ExecuteStatus.Canceled) + assertThrows[IllegalStateException] { + events.postStarted() + } + assertThrows[IllegalStateException] { + events.postAnalyzed() + } + assertThrows[IllegalStateException] { + events.postReadyForExecution() + } + assertThrows[IllegalStateException] { + events.postCanceled() + } + assertThrows[IllegalStateException] { + events.postFinished() + } + assertThrows[IllegalStateException] { + events.postFailed(DEFAULT_ERROR) + } + } + + test("SPARK-43923: ReadyForExecution wrong order throws exception") { + val events = setupEvents(ExecuteStatus.ReadyForExecution) + assertThrows[IllegalStateException] { + events.postStarted() + } + assertThrows[IllegalStateException] { + events.postAnalyzed() + } + assertThrows[IllegalStateException] { + events.postReadyForExecution() + } + assertThrows[IllegalStateException] { + events.postClosed() + } + } + + test("SPARK-43923: Analyzed wrong order throws exception") { + val events = setupEvents(ExecuteStatus.Analyzed) + assertThrows[IllegalStateException] { + events.postStarted() + } + assertThrows[IllegalStateException] { + events.postFinished() + } + assertThrows[IllegalStateException] { + events.postClosed() + } + } + + test("SPARK-43923: Started wrong order throws exception") { + val events = setupEvents(ExecuteStatus.Started) + assertThrows[IllegalStateException] { + events.postStarted() + } + assertThrows[IllegalStateException] { + events.postReadyForExecution() + } + assertThrows[IllegalStateException] { + events.postClosed() + } + } + + test("SPARK-43923: Started wrong session status") { + val events = setupEvents(ExecuteStatus.Started, SessionStatus.Pending) + assertThrows[IllegalStateException] { + events.postStarted() + } + } + + def setupEvents( + executeStatus: ExecuteStatus, + sessionStatus: SessionStatus = SessionStatus.Started): ExecuteEventsManager = { + val mockSession = mock[SparkSession] + val sessionHolder = SessionHolder(DEFAULT_USER_ID, DEFAULT_SESSION_ID, mockSession) + sessionHolder.eventManager.status_(sessionStatus) + val mockContext = mock[SparkContext] + val mockListenerBus = mock[LiveListenerBus] + val mockSessionState = mock[SessionState] + val mockConf = mock[SQLConf] + when(mockSession.sessionState).thenReturn(mockSessionState) + when(mockSessionState.conf).thenReturn(mockConf) + when(mockConf.stringRedactionPattern).thenReturn(Option.empty[Regex]) + when(mockContext.listenerBus).thenReturn(mockListenerBus) + when(mockSession.sparkContext).thenReturn(mockContext) + + val relation = proto.Relation.newBuilder + .setLimit(proto.Limit.newBuilder.setLimit(10)) + .build() + + val executePlanRequest = ExecutePlanRequest + .newBuilder() + .setPlan(Plan.newBuilder().setRoot(relation)) + .setUserContext( + UserContext + .newBuilder() + .setUserId(DEFAULT_USER_ID) + .setUserName(DEFAULT_USER_NAME)) + .setSessionId(DEFAULT_SESSION_ID) + .setClientType(DEFAULT_CLIENT_TYPE) + .build() + + val executeHolder = new ExecuteHolder(executePlanRequest, DEFAULT_QUERY_ID, sessionHolder) + + val eventsManager = ExecuteEventsManager(executeHolder, DEFAULT_CLOCK) + eventsManager.status_(executeStatus) + eventsManager + } +} diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SessionEventsManagerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SessionEventsManagerSuite.scala new file mode 100644 index 000000000000..7025146b0295 --- /dev/null +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SessionEventsManagerSuite.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.service + +import org.mockito.Mockito._ +import org.scalatestplus.mockito.MockitoSugar + +import org.apache.spark.{SparkContext, SparkFunSuite} +import org.apache.spark.scheduler.LiveListenerBus +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connect.planner.SparkConnectPlanTest +import org.apache.spark.util.ManualClock + +class SessionEventsManagerSuite + extends SparkFunSuite + with MockitoSugar + with SparkConnectPlanTest { + + val DEFAULT_ERROR = "error" + val DEFAULT_CLOCK = new ManualClock() + val DEFAULT_NODE_NAME = "nodeName" + val DEFAULT_TEXT = """limit { + limit: 10 +} +""" + val DEFAULT_USER_ID = "1" + val DEFAULT_USER_NAME = "userName" + val DEFAULT_SESSION_ID = "2" + val DEFAULT_QUERY_ID = "3" + val DEFAULT_CLIENT_TYPE = "clientType" + + test("SPARK-43923: post started") { + val events = setupEvents(SessionStatus.Pending) + events.postStarted() + + verify(events.sessionHolder.session.sparkContext.listenerBus, times(1)) + .post( + SparkListenerConnectSessionStarted( + DEFAULT_SESSION_ID, + DEFAULT_USER_ID, + DEFAULT_CLOCK.getTimeMillis(), + Map.empty)) + } + + test("SPARK-43923: post closed") { + val events = setupEvents(SessionStatus.Started) + events.postClosed() + + verify(events.sessionHolder.session.sparkContext.listenerBus, times(1)) + .post( + SparkListenerConnectSessionClosed( + DEFAULT_SESSION_ID, + DEFAULT_USER_ID, + DEFAULT_CLOCK.getTimeMillis(), + Map.empty)) + } + + test("SPARK-43923: Started wrong order throws exception") { + val events = setupEvents(SessionStatus.Started) + assertThrows[IllegalStateException] { + events.postStarted() + } + } + + test("SPARK-43923: Closed wrong order throws exception") { + val events = setupEvents(SessionStatus.Closed) + assertThrows[IllegalStateException] { + events.postStarted() + } + assertThrows[IllegalStateException] { + events.postClosed() + } + } + + def setupEvents(status: SessionStatus): SessionEventsManager = { + val mockSession = mock[SparkSession] + val sessionHolder = SessionHolder(DEFAULT_USER_ID, DEFAULT_SESSION_ID, mockSession) + val mockContext = mock[SparkContext] + val mockListenerBus = mock[LiveListenerBus] + when(mockContext.listenerBus).thenReturn(mockListenerBus) + when(mockSession.sparkContext).thenReturn(mockContext) + + val eventsManager = SessionEventsManager(sessionHolder, DEFAULT_CLOCK) + eventsManager.status_(status) + eventsManager + } +}