Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
*/
protected def askTracker[T: ClassTag](message: Any): T = {
try {
trackerEndpoint.askWithReply[T](message)
trackerEndpoint.askWithRetry[T](message)
} catch {
case e: Exception =>
logError("Error communicating with MapOutputTracker", e)
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
SparkEnv.executorActorSystemName,
RpcAddress(host, port),
ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME)
Some(endpointRef.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump))
Some(endpointRef.askWithRetry[Array[ThreadStackTrace]](TriggerThreadDump))
}
} catch {
case e: Exception =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ private[spark] class CoarseGrainedExecutorBackend(
logInfo("Connecting to driver: " + driverUrl)
rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref =>
driver = Some(ref)
ref.sendWithReply[RegisteredExecutor.type](
ref.ask[RegisteredExecutor.type](
RegisterExecutor(executorId, self, hostPort, cores, extractLogUrls))
} onComplete {
case Success(msg) => Utils.tryLogNonFatalError {
Expand Down Expand Up @@ -154,7 +154,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
executorConf,
new SecurityManager(executorConf))
val driver = fetcher.setupEndpointRefByURI(driverUrl)
val props = driver.askWithReply[Seq[(String, String)]](RetrieveSparkProps) ++
val props = driver.askWithRetry[Seq[(String, String)]](RetrieveSparkProps) ++
Seq[(String, String)](("spark.app.id", appId))
fetcher.shutdown()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ private[spark] class Executor(

val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId)
try {
val response = heartbeatReceiverRef.askWithReply[HeartbeatResponse](message)
val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse](message)
if (response.reregisterBlockManager) {
logWarning("Told to re-register on heartbeat")
env.blockManager.reregister()
Expand Down
41 changes: 41 additions & 0 deletions core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.rpc

/**
* A callback that [[RpcEndpoint]] can use it to send back a message or failure. It's thread-safe
* and can be called in any thread.
*/
private[spark] trait RpcCallContext {

/**
* Reply a message to the sender. If the sender is [[RpcEndpoint]], its [[RpcEndpoint.receive]]
* will be called.
*/
def reply(response: Any): Unit

/**
* Report a failure to the sender.
*/
def sendFailure(e: Throwable): Unit

/**
* The sender of this message.
*/
def sender: RpcEndpointRef
}
148 changes: 148 additions & 0 deletions core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.rpc

import org.apache.spark.SparkException

/**
* A factory class to create the [[RpcEnv]]. It must have an empty constructor so that it can be
* created using Reflection.
*/
private[spark] trait RpcEnvFactory {

def create(config: RpcEnvConfig): RpcEnv
}

/**
* A trait that requires RpcEnv thread-safely sending messages to it.
*
* Thread-safety means processing of one message happens before processing of the next message by
* the same [[ThreadSafeRpcEndpoint]]. In the other words, changes to internal fields of a
* [[ThreadSafeRpcEndpoint]] are visible when processing the next message, and fields in the
* [[ThreadSafeRpcEndpoint]] need not be volatile or equivalent.
*
* However, there is no guarantee that the same thread will be executing the same
* [[ThreadSafeRpcEndpoint]] for different messages.
*/
private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint


/**
* An end point for the RPC that defines what functions to trigger given a message.
*
* It is guaranteed that `onStart`, `receive` and `onStop` will be called in sequence.
*
* The lift-cycle will be:
*
* constructor onStart receive* onStop
*
* Note: `receive` can be called concurrently. If you want `receive` is thread-safe, please use
* [[ThreadSafeRpcEndpoint]]
*
* If any error is thrown from one of [[RpcEndpoint]] methods except `onError`, `onError` will be
* invoked with the cause. If `onError` throws an error, [[RpcEnv]] will ignore it.
*/
private[spark] trait RpcEndpoint {

/**
* The [[RpcEnv]] that this [[RpcEndpoint]] is registered to.
*/
val rpcEnv: RpcEnv

/**
* The [[RpcEndpointRef]] of this [[RpcEndpoint]]. `self` will become valid when `onStart` is
* called. And `self` will become `null` when `onStop` is called.
*
* Note: Because before `onStart`, [[RpcEndpoint]] has not yet been registered and there is not
* valid [[RpcEndpointRef]] for it. So don't call `self` before `onStart` is called.
*/
final def self: RpcEndpointRef = {
require(rpcEnv != null, "rpcEnv has not been initialized")
rpcEnv.endpointRef(this)
}

/**
* Process messages from [[RpcEndpointRef.send]] or [[RpcCallContext.reply)]]. If receiving a
* unmatched message, [[SparkException]] will be thrown and sent to `onError`.
*/
def receive: PartialFunction[Any, Unit] = {
case _ => throw new SparkException(self + " does not implement 'receive'")
}

/**
* Process messages from [[RpcEndpointRef.ask]]. If receiving a unmatched message,
* [[SparkException]] will be thrown and sent to `onError`.
*/
def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case _ => context.sendFailure(new SparkException(self + " won't reply anything"))
}

/**
* Invoked when any exception is thrown during handling messages.
*/
def onError(cause: Throwable): Unit = {
// By default, throw e and let RpcEnv handle it
throw cause
}

/**
* Invoked before [[RpcEndpoint]] starts to handle any message.
*/
def onStart(): Unit = {
// By default, do nothing.
}

/**
* Invoked when [[RpcEndpoint]] is stopping.
*/
def onStop(): Unit = {
// By default, do nothing.
}

/**
* Invoked when `remoteAddress` is connected to the current node.
*/
def onConnected(remoteAddress: RpcAddress): Unit = {
// By default, do nothing.
}

/**
* Invoked when `remoteAddress` is lost.
*/
def onDisconnected(remoteAddress: RpcAddress): Unit = {
// By default, do nothing.
}

/**
* Invoked when some network error happens in the connection between the current node and
* `remoteAddress`.
*/
def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = {
// By default, do nothing.
}

/**
* A convenient method to stop [[RpcEndpoint]].
*/
final def stop(): Unit = {
val _self = self
if (_self != null) {
rpcEnv.stop(_self)
}
}
}
119 changes: 119 additions & 0 deletions core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.rpc

import scala.concurrent.{Await, Future}
import scala.concurrent.duration.FiniteDuration
import scala.reflect.ClassTag

import org.apache.spark.util.RpcUtils
import org.apache.spark.{SparkException, Logging, SparkConf}

/**
* A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe.
*/
private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf)
extends Serializable with Logging {

private[this] val maxRetries = RpcUtils.numRetries(conf)
private[this] val retryWaitMs = RpcUtils.retryWaitMs(conf)
private[this] val defaultAskTimeout = RpcUtils.askTimeout(conf)

/**
* return the address for the [[RpcEndpointRef]]
*/
def address: RpcAddress

def name: String

/**
* Sends a one-way asynchronous message. Fire-and-forget semantics.
*/
def send(message: Any): Unit

/**
* Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a [[Future]] to
* receive the reply within the specified timeout.
*
* This method only sends the message once and never retries.
*/
def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T]

/**
* Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a [[Future]] to
* receive the reply within a default timeout.
*
* This method only sends the message once and never retries.
*/
def ask[T: ClassTag](message: Any): Future[T] = ask(message, defaultAskTimeout)

/**
* Send a message to the corresponding [[RpcEndpoint]] and get its result within a default
* timeout, or throw a SparkException if this fails even after the default number of retries.
* The default `timeout` will be used in every trial of calling `sendWithReply`. Because this
* method retries, the message handling in the receiver side should be idempotent.
*
* Note: this is a blocking action which may cost a lot of time, so don't call it in an message
* loop of [[RpcEndpoint]].
*
* @param message the message to send
* @tparam T type of the reply message
* @return the reply message from the corresponding [[RpcEndpoint]]
*/
def askWithRetry[T: ClassTag](message: Any): T = askWithRetry(message, defaultAskTimeout)

/**
* Send a message to the corresponding [[RpcEndpoint.receive]] and get its result within a
* specified timeout, throw a SparkException if this fails even after the specified number of
* retries. `timeout` will be used in every trial of calling `sendWithReply`. Because this method
* retries, the message handling in the receiver side should be idempotent.
*
* Note: this is a blocking action which may cost a lot of time, so don't call it in an message
* loop of [[RpcEndpoint]].
*
* @param message the message to send
* @param timeout the timeout duration
* @tparam T type of the reply message
* @return the reply message from the corresponding [[RpcEndpoint]]
*/
def askWithRetry[T: ClassTag](message: Any, timeout: FiniteDuration): T = {
// TODO: Consider removing multiple attempts
var attempts = 0
var lastException: Exception = null
while (attempts < maxRetries) {
attempts += 1
try {
val future = ask[T](message, timeout)
val result = Await.result(future, timeout)
if (result == null) {
throw new SparkException("Actor returned null")
}
return result
} catch {
case ie: InterruptedException => throw ie
case e: Exception =>
lastException = e
logWarning(s"Error sending message [message = $message] in $attempts attempts", e)
}
Thread.sleep(retryWaitMs)
}

throw new SparkException(
s"Error sending message [message = $message]", lastException)
}
}
Loading