Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ private[spark] class CoarseGrainedExecutorBackend(
userClassPath: Seq[URL],
env: SparkEnv,
resourcesFileOpt: Option[String])
extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging {
extends IsolatedRpcEndpoint with ExecutorBackend with Logging {

private implicit val formats = DefaultFormats

Expand Down
16 changes: 16 additions & 0 deletions core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,19 @@ private[spark] trait RpcEndpoint {
* [[ThreadSafeRpcEndpoint]] for different messages.
*/
private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint

/**
* An endpoint that uses a dedicated thread pool for delivering messages.
*/
private[spark] trait IsolatedRpcEndpoint extends RpcEndpoint {

/**
* How many threads to use for delivering messages. By default, use a single thread.
*
* Note that requesting more than one thread means that the endpoint should be able to handle
* messages arriving from many threads at once, and all the things that entails (including
* messages being delivered to the endpoint out of order).
*/
def threadCount(): Int = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to wrap my head around what happens if you create an IsolatedRpcEndpoint with threadCount() > 1, given the code in Inbox which checks for inheritance from ThreadSafeRpcEndpoint:

if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {

I guess if you expect one endpoint to be served by multiple threads, it makes sense you'd want Inbox.enableConcurrent = false and you'd have to make your endpoint safe for that -- but worth a comment here at least.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have the same question with @squito . How do you deal with ThreadSafeRpcEndpoint ?

Though we could set Inbox.enableConcurrent = false with threadCount() > 0, but multiple threads would be wasted.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I already updated the comment. ThreadSafeRpcEndpoint is irrelevant here. You may even extend both if you want; but if you do that, either it does nothing (because the thread pool has a single thread), or you're doing it wrong (because the thread pool has multiple thread but you just want one).

So it's pointless to mix in both traits.


}
130 changes: 38 additions & 92 deletions core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,16 @@

package org.apache.spark.rpc.netty

import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit}
import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, CountDownLatch}
import javax.annotation.concurrent.GuardedBy

import scala.collection.JavaConverters._
import scala.concurrent.Promise
import scala.util.control.NonFatal

import org.apache.spark.{SparkConf, SparkContext, SparkException}
import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.EXECUTOR_ID
import org.apache.spark.internal.config.Network.RPC_NETTY_DISPATCHER_NUM_THREADS
import org.apache.spark.network.client.RpcResponseCallback
import org.apache.spark.rpc._
import org.apache.spark.util.ThreadUtils

/**
* A message dispatcher, responsible for routing RPC messages to the appropriate endpoint(s).
Expand All @@ -40,20 +36,23 @@ import org.apache.spark.util.ThreadUtils
*/
private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) extends Logging {

private class EndpointData(
val name: String,
val endpoint: RpcEndpoint,
val ref: NettyRpcEndpointRef) {
val inbox = new Inbox(ref, endpoint)
}

private val endpoints: ConcurrentMap[String, EndpointData] =
new ConcurrentHashMap[String, EndpointData]
private val endpoints: ConcurrentMap[String, MessageLoop] =
new ConcurrentHashMap[String, MessageLoop]
private val endpointRefs: ConcurrentMap[RpcEndpoint, RpcEndpointRef] =
new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]

// Track the receivers whose inboxes may contain messages.
private val receivers = new LinkedBlockingQueue[EndpointData]
private val shutdownLatch = new CountDownLatch(1)
private lazy val sharedLoop = new SharedMessageLoop(nettyEnv.conf, this, numUsableCores)

private def getMessageLoop(name: String, endpoint: RpcEndpoint): MessageLoop = {
endpoint match {
case e: IsolatedRpcEndpoint =>
new DedicatedMessageLoop(name, e, this)
case _ =>
sharedLoop.register(name, endpoint)
sharedLoop
}
}

/**
* True if the dispatcher has been stopped. Once stopped, all messages posted will be bounced
Expand All @@ -69,13 +68,11 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) exte
if (stopped) {
throw new IllegalStateException("RpcEnv has been stopped")
}
if (endpoints.putIfAbsent(name, new EndpointData(name, endpoint, endpointRef)) != null) {
if (endpoints.putIfAbsent(name, getMessageLoop(name, endpoint)) != null) {
throw new IllegalArgumentException(s"There is already an RpcEndpoint called $name")
}
val data = endpoints.get(name)
endpointRefs.put(data.endpoint, data.ref)
receivers.offer(data) // for the OnStart message
}
endpointRefs.put(endpoint, endpointRef)
endpointRef
}

Expand All @@ -85,10 +82,9 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) exte

// Should be idempotent
private def unregisterRpcEndpoint(name: String): Unit = {
val data = endpoints.remove(name)
if (data != null) {
data.inbox.stop()
receivers.offer(data) // for the OnStop message
val loop = endpoints.remove(name)
if (loop != null) {
loop.unregister(name)
}
// Don't clean `endpointRefs` here because it's possible that some messages are being processed
// now and they can use `getRpcEndpointRef`. So `endpointRefs` will be cleaned in Inbox via
Expand Down Expand Up @@ -155,14 +151,13 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) exte
message: InboxMessage,
callbackIfStopped: (Exception) => Unit): Unit = {
val error = synchronized {
val data = endpoints.get(endpointName)
val loop = endpoints.get(endpointName)
if (stopped) {
Some(new RpcEnvStoppedException())
} else if (data == null) {
} else if (loop == null) {
Some(new SparkException(s"Could not find $endpointName."))
} else {
data.inbox.post(message)
receivers.offer(data)
loop.post(endpointName, message)
None
}
}
Expand All @@ -177,15 +172,23 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) exte
}
stopped = true
}
// Stop all endpoints. This will queue all endpoints for processing by the message loops.
endpoints.keySet().asScala.foreach(unregisterRpcEndpoint)
// Enqueue a message that tells the message loops to stop.
receivers.offer(PoisonPill)
threadpool.shutdown()
var stopSharedLoop = false
endpoints.asScala.foreach { case (name, loop) =>
unregisterRpcEndpoint(name)
if (!loop.isInstanceOf[SharedMessageLoop]) {
loop.stop()
} else {
stopSharedLoop = true
}
}
if (stopSharedLoop) {
sharedLoop.stop()
}
shutdownLatch.countDown()
}

def awaitTermination(): Unit = {
threadpool.awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS)
shutdownLatch.await()
}

/**
Expand All @@ -194,61 +197,4 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) exte
def verify(name: String): Boolean = {
endpoints.containsKey(name)
}

private def getNumOfThreads(conf: SparkConf): Int = {
val availableCores =
if (numUsableCores > 0) numUsableCores else Runtime.getRuntime.availableProcessors()

val modNumThreads = conf.get(RPC_NETTY_DISPATCHER_NUM_THREADS)
.getOrElse(math.max(2, availableCores))

conf.get(EXECUTOR_ID).map { id =>
val role = if (id == SparkContext.DRIVER_IDENTIFIER) "driver" else "executor"
conf.getInt(s"spark.$role.rpc.netty.dispatcher.numThreads", modNumThreads)
}.getOrElse(modNumThreads)
}

/** Thread pool used for dispatching messages. */
private val threadpool: ThreadPoolExecutor = {
val numThreads = getNumOfThreads(nettyEnv.conf)
val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop")
for (i <- 0 until numThreads) {
pool.execute(new MessageLoop)
}
pool
}

/** Message loop used for dispatching messages. */
private class MessageLoop extends Runnable {
override def run(): Unit = {
try {
while (true) {
try {
val data = receivers.take()
if (data == PoisonPill) {
// Put PoisonPill back so that other MessageLoops can see it.
receivers.offer(PoisonPill)
return
}
data.inbox.process(Dispatcher.this)
} catch {
case NonFatal(e) => logError(e.getMessage, e)
}
}
} catch {
case _: InterruptedException => // exit
case t: Throwable =>
try {
// Re-submit a MessageLoop so that Dispatcher will still work if
// UncaughtExceptionHandler decides to not kill JVM.
threadpool.execute(new MessageLoop)
} finally {
throw t
}
}
}
}

/** A poison endpoint that indicates MessageLoop should exit its message loop. */
private val PoisonPill = new EndpointData(null, null, null)
}
6 changes: 2 additions & 4 deletions core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ private[netty] case class RemoteProcessConnectionError(cause: Throwable, remoteA
/**
* An inbox that stores messages for an [[RpcEndpoint]] and posts messages to it thread-safely.
*/
private[netty] class Inbox(
val endpointRef: NettyRpcEndpointRef,
val endpoint: RpcEndpoint)
private[netty] class Inbox(val endpointName: String, val endpoint: RpcEndpoint)
extends Logging {

inbox => // Give this an alias so we can use it more clearly in closures.
Expand Down Expand Up @@ -195,7 +193,7 @@ private[netty] class Inbox(
* Exposed for testing.
*/
protected def onDrop(message: InboxMessage): Unit = {
logWarning(s"Drop $message because $endpointRef is stopped")
logWarning(s"Drop $message because endpoint $endpointName is stopped")
}

/**
Expand Down
Loading