diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 3196c1ece15eb..1095d7285781f 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -243,7 +243,7 @@ object SparkEnv extends Logging { val systemName = if (isDriver) driverSystemName else executorSystemName val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, port.getOrElse(-1), conf, - securityManager, clientMode = !isDriver) + securityManager, numUsableCores, !isDriver) // Figure out which port RpcEnv actually bound to in case the original port is 0 or occupied. if (isDriver) { diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 530743c03640b..de2cc56bc6b16 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -40,7 +40,7 @@ private[spark] object RpcEnv { conf: SparkConf, securityManager: SecurityManager, clientMode: Boolean = false): RpcEnv = { - create(name, host, host, port, conf, securityManager, clientMode) + create(name, host, host, port, conf, securityManager, 0, clientMode) } def create( @@ -50,9 +50,10 @@ private[spark] object RpcEnv { port: Int, conf: SparkConf, securityManager: SecurityManager, + numUsableCores: Int, clientMode: Boolean): RpcEnv = { val config = RpcEnvConfig(conf, name, bindAddress, advertiseAddress, port, securityManager, - clientMode) + numUsableCores, clientMode) new NettyRpcEnvFactory().create(config) } } @@ -201,4 +202,5 @@ private[spark] case class RpcEnvConfig( advertiseAddress: String, port: Int, securityManager: SecurityManager, + numUsableCores: Int, clientMode: Boolean) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index a02cf30a5d831..adde3293185cd 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -32,8 +32,11 @@ import org.apache.spark.util.ThreadUtils /** * A message dispatcher, responsible for routing RPC messages to the appropriate endpoint(s). + * + * @param numUsableCores Number of CPU cores allocated to the process, for sizing the thread pool. + * If 0, will consider the available CPUs on the host. */ -private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { +private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) extends Logging { private class EndpointData( val name: String, @@ -189,8 +192,10 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { /** Thread pool used for dispatching messages. */ private val threadpool: ThreadPoolExecutor = { + val availableCores = + if (numUsableCores > 0) numUsableCores else Runtime.getRuntime.availableProcessors() val numThreads = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.numThreads", - math.max(2, Runtime.getRuntime.availableProcessors())) + math.max(2, availableCores)) val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop") for (i <- 0 until numThreads) { pool.execute(new MessageLoop) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index b316e5443f639..4d88a1be6042f 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -44,14 +44,15 @@ private[netty] class NettyRpcEnv( val conf: SparkConf, javaSerializerInstance: JavaSerializerInstance, host: String, - securityManager: SecurityManager) extends RpcEnv(conf) with Logging { + securityManager: SecurityManager, + numUsableCores: Int) extends RpcEnv(conf) with Logging { private[netty] val transportConf = SparkTransportConf.fromSparkConf( conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"), "rpc", conf.getInt("spark.rpc.io.threads", 0)) - private val dispatcher: Dispatcher = new Dispatcher(this) + private val dispatcher: Dispatcher = new Dispatcher(this, numUsableCores) private val streamManager = new NettyStreamManager(this) @@ -448,7 +449,7 @@ private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance] val nettyEnv = new NettyRpcEnv(sparkConf, javaSerializerInstance, config.advertiseAddress, - config.securityManager) + config.securityManager, config.numUsableCores) if (!config.clientMode) { val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort => nettyEnv.startServer(config.bindAddress, actualPort) diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala index 2b1bce4d208f6..777163709bbf5 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala @@ -31,7 +31,7 @@ class NettyRpcEnvSuite extends RpcEnvSuite with MockitoSugar { port: Int, clientMode: Boolean = false): RpcEnv = { val config = RpcEnvConfig(conf, "test", "localhost", "localhost", port, - new SecurityManager(conf), clientMode) + new SecurityManager(conf), 0, clientMode) new NettyRpcEnvFactory().create(config) } @@ -47,7 +47,7 @@ class NettyRpcEnvSuite extends RpcEnvSuite with MockitoSugar { test("advertise address different from bind address") { val sparkConf = new SparkConf() val config = RpcEnvConfig(sparkConf, "test", "localhost", "example.com", 0, - new SecurityManager(sparkConf), false) + new SecurityManager(sparkConf), 0, false) val env = new NettyRpcEnvFactory().create(config) try { assert(env.address.hostPort.startsWith("example.com:")) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 4868180569778..00b53e29ed102 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -431,8 +431,10 @@ private[spark] class ApplicationMaster( } private def runExecutorLauncher(securityMgr: SecurityManager): Unit = { - rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, -1, sparkConf, securityMgr, - clientMode = true) + val hostname = Utils.localHostName + val amCores = sparkConf.get(AM_CORES) + rpcEnv = RpcEnv.create("sparkYarnAM", hostname, hostname, -1, sparkConf, securityMgr, + amCores, true) val driverRef = waitForSparkDriver() addAmIpFilter() registerAM(sparkConf, rpcEnv, driverRef, sparkConf.getOption("spark.driver.appUIAddress"),