diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 99a0a4ae4bad..e8cdaa6c63b3 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -2166,6 +2166,12 @@ "Number of given aliases does not match number of output columns. Function name: ; number of aliases: ; number of output columns: ." ] }, + "OPERATION_CANCELED" : { + "message" : [ + "Operation has been canceled." + ], + "sqlState" : "HY008" + }, "ORDER_BY_POS_OUT_OF_RANGE" : { "message" : [ "ORDER BY position is not in select list (valid range is [1, ])." diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala index 044f6a48cc8f..70eeb6c2c41d 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala @@ -49,7 +49,7 @@ class SparkSessionE2ESuite extends RemoteSparkSession { q1.onComplete { case Success(_) => error = Some("q1 shouldn't have finished!") - case Failure(t) if t.getMessage.contains("cancelled") => + case Failure(t) if t.getMessage.contains("OPERATION_CANCELED") => q1Interrupted = true case Failure(t) => error = Some("unexpected failure in q1: " + t.toString) @@ -57,7 +57,7 @@ class SparkSessionE2ESuite extends RemoteSparkSession { q2.onComplete { case Success(_) => error = Some("q2 shouldn't have finished!") - case Failure(t) if t.getMessage.contains("cancelled") => + case Failure(t) if t.getMessage.contains("OPERATION_CANCELED") => q2Interrupted = true case Failure(t) => error = Some("unexpected failure in q2: " + t.toString) @@ -89,11 +89,11 @@ class SparkSessionE2ESuite extends RemoteSparkSession { val e1 = intercept[SparkException] { spark.range(10).map(n => { Thread.sleep(30.seconds.toMillis); n }).collect() } - assert(e1.getMessage.contains("cancelled"), s"Unexpected exception: $e1") + assert(e1.getMessage.contains("OPERATION_CANCELED"), s"Unexpected exception: $e1") val e2 = intercept[SparkException] { spark.range(10).map(n => { Thread.sleep(30.seconds.toMillis); n }).collect() } - assert(e2.getMessage.contains("cancelled"), s"Unexpected exception: $e2") + assert(e2.getMessage.contains("OPERATION_CANCELED"), s"Unexpected exception: $e2") finished = true assert(ThreadUtils.awaitResult(interruptor, 10.seconds)) } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/CachedStreamResponse.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/CachedStreamResponse.scala new file mode 100644 index 000000000000..28eedca9d007 --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/CachedStreamResponse.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.execution + +private[execution] case class CachedStreamResponse[T]( + // the actual cached response + response: T, + // index of the response in the response stream. + // responses produced in the stream are numbered consecutively starting from 1. + streamIndex: Long) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala new file mode 100644 index 000000000000..2e5817fa504b --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.execution + +import io.grpc.stub.StreamObserver + +import org.apache.spark.internal.Logging + +/** + * ExecuteGrpcResponseSender sends responses to the GRPC stream. It runs on the RPC thread, and + * gets notified by ExecuteResponseObserver about available responses. It notifies the + * ExecuteResponseObserver back about cached responses that can be removed after being sent out. + * @param responseObserver + * the GRPC request StreamObserver + */ +private[connect] class ExecuteGrpcResponseSender[T](grpcObserver: StreamObserver[T]) + extends Logging { + + private var detached = false + + /** + * Detach this sender from executionObserver. Called only from executionObserver that this + * sender is attached to. executionObserver holds lock, and needs to notify after this call. + */ + def detach(): Unit = { + if (detached == true) { + throw new IllegalStateException("ExecuteGrpcResponseSender already detached!") + } + detached = true + } + + /** + * Attach to the executionObserver, consume responses from it, and send them to grpcObserver. + * @param lastConsumedStreamIndex + * the last index that was already consumed and sent. This sender will start from index after + * that. 0 means start from beginning (since first response has index 1) + * + * @return + * true if the execution was detached before stream completed. The caller needs to finish the + * grpcObserver stream false if stream was finished. In this case, grpcObserver stream is + * already completed. + */ + def run( + executionObserver: ExecuteResponseObserver[T], + lastConsumedStreamIndex: Long): Boolean = { + // register to be notified about available responses. + executionObserver.attachConsumer(this) + + var nextIndex = lastConsumedStreamIndex + 1 + var finished = false + + while (!finished) { + var response: Option[CachedStreamResponse[T]] = None + // Get next available response. + // Wait until either this sender got detached or next response is ready, + // or the stream is complete and it had already sent all responses. + logDebug(s"Trying to get next response with index=$nextIndex.") + executionObserver.synchronized { + logDebug(s"Acquired lock.") + while (!detached && response.isEmpty && + executionObserver.getLastIndex().forall(nextIndex <= _)) { + logDebug(s"Try to get response with index=$nextIndex from observer.") + response = executionObserver.getResponse(nextIndex) + logDebug(s"Response index=$nextIndex from observer: ${response.isDefined}") + // If response is empty, release executionObserver lock and wait to get notified. + // The state of detached, response and lastIndex are change under lock in + // executionObserver, and will notify upon state change. + if (response.isEmpty) { + logDebug(s"Wait for response to become available.") + executionObserver.wait() + logDebug(s"Reacquired lock after waiting.") + } + } + logDebug( + s"Exiting loop: detached=$detached, response=$response," + + s"lastIndex=${executionObserver.getLastIndex()}") + } + + // Send next available response. + if (detached) { + // This sender got detached by the observer. + logDebug(s"Detached from observer at index ${nextIndex - 1}. Complete stream.") + finished = true + } else if (response.isDefined) { + // There is a response available to be sent. + grpcObserver.onNext(response.get.response) + logDebug(s"Sent response index=$nextIndex.") + nextIndex += 1 + } else if (executionObserver.getLastIndex().forall(nextIndex > _)) { + // Stream is finished and all responses have been sent + logDebug(s"Sent all responses up to index ${nextIndex - 1}.") + executionObserver.getError() match { + case Some(t) => grpcObserver.onError(t) + case None => grpcObserver.onCompleted() + } + finished = true + } + } + // Return true if stream finished, or false if was detached. + detached + } +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala new file mode 100644 index 000000000000..5aecbdfce163 --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.execution + +import scala.collection.mutable + +import io.grpc.stub.StreamObserver + +import org.apache.spark.internal.Logging + +/** + * This StreamObserver is running on the execution thread. Execution pushes responses to it, it + * caches them. ExecuteResponseGRPCSender is the consumer of the responses ExecuteResponseObserver + * "produces". It waits on the monitor of ExecuteResponseObserver. New produced responses notify + * the monitor. + * @see + * getResponse. + * + * ExecuteResponseObserver controls how responses stay cached after being returned to consumer, + * @see + * removeCachedResponses. + * + * A single ExecuteResponseGRPCSender can be attached to the ExecuteResponseObserver. Attaching a + * new one will notify an existing one that it was detached. + * @see + * attachConsumer + */ +private[connect] class ExecuteResponseObserver[T]() extends StreamObserver[T] with Logging { + + /** + * Cached responses produced by the execution. Map from response index -> response. Response + * indexes are numbered consecutively starting from 1. + */ + private val responses: mutable.Map[Long, CachedStreamResponse[T]] = + new mutable.HashMap[Long, CachedStreamResponse[T]]() + + /** Cached error of the execution, if an error was thrown. */ + private var error: Option[Throwable] = None + + /** + * If execution stream is finished (completed or with error), the index of the final response. + */ + private var finalProducedIndex: Option[Long] = None // index of final response before completed. + + /** The index of the last response produced by execution. */ + private var lastProducedIndex: Long = 0 // first response will have index 1 + + /** + * Highest response index that was consumed. Keeps track of it to decide which responses needs + * to be cached, and to assert that all responses are consumed. + */ + private var highestConsumedIndex: Long = 0 + + /** + * Consumer that waits for available responses. There can be only one at a time, @see + * attachConsumer. + */ + private var responseSender: Option[ExecuteGrpcResponseSender[T]] = None + + def onNext(r: T): Unit = synchronized { + if (finalProducedIndex.nonEmpty) { + throw new IllegalStateException("Stream onNext can't be called after stream completed") + } + lastProducedIndex += 1 + responses += ((lastProducedIndex, CachedStreamResponse[T](r, lastProducedIndex))) + logDebug(s"Saved response with index=$lastProducedIndex") + notifyAll() + } + + def onError(t: Throwable): Unit = synchronized { + if (finalProducedIndex.nonEmpty) { + throw new IllegalStateException("Stream onError can't be called after stream completed") + } + error = Some(t) + finalProducedIndex = Some(lastProducedIndex) // no responses to be send after error. + logDebug(s"Error. Last stream index is $lastProducedIndex.") + notifyAll() + } + + def onCompleted(): Unit = synchronized { + if (finalProducedIndex.nonEmpty) { + throw new IllegalStateException("Stream onCompleted can't be called after stream completed") + } + finalProducedIndex = Some(lastProducedIndex) + logDebug(s"Completed. Last stream index is $lastProducedIndex.") + notifyAll() + } + + /** Attach a new consumer (ExecuteResponseGRPCSender). */ + def attachConsumer(newSender: ExecuteGrpcResponseSender[T]): Unit = synchronized { + // detach the current sender before attaching new one + // this.synchronized() needs to be held while detaching a sender, and the detached sender + // needs to be notified with notifyAll() afterwards. + responseSender.foreach(_.detach()) + responseSender = Some(newSender) + notifyAll() // consumer + } + + /** Get response with a given index in the stream, if set. */ + def getResponse(index: Long): Option[CachedStreamResponse[T]] = synchronized { + // we index stream responses from 1, getting a lower index would be invalid. + assert(index >= 1) + // it would be invalid if consumer would skip a response + assert(index <= highestConsumedIndex + 1) + val ret = responses.get(index) + if (ret.isDefined) { + if (index > highestConsumedIndex) highestConsumedIndex = index + removeCachedResponses() + } + ret + } + + /** Get the stream error if there is one, otherwise None. */ + def getError(): Option[Throwable] = synchronized { + error + } + + /** If the stream is finished, the index of the last response, otherwise None. */ + def getLastIndex(): Option[Long] = synchronized { + finalProducedIndex + } + + /** Returns if the stream is finished. */ + def completed(): Boolean = synchronized { + finalProducedIndex.isDefined + } + + /** Consumer (ExecuteResponseGRPCSender) waits on the monitor of ExecuteResponseObserver. */ + private def notifyConsumer(): Unit = { + notifyAll() + } + + /** + * Remove cached responses after response with lastReturnedIndex is returned from getResponse. + * Remove according to caching policy: + * - if query is not reattachable, remove all responses up to and including + * highestConsumedIndex. + */ + private def removeCachedResponses() = { + var i = highestConsumedIndex + while (i >= 1 && responses.get(i).isDefined) { + responses.remove(i) + i -= 1 + } + } +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala new file mode 100644 index 000000000000..b7b3d2adf9f7 --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.execution + +import scala.util.control.NonFatal + +import com.google.protobuf.Message +import org.apache.commons.lang3.StringUtils + +import org.apache.spark.SparkSQLException +import org.apache.spark.connect.proto +import org.apache.spark.internal.Logging +import org.apache.spark.sql.connect.common.ProtoUtils +import org.apache.spark.sql.connect.planner.SparkConnectPlanner +import org.apache.spark.sql.connect.service.ExecuteHolder +import org.apache.spark.sql.connect.utils.ErrorUtils +import org.apache.spark.util.Utils + +/** + * This class launches the actual execution in an execution thread. The execution pushes the + * responses to a ExecuteResponseObserver in executeHolder. + */ +private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends Logging { + + // The newly created thread will inherit all InheritableThreadLocals used by Spark, + // e.g. SparkContext.localProperties. If considering implementing a threadpool, + // forwarding of thread locals needs to be taken into account. + private var executionThread: Thread = new ExecutionThread() + + private var interrupted: Boolean = false + + /** Launches the execution in a background thread, returns immediately. */ + def start(): Unit = { + executionThread.start() + } + + /** Joins the background execution thread after it is finished. */ + def join(): Unit = { + executionThread.join() + } + + /** Interrupt the executing thread. */ + def interrupt(): Unit = { + synchronized { + interrupted = true + executionThread.interrupt() + } + } + + private def execute(): Unit = { + // Outer execute handles errors. + // Separate it from executeInternal to save on indent and improve readability. + try { + try { + executeInternal() + } catch { + // Need to catch throwable instead of NonFatal, because e.g. InterruptedException is fatal. + case e: Throwable => + logDebug(s"Exception in execute: $e") + // Always cancel all remaining execution after error. + executeHolder.sessionHolder.session.sparkContext.cancelJobsWithTag(executeHolder.jobTag) + // Rely on an internal interrupted flag, because Thread.interrupted() could be cleared, + // and different exceptions like InterruptedException, ClosedByInterruptException etc. + // could be thrown. + if (interrupted) { + // Turn the interrupt into OPERATION_CANCELED error. + throw new SparkSQLException("OPERATION_CANCELED", Map.empty) + } else { + // Rethrown the original error. + throw e + } + } finally { + executeHolder.sessionHolder.session.sparkContext.removeJobTag(executeHolder.jobTag) + } + } catch { + ErrorUtils.handleError( + "execute", + executeHolder.responseObserver, + executeHolder.sessionHolder.userId, + executeHolder.sessionHolder.sessionId) + } + } + + // Inner executeInternal is wrapped by execute() for error handling. + private def executeInternal() = { + // synchronized - check if already got interrupted while starting. + synchronized { + if (interrupted) { + throw new InterruptedException() + } + } + + // `withSession` ensures that session-specific artifacts (such as JARs and class files) are + // available during processing. + executeHolder.sessionHolder.withSession { session => + val debugString = requestString(executeHolder.request) + + // Set tag for query cancellation + session.sparkContext.addJobTag(executeHolder.jobTag) + session.sparkContext.setJobDescription( + s"Spark Connect - ${StringUtils.abbreviate(debugString, 128)}") + session.sparkContext.setInterruptOnCancel(true) + + // Add debug information to the query execution so that the jobs are traceable. + session.sparkContext.setLocalProperty( + "callSite.short", + s"Spark Connect - ${StringUtils.abbreviate(debugString, 128)}") + session.sparkContext.setLocalProperty( + "callSite.long", + StringUtils.abbreviate(debugString, 2048)) + + executeHolder.request.getPlan.getOpTypeCase match { + case proto.Plan.OpTypeCase.COMMAND => handleCommand(executeHolder.request) + case proto.Plan.OpTypeCase.ROOT => handlePlan(executeHolder.request) + case _ => + throw new UnsupportedOperationException( + s"${executeHolder.request.getPlan.getOpTypeCase} not supported.") + } + } + } + + private def handlePlan(request: proto.ExecutePlanRequest): Unit = { + val responseObserver = executeHolder.responseObserver + + val execution = new SparkConnectPlanExecution(executeHolder) + execution.handlePlan(responseObserver) + } + + private def handleCommand(request: proto.ExecutePlanRequest): Unit = { + val responseObserver = executeHolder.responseObserver + + val command = request.getPlan.getCommand + val planner = new SparkConnectPlanner(executeHolder.sessionHolder) + planner.process( + command = command, + userId = request.getUserContext.getUserId, + sessionId = request.getSessionId, + responseObserver = responseObserver) + responseObserver.onCompleted() + } + + private def requestString(request: Message) = { + try { + Utils.redact( + executeHolder.sessionHolder.session.sessionState.conf.stringRedactionPattern, + ProtoUtils.abbreviate(request).toString) + } catch { + case NonFatal(e) => + logWarning("Fail to extract debug information", e) + "UNKNOWN" + } + } + + private class ExecutionThread extends Thread { + override def run(): Unit = { + execute() + } + } +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala similarity index 57% rename from connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala rename to connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala index d809833d0129..74b4a5f65974 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala @@ -15,115 +15,60 @@ * limitations under the License. */ -package org.apache.spark.sql.connect.service +package org.apache.spark.sql.connect.execution import scala.collection.JavaConverters._ -import scala.util.control.NonFatal import com.google.protobuf.ByteString import io.grpc.stub.StreamObserver -import org.apache.commons.lang3.StringUtils import org.apache.spark.SparkEnv import org.apache.spark.connect.proto -import org.apache.spark.connect.proto.{ExecutePlanRequest, ExecutePlanResponse} -import org.apache.spark.internal.Logging +import org.apache.spark.connect.proto.ExecutePlanResponse import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ProtoUtils} +import org.apache.spark.sql.connect.common.DataTypeProtoConverter import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE import org.apache.spark.sql.connect.planner.SparkConnectPlanner -import org.apache.spark.sql.connect.service.SparkConnectStreamHandler.processAsArrowBatches -import org.apache.spark.sql.execution.{LocalTableScanExec, SparkPlan, SQLExecution} -import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, QueryStageExec} +import org.apache.spark.sql.connect.service.ExecuteHolder +import org.apache.spark.sql.connect.utils.MetricGenerator +import org.apache.spark.sql.execution.{LocalTableScanExec, SQLExecution} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.types.StructType -import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.util.ThreadUtils -class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResponse]) - extends Logging { - - def handle(v: ExecutePlanRequest): Unit = { - val sessionHolder = - SparkConnectService - .getOrCreateIsolatedSession(v.getUserContext.getUserId, v.getSessionId) - // `withSession` ensures that session-specific artifacts (such as JARs and class files) are - // available during processing. - sessionHolder.withSession { session => - // Add debug information to the query execution so that the jobs are traceable. - val debugString = - try { - Utils.redact( - session.sessionState.conf.stringRedactionPattern, - ProtoUtils.abbreviate(v).toString) - } catch { - case NonFatal(e) => - logWarning("Fail to extract debug information", e) - "UNKNOWN" - } - - val executeHolder = sessionHolder.createExecutePlanHolder(v) - session.sparkContext.addJobTag(executeHolder.jobTag) - session.sparkContext.setInterruptOnCancel(true) +/** + * Handle ExecutePlanRequest where the operation to handle is of `Plan` type. + * proto.Plan.OpTypeCase.ROOT + * @param executeHolder + */ +private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) { - try { - // Add debug information to the query execution so that the jobs are traceable. - session.sparkContext.setLocalProperty( - "callSite.short", - s"Spark Connect - ${StringUtils.abbreviate(debugString, 128)}") - session.sparkContext.setLocalProperty( - "callSite.long", - StringUtils.abbreviate(debugString, 2048)) - } catch { - case NonFatal(e) => - logWarning("Fail to attach the debug information", e) - } + private val sessionHolder = executeHolder.sessionHolder + private val session = executeHolder.session - try { - v.getPlan.getOpTypeCase match { - case proto.Plan.OpTypeCase.COMMAND => handleCommand(sessionHolder, v) - case proto.Plan.OpTypeCase.ROOT => handlePlan(sessionHolder, v) - case _ => - throw new UnsupportedOperationException(s"${v.getPlan.getOpTypeCase} not supported.") - } - } finally { - session.sparkContext.removeJobTag(executeHolder.jobTag) - sessionHolder.removeExecutePlanHolder(executeHolder.operationId) - } + def handlePlan(responseObserver: ExecuteResponseObserver[proto.ExecutePlanResponse]): Unit = { + val request = executeHolder.request + if (request.getPlan.getOpTypeCase != proto.Plan.OpTypeCase.ROOT) { + throw new IllegalStateException( + s"Illegal operation type ${request.getPlan.getOpTypeCase} to be handled here.") } - } - private def handlePlan(sessionHolder: SessionHolder, request: ExecutePlanRequest): Unit = { // Extract the plan from the request and convert it to a logical plan val planner = new SparkConnectPlanner(sessionHolder) val dataframe = Dataset.ofRows(sessionHolder.session, planner.transformRelation(request.getPlan.getRoot)) - responseObserver.onNext( - SparkConnectStreamHandler.sendSchemaToResponse(request.getSessionId, dataframe.schema)) + responseObserver.onNext(createSchemaResponse(request.getSessionId, dataframe.schema)) processAsArrowBatches(request.getSessionId, dataframe, responseObserver) responseObserver.onNext( - SparkConnectStreamHandler.createMetricsResponse(request.getSessionId, dataframe)) + MetricGenerator.createMetricsResponse(request.getSessionId, dataframe)) if (dataframe.queryExecution.observedMetrics.nonEmpty) { - responseObserver.onNext( - SparkConnectStreamHandler.sendObservedMetricsToResponse(request.getSessionId, dataframe)) + responseObserver.onNext(createObservedMetricsResponse(request.getSessionId, dataframe)) } responseObserver.onCompleted() } - private def handleCommand(sessionHolder: SessionHolder, request: ExecutePlanRequest): Unit = { - val command = request.getPlan.getCommand - val planner = new SparkConnectPlanner(sessionHolder) - planner.process( - command = command, - userId = request.getUserContext.getUserId, - sessionId = request.getSessionId, - responseObserver = responseObserver) - responseObserver.onCompleted() - } -} - -object SparkConnectStreamHandler { type Batch = (Array[Byte], Long) def rowToArrowConverter( @@ -142,7 +87,7 @@ object SparkConnectStreamHandler { batches.map(b => b -> batches.rowCountInLastBatch) } - def processAsArrowBatches( + private def processAsArrowBatches( sessionId: String, dataframe: DataFrame, responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { @@ -153,7 +98,7 @@ object SparkConnectStreamHandler { // Conservatively sets it 70% because the size is not accurate but estimated. val maxBatchSize = (SparkEnv.get.conf.get(CONNECT_GRPC_ARROW_MAX_BATCH_SIZE) * 0.7).toLong - val rowToArrowConverter = SparkConnectStreamHandler.rowToArrowConverter( + val converter = rowToArrowConverter( schema, maxRecordsPerBatch, maxBatchSize, @@ -175,7 +120,7 @@ object SparkConnectStreamHandler { dataframe.queryExecution.executedPlan match { case LocalTableScanExec(_, rows) => - rowToArrowConverter(rows.iterator).foreach { case (bytes, count) => + converter(rows.iterator).foreach { case (bytes, count) => sendBatch(bytes, count) } case _ => @@ -186,7 +131,7 @@ object SparkConnectStreamHandler { if (numPartitions > 0) { type Batch = (Array[Byte], Long) - val batches = rows.mapPartitionsInternal(rowToArrowConverter) + val batches = rows.mapPartitionsInternal(converter) val signal = new Object val partitions = new Array[Array[Batch]](numPartitions) @@ -263,7 +208,7 @@ object SparkConnectStreamHandler { } } - def sendSchemaToResponse(sessionId: String, schema: StructType): ExecutePlanResponse = { + private def createSchemaResponse(sessionId: String, schema: StructType): ExecutePlanResponse = { // Send the Spark data type ExecutePlanResponse .newBuilder() @@ -272,16 +217,7 @@ object SparkConnectStreamHandler { .build() } - def createMetricsResponse(sessionId: String, rows: DataFrame): ExecutePlanResponse = { - // Send a last batch with the metrics - ExecutePlanResponse - .newBuilder() - .setSessionId(sessionId) - .setMetrics(MetricGenerator.buildMetrics(rows.queryExecution.executedPlan)) - .build() - } - - def sendObservedMetricsToResponse( + private def createObservedMetricsResponse( sessionId: String, dataframe: DataFrame): ExecutePlanResponse = { val observedMetrics = dataframe.queryExecution.observedMetrics.map { case (name, row) => @@ -300,39 +236,3 @@ object SparkConnectStreamHandler { .build() } } - -object MetricGenerator extends AdaptiveSparkPlanHelper { - def buildMetrics(p: SparkPlan): ExecutePlanResponse.Metrics = { - val b = ExecutePlanResponse.Metrics.newBuilder - b.addAllMetrics(transformPlan(p, p.id).asJava) - b.build() - } - - private def transformChildren(p: SparkPlan): Seq[ExecutePlanResponse.Metrics.MetricObject] = { - allChildren(p).flatMap(c => transformPlan(c, p.id)) - } - - private def allChildren(p: SparkPlan): Seq[SparkPlan] = p match { - case a: AdaptiveSparkPlanExec => Seq(a.executedPlan) - case s: QueryStageExec => Seq(s.plan) - case _ => p.children - } - - private def transformPlan( - p: SparkPlan, - parentId: Int): Seq[ExecutePlanResponse.Metrics.MetricObject] = { - val mv = p.metrics.map(m => - m._1 -> ExecutePlanResponse.Metrics.MetricValue.newBuilder - .setName(m._2.name.getOrElse("")) - .setValue(m._2.value) - .setMetricType(m._2.metricType) - .build()) - val mo = ExecutePlanResponse.Metrics.MetricObject - .newBuilder() - .setName(p.nodeName) - .setPlanId(p.id) - .putAllExecutionMetrics(mv.asJava) - .build() - Seq(mo) ++ transformChildren(p) - } -} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index e0bee824195c..492396631f39 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -59,7 +59,7 @@ import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_ import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry import org.apache.spark.sql.connect.service.SessionHolder import org.apache.spark.sql.connect.service.SparkConnectService -import org.apache.spark.sql.connect.service.SparkConnectStreamHandler +import org.apache.spark.sql.connect.utils.MetricGenerator import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.arrow.ArrowConverters @@ -2419,7 +2419,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { .build()) // Send Metrics - responseObserver.onNext(SparkConnectStreamHandler.createMetricsResponse(sessionId, df)) + responseObserver.onNext(MetricGenerator.createMetricsResponse(sessionId, df)) } private def handleRegisterUserDefinedFunction( diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala new file mode 100644 index 000000000000..89aceaee1e4a --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.service + +import org.apache.spark.connect.proto +import org.apache.spark.internal.Logging +import org.apache.spark.sql.connect.execution.{ExecuteGrpcResponseSender, ExecuteResponseObserver, ExecuteThreadRunner} + +/** + * Object used to hold the Spark Connect execution state. + */ +private[connect] class ExecuteHolder( + val request: proto.ExecutePlanRequest, + val operationId: String, + val sessionHolder: SessionHolder) + extends Logging { + + val jobTag = + s"SparkConnect_Execute_" + + s"User_${sessionHolder.userId}_" + + s"Session_${sessionHolder.sessionId}_" + + s"Request_${operationId}" + + val session = sessionHolder.session + + val responseObserver: ExecuteResponseObserver[proto.ExecutePlanResponse] = + new ExecuteResponseObserver[proto.ExecutePlanResponse]() + + private val runner: ExecuteThreadRunner = new ExecuteThreadRunner(this) + + /** + * Start the execution. The execution is started in a background thread in ExecuteThreadRunner. + * Responses are produced and cached in ExecuteResponseObserver. A GRPC thread consumes the + * responses by attaching an ExecuteGrpcResponseSender, + * @see + * attachAndRunGrpcResponseSender. + */ + def start(): Unit = { + runner.start() + } + + /** + * Wait for the execution thread to finish and join it. + */ + def join(): Unit = { + runner.join() + } + + /** + * Attach an ExecuteGrpcResponseSender that will consume responses from the query and send them + * out on the Grpc response stream. + * @param responseSender + * the ExecuteGrpcResponseSender + * @param lastConsumedStreamIndex + * the last index that was already consumed. The consumer will start from index after that. 0 + * means start from beginning (since first response has index 1) + * @return + * true if the sender got detached without completing the stream. false if the executing + * stream was completely sent out. + */ + def attachAndRunGrpcResponseSender( + responseSender: ExecuteGrpcResponseSender[proto.ExecutePlanResponse], + lastConsumedStreamIndex: Long): Boolean = { + responseSender.run(responseObserver, lastConsumedStreamIndex) + } + + /** + * Interrupt the execution. Interrupts the running thread, which cancels all running Spark Jobs + * and makes the execution throw an OPERATION_CANCELED error. + */ + def interrupt(): Unit = { + runner.interrupt() + } +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecutePlanHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecutePlanHolder.scala deleted file mode 100644 index 9bf9df07e017..000000000000 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecutePlanHolder.scala +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connect.service - -import org.apache.spark.connect.proto - -/** - * Object used to hold the Spark Connect execution state. - */ -case class ExecutePlanHolder( - operationId: String, - sessionHolder: SessionHolder, - request: proto.ExecutePlanRequest) { - - val jobTag = - "SparkConnect_" + - s"User_${sessionHolder.userId}_Session_${sessionHolder.sessionId}_Request_${operationId}" - - def interrupt(): Unit = { - // TODO/WIP: This only interrupts active Spark jobs that are actively running. - // This would then throw the error from ExecutePlan and terminate it. - // But if the query is not running a Spark job, but executing code on Spark driver, this - // would be a noop and the execution will keep running. - sessionHolder.session.sparkContext.cancelJobsWithTag(jobTag) - } - -} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index a24a9eb2fece..7361e370062f 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -41,8 +41,8 @@ import org.apache.spark.util.Utils case class SessionHolder(userId: String, sessionId: String, session: SparkSession) extends Logging { - val executePlanOperations: ConcurrentMap[String, ExecutePlanHolder] = - new ConcurrentHashMap[String, ExecutePlanHolder]() + val executions: ConcurrentMap[String, ExecuteHolder] = + new ConcurrentHashMap[String, ExecuteHolder]() // Mapping from relation ID (passed to client) to runtime dataframe. Used for callbacks like // foreachBatch() in Streaming. Lazy since most sessions don't need it. @@ -53,23 +53,22 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio private lazy val listenerCache: ConcurrentMap[String, StreamingQueryListener] = new ConcurrentHashMap() - private[connect] def createExecutePlanHolder( - request: proto.ExecutePlanRequest): ExecutePlanHolder = { - + private[connect] def createExecuteHolder(request: proto.ExecutePlanRequest): ExecuteHolder = { val operationId = UUID.randomUUID().toString - val executePlanHolder = ExecutePlanHolder(operationId, this, request) - assert(executePlanOperations.putIfAbsent(operationId, executePlanHolder) == null) + val executePlanHolder = new ExecuteHolder(request, operationId, this) + assert(executions.putIfAbsent(operationId, executePlanHolder) == null) executePlanHolder } - private[connect] def removeExecutePlanHolder(operationId: String): Unit = { - executePlanOperations.remove(operationId) + private[connect] def removeExecuteHolder(operationId: String): Unit = { + executions.remove(operationId) } private[connect] def interruptAll(): Unit = { - executePlanOperations.asScala.values.foreach { execute => + executions.asScala.values.foreach { execute => // Eat exception while trying to interrupt a given execution and move forward. try { + logDebug(s"Interrupting execution ${execute.operationId}") execute.interrupt() } catch { case NonFatal(e) => diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala new file mode 100644 index 000000000000..50ca733b4391 --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.service + +import io.grpc.stub.StreamObserver + +import org.apache.spark.connect.proto +import org.apache.spark.internal.Logging +import org.apache.spark.sql.connect.execution.ExecuteGrpcResponseSender + +class SparkConnectExecutePlanHandler(responseObserver: StreamObserver[proto.ExecutePlanResponse]) + extends Logging { + + def handle(v: proto.ExecutePlanRequest): Unit = { + val sessionHolder = SparkConnectService + .getOrCreateIsolatedSession(v.getUserContext.getUserId, v.getSessionId) + val executeHolder = sessionHolder.createExecuteHolder(v) + + try { + executeHolder.start() + val responseSender = + new ExecuteGrpcResponseSender[proto.ExecutePlanResponse](responseObserver) + val detached = executeHolder.attachAndRunGrpcResponseSender(responseSender, 0) + if (detached) { + // Detached before execution finished. + // TODO this doesn't happen yet without reattachable execution. + responseObserver.onCompleted() + } + } finally { + // TODO this will change with detachable execution. + executeHolder.join() + sessionHolder.removeExecuteHolder(executeHolder.operationId) + } + } +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index 0f90bccaac8f..c38fbbdfcf97 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -19,32 +19,21 @@ package org.apache.spark.sql.connect.service import java.util.concurrent.TimeUnit -import scala.annotation.tailrec -import scala.collection.mutable.ArrayBuffer -import scala.util.control.NonFatal - import com.google.common.base.Ticker import com.google.common.cache.{CacheBuilder, RemovalListener, RemovalNotification} -import com.google.protobuf.{Any => ProtoAny} -import com.google.rpc.{Code => RPCCode, ErrorInfo, Status => RPCStatus} -import io.grpc.{Server, Status} +import io.grpc.Server import io.grpc.netty.NettyServerBuilder -import io.grpc.protobuf.StatusProto import io.grpc.protobuf.services.ProtoReflectionService import io.grpc.stub.StreamObserver import org.apache.commons.lang3.StringUtils -import org.apache.commons.lang3.exception.ExceptionUtils -import org.json4s.JsonDSL._ -import org.json4s.jackson.JsonMethods.{compact, render} -import org.apache.spark.{SparkEnv, SparkException, SparkThrowable} -import org.apache.spark.api.python.PythonException +import org.apache.spark.SparkEnv import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse} import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE, CONNECT_JVM_STACK_TRACE_MAX_SIZE} -import org.apache.spark.sql.internal.SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED +import org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE} +import org.apache.spark.sql.connect.utils.ErrorUtils /** * The SparkConnectService implementation. @@ -58,101 +47,10 @@ class SparkConnectService(debug: Boolean) extends proto.SparkConnectServiceGrpc.SparkConnectServiceImplBase with Logging { - private def allClasses(cl: Class[_]): Seq[Class[_]] = { - val classes = ArrayBuffer.empty[Class[_]] - if (cl != null && !cl.equals(classOf[java.lang.Object])) { - classes.append(cl) // Includes itself. - } - - @tailrec - def appendSuperClasses(clazz: Class[_]): Unit = { - if (clazz == null || clazz.equals(classOf[java.lang.Object])) return - classes.append(clazz.getSuperclass) - appendSuperClasses(clazz.getSuperclass) - } - - appendSuperClasses(cl) - classes.toSeq - } - - private def buildStatusFromThrowable(st: Throwable, stackTraceEnabled: Boolean): RPCStatus = { - val errorInfo = ErrorInfo - .newBuilder() - .setReason(st.getClass.getName) - .setDomain("org.apache.spark") - .putMetadata("classes", compact(render(allClasses(st.getClass).map(_.getName)))) - - lazy val stackTrace = Option(ExceptionUtils.getStackTrace(st)) - val withStackTrace = if (stackTraceEnabled && stackTrace.nonEmpty) { - val maxSize = SparkEnv.get.conf.get(CONNECT_JVM_STACK_TRACE_MAX_SIZE) - errorInfo.putMetadata("stackTrace", StringUtils.abbreviate(stackTrace.get, maxSize)) - } else { - errorInfo - } - - RPCStatus - .newBuilder() - .setCode(RPCCode.INTERNAL_VALUE) - .addDetails(ProtoAny.pack(withStackTrace.build())) - .setMessage(SparkConnectService.extractErrorMessage(st)) - .build() - } - - private def isPythonExecutionException(se: SparkException): Boolean = { - // See also pyspark.errors.exceptions.captured.convert_exception in PySpark. - se.getCause != null && se.getCause - .isInstanceOf[PythonException] && se.getCause.getStackTrace - .exists(_.toString.contains("org.apache.spark.sql.execution.python")) - } - - /** - * Common exception handling function for the Analysis and Execution methods. Closes the stream - * after the error has been sent. - * - * @param opType - * String value indicating the operation type (analysis, execution) - * @param observer - * The GRPC response observer. - * @tparam V - * @return - */ - private def handleError[V]( - opType: String, - observer: StreamObserver[V], - userId: String, - sessionId: String): PartialFunction[Throwable, Unit] = { - val session = - SparkConnectService - .getOrCreateIsolatedSession(userId, sessionId) - .session - val stackTraceEnabled = session.conf.get(PYSPARK_JVM_STACKTRACE_ENABLED) - - { - case se: SparkException if isPythonExecutionException(se) => - logError(s"Error during: $opType. UserId: $userId. SessionId: $sessionId.", se) - observer.onError( - StatusProto.toStatusRuntimeException( - buildStatusFromThrowable(se.getCause, stackTraceEnabled))) - - case e: Throwable if e.isInstanceOf[SparkThrowable] || NonFatal.apply(e) => - logError(s"Error during: $opType. UserId: $userId. SessionId: $sessionId.", e) - observer.onError( - StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e, stackTraceEnabled))) - - case e: Throwable => - logError(s"Error during: $opType. UserId: $userId. SessionId: $sessionId.", e) - observer.onError( - Status.UNKNOWN - .withCause(e) - .withDescription(StringUtils.abbreviate(e.getMessage, 2048)) - .asRuntimeException()) - } - } - /** * This is the main entry method for Spark Connect and all calls to execute a plan. * - * The plan execution is delegated to the [[SparkConnectStreamHandler]]. All error handling + * The plan execution is delegated to the [[SparkConnectExecutePlanHandler]]. All error handling * should be directly implemented in the deferred implementation. But this method catches * generic errors. * @@ -163,9 +61,9 @@ class SparkConnectService(debug: Boolean) request: proto.ExecutePlanRequest, responseObserver: StreamObserver[proto.ExecutePlanResponse]): Unit = { try { - new SparkConnectStreamHandler(responseObserver).handle(request) + new SparkConnectExecutePlanHandler(responseObserver).handle(request) } catch { - handleError( + ErrorUtils.handleError( "execute", observer = responseObserver, userId = request.getUserContext.getUserId, @@ -191,7 +89,7 @@ class SparkConnectService(debug: Boolean) try { new SparkConnectAnalyzeHandler(responseObserver).handle(request) } catch { - handleError( + ErrorUtils.handleError( "analyze", observer = responseObserver, userId = request.getUserContext.getUserId, @@ -212,7 +110,7 @@ class SparkConnectService(debug: Boolean) try { new SparkConnectConfigHandler(responseObserver).handle(request) } catch { - handleError( + ErrorUtils.handleError( "config", observer = responseObserver, userId = request.getUserContext.getUserId, @@ -239,7 +137,7 @@ class SparkConnectService(debug: Boolean) try { new SparkConnectArtifactStatusesHandler(responseObserver).handle(request) } catch - handleError( + ErrorUtils.handleError( "artifactStatus", observer = responseObserver, userId = request.getUserContext.getUserId, @@ -255,7 +153,7 @@ class SparkConnectService(debug: Boolean) try { new SparkConnectInterruptHandler(responseObserver).handle(request) } catch - handleError( + ErrorUtils.handleError( "interrupt", observer = responseObserver, userId = request.getUserContext.getUserId, diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala new file mode 100644 index 000000000000..d0f754827dad --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.utils + +import scala.annotation.tailrec +import scala.collection.mutable.ArrayBuffer +import scala.util.control.NonFatal + +import com.google.protobuf.{Any => ProtoAny} +import com.google.rpc.{Code => RPCCode, ErrorInfo, Status => RPCStatus} +import io.grpc.Status +import io.grpc.protobuf.StatusProto +import io.grpc.stub.StreamObserver +import org.apache.commons.lang3.StringUtils +import org.apache.commons.lang3.exception.ExceptionUtils +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods + +import org.apache.spark.{SparkEnv, SparkException, SparkThrowable} +import org.apache.spark.api.python.PythonException +import org.apache.spark.internal.Logging +import org.apache.spark.sql.connect.config.Connect +import org.apache.spark.sql.connect.service.SparkConnectService +import org.apache.spark.sql.internal.SQLConf + +private[connect] object ErrorUtils extends Logging { + private def allClasses(cl: Class[_]): Seq[Class[_]] = { + val classes = ArrayBuffer.empty[Class[_]] + if (cl != null && !cl.equals(classOf[java.lang.Object])) { + classes.append(cl) // Includes itself. + } + + @tailrec + def appendSuperClasses(clazz: Class[_]): Unit = { + if (clazz == null || clazz.equals(classOf[java.lang.Object])) return + classes.append(clazz.getSuperclass) + appendSuperClasses(clazz.getSuperclass) + } + + appendSuperClasses(cl) + classes.toSeq + } + + private def buildStatusFromThrowable(st: Throwable, stackTraceEnabled: Boolean): RPCStatus = { + val errorInfo = ErrorInfo + .newBuilder() + .setReason(st.getClass.getName) + .setDomain("org.apache.spark") + .putMetadata( + "classes", + JsonMethods.compact(JsonMethods.render(allClasses(st.getClass).map(_.getName)))) + + lazy val stackTrace = Option(ExceptionUtils.getStackTrace(st)) + val withStackTrace = if (stackTraceEnabled && stackTrace.nonEmpty) { + val maxSize = SparkEnv.get.conf.get(Connect.CONNECT_JVM_STACK_TRACE_MAX_SIZE) + errorInfo.putMetadata("stackTrace", StringUtils.abbreviate(stackTrace.get, maxSize)) + } else { + errorInfo + } + + RPCStatus + .newBuilder() + .setCode(RPCCode.INTERNAL_VALUE) + .addDetails(ProtoAny.pack(withStackTrace.build())) + .setMessage(SparkConnectService.extractErrorMessage(st)) + .build() + } + + private def isPythonExecutionException(se: SparkException): Boolean = { + // See also pyspark.errors.exceptions.captured.convert_exception in PySpark. + se.getCause != null && se.getCause + .isInstanceOf[PythonException] && se.getCause.getStackTrace + .exists(_.toString.contains("org.apache.spark.sql.execution.python")) + } + + /** + * Common exception handling function for RPC methods. Closes the stream after the error has + * been sent. + * + * @param opType + * String value indicating the operation type (analysis, execution) + * @param observer + * The GRPC response observer. + * @tparam V + * @return + */ + def handleError[V]( + opType: String, + observer: StreamObserver[V], + userId: String, + sessionId: String): PartialFunction[Throwable, Unit] = { + val session = + SparkConnectService + .getOrCreateIsolatedSession(userId, sessionId) + .session + val stackTraceEnabled = session.conf.get(SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED) + + { + case se: SparkException if isPythonExecutionException(se) => + logError(s"Error during: $opType. UserId: $userId. SessionId: $sessionId.", se) + observer.onError( + StatusProto.toStatusRuntimeException( + buildStatusFromThrowable(se.getCause, stackTraceEnabled))) + + case e: Throwable if e.isInstanceOf[SparkThrowable] || NonFatal.apply(e) => + logError(s"Error during: $opType. UserId: $userId. SessionId: $sessionId.", e) + observer.onError( + StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e, stackTraceEnabled))) + + case e: Throwable => + logError(s"Error during: $opType. UserId: $userId. SessionId: $sessionId.", e) + observer.onError( + Status.UNKNOWN + .withCause(e) + .withDescription(StringUtils.abbreviate(e.getMessage, 2048)) + .asRuntimeException()) + } + } +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/MetricGenerator.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/MetricGenerator.scala new file mode 100644 index 000000000000..88120e616efd --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/MetricGenerator.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.utils + +import scala.collection.JavaConverters._ + +import org.apache.spark.connect.proto.ExecutePlanResponse +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, QueryStageExec} + +/** + * Helper object for generating responses with metrics from queries. + */ +private[connect] object MetricGenerator extends AdaptiveSparkPlanHelper { + + def createMetricsResponse(sessionId: String, rows: DataFrame): ExecutePlanResponse = { + ExecutePlanResponse + .newBuilder() + .setSessionId(sessionId) + .setMetrics(MetricGenerator.buildMetrics(rows.queryExecution.executedPlan)) + .build() + } + + private def buildMetrics(p: SparkPlan): ExecutePlanResponse.Metrics = { + val b = ExecutePlanResponse.Metrics.newBuilder + b.addAllMetrics(transformPlan(p, p.id).asJava) + b.build() + } + + private def transformChildren(p: SparkPlan): Seq[ExecutePlanResponse.Metrics.MetricObject] = { + allChildren(p).flatMap(c => transformPlan(c, p.id)) + } + + private def allChildren(p: SparkPlan): Seq[SparkPlan] = p match { + case a: AdaptiveSparkPlanExec => Seq(a.executedPlan) + case s: QueryStageExec => Seq(s.plan) + case _ => p.children + } + + private def transformPlan( + p: SparkPlan, + parentId: Int): Seq[ExecutePlanResponse.Metrics.MetricObject] = { + val mv = p.metrics.map(m => + m._1 -> ExecutePlanResponse.Metrics.MetricValue.newBuilder + .setName(m._2.name.getOrElse("")) + .setValue(m._2.value) + .setMetricType(m._2.metricType) + .build()) + val mo = ExecutePlanResponse.Metrics.MetricObject + .newBuilder() + .setName(p.nodeName) + .setPlanId(p.id) + .putAllExecutionMetrics(mv.asJava) + .build() + Seq(mo) ++ transformChildren(p) + } +} diff --git a/docs/sql-error-conditions-sqlstates.md b/docs/sql-error-conditions-sqlstates.md index 6b4c7e62f712..5529c961b3bf 100644 --- a/docs/sql-error-conditions-sqlstates.md +++ b/docs/sql-error-conditions-sqlstates.md @@ -699,6 +699,21 @@ Spark SQL uses the following `SQLSTATE` classes: + +## Class `HY`: CLI-specific condition + + + + + + + + + + + +
SQLSTATEDescription and issuing error classes
HY008operation canceled
OPERATION_CANCELED +
## Class `XX`: internal error diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index 21eda114d06e..91b77a6452bc 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -1380,6 +1380,12 @@ SQLSTATE: none assigned Number of given aliases does not match number of output columns. Function name: ``; number of aliases: ``; number of output columns: ``. +### OPERATION_CANCELED + +[SQLSTATE: HY008](sql-error-conditions-sqlstates.html#class-HY-cli-specific-condition) + +Operation has been canceled. + ### ORDER_BY_POS_OUT_OF_RANGE [SQLSTATE: 42805](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)