Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -2166,6 +2166,12 @@
"Number of given aliases does not match number of output columns. Function name: <funcName>; number of aliases: <aliasesNum>; number of output columns: <outColsNum>."
]
},
"OPERATION_CANCELED" : {
"message" : [
"Operation has been canceled."
],
"sqlState" : "HY008"
},
"ORDER_BY_POS_OUT_OF_RANGE" : {
"message" : [
"ORDER BY position <index> is not in select list (valid range is [1, <size>])."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,15 @@ class SparkSessionE2ESuite extends RemoteSparkSession {
q1.onComplete {
case Success(_) =>
error = Some("q1 shouldn't have finished!")
case Failure(t) if t.getMessage.contains("cancelled") =>
case Failure(t) if t.getMessage.contains("OPERATION_CANCELED") =>
q1Interrupted = true
case Failure(t) =>
error = Some("unexpected failure in q1: " + t.toString)
}
q2.onComplete {
case Success(_) =>
error = Some("q2 shouldn't have finished!")
case Failure(t) if t.getMessage.contains("cancelled") =>
case Failure(t) if t.getMessage.contains("OPERATION_CANCELED") =>
q2Interrupted = true
case Failure(t) =>
error = Some("unexpected failure in q2: " + t.toString)
Expand Down Expand Up @@ -89,11 +89,11 @@ class SparkSessionE2ESuite extends RemoteSparkSession {
val e1 = intercept[SparkException] {
spark.range(10).map(n => { Thread.sleep(30.seconds.toMillis); n }).collect()
}
assert(e1.getMessage.contains("cancelled"), s"Unexpected exception: $e1")
assert(e1.getMessage.contains("OPERATION_CANCELED"), s"Unexpected exception: $e1")
val e2 = intercept[SparkException] {
spark.range(10).map(n => { Thread.sleep(30.seconds.toMillis); n }).collect()
}
assert(e2.getMessage.contains("cancelled"), s"Unexpected exception: $e2")
assert(e2.getMessage.contains("OPERATION_CANCELED"), s"Unexpected exception: $e2")
finished = true
assert(ThreadUtils.awaitResult(interruptor, 10.seconds))
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connect.execution

private[execution] case class CachedStreamResponse[T](
// the actual cached response
response: T,
// index of the response in the response stream.
// responses produced in the stream are numbered consecutively starting from 1.
streamIndex: Long)
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connect.execution

import io.grpc.stub.StreamObserver

import org.apache.spark.internal.Logging

/**
* ExecuteGrpcResponseSender sends responses to the GRPC stream. It runs on the RPC thread, and
* gets notified by ExecuteResponseObserver about available responses. It notifies the
* ExecuteResponseObserver back about cached responses that can be removed after being sent out.
* @param responseObserver
* the GRPC request StreamObserver
*/
private[connect] class ExecuteGrpcResponseSender[T](grpcObserver: StreamObserver[T])
extends Logging {

private var detached = false

/**
* Detach this sender from executionObserver. Called only from executionObserver that this
* sender is attached to. executionObserver holds lock, and needs to notify after this call.
*/
def detach(): Unit = {
if (detached == true) {
throw new IllegalStateException("ExecuteGrpcResponseSender already detached!")
}
detached = true
}

/**
* Attach to the executionObserver, consume responses from it, and send them to grpcObserver.
* @param lastConsumedStreamIndex
* the last index that was already consumed and sent. This sender will start from index after
* that. 0 means start from beginning (since first response has index 1)
*
* @return
* true if the execution was detached before stream completed. The caller needs to finish the
* grpcObserver stream false if stream was finished. In this case, grpcObserver stream is
* already completed.
*/
def run(
executionObserver: ExecuteResponseObserver[T],
lastConsumedStreamIndex: Long): Boolean = {
// register to be notified about available responses.
executionObserver.attachConsumer(this)

var nextIndex = lastConsumedStreamIndex + 1
var finished = false

while (!finished) {
var response: Option[CachedStreamResponse[T]] = None
// Get next available response.
// Wait until either this sender got detached or next response is ready,
// or the stream is complete and it had already sent all responses.
logDebug(s"Trying to get next response with index=$nextIndex.")
executionObserver.synchronized {
logDebug(s"Acquired lock.")
while (!detached && response.isEmpty &&
executionObserver.getLastIndex().forall(nextIndex <= _)) {
logDebug(s"Try to get response with index=$nextIndex from observer.")
response = executionObserver.getResponse(nextIndex)
logDebug(s"Response index=$nextIndex from observer: ${response.isDefined}")
// If response is empty, release executionObserver lock and wait to get notified.
// The state of detached, response and lastIndex are change under lock in
// executionObserver, and will notify upon state change.
if (response.isEmpty) {
logDebug(s"Wait for response to become available.")
executionObserver.wait()
logDebug(s"Reacquired lock after waiting.")
}
}
logDebug(
s"Exiting loop: detached=$detached, response=$response," +
s"lastIndex=${executionObserver.getLastIndex()}")
}

// Send next available response.
if (detached) {
// This sender got detached by the observer.
logDebug(s"Detached from observer at index ${nextIndex - 1}. Complete stream.")
finished = true
} else if (response.isDefined) {
// There is a response available to be sent.
grpcObserver.onNext(response.get.response)
logDebug(s"Sent response index=$nextIndex.")
nextIndex += 1
} else if (executionObserver.getLastIndex().forall(nextIndex > _)) {
// Stream is finished and all responses have been sent
logDebug(s"Sent all responses up to index ${nextIndex - 1}.")
executionObserver.getError() match {
case Some(t) => grpcObserver.onError(t)
case None => grpcObserver.onCompleted()
}
finished = true
}
}
// Return true if stream finished, or false if was detached.
detached
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connect.execution

import scala.collection.mutable

import io.grpc.stub.StreamObserver

import org.apache.spark.internal.Logging

/**
* This StreamObserver is running on the execution thread. Execution pushes responses to it, it
* caches them. ExecuteResponseGRPCSender is the consumer of the responses ExecuteResponseObserver
* "produces". It waits on the monitor of ExecuteResponseObserver. New produced responses notify
* the monitor.
* @see
* getResponse.
*
* ExecuteResponseObserver controls how responses stay cached after being returned to consumer,
* @see
* removeCachedResponses.
*
* A single ExecuteResponseGRPCSender can be attached to the ExecuteResponseObserver. Attaching a
* new one will notify an existing one that it was detached.
* @see
* attachConsumer
*/
private[connect] class ExecuteResponseObserver[T]() extends StreamObserver[T] with Logging {

/**
* Cached responses produced by the execution. Map from response index -> response. Response
* indexes are numbered consecutively starting from 1.
*/
private val responses: mutable.Map[Long, CachedStreamResponse[T]] =
new mutable.HashMap[Long, CachedStreamResponse[T]]()

/** Cached error of the execution, if an error was thrown. */
private var error: Option[Throwable] = None

/**
* If execution stream is finished (completed or with error), the index of the final response.
*/
private var finalProducedIndex: Option[Long] = None // index of final response before completed.

/** The index of the last response produced by execution. */
private var lastProducedIndex: Long = 0 // first response will have index 1

/**
* Highest response index that was consumed. Keeps track of it to decide which responses needs
* to be cached, and to assert that all responses are consumed.
*/
private var highestConsumedIndex: Long = 0

/**
* Consumer that waits for available responses. There can be only one at a time, @see
* attachConsumer.
*/
private var responseSender: Option[ExecuteGrpcResponseSender[T]] = None

def onNext(r: T): Unit = synchronized {
if (finalProducedIndex.nonEmpty) {
throw new IllegalStateException("Stream onNext can't be called after stream completed")
}
lastProducedIndex += 1
responses += ((lastProducedIndex, CachedStreamResponse[T](r, lastProducedIndex)))
logDebug(s"Saved response with index=$lastProducedIndex")
notifyAll()
}

def onError(t: Throwable): Unit = synchronized {
if (finalProducedIndex.nonEmpty) {
throw new IllegalStateException("Stream onError can't be called after stream completed")
}
error = Some(t)
finalProducedIndex = Some(lastProducedIndex) // no responses to be send after error.
logDebug(s"Error. Last stream index is $lastProducedIndex.")
notifyAll()
}

def onCompleted(): Unit = synchronized {
if (finalProducedIndex.nonEmpty) {
throw new IllegalStateException("Stream onCompleted can't be called after stream completed")
}
finalProducedIndex = Some(lastProducedIndex)
logDebug(s"Completed. Last stream index is $lastProducedIndex.")
notifyAll()
}

/** Attach a new consumer (ExecuteResponseGRPCSender). */
def attachConsumer(newSender: ExecuteGrpcResponseSender[T]): Unit = synchronized {
// detach the current sender before attaching new one
// this.synchronized() needs to be held while detaching a sender, and the detached sender
// needs to be notified with notifyAll() afterwards.
responseSender.foreach(_.detach())
responseSender = Some(newSender)
notifyAll() // consumer
}

/** Get response with a given index in the stream, if set. */
def getResponse(index: Long): Option[CachedStreamResponse[T]] = synchronized {
// we index stream responses from 1, getting a lower index would be invalid.
assert(index >= 1)
// it would be invalid if consumer would skip a response
assert(index <= highestConsumedIndex + 1)
val ret = responses.get(index)
if (ret.isDefined) {
if (index > highestConsumedIndex) highestConsumedIndex = index
removeCachedResponses()
}
ret
}

/** Get the stream error if there is one, otherwise None. */
def getError(): Option[Throwable] = synchronized {
error
}

/** If the stream is finished, the index of the last response, otherwise None. */
def getLastIndex(): Option[Long] = synchronized {
finalProducedIndex
}

/** Returns if the stream is finished. */
def completed(): Boolean = synchronized {
finalProducedIndex.isDefined
}

/** Consumer (ExecuteResponseGRPCSender) waits on the monitor of ExecuteResponseObserver. */
private def notifyConsumer(): Unit = {
notifyAll()
}

/**
* Remove cached responses after response with lastReturnedIndex is returned from getResponse.
* Remove according to caching policy:
* - if query is not reattachable, remove all responses up to and including
* highestConsumedIndex.
*/
private def removeCachedResponses() = {
var i = highestConsumedIndex
while (i >= 1 && responses.get(i).isDefined) {
responses.remove(i)
i -= 1
}
}
}
Loading