Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ object SparkEnv extends Logging {

val systemName = if (isDriver) driverSystemName else executorSystemName
val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, port.getOrElse(-1), conf,
securityManager, clientMode = !isDriver)
securityManager, numUsableCores, !isDriver)

// Figure out which port RpcEnv actually bound to in case the original port is 0 or occupied.
if (isDriver) {
Expand Down
6 changes: 4 additions & 2 deletions core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ private[spark] object RpcEnv {
conf: SparkConf,
securityManager: SecurityManager,
clientMode: Boolean = false): RpcEnv = {
create(name, host, host, port, conf, securityManager, clientMode)
create(name, host, host, port, conf, securityManager, 0, clientMode)
}

def create(
Expand All @@ -50,9 +50,10 @@ private[spark] object RpcEnv {
port: Int,
conf: SparkConf,
securityManager: SecurityManager,
numUsableCores: Int,
clientMode: Boolean): RpcEnv = {
val config = RpcEnvConfig(conf, name, bindAddress, advertiseAddress, port, securityManager,
clientMode)
numUsableCores, clientMode)
new NettyRpcEnvFactory().create(config)
}
}
Expand Down Expand Up @@ -201,4 +202,5 @@ private[spark] case class RpcEnvConfig(
advertiseAddress: String,
port: Int,
securityManager: SecurityManager,
numUsableCores: Int,
clientMode: Boolean)
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ import org.apache.spark.util.ThreadUtils

/**
* A message dispatcher, responsible for routing RPC messages to the appropriate endpoint(s).
*
* @param numUsableCores Number of CPU cores allocated to the process, for sizing the thread pool.
* If 0, will consider the available CPUs on the host.
*/
private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) extends Logging {

private class EndpointData(
val name: String,
Expand Down Expand Up @@ -189,8 +192,10 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {

/** Thread pool used for dispatching messages. */
private val threadpool: ThreadPoolExecutor = {
val availableCores =
if (numUsableCores > 0) numUsableCores else Runtime.getRuntime.availableProcessors()
val numThreads = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.numThreads",
math.max(2, Runtime.getRuntime.availableProcessors()))
math.max(2, availableCores))
val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop")
for (i <- 0 until numThreads) {
pool.execute(new MessageLoop)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,15 @@ private[netty] class NettyRpcEnv(
val conf: SparkConf,
javaSerializerInstance: JavaSerializerInstance,
host: String,
securityManager: SecurityManager) extends RpcEnv(conf) with Logging {
securityManager: SecurityManager,
numUsableCores: Int) extends RpcEnv(conf) with Logging {

private[netty] val transportConf = SparkTransportConf.fromSparkConf(
conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"),
"rpc",
conf.getInt("spark.rpc.io.threads", 0))

private val dispatcher: Dispatcher = new Dispatcher(this)
private val dispatcher: Dispatcher = new Dispatcher(this, numUsableCores)

private val streamManager = new NettyStreamManager(this)

Expand Down Expand Up @@ -448,7 +449,7 @@ private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance]
val nettyEnv =
new NettyRpcEnv(sparkConf, javaSerializerInstance, config.advertiseAddress,
config.securityManager)
config.securityManager, config.numUsableCores)
if (!config.clientMode) {
val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort =>
nettyEnv.startServer(config.bindAddress, actualPort)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class NettyRpcEnvSuite extends RpcEnvSuite with MockitoSugar {
port: Int,
clientMode: Boolean = false): RpcEnv = {
val config = RpcEnvConfig(conf, "test", "localhost", "localhost", port,
new SecurityManager(conf), clientMode)
new SecurityManager(conf), 0, clientMode)
new NettyRpcEnvFactory().create(config)
}

Expand All @@ -47,7 +47,7 @@ class NettyRpcEnvSuite extends RpcEnvSuite with MockitoSugar {
test("advertise address different from bind address") {
val sparkConf = new SparkConf()
val config = RpcEnvConfig(sparkConf, "test", "localhost", "example.com", 0,
new SecurityManager(sparkConf), false)
new SecurityManager(sparkConf), 0, false)
val env = new NettyRpcEnvFactory().create(config)
try {
assert(env.address.hostPort.startsWith("example.com:"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,10 @@ private[spark] class ApplicationMaster(
}

private def runExecutorLauncher(securityMgr: SecurityManager): Unit = {
rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, -1, sparkConf, securityMgr,
clientMode = true)
val hostname = Utils.localHostName
val amCores = sparkConf.get(AM_CORES)
rpcEnv = RpcEnv.create("sparkYarnAM", hostname, hostname, -1, sparkConf, securityMgr,
amCores, true)
val driverRef = waitForSparkDriver()
addAmIpFilter()
registerAM(sparkConf, rpcEnv, driverRef, sparkConf.getOption("spark.driver.appUIAddress"),
Expand Down