Skip to content

Commit cf2e0ae

Browse files
committed
[SPARK-11096] Post-hoc review Netty based RPC implementation - round 2
A few more changes: 1. Renamed IDVerifier -> RpcEndpointVerifier 2. Renamed NettyRpcAddress -> RpcEndpointAddress 3. Simplified NettyRpcHandler a bit by removing the connection count tracking. This is OK because I now force spark.shuffle.io.numConnectionsPerPeer to 1 4. Reduced spark.rpc.connect.threads to 64. It would be great to eventually remove this extra thread pool. 5. Minor cleanup & documentation. Author: Reynold Xin <rxin@databricks.com> Closes #9112 from rxin/SPARK-11096.
1 parent 615cc85 commit cf2e0ae

File tree

7 files changed

+81
-107
lines changed

7 files changed

+81
-107
lines changed

core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -93,15 +93,6 @@ private[spark] abstract class RpcEnv(conf: SparkConf) {
9393
defaultLookupTimeout.awaitResult(asyncSetupEndpointRefByURI(uri))
9494
}
9595

96-
/**
97-
* Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName`
98-
* asynchronously.
99-
*/
100-
def asyncSetupEndpointRef(
101-
systemName: String, address: RpcAddress, endpointName: String): Future[RpcEndpointRef] = {
102-
asyncSetupEndpointRefByURI(uriOf(systemName, address, endpointName))
103-
}
104-
10596
/**
10697
* Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName`.
10798
* This is a blocking action.

core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ import org.apache.spark.network.client.RpcResponseCallback
2929
import org.apache.spark.rpc._
3030
import org.apache.spark.util.ThreadUtils
3131

32+
/**
33+
* A message dispatcher, responsible for routing RPC messages to the appropriate endpoint(s).
34+
*/
3235
private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
3336

3437
private class EndpointData(
@@ -42,7 +45,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
4245
private val endpointRefs = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]
4346

4447
// Track the receivers whose inboxes may contain messages.
45-
private val receivers = new LinkedBlockingQueue[EndpointData]()
48+
private val receivers = new LinkedBlockingQueue[EndpointData]
4649

4750
/**
4851
* True if the dispatcher has been stopped. Once stopped, all messages posted will be bounced
@@ -52,7 +55,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
5255
private var stopped = false
5356

5457
def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = {
55-
val addr = new NettyRpcAddress(nettyEnv.address.host, nettyEnv.address.port, name)
58+
val addr = new RpcEndpointAddress(nettyEnv.address.host, nettyEnv.address.port, name)
5659
val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv)
5760
synchronized {
5861
if (stopped) {

core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala

Lines changed: 46 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import java.nio.ByteBuffer
2222
import java.util.concurrent._
2323
import javax.annotation.concurrent.GuardedBy
2424

25-
import scala.collection.JavaConverters._
2625
import scala.collection.mutable
2726
import scala.concurrent.{Future, Promise}
2827
import scala.reflect.ClassTag
@@ -45,23 +44,25 @@ private[netty] class NettyRpcEnv(
4544
host: String,
4645
securityManager: SecurityManager) extends RpcEnv(conf) with Logging {
4746

48-
private val transportConf =
49-
SparkTransportConf.fromSparkConf(conf, conf.getInt("spark.rpc.io.threads", 0))
47+
// Override numConnectionsPerPeer to 1 for RPC.
48+
private val transportConf = SparkTransportConf.fromSparkConf(
49+
conf.clone.set("spark.shuffle.io.numConnectionsPerPeer", "1"),
50+
conf.getInt("spark.rpc.io.threads", 0))
5051

5152
private val dispatcher: Dispatcher = new Dispatcher(this)
5253

5354
private val transportContext =
5455
new TransportContext(transportConf, new NettyRpcHandler(dispatcher, this))
5556

5657
private val clientFactory = {
57-
val bootstraps: Seq[TransportClientBootstrap] =
58+
val bootstraps: java.util.List[TransportClientBootstrap] =
5859
if (securityManager.isAuthenticationEnabled()) {
59-
Seq(new SaslClientBootstrap(transportConf, "", securityManager,
60+
java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager,
6061
securityManager.isSaslEncryptionEnabled()))
6162
} else {
62-
Nil
63+
java.util.Collections.emptyList[TransportClientBootstrap]
6364
}
64-
transportContext.createClientFactory(bootstraps.asJava)
65+
transportContext.createClientFactory(bootstraps)
6566
}
6667

6768
val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout")
@@ -71,7 +72,7 @@ private[netty] class NettyRpcEnv(
7172
// TODO: a non-blocking TransportClientFactory.createClient in future
7273
private val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool(
7374
"netty-rpc-connection",
74-
conf.getInt("spark.rpc.connect.threads", 256))
75+
conf.getInt("spark.rpc.connect.threads", 64))
7576

7677
@volatile private var server: TransportServer = _
7778

@@ -83,7 +84,8 @@ private[netty] class NettyRpcEnv(
8384
java.util.Collections.emptyList()
8485
}
8586
server = transportContext.createServer(port, bootstraps)
86-
dispatcher.registerRpcEndpoint(IDVerifier.NAME, new IDVerifier(this, dispatcher))
87+
dispatcher.registerRpcEndpoint(
88+
RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher))
8789
}
8890

8991
override lazy val address: RpcAddress = {
@@ -96,11 +98,11 @@ private[netty] class NettyRpcEnv(
9698
}
9799

98100
def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = {
99-
val addr = NettyRpcAddress(uri)
101+
val addr = RpcEndpointAddress(uri)
100102
val endpointRef = new NettyRpcEndpointRef(conf, addr, this)
101-
val idVerifierRef =
102-
new NettyRpcEndpointRef(conf, NettyRpcAddress(addr.host, addr.port, IDVerifier.NAME), this)
103-
idVerifierRef.ask[Boolean](ID(endpointRef.name)).flatMap { find =>
103+
val verifier = new NettyRpcEndpointRef(
104+
conf, RpcEndpointAddress(addr.host, addr.port, RpcEndpointVerifier.NAME), this)
105+
verifier.ask[Boolean](RpcEndpointVerifier.CheckExistence(endpointRef.name)).flatMap { find =>
104106
if (find) {
105107
Future.successful(endpointRef)
106108
} else {
@@ -117,16 +119,18 @@ private[netty] class NettyRpcEnv(
117119
private[netty] def send(message: RequestMessage): Unit = {
118120
val remoteAddr = message.receiver.address
119121
if (remoteAddr == address) {
122+
// Message to a local RPC endpoint.
120123
val promise = Promise[Any]()
121124
dispatcher.postLocalMessage(message, promise)
122125
promise.future.onComplete {
123126
case Success(response) =>
124127
val ack = response.asInstanceOf[Ack]
125-
logDebug(s"Receive ack from ${ack.sender}")
128+
logTrace(s"Received ack from ${ack.sender}")
126129
case Failure(e) =>
127130
logError(s"Exception when sending $message", e)
128131
}(ThreadUtils.sameThread)
129132
} else {
133+
// Message to a remote RPC endpoint.
130134
try {
131135
// `createClient` will block if it cannot find a known connection, so we should run it in
132136
// clientConnectionExecutor
@@ -204,11 +208,10 @@ private[netty] class NettyRpcEnv(
204208
}
205209
})
206210
} catch {
207-
case e: RejectedExecutionException => {
211+
case e: RejectedExecutionException =>
208212
if (!promise.tryFailure(e)) {
209213
logWarning(s"Ignore failure", e)
210214
}
211-
}
212215
}
213216
}
214217
promise.future
@@ -231,7 +234,7 @@ private[netty] class NettyRpcEnv(
231234
}
232235

233236
override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String =
234-
new NettyRpcAddress(address.host, address.port, endpointName).toString
237+
new RpcEndpointAddress(address.host, address.port, endpointName).toString
235238

236239
override def shutdown(): Unit = {
237240
cleanup()
@@ -310,9 +313,9 @@ private[netty] class NettyRpcEndpointRef(@transient conf: SparkConf)
310313

311314
@transient @volatile private var nettyEnv: NettyRpcEnv = _
312315

313-
@transient @volatile private var _address: NettyRpcAddress = _
316+
@transient @volatile private var _address: RpcEndpointAddress = _
314317

315-
def this(conf: SparkConf, _address: NettyRpcAddress, nettyEnv: NettyRpcEnv) {
318+
def this(conf: SparkConf, _address: RpcEndpointAddress, nettyEnv: NettyRpcEnv) {
316319
this(conf)
317320
this._address = _address
318321
this.nettyEnv = nettyEnv
@@ -322,7 +325,7 @@ private[netty] class NettyRpcEndpointRef(@transient conf: SparkConf)
322325

323326
private def readObject(in: ObjectInputStream): Unit = {
324327
in.defaultReadObject()
325-
_address = in.readObject().asInstanceOf[NettyRpcAddress]
328+
_address = in.readObject().asInstanceOf[RpcEndpointAddress]
326329
nettyEnv = NettyRpcEnv.currentEnv.value
327330
}
328331

@@ -406,49 +409,37 @@ private[netty] class NettyRpcHandler(
406409
private type RemoteEnvAddress = RpcAddress
407410

408411
// Store all client addresses and their NettyRpcEnv addresses.
412+
// TODO: Is this even necessary?
409413
@GuardedBy("this")
410414
private val remoteAddresses = new mutable.HashMap[ClientAddress, RemoteEnvAddress]()
411415

412-
// Store the connections from other NettyRpcEnv addresses. We need to keep track of the connection
413-
// count because `TransportClientFactory.createClient` will create multiple connections
414-
// (at most `spark.shuffle.io.numConnectionsPerPeer` connections) and randomly select a connection
415-
// to send the message. See `TransportClientFactory.createClient` for more details.
416-
@GuardedBy("this")
417-
private val remoteConnectionCount = new mutable.HashMap[RemoteEnvAddress, Int]()
418-
419416
override def receive(
420417
client: TransportClient, message: Array[Byte], callback: RpcResponseCallback): Unit = {
421418
val requestMessage = nettyEnv.deserialize[RequestMessage](message)
422-
val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
419+
val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
423420
assert(addr != null)
424421
val remoteEnvAddress = requestMessage.senderAddress
425422
val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
426-
val broadcastMessage: Option[RemoteProcessConnected] =
427-
synchronized {
428-
// If the first connection to a remote RpcEnv is found, we should broadcast "Associated"
429-
if (remoteAddresses.put(clientAddr, remoteEnvAddress).isEmpty) {
430-
// clientAddr connects at the first time
431-
val count = remoteConnectionCount.getOrElse(remoteEnvAddress, 0)
432-
// Increase the connection number of remoteEnvAddress
433-
remoteConnectionCount.put(remoteEnvAddress, count + 1)
434-
if (count == 0) {
435-
// This is the first connection, so fire "Associated"
436-
Some(RemoteProcessConnected(remoteEnvAddress))
437-
} else {
438-
None
439-
}
440-
} else {
441-
None
442-
}
423+
424+
// TODO: Can we add connection callback (channel registered) to the underlying framework?
425+
// A variable to track whether we should dispatch the RemoteProcessConnected message.
426+
var dispatchRemoteProcessConnected = false
427+
synchronized {
428+
if (remoteAddresses.put(clientAddr, remoteEnvAddress).isEmpty) {
429+
// clientAddr connects at the first time, fire "RemoteProcessConnected"
430+
dispatchRemoteProcessConnected = true
443431
}
444-
broadcastMessage.foreach(dispatcher.postToAll)
432+
}
433+
if (dispatchRemoteProcessConnected) {
434+
dispatcher.postToAll(RemoteProcessConnected(remoteEnvAddress))
435+
}
445436
dispatcher.postRemoteMessage(requestMessage, callback)
446437
}
447438

448439
override def getStreamManager: StreamManager = new OneForOneStreamManager
449440

450441
override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = {
451-
val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
442+
val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
452443
if (addr != null) {
453444
val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
454445
val broadcastMessage =
@@ -469,34 +460,21 @@ private[netty] class NettyRpcHandler(
469460
}
470461

471462
override def connectionTerminated(client: TransportClient): Unit = {
472-
val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
463+
val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
473464
if (addr != null) {
474465
val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
475-
val broadcastMessage =
476-
synchronized {
477-
// If the last connection to a remote RpcEnv is terminated, we should broadcast
478-
// "Disassociated"
479-
remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress =>
480-
remoteAddresses -= clientAddr
481-
val count = remoteConnectionCount.getOrElse(remoteEnvAddress, 0)
482-
assert(count != 0, "remoteAddresses and remoteConnectionCount are not consistent")
483-
if (count - 1 == 0) {
484-
// We lost all clients, so clean up and fire "Disassociated"
485-
remoteConnectionCount.remove(remoteEnvAddress)
486-
Some(RemoteProcessDisconnected(remoteEnvAddress))
487-
} else {
488-
// Decrease the connection number of remoteEnvAddress
489-
remoteConnectionCount.put(remoteEnvAddress, count - 1)
490-
None
491-
}
492-
}
466+
val messageOpt: Option[RemoteProcessDisconnected] =
467+
synchronized {
468+
remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress =>
469+
remoteAddresses -= clientAddr
470+
Some(RemoteProcessDisconnected(remoteEnvAddress))
493471
}
494-
broadcastMessage.foreach(dispatcher.postToAll)
472+
}
473+
messageOpt.foreach(dispatcher.postToAll)
495474
} else {
496475
// If the channel is closed before connecting, its remoteAddress will be null. In this case,
497476
// we can ignore it since we don't fire "Associated".
498477
// See java.net.Socket.getRemoteSocketAddress
499478
}
500479
}
501-
502480
}

core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala renamed to core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,40 +17,44 @@
1717

1818
package org.apache.spark.rpc.netty
1919

20-
import java.net.URI
21-
2220
import org.apache.spark.SparkException
2321
import org.apache.spark.rpc.RpcAddress
2422

25-
private[netty] case class NettyRpcAddress(host: String, port: Int, name: String) {
23+
/**
24+
* An address identifier for an RPC endpoint.
25+
*
26+
* @param host host name of the remote process.
27+
* @param port the port the remote RPC environment binds to.
28+
* @param name name of the remote endpoint.
29+
*/
30+
private[netty] case class RpcEndpointAddress(host: String, port: Int, name: String) {
2631

2732
def toRpcAddress: RpcAddress = RpcAddress(host, port)
2833

2934
override val toString = s"spark://$name@$host:$port"
3035
}
3136

32-
private[netty] object NettyRpcAddress {
37+
private[netty] object RpcEndpointAddress {
3338

34-
def apply(sparkUrl: String): NettyRpcAddress = {
39+
def apply(sparkUrl: String): RpcEndpointAddress = {
3540
try {
36-
val uri = new URI(sparkUrl)
41+
val uri = new java.net.URI(sparkUrl)
3742
val host = uri.getHost
3843
val port = uri.getPort
3944
val name = uri.getUserInfo
4045
if (uri.getScheme != "spark" ||
41-
host == null ||
42-
port < 0 ||
43-
name == null ||
44-
(uri.getPath != null && !uri.getPath.isEmpty) || // uri.getPath returns "" instead of null
45-
uri.getFragment != null ||
46-
uri.getQuery != null) {
46+
host == null ||
47+
port < 0 ||
48+
name == null ||
49+
(uri.getPath != null && !uri.getPath.isEmpty) || // uri.getPath returns "" instead of null
50+
uri.getFragment != null ||
51+
uri.getQuery != null) {
4752
throw new SparkException("Invalid Spark URL: " + sparkUrl)
4853
}
49-
NettyRpcAddress(host, port, name)
54+
RpcEndpointAddress(host, port, name)
5055
} catch {
5156
case e: java.net.URISyntaxException =>
5257
throw new SparkException("Invalid Spark URL: " + sparkUrl, e)
5358
}
5459
}
55-
5660
}

core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala renamed to core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,26 +14,27 @@
1414
* See the License for the specific language governing permissions and
1515
* limitations under the License.
1616
*/
17+
1718
package org.apache.spark.rpc.netty
1819

1920
import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv}
2021

2122
/**
22-
* A message used to ask the remote [[IDVerifier]] if an [[RpcEndpoint]] exists
23-
*/
24-
private[netty] case class ID(name: String)
25-
26-
/**
27-
* An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if a [[RpcEndpoint]] exists in this [[RpcEnv]]
23+
* An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if an [[RpcEndpoint]] exists.
24+
*
25+
* This is used when setting up a remote endpoint reference.
2826
*/
29-
private[netty] class IDVerifier(override val rpcEnv: RpcEnv, dispatcher: Dispatcher)
27+
private[netty] class RpcEndpointVerifier(override val rpcEnv: RpcEnv, dispatcher: Dispatcher)
3028
extends RpcEndpoint {
3129

3230
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
33-
case ID(name) => context.reply(dispatcher.verify(name))
31+
case RpcEndpointVerifier.CheckExistence(name) => context.reply(dispatcher.verify(name))
3432
}
3533
}
3634

37-
private[netty] object IDVerifier {
38-
val NAME = "id-verifier"
35+
private[netty] object RpcEndpointVerifier {
36+
val NAME = "endpoint-verifier"
37+
38+
/** A message used to ask the remote [[RpcEndpointVerifier]] if an [[RpcEndpoint]] exists. */
39+
case class CheckExistence(name: String)
3940
}

core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import org.apache.spark.SparkFunSuite
2222
class NettyRpcAddressSuite extends SparkFunSuite {
2323

2424
test("toString") {
25-
val addr = NettyRpcAddress("localhost", 12345, "test")
25+
val addr = RpcEndpointAddress("localhost", 12345, "test")
2626
assert(addr.toString === "spark://test@localhost:12345")
2727
}
2828

0 commit comments

Comments
 (0)