Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,69 +17,89 @@

package org.apache.spark.deploy.mesos

import java.net.SocketAddress
import java.nio.ByteBuffer
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}

import scala.collection.mutable
import scala.collection.JavaConverters._

import org.apache.spark.{Logging, SecurityManager, SparkConf}
import org.apache.spark.deploy.ExternalShuffleService
import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage
import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver
import org.apache.spark.network.shuffle.protocol.mesos.{RegisterDriver, ShuffleServiceHeartbeat}
import org.apache.spark.network.util.TransportConf
import org.apache.spark.util.ThreadUtils

/**
* An RPC endpoint that receives registration requests from Spark drivers running on Mesos.
* It detects driver termination and calls the cleanup callback to [[ExternalShuffleService]].
*/
private[mesos] class MesosExternalShuffleBlockHandler(transportConf: TransportConf)
private[mesos] class MesosExternalShuffleBlockHandler(
transportConf: TransportConf,
cleanerIntervalS: Long)
extends ExternalShuffleBlockHandler(transportConf, null) with Logging {

// Stores a map of driver socket addresses to app ids
private val connectedApps = new mutable.HashMap[SocketAddress, String]
ThreadUtils.newDaemonSingleThreadScheduledExecutor("shuffle-cleaner-watcher")
.scheduleAtFixedRate(new CleanerThread(), 0, cleanerIntervalS, TimeUnit.SECONDS)

// Stores a map of app id to app state (timeout value and last heartbeat)
private val connectedApps = new ConcurrentHashMap[String, AppState]()

protected override def handleMessage(
message: BlockTransferMessage,
client: TransportClient,
callback: RpcResponseCallback): Unit = {
message match {
case RegisterDriverParam(appId) =>
case RegisterDriverParam(appId, appState) =>
val address = client.getSocketAddress
logDebug(s"Received registration request from app $appId (remote address $address).")
if (connectedApps.contains(address)) {
val existingAppId = connectedApps(address)
if (!existingAppId.equals(appId)) {
logError(s"A new app '$appId' has connected to existing address $address, " +
s"removing previously registered app '$existingAppId'.")
applicationRemoved(existingAppId, true)
}
val timeout = appState.heartbeatTimeout
logInfo(s"Received registration request from app $appId (remote address $address, " +
s"heartbeat timeout $timeout ms).")
if (connectedApps.containsKey(appId)) {
logWarning(s"Received a registration request from app $appId, but it was already " +
s"registered")
}
connectedApps(address) = appId
connectedApps.put(appId, appState)
callback.onSuccess(ByteBuffer.allocate(0))
case Heartbeat(appId) =>
val address = client.getSocketAddress
Option(connectedApps.get(appId)) match {
case Some(existingAppState) =>
logTrace(s"Received ShuffleServiceHeartbeat from app '$appId' (remote " +
s"address $address).")
existingAppState.lastHeartbeat = System.nanoTime()
case None =>
logWarning(s"Received ShuffleServiceHeartbeat from an unknown app (remote " +
s"address $address, appId '$appId').")
}
case _ => super.handleMessage(message, client, callback)
}
}

/**
* On connection termination, clean up shuffle files written by the associated application.
*/
override def connectionTerminated(client: TransportClient): Unit = {
val address = client.getSocketAddress
if (connectedApps.contains(address)) {
val appId = connectedApps(address)
logInfo(s"Application $appId disconnected (address was $address).")
applicationRemoved(appId, true /* cleanupLocalDirs */)
connectedApps.remove(address)
} else {
logWarning(s"Unknown $address disconnected.")
}
}

/** An extractor object for matching [[RegisterDriver]] message. */
private object RegisterDriverParam {
def unapply(r: RegisterDriver): Option[String] = Some(r.getAppId)
def unapply(r: RegisterDriver): Option[(String, AppState)] =
Some((r.getAppId, new AppState(r.getHeartbeatTimeoutMs, System.nanoTime())))
}

private object Heartbeat {
def unapply(h: ShuffleServiceHeartbeat): Option[String] = Some(h.getAppId)
}

private class AppState(val heartbeatTimeout: Long, @volatile var lastHeartbeat: Long)

private class CleanerThread extends Runnable {
override def run(): Unit = {
val now = System.nanoTime()
connectedApps.asScala.foreach { case (appId, appState) =>
if (now - appState.lastHeartbeat > appState.heartbeatTimeout * 1000 * 1000) {
logInfo(s"Application $appId timed out. Removing shuffle files.")
connectedApps.remove(appId)
applicationRemoved(appId, true)
}
}
}
}
}

Expand All @@ -93,7 +113,8 @@ private[mesos] class MesosExternalShuffleService(conf: SparkConf, securityManage

protected override def newShuffleBlockHandler(
conf: TransportConf): ExternalShuffleBlockHandler = {
new MesosExternalShuffleBlockHandler(conf)
val cleanerIntervalS = this.conf.getTimeAsSeconds("spark.shuffle.cleaner.interval", "30s")
new MesosExternalShuffleBlockHandler(conf, cleanerIntervalS)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.rpc.RpcAddress
* @param rpcAddress The socket address of the endpint.
* @param name Name of the endpoint.
*/
private[netty] case class RpcEndpointAddress(val rpcAddress: RpcAddress, val name: String) {
private[spark] case class RpcEndpointAddress(val rpcAddress: RpcAddress, val name: String) {

require(name != null, "RpcEndpoint name must be provided.")

Expand All @@ -44,7 +44,11 @@ private[netty] case class RpcEndpointAddress(val rpcAddress: RpcAddress, val nam
}
}

private[netty] object RpcEndpointAddress {
private[spark] object RpcEndpointAddress {

def apply(host: String, port: Int, name: String): RpcEndpointAddress = {
new RpcEndpointAddress(host, port, name)
}

def apply(sparkUrl: String): RpcEndpointAddress = {
try {
Expand Down
Loading