Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,6 @@ private[spark] abstract class RpcEnv(conf: SparkConf) {
defaultLookupTimeout.awaitResult(asyncSetupEndpointRefByURI(uri))
}

/**
* Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName`
* asynchronously.
*/
def asyncSetupEndpointRef(
systemName: String, address: RpcAddress, endpointName: String): Future[RpcEndpointRef] = {
asyncSetupEndpointRefByURI(uriOf(systemName, address, endpointName))
}

/**
* Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName`.
* This is a blocking action.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ import org.apache.spark.network.client.RpcResponseCallback
import org.apache.spark.rpc._
import org.apache.spark.util.ThreadUtils

/**
* A message dispatcher, responsible for routing RPC messages to the appropriate endpoint(s).
*/
private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {

private class EndpointData(
Expand All @@ -42,7 +45,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
private val endpointRefs = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]

// Track the receivers whose inboxes may contain messages.
private val receivers = new LinkedBlockingQueue[EndpointData]()
private val receivers = new LinkedBlockingQueue[EndpointData]

/**
* True if the dispatcher has been stopped. Once stopped, all messages posted will be bounced
Expand All @@ -52,7 +55,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
private var stopped = false

def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = {
val addr = new NettyRpcAddress(nettyEnv.address.host, nettyEnv.address.port, name)
val addr = new RpcEndpointAddress(nettyEnv.address.host, nettyEnv.address.port, name)
val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv)
synchronized {
if (stopped) {
Expand Down
114 changes: 46 additions & 68 deletions core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import java.nio.ByteBuffer
import java.util.concurrent._
import javax.annotation.concurrent.GuardedBy

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.concurrent.{Future, Promise}
import scala.reflect.ClassTag
Expand All @@ -45,23 +44,25 @@ private[netty] class NettyRpcEnv(
host: String,
securityManager: SecurityManager) extends RpcEnv(conf) with Logging {

private val transportConf =
SparkTransportConf.fromSparkConf(conf, conf.getInt("spark.rpc.io.threads", 0))
// Override numConnectionsPerPeer to 1 for RPC.
private val transportConf = SparkTransportConf.fromSparkConf(
conf.clone.set("spark.shuffle.io.numConnectionsPerPeer", "1"),
conf.getInt("spark.rpc.io.threads", 0))

private val dispatcher: Dispatcher = new Dispatcher(this)

private val transportContext =
new TransportContext(transportConf, new NettyRpcHandler(dispatcher, this))

private val clientFactory = {
val bootstraps: Seq[TransportClientBootstrap] =
val bootstraps: java.util.List[TransportClientBootstrap] =
if (securityManager.isAuthenticationEnabled()) {
Seq(new SaslClientBootstrap(transportConf, "", securityManager,
java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager,
securityManager.isSaslEncryptionEnabled()))
} else {
Nil
java.util.Collections.emptyList[TransportClientBootstrap]
}
transportContext.createClientFactory(bootstraps.asJava)
transportContext.createClientFactory(bootstraps)
}

val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout")
Expand All @@ -71,7 +72,7 @@ private[netty] class NettyRpcEnv(
// TODO: a non-blocking TransportClientFactory.createClient in future
private val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool(
"netty-rpc-connection",
conf.getInt("spark.rpc.connect.threads", 256))
conf.getInt("spark.rpc.connect.threads", 64))

@volatile private var server: TransportServer = _

Expand All @@ -83,7 +84,8 @@ private[netty] class NettyRpcEnv(
java.util.Collections.emptyList()
}
server = transportContext.createServer(port, bootstraps)
dispatcher.registerRpcEndpoint(IDVerifier.NAME, new IDVerifier(this, dispatcher))
dispatcher.registerRpcEndpoint(
RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher))
}

override lazy val address: RpcAddress = {
Expand All @@ -96,11 +98,11 @@ private[netty] class NettyRpcEnv(
}

def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = {
val addr = NettyRpcAddress(uri)
val addr = RpcEndpointAddress(uri)
val endpointRef = new NettyRpcEndpointRef(conf, addr, this)
val idVerifierRef =
new NettyRpcEndpointRef(conf, NettyRpcAddress(addr.host, addr.port, IDVerifier.NAME), this)
idVerifierRef.ask[Boolean](ID(endpointRef.name)).flatMap { find =>
val verifier = new NettyRpcEndpointRef(
conf, RpcEndpointAddress(addr.host, addr.port, RpcEndpointVerifier.NAME), this)
verifier.ask[Boolean](RpcEndpointVerifier.CheckExistence(endpointRef.name)).flatMap { find =>
if (find) {
Future.successful(endpointRef)
} else {
Expand All @@ -117,16 +119,18 @@ private[netty] class NettyRpcEnv(
private[netty] def send(message: RequestMessage): Unit = {
val remoteAddr = message.receiver.address
if (remoteAddr == address) {
// Message to a local RPC endpoint.
val promise = Promise[Any]()
dispatcher.postLocalMessage(message, promise)
promise.future.onComplete {
case Success(response) =>
val ack = response.asInstanceOf[Ack]
logDebug(s"Receive ack from ${ack.sender}")
logTrace(s"Received ack from ${ack.sender}")
case Failure(e) =>
logError(s"Exception when sending $message", e)
}(ThreadUtils.sameThread)
} else {
// Message to a remote RPC endpoint.
try {
// `createClient` will block if it cannot find a known connection, so we should run it in
// clientConnectionExecutor
Expand Down Expand Up @@ -204,11 +208,10 @@ private[netty] class NettyRpcEnv(
}
})
} catch {
case e: RejectedExecutionException => {
case e: RejectedExecutionException =>
if (!promise.tryFailure(e)) {
logWarning(s"Ignore failure", e)
}
}
}
}
promise.future
Expand All @@ -231,7 +234,7 @@ private[netty] class NettyRpcEnv(
}

override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String =
new NettyRpcAddress(address.host, address.port, endpointName).toString
new RpcEndpointAddress(address.host, address.port, endpointName).toString

override def shutdown(): Unit = {
cleanup()
Expand Down Expand Up @@ -310,9 +313,9 @@ private[netty] class NettyRpcEndpointRef(@transient conf: SparkConf)

@transient @volatile private var nettyEnv: NettyRpcEnv = _

@transient @volatile private var _address: NettyRpcAddress = _
@transient @volatile private var _address: RpcEndpointAddress = _

def this(conf: SparkConf, _address: NettyRpcAddress, nettyEnv: NettyRpcEnv) {
def this(conf: SparkConf, _address: RpcEndpointAddress, nettyEnv: NettyRpcEnv) {
this(conf)
this._address = _address
this.nettyEnv = nettyEnv
Expand All @@ -322,7 +325,7 @@ private[netty] class NettyRpcEndpointRef(@transient conf: SparkConf)

private def readObject(in: ObjectInputStream): Unit = {
in.defaultReadObject()
_address = in.readObject().asInstanceOf[NettyRpcAddress]
_address = in.readObject().asInstanceOf[RpcEndpointAddress]
nettyEnv = NettyRpcEnv.currentEnv.value
}

Expand Down Expand Up @@ -406,49 +409,37 @@ private[netty] class NettyRpcHandler(
private type RemoteEnvAddress = RpcAddress

// Store all client addresses and their NettyRpcEnv addresses.
// TODO: Is this even necessary?
@GuardedBy("this")
private val remoteAddresses = new mutable.HashMap[ClientAddress, RemoteEnvAddress]()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zsxwing is this necessary at all? Maybe it is ok to not have this and always reconstruct the RpcAddress.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It's necessary. RemoteEnvAddress is the port that a RpcEnv listens to, but ClientAddress is a random port that the client opens. For these network events, RpcEndpoint needs to use RemoteEnvAddress.


// Store the connections from other NettyRpcEnv addresses. We need to keep track of the connection
// count because `TransportClientFactory.createClient` will create multiple connections
// (at most `spark.shuffle.io.numConnectionsPerPeer` connections) and randomly select a connection
// to send the message. See `TransportClientFactory.createClient` for more details.
@GuardedBy("this")
private val remoteConnectionCount = new mutable.HashMap[RemoteEnvAddress, Int]()

override def receive(
client: TransportClient, message: Array[Byte], callback: RpcResponseCallback): Unit = {
val requestMessage = nettyEnv.deserialize[RequestMessage](message)
val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
assert(addr != null)
val remoteEnvAddress = requestMessage.senderAddress
val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
val broadcastMessage: Option[RemoteProcessConnected] =
synchronized {
// If the first connection to a remote RpcEnv is found, we should broadcast "Associated"
if (remoteAddresses.put(clientAddr, remoteEnvAddress).isEmpty) {
// clientAddr connects at the first time
val count = remoteConnectionCount.getOrElse(remoteEnvAddress, 0)
// Increase the connection number of remoteEnvAddress
remoteConnectionCount.put(remoteEnvAddress, count + 1)
if (count == 0) {
// This is the first connection, so fire "Associated"
Some(RemoteProcessConnected(remoteEnvAddress))
} else {
None
}
} else {
None
}

// TODO: Can we add connection callback (channel registered) to the underlying framework?
// A variable to track whether we should dispatch the RemoteProcessConnected message.
var dispatchRemoteProcessConnected = false
synchronized {
if (remoteAddresses.put(clientAddr, remoteEnvAddress).isEmpty) {
// clientAddr connects at the first time, fire "RemoteProcessConnected"
dispatchRemoteProcessConnected = true
}
broadcastMessage.foreach(dispatcher.postToAll)
}
if (dispatchRemoteProcessConnected) {
dispatcher.postToAll(RemoteProcessConnected(remoteEnvAddress))
}
dispatcher.postRemoteMessage(requestMessage, callback)
}

override def getStreamManager: StreamManager = new OneForOneStreamManager

override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = {
val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
if (addr != null) {
val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
val broadcastMessage =
Expand All @@ -469,34 +460,21 @@ private[netty] class NettyRpcHandler(
}

override def connectionTerminated(client: TransportClient): Unit = {
val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
if (addr != null) {
val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
val broadcastMessage =
synchronized {
// If the last connection to a remote RpcEnv is terminated, we should broadcast
// "Disassociated"
remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress =>
remoteAddresses -= clientAddr
val count = remoteConnectionCount.getOrElse(remoteEnvAddress, 0)
assert(count != 0, "remoteAddresses and remoteConnectionCount are not consistent")
if (count - 1 == 0) {
// We lost all clients, so clean up and fire "Disassociated"
remoteConnectionCount.remove(remoteEnvAddress)
Some(RemoteProcessDisconnected(remoteEnvAddress))
} else {
// Decrease the connection number of remoteEnvAddress
remoteConnectionCount.put(remoteEnvAddress, count - 1)
None
}
}
val messageOpt: Option[RemoteProcessDisconnected] =
synchronized {
remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress =>
remoteAddresses -= clientAddr
Some(RemoteProcessDisconnected(remoteEnvAddress))
}
broadcastMessage.foreach(dispatcher.postToAll)
}
messageOpt.foreach(dispatcher.postToAll)
} else {
// If the channel is closed before connecting, its remoteAddress will be null. In this case,
// we can ignore it since we don't fire "Associated".
// See java.net.Socket.getRemoteSocketAddress
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,40 +17,44 @@

package org.apache.spark.rpc.netty

import java.net.URI

import org.apache.spark.SparkException
import org.apache.spark.rpc.RpcAddress

private[netty] case class NettyRpcAddress(host: String, port: Int, name: String) {
/**
* An address identifier for an RPC endpoint.
*
* @param host host name of the remote process.
* @param port the port the remote RPC environment binds to.
* @param name name of the remote endpoint.
*/
private[netty] case class RpcEndpointAddress(host: String, port: Int, name: String) {

def toRpcAddress: RpcAddress = RpcAddress(host, port)

override val toString = s"spark://$name@$host:$port"
}

private[netty] object NettyRpcAddress {
private[netty] object RpcEndpointAddress {

def apply(sparkUrl: String): NettyRpcAddress = {
def apply(sparkUrl: String): RpcEndpointAddress = {
try {
val uri = new URI(sparkUrl)
val uri = new java.net.URI(sparkUrl)
val host = uri.getHost
val port = uri.getPort
val name = uri.getUserInfo
if (uri.getScheme != "spark" ||
host == null ||
port < 0 ||
name == null ||
(uri.getPath != null && !uri.getPath.isEmpty) || // uri.getPath returns "" instead of null
uri.getFragment != null ||
uri.getQuery != null) {
host == null ||
port < 0 ||
name == null ||
(uri.getPath != null && !uri.getPath.isEmpty) || // uri.getPath returns "" instead of null
uri.getFragment != null ||
uri.getQuery != null) {
throw new SparkException("Invalid Spark URL: " + sparkUrl)
}
NettyRpcAddress(host, port, name)
RpcEndpointAddress(host, port, name)
} catch {
case e: java.net.URISyntaxException =>
throw new SparkException("Invalid Spark URL: " + sparkUrl, e)
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,27 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.rpc.netty

import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv}

/**
* A message used to ask the remote [[IDVerifier]] if an [[RpcEndpoint]] exists
*/
private[netty] case class ID(name: String)

/**
* An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if a [[RpcEndpoint]] exists in this [[RpcEnv]]
* An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if an [[RpcEndpoint]] exists.
*
* This is used when setting up a remote endpoint reference.
*/
private[netty] class IDVerifier(override val rpcEnv: RpcEnv, dispatcher: Dispatcher)
private[netty] class RpcEndpointVerifier(override val rpcEnv: RpcEnv, dispatcher: Dispatcher)
extends RpcEndpoint {

override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case ID(name) => context.reply(dispatcher.verify(name))
case RpcEndpointVerifier.CheckExistence(name) => context.reply(dispatcher.verify(name))
}
}

private[netty] object IDVerifier {
val NAME = "id-verifier"
private[netty] object RpcEndpointVerifier {
val NAME = "endpoint-verifier"

/** A message used to ask the remote [[RpcEndpointVerifier]] if an [[RpcEndpoint]] exists. */
case class CheckExistence(name: String)
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.SparkFunSuite
class NettyRpcAddressSuite extends SparkFunSuite {

test("toString") {
val addr = NettyRpcAddress("localhost", 12345, "test")
val addr = RpcEndpointAddress("localhost", 12345, "test")
assert(addr.toString === "spark://test@localhost:12345")
}

Expand Down
Loading