@@ -22,7 +22,6 @@ import java.nio.ByteBuffer
2222import java .util .concurrent ._
2323import javax .annotation .concurrent .GuardedBy
2424
25- import scala .collection .JavaConverters ._
2625import scala .collection .mutable
2726import scala .concurrent .{Future , Promise }
2827import scala .reflect .ClassTag
@@ -45,23 +44,25 @@ private[netty] class NettyRpcEnv(
4544 host : String ,
4645 securityManager : SecurityManager ) extends RpcEnv (conf) with Logging {
4746
48- private val transportConf =
49- SparkTransportConf .fromSparkConf(conf, conf.getInt(" spark.rpc.io.threads" , 0 ))
47+ // Override numConnectionsPerPeer to 1 for RPC.
48+ private val transportConf = SparkTransportConf .fromSparkConf(
49+ conf.clone.set(" spark.shuffle.io.numConnectionsPerPeer" , " 1" ),
50+ conf.getInt(" spark.rpc.io.threads" , 0 ))
5051
5152 private val dispatcher : Dispatcher = new Dispatcher (this )
5253
5354 private val transportContext =
5455 new TransportContext (transportConf, new NettyRpcHandler (dispatcher, this ))
5556
5657 private val clientFactory = {
57- val bootstraps : Seq [TransportClientBootstrap ] =
58+ val bootstraps : java.util. List [TransportClientBootstrap ] =
5859 if (securityManager.isAuthenticationEnabled()) {
59- Seq (new SaslClientBootstrap (transportConf, " " , securityManager,
60+ java.util. Arrays .asList (new SaslClientBootstrap (transportConf, " " , securityManager,
6061 securityManager.isSaslEncryptionEnabled()))
6162 } else {
62- Nil
63+ java.util. Collections .emptyList[ TransportClientBootstrap ]
6364 }
64- transportContext.createClientFactory(bootstraps.asJava )
65+ transportContext.createClientFactory(bootstraps)
6566 }
6667
6768 val timeoutScheduler = ThreadUtils .newDaemonSingleThreadScheduledExecutor(" netty-rpc-env-timeout" )
@@ -71,7 +72,7 @@ private[netty] class NettyRpcEnv(
7172 // TODO: a non-blocking TransportClientFactory.createClient in future
7273 private val clientConnectionExecutor = ThreadUtils .newDaemonCachedThreadPool(
7374 " netty-rpc-connection" ,
74- conf.getInt(" spark.rpc.connect.threads" , 256 ))
75+ conf.getInt(" spark.rpc.connect.threads" , 64 ))
7576
7677 @ volatile private var server : TransportServer = _
7778
@@ -83,7 +84,8 @@ private[netty] class NettyRpcEnv(
8384 java.util.Collections .emptyList()
8485 }
8586 server = transportContext.createServer(port, bootstraps)
86- dispatcher.registerRpcEndpoint(IDVerifier .NAME , new IDVerifier (this , dispatcher))
87+ dispatcher.registerRpcEndpoint(
88+ RpcEndpointVerifier .NAME , new RpcEndpointVerifier (this , dispatcher))
8789 }
8890
8991 override lazy val address : RpcAddress = {
@@ -96,11 +98,11 @@ private[netty] class NettyRpcEnv(
9698 }
9799
98100 def asyncSetupEndpointRefByURI (uri : String ): Future [RpcEndpointRef ] = {
99- val addr = NettyRpcAddress (uri)
101+ val addr = RpcEndpointAddress (uri)
100102 val endpointRef = new NettyRpcEndpointRef (conf, addr, this )
101- val idVerifierRef =
102- new NettyRpcEndpointRef ( conf, NettyRpcAddress (addr.host, addr.port, IDVerifier .NAME ), this )
103- idVerifierRef .ask[Boolean ](ID (endpointRef.name)).flatMap { find =>
103+ val verifier = new NettyRpcEndpointRef (
104+ conf, RpcEndpointAddress (addr.host, addr.port, RpcEndpointVerifier .NAME ), this )
105+ verifier .ask[Boolean ](RpcEndpointVerifier . CheckExistence (endpointRef.name)).flatMap { find =>
104106 if (find) {
105107 Future .successful(endpointRef)
106108 } else {
@@ -117,16 +119,18 @@ private[netty] class NettyRpcEnv(
117119 private [netty] def send (message : RequestMessage ): Unit = {
118120 val remoteAddr = message.receiver.address
119121 if (remoteAddr == address) {
122+ // Message to a local RPC endpoint.
120123 val promise = Promise [Any ]()
121124 dispatcher.postLocalMessage(message, promise)
122125 promise.future.onComplete {
123126 case Success (response) =>
124127 val ack = response.asInstanceOf [Ack ]
125- logDebug (s " Receive ack from ${ack.sender}" )
128+ logTrace (s " Received ack from ${ack.sender}" )
126129 case Failure (e) =>
127130 logError(s " Exception when sending $message" , e)
128131 }(ThreadUtils .sameThread)
129132 } else {
133+ // Message to a remote RPC endpoint.
130134 try {
131135 // `createClient` will block if it cannot find a known connection, so we should run it in
132136 // clientConnectionExecutor
@@ -204,11 +208,10 @@ private[netty] class NettyRpcEnv(
204208 }
205209 })
206210 } catch {
207- case e : RejectedExecutionException => {
211+ case e : RejectedExecutionException =>
208212 if (! promise.tryFailure(e)) {
209213 logWarning(s " Ignore failure " , e)
210214 }
211- }
212215 }
213216 }
214217 promise.future
@@ -231,7 +234,7 @@ private[netty] class NettyRpcEnv(
231234 }
232235
233236 override def uriOf (systemName : String , address : RpcAddress , endpointName : String ): String =
234- new NettyRpcAddress (address.host, address.port, endpointName).toString
237+ new RpcEndpointAddress (address.host, address.port, endpointName).toString
235238
236239 override def shutdown (): Unit = {
237240 cleanup()
@@ -310,9 +313,9 @@ private[netty] class NettyRpcEndpointRef(@transient conf: SparkConf)
310313
311314 @ transient @ volatile private var nettyEnv : NettyRpcEnv = _
312315
313- @ transient @ volatile private var _address : NettyRpcAddress = _
316+ @ transient @ volatile private var _address : RpcEndpointAddress = _
314317
315- def this (conf : SparkConf , _address : NettyRpcAddress , nettyEnv : NettyRpcEnv ) {
318+ def this (conf : SparkConf , _address : RpcEndpointAddress , nettyEnv : NettyRpcEnv ) {
316319 this (conf)
317320 this ._address = _address
318321 this .nettyEnv = nettyEnv
@@ -322,7 +325,7 @@ private[netty] class NettyRpcEndpointRef(@transient conf: SparkConf)
322325
323326 private def readObject (in : ObjectInputStream ): Unit = {
324327 in.defaultReadObject()
325- _address = in.readObject().asInstanceOf [NettyRpcAddress ]
328+ _address = in.readObject().asInstanceOf [RpcEndpointAddress ]
326329 nettyEnv = NettyRpcEnv .currentEnv.value
327330 }
328331
@@ -406,49 +409,37 @@ private[netty] class NettyRpcHandler(
406409 private type RemoteEnvAddress = RpcAddress
407410
408411 // Store all client addresses and their NettyRpcEnv addresses.
412+ // TODO: Is this even necessary?
409413 @ GuardedBy (" this" )
410414 private val remoteAddresses = new mutable.HashMap [ClientAddress , RemoteEnvAddress ]()
411415
412- // Store the connections from other NettyRpcEnv addresses. We need to keep track of the connection
413- // count because `TransportClientFactory.createClient` will create multiple connections
414- // (at most `spark.shuffle.io.numConnectionsPerPeer` connections) and randomly select a connection
415- // to send the message. See `TransportClientFactory.createClient` for more details.
416- @ GuardedBy (" this" )
417- private val remoteConnectionCount = new mutable.HashMap [RemoteEnvAddress , Int ]()
418-
419416 override def receive (
420417 client : TransportClient , message : Array [Byte ], callback : RpcResponseCallback ): Unit = {
421418 val requestMessage = nettyEnv.deserialize[RequestMessage ](message)
422- val addr = client.getChannel() .remoteAddress().asInstanceOf [InetSocketAddress ]
419+ val addr = client.getChannel.remoteAddress().asInstanceOf [InetSocketAddress ]
423420 assert(addr != null )
424421 val remoteEnvAddress = requestMessage.senderAddress
425422 val clientAddr = RpcAddress (addr.getHostName, addr.getPort)
426- val broadcastMessage : Option [RemoteProcessConnected ] =
427- synchronized {
428- // If the first connection to a remote RpcEnv is found, we should broadcast "Associated"
429- if (remoteAddresses.put(clientAddr, remoteEnvAddress).isEmpty) {
430- // clientAddr connects at the first time
431- val count = remoteConnectionCount.getOrElse(remoteEnvAddress, 0 )
432- // Increase the connection number of remoteEnvAddress
433- remoteConnectionCount.put(remoteEnvAddress, count + 1 )
434- if (count == 0 ) {
435- // This is the first connection, so fire "Associated"
436- Some (RemoteProcessConnected (remoteEnvAddress))
437- } else {
438- None
439- }
440- } else {
441- None
442- }
423+
424+ // TODO: Can we add connection callback (channel registered) to the underlying framework?
425+ // A variable to track whether we should dispatch the RemoteProcessConnected message.
426+ var dispatchRemoteProcessConnected = false
427+ synchronized {
428+ if (remoteAddresses.put(clientAddr, remoteEnvAddress).isEmpty) {
429+ // clientAddr connects at the first time, fire "RemoteProcessConnected"
430+ dispatchRemoteProcessConnected = true
443431 }
444- broadcastMessage.foreach(dispatcher.postToAll)
432+ }
433+ if (dispatchRemoteProcessConnected) {
434+ dispatcher.postToAll(RemoteProcessConnected (remoteEnvAddress))
435+ }
445436 dispatcher.postRemoteMessage(requestMessage, callback)
446437 }
447438
448439 override def getStreamManager : StreamManager = new OneForOneStreamManager
449440
450441 override def exceptionCaught (cause : Throwable , client : TransportClient ): Unit = {
451- val addr = client.getChannel() .remoteAddress().asInstanceOf [InetSocketAddress ]
442+ val addr = client.getChannel.remoteAddress().asInstanceOf [InetSocketAddress ]
452443 if (addr != null ) {
453444 val clientAddr = RpcAddress (addr.getHostName, addr.getPort)
454445 val broadcastMessage =
@@ -469,34 +460,21 @@ private[netty] class NettyRpcHandler(
469460 }
470461
471462 override def connectionTerminated (client : TransportClient ): Unit = {
472- val addr = client.getChannel() .remoteAddress().asInstanceOf [InetSocketAddress ]
463+ val addr = client.getChannel.remoteAddress().asInstanceOf [InetSocketAddress ]
473464 if (addr != null ) {
474465 val clientAddr = RpcAddress (addr.getHostName, addr.getPort)
475- val broadcastMessage =
476- synchronized {
477- // If the last connection to a remote RpcEnv is terminated, we should broadcast
478- // "Disassociated"
479- remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress =>
480- remoteAddresses -= clientAddr
481- val count = remoteConnectionCount.getOrElse(remoteEnvAddress, 0 )
482- assert(count != 0 , " remoteAddresses and remoteConnectionCount are not consistent" )
483- if (count - 1 == 0 ) {
484- // We lost all clients, so clean up and fire "Disassociated"
485- remoteConnectionCount.remove(remoteEnvAddress)
486- Some (RemoteProcessDisconnected (remoteEnvAddress))
487- } else {
488- // Decrease the connection number of remoteEnvAddress
489- remoteConnectionCount.put(remoteEnvAddress, count - 1 )
490- None
491- }
492- }
466+ val messageOpt : Option [RemoteProcessDisconnected ] =
467+ synchronized {
468+ remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress =>
469+ remoteAddresses -= clientAddr
470+ Some (RemoteProcessDisconnected (remoteEnvAddress))
493471 }
494- broadcastMessage.foreach(dispatcher.postToAll)
472+ }
473+ messageOpt.foreach(dispatcher.postToAll)
495474 } else {
496475 // If the channel is closed before connecting, its remoteAddress will be null. In this case,
497476 // we can ignore it since we don't fire "Associated".
498477 // See java.net.Socket.getRemoteSocketAddress
499478 }
500479 }
501-
502480}
0 commit comments