From 63d2d8d936e3d3a532f39ed0c375e134361f2a34 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 18 Dec 2014 15:58:09 -0800 Subject: [PATCH 01/36] rpc temp checkin --- .../org/apache/spark/HeartbeatReceiver.scala | 13 +- .../scala/org/apache/spark/SparkContext.scala | 8 +- .../scala/org/apache/spark/SparkEnv.scala | 6 + .../CoarseGrainedExecutorBackend.scala | 42 ++--- .../org/apache/spark/executor/Executor.scala | 20 ++- .../apache/spark/executor/ExecutorActor.scala | 11 +- .../spark/network/rpc/SimpleRpcServer.scala | 86 +++++++++++ .../scala/org/apache/spark/rpc/RpcEnv.scala | 143 ++++++++++++++++++ .../spark/network/client/TransportClient.java | 2 +- .../network/protocol/MessageDecoder.java | 1 + .../apache/spark/network/util/JavaUtils.java | 26 ++++ 11 files changed, 311 insertions(+), 47 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/network/rpc/SimpleRpcServer.scala create mode 100644 core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 83ae57b7f151..b317cdcae86d 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -17,11 +17,10 @@ package org.apache.spark -import akka.actor.Actor import org.apache.spark.executor.TaskMetrics -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.rpc.{RpcEndPointRef, RpcEndPoint} import org.apache.spark.scheduler.TaskScheduler -import org.apache.spark.util.ActorLogReceive +import org.apache.spark.storage.BlockManagerId /** * A heartbeat from executors to the driver. This is a shared message used by several internal @@ -37,13 +36,13 @@ private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) /** * Lives in the driver to receive heartbeats from executors.. */ -private[spark] class HeartbeatReceiver(scheduler: TaskScheduler) - extends Actor with ActorLogReceive with Logging { +private[spark] class HeartbeatReceiver(scheduler: TaskScheduler, conf: SparkConf) + extends RpcEndPoint with Logging { - override def receiveWithLogging = { + override def receive(sender: RpcEndPointRef, message: Any): Unit = message match { case Heartbeat(executorId, taskMetrics, blockManagerId) => val response = HeartbeatResponse( !scheduler.executorHeartbeatReceived(executorId, taskMetrics, blockManagerId)) - sender ! response + sender.send(response) } } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 8e5378ecc08d..a4276f9b16b7 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -320,8 +320,10 @@ class SparkContext(config: SparkConf) extends Logging { // Create and start the scheduler private[spark] var (schedulerBackend, taskScheduler) = SparkContext.createTaskScheduler(this, master) - private val heartbeatReceiver = env.actorSystem.actorOf( - Props(new HeartbeatReceiver(taskScheduler)), "HeartbeatReceiver") + + private val heartbeatReceiver = + env.rpcEnv.setupEndPoint("HeartbeatReceiver", new HeartbeatReceiver(taskScheduler, conf)) + @volatile private[spark] var dagScheduler: DAGScheduler = _ try { dagScheduler = new DAGScheduler(this) @@ -1205,7 +1207,7 @@ class SparkContext(config: SparkConf) extends Logging { if (dagSchedulerCopy != null) { env.metricsSystem.report() metadataCleaner.cancel() - env.actorSystem.stop(heartbeatReceiver) + env.rpcEnv.stop(heartbeatReceiver) cleaner.foreach(_.stop()) dagSchedulerCopy.stop() taskScheduler = null diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index f4215f268a0d..03f934d34276 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -34,6 +34,7 @@ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.network.nio.NioBlockTransferService +import org.apache.spark.rpc.{AkkaRpcEnv, RpcEnv} import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} @@ -54,6 +55,7 @@ import org.apache.spark.util.{AkkaUtils, Utils} class SparkEnv ( val executorId: String, val actorSystem: ActorSystem, + val rpcEnv: RpcEnv, val serializer: Serializer, val closureSerializer: Serializer, val cacheManager: CacheManager, @@ -87,6 +89,7 @@ class SparkEnv ( blockManager.master.stop() metricsSystem.stop() actorSystem.shutdown() + rpcEnv.stopAll() // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut // down, but let's call it anyway in case it gets fixed in a later release // UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it. @@ -222,6 +225,8 @@ object SparkEnv extends Logging { AkkaUtils.createActorSystem(actorSystemName, hostname, port, conf, securityManager) } + val rpcEnv = new AkkaRpcEnv(actorSystem, conf) + // Figure out which port Akka actually bound to in case the original port is 0 or occupied. // This is so that we tell the executors the correct port to connect to. if (isDriver) { @@ -353,6 +358,7 @@ object SparkEnv extends Logging { new SparkEnv( executorId, actorSystem, + rpcEnv, serializer, closureSerializer, cacheManager, diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 5f46f3b1f085..cc0f4eeb6365 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -19,11 +19,12 @@ package org.apache.spark.executor import java.nio.ByteBuffer +import org.apache.spark.rpc.{AkkaRpcEnv, RpcEnv, RpcEndPointRef, RpcEndPoint} + import scala.concurrent.Await -import akka.actor.{Actor, ActorSelection, ActorSystem, Props} +import akka.actor.{ActorSystem, Props} import akka.pattern.Patterns -import akka.remote.{RemotingLifecycleEvent, DisassociatedEvent} import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv} import org.apache.spark.TaskState.TaskState @@ -31,7 +32,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.worker.WorkerWatcher import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} +import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} private[spark] class CoarseGrainedExecutorBackend( driverUrl: String, @@ -40,29 +41,36 @@ private[spark] class CoarseGrainedExecutorBackend( cores: Int, sparkProperties: Seq[(String, String)], actorSystem: ActorSystem) - extends Actor with ActorLogReceive with ExecutorBackend with Logging { + extends RpcEndPoint with ExecutorBackend with Logging { Utils.checkHostPort(hostPort, "Expected hostport") var executor: Executor = null - var driver: ActorSelection = null + var driver: RpcEndPointRef = null + + preStart() - override def preStart() { + def preStart() { logInfo("Connecting to driver: " + driverUrl) - driver = context.actorSelection(driverUrl) - driver ! RegisterExecutor(executorId, hostPort, cores) - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + val rpcEnv = new AkkaRpcEnv(actorSystem, conf) + driver = rpcEnv.setupEndPointRefByUrl(driverUrl) + driver.send(RegisterExecutor(executorId, hostPort, cores)) + } + + override def remoteConnectionTerminated(remoteAddress: String): Unit = { + logError(s"Driver $remoteAddress disassociated! Shutting down.") + System.exit(1) } - override def receiveWithLogging = { + override def receive(sender: RpcEndPointRef, message: Any): Unit = message match { case RegisteredExecutor => logInfo("Successfully registered with driver") val (hostname, _) = Utils.parseHostPort(hostPort) executor = new Executor(executorId, hostname, sparkProperties, cores, isLocal = false, actorSystem) - case RegisterExecutorFailed(message) => - logError("Slave registration failed: " + message) + case RegisterExecutorFailed(failureMessage) => + logError("Slave registration failed: " + failureMessage) System.exit(1) case LaunchTask(data) => @@ -84,19 +92,15 @@ private[spark] class CoarseGrainedExecutorBackend( executor.killTask(taskId, interruptThread) } - case x: DisassociatedEvent => - logError(s"Driver $x disassociated! Shutting down.") - System.exit(1) - case StopExecutor => logInfo("Driver commanded a shutdown") executor.stop() - context.stop(self) - context.system.shutdown() + //rpcEnv.stop(this) + // TODO(rxin): Stop this properly. } override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) { - driver ! StatusUpdate(executorId, taskId, state, data) + driver.send(StatusUpdate(executorId, taskId, state, data)) } } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 52de6980ecbf..054c5a9d6b26 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -22,6 +22,8 @@ import java.lang.management.ManagementFactory import java.nio.ByteBuffer import java.util.concurrent._ +import org.apache.spark.network.rpc.{SimpleRpcClient, SimpleRpcServer} + import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal @@ -95,8 +97,8 @@ private[spark] class Executor( } // Create an actor for receiving RPCs from the driver - private val executorActor = env.actorSystem.actorOf( - Props(new ExecutorActor(executorId)), "ExecutorActor") + private val executorActor = + env.rpcEnv.setupEndPoint("ExecutorActor", new ExecutorActor(executorId)) // Create our ClassLoader // do this after SparkEnv creation so can access the SecurityManager @@ -137,7 +139,7 @@ private[spark] class Executor( def stop() { env.metricsSystem.report() - env.actorSystem.stop(executorActor) + env.rpcEnv.stop(executorActor) isStopped = true threadPool.shutdown() if (!isLocal) { @@ -363,11 +365,8 @@ private[spark] class Executor( } def startDriverHeartbeater() { - val interval = conf.getInt("spark.executor.heartbeatInterval", 10000) - val timeout = AkkaUtils.lookupTimeout(conf) - val retryAttempts = AkkaUtils.numRetries(conf) - val retryIntervalMs = AkkaUtils.retryWaitMs(conf) - val heartbeatReceiverRef = AkkaUtils.makeDriverRef("HeartbeatReceiver", conf, env.actorSystem) + val interval = conf.getInt("spark.executor.heartbeatInterval", 3000) + val heartbeatReceiverRef = env.rpcEnv.setupDriverEndPointRef("HeartbeatReceiver") val t = new Thread() { override def run() { @@ -379,7 +378,7 @@ private[spark] class Executor( for (taskRunner <- runningTasks.values()) { if (!taskRunner.attemptedTask.isEmpty) { Option(taskRunner.task).flatMap(_.metrics).foreach { metrics => - metrics.updateShuffleReadMetrics + metrics.updateShuffleReadMetrics() if (isLocal) { // JobProgressListener will hold an reference of it during // onExecutorMetricsUpdate(), then JobProgressListener can not see @@ -396,8 +395,7 @@ private[spark] class Executor( val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId) try { - val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef, - retryAttempts, retryIntervalMs, timeout) + val response = heartbeatReceiverRef.askWithReply[HeartbeatResponse](message) if (response.reregisterBlockManager) { logWarning("Told to re-register on heartbeat") env.blockManager.reregister() diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala index 41925f7e97e8..c05f9cf143f1 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala @@ -17,10 +17,9 @@ package org.apache.spark.executor -import akka.actor.Actor import org.apache.spark.Logging - -import org.apache.spark.util.{Utils, ActorLogReceive} +import org.apache.spark.rpc.{RpcEndPointRef, RpcEndPoint} +import org.apache.spark.util.Utils /** * Driver -> Executor message to trigger a thread dump. @@ -31,11 +30,11 @@ private[spark] case object TriggerThreadDump * Actor that runs inside of executors to enable driver -> executor RPC. */ private[spark] -class ExecutorActor(executorId: String) extends Actor with ActorLogReceive with Logging { +class ExecutorActor(executorId: String) extends RpcEndPoint with Logging { - override def receiveWithLogging = { + override def receive(sender: RpcEndPointRef, message: Any): Unit = message match { case TriggerThreadDump => - sender ! Utils.getThreadDump() + sender.send(Utils.getThreadDump()) } } diff --git a/core/src/main/scala/org/apache/spark/network/rpc/SimpleRpcServer.scala b/core/src/main/scala/org/apache/spark/network/rpc/SimpleRpcServer.scala new file mode 100644 index 000000000000..9b76ae076a34 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/rpc/SimpleRpcServer.scala @@ -0,0 +1,86 @@ +package org.apache.spark.network.rpc + +import java.nio.ByteBuffer +import org.apache.spark.SparkConf +import org.apache.spark.network.TransportContext +import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} +import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.network.server._ +import org.apache.spark.network.util.JavaUtils +import org.slf4j.Logger + + +class SimpleRpcClient(conf: SparkConf) { + private val transportConf = SparkTransportConf.fromSparkConf(conf, 1) + val transportContext = new TransportContext(transportConf, new RpcHandler { + override def getStreamManager: StreamManager = new OneForOneStreamManager + + override def receive( + client: TransportClient, message: Array[Byte], callback: RpcResponseCallback): Unit = { + println("gotten some message " + JavaUtils.bytesToString(ByteBuffer.wrap(message))) + callback.onSuccess(new Array[Byte](0)) + } + }) + val clientF = transportContext.createClientFactory() + val client = clientF.createClient("localhost", 12345) + + def sendMessage(message: Any): Unit = { + client.sendRpcSync(JavaUtils.serialize(message), 5000) + } +} + + +abstract class SimpleRpcServer(conf: SparkConf) { + + protected def log: Logger + + private val transportConf = SparkTransportConf.fromSparkConf(conf, 1) + + val transportContext = new TransportContext(transportConf, new RpcHandler { + override def getStreamManager: StreamManager = new OneForOneStreamManager + + override def receive( + client: TransportClient, message: Array[Byte], callback: RpcResponseCallback): Unit = { + callback.onSuccess(new Array[Byte](0)) + val received = JavaUtils.deserialize[Any](message) + println("got mesage " + received) + remote = client + if (receiveWithLogging.isDefinedAt(received)) { + receiveWithLogging.apply(received) + } + } + }) + + private[this] val clientFactory = transportContext.createClientFactory() + private[this] var server: TransportServer = _ + + startServer() + private[this] val client = clientFactory.createClient("localhost", 12345) + + def startServer(): Unit = { + server = transportContext.createServer(12345) + log.info("RPC server created on " + server.getPort) + } + + var remote: TransportClient = _ + + def reply(message: Any): Unit = { +// val c = clientFactory.createClient("localhost", +// remote.channel.remoteAddress.asInstanceOf[InetSocketAddress].getPort) +// c.sendRpc(JavaUtils.serialize(message), new RpcResponseCallback { +// override def onSuccess(response: Array[Byte]): Unit = {} +// override def onFailure(e: Throwable): Unit = {} +// }) +// remote.sendRpc(JavaUtils.serialize(message), new RpcResponseCallback { +// override def onFailure(e: Throwable): Unit = {} +// override def onSuccess(response: Array[Byte]): Unit = {} +// }) + remote.sendRpcSync(JavaUtils.serialize(message), 5000) + } + + def sendMessage(message: Any): Unit = { + client.sendRpcSync(JavaUtils.serialize(message), 5000) + } + + def receiveWithLogging: PartialFunction[Any, Unit] +} diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala new file mode 100644 index 000000000000..8c64859608d0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} + +import scala.concurrent.Await +import scala.concurrent.duration.Duration + +import akka.actor.{ActorRef, Actor, Props, ActorSystem} +import akka.pattern.ask + +import org.apache.spark.{Logging, SparkException, SparkConf} +import org.apache.spark.util.AkkaUtils + + +abstract class RpcEnv { + def setupEndPoint(name: String, endpoint: RpcEndPoint): RpcEndPointRef + + def setupDriverEndPointRef(name: String): RpcEndPointRef + + def setupEndPointRefByUrl(url: String): RpcEndPointRef + + def stop(endpoint: RpcEndPointRef): Unit + + def stopAll(): Unit +} + + +abstract class RpcEndPoint { + def receive(sender: RpcEndPointRef, message: Any): Unit + + def remoteConnectionTerminated(remoteAddress: String): Unit = { + // By default, do nothing. + } +} + + +abstract class RpcEndPointRef { + + def address: String + + def askWithReply[T](message: Any): T + + def send(message: Any): Unit +} + + +class AkkaRpcEnv(actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { + + override def setupEndPoint(name: String, endpoint: RpcEndPoint): RpcEndPointRef = { + val actorRef = actorSystem.actorOf(Props(new Actor { + override def preStart(): Unit = { + // Listen for remote client disconnection events, since they don't go through Akka's watch() + context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + } + + override def receive: Receive = { + case DisassociatedEvent(_, remoteAddress, _) => + endpoint.remoteConnectionTerminated(remoteAddress.toString) + + case message: Any => + endpoint.receive(new AkkaRpcEndPointRef(sender(), conf), message) + } + }), name = name) + new AkkaRpcEndPointRef(actorRef, conf) + } + + override def setupDriverEndPointRef(name: String): RpcEndPointRef = { + new AkkaRpcEndPointRef(AkkaUtils.makeDriverRef(name, conf, actorSystem), conf) + } + + override def setupEndPointRefByUrl(url: String): RpcEndPointRef = { + val timeout = Duration.create(conf.getLong("spark.akka.lookupTimeout", 30), "seconds") + val ref = Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) + new AkkaRpcEndPointRef(ref, conf) + } + + override def stopAll(): Unit = { + // Do nothing since actorSystem was created outside. + } + + override def stop(endpoint: RpcEndPointRef): Unit = { + require(endpoint.isInstanceOf[AkkaRpcEndPointRef]) + actorSystem.stop(endpoint.asInstanceOf[AkkaRpcEndPointRef].actorRef) + } +} + + +class AkkaRpcEndPointRef(private[rpc] val actorRef: ActorRef, conf: SparkConf) + extends RpcEndPointRef with Serializable with Logging { + + private[this] val maxRetries = conf.getInt("spark.akka.num.retries", 3) + private[this] val retryWaitMs = conf.getInt("spark.akka.retry.wait", 3000) + private[this] val timeout = + Duration.create(conf.getLong("spark.akka.lookupTimeout", 30), "seconds") + + override def address: String = actorRef.path.address.toString + + override def askWithReply[T](message: Any): T = { + var attempts = 0 + var lastException: Exception = null + while (attempts < maxRetries) { + attempts += 1 + try { + val future = actorRef.ask(message)(timeout) + val result = Await.result(future, timeout) + if (result == null) { + throw new SparkException("Actor returned null") + } + return result.asInstanceOf[T] + } catch { + case ie: InterruptedException => throw ie + case e: Exception => + lastException = e + logWarning("Error sending message in " + attempts + " attempts", e) + } + Thread.sleep(retryWaitMs) + } + + throw new SparkException( + "Error sending message [message = " + message + "]", lastException) + } + + override def send(message: Any): Unit = { + actorRef ! message + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index 37f2e34ceb24..4b22f87f7316 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -67,7 +67,7 @@ public class TransportClient implements Closeable { private final Logger logger = LoggerFactory.getLogger(TransportClient.class); - private final Channel channel; + public final Channel channel; private final TransportResponseHandler handler; public TransportClient(Channel channel, TransportResponseHandler handler) { diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java index 81f8d7f96350..7c4fd8161454 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -34,6 +34,7 @@ public final class MessageDecoder extends MessageToMessageDecoder { private final Logger logger = LoggerFactory.getLogger(MessageDecoder.class); + @Override public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { Message.Type msgType = Message.Type.decode(in); diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index bf8a1fc42fc6..d58f4e104243 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -41,6 +41,32 @@ public class JavaUtils { private static final Logger logger = LoggerFactory.getLogger(JavaUtils.class); + public static T deserialize(byte[] bytes) { + try { + ObjectInputStream is = new ObjectInputStream(new ByteArrayInputStream(bytes)); + Object out = is.readObject(); + is.close(); + return (T) out; + } catch (ClassNotFoundException e) { + throw new RuntimeException("Could not deserialize object", e); + } catch (IOException e) { + throw new RuntimeException("Could not deserialize object", e); + } + } + + // TODO: Make this configurable, do not use Java serialization! + public static byte[] serialize(Object object) { + try { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ObjectOutputStream os = new ObjectOutputStream(baos); + os.writeObject(object); + os.close(); + return baos.toByteArray(); + } catch (IOException e) { + throw new RuntimeException("Could not serialize object", e); + } + } + /** Closes the given object, ignoring IOExceptions. */ public static void closeQuietly(Closeable closeable) { try { From c9f2cc929777befb1ce428c4d90293edda13a684 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 20 Dec 2014 01:26:03 -0800 Subject: [PATCH 02/36] First version that ran. --- .../org/apache/spark/HeartbeatReceiver.scala | 2 +- .../CoarseGrainedExecutorBackend.scala | 30 +++++++++---------- .../apache/spark/executor/ExecutorActor.scala | 2 +- .../scala/org/apache/spark/rpc/RpcEnv.scala | 23 ++++++++++++-- .../cluster/CoarseGrainedClusterMessage.scala | 5 +++- .../CoarseGrainedSchedulerBackend.scala | 27 +++++++++-------- .../scheduler/cluster/ExecutorData.scala | 6 ++-- 7 files changed, 59 insertions(+), 36 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index b317cdcae86d..2015d22324d1 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -39,7 +39,7 @@ private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) private[spark] class HeartbeatReceiver(scheduler: TaskScheduler, conf: SparkConf) extends RpcEndPoint with Logging { - override def receive(sender: RpcEndPointRef, message: Any): Unit = message match { + override def receive(sender: RpcEndPointRef) = { case Heartbeat(executorId, taskMetrics, blockManagerId) => val response = HeartbeatResponse( !scheduler.executorHeartbeatReceived(executorId, taskMetrics, blockManagerId)) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 1d88c23eea7b..37582cf23e61 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -46,13 +46,10 @@ private[spark] class CoarseGrainedExecutorBackend( var executor: Executor = null var driver: RpcEndPointRef = null - preStart() - - def preStart() { + def notifyDriver(driverUrl: String, selfRef: RpcEndPointRef): Unit = { logInfo("Connecting to driver: " + driverUrl) - val rpcEnv = new AkkaRpcEnv(actorSystem, conf) - driver = rpcEnv.setupEndPointRefByUrl(driverUrl) - driver.send(RegisterExecutor(executorId, hostPort, cores)) + driver = env.rpcEnv.setupEndPointRefByUrl(driverUrl) + driver.send(RegisterExecutor(executorId, hostPort, cores, selfRef)) } override def remoteConnectionTerminated(remoteAddress: String): Unit = { @@ -60,7 +57,7 @@ private[spark] class CoarseGrainedExecutorBackend( System.exit(1) } - override def receive(sender: RpcEndPointRef, message: Any): Unit = message match { + override def receive(sender: RpcEndPointRef) = { case RegisteredExecutor => logInfo("Successfully registered with driver") val (hostname, _) = Utils.parseHostPort(hostPort) @@ -122,12 +119,12 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { val port = executorConf.getInt("spark.executor.port", 0) val (fetcher, _) = AkkaUtils.createActorSystem( "driverPropsFetcher", hostname, port, executorConf, new SecurityManager(executorConf)) - val driver = fetcher.actorSelection(driverUrl) + val driverActor = fetcher.actorSelection(driverUrl) val timeout = AkkaUtils.askTimeout(executorConf) - val fut = Patterns.ask(driver, RetrieveSparkProps, timeout) + val fut = Patterns.ask(driverActor, RetrieveSparkProps, timeout) val props = Await.result(fut, timeout).asInstanceOf[Seq[(String, String)]] ++ Seq[(String, String)](("spark.app.id", appId)) - fetcher.shutdown() + //fetcher.shutdown() // Create SparkEnv using properties we fetched from the driver. val driverConf = new SparkConf().setAll(props) @@ -138,12 +135,15 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { val boundPort = env.conf.getInt("spark.executor.port", 0) assert(boundPort != 0) - // Start the CoarseGrainedExecutorBackend actor. + // Start the CoarseGrainedExecutorBackend RPC end point. val sparkHostPort = hostname + ":" + boundPort - env.actorSystem.actorOf( - Props(classOf[CoarseGrainedExecutorBackend], - driverUrl, executorId, sparkHostPort, cores, env), - name = "Executor") + val rpc = new CoarseGrainedExecutorBackend(driverUrl, executorId, sparkHostPort, cores, env) + val rpcRef = env.rpcEnv.setupEndPoint("Executor", rpc) + + rpc.notifyDriver(driverUrl, rpcRef) + + // Notify the driver of our existence. + workerUrl.foreach { url => env.actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher") } diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala index c05f9cf143f1..8f10dd15a7fe 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala @@ -32,7 +32,7 @@ private[spark] case object TriggerThreadDump private[spark] class ExecutorActor(executorId: String) extends RpcEndPoint with Logging { - override def receive(sender: RpcEndPointRef, message: Any): Unit = message match { + override def receive(sender: RpcEndPointRef) = { case TriggerThreadDump => sender.send(Utils.getThreadDump()) } diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 8c64859608d0..89a7b41e42a8 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -29,6 +29,9 @@ import org.apache.spark.{Logging, SparkException, SparkConf} import org.apache.spark.util.AkkaUtils +/** + * An RPC environment. + */ abstract class RpcEnv { def setupEndPoint(name: String, endpoint: RpcEndPoint): RpcEndPointRef @@ -42,8 +45,12 @@ abstract class RpcEnv { } +/** + * An end point for the RPC that defines what functions to trigger given a message. + */ abstract class RpcEndPoint { - def receive(sender: RpcEndPointRef, message: Any): Unit + + def receive(sender: RpcEndPointRef): PartialFunction[Any, Unit] def remoteConnectionTerminated(remoteAddress: String): Unit = { // By default, do nothing. @@ -51,12 +58,18 @@ abstract class RpcEndPoint { } +/** + * A reference for a remote [[RpcEndPoint]]. + */ abstract class RpcEndPointRef { def address: String def askWithReply[T](message: Any): T + /** + * Send a message to the remote end point asynchronously. No delivery guarantee is provided. + */ def send(message: Any): Unit } @@ -75,7 +88,11 @@ class AkkaRpcEnv(actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { endpoint.remoteConnectionTerminated(remoteAddress.toString) case message: Any => - endpoint.receive(new AkkaRpcEndPointRef(sender(), conf), message) + println("got message " + name + " : " + message) + val pf = endpoint.receive(new AkkaRpcEndPointRef(sender(), conf)) + if (pf.isDefinedAt(message)) { + pf.apply(message) + } } }), name = name) new AkkaRpcEndPointRef(actorRef, conf) @@ -140,4 +157,6 @@ class AkkaRpcEndPointRef(private[rpc] val actorRef: ActorRef, conf: SparkConf) override def send(message: Any): Unit = { actorRef ! message } + + override def toString: String = s"${getClass.getSimpleName}($actorRef)" } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 1da6fe976da5..7c828bb05ceb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -20,8 +20,10 @@ package org.apache.spark.scheduler.cluster import java.nio.ByteBuffer import org.apache.spark.TaskState.TaskState +import org.apache.spark.rpc.RpcEndPointRef import org.apache.spark.util.{SerializableBuffer, Utils} + private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable private[spark] object CoarseGrainedClusterMessages { @@ -39,7 +41,8 @@ private[spark] object CoarseGrainedClusterMessages { case class RegisterExecutorFailed(message: String) extends CoarseGrainedClusterMessage // Executors to driver - case class RegisterExecutor(executorId: String, hostPort: String, cores: Int) + case class RegisterExecutor( + executorId: String, hostPort: String, cores: Int, rpcRef: RpcEndPointRef) extends CoarseGrainedClusterMessage { Utils.checkHostPort(hostPort, "Expected host port") } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index fe9914b50bc5..1956ebbf8181 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -71,7 +71,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor with ActorLogReceive { override protected def log = CoarseGrainedSchedulerBackend.this.log - private val addressToExecutorId = new HashMap[Address, String] + private val addressToExecutorId = new HashMap[String, String] override def preStart() { // Listen for remote client disconnection events, since they don't go through Akka's watch() @@ -84,19 +84,19 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste } def receiveWithLogging = { - case RegisterExecutor(executorId, hostPort, cores) => + case RegisterExecutor(executorId, hostPort, cores, executorRpcRef) => Utils.checkHostPort(hostPort, "Host port expected " + hostPort) if (executorDataMap.contains(executorId)) { - sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId) + executorRpcRef.send(RegisterExecutorFailed("Duplicate executor ID: " + executorId)) } else { - logInfo("Registered executor: " + sender + " with ID " + executorId) - sender ! RegisteredExecutor + logInfo("Registered executor: " + executorRpcRef + " with ID " + executorId) + executorRpcRef.send(RegisteredExecutor) - addressToExecutorId(sender.path.address) = executorId + addressToExecutorId(executorRpcRef.address) = executorId totalCoreCount.addAndGet(cores) totalRegisteredExecutors.addAndGet(1) val (host, _) = Utils.parseHostPort(hostPort) - val data = new ExecutorData(sender, sender.path.address, host, cores, cores) + val data = new ExecutorData(executorRpcRef, executorRpcRef.address, host, cores, cores) // This must be synchronized because variables mutated // in this block are read when requesting executors CoarseGrainedSchedulerBackend.this.synchronized { @@ -129,7 +129,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste case KillTask(taskId, executorId, interruptThread) => executorDataMap.get(executorId) match { case Some(executorInfo) => - executorInfo.executorActor ! KillTask(taskId, executorId, interruptThread) + executorInfo.executorActor.send(KillTask(taskId, executorId, interruptThread)) case None => // Ignoring the task kill since the executor is not registered. logWarning(s"Attempted to kill task $taskId for unknown executor $executorId.") @@ -142,7 +142,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste case StopExecutors => logInfo("Asking each executor to shut down") for ((_, executorData) <- executorDataMap) { - executorData.executorActor ! StopExecutor + executorData.executorActor.send(StopExecutor) } sender ! true @@ -150,11 +150,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste removeExecutor(executorId, reason) sender ! true - case DisassociatedEvent(_, address, _) => - addressToExecutorId.get(address).foreach(removeExecutor(_, - "remote Akka client disassociated")) +// case DisassociatedEvent(_, address, _) => +// addressToExecutorId.get(address).foreach(removeExecutor(_, +// "remote Akka client disassociated")) case RetrieveSparkProps => + println("sending the other side properties RetrieveSparkProps") sender ! sparkProperties } @@ -195,7 +196,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste else { val executorData = executorDataMap(task.executorId) executorData.freeCores -= scheduler.CPUS_PER_TASK - executorData.executorActor ! LaunchTask(new SerializableBuffer(serializedTask)) + executorData.executorActor.send(LaunchTask(new SerializableBuffer(serializedTask))) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala index b71bd5783d6d..630b38a58e72 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler.cluster -import akka.actor.{Address, ActorRef} +import org.apache.spark.rpc.RpcEndPointRef /** * Grouping of data for an executor used by CoarseGrainedSchedulerBackend. @@ -29,8 +29,8 @@ import akka.actor.{Address, ActorRef} * @param totalCores The total number of cores available to the executor */ private[cluster] class ExecutorData( - val executorActor: ActorRef, - val executorAddress: Address, + val executorActor: RpcEndPointRef, + val executorAddress: String, val executorHost: String , var freeCores: Int, val totalCores: Int From bdc67b56033c10a4e3bff0ab7999886ff39a5481 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 24 Dec 2014 10:58:00 -0800 Subject: [PATCH 03/36] Minor update. --- .../executor/CoarseGrainedExecutorBackend.scala | 15 +++++---------- .../main/scala/org/apache/spark/rpc/RpcEnv.scala | 10 +++++++--- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 37582cf23e61..0423b48e7a31 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -46,12 +46,6 @@ private[spark] class CoarseGrainedExecutorBackend( var executor: Executor = null var driver: RpcEndPointRef = null - def notifyDriver(driverUrl: String, selfRef: RpcEndPointRef): Unit = { - logInfo("Connecting to driver: " + driverUrl) - driver = env.rpcEnv.setupEndPointRefByUrl(driverUrl) - driver.send(RegisterExecutor(executorId, hostPort, cores, selfRef)) - } - override def remoteConnectionTerminated(remoteAddress: String): Unit = { logError(s"Driver $remoteAddress disassociated! Shutting down.") System.exit(1) @@ -124,7 +118,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { val fut = Patterns.ask(driverActor, RetrieveSparkProps, timeout) val props = Await.result(fut, timeout).asInstanceOf[Seq[(String, String)]] ++ Seq[(String, String)](("spark.app.id", appId)) - //fetcher.shutdown() + fetcher.shutdown() // Create SparkEnv using properties we fetched from the driver. val driverConf = new SparkConf().setAll(props) @@ -140,9 +134,10 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { val rpc = new CoarseGrainedExecutorBackend(driverUrl, executorId, sparkHostPort, cores, env) val rpcRef = env.rpcEnv.setupEndPoint("Executor", rpc) - rpc.notifyDriver(driverUrl, rpcRef) - - // Notify the driver of our existence. + // Register this executor with the driver. + logInfo("Connecting to driver: " + driverUrl) + val driverRef = env.rpcEnv.setupEndPointRefByUrl(driverUrl) + driverRef.send(RegisterExecutor(executorId, sparkHostPort, cores, rpcRef)) workerUrl.foreach { url => env.actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher") diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 89a7b41e42a8..8a08058e7a3e 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -17,13 +17,13 @@ package org.apache.spark.rpc -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} - import scala.concurrent.Await import scala.concurrent.duration.Duration import akka.actor.{ActorRef, Actor, Props, ActorSystem} import akka.pattern.ask +import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} +import org.slf4j.Logger import org.apache.spark.{Logging, SparkException, SparkConf} import org.apache.spark.util.AkkaUtils @@ -55,6 +55,10 @@ abstract class RpcEndPoint { def remoteConnectionTerminated(remoteAddress: String): Unit = { // By default, do nothing. } + + protected def log: Logger + + private[rpc] def logMessage = log } @@ -88,7 +92,7 @@ class AkkaRpcEnv(actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { endpoint.remoteConnectionTerminated(remoteAddress.toString) case message: Any => - println("got message " + name + " : " + message) + endpoint.logMessage.trace("Received RPC message: " + message) val pf = endpoint.receive(new AkkaRpcEndPointRef(sender(), conf)) if (pf.isDefinedAt(message)) { pf.apply(message) From 2f4b9d864ade386b171ae7983449ae86bc55eae6 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 25 Dec 2014 11:41:43 +0800 Subject: [PATCH 04/36] abstract class => trait --- core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 8a08058e7a3e..d240bace784a 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -32,7 +32,7 @@ import org.apache.spark.util.AkkaUtils /** * An RPC environment. */ -abstract class RpcEnv { +trait RpcEnv { def setupEndPoint(name: String, endpoint: RpcEndPoint): RpcEndPointRef def setupDriverEndPointRef(name: String): RpcEndPointRef @@ -48,7 +48,7 @@ abstract class RpcEnv { /** * An end point for the RPC that defines what functions to trigger given a message. */ -abstract class RpcEndPoint { +trait RpcEndPoint { def receive(sender: RpcEndPointRef): PartialFunction[Any, Unit] @@ -65,7 +65,7 @@ abstract class RpcEndPoint { /** * A reference for a remote [[RpcEndPoint]]. */ -abstract class RpcEndPointRef { +trait RpcEndPointRef { def address: String @@ -131,7 +131,7 @@ class AkkaRpcEndPointRef(private[rpc] val actorRef: ActorRef, conf: SparkConf) private[this] val timeout = Duration.create(conf.getLong("spark.akka.lookupTimeout", 30), "seconds") - override def address: String = actorRef.path.address.toString + override val address: String = actorRef.path.address.toString override def askWithReply[T](message: Any): T = { var attempts = 0 From 0f7f0325ecef99e0cf758c0675c44bddc46a0541 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 25 Dec 2014 12:03:50 +0800 Subject: [PATCH 05/36] Move Akka classes to org.apache.spark.rpc.akka --- .../scala/org/apache/spark/SparkEnv.scala | 3 +- .../CoarseGrainedExecutorBackend.scala | 3 +- .../scala/org/apache/spark/rpc/RpcEnv.scala | 98 --------------- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 116 ++++++++++++++++++ .../org/apache/spark/rpc/RpcEnvSuite.scala | 24 ++++ .../spark/rpc/akka/AkkaRpcEnvSuite.scala | 24 ++++ 6 files changed, 168 insertions(+), 100 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala create mode 100644 core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 2a22e60abaff..d7d357f61778 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -34,7 +34,8 @@ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.network.nio.NioBlockTransferService -import org.apache.spark.rpc.{AkkaRpcEnv, RpcEnv} +import org.apache.spark.rpc.RpcEnv +import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 0423b48e7a31..9fc76a34f72f 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -28,7 +28,8 @@ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv} import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.worker.WorkerWatcher -import org.apache.spark.rpc.{AkkaRpcEnv, RpcEnv, RpcEndPointRef, RpcEndPoint} +import org.apache.spark.rpc.{RpcEnv, RpcEndPointRef, RpcEndPoint} +import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index d240bace784a..e8e4e5e2f778 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -17,18 +17,8 @@ package org.apache.spark.rpc -import scala.concurrent.Await -import scala.concurrent.duration.Duration - -import akka.actor.{ActorRef, Actor, Props, ActorSystem} -import akka.pattern.ask -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import org.slf4j.Logger -import org.apache.spark.{Logging, SparkException, SparkConf} -import org.apache.spark.util.AkkaUtils - - /** * An RPC environment. */ @@ -76,91 +66,3 @@ trait RpcEndPointRef { */ def send(message: Any): Unit } - - -class AkkaRpcEnv(actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { - - override def setupEndPoint(name: String, endpoint: RpcEndPoint): RpcEndPointRef = { - val actorRef = actorSystem.actorOf(Props(new Actor { - override def preStart(): Unit = { - // Listen for remote client disconnection events, since they don't go through Akka's watch() - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - } - - override def receive: Receive = { - case DisassociatedEvent(_, remoteAddress, _) => - endpoint.remoteConnectionTerminated(remoteAddress.toString) - - case message: Any => - endpoint.logMessage.trace("Received RPC message: " + message) - val pf = endpoint.receive(new AkkaRpcEndPointRef(sender(), conf)) - if (pf.isDefinedAt(message)) { - pf.apply(message) - } - } - }), name = name) - new AkkaRpcEndPointRef(actorRef, conf) - } - - override def setupDriverEndPointRef(name: String): RpcEndPointRef = { - new AkkaRpcEndPointRef(AkkaUtils.makeDriverRef(name, conf, actorSystem), conf) - } - - override def setupEndPointRefByUrl(url: String): RpcEndPointRef = { - val timeout = Duration.create(conf.getLong("spark.akka.lookupTimeout", 30), "seconds") - val ref = Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) - new AkkaRpcEndPointRef(ref, conf) - } - - override def stopAll(): Unit = { - // Do nothing since actorSystem was created outside. - } - - override def stop(endpoint: RpcEndPointRef): Unit = { - require(endpoint.isInstanceOf[AkkaRpcEndPointRef]) - actorSystem.stop(endpoint.asInstanceOf[AkkaRpcEndPointRef].actorRef) - } -} - - -class AkkaRpcEndPointRef(private[rpc] val actorRef: ActorRef, conf: SparkConf) - extends RpcEndPointRef with Serializable with Logging { - - private[this] val maxRetries = conf.getInt("spark.akka.num.retries", 3) - private[this] val retryWaitMs = conf.getInt("spark.akka.retry.wait", 3000) - private[this] val timeout = - Duration.create(conf.getLong("spark.akka.lookupTimeout", 30), "seconds") - - override val address: String = actorRef.path.address.toString - - override def askWithReply[T](message: Any): T = { - var attempts = 0 - var lastException: Exception = null - while (attempts < maxRetries) { - attempts += 1 - try { - val future = actorRef.ask(message)(timeout) - val result = Await.result(future, timeout) - if (result == null) { - throw new SparkException("Actor returned null") - } - return result.asInstanceOf[T] - } catch { - case ie: InterruptedException => throw ie - case e: Exception => - lastException = e - logWarning("Error sending message in " + attempts + " attempts", e) - } - Thread.sleep(retryWaitMs) - } - - throw new SparkException( - "Error sending message [message = " + message + "]", lastException) - } - - override def send(message: Any): Unit = { - actorRef ! message - } - - override def toString: String = s"${getClass.getSimpleName}($actorRef)" -} diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala new file mode 100644 index 000000000000..667011ba9852 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.akka + +import scala.concurrent.Await +import scala.concurrent.duration.Duration + +import akka.actor.{ActorRef, Actor, Props, ActorSystem} +import akka.pattern.ask +import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} + +import org.apache.spark.{Logging, SparkException, SparkConf} +import org.apache.spark.rpc._ +import org.apache.spark.util.AkkaUtils + +class AkkaRpcEnv(actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { + + override def setupEndPoint(name: String, endpoint: RpcEndPoint): RpcEndPointRef = { + val actorRef = actorSystem.actorOf(Props(new Actor { + override def preStart(): Unit = { + // Listen for remote client disconnection events, since they don't go through Akka's watch() + context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + } + + override def receive: Receive = { + case DisassociatedEvent(_, remoteAddress, _) => + endpoint.remoteConnectionTerminated(remoteAddress.toString) + + case message: Any => + endpoint.logMessage.trace("Received RPC message: " + message) + val pf = endpoint.receive(new AkkaRpcEndPointRef(sender(), conf)) + if (pf.isDefinedAt(message)) { + pf.apply(message) + } + } + }), name = name) + new AkkaRpcEndPointRef(actorRef, conf) + } + + override def setupDriverEndPointRef(name: String): RpcEndPointRef = { + new AkkaRpcEndPointRef(AkkaUtils.makeDriverRef(name, conf, actorSystem), conf) + } + + override def setupEndPointRefByUrl(url: String): RpcEndPointRef = { + val timeout = Duration.create(conf.getLong("spark.akka.lookupTimeout", 30), "seconds") + val ref = Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) + new AkkaRpcEndPointRef(ref, conf) + } + + override def stopAll(): Unit = { + // Do nothing since actorSystem was created outside. + } + + override def stop(endpoint: RpcEndPointRef): Unit = { + require(endpoint.isInstanceOf[AkkaRpcEndPointRef]) + actorSystem.stop(endpoint.asInstanceOf[AkkaRpcEndPointRef].actorRef) + } +} + + +class AkkaRpcEndPointRef(private[rpc] val actorRef: ActorRef, conf: SparkConf) + extends RpcEndPointRef with Serializable with Logging { + + private[this] val maxRetries = conf.getInt("spark.akka.num.retries", 3) + private[this] val retryWaitMs = conf.getInt("spark.akka.retry.wait", 3000) + private[this] val timeout = + Duration.create(conf.getLong("spark.akka.lookupTimeout", 30), "seconds") + + override val address: String = actorRef.path.address.toString + + override def askWithReply[T](message: Any): T = { + var attempts = 0 + var lastException: Exception = null + while (attempts < maxRetries) { + attempts += 1 + try { + val future = actorRef.ask(message)(timeout) + val result = Await.result(future, timeout) + if (result == null) { + throw new SparkException("Actor returned null") + } + return result.asInstanceOf[T] + } catch { + case ie: InterruptedException => throw ie + case e: Exception => + lastException = e + logWarning("Error sending message in " + attempts + " attempts", e) + } + Thread.sleep(retryWaitMs) + } + + throw new SparkException( + "Error sending message [message = " + message + "]", lastException) + } + + override def send(message: Any): Unit = { + actorRef ! message + } + + override def toString: String = s"${getClass.getSimpleName}($actorRef)" +} diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala new file mode 100644 index 000000000000..51ff16547940 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import org.scalatest.FunSuite + +abstract class RpcEnvSuite extends FunSuite { + +} diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala new file mode 100644 index 000000000000..a475f5481896 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.akka + +import org.apache.spark.rpc.RpcEnvSuite + +class AkkaRpcEnvSuite extends RpcEnvSuite{ + +} From 2d2cba373ee3a0a93c34514c9cd5b26f37c8f06e Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 25 Dec 2014 16:56:03 +0800 Subject: [PATCH 06/36] Update the APIs to support to get the `sender` RpcEndPointRef when sending a message and add some unit tests --- .../org/apache/spark/HeartbeatReceiver.scala | 6 +- .../scala/org/apache/spark/SparkContext.scala | 4 +- .../CoarseGrainedExecutorBackend.scala | 5 +- .../org/apache/spark/executor/Executor.scala | 2 +- .../apache/spark/executor/ExecutorActor.scala | 5 +- .../scala/org/apache/spark/rpc/RpcEnv.scala | 49 ++++++++- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 30 +++--- .../org/apache/spark/rpc/RpcEnvSuite.scala | 99 ++++++++++++++++++- .../spark/rpc/akka/AkkaRpcEnvSuite.scala | 20 +++- 9 files changed, 191 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 2015d22324d1..2eb2a69e4adc 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -18,7 +18,7 @@ package org.apache.spark import org.apache.spark.executor.TaskMetrics -import org.apache.spark.rpc.{RpcEndPointRef, RpcEndPoint} +import org.apache.spark.rpc.{RpcEnv, RpcEndPointRef, RpcEndPoint} import org.apache.spark.scheduler.TaskScheduler import org.apache.spark.storage.BlockManagerId @@ -36,8 +36,8 @@ private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) /** * Lives in the driver to receive heartbeats from executors.. */ -private[spark] class HeartbeatReceiver(scheduler: TaskScheduler, conf: SparkConf) - extends RpcEndPoint with Logging { +private[spark] class HeartbeatReceiver(override val rpcEnv: RpcEnv, + scheduler: TaskScheduler, conf: SparkConf) extends RpcEndPoint with Logging { override def receive(sender: RpcEndPointRef) = { case Heartbeat(executorId, taskMetrics, blockManagerId) => diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 43a51007b642..23b1fce949fa 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -324,8 +324,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli private[spark] var (schedulerBackend, taskScheduler) = SparkContext.createTaskScheduler(this, master) - private val heartbeatReceiver = - env.rpcEnv.setupEndPoint("HeartbeatReceiver", new HeartbeatReceiver(taskScheduler, conf)) + private val heartbeatReceiver = env.rpcEnv.setupEndPoint("HeartbeatReceiver", + new HeartbeatReceiver(env.rpcEnv, taskScheduler, conf)) @volatile private[spark] var dagScheduler: DAGScheduler = _ try { diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 9fc76a34f72f..b09ff8d09803 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -28,8 +28,7 @@ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv} import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.worker.WorkerWatcher -import org.apache.spark.rpc.{RpcEnv, RpcEndPointRef, RpcEndPoint} -import org.apache.spark.rpc.akka.AkkaRpcEnv +import org.apache.spark.rpc.{RpcEndPointRef, RpcEndPoint} import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} @@ -42,6 +41,8 @@ private[spark] class CoarseGrainedExecutorBackend( env: SparkEnv) extends RpcEndPoint with ExecutorBackend with Logging { + override val rpcEnv = env.rpcEnv + Utils.checkHostPort(hostPort, "Expected hostport") var executor: Executor = null diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index f44e2facb4c3..80dc6225ad03 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -84,7 +84,7 @@ private[spark] class Executor( // Create an actor for receiving RPCs from the driver private val executorActor = - env.rpcEnv.setupEndPoint("ExecutorActor", new ExecutorActor(executorId)) + env.rpcEnv.setupEndPoint("ExecutorActor", new ExecutorActor(env.rpcEnv, executorId)) // Create our ClassLoader // do this after SparkEnv creation so can access the SecurityManager diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala index 8f10dd15a7fe..93756bf82a94 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala @@ -18,7 +18,7 @@ package org.apache.spark.executor import org.apache.spark.Logging -import org.apache.spark.rpc.{RpcEndPointRef, RpcEndPoint} +import org.apache.spark.rpc.{RpcEnv, RpcEndPointRef, RpcEndPoint} import org.apache.spark.util.Utils /** @@ -30,7 +30,8 @@ private[spark] case object TriggerThreadDump * Actor that runs inside of executors to enable driver -> executor RPC. */ private[spark] -class ExecutorActor(executorId: String) extends RpcEndPoint with Logging { +class ExecutorActor(override val rpcEnv: RpcEnv, executorId: String) + extends RpcEndPoint with Logging { override def receive(sender: RpcEndPointRef) = { case TriggerThreadDump => diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index e8e4e5e2f778..f6f49266f430 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -17,19 +17,45 @@ package org.apache.spark.rpc +import java.util.concurrent.ConcurrentHashMap + import org.slf4j.Logger /** * An RPC environment. */ trait RpcEnv { - def setupEndPoint(name: String, endpoint: RpcEndPoint): RpcEndPointRef + + /** + * Need this map to set up the `sender` for the send method. + */ + private val endPointToRef = new ConcurrentHashMap[RpcEndPoint, RpcEndPointRef]() + + /** + * Need this map to remove `RpcEndPoint` from `endPointToRef` via a `RpcEndPointRef` + */ + private val refToEndPoint = new ConcurrentHashMap[RpcEndPointRef, RpcEndPoint]() + + + protected def registerEndPoint(endPoint: RpcEndPoint, endPointRef: RpcEndPointRef): Unit = { + refToEndPoint.put(endPointRef, endPoint) + endPointToRef.put(endPoint, endPointRef) + } + + protected def unregisterEndPoint(endPointRef: RpcEndPointRef): Unit = { + val endPoint = refToEndPoint.remove(endPointRef) + endPointToRef.remove(endPoint) + } + + def endPointRef(endPoint: RpcEndPoint): RpcEndPointRef = endPointToRef.get(endPoint) + + def setupEndPoint(name: String, endPoint: RpcEndPoint): RpcEndPointRef def setupDriverEndPointRef(name: String): RpcEndPointRef def setupEndPointRefByUrl(url: String): RpcEndPointRef - def stop(endpoint: RpcEndPointRef): Unit + def stop(endPoint: RpcEndPointRef): Unit def stopAll(): Unit } @@ -40,6 +66,13 @@ trait RpcEnv { */ trait RpcEndPoint { + val rpcEnv: RpcEnv + + /** + * Provide the implicit sender. + */ + implicit final def self: RpcEndPointRef = rpcEnv.endPointRef(this) + def receive(sender: RpcEndPointRef): PartialFunction[Any, Unit] def remoteConnectionTerminated(remoteAddress: String): Unit = { @@ -49,9 +82,17 @@ trait RpcEndPoint { protected def log: Logger private[rpc] def logMessage = log + + final def stop(): Unit = { + rpcEnv.stop(self) + } } +object RpcEndPoint { + final val noSender: RpcEndPointRef = null +} + /** * A reference for a remote [[RpcEndPoint]]. */ @@ -64,5 +105,5 @@ trait RpcEndPointRef { /** * Send a message to the remote end point asynchronously. No delivery guarantee is provided. */ - def send(message: Any): Unit -} + def send(message: Any)(implicit sender: RpcEndPointRef = RpcEndPoint.noSender): Unit +} \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 667011ba9852..2e7b3c7e08e2 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -30,7 +30,7 @@ import org.apache.spark.util.AkkaUtils class AkkaRpcEnv(actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { - override def setupEndPoint(name: String, endpoint: RpcEndPoint): RpcEndPointRef = { + override def setupEndPoint(name: String, endPoint: RpcEndPoint): RpcEndPointRef = { val actorRef = actorSystem.actorOf(Props(new Actor { override def preStart(): Unit = { // Listen for remote client disconnection events, since they don't go through Akka's watch() @@ -39,17 +39,19 @@ class AkkaRpcEnv(actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { override def receive: Receive = { case DisassociatedEvent(_, remoteAddress, _) => - endpoint.remoteConnectionTerminated(remoteAddress.toString) + endPoint.remoteConnectionTerminated(remoteAddress.toString) case message: Any => - endpoint.logMessage.trace("Received RPC message: " + message) - val pf = endpoint.receive(new AkkaRpcEndPointRef(sender(), conf)) + endPoint.logMessage.trace("Received RPC message: " + message) + val pf = endPoint.receive(new AkkaRpcEndPointRef(sender(), conf)) if (pf.isDefinedAt(message)) { pf.apply(message) } } }), name = name) - new AkkaRpcEndPointRef(actorRef, conf) + val endPointRef = new AkkaRpcEndPointRef(actorRef, conf) + registerEndPoint(endPoint, endPointRef) + endPointRef } override def setupDriverEndPointRef(name: String): RpcEndPointRef = { @@ -66,14 +68,14 @@ class AkkaRpcEnv(actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { // Do nothing since actorSystem was created outside. } - override def stop(endpoint: RpcEndPointRef): Unit = { - require(endpoint.isInstanceOf[AkkaRpcEndPointRef]) - actorSystem.stop(endpoint.asInstanceOf[AkkaRpcEndPointRef].actorRef) + override def stop(endPoint: RpcEndPointRef): Unit = { + require(endPoint.isInstanceOf[AkkaRpcEndPointRef]) + unregisterEndPoint(endPoint) + actorSystem.stop(endPoint.asInstanceOf[AkkaRpcEndPointRef].actorRef) } } - -class AkkaRpcEndPointRef(private[rpc] val actorRef: ActorRef, conf: SparkConf) +private[akka] class AkkaRpcEndPointRef(val actorRef: ActorRef, conf: SparkConf) extends RpcEndPointRef with Serializable with Logging { private[this] val maxRetries = conf.getInt("spark.akka.num.retries", 3) @@ -108,7 +110,13 @@ class AkkaRpcEndPointRef(private[rpc] val actorRef: ActorRef, conf: SparkConf) "Error sending message [message = " + message + "]", lastException) } - override def send(message: Any): Unit = { + override def send(message: Any)(implicit sender: RpcEndPointRef = RpcEndPoint.noSender): Unit = { + implicit val actorSender: ActorRef = + if (sender == null) { + Actor.noSender + } else { + sender.asInstanceOf[AkkaRpcEndPointRef].actorRef + } actorRef ! message } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 51ff16547940..d47704ecbd74 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -17,8 +17,103 @@ package org.apache.spark.rpc -import org.scalatest.FunSuite +import org.apache.spark.Logging +import org.scalatest.{BeforeAndAfterAll, FunSuite} -abstract class RpcEnvSuite extends FunSuite { +abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { + var env: RpcEnv = _ + + override def beforeAll(): Unit = { + env = createRpcEnv + } + + override def afterAll(): Unit = { + if(env != null) { + destroyRpcEnv(env) + } + } + + def createRpcEnv: RpcEnv + + def destroyRpcEnv(rpcEnv: RpcEnv) + + test("send a message locally") { + @volatile var message: String = null + val rpcEndPointRef = env.setupEndPoint("send_test", new RpcEndPoint with Logging { + override val rpcEnv = env + + override def receive(sender: RpcEndPointRef) = { + case msg: String => message = msg + } + }) + rpcEndPointRef.send("hello") + Thread.sleep(2000) + assert("hello" === message) + } + + test("ask a message locally") { + val rpcEndPointRef = env.setupEndPoint("ask_test", new RpcEndPoint with Logging { + override val rpcEnv = env + + override def receive(sender: RpcEndPointRef) = { + case msg: String => sender.send(msg) + } + }) + val reply = rpcEndPointRef.askWithReply[String]("hello") + assert("hello" === reply) + } + + test("ping pong") { + case object Start + + case class Ping(id: Int) + + case class Pong(id: Int) + + val pongRef = env.setupEndPoint("pong", new RpcEndPoint with Logging { + override val rpcEnv = env + + override def receive(sender: RpcEndPointRef) = { + case Ping(id) => sender.send(Pong(id)) + } + }) + + val pingRef = env.setupEndPoint("ping", new RpcEndPoint with Logging { + override val rpcEnv = env + + var requester: RpcEndPointRef = _ + + override def receive(sender: RpcEndPointRef) = { + case Start => { + requester = sender + pongRef.send(Ping(1)) + } + case p @ Pong(id) => { + if (id < 10) { + sender.send(Ping(id + 1)) + } else { + requester.send(p) + } + } + } + }) + + val reply = pingRef.askWithReply[Pong](Start) + assert(Pong(10) === reply) + } + + test("register and unregister") { + val endPoint = new RpcEndPoint with Logging { + override val rpcEnv = env + + override def receive(sender: RpcEndPointRef) = { + case msg: String => sender.send(msg) + } + } + val rpcEndPointRef = env.setupEndPoint("register_test", endPoint) + assert(rpcEndPointRef eq env.endPointRef(endPoint)) + endPoint.stop() + assert(null == env.endPointRef(endPoint)) + } } diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala index a475f5481896..462810e3ec74 100644 --- a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala @@ -17,8 +17,24 @@ package org.apache.spark.rpc.akka -import org.apache.spark.rpc.RpcEnvSuite +import akka.actor.ActorSystem -class AkkaRpcEnvSuite extends RpcEnvSuite{ +import org.apache.spark.rpc.{RpcEnv, RpcEnvSuite} +import org.apache.spark.SparkConf +class AkkaRpcEnvSuite extends RpcEnvSuite { + + var akkaSystem: ActorSystem = _ + + override def createRpcEnv: RpcEnv = { + val conf = new SparkConf() + akkaSystem = ActorSystem("test") + new AkkaRpcEnv(akkaSystem, conf) + } + + override def destroyRpcEnv(rpcEnv: RpcEnv): Unit = { + if (akkaSystem != null) { + akkaSystem.shutdown() + } + } } From 1fc4a0176021e1397b0e3639f9d97f3f236a5329 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 25 Dec 2014 17:15:08 +0800 Subject: [PATCH 07/36] endPoint => endpoint --- .../org/apache/spark/HeartbeatReceiver.scala | 6 +-- .../scala/org/apache/spark/SparkContext.scala | 2 +- .../CoarseGrainedExecutorBackend.scala | 12 ++--- .../org/apache/spark/executor/Executor.scala | 4 +- .../apache/spark/executor/ExecutorActor.scala | 6 +-- .../scala/org/apache/spark/rpc/RpcEnv.scala | 46 +++++++++---------- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 38 +++++++-------- .../cluster/CoarseGrainedClusterMessage.scala | 4 +- .../scheduler/cluster/ExecutorData.scala | 4 +- .../org/apache/spark/rpc/RpcEnvSuite.scala | 34 +++++++------- 10 files changed, 78 insertions(+), 78 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 2eb2a69e4adc..32feab48245e 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -18,7 +18,7 @@ package org.apache.spark import org.apache.spark.executor.TaskMetrics -import org.apache.spark.rpc.{RpcEnv, RpcEndPointRef, RpcEndPoint} +import org.apache.spark.rpc.{RpcEnv, RpcEndpointRef, RpcEndpoint} import org.apache.spark.scheduler.TaskScheduler import org.apache.spark.storage.BlockManagerId @@ -37,9 +37,9 @@ private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) * Lives in the driver to receive heartbeats from executors.. */ private[spark] class HeartbeatReceiver(override val rpcEnv: RpcEnv, - scheduler: TaskScheduler, conf: SparkConf) extends RpcEndPoint with Logging { + scheduler: TaskScheduler, conf: SparkConf) extends RpcEndpoint with Logging { - override def receive(sender: RpcEndPointRef) = { + override def receive(sender: RpcEndpointRef) = { case Heartbeat(executorId, taskMetrics, blockManagerId) => val response = HeartbeatResponse( !scheduler.executorHeartbeatReceived(executorId, taskMetrics, blockManagerId)) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 23b1fce949fa..71b09f903b7d 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -324,7 +324,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli private[spark] var (schedulerBackend, taskScheduler) = SparkContext.createTaskScheduler(this, master) - private val heartbeatReceiver = env.rpcEnv.setupEndPoint("HeartbeatReceiver", + private val heartbeatReceiver = env.rpcEnv.setupEndpoint("HeartbeatReceiver", new HeartbeatReceiver(env.rpcEnv, taskScheduler, conf)) @volatile private[spark] var dagScheduler: DAGScheduler = _ diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index b09ff8d09803..765b63a37d51 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -28,7 +28,7 @@ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv} import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.worker.WorkerWatcher -import org.apache.spark.rpc.{RpcEndPointRef, RpcEndPoint} +import org.apache.spark.rpc.{RpcEndpointRef, RpcEndpoint} import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} @@ -39,21 +39,21 @@ private[spark] class CoarseGrainedExecutorBackend( hostPort: String, cores: Int, env: SparkEnv) - extends RpcEndPoint with ExecutorBackend with Logging { + extends RpcEndpoint with ExecutorBackend with Logging { override val rpcEnv = env.rpcEnv Utils.checkHostPort(hostPort, "Expected hostport") var executor: Executor = null - var driver: RpcEndPointRef = null + var driver: RpcEndpointRef = null override def remoteConnectionTerminated(remoteAddress: String): Unit = { logError(s"Driver $remoteAddress disassociated! Shutting down.") System.exit(1) } - override def receive(sender: RpcEndPointRef) = { + override def receive(sender: RpcEndpointRef) = { case RegisteredExecutor => logInfo("Successfully registered with driver") val (hostname, _) = Utils.parseHostPort(hostPort) @@ -134,11 +134,11 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { // Start the CoarseGrainedExecutorBackend RPC end point. val sparkHostPort = hostname + ":" + boundPort val rpc = new CoarseGrainedExecutorBackend(driverUrl, executorId, sparkHostPort, cores, env) - val rpcRef = env.rpcEnv.setupEndPoint("Executor", rpc) + val rpcRef = env.rpcEnv.setupEndpoint("Executor", rpc) // Register this executor with the driver. logInfo("Connecting to driver: " + driverUrl) - val driverRef = env.rpcEnv.setupEndPointRefByUrl(driverUrl) + val driverRef = env.rpcEnv.setupEndpointRefByUrl(driverUrl) driverRef.send(RegisterExecutor(executorId, sparkHostPort, cores, rpcRef)) workerUrl.foreach { url => diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 80dc6225ad03..1b4d2521ed26 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -84,7 +84,7 @@ private[spark] class Executor( // Create an actor for receiving RPCs from the driver private val executorActor = - env.rpcEnv.setupEndPoint("ExecutorActor", new ExecutorActor(env.rpcEnv, executorId)) + env.rpcEnv.setupEndpoint("ExecutorActor", new ExecutorActor(env.rpcEnv, executorId)) // Create our ClassLoader // do this after SparkEnv creation so can access the SecurityManager @@ -354,7 +354,7 @@ private[spark] class Executor( def startDriverHeartbeater() { val interval = conf.getInt("spark.executor.heartbeatInterval", 3000) - val heartbeatReceiverRef = env.rpcEnv.setupDriverEndPointRef("HeartbeatReceiver") + val heartbeatReceiverRef = env.rpcEnv.setupDriverEndpointRef("HeartbeatReceiver") val t = new Thread() { override def run() { diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala index 93756bf82a94..b81e647d762f 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala @@ -18,7 +18,7 @@ package org.apache.spark.executor import org.apache.spark.Logging -import org.apache.spark.rpc.{RpcEnv, RpcEndPointRef, RpcEndPoint} +import org.apache.spark.rpc.{RpcEnv, RpcEndpointRef, RpcEndpoint} import org.apache.spark.util.Utils /** @@ -31,9 +31,9 @@ private[spark] case object TriggerThreadDump */ private[spark] class ExecutorActor(override val rpcEnv: RpcEnv, executorId: String) - extends RpcEndPoint with Logging { + extends RpcEndpoint with Logging { - override def receive(sender: RpcEndPointRef) = { + override def receive(sender: RpcEndpointRef) = { case TriggerThreadDump => sender.send(Utils.getThreadDump()) } diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index f6f49266f430..3692c14cf012 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -29,33 +29,33 @@ trait RpcEnv { /** * Need this map to set up the `sender` for the send method. */ - private val endPointToRef = new ConcurrentHashMap[RpcEndPoint, RpcEndPointRef]() + private val endpointToRef = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]() /** - * Need this map to remove `RpcEndPoint` from `endPointToRef` via a `RpcEndPointRef` + * Need this map to remove `RpcEndpoint` from `endpointToRef` via a `RpcEndpointRef` */ - private val refToEndPoint = new ConcurrentHashMap[RpcEndPointRef, RpcEndPoint]() + private val refToEndpoint = new ConcurrentHashMap[RpcEndpointRef, RpcEndpoint]() - protected def registerEndPoint(endPoint: RpcEndPoint, endPointRef: RpcEndPointRef): Unit = { - refToEndPoint.put(endPointRef, endPoint) - endPointToRef.put(endPoint, endPointRef) + protected def registerEndpoint(endpoint: RpcEndpoint, endpointRef: RpcEndpointRef): Unit = { + refToEndpoint.put(endpointRef, endpoint) + endpointToRef.put(endpoint, endpointRef) } - protected def unregisterEndPoint(endPointRef: RpcEndPointRef): Unit = { - val endPoint = refToEndPoint.remove(endPointRef) - endPointToRef.remove(endPoint) + protected def unregisterEndpoint(endpointRef: RpcEndpointRef): Unit = { + val endpoint = refToEndpoint.remove(endpointRef) + endpointToRef.remove(endpoint) } - def endPointRef(endPoint: RpcEndPoint): RpcEndPointRef = endPointToRef.get(endPoint) + def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = endpointToRef.get(endpoint) - def setupEndPoint(name: String, endPoint: RpcEndPoint): RpcEndPointRef + def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef - def setupDriverEndPointRef(name: String): RpcEndPointRef + def setupDriverEndpointRef(name: String): RpcEndpointRef - def setupEndPointRefByUrl(url: String): RpcEndPointRef + def setupEndpointRefByUrl(url: String): RpcEndpointRef - def stop(endPoint: RpcEndPointRef): Unit + def stop(endpoint: RpcEndpointRef): Unit def stopAll(): Unit } @@ -64,16 +64,16 @@ trait RpcEnv { /** * An end point for the RPC that defines what functions to trigger given a message. */ -trait RpcEndPoint { +trait RpcEndpoint { val rpcEnv: RpcEnv /** * Provide the implicit sender. */ - implicit final def self: RpcEndPointRef = rpcEnv.endPointRef(this) + implicit final def self: RpcEndpointRef = rpcEnv.endpointRef(this) - def receive(sender: RpcEndPointRef): PartialFunction[Any, Unit] + def receive(sender: RpcEndpointRef): PartialFunction[Any, Unit] def remoteConnectionTerminated(remoteAddress: String): Unit = { // By default, do nothing. @@ -89,21 +89,21 @@ trait RpcEndPoint { } -object RpcEndPoint { - final val noSender: RpcEndPointRef = null +object RpcEndpoint { + final val noSender: RpcEndpointRef = null } /** - * A reference for a remote [[RpcEndPoint]]. + * A reference for a remote [[RpcEndpoint]]. */ -trait RpcEndPointRef { +trait RpcEndpointRef { def address: String def askWithReply[T](message: Any): T /** - * Send a message to the remote end point asynchronously. No delivery guarantee is provided. + * Send a message to the remote endpoint asynchronously. No delivery guarantee is provided. */ - def send(message: Any)(implicit sender: RpcEndPointRef = RpcEndPoint.noSender): Unit + def send(message: Any)(implicit sender: RpcEndpointRef = RpcEndpoint.noSender): Unit } \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 2e7b3c7e08e2..798065e0d418 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -30,7 +30,7 @@ import org.apache.spark.util.AkkaUtils class AkkaRpcEnv(actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { - override def setupEndPoint(name: String, endPoint: RpcEndPoint): RpcEndPointRef = { + override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { val actorRef = actorSystem.actorOf(Props(new Actor { override def preStart(): Unit = { // Listen for remote client disconnection events, since they don't go through Akka's watch() @@ -39,44 +39,44 @@ class AkkaRpcEnv(actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { override def receive: Receive = { case DisassociatedEvent(_, remoteAddress, _) => - endPoint.remoteConnectionTerminated(remoteAddress.toString) + endpoint.remoteConnectionTerminated(remoteAddress.toString) case message: Any => - endPoint.logMessage.trace("Received RPC message: " + message) - val pf = endPoint.receive(new AkkaRpcEndPointRef(sender(), conf)) + endpoint.logMessage.trace("Received RPC message: " + message) + val pf = endpoint.receive(new AkkaRpcEndpointRef(sender(), conf)) if (pf.isDefinedAt(message)) { pf.apply(message) } } }), name = name) - val endPointRef = new AkkaRpcEndPointRef(actorRef, conf) - registerEndPoint(endPoint, endPointRef) - endPointRef + val endpointRef = new AkkaRpcEndpointRef(actorRef, conf) + registerEndpoint(endpoint, endpointRef) + endpointRef } - override def setupDriverEndPointRef(name: String): RpcEndPointRef = { - new AkkaRpcEndPointRef(AkkaUtils.makeDriverRef(name, conf, actorSystem), conf) + override def setupDriverEndpointRef(name: String): RpcEndpointRef = { + new AkkaRpcEndpointRef(AkkaUtils.makeDriverRef(name, conf, actorSystem), conf) } - override def setupEndPointRefByUrl(url: String): RpcEndPointRef = { + override def setupEndpointRefByUrl(url: String): RpcEndpointRef = { val timeout = Duration.create(conf.getLong("spark.akka.lookupTimeout", 30), "seconds") val ref = Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) - new AkkaRpcEndPointRef(ref, conf) + new AkkaRpcEndpointRef(ref, conf) } override def stopAll(): Unit = { // Do nothing since actorSystem was created outside. } - override def stop(endPoint: RpcEndPointRef): Unit = { - require(endPoint.isInstanceOf[AkkaRpcEndPointRef]) - unregisterEndPoint(endPoint) - actorSystem.stop(endPoint.asInstanceOf[AkkaRpcEndPointRef].actorRef) + override def stop(endpoint: RpcEndpointRef): Unit = { + require(endpoint.isInstanceOf[AkkaRpcEndpointRef]) + unregisterEndpoint(endpoint) + actorSystem.stop(endpoint.asInstanceOf[AkkaRpcEndpointRef].actorRef) } } -private[akka] class AkkaRpcEndPointRef(val actorRef: ActorRef, conf: SparkConf) - extends RpcEndPointRef with Serializable with Logging { +private[akka] class AkkaRpcEndpointRef(val actorRef: ActorRef, conf: SparkConf) + extends RpcEndpointRef with Serializable with Logging { private[this] val maxRetries = conf.getInt("spark.akka.num.retries", 3) private[this] val retryWaitMs = conf.getInt("spark.akka.retry.wait", 3000) @@ -110,12 +110,12 @@ private[akka] class AkkaRpcEndPointRef(val actorRef: ActorRef, conf: SparkConf) "Error sending message [message = " + message + "]", lastException) } - override def send(message: Any)(implicit sender: RpcEndPointRef = RpcEndPoint.noSender): Unit = { + override def send(message: Any)(implicit sender: RpcEndpointRef = RpcEndpoint.noSender): Unit = { implicit val actorSender: ActorRef = if (sender == null) { Actor.noSender } else { - sender.asInstanceOf[AkkaRpcEndPointRef].actorRef + sender.asInstanceOf[AkkaRpcEndpointRef].actorRef } actorRef ! message } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 7c828bb05ceb..1c79d4d575b2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -20,7 +20,7 @@ package org.apache.spark.scheduler.cluster import java.nio.ByteBuffer import org.apache.spark.TaskState.TaskState -import org.apache.spark.rpc.RpcEndPointRef +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.{SerializableBuffer, Utils} @@ -42,7 +42,7 @@ private[spark] object CoarseGrainedClusterMessages { // Executors to driver case class RegisterExecutor( - executorId: String, hostPort: String, cores: Int, rpcRef: RpcEndPointRef) + executorId: String, hostPort: String, cores: Int, rpcRef: RpcEndpointRef) extends CoarseGrainedClusterMessage { Utils.checkHostPort(hostPort, "Expected host port") } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala index 630b38a58e72..7045c577443c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler.cluster -import org.apache.spark.rpc.RpcEndPointRef +import org.apache.spark.rpc.RpcEndpointRef /** * Grouping of data for an executor used by CoarseGrainedSchedulerBackend. @@ -29,7 +29,7 @@ import org.apache.spark.rpc.RpcEndPointRef * @param totalCores The total number of cores available to the executor */ private[cluster] class ExecutorData( - val executorActor: RpcEndPointRef, + val executorActor: RpcEndpointRef, val executorAddress: String, val executorHost: String , var freeCores: Int, diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index d47704ecbd74..1edd878cf94a 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -40,27 +40,27 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { test("send a message locally") { @volatile var message: String = null - val rpcEndPointRef = env.setupEndPoint("send_test", new RpcEndPoint with Logging { + val rpcEndpointRef = env.setupEndpoint("send_test", new RpcEndpoint with Logging { override val rpcEnv = env - override def receive(sender: RpcEndPointRef) = { + override def receive(sender: RpcEndpointRef) = { case msg: String => message = msg } }) - rpcEndPointRef.send("hello") + rpcEndpointRef.send("hello") Thread.sleep(2000) assert("hello" === message) } test("ask a message locally") { - val rpcEndPointRef = env.setupEndPoint("ask_test", new RpcEndPoint with Logging { + val rpcEndpointRef = env.setupEndpoint("ask_test", new RpcEndpoint with Logging { override val rpcEnv = env - override def receive(sender: RpcEndPointRef) = { + override def receive(sender: RpcEndpointRef) = { case msg: String => sender.send(msg) } }) - val reply = rpcEndPointRef.askWithReply[String]("hello") + val reply = rpcEndpointRef.askWithReply[String]("hello") assert("hello" === reply) } @@ -71,20 +71,20 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { case class Pong(id: Int) - val pongRef = env.setupEndPoint("pong", new RpcEndPoint with Logging { + val pongRef = env.setupEndpoint("pong", new RpcEndpoint with Logging { override val rpcEnv = env - override def receive(sender: RpcEndPointRef) = { + override def receive(sender: RpcEndpointRef) = { case Ping(id) => sender.send(Pong(id)) } }) - val pingRef = env.setupEndPoint("ping", new RpcEndPoint with Logging { + val pingRef = env.setupEndpoint("ping", new RpcEndpoint with Logging { override val rpcEnv = env - var requester: RpcEndPointRef = _ + var requester: RpcEndpointRef = _ - override def receive(sender: RpcEndPointRef) = { + override def receive(sender: RpcEndpointRef) = { case Start => { requester = sender pongRef.send(Ping(1)) @@ -104,16 +104,16 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } test("register and unregister") { - val endPoint = new RpcEndPoint with Logging { + val endpoint = new RpcEndpoint with Logging { override val rpcEnv = env - override def receive(sender: RpcEndPointRef) = { + override def receive(sender: RpcEndpointRef) = { case msg: String => sender.send(msg) } } - val rpcEndPointRef = env.setupEndPoint("register_test", endPoint) - assert(rpcEndPointRef eq env.endPointRef(endPoint)) - endPoint.stop() - assert(null == env.endPointRef(endPoint)) + val rpcEndpointRef = env.setupEndpoint("register_test", endpoint) + assert(rpcEndpointRef eq env.endpointRef(endpoint)) + endpoint.stop() + assert(null == env.endpointRef(endpoint)) } } From 0627986d232e0257f5ee12a6b3340af9713aff01 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 26 Dec 2014 11:28:27 +0800 Subject: [PATCH 08/36] Remove log api from RpcEndpoint trait --- core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala | 6 ------ .../scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala | 4 ++-- .../test/scala/org/apache/spark/rpc/RpcEnvSuite.scala | 11 +++++------ 3 files changed, 7 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 3692c14cf012..31029462003d 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -19,8 +19,6 @@ package org.apache.spark.rpc import java.util.concurrent.ConcurrentHashMap -import org.slf4j.Logger - /** * An RPC environment. */ @@ -79,10 +77,6 @@ trait RpcEndpoint { // By default, do nothing. } - protected def log: Logger - - private[rpc] def logMessage = log - final def stop(): Unit = { rpcEnv.stop(self) } diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 798065e0d418..54f1df6e3f05 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -31,7 +31,7 @@ import org.apache.spark.util.AkkaUtils class AkkaRpcEnv(actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { - val actorRef = actorSystem.actorOf(Props(new Actor { + val actorRef = actorSystem.actorOf(Props(new Actor with Logging { override def preStart(): Unit = { // Listen for remote client disconnection events, since they don't go through Akka's watch() context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) @@ -42,7 +42,7 @@ class AkkaRpcEnv(actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { endpoint.remoteConnectionTerminated(remoteAddress.toString) case message: Any => - endpoint.logMessage.trace("Received RPC message: " + message) + logTrace("Received RPC message: " + message) val pf = endpoint.receive(new AkkaRpcEndpointRef(sender(), conf)) if (pf.isDefinedAt(message)) { pf.apply(message) diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 1edd878cf94a..9c0cd824307e 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.rpc -import org.apache.spark.Logging import org.scalatest.{BeforeAndAfterAll, FunSuite} abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { @@ -40,7 +39,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { test("send a message locally") { @volatile var message: String = null - val rpcEndpointRef = env.setupEndpoint("send_test", new RpcEndpoint with Logging { + val rpcEndpointRef = env.setupEndpoint("send_test", new RpcEndpoint { override val rpcEnv = env override def receive(sender: RpcEndpointRef) = { @@ -53,7 +52,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } test("ask a message locally") { - val rpcEndpointRef = env.setupEndpoint("ask_test", new RpcEndpoint with Logging { + val rpcEndpointRef = env.setupEndpoint("ask_test", new RpcEndpoint { override val rpcEnv = env override def receive(sender: RpcEndpointRef) = { @@ -71,7 +70,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { case class Pong(id: Int) - val pongRef = env.setupEndpoint("pong", new RpcEndpoint with Logging { + val pongRef = env.setupEndpoint("pong", new RpcEndpoint { override val rpcEnv = env override def receive(sender: RpcEndpointRef) = { @@ -79,7 +78,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } }) - val pingRef = env.setupEndpoint("ping", new RpcEndpoint with Logging { + val pingRef = env.setupEndpoint("ping", new RpcEndpoint { override val rpcEnv = env var requester: RpcEndpointRef = _ @@ -104,7 +103,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } test("register and unregister") { - val endpoint = new RpcEndpoint with Logging { + val endpoint = new RpcEndpoint { override val rpcEnv = env override def receive(sender: RpcEndpointRef) = { From ee988d325ed726886a785ceef4862cf95847f688 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 26 Dec 2014 11:37:28 +0800 Subject: [PATCH 09/36] Throw an exception if RpcEndpoint is not in RpcEnv --- core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala | 6 +++++- .../main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala | 2 ++ core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala | 6 +++++- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 31029462003d..06b9149bff1b 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -45,7 +45,11 @@ trait RpcEnv { endpointToRef.remove(endpoint) } - def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = endpointToRef.get(endpoint) + def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = { + val endpointRef = endpointToRef.get(endpoint) + require(endpointRef != null, s"Cannot find RpcEndpointRef of ${endpoint} in ${this}") + endpointRef + } def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 54f1df6e3f05..2c4c1b5a1b95 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -73,6 +73,8 @@ class AkkaRpcEnv(actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { unregisterEndpoint(endpoint) actorSystem.stop(endpoint.asInstanceOf[AkkaRpcEndpointRef].actorRef) } + + override def toString = s"${getClass.getSimpleName}($actorSystem)" } private[akka] class AkkaRpcEndpointRef(val actorRef: ActorRef, conf: SparkConf) diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 9c0cd824307e..165bad4f5e5b 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -113,6 +113,10 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val rpcEndpointRef = env.setupEndpoint("register_test", endpoint) assert(rpcEndpointRef eq env.endpointRef(endpoint)) endpoint.stop() - assert(null == env.endpointRef(endpoint)) + + val e = intercept[IllegalArgumentException] { + env.endpointRef(endpoint) + } + assert(e.getMessage.contains("Cannot find RpcEndpointRef")) } } From 2e99b4a7cc688119449ce523c90e9ce3c702738f Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 26 Dec 2014 18:48:07 +0800 Subject: [PATCH 10/36] Add tests for HeartbeatReceiver --- .../org/apache/spark/HeartbeatReceiver.scala | 4 +- .../scala/org/apache/spark/SparkContext.scala | 2 +- .../scala/org/apache/spark/rpc/RpcEnv.scala | 13 +++- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 5 +- .../CoarseGrainedSchedulerBackend.scala | 6 +- .../apache/spark/HeartbeatReceiverSuite.scala | 68 +++++++++++++++++++ 6 files changed, 88 insertions(+), 10 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 32feab48245e..da7bec3e5c7e 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -36,8 +36,8 @@ private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) /** * Lives in the driver to receive heartbeats from executors.. */ -private[spark] class HeartbeatReceiver(override val rpcEnv: RpcEnv, - scheduler: TaskScheduler, conf: SparkConf) extends RpcEndpoint with Logging { +private[spark] class HeartbeatReceiver(override val rpcEnv: RpcEnv, scheduler: TaskScheduler) + extends RpcEndpoint { override def receive(sender: RpcEndpointRef) = { case Heartbeat(executorId, taskMetrics, blockManagerId) => diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 71b09f903b7d..61ebecc4d4da 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -325,7 +325,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli SparkContext.createTaskScheduler(this, master) private val heartbeatReceiver = env.rpcEnv.setupEndpoint("HeartbeatReceiver", - new HeartbeatReceiver(env.rpcEnv, taskScheduler, conf)) + new HeartbeatReceiver(env.rpcEnv, taskScheduler)) @volatile private[spark] var dagScheduler: DAGScheduler = _ try { diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 06b9149bff1b..8d00707e1844 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -42,7 +42,9 @@ trait RpcEnv { protected def unregisterEndpoint(endpointRef: RpcEndpointRef): Unit = { val endpoint = refToEndpoint.remove(endpointRef) - endpointToRef.remove(endpoint) + if (endpoint != null) { + endpointToRef.remove(endpoint) + } } def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = { @@ -75,6 +77,13 @@ trait RpcEndpoint { */ implicit final def self: RpcEndpointRef = rpcEnv.endpointRef(this) + /** + * Same assumption like Actor: messages sent to a RpcEndpoint will be delivered in sequence, and + * messages from the same RpcEndpoint will be delivered in order. + * + * @param sender + * @return + */ def receive(sender: RpcEndpointRef): PartialFunction[Any, Unit] def remoteConnectionTerminated(remoteAddress: String): Unit = { @@ -104,4 +113,4 @@ trait RpcEndpointRef { * Send a message to the remote endpoint asynchronously. No delivery guarantee is provided. */ def send(message: Any)(implicit sender: RpcEndpointRef = RpcEndpoint.noSender): Unit -} \ No newline at end of file +} diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 2c4c1b5a1b95..66e0c88dd3a3 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -32,6 +32,7 @@ class AkkaRpcEnv(actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { val actorRef = actorSystem.actorOf(Props(new Actor with Logging { + override def preStart(): Unit = { // Listen for remote client disconnection events, since they don't go through Akka's watch() context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) @@ -40,9 +41,8 @@ class AkkaRpcEnv(actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { override def receive: Receive = { case DisassociatedEvent(_, remoteAddress, _) => endpoint.remoteConnectionTerminated(remoteAddress.toString) - case message: Any => - logTrace("Received RPC message: " + message) + logInfo("Received RPC message: " + message) val pf = endpoint.receive(new AkkaRpcEndpointRef(sender(), conf)) if (pf.isDefinedAt(message)) { pf.apply(message) @@ -117,6 +117,7 @@ private[akka] class AkkaRpcEndpointRef(val actorRef: ActorRef, conf: SparkConf) if (sender == null) { Actor.noSender } else { + require(sender.isInstanceOf[AkkaRpcEndpointRef]) sender.asInstanceOf[AkkaRpcEndpointRef].actorRef } actorRef ! message diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 1956ebbf8181..996ce4e166d9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -150,9 +150,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste removeExecutor(executorId, reason) sender ! true -// case DisassociatedEvent(_, address, _) => -// addressToExecutorId.get(address).foreach(removeExecutor(_, -// "remote Akka client disassociated")) + case DisassociatedEvent(_, address, _) => + addressToExecutorId.get(address.toString).foreach(removeExecutor(_, + "remote Akka client disassociated")) case RetrieveSparkProps => println("sending the other side properties RetrieveSparkProps") diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala new file mode 100644 index 000000000000..0a8c1c5ea129 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.storage.BlockManagerId +import org.scalatest.FunSuite +import org.mockito.Mockito._ +import org.mockito.Matchers +import org.mockito.Matchers._ + +import org.apache.spark.scheduler.TaskScheduler + +class HeartbeatReceiverSuite extends FunSuite with LocalSparkContext { + + test("HeartbeatReceiver") { + sc = new SparkContext("local[2]", "test") + val scheduler = mock(classOf[TaskScheduler]) + when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true) + + sc.env.rpcEnv.setupEndpoint("heartbeat", new HeartbeatReceiver(sc.env.rpcEnv, scheduler)) + val receiverRef = sc.env.rpcEnv.setupDriverEndpointRef("heartbeat") + + val metrics = new TaskMetrics + metrics.jvmGCTime = 100 + val blockManagerId = BlockManagerId("executor-1", "localhost", 12345) + val response = receiverRef.askWithReply[HeartbeatResponse]( + Heartbeat("executor-1", Array(1L -> metrics), blockManagerId)) + + verify(scheduler).executorHeartbeatReceived( + Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId)) + assert(false === response.reregisterBlockManager) + } + + test("HeartbeatReceiver re-register") { + sc = new SparkContext("local[2]", "test") + val scheduler = mock(classOf[TaskScheduler]) + when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(false) + + sc.env.rpcEnv.setupEndpoint("heartbeat", new HeartbeatReceiver(sc.env.rpcEnv, scheduler)) + val receiverRef = sc.env.rpcEnv.setupDriverEndpointRef("heartbeat") + + val metrics = new TaskMetrics + metrics.jvmGCTime = 100 + val blockManagerId = BlockManagerId("executor-1", "localhost", 12345) + val response = receiverRef.askWithReply[HeartbeatResponse]( + Heartbeat("executor-1", Array(1L -> metrics), blockManagerId)) + + verify(scheduler).executorHeartbeatReceived( + Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId)) + assert(true === response.reregisterBlockManager) + } +} From 73db9e56fe86ac1a3ecbb7d16a80c1a3dd5b5011 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 26 Dec 2014 19:13:53 +0800 Subject: [PATCH 11/36] Minor changes --- .../spark/executor/CoarseGrainedExecutorBackend.scala | 4 ++-- .../scala/org/apache/spark/executor/ExecutorActor.scala | 4 +--- .../main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala | 6 +++++- .../scheduler/cluster/CoarseGrainedSchedulerBackend.scala | 6 +++--- .../org/apache/spark/scheduler/cluster/ExecutorData.scala | 4 ++-- 5 files changed, 13 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 765b63a37d51..4dc72ca48659 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -84,9 +84,9 @@ private[spark] class CoarseGrainedExecutorBackend( case StopExecutor => logInfo("Driver commanded a shutdown") - executor.stop() - //rpcEnv.stop(this) // TODO(rxin): Stop this properly. + stop() + rpcEnv.stopAll() } override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) { diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala index b81e647d762f..5c7d07195808 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala @@ -17,7 +17,6 @@ package org.apache.spark.executor -import org.apache.spark.Logging import org.apache.spark.rpc.{RpcEnv, RpcEndpointRef, RpcEndpoint} import org.apache.spark.util.Utils @@ -30,8 +29,7 @@ private[spark] case object TriggerThreadDump * Actor that runs inside of executors to enable driver -> executor RPC. */ private[spark] -class ExecutorActor(override val rpcEnv: RpcEnv, executorId: String) - extends RpcEndpoint with Logging { +class ExecutorActor(override val rpcEnv: RpcEnv, executorId: String) extends RpcEndpoint { override def receive(sender: RpcEndpointRef) = { case TriggerThreadDump => diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 66e0c88dd3a3..600635943234 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -39,7 +39,11 @@ class AkkaRpcEnv(actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { } override def receive: Receive = { - case DisassociatedEvent(_, remoteAddress, _) => + case message @ DisassociatedEvent(_, remoteAddress, _) => + val pf = endpoint.receive(new AkkaRpcEndpointRef(sender(), conf)) + if (pf.isDefinedAt(message)) { + pf.apply(message) + } endpoint.remoteConnectionTerminated(remoteAddress.toString) case message: Any => logInfo("Received RPC message: " + message) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 996ce4e166d9..c1e32831d632 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -129,7 +129,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste case KillTask(taskId, executorId, interruptThread) => executorDataMap.get(executorId) match { case Some(executorInfo) => - executorInfo.executorActor.send(KillTask(taskId, executorId, interruptThread)) + executorInfo.executorEndpoint.send(KillTask(taskId, executorId, interruptThread)) case None => // Ignoring the task kill since the executor is not registered. logWarning(s"Attempted to kill task $taskId for unknown executor $executorId.") @@ -142,7 +142,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste case StopExecutors => logInfo("Asking each executor to shut down") for ((_, executorData) <- executorDataMap) { - executorData.executorActor.send(StopExecutor) + executorData.executorEndpoint.send(StopExecutor) } sender ! true @@ -196,7 +196,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste else { val executorData = executorDataMap(task.executorId) executorData.freeCores -= scheduler.CPUS_PER_TASK - executorData.executorActor.send(LaunchTask(new SerializableBuffer(serializedTask))) + executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask))) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala index 7045c577443c..bb5122fd6ac8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala @@ -22,14 +22,14 @@ import org.apache.spark.rpc.RpcEndpointRef /** * Grouping of data for an executor used by CoarseGrainedSchedulerBackend. * - * @param executorActor The ActorRef representing this executor + * @param executorEndpoint The RpcEndpointRef representing this executor * @param executorAddress The network address of this executor * @param executorHost The hostname that this executor is running on * @param freeCores The current number of cores available for work on the executor * @param totalCores The total number of cores available to the executor */ private[cluster] class ExecutorData( - val executorActor: RpcEndpointRef, + val executorEndpoint: RpcEndpointRef, val executorAddress: String, val executorHost: String , var freeCores: Int, From 19053ca1ff95d6c001befe6107de5e15f96e9f13 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 26 Dec 2014 20:07:30 +0800 Subject: [PATCH 12/36] Add test for ExecutorActor --- .../scala/org/apache/spark/SparkContext.scala | 5 +- .../CoarseGrainedExecutorBackend.scala | 54 +++++++++---------- .../org/apache/spark/executor/Executor.scala | 18 ++++--- .../cluster/CoarseGrainedClusterMessage.scala | 5 +- .../CoarseGrainedSchedulerBackend.scala | 23 ++++---- .../scheduler/cluster/ExecutorData.scala | 8 +-- .../spark/executor/ExecutorActorSuite.scala | 33 ++++++++++++ 7 files changed, 88 insertions(+), 58 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/executor/ExecutorActorSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 61ebecc4d4da..75f793455c70 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -410,9 +410,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli Some(Utils.getThreadDump()) } else { val (host, port) = env.blockManager.master.getActorSystemHostPortForExecutor(executorId).get - val actorRef = AkkaUtils.makeExecutorRef("ExecutorActor", conf, host, port, env.actorSystem) - Some(AkkaUtils.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump, actorRef, - AkkaUtils.numRetries(conf), AkkaUtils.retryWaitMs(conf), AkkaUtils.askTimeout(conf))) + val endpointRef = env.rpcEnv.setupDriverEndpointRef("ExecutorActor") + Some(endpointRef.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump)) } } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 4dc72ca48659..c794a7bc3599 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -23,15 +23,15 @@ import scala.concurrent.Await import akka.actor.{Actor, ActorSelection, Props} import akka.pattern.Patterns +import akka.remote.{RemotingLifecycleEvent, DisassociatedEvent} import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv} import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.worker.WorkerWatcher -import org.apache.spark.rpc.{RpcEndpointRef, RpcEndpoint} import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} +import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} private[spark] class CoarseGrainedExecutorBackend( driverUrl: String, @@ -39,28 +39,28 @@ private[spark] class CoarseGrainedExecutorBackend( hostPort: String, cores: Int, env: SparkEnv) - extends RpcEndpoint with ExecutorBackend with Logging { - - override val rpcEnv = env.rpcEnv + extends Actor with ActorLogReceive with ExecutorBackend with Logging { Utils.checkHostPort(hostPort, "Expected hostport") var executor: Executor = null - var driver: RpcEndpointRef = null + var driver: ActorSelection = null - override def remoteConnectionTerminated(remoteAddress: String): Unit = { - logError(s"Driver $remoteAddress disassociated! Shutting down.") - System.exit(1) + override def preStart() { + logInfo("Connecting to driver: " + driverUrl) + driver = context.actorSelection(driverUrl) + driver ! RegisterExecutor(executorId, hostPort, cores) + context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) } - override def receive(sender: RpcEndpointRef) = { + override def receiveWithLogging = { case RegisteredExecutor => logInfo("Successfully registered with driver") val (hostname, _) = Utils.parseHostPort(hostPort) executor = new Executor(executorId, hostname, env, isLocal = false) - case RegisterExecutorFailed(failureMessage) => - logError("Slave registration failed: " + failureMessage) + case RegisterExecutorFailed(message) => + logError("Slave registration failed: " + message) System.exit(1) case LaunchTask(data) => @@ -82,15 +82,19 @@ private[spark] class CoarseGrainedExecutorBackend( executor.killTask(taskId, interruptThread) } + case x: DisassociatedEvent => + logError(s"Driver $x disassociated! Shutting down.") + System.exit(1) + case StopExecutor => logInfo("Driver commanded a shutdown") - // TODO(rxin): Stop this properly. - stop() - rpcEnv.stopAll() + executor.stop() + context.stop(self) + context.system.shutdown() } override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) { - driver.send(StatusUpdate(executorId, taskId, state, data)) + driver ! StatusUpdate(executorId, taskId, state, data) } } @@ -115,9 +119,9 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { val port = executorConf.getInt("spark.executor.port", 0) val (fetcher, _) = AkkaUtils.createActorSystem( "driverPropsFetcher", hostname, port, executorConf, new SecurityManager(executorConf)) - val driverActor = fetcher.actorSelection(driverUrl) + val driver = fetcher.actorSelection(driverUrl) val timeout = AkkaUtils.askTimeout(executorConf) - val fut = Patterns.ask(driverActor, RetrieveSparkProps, timeout) + val fut = Patterns.ask(driver, RetrieveSparkProps, timeout) val props = Await.result(fut, timeout).asInstanceOf[Seq[(String, String)]] ++ Seq[(String, String)](("spark.app.id", appId)) fetcher.shutdown() @@ -131,16 +135,12 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { val boundPort = env.conf.getInt("spark.executor.port", 0) assert(boundPort != 0) - // Start the CoarseGrainedExecutorBackend RPC end point. + // Start the CoarseGrainedExecutorBackend actor. val sparkHostPort = hostname + ":" + boundPort - val rpc = new CoarseGrainedExecutorBackend(driverUrl, executorId, sparkHostPort, cores, env) - val rpcRef = env.rpcEnv.setupEndpoint("Executor", rpc) - - // Register this executor with the driver. - logInfo("Connecting to driver: " + driverUrl) - val driverRef = env.rpcEnv.setupEndpointRefByUrl(driverUrl) - driverRef.send(RegisterExecutor(executorId, sparkHostPort, cores, rpcRef)) - + env.actorSystem.actorOf( + Props(classOf[CoarseGrainedExecutorBackend], + driverUrl, executorId, sparkHostPort, cores, env), + name = "Executor") workerUrl.foreach { url => env.actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher") } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 1b4d2521ed26..928653aa1168 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -22,8 +22,6 @@ import java.lang.management.ManagementFactory import java.nio.ByteBuffer import java.util.concurrent._ -import org.apache.spark.network.rpc.{SimpleRpcClient, SimpleRpcServer} - import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal @@ -83,8 +81,8 @@ private[spark] class Executor( } // Create an actor for receiving RPCs from the driver - private val executorActor = - env.rpcEnv.setupEndpoint("ExecutorActor", new ExecutorActor(env.rpcEnv, executorId)) + private val executorActor = env.rpcEnv.setupEndpoint( + "ExecutorActor", new ExecutorActor(env.rpcEnv, executorId)) // Create our ClassLoader // do this after SparkEnv creation so can access the SecurityManager @@ -353,8 +351,11 @@ private[spark] class Executor( } def startDriverHeartbeater() { - val interval = conf.getInt("spark.executor.heartbeatInterval", 3000) - val heartbeatReceiverRef = env.rpcEnv.setupDriverEndpointRef("HeartbeatReceiver") + val interval = conf.getInt("spark.executor.heartbeatInterval", 10000) + val timeout = AkkaUtils.lookupTimeout(conf) + val retryAttempts = AkkaUtils.numRetries(conf) + val retryIntervalMs = AkkaUtils.retryWaitMs(conf) + val heartbeatReceiverRef = AkkaUtils.makeDriverRef("HeartbeatReceiver", conf, env.actorSystem) val t = new Thread() { override def run() { @@ -368,7 +369,7 @@ private[spark] class Executor( for (taskRunner <- runningTasks.values()) { if (!taskRunner.attemptedTask.isEmpty) { Option(taskRunner.task).flatMap(_.metrics).foreach { metrics => - metrics.updateShuffleReadMetrics() + metrics.updateShuffleReadMetrics metrics.jvmGCTime = curGCTime - taskRunner.startGCTime if (isLocal) { // JobProgressListener will hold an reference of it during @@ -386,7 +387,8 @@ private[spark] class Executor( val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId) try { - val response = heartbeatReceiverRef.askWithReply[HeartbeatResponse](message) + val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef, + retryAttempts, retryIntervalMs, timeout) if (response.reregisterBlockManager) { logWarning("Told to re-register on heartbeat") env.blockManager.reregister() diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 1c79d4d575b2..1da6fe976da5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -20,10 +20,8 @@ package org.apache.spark.scheduler.cluster import java.nio.ByteBuffer import org.apache.spark.TaskState.TaskState -import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.{SerializableBuffer, Utils} - private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable private[spark] object CoarseGrainedClusterMessages { @@ -41,8 +39,7 @@ private[spark] object CoarseGrainedClusterMessages { case class RegisterExecutorFailed(message: String) extends CoarseGrainedClusterMessage // Executors to driver - case class RegisterExecutor( - executorId: String, hostPort: String, cores: Int, rpcRef: RpcEndpointRef) + case class RegisterExecutor(executorId: String, hostPort: String, cores: Int) extends CoarseGrainedClusterMessage { Utils.checkHostPort(hostPort, "Expected host port") } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index c1e32831d632..fe9914b50bc5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -71,7 +71,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor with ActorLogReceive { override protected def log = CoarseGrainedSchedulerBackend.this.log - private val addressToExecutorId = new HashMap[String, String] + private val addressToExecutorId = new HashMap[Address, String] override def preStart() { // Listen for remote client disconnection events, since they don't go through Akka's watch() @@ -84,19 +84,19 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste } def receiveWithLogging = { - case RegisterExecutor(executorId, hostPort, cores, executorRpcRef) => + case RegisterExecutor(executorId, hostPort, cores) => Utils.checkHostPort(hostPort, "Host port expected " + hostPort) if (executorDataMap.contains(executorId)) { - executorRpcRef.send(RegisterExecutorFailed("Duplicate executor ID: " + executorId)) + sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId) } else { - logInfo("Registered executor: " + executorRpcRef + " with ID " + executorId) - executorRpcRef.send(RegisteredExecutor) + logInfo("Registered executor: " + sender + " with ID " + executorId) + sender ! RegisteredExecutor - addressToExecutorId(executorRpcRef.address) = executorId + addressToExecutorId(sender.path.address) = executorId totalCoreCount.addAndGet(cores) totalRegisteredExecutors.addAndGet(1) val (host, _) = Utils.parseHostPort(hostPort) - val data = new ExecutorData(executorRpcRef, executorRpcRef.address, host, cores, cores) + val data = new ExecutorData(sender, sender.path.address, host, cores, cores) // This must be synchronized because variables mutated // in this block are read when requesting executors CoarseGrainedSchedulerBackend.this.synchronized { @@ -129,7 +129,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste case KillTask(taskId, executorId, interruptThread) => executorDataMap.get(executorId) match { case Some(executorInfo) => - executorInfo.executorEndpoint.send(KillTask(taskId, executorId, interruptThread)) + executorInfo.executorActor ! KillTask(taskId, executorId, interruptThread) case None => // Ignoring the task kill since the executor is not registered. logWarning(s"Attempted to kill task $taskId for unknown executor $executorId.") @@ -142,7 +142,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste case StopExecutors => logInfo("Asking each executor to shut down") for ((_, executorData) <- executorDataMap) { - executorData.executorEndpoint.send(StopExecutor) + executorData.executorActor ! StopExecutor } sender ! true @@ -151,11 +151,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste sender ! true case DisassociatedEvent(_, address, _) => - addressToExecutorId.get(address.toString).foreach(removeExecutor(_, + addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client disassociated")) case RetrieveSparkProps => - println("sending the other side properties RetrieveSparkProps") sender ! sparkProperties } @@ -196,7 +195,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste else { val executorData = executorDataMap(task.executorId) executorData.freeCores -= scheduler.CPUS_PER_TASK - executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask))) + executorData.executorActor ! LaunchTask(new SerializableBuffer(serializedTask)) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala index bb5122fd6ac8..b71bd5783d6d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala @@ -17,20 +17,20 @@ package org.apache.spark.scheduler.cluster -import org.apache.spark.rpc.RpcEndpointRef +import akka.actor.{Address, ActorRef} /** * Grouping of data for an executor used by CoarseGrainedSchedulerBackend. * - * @param executorEndpoint The RpcEndpointRef representing this executor + * @param executorActor The ActorRef representing this executor * @param executorAddress The network address of this executor * @param executorHost The hostname that this executor is running on * @param freeCores The current number of cores available for work on the executor * @param totalCores The total number of cores available to the executor */ private[cluster] class ExecutorData( - val executorEndpoint: RpcEndpointRef, - val executorAddress: String, + val executorActor: ActorRef, + val executorAddress: Address, val executorHost: String , var freeCores: Int, val totalCores: Int diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorActorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorActorSuite.scala new file mode 100644 index 000000000000..96f7f7a2f214 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorActorSuite.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import org.apache.spark.util.ThreadStackTrace +import org.apache.spark.{LocalSparkContext, SparkContext} +import org.scalatest.FunSuite + +class ExecutorActorSuite extends FunSuite with LocalSparkContext { + + test("ExecutorActor") { + sc = new SparkContext("local[2]", "test") + sc.env.rpcEnv.setupEndpoint("executor-actor", new ExecutorActor(sc.env.rpcEnv, "executor-1")) + val receiverRef = sc.env.rpcEnv.setupDriverEndpointRef("executor-actor") + val response = receiverRef.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump) + assert(response.size > 0) + } + } From a435eb234a6be4fdbdb1f899ac4ffb22d031df72 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sat, 27 Dec 2014 18:11:51 +0800 Subject: [PATCH 13/36] Fault tolerance for RpcEndpoint --- .../scala/org/apache/spark/rpc/RpcEnv.scala | 2 +- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 60 +++++++++++-------- .../org/apache/spark/rpc/RpcEnvSuite.scala | 40 ++++++++++++- .../spark/rpc/akka/AkkaRpcEnvSuite.scala | 1 + 4 files changed, 77 insertions(+), 26 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 8d00707e1844..9847d1871391 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -53,7 +53,7 @@ trait RpcEnv { endpointRef } - def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef + def setupEndpoint(name: String, endpointCreator: => RpcEndpoint): RpcEndpointRef def setupDriverEndpointRef(name: String): RpcEndpointRef diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 600635943234..1955dad6ce2b 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -17,6 +17,8 @@ package org.apache.spark.rpc.akka +import java.util.concurrent.CountDownLatch + import scala.concurrent.Await import scala.concurrent.duration.Duration @@ -30,32 +32,42 @@ import org.apache.spark.util.AkkaUtils class AkkaRpcEnv(actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { - override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { - val actorRef = actorSystem.actorOf(Props(new Actor with Logging { + override def setupEndpoint(name: String, endpointCreator: => RpcEndpoint): RpcEndpointRef = { + val latch = new CountDownLatch(1) + try { + @volatile var endpointRef: AkkaRpcEndpointRef = null + val actorRef = actorSystem.actorOf(Props(new Actor with Logging { - override def preStart(): Unit = { - // Listen for remote client disconnection events, since they don't go through Akka's watch() - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - } + val endpoint = endpointCreator + latch.await() + require(endpointRef != null) + registerEndpoint(endpoint, endpointRef) - override def receive: Receive = { - case message @ DisassociatedEvent(_, remoteAddress, _) => - val pf = endpoint.receive(new AkkaRpcEndpointRef(sender(), conf)) - if (pf.isDefinedAt(message)) { - pf.apply(message) - } - endpoint.remoteConnectionTerminated(remoteAddress.toString) - case message: Any => - logInfo("Received RPC message: " + message) - val pf = endpoint.receive(new AkkaRpcEndpointRef(sender(), conf)) - if (pf.isDefinedAt(message)) { - pf.apply(message) - } - } - }), name = name) - val endpointRef = new AkkaRpcEndpointRef(actorRef, conf) - registerEndpoint(endpoint, endpointRef) - endpointRef + override def preStart(): Unit = { + // Listen for remote client disconnection events, since they don't go through Akka's watch() + context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + } + + override def receive: Receive = { + case message @ DisassociatedEvent(_, remoteAddress, _) => + val pf = endpoint.receive(new AkkaRpcEndpointRef(sender(), conf)) + if (pf.isDefinedAt(message)) { + pf.apply(message) + } + endpoint.remoteConnectionTerminated(remoteAddress.toString) + case message: Any => + logInfo("Received RPC message: " + message) + val pf = endpoint.receive(new AkkaRpcEndpointRef(sender(), conf)) + if (pf.isDefinedAt(message)) { + pf.apply(message) + } + } + }), name = name) + endpointRef = new AkkaRpcEndpointRef(actorRef, conf) + endpointRef + } finally { + latch.countDown() + } } override def setupDriverEndpointRef(name: String): RpcEndpointRef = { diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 165bad4f5e5b..0eabd73ae2e5 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -17,8 +17,15 @@ package org.apache.spark.rpc +import scala.concurrent.duration._ +import scala.language.postfixOps + import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.concurrent.Eventually._ +/** + * Common tests for an RpcEnv implementation. + */ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { var env: RpcEnv = _ @@ -111,7 +118,10 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } } val rpcEndpointRef = env.setupEndpoint("register_test", endpoint) - assert(rpcEndpointRef eq env.endpointRef(endpoint)) + + eventually(timeout(5 seconds), interval(200 milliseconds)) { + assert(rpcEndpointRef eq env.endpointRef(endpoint)) + } endpoint.stop() val e = intercept[IllegalArgumentException] { @@ -119,4 +129,32 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } assert(e.getMessage.contains("Cannot find RpcEndpointRef")) } + + test("fault tolerance") { + case class SetState(state: Int) + + case object Crash + + case object GetState + + val rpcEndpointRef = env.setupEndpoint("fault_tolerance", new RpcEndpoint { + override val rpcEnv = env + + var state: Int = 0 + + override def receive(sender: RpcEndpointRef) = { + case SetState(state) => this.state = state + case Crash => throw new RuntimeException("Oops") + case GetState => sender.send(state) + } + }) + assert(0 === rpcEndpointRef.askWithReply[Int](GetState)) + + rpcEndpointRef.send(SetState(10)) + assert(10 === rpcEndpointRef.askWithReply[Int](GetState)) + + rpcEndpointRef.send(Crash) + // RpcEndpoint is crashed. Should reset its state. + assert(0 === rpcEndpointRef.askWithReply[Int](GetState)) + } } diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala index 462810e3ec74..1a1da58fd148 100644 --- a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala @@ -33,6 +33,7 @@ class AkkaRpcEnvSuite extends RpcEnvSuite { } override def destroyRpcEnv(rpcEnv: RpcEnv): Unit = { + rpcEnv.stopAll() if (akkaSystem != null) { akkaSystem.shutdown() } From 2636664922701272408a6af487125c260c6ad416 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 29 Dec 2014 16:08:39 +0800 Subject: [PATCH 14/36] Change CoarseGrainedExecutorBackend to a RpcEndpoint --- .../CoarseGrainedExecutorBackend.scala | 50 ++++++++++--------- .../scala/org/apache/spark/rpc/RpcEnv.scala | 9 +++- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 4 +- .../cluster/CoarseGrainedClusterMessage.scala | 5 +- .../CoarseGrainedSchedulerBackend.scala | 2 +- 5 files changed, 40 insertions(+), 30 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index c794a7bc3599..dc46911a4513 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -19,19 +19,18 @@ package org.apache.spark.executor import java.nio.ByteBuffer -import scala.concurrent.Await - -import akka.actor.{Actor, ActorSelection, Props} -import akka.pattern.Patterns -import akka.remote.{RemotingLifecycleEvent, DisassociatedEvent} +import akka.actor.Props +import akka.remote.DisassociatedEvent import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv} import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.worker.WorkerWatcher +import org.apache.spark.rpc.akka.AkkaRpcEnv +import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef} import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} +import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} private[spark] class CoarseGrainedExecutorBackend( driverUrl: String, @@ -39,21 +38,23 @@ private[spark] class CoarseGrainedExecutorBackend( hostPort: String, cores: Int, env: SparkEnv) - extends Actor with ActorLogReceive with ExecutorBackend with Logging { + extends RpcEndpoint with ExecutorBackend with Logging { + + override val rpcEnv = env.rpcEnv Utils.checkHostPort(hostPort, "Expected hostport") var executor: Executor = null - var driver: ActorSelection = null + var driver: RpcEndpointRef = _ - override def preStart() { + override def preStart(): Unit = { + // self is valid now. So now we can use `send` logInfo("Connecting to driver: " + driverUrl) - driver = context.actorSelection(driverUrl) - driver ! RegisterExecutor(executorId, hostPort, cores) - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + driver = rpcEnv.setupEndpointRefByUrl(driverUrl) + driver.send(RegisterExecutor(executorId, hostPort, cores, self)) } - override def receiveWithLogging = { + override def receive(sender: RpcEndpointRef) = { case RegisteredExecutor => logInfo("Successfully registered with driver") val (hostname, _) = Utils.parseHostPort(hostPort) @@ -89,12 +90,13 @@ private[spark] class CoarseGrainedExecutorBackend( case StopExecutor => logInfo("Driver commanded a shutdown") executor.stop() - context.stop(self) - context.system.shutdown() + stop() + rpcEnv.stopAll() + env.actorSystem.shutdown() } override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) { - driver ! StatusUpdate(executorId, taskId, state, data) + driver.send(StatusUpdate(executorId, taskId, state, data)) } } @@ -119,12 +121,14 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { val port = executorConf.getInt("spark.executor.port", 0) val (fetcher, _) = AkkaUtils.createActorSystem( "driverPropsFetcher", hostname, port, executorConf, new SecurityManager(executorConf)) - val driver = fetcher.actorSelection(driverUrl) - val timeout = AkkaUtils.askTimeout(executorConf) - val fut = Patterns.ask(driver, RetrieveSparkProps, timeout) - val props = Await.result(fut, timeout).asInstanceOf[Seq[(String, String)]] ++ + val rpcEnv = new AkkaRpcEnv(fetcher, executorConf) + val driver = rpcEnv.setupEndpointRefByUrl(driverUrl) + val props = driver.askWithReply[Seq[(String, String)]](RetrieveSparkProps) ++ Seq[(String, String)](("spark.app.id", appId)) + + rpcEnv.stopAll() fetcher.shutdown() + fetcher.awaitTermination() // Create SparkEnv using properties we fetched from the driver. val driverConf = new SparkConf().setAll(props) @@ -137,10 +141,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { // Start the CoarseGrainedExecutorBackend actor. val sparkHostPort = hostname + ":" + boundPort - env.actorSystem.actorOf( - Props(classOf[CoarseGrainedExecutorBackend], - driverUrl, executorId, sparkHostPort, cores, env), - name = "Executor") + env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend( + driverUrl, executorId, sparkHostPort, cores, env)) workerUrl.foreach { url => env.actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher") } diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 9847d1871391..3b12cade6802 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -72,10 +72,15 @@ trait RpcEndpoint { val rpcEnv: RpcEnv + def preStart(): Unit = {} + /** - * Provide the implicit sender. + * Provide the implicit sender. `self` will become valid when `preStart` is called. */ - implicit final def self: RpcEndpointRef = rpcEnv.endpointRef(this) + implicit final def self: RpcEndpointRef = { + require(rpcEnv != null, "rpcEnv has not been initialized") + rpcEnv.endpointRef(this) + } /** * Same assumption like Actor: messages sent to a RpcEndpoint will be delivered in sequence, and diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 1955dad6ce2b..9ca4606ca44c 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -31,6 +31,7 @@ import org.apache.spark.rpc._ import org.apache.spark.util.AkkaUtils class AkkaRpcEnv(actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { + // TODO Once finishing the new Rpc mechanism, make actorSystem be a private val override def setupEndpoint(name: String, endpointCreator: => RpcEndpoint): RpcEndpointRef = { val latch = new CountDownLatch(1) @@ -44,6 +45,7 @@ class AkkaRpcEnv(actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { registerEndpoint(endpoint, endpointRef) override def preStart(): Unit = { + endpoint.preStart() // Listen for remote client disconnection events, since they don't go through Akka's watch() context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) } @@ -93,7 +95,7 @@ class AkkaRpcEnv(actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { override def toString = s"${getClass.getSimpleName}($actorSystem)" } -private[akka] class AkkaRpcEndpointRef(val actorRef: ActorRef, conf: SparkConf) +private[akka] class AkkaRpcEndpointRef(val actorRef: ActorRef, @transient conf: SparkConf) extends RpcEndpointRef with Serializable with Logging { private[this] val maxRetries = conf.getInt("spark.akka.num.retries", 3) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 1da6fe976da5..f2cc88b5a384 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler.cluster import java.nio.ByteBuffer import org.apache.spark.TaskState.TaskState +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.{SerializableBuffer, Utils} private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable @@ -39,8 +40,8 @@ private[spark] object CoarseGrainedClusterMessages { case class RegisterExecutorFailed(message: String) extends CoarseGrainedClusterMessage // Executors to driver - case class RegisterExecutor(executorId: String, hostPort: String, cores: Int) - extends CoarseGrainedClusterMessage { + case class RegisterExecutor(executorId: String, hostPort: String, cores: Int, + executorRef: RpcEndpointRef) extends CoarseGrainedClusterMessage { Utils.checkHostPort(hostPort, "Expected host port") } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index fe9914b50bc5..6a6962d5105a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -84,7 +84,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste } def receiveWithLogging = { - case RegisterExecutor(executorId, hostPort, cores) => + case RegisterExecutor(executorId, hostPort, cores, executorRef) => Utils.checkHostPort(hostPort, "Host port expected " + hostPort) if (executorDataMap.contains(executorId)) { sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId) From 85cfb333fe157d6eebcb4292fb0abe5af8624969 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 29 Dec 2014 17:37:49 +0800 Subject: [PATCH 15/36] Change CoarseMesosSchedulerBackend to use RpcEndpoint --- .../CoarseGrainedExecutorBackend.scala | 10 +-- .../scala/org/apache/spark/rpc/RpcEnv.scala | 9 +++ .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 6 +- .../CoarseGrainedSchedulerBackend.scala | 80 +++++++++---------- .../scheduler/cluster/ExecutorData.scala | 6 +- .../cluster/SimrSchedulerBackend.scala | 2 +- .../cluster/SparkDeploySchedulerBackend.scala | 2 +- .../cluster/YarnSchedulerBackend.scala | 4 +- .../mesos/CoarseMesosSchedulerBackend.scala | 2 +- 9 files changed, 61 insertions(+), 60 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index dc46911a4513..75fbee49c621 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -20,7 +20,6 @@ package org.apache.spark.executor import java.nio.ByteBuffer import akka.actor.Props -import akka.remote.DisassociatedEvent import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv} import org.apache.spark.TaskState.TaskState @@ -83,10 +82,6 @@ private[spark] class CoarseGrainedExecutorBackend( executor.killTask(taskId, interruptThread) } - case x: DisassociatedEvent => - logError(s"Driver $x disassociated! Shutting down.") - System.exit(1) - case StopExecutor => logInfo("Driver commanded a shutdown") executor.stop() @@ -95,6 +90,11 @@ private[spark] class CoarseGrainedExecutorBackend( env.actorSystem.shutdown() } + override def remoteConnectionTerminated(remoteAddress: String): Unit = { + logError(s"Driver $remoteAddress disassociated! Shutting down.") + System.exit(1) + } + override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) { driver.send(StatusUpdate(executorId, taskId, state, data)) } diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 3b12cade6802..0458e95f9949 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -67,6 +67,15 @@ trait RpcEnv { /** * An end point for the RPC that defines what functions to trigger given a message. + * + * RpcEndpoint will be guaranteed that `preStart`, `receive` and `remoteConnectionTerminated` will + * be called in sequence. + * + * Happen before relation: + * + * constructor preStart receive* remoteConnectionTerminated + * + * ?? Need to guarantee that no message will be delivered after remoteConnectionTerminated ?? */ trait RpcEndpoint { diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 9ca4606ca44c..1a56ae423550 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -51,11 +51,7 @@ class AkkaRpcEnv(actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { } override def receive: Receive = { - case message @ DisassociatedEvent(_, remoteAddress, _) => - val pf = endpoint.receive(new AkkaRpcEndpointRef(sender(), conf)) - if (pf.isDefinedAt(message)) { - pf.apply(message) - } + case DisassociatedEvent(_, remoteAddress, _) => endpoint.remoteConnectionTerminated(remoteAddress.toString) case message: Any => logInfo("Received RPC message: " + message) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 6a6962d5105a..8f507f6acbeb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -18,19 +18,15 @@ package org.apache.spark.scheduler.cluster import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.{TimeUnit, Executors} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.concurrent.Await -import scala.concurrent.duration._ - -import akka.actor._ -import akka.pattern.ask -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import org.apache.spark.{ExecutorAllocationClient, Logging, SparkEnv, SparkException, TaskState} +import org.apache.spark.rpc.{RpcEndpoint, RpcEnv, RpcEndpointRef} import org.apache.spark.scheduler.{SchedulerBackend, SlaveLost, TaskDescription, TaskSchedulerImpl, WorkerOffer} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Utils} +import org.apache.spark.util.{SerializableBuffer, AkkaUtils, Utils} /** * A scheduler backend that waits for coarse grained executors to connect to it through Akka. @@ -41,7 +37,7 @@ import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Ut * (spark.deploy.*). */ private[spark] -class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSystem: ActorSystem) +class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: RpcEnv) extends ExecutorAllocationClient with SchedulerBackend with Logging { // Use an atomic variable to track total number of cores in the cluster for simplicity and speed @@ -69,34 +65,34 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste // Executors we have requested the cluster manager to kill that have not died yet private val executorsPendingToRemove = new HashSet[String] - class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor with ActorLogReceive { + class DriverActor(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) extends RpcEndpoint with Logging { override protected def log = CoarseGrainedSchedulerBackend.this.log - private val addressToExecutorId = new HashMap[Address, String] + private val addressToExecutorId = new HashMap[String, String] - override def preStart() { - // Listen for remote client disconnection events, since they don't go through Akka's watch() - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + private val reviveScheduler = Executors.newScheduledThreadPool(1, Utils.namedThreadFactory("driver-revive-scheduler")) + override def preStart() { // Periodically revive offers to allow delay scheduling to work val reviveInterval = conf.getLong("spark.scheduler.revive.interval", 1000) - import context.dispatcher - context.system.scheduler.schedule(0.millis, reviveInterval.millis, self, ReviveOffers) + reviveScheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = self.send(ReviveOffers) + }, 0, reviveInterval, TimeUnit.MILLISECONDS) } - def receiveWithLogging = { + def receive(sender: RpcEndpointRef) = { case RegisterExecutor(executorId, hostPort, cores, executorRef) => Utils.checkHostPort(hostPort, "Host port expected " + hostPort) if (executorDataMap.contains(executorId)) { - sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId) + sender.send(RegisterExecutorFailed("Duplicate executor ID: " + executorId)) } else { logInfo("Registered executor: " + sender + " with ID " + executorId) - sender ! RegisteredExecutor + sender.send(RegisteredExecutor) - addressToExecutorId(sender.path.address) = executorId + addressToExecutorId(sender.address) = executorId totalCoreCount.addAndGet(cores) totalRegisteredExecutors.addAndGet(1) val (host, _) = Utils.parseHostPort(hostPort) - val data = new ExecutorData(sender, sender.path.address, host, cores, cores) + val data = new ExecutorData(sender, sender.address, host, cores, cores) // This must be synchronized because variables mutated // in this block are read when requesting executors CoarseGrainedSchedulerBackend.this.synchronized { @@ -129,33 +125,29 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste case KillTask(taskId, executorId, interruptThread) => executorDataMap.get(executorId) match { case Some(executorInfo) => - executorInfo.executorActor ! KillTask(taskId, executorId, interruptThread) + executorInfo.executorActor.send(KillTask(taskId, executorId, interruptThread)) case None => // Ignoring the task kill since the executor is not registered. logWarning(s"Attempted to kill task $taskId for unknown executor $executorId.") } case StopDriver => - sender ! true - context.stop(self) + sender.send(true) + stop() case StopExecutors => logInfo("Asking each executor to shut down") for ((_, executorData) <- executorDataMap) { - executorData.executorActor ! StopExecutor + executorData.executorActor .send(StopExecutor) } - sender ! true + sender.send(true) case RemoveExecutor(executorId, reason) => removeExecutor(executorId, reason) - sender ! true - - case DisassociatedEvent(_, address, _) => - addressToExecutorId.get(address).foreach(removeExecutor(_, - "remote Akka client disassociated")) + sender.send(true) case RetrieveSparkProps => - sender ! sparkProperties + sender.send(sparkProperties) } // Make fake resource offers on all executors @@ -195,11 +187,16 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste else { val executorData = executorDataMap(task.executorId) executorData.freeCores -= scheduler.CPUS_PER_TASK - executorData.executorActor ! LaunchTask(new SerializableBuffer(serializedTask)) + executorData.executorActor.send(LaunchTask(new SerializableBuffer(serializedTask))) } } } + override def remoteConnectionTerminated(remoteAddress: String): Unit = { + addressToExecutorId.get(remoteAddress).foreach(removeExecutor(_, + "remote Akka client disassociated")) + } + // Remove a disconnected slave from the cluster def removeExecutor(executorId: String, reason: String): Unit = { executorDataMap.get(executorId) match { @@ -218,7 +215,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste } } - var driverActor: ActorRef = null + var driverActor: RpcEndpointRef = null val taskIdsOnSlave = new HashMap[String, HashSet[String]] override def start() { @@ -229,16 +226,15 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste } } // TODO (prashant) send conf instead of properties - driverActor = actorSystem.actorOf( - Props(new DriverActor(properties)), name = CoarseGrainedSchedulerBackend.ACTOR_NAME) + driverActor = rpcEnv.setupEndpoint(CoarseGrainedSchedulerBackend.ACTOR_NAME, + new DriverActor(rpcEnv, properties)) } def stopExecutors() { try { if (driverActor != null) { logInfo("Shutting down all executors") - val future = driverActor.ask(StopExecutors)(timeout) - Await.ready(future, timeout) + driverActor.askWithReply(StopExecutors) } } catch { case e: Exception => @@ -250,8 +246,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste stopExecutors() try { if (driverActor != null) { - val future = driverActor.ask(StopDriver)(timeout) - Await.ready(future, timeout) + driverActor.askWithReply(StopDriver) } } catch { case e: Exception => @@ -260,11 +255,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste } override def reviveOffers() { - driverActor ! ReviveOffers + driverActor.send(ReviveOffers) } override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { - driverActor ! KillTask(taskId, executorId, interruptThread) + driverActor.send(KillTask(taskId, executorId, interruptThread)) } override def defaultParallelism(): Int = { @@ -274,8 +269,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste // Called by subclasses when notified of a lost worker def removeExecutor(executorId: String, reason: String) { try { - val future = driverActor.ask(RemoveExecutor(executorId, reason))(timeout) - Await.ready(future, timeout) + driverActor.askWithReply(RemoveExecutor(executorId, reason)) } catch { case e: Exception => throw new SparkException("Error notifying standalone scheduler's driver actor", e) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala index b71bd5783d6d..7045c577443c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler.cluster -import akka.actor.{Address, ActorRef} +import org.apache.spark.rpc.RpcEndpointRef /** * Grouping of data for an executor used by CoarseGrainedSchedulerBackend. @@ -29,8 +29,8 @@ import akka.actor.{Address, ActorRef} * @param totalCores The total number of cores available to the executor */ private[cluster] class ExecutorData( - val executorActor: ActorRef, - val executorAddress: Address, + val executorActor: RpcEndpointRef, + val executorAddress: String, val executorHost: String , var freeCores: Int, val totalCores: Int diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index ee10aa061f4e..a851a01c07ba 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -27,7 +27,7 @@ private[spark] class SimrSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext, driverFilePath: String) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with Logging { val tmpPath = new Path(driverFilePath + "_tmp") diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 8c7de75600b5..f342d9c466e2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -27,7 +27,7 @@ private[spark] class SparkDeploySchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext, masters: Array[String]) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with AppClientListener with Logging { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 50721b9d6cd6..23f232211746 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -33,7 +33,9 @@ import org.apache.spark.util.AkkaUtils private[spark] abstract class YarnSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) { + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) { + + val actorSystem = sc.env.actorSystem if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { minRegisteredRatio = 0.8 diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 5289661eb896..371c2debaa21 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -47,7 +47,7 @@ private[spark] class CoarseMesosSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext, master: String) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with MScheduler with Logging { From 8e561b49cbf48519680cbf5809ed699cf910eb61 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 30 Dec 2014 15:19:34 +0800 Subject: [PATCH 16/36] Change MapOutputTrackerMasterActor to use RpcEndpoint --- .../org/apache/spark/MapOutputTracker.scala | 29 +++----- .../scala/org/apache/spark/SparkEnv.scala | 15 +++- .../apache/spark/MapOutputTrackerSuite.scala | 62 ++++++++++------ .../apache/spark/util/AkkaUtilsSuite.scala | 74 ++++++++++--------- 4 files changed, 101 insertions(+), 79 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 6e4edc7c80d7..6d32c3231784 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -21,13 +21,10 @@ import java.io._ import java.util.concurrent.ConcurrentHashMap import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import scala.collection.mutable.{HashSet, HashMap, Map} -import scala.concurrent.Await +import scala.collection.mutable.{HashSet, Map} import scala.collection.JavaConversions._ -import akka.actor._ -import akka.pattern.ask - +import org.apache.spark.rpc.{RpcEnv, RpcEndpointRef, RpcEndpoint} import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.MetadataFetchFailedException import org.apache.spark.storage.BlockManagerId @@ -39,14 +36,13 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int) private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage /** Actor class for MapOutputTrackerMaster */ -private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster, conf: SparkConf) - extends Actor with ActorLogReceive with Logging { +private[spark] class MapOutputTrackerMasterActor(override val rpcEnv: RpcEnv, + tracker: MapOutputTrackerMaster, conf: SparkConf) extends RpcEndpoint with Logging { val maxAkkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) - override def receiveWithLogging = { + override def receive(sender: RpcEndpointRef) = { case GetMapOutputStatuses(shuffleId: Int) => - val hostPort = sender.path.address.hostPort - logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort) + logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + sender) val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId) val serializedSize = mapOutputStatuses.size if (serializedSize > maxAkkaFrameSize) { @@ -60,12 +56,12 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster logError(msg, exception) throw exception } - sender ! mapOutputStatuses + sender.send(mapOutputStatuses) case StopMapOutputTracker => logInfo("MapOutputTrackerActor stopped!") - sender ! true - context.stop(self) + sender.send(true) + stop() } } @@ -75,12 +71,9 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster * (driver and executor) use different HashMap to store its metadata. */ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging { - private val timeout = AkkaUtils.askTimeout(conf) - private val retryAttempts = AkkaUtils.numRetries(conf) - private val retryIntervalMs = AkkaUtils.retryWaitMs(conf) /** Set to the MapOutputTrackerActor living on the driver. */ - var trackerActor: ActorRef = _ + var trackerActor: RpcEndpointRef = _ /** * This HashMap has different behavior for the driver and the executors. @@ -110,7 +103,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging */ protected def askTracker(message: Any): Any = { try { - AkkaUtils.askWithReply(message, trackerActor, retryAttempts, retryIntervalMs, timeout) + trackerActor.askWithReply(message) } catch { case e: Exception => logError("Error communicating with MapOutputTracker", e) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index c92beaebee89..04e527993e12 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -34,7 +34,7 @@ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.network.nio.NioBlockTransferService -import org.apache.spark.rpc.RpcEnv +import org.apache.spark.rpc.{RpcEndpointRef, RpcEndpoint, RpcEnv} import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.Serializer @@ -272,6 +272,15 @@ object SparkEnv extends Logging { } } + def registerOrLookupEndpoint(name: String, endpointCreator: => RpcEndpoint): RpcEndpointRef = { + if (isDriver) { + logInfo("Registering " + name) + rpcEnv.setupEndpoint(name, endpointCreator) + } else { + rpcEnv.setupDriverEndpointRef(name) + } + } + val mapOutputTracker = if (isDriver) { new MapOutputTrackerMaster(conf) } else { @@ -280,9 +289,9 @@ object SparkEnv extends Logging { // Have to assign trackerActor after initialization as MapOutputTrackerActor // requires the MapOutputTracker itself - mapOutputTracker.trackerActor = registerOrLookup( + mapOutputTracker.trackerActor = registerOrLookupEndpoint( "MapOutputTracker", - new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf)) + new MapOutputTrackerMasterActor(rpcEnv, mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf)) // Let the user specify short names for shuffle managers val shortShuffleMgrNames = Map( diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index d27880f4bc32..f4fb476d4864 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -17,12 +17,13 @@ package org.apache.spark -import scala.concurrent.Await +import org.apache.spark.rpc.akka.AkkaRpcEnv import akka.actor._ -import akka.testkit.TestActorRef +import org.mockito.Mockito._ import org.scalatest.FunSuite +import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.BlockManagerId @@ -33,18 +34,21 @@ class MapOutputTrackerSuite extends FunSuite { test("master start and stop") { val actorSystem = ActorSystem("test") + val rpcEnv = new AkkaRpcEnv(actorSystem, conf) val tracker = new MapOutputTrackerMaster(conf) - tracker.trackerActor = - actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) + tracker.trackerActor = rpcEnv.setupEndpoint( + "MapOutputTracker", new MapOutputTrackerMasterActor(rpcEnv, tracker, conf)) tracker.stop() + rpcEnv.stopAll() actorSystem.shutdown() } test("master register shuffle and fetch") { val actorSystem = ActorSystem("test") + val rpcEnv = new AkkaRpcEnv(actorSystem, conf) val tracker = new MapOutputTrackerMaster(conf) - tracker.trackerActor = - actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) + tracker.trackerActor = rpcEnv.setupEndpoint( + "MapOutputTracker", new MapOutputTrackerMasterActor(rpcEnv, tracker, conf)) tracker.registerShuffle(10, 2) assert(tracker.containsShuffle(10)) val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) @@ -57,13 +61,16 @@ class MapOutputTrackerSuite extends FunSuite { assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000), (BlockManagerId("b", "hostB", 1000), size10000))) tracker.stop() + rpcEnv.stopAll() actorSystem.shutdown() } test("master register and unregister shuffle") { val actorSystem = ActorSystem("test") + val rpcEnv = new AkkaRpcEnv(actorSystem, conf) val tracker = new MapOutputTrackerMaster(conf) - tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) + tracker.trackerActor = rpcEnv.setupEndpoint( + "MapOutputTracker", new MapOutputTrackerMasterActor(rpcEnv, tracker, conf)) tracker.registerShuffle(10, 2) val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) @@ -78,14 +85,16 @@ class MapOutputTrackerSuite extends FunSuite { assert(tracker.getServerStatuses(10, 0).isEmpty) tracker.stop() + rpcEnv.stopAll() actorSystem.shutdown() } test("master register shuffle and unregister map output and fetch") { val actorSystem = ActorSystem("test") + val rpcEnv = new AkkaRpcEnv(actorSystem, conf) val tracker = new MapOutputTrackerMaster(conf) - tracker.trackerActor = - actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) + tracker.trackerActor = rpcEnv.setupEndpoint( + "MapOutputTracker", new MapOutputTrackerMasterActor(rpcEnv, tracker, conf)) tracker.registerShuffle(10, 2) val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) @@ -104,6 +113,7 @@ class MapOutputTrackerSuite extends FunSuite { intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) } tracker.stop() + rpcEnv.stopAll() actorSystem.shutdown() } @@ -111,18 +121,18 @@ class MapOutputTrackerSuite extends FunSuite { val hostname = "localhost" val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf, securityManager = new SecurityManager(conf)) + val rpcEnv = new AkkaRpcEnv(actorSystem, conf) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerActor = rpcEnv.setupEndpoint( + "MapOutputTracker", new MapOutputTrackerMasterActor(rpcEnv, masterTracker, conf)) val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf, securityManager = new SecurityManager(conf)) + val slaveRpcEnv = new AkkaRpcEnv(slaveSystem, conf) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") - val timeout = AkkaUtils.lookupTimeout(conf) - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + val selection = s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker" + slaveTracker.trackerActor = slaveRpcEnv.setupEndpointRefByUrl(selection) masterTracker.registerShuffle(10, 1) masterTracker.incrementEpoch() @@ -147,7 +157,9 @@ class MapOutputTrackerSuite extends FunSuite { masterTracker.stop() slaveTracker.stop() + rpcEnv.stopAll() actorSystem.shutdown() + slaveRpcEnv.stopAll() slaveSystem.shutdown() } @@ -158,17 +170,19 @@ class MapOutputTrackerSuite extends FunSuite { val masterTracker = new MapOutputTrackerMaster(conf) val actorSystem = ActorSystem("test") - val actorRef = TestActorRef[MapOutputTrackerMasterActor]( - Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem) - val masterActor = actorRef.underlyingActor + val rpcEnv = new AkkaRpcEnv(actorSystem, conf) + val masterActor = new MapOutputTrackerMasterActor(rpcEnv, masterTracker, newConf) + rpcEnv.setupEndpoint("MapOutputTracker", masterActor) // Frame size should be ~123B, and no exception should be thrown masterTracker.registerShuffle(10, 1) masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0))) - masterActor.receive(GetMapOutputStatuses(10)) + val sender = mock(classOf[RpcEndpointRef]) + masterActor.receive(sender).apply(GetMapOutputStatuses(20)) // masterTracker.stop() // this throws an exception + rpcEnv.stopAll() actorSystem.shutdown() } @@ -179,9 +193,8 @@ class MapOutputTrackerSuite extends FunSuite { val masterTracker = new MapOutputTrackerMaster(conf) val actorSystem = ActorSystem("test") - val actorRef = TestActorRef[MapOutputTrackerMasterActor]( - Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem) - val masterActor = actorRef.underlyingActor + val rpcEnv = new AkkaRpcEnv(actorSystem, conf) + val masterActor = new MapOutputTrackerMasterActor(rpcEnv, masterTracker, newConf) // Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should throw exception. // Note that the size is hand-selected here because map output statuses are compressed before @@ -191,9 +204,12 @@ class MapOutputTrackerSuite extends FunSuite { masterTracker.registerMapOutput(20, i, new CompressedMapStatus( BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) } - intercept[SparkException] { masterActor.receive(GetMapOutputStatuses(20)) } + intercept[SparkException] { + masterActor.receive(RpcEndpoint.noSender).apply(GetMapOutputStatuses(20)) + } // masterTracker.stop() // this throws an exception + rpcEnv.stopAll() actorSystem.shutdown() } } diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala index 7bca1711ae22..b3c8f99afb6e 100644 --- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala @@ -17,13 +17,10 @@ package org.apache.spark.util -import scala.concurrent.Await - -import akka.actor._ - import org.scalatest.FunSuite import org.apache.spark._ +import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.scheduler.MapStatus import org.apache.spark.storage.BlockManagerId @@ -44,10 +41,10 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { conf = conf, securityManager = securityManager) System.setProperty("spark.hostPort", hostname + ":" + boundPort) assert(securityManager.isAuthenticationEnabled() === true) - + val rpcEnv = new AkkaRpcEnv(actorSystem, conf) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerActor = rpcEnv.setupEndpoint( + "MapOutputTracker", new MapOutputTrackerMasterActor(rpcEnv, masterTracker, conf)) val badconf = new SparkConf badconf.set("spark.authenticate", "true") @@ -58,15 +55,16 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf, securityManager = securityManagerBad) + val slaveRpcEnv = new AkkaRpcEnv(slaveSystem, conf) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") - val timeout = AkkaUtils.lookupTimeout(conf) + val selection = s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker" intercept[akka.actor.ActorNotFound] { - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerActor = slaveRpcEnv.setupEndpointRefByUrl(selection) } + rpcEnv.stopAll() actorSystem.shutdown() + slaveRpcEnv.stopAll() slaveSystem.shutdown() } @@ -74,31 +72,31 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { val conf = new SparkConf conf.set("spark.authenticate", "false") conf.set("spark.authenticate.secret", "bad") - val securityManager = new SecurityManager(conf); + val securityManager = new SecurityManager(conf) val hostname = "localhost" val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf, securityManager = securityManager) System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = new AkkaRpcEnv(actorSystem, conf) assert(securityManager.isAuthenticationEnabled() === false) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerActor = rpcEnv.setupEndpoint( + "MapOutputTracker", new MapOutputTrackerMasterActor(rpcEnv, masterTracker, conf)) val badconf = new SparkConf badconf.set("spark.authenticate", "false") badconf.set("spark.authenticate.secret", "good") - val securityManagerBad = new SecurityManager(badconf); + val securityManagerBad = new SecurityManager(badconf) val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = badconf, securityManager = securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") - val timeout = AkkaUtils.lookupTimeout(conf) - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + val slaveRpcEnv = new AkkaRpcEnv(slaveSystem, conf) + val selection = s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker" + slaveTracker.trackerActor = slaveRpcEnv.setupEndpointRefByUrl(selection) assert(securityManagerBad.isAuthenticationEnabled() === false) @@ -116,7 +114,9 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { assert(slaveTracker.getServerStatuses(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000))) + rpcEnv.stopAll() actorSystem.shutdown() + slaveRpcEnv.stopAll() slaveSystem.shutdown() } @@ -124,33 +124,33 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { val conf = new SparkConf conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") - val securityManager = new SecurityManager(conf); + val securityManager = new SecurityManager(conf) val hostname = "localhost" val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf, securityManager = securityManager) System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = new AkkaRpcEnv(actorSystem, conf) assert(securityManager.isAuthenticationEnabled() === true) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerActor = rpcEnv.setupEndpoint( + "MapOutputTracker", new MapOutputTrackerMasterActor(rpcEnv, masterTracker, conf)) val goodconf = new SparkConf goodconf.set("spark.authenticate", "true") goodconf.set("spark.authenticate.secret", "good") - val securityManagerGood = new SecurityManager(goodconf); + val securityManagerGood = new SecurityManager(goodconf) assert(securityManagerGood.isAuthenticationEnabled() === true) val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = goodconf, securityManager = securityManagerGood) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") - val timeout = AkkaUtils.lookupTimeout(conf) - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + val slaveRpcEnv = new AkkaRpcEnv(slaveSystem, conf) + val selection = s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker" + slaveTracker.trackerActor = slaveRpcEnv.setupEndpointRefByUrl(selection) masterTracker.registerShuffle(10, 1) masterTracker.incrementEpoch() @@ -166,7 +166,9 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { assert(slaveTracker.getServerStatuses(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000))) + rpcEnv.stopAll() actorSystem.shutdown() + slaveRpcEnv.stopAll() slaveSystem.shutdown() } @@ -175,37 +177,39 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") - val securityManager = new SecurityManager(conf); + val securityManager = new SecurityManager(conf) val hostname = "localhost" val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf, securityManager = securityManager) System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = new AkkaRpcEnv(actorSystem, conf) assert(securityManager.isAuthenticationEnabled() === true) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerActor = rpcEnv.setupEndpoint( + "MapOutputTracker", new MapOutputTrackerMasterActor(rpcEnv, masterTracker, conf)) val badconf = new SparkConf badconf.set("spark.authenticate", "false") badconf.set("spark.authenticate.secret", "bad") - val securityManagerBad = new SecurityManager(badconf); + val securityManagerBad = new SecurityManager(badconf) assert(securityManagerBad.isAuthenticationEnabled() === false) val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = badconf, securityManager = securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") - val timeout = AkkaUtils.lookupTimeout(conf) + val slaveRpcEnv = new AkkaRpcEnv(slaveSystem, conf) + val selection = s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker" intercept[akka.actor.ActorNotFound] { - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerActor = slaveRpcEnv.setupEndpointRefByUrl(selection) } + rpcEnv.stopAll() actorSystem.shutdown() + slaveRpcEnv.stopAll() slaveSystem.shutdown() } From a067228c1188e202fbdc49eaf1578b87f82183fc Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 30 Dec 2014 15:23:14 +0800 Subject: [PATCH 17/36] Fix the code style --- .../main/scala/org/apache/spark/SparkEnv.scala | 6 +++--- .../spark/network/rpc/SimpleRpcServer.scala | 17 +++++++++++++++++ .../org/apache/spark/rpc/akka/AkkaRpcEnv.scala | 3 ++- .../cluster/CoarseGrainedSchedulerBackend.scala | 9 ++++++--- 4 files changed, 28 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 04e527993e12..4314e0b800c1 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -289,9 +289,9 @@ object SparkEnv extends Logging { // Have to assign trackerActor after initialization as MapOutputTrackerActor // requires the MapOutputTracker itself - mapOutputTracker.trackerActor = registerOrLookupEndpoint( - "MapOutputTracker", - new MapOutputTrackerMasterActor(rpcEnv, mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf)) + mapOutputTracker.trackerActor = registerOrLookupEndpoint("MapOutputTracker", + new MapOutputTrackerMasterActor( + rpcEnv, mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf)) // Let the user specify short names for shuffle managers val shortShuffleMgrNames = Map( diff --git a/core/src/main/scala/org/apache/spark/network/rpc/SimpleRpcServer.scala b/core/src/main/scala/org/apache/spark/network/rpc/SimpleRpcServer.scala index 9b76ae076a34..10bbeb032e89 100644 --- a/core/src/main/scala/org/apache/spark/network/rpc/SimpleRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/rpc/SimpleRpcServer.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.network.rpc import java.nio.ByteBuffer diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 1a56ae423550..9ebd2d90f466 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -46,7 +46,8 @@ class AkkaRpcEnv(actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { override def preStart(): Unit = { endpoint.preStart() - // Listen for remote client disconnection events, since they don't go through Akka's watch() + // Listen for remote client disconnection events, + // since they don't go through Akka's watch() context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 8f507f6acbeb..35f2a47b6802 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -59,17 +59,20 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp private val executorDataMap = new HashMap[String, ExecutorData] - // Number of executors requested from the cluster manager that have not registered yet + // Number of executors requested from the cluster manager thaSimpt have not registered yet private var numPendingExecutors = 0 // Executors we have requested the cluster manager to kill that have not died yet private val executorsPendingToRemove = new HashSet[String] - class DriverActor(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) extends RpcEndpoint with Logging { + class DriverActor(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) + extends RpcEndpoint with Logging { + override protected def log = CoarseGrainedSchedulerBackend.this.log private val addressToExecutorId = new HashMap[String, String] - private val reviveScheduler = Executors.newScheduledThreadPool(1, Utils.namedThreadFactory("driver-revive-scheduler")) + private val reviveScheduler = + Executors.newScheduledThreadPool(1, Utils.namedThreadFactory("driver-revive-scheduler")) override def preStart() { // Periodically revive offers to allow delay scheduling to work From b13fbd9e2e29153cd5eea533a791dff47b1b9ca7 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 30 Dec 2014 17:10:43 +0800 Subject: [PATCH 18/36] Change DAGScheduler to use RpcEndpoint --- .../scala/org/apache/spark/rpc/RpcEnv.scala | 12 ++ .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 27 ++++- .../apache/spark/scheduler/DAGScheduler.scala | 110 ++++++++---------- .../spark/scheduler/DAGSchedulerSuite.scala | 47 +++----- 4 files changed, 97 insertions(+), 99 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 0458e95f9949..3c4baa787471 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -100,10 +100,22 @@ trait RpcEndpoint { */ def receive(sender: RpcEndpointRef): PartialFunction[Any, Unit] + /** + * Call onError when any exception is thrown during handling messages. + * + * @param e + */ + def onError(e: Throwable): Unit = { + // By default, throw e and let RpcEnv handle it + throw e + } + def remoteConnectionTerminated(remoteAddress: String): Unit = { // By default, do nothing. } + def postStop(): Unit = {} + final def stop(): Unit = { rpcEnv.stop(self) } diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 9ebd2d90f466..6c735a54e42b 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -30,6 +30,8 @@ import org.apache.spark.{Logging, SparkException, SparkConf} import org.apache.spark.rpc._ import org.apache.spark.util.AkkaUtils +import scala.util.control.NonFatal + class AkkaRpcEnv(actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { // TODO Once finishing the new Rpc mechanism, make actorSystem be a private val @@ -53,15 +55,28 @@ class AkkaRpcEnv(actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { override def receive: Receive = { case DisassociatedEvent(_, remoteAddress, _) => - endpoint.remoteConnectionTerminated(remoteAddress.toString) + try { + endpoint.remoteConnectionTerminated(remoteAddress.toString) + } catch { + case NonFatal(e) => endpoint.onError(e) + } case message: Any => - logInfo("Received RPC message: " + message) - val pf = endpoint.receive(new AkkaRpcEndpointRef(sender(), conf)) - if (pf.isDefinedAt(message)) { - pf.apply(message) + try { + logInfo("Received RPC message: " + message) + val pf = endpoint.receive(new AkkaRpcEndpointRef(sender(), conf)) + if (pf.isDefinedAt(message)) { + pf.apply(message) + } + } catch { + case NonFatal(e) => endpoint.onError(e) } } - }), name = name) + + override def postStop(): Unit = { + endpoint.postStop() + } + + }), name = name) endpointRef = new AkkaRpcEndpointRef(actorRef, conf) endpointRef } finally { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index cb8ccfbdbdcb..0c397882f258 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler import java.io.NotSerializableException import java.util.Properties import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.{TimeUnit, Executors} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack} import scala.concurrent.Await @@ -28,8 +29,6 @@ import scala.language.postfixOps import scala.reflect.ClassTag import scala.util.control.NonFatal -import akka.actor._ -import akka.actor.SupervisorStrategy.Stop import akka.pattern.ask import akka.util.Timeout @@ -38,6 +37,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD +import org.apache.spark.rpc.{RpcEnv, RpcEndpointRef, RpcEndpoint} import org.apache.spark.storage._ import org.apache.spark.util.{CallSite, SystemClock, Clock, Utils} import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat @@ -81,6 +81,11 @@ class DAGScheduler( def this(sc: SparkContext) = this(sc, sc.taskScheduler) + private val rpcEnv = env.rpcEnv + + private val messageScheduler = + Executors.newScheduledThreadPool(1, Utils.namedThreadFactory("dag-scheduler-message")) + private[scheduler] val nextJobId = new AtomicInteger(0) private[scheduler] def numTotalJobs: Int = nextJobId.get() private val nextStageId = new AtomicInteger(0) @@ -112,41 +117,28 @@ class DAGScheduler( // stray messages to detect. private val failedEpoch = new HashMap[String, Long] - private val dagSchedulerActorSupervisor = - env.actorSystem.actorOf(Props(new DAGSchedulerActorSupervisor(this))) - // A closure serializer that we reuse. // This is only safe because DAGScheduler runs in a single thread. private val closureSerializer = SparkEnv.get.closureSerializer.newInstance() - private[scheduler] var eventProcessActor: ActorRef = _ - /** If enabled, we may run certain actions like take() and first() locally. */ private val localExecutionEnabled = sc.getConf.getBoolean("spark.localExecution.enabled", false) /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */ private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false) - private def initializeEventProcessActor() { - // blocking the thread until supervisor is started, which ensures eventProcessActor is - // not null before any job is submitted - implicit val timeout = Timeout(30 seconds) - val initEventActorReply = - dagSchedulerActorSupervisor ? Props(new DAGSchedulerEventProcessActor(this)) - eventProcessActor = Await.result(initEventActorReply, timeout.duration). - asInstanceOf[ActorRef] - } - - initializeEventProcessActor() + private[scheduler] var eventProcessActor = rpcEnv.setupEndpoint( + "DAGSchedulerEventProcessActor-" + DAGScheduler.nextId, + new DAGSchedulerEventProcessActor(rpcEnv, this)) // Called by TaskScheduler to report task's starting. def taskStarted(task: Task[_], taskInfo: TaskInfo) { - eventProcessActor ! BeginEvent(task, taskInfo) + eventProcessActor.send(BeginEvent(task, taskInfo)) } // Called to report that a task has completed and results are being fetched remotely. def taskGettingResult(taskInfo: TaskInfo) { - eventProcessActor ! GettingResultEvent(taskInfo) + eventProcessActor.send(GettingResultEvent(taskInfo)) } // Called by TaskScheduler to report task completions or failures. @@ -157,7 +149,8 @@ class DAGScheduler( accumUpdates: Map[Long, Any], taskInfo: TaskInfo, taskMetrics: TaskMetrics) { - eventProcessActor ! CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics) + eventProcessActor.send( + CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics)) } /** @@ -179,18 +172,18 @@ class DAGScheduler( // Called by TaskScheduler when an executor fails. def executorLost(execId: String) { - eventProcessActor ! ExecutorLost(execId) + eventProcessActor.send(ExecutorLost(execId)) } // Called by TaskScheduler when a host is added def executorAdded(execId: String, host: String) { - eventProcessActor ! ExecutorAdded(execId, host) + eventProcessActor.send(ExecutorAdded(execId, host)) } // Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or // cancellation of the job itself. def taskSetFailed(taskSet: TaskSet, reason: String) { - eventProcessActor ! TaskSetFailed(taskSet, reason) + eventProcessActor.send(TaskSetFailed(taskSet, reason)) } private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = { @@ -495,8 +488,8 @@ class DAGScheduler( assert(partitions.size > 0) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler) - eventProcessActor ! JobSubmitted( - jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties) + eventProcessActor.send(JobSubmitted( + jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties)) waiter } @@ -536,8 +529,8 @@ class DAGScheduler( val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val partitions = (0 until rdd.partitions.size).toArray val jobId = nextJobId.getAndIncrement() - eventProcessActor ! JobSubmitted( - jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties) + eventProcessActor.send(JobSubmitted( + jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties)) listener.awaitResult() // Will throw an exception if the job fails } @@ -546,19 +539,19 @@ class DAGScheduler( */ def cancelJob(jobId: Int) { logInfo("Asked to cancel job " + jobId) - eventProcessActor ! JobCancelled(jobId) + eventProcessActor.send(JobCancelled(jobId)) } def cancelJobGroup(groupId: String) { logInfo("Asked to cancel job group " + groupId) - eventProcessActor ! JobGroupCancelled(groupId) + eventProcessActor.send(JobGroupCancelled(groupId)) } /** * Cancel all jobs that are running or waiting in the queue. */ def cancelAllJobs() { - eventProcessActor ! AllJobsCancelled + eventProcessActor.send(AllJobsCancelled) } private[scheduler] def doCancelAllJobs() { @@ -574,7 +567,7 @@ class DAGScheduler( * Cancel all jobs associated with a running or scheduled stage. */ def cancelStage(stageId: Int) { - eventProcessActor ! StageCancelled(stageId) + eventProcessActor.send(StageCancelled(stageId)) } /** @@ -1086,8 +1079,11 @@ class DAGScheduler( import env.actorSystem.dispatcher logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " + s"$failedStage (${failedStage.name}) due to fetch failure") - env.actorSystem.scheduler.scheduleOnce( - RESUBMIT_TIMEOUT, eventProcessActor, ResubmitFailedStages) + messageScheduler.schedule(new Runnable { + override def run(): Unit = { + eventProcessActor.send(ResubmitFailedStages) + } + }, RESUBMIT_TIMEOUT.toMillis, TimeUnit.MILLISECONDS) } failedStages += failedStage failedStages += mapStage @@ -1345,35 +1341,13 @@ class DAGScheduler( def stop() { logInfo("Stopping DAGScheduler") - dagSchedulerActorSupervisor ! PoisonPill + rpcEnv.stop(eventProcessActor) taskScheduler.stop() } } -private[scheduler] class DAGSchedulerActorSupervisor(dagScheduler: DAGScheduler) - extends Actor with Logging { - - override val supervisorStrategy = - OneForOneStrategy() { - case x: Exception => - logError("eventProcesserActor failed; shutting down SparkContext", x) - try { - dagScheduler.doCancelAllJobs() - } catch { - case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t) - } - dagScheduler.sc.stop() - Stop - } - - def receive = { - case p: Props => sender ! context.actorOf(p) - case _ => logWarning("received unknown message in DAGSchedulerActorSupervisor") - } -} - -private[scheduler] class DAGSchedulerEventProcessActor(dagScheduler: DAGScheduler) - extends Actor with Logging { +private[scheduler] class DAGSchedulerEventProcessActor( + override val rpcEnv: RpcEnv, dagScheduler: DAGScheduler) extends RpcEndpoint with Logging { override def preStart() { // set DAGScheduler for taskScheduler to ensure eventProcessActor is always @@ -1384,7 +1358,7 @@ private[scheduler] class DAGSchedulerEventProcessActor(dagScheduler: DAGSchedule /** * The main event loop of the DAG scheduler. */ - def receive = { + def receive(sender: RpcEndpointRef) = { case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) => dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) @@ -1427,6 +1401,17 @@ private[scheduler] class DAGSchedulerEventProcessActor(dagScheduler: DAGSchedule // Cancel any active jobs in postStop hook dagScheduler.cleanUpAfterSchedulerStop() } + + override def onError(e: Throwable): Unit = { + logError("eventProcesserActor failed; shutting down SparkContext", e) + stop() + try { + dagScheduler.doCancelAllJobs() + } catch { + case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t) + } + dagScheduler.sc.stop() + } } private[spark] object DAGScheduler { @@ -1438,4 +1423,9 @@ private[spark] object DAGScheduler { // The time, in millis, to wake up between polls of the completion queue in order to potentially // resubmit failed stages val POLL_TIMEOUT = 10L + + private val id = new AtomicInteger(0) + + // To resolve the conflicts of actor name in the unit tests + def nextId = id.getAndDecrement } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index d6ec9e129cce..a72be620aefa 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -20,26 +20,18 @@ package org.apache.spark.scheduler import scala.collection.mutable.{ArrayBuffer, HashSet, HashMap, Map} import scala.language.reflectiveCalls -import akka.actor._ -import akka.testkit.{ImplicitSender, TestKit, TestActorRef} -import org.scalatest.{BeforeAndAfter, FunSuiteLike} +import org.scalatest.{FunSuite, BeforeAndAfter} import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.rdd.RDD +import org.apache.spark.rpc.RpcEndpoint import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} import org.apache.spark.util.CallSite import org.apache.spark.executor.TaskMetrics -class BuggyDAGEventProcessActor extends Actor { - val state = 0 - def receive = { - case _ => throw new SparkException("error") - } -} - /** * An RDD for passing to DAGScheduler. These RDDs will use the dependencies and * preferredLocations (if any) that are passed to them. They are deliberately not executable @@ -65,8 +57,7 @@ class MyRDD( class DAGSchedulerSuiteDummyException extends Exception -class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with FunSuiteLike - with ImplicitSender with BeforeAndAfter with LocalSparkContext with Timeouts { +class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with Timeouts { val conf = new SparkConf /** Set of TaskSets the DAGScheduler has requested executed. */ @@ -111,9 +102,10 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F } } + var sc: SparkContext = null var mapOutputTracker: MapOutputTrackerMaster = null var scheduler: DAGScheduler = null - var dagEventProcessTestActor: TestActorRef[DAGSchedulerEventProcessActor] = null + var dagEventProcessTestActor: RpcEndpoint = null /** * Set of cache locations to return from our mock BlockManagerMaster. @@ -167,13 +159,13 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F runLocallyWithinThread(job) } } - dagEventProcessTestActor = TestActorRef[DAGSchedulerEventProcessActor]( - Props(classOf[DAGSchedulerEventProcessActor], scheduler))(system) + val rpcEnv = sc.env.rpcEnv + dagEventProcessTestActor = new DAGSchedulerEventProcessActor(rpcEnv, scheduler) + rpcEnv.setupEndpoint("DAGSchedulerEventProcessActorTest", dagEventProcessTestActor) } - override def afterAll() { - super.afterAll() - TestKit.shutdownActorSystem(system) + after { + sc.stop() } /** @@ -190,7 +182,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F * DAGScheduler event loop. */ private def runEvent(event: DAGSchedulerEvent) { - dagEventProcessTestActor.receive(event) + dagEventProcessTestActor.receive(RpcEndpoint.noSender)(event) } /** @@ -396,8 +388,9 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F runLocallyWithinThread(job) } } - dagEventProcessTestActor = TestActorRef[DAGSchedulerEventProcessActor]( - Props(classOf[DAGSchedulerEventProcessActor], noKillScheduler))(system) + val rpcEnv = sc.env.rpcEnv + dagEventProcessTestActor = new DAGSchedulerEventProcessActor(rpcEnv, noKillScheduler) + rpcEnv.setupEndpoint("DAGSchedulerEventProcessActor-nokill", dagEventProcessTestActor) val jobId = submit(new MyRDD(sc, 1, Nil), Array(0)) cancel(jobId) // Because the job wasn't actually cancelled, we shouldn't have received a failure message. @@ -725,18 +718,6 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F assert(sc.parallelize(1 to 10, 2).first() === 1) } - test("DAGSchedulerActorSupervisor closes the SparkContext when EventProcessActor crashes") { - val actorSystem = ActorSystem("test") - val supervisor = actorSystem.actorOf( - Props(classOf[DAGSchedulerActorSupervisor], scheduler), "dagSupervisor") - supervisor ! Props[BuggyDAGEventProcessActor] - val child = expectMsgType[ActorRef] - watch(child) - child ! "hi" - expectMsgPF(){ case Terminated(child) => () } - assert(scheduler.sc.dagScheduler === null) - } - test("accumulator not calculated for resubmitted result stage") { //just for register val accum = new Accumulator[Int](0, AccumulatorParam.IntAccumulatorParam) From 811b6b8f6a27e950fef456476f35221281ee8446 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sun, 4 Jan 2015 12:17:32 +0800 Subject: [PATCH 19/36] Change BlockManager to use RpcEndpoint --- .../scala/org/apache/spark/SparkEnv.scala | 6 +- .../scala/org/apache/spark/rpc/RpcEnv.scala | 14 +++ .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 24 ++++- .../apache/spark/scheduler/DAGScheduler.scala | 11 +- .../apache/spark/storage/BlockManager.scala | 15 +-- .../spark/storage/BlockManagerMaster.scala | 12 +-- .../storage/BlockManagerMasterActor.scala | 100 +++++++++--------- .../spark/storage/BlockManagerMessages.scala | 5 +- .../storage/BlockManagerSlaveActor.scala | 27 ++--- .../BlockManagerReplicationSuite.scala | 15 ++- .../spark/storage/BlockManagerSuite.scala | 23 ++-- 11 files changed, 140 insertions(+), 112 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 4314e0b800c1..3896a3d5d14b 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -311,12 +311,12 @@ object SparkEnv extends Logging { new NioBlockTransferService(conf, securityManager) } - val blockManagerMaster = new BlockManagerMaster(registerOrLookup( + val blockManagerMaster = new BlockManagerMaster(registerOrLookupEndpoint( "BlockManagerMaster", - new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf, isDriver) + new BlockManagerMasterActor(rpcEnv, isLocal, conf, listenerBus)), conf, isDriver) // NB: blockManager is not valid until initialize() is called later. - val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, + val blockManager = new BlockManager(executorId, rpcEnv, blockManagerMaster, serializer, conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, numUsableCores) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 3c4baa787471..412ddb0ce645 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -19,6 +19,10 @@ package org.apache.spark.rpc import java.util.concurrent.ConcurrentHashMap +import scala.concurrent.Future +import scala.concurrent.duration.{FiniteDuration, Duration} +import scala.reflect.ClassTag + /** * An RPC environment. */ @@ -133,8 +137,18 @@ trait RpcEndpointRef { def address: String + def host: Option[String] + + def port: Option[Int] + + def ask[T: ClassTag](message: Any): Future[T] + + def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] + def askWithReply[T](message: Any): T + def askWithReply[T](message: Any, timeout: FiniteDuration): T + /** * Send a message to the remote endpoint asynchronously. No delivery guarantee is provided. */ diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 6c735a54e42b..5912901ff32e 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -20,16 +20,19 @@ package org.apache.spark.rpc.akka import java.util.concurrent.CountDownLatch import scala.concurrent.Await -import scala.concurrent.duration.Duration +import scala.concurrent.duration._ +import scala.concurrent.Future +import scala.language.postfixOps import akka.actor.{ActorRef, Actor, Props, ActorSystem} -import akka.pattern.ask +import akka.pattern.{ask => akkaAsk} import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import org.apache.spark.{Logging, SparkException, SparkConf} import org.apache.spark.rpc._ import org.apache.spark.util.AkkaUtils +import scala.reflect.ClassTag import scala.util.control.NonFatal class AkkaRpcEnv(actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { @@ -112,12 +115,23 @@ private[akka] class AkkaRpcEndpointRef(val actorRef: ActorRef, @transient conf: private[this] val maxRetries = conf.getInt("spark.akka.num.retries", 3) private[this] val retryWaitMs = conf.getInt("spark.akka.retry.wait", 3000) - private[this] val timeout = - Duration.create(conf.getLong("spark.akka.lookupTimeout", 30), "seconds") + private[this] val defaultTimeout = conf.getLong("spark.akka.lookupTimeout", 30) seconds override val address: String = actorRef.path.address.toString - override def askWithReply[T](message: Any): T = { + override val host: Option[String] = actorRef.path.address.host + + override val port: Option[Int] = actorRef.path.address.port + + override def ask[T: ClassTag](message: Any): Future[T] = ask(message, defaultTimeout) + + override def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] = { + actorRef.ask(message)(timeout).mapTo[T] + } + + override def askWithReply[T](message: Any): T = askWithReply(message, defaultTimeout) + + override def askWithReply[T](message: Any, timeout: FiniteDuration): T = { var attempts = 0 var lastException: Exception = null while (attempts < maxRetries) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 0c397882f258..2e04e7565dbf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -23,15 +23,11 @@ import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.{TimeUnit, Executors} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack} -import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps import scala.reflect.ClassTag import scala.util.control.NonFatal -import akka.pattern.ask -import akka.util.Timeout - import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics @@ -163,11 +159,8 @@ class DAGScheduler( taskMetrics: Array[(Long, Int, Int, TaskMetrics)], // (taskId, stageId, stateAttempt, metrics) blockManagerId: BlockManagerId): Boolean = { listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics)) - implicit val timeout = Timeout(600 seconds) - - Await.result( - blockManagerMaster.driverActor ? BlockManagerHeartbeat(blockManagerId), - timeout.duration).asInstanceOf[Boolean] + blockManagerMaster.driverActor.askWithReply[Boolean]( + BlockManagerHeartbeat(blockManagerId), 600 seconds) } // Called by TaskScheduler when an executor fails. diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index d7b184f8a10e..0285ec18be6f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -38,6 +38,7 @@ import org.apache.spark.network.netty.{SparkTransportConf, NettyBlockTransferSer import org.apache.spark.network.shuffle.ExternalShuffleClient import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.network.util.{ConfigProvider, TransportConf} +import org.apache.spark.rpc.{RpcEnv, RpcEndpointRef} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.shuffle.hash.HashShuffleManager @@ -65,7 +66,7 @@ private[spark] class BlockResult( */ private[spark] class BlockManager( executorId: String, - actorSystem: ActorSystem, + rpcEnv: RpcEnv, val master: BlockManagerMaster, defaultSerializer: Serializer, maxMemory: Long, @@ -137,9 +138,9 @@ private[spark] class BlockManager( // Whether to compress shuffle output temporarily spilled to disk private val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true) - private val slaveActor = actorSystem.actorOf( - Props(new BlockManagerSlaveActor(this, mapOutputTracker)), - name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next) + private val slaveActor = rpcEnv.setupEndpoint( + name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next, + new BlockManagerSlaveActor(rpcEnv, this, mapOutputTracker)) // Pending re-registration action being executed asynchronously or null if none is pending. // Accesses should synchronize on asyncReregisterLock. @@ -168,7 +169,7 @@ private[spark] class BlockManager( */ def this( execId: String, - actorSystem: ActorSystem, + rpcEnv: RpcEnv, master: BlockManagerMaster, serializer: Serializer, conf: SparkConf, @@ -177,7 +178,7 @@ private[spark] class BlockManager( blockTransferService: BlockTransferService, securityManager: SecurityManager, numUsableCores: Int) = { - this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), + this(execId, rpcEnv, master, serializer, BlockManager.getMaxMemory(conf), conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, numUsableCores) } @@ -1207,7 +1208,7 @@ private[spark] class BlockManager( shuffleClient.close() } diskBlockManager.stop() - actorSystem.stop(slaveActor) + rpcEnv.stop(slaveActor) blockInfo.clear() memoryStore.clear() diskStore.clear() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index b63c7f191155..4743a9f64a34 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -20,20 +20,17 @@ package org.apache.spark.storage import scala.concurrent.{Await, Future} import scala.concurrent.ExecutionContext.Implicits.global -import akka.actor._ - import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util.AkkaUtils private[spark] class BlockManagerMaster( - var driverActor: ActorRef, + var driverActor: RpcEndpointRef, conf: SparkConf, isDriver: Boolean) extends Logging { - private val AKKA_RETRY_ATTEMPTS: Int = AkkaUtils.numRetries(conf) - private val AKKA_RETRY_INTERVAL_MS: Int = AkkaUtils.retryWaitMs(conf) val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster" @@ -46,7 +43,7 @@ class BlockManagerMaster( } /** Register the BlockManager's id with the driver. */ - def registerBlockManager(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { + def registerBlockManager(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: RpcEndpointRef) { logInfo("Trying to register BlockManager") tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveActor)) logInfo("Registered BlockManager") @@ -218,8 +215,7 @@ class BlockManagerMaster( * throw a SparkException if this fails. */ private def askDriverWithReply[T](message: Any): T = { - AkkaUtils.askWithReply(message, driverActor, AKKA_RETRY_ATTEMPTS, AKKA_RETRY_INTERVAL_MS, - timeout) + driverActor.askWithReply(message) } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 9cbda41223a8..fa7a25608e57 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -17,29 +17,34 @@ package org.apache.spark.storage +import java.util.concurrent.{ScheduledFuture, TimeUnit, Executors} import java.util.{HashMap => JHashMap} import scala.collection.mutable import scala.collection.JavaConversions._ -import scala.concurrent.Future -import scala.concurrent.duration._ +import scala.concurrent.{ExecutionContext, Future} -import akka.actor.{Actor, ActorRef, Cancellable} -import akka.pattern.ask - -import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rpc.{RpcEnv, RpcEndpointRef, RpcEndpoint} import org.apache.spark.scheduler._ import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, Utils} +import org.apache.spark.util.Utils /** * BlockManagerMasterActor is an actor on the master node to track statuses of * all slaves' block managers. */ private[spark] -class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus: LiveListenerBus) - extends Actor with ActorLogReceive with Logging { +class BlockManagerMasterActor(override val rpcEnv: RpcEnv, val isLocal: Boolean, conf: SparkConf, + listenerBus: LiveListenerBus) extends RpcEndpoint with Logging { + + val scheduler = Executors.newScheduledThreadPool(1, + Utils.namedThreadFactory("block-manager-master-actor-heartbeat-scheduler")) + + implicit val executor = ExecutionContext.fromExecutor( + Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors(), + Utils.namedThreadFactory("block-manager-master-actor-ask-timeout-executor"))) // Mapping from block manager id to the block manager's information. private val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo] @@ -50,85 +55,83 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus // Mapping from block id to the set of block managers that have the block. private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]] - private val akkaTimeout = AkkaUtils.askTimeout(conf) - val slaveTimeout = conf.getLong("spark.storage.blockManagerSlaveTimeoutMs", math.max(conf.getInt("spark.executor.heartbeatInterval", 10000) * 3, 45000)) val checkTimeoutInterval = conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60000) - var timeoutCheckingTask: Cancellable = null + var timeoutCheckingTask: ScheduledFuture[_] = null override def preStart() { - import context.dispatcher - timeoutCheckingTask = context.system.scheduler.schedule(0.seconds, - checkTimeoutInterval.milliseconds, self, ExpireDeadHosts) super.preStart() + timeoutCheckingTask = scheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = self.send(ExpireDeadHosts) + }, 0, checkTimeoutInterval, TimeUnit.MILLISECONDS) } - override def receiveWithLogging = { + override def receive(sender: RpcEndpointRef) = { case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) => register(blockManagerId, maxMemSize, slaveActor) - sender ! true + sender.send(true) case UpdateBlockInfo( blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize) => - sender ! updateBlockInfo( - blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize) + sender.send(updateBlockInfo( + blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize)) case GetLocations(blockId) => - sender ! getLocations(blockId) + sender.send(getLocations(blockId)) case GetLocationsMultipleBlockIds(blockIds) => - sender ! getLocationsMultipleBlockIds(blockIds) + sender.send(getLocationsMultipleBlockIds(blockIds)) case GetPeers(blockManagerId) => - sender ! getPeers(blockManagerId) + sender.send(getPeers(blockManagerId)) case GetActorSystemHostPortForExecutor(executorId) => - sender ! getActorSystemHostPortForExecutor(executorId) + sender.send(getActorSystemHostPortForExecutor(executorId)) case GetMemoryStatus => - sender ! memoryStatus + sender.send(memoryStatus) case GetStorageStatus => - sender ! storageStatus + sender.send(storageStatus) case GetBlockStatus(blockId, askSlaves) => - sender ! blockStatus(blockId, askSlaves) + sender.send(blockStatus(blockId, askSlaves)) case GetMatchingBlockIds(filter, askSlaves) => - sender ! getMatchingBlockIds(filter, askSlaves) + sender.send(getMatchingBlockIds(filter, askSlaves)) case RemoveRdd(rddId) => - sender ! removeRdd(rddId) + sender.send(removeRdd(rddId)) case RemoveShuffle(shuffleId) => - sender ! removeShuffle(shuffleId) + sender.send(removeShuffle(shuffleId)) case RemoveBroadcast(broadcastId, removeFromDriver) => - sender ! removeBroadcast(broadcastId, removeFromDriver) + sender.send(removeBroadcast(broadcastId, removeFromDriver)) case RemoveBlock(blockId) => removeBlockFromWorkers(blockId) - sender ! true + sender.send(true) case RemoveExecutor(execId) => removeExecutor(execId) - sender ! true + sender.send(true) case StopBlockManagerMaster => - sender ! true + sender.send(true) if (timeoutCheckingTask != null) { - timeoutCheckingTask.cancel() + timeoutCheckingTask.cancel(true) } - context.stop(self) + stop() case ExpireDeadHosts => expireDeadHosts() case BlockManagerHeartbeat(blockManagerId) => - sender ! heartbeatReceived(blockManagerId) + sender.send(heartbeatReceived(blockManagerId)) case other => logWarning("Got unknown message: " + other) @@ -149,22 +152,20 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus // Ask the slaves to remove the RDD, and put the result in a sequence of Futures. // The dispatcher is used as an implicit argument into the Future sequence construction. - import context.dispatcher val removeMsg = RemoveRdd(rddId) Future.sequence( blockManagerInfo.values.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] + bm.slaveActor.ask[Int](removeMsg) }.toSeq ) } private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = { // Nothing to do in the BlockManagerMasterActor data structures - import context.dispatcher val removeMsg = RemoveShuffle(shuffleId) Future.sequence( blockManagerInfo.values.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Boolean] + bm.slaveActor.ask[Boolean](removeMsg) }.toSeq ) } @@ -175,14 +176,13 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus * from the executors, but not from the driver. */ private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean): Future[Seq[Int]] = { - import context.dispatcher val removeMsg = RemoveBroadcast(broadcastId, removeFromDriver) val requiredBlockManagers = blockManagerInfo.values.filter { info => removeFromDriver || !info.blockManagerId.isDriver } Future.sequence( requiredBlockManagers.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] + bm.slaveActor.ask[Int](removeMsg) }.toSeq ) } @@ -252,7 +252,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus // Remove the block from the slave's BlockManager. // Doesn't actually wait for a confirmation and the message might get lost. // If message loss becomes frequent, we should add retry logic here. - blockManager.get.slaveActor.ask(RemoveBlock(blockId))(akkaTimeout) + blockManager.get.slaveActor.send(RemoveBlock(blockId)) } } } @@ -282,7 +282,6 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus private def blockStatus( blockId: BlockId, askSlaves: Boolean): Map[BlockManagerId, Future[Option[BlockStatus]]] = { - import context.dispatcher val getBlockStatus = GetBlockStatus(blockId) /* * Rather than blocking on the block status query, master actor should simply return @@ -292,7 +291,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus blockManagerInfo.values.map { info => val blockStatusFuture = if (askSlaves) { - info.slaveActor.ask(getBlockStatus)(akkaTimeout).mapTo[Option[BlockStatus]] + info.slaveActor.ask[Option[BlockStatus]](getBlockStatus) } else { Future { info.getStatus(blockId) } } @@ -311,13 +310,12 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus private def getMatchingBlockIds( filter: BlockId => Boolean, askSlaves: Boolean): Future[Seq[BlockId]] = { - import context.dispatcher val getMatchingBlockIds = GetMatchingBlockIds(filter) Future.sequence( blockManagerInfo.values.map { info => val future = if (askSlaves) { - info.slaveActor.ask(getMatchingBlockIds)(akkaTimeout).mapTo[Seq[BlockId]] + info.slaveActor.ask[Seq[BlockId]](getMatchingBlockIds) } else { Future { info.blocks.keys.filter(filter).toSeq } } @@ -326,7 +324,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus ).map(_.flatten.toSeq) } - private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { + private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: RpcEndpointRef) { val time = System.currentTimeMillis() if (!blockManagerInfo.contains(id)) { blockManagerIdByExecutor.get(id.executorId) match { @@ -421,8 +419,8 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus for ( blockManagerId <- blockManagerIdByExecutor.get(executorId); info <- blockManagerInfo.get(blockManagerId); - host <- info.slaveActor.path.address.host; - port <- info.slaveActor.path.address.port + host <- info.slaveActor.host; + port <- info.slaveActor.port ) yield { (host, port) } @@ -447,7 +445,7 @@ private[spark] class BlockManagerInfo( val blockManagerId: BlockManagerId, timeMs: Long, val maxMem: Long, - val slaveActor: ActorRef) + val slaveActor: RpcEndpointRef) extends Logging { private var _lastSeenMs: Long = timeMs diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 3f32099d08cc..26e1b311c032 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -19,8 +19,7 @@ package org.apache.spark.storage import java.io.{Externalizable, ObjectInput, ObjectOutput} -import akka.actor.ActorRef - +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils private[spark] object BlockManagerMessages { @@ -52,7 +51,7 @@ private[spark] object BlockManagerMessages { case class RegisterBlockManager( blockManagerId: BlockManagerId, maxMemSize: Long, - sender: ActorRef) + sender: RpcEndpointRef) extends ToBlockManagerMaster case class UpdateBlockInfo( diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index 8462871e798a..b2462e1a3291 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -17,13 +17,15 @@ package org.apache.spark.storage -import scala.concurrent.Future +import java.util.concurrent.Executors -import akka.actor.{ActorRef, Actor} +import scala.concurrent.ExecutionContext +import scala.concurrent.Future import org.apache.spark.{Logging, MapOutputTracker, SparkEnv} +import org.apache.spark.rpc.{RpcEnv, RpcEndpoint, RpcEndpointRef} import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.ActorLogReceive +import org.apache.spark.util.Utils /** * An actor to take commands from the master to execute options. For example, @@ -31,14 +33,15 @@ import org.apache.spark.util.ActorLogReceive */ private[storage] class BlockManagerSlaveActor( + override val rpcEnv: RpcEnv, blockManager: BlockManager, - mapOutputTracker: MapOutputTracker) - extends Actor with ActorLogReceive with Logging { + mapOutputTracker: MapOutputTracker) extends RpcEndpoint with Logging { - import context.dispatcher + implicit val executor = ExecutionContext.fromExecutorService(Executors.newScheduledThreadPool(1, + Utils.namedThreadFactory("block-manager-slave-actor-executor"))) // Operations that involve removing blocks may be slow and should be done asynchronously - override def receiveWithLogging = { + override def receive(sender: RpcEndpointRef) = { case RemoveBlock(blockId) => doAsync[Boolean]("removing block " + blockId, sender) { blockManager.removeBlock(blockId) @@ -64,25 +67,25 @@ class BlockManagerSlaveActor( } case GetBlockStatus(blockId, _) => - sender ! blockManager.getStatus(blockId) + sender.send(blockManager.getStatus(blockId)) case GetMatchingBlockIds(filter, _) => - sender ! blockManager.getMatchingBlockIds(filter) + sender.send(blockManager.getMatchingBlockIds(filter)) } - private def doAsync[T](actionMessage: String, responseActor: ActorRef)(body: => T) { + private def doAsync[T](actionMessage: String, responseActor: RpcEndpointRef)(body: => T) { val future = Future { logDebug(actionMessage) body } future.onSuccess { case response => logDebug("Done " + actionMessage + ", response is " + response) - responseActor ! response + responseActor.send(response) logDebug("Sent response: " + response + " to " + responseActor) } future.onFailure { case t: Throwable => logError("Error in " + actionMessage, t) - responseActor ! null.asInstanceOf[T] + responseActor.send(null.asInstanceOf[T]) } } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index c2903c859799..6b9112dc2bc0 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -22,7 +22,7 @@ import scala.concurrent.duration._ import scala.language.implicitConversions import scala.language.postfixOps -import akka.actor.{ActorSystem, Props} +import akka.actor.ActorSystem import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfter, FunSuite, Matchers, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ @@ -30,6 +30,8 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SecurityManager} import org.apache.spark.network.BlockTransferService import org.apache.spark.network.nio.NioBlockTransferService +import org.apache.spark.rpc.RpcEnv +import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.hash.HashShuffleManager @@ -41,6 +43,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd private val conf = new SparkConf(false) var actorSystem: ActorSystem = null + var rpcEnv: RpcEnv = null var master: BlockManagerMaster = null val securityMgr = new SecurityManager(conf) val mapOutputTracker = new MapOutputTrackerMaster(conf) @@ -61,7 +64,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) - val store = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, + val store = new BlockManager(name, rpcEnv, master, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) store.initialize("app-id") allStores += store @@ -72,6 +75,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd val (actorSystem, boundPort) = AkkaUtils.createActorSystem( "test", "localhost", 0, conf = conf, securityManager = securityMgr) this.actorSystem = actorSystem + this.rpcEnv = new AkkaRpcEnv(actorSystem, conf) conf.set("spark.authenticate", "false") conf.set("spark.driver.port", boundPort.toString) @@ -84,7 +88,8 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd conf.set("spark.storage.cachedPeersTtl", "10") master = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), + rpcEnv.setupEndpoint("block-manager-master", + new BlockManagerMasterActor(rpcEnv, true, conf, new LiveListenerBus)), conf, true) allStores.clear() } @@ -92,6 +97,8 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd after { allStores.foreach { _.stop() } allStores.clear() + rpcEnv.stopAll() + rpcEnv = null actorSystem.shutdown() actorSystem.awaitTermination() actorSystem = null @@ -262,7 +269,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd val failableTransfer = mock(classOf[BlockTransferService]) // this wont actually work when(failableTransfer.hostName).thenReturn("some-hostname") when(failableTransfer.port).thenReturn(1000) - val failableStore = new BlockManager("failable-store", actorSystem, master, serializer, + val failableStore = new BlockManager("failable-store", rpcEnv, master, serializer, 10000, conf, mapOutputTracker, shuffleManager, failableTransfer, securityMgr, 0) failableStore.initialize("app-id") allStores += failableStore // so that this gets stopped after test diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index ffe6f039145e..169fb727f1ee 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -22,13 +22,11 @@ import java.util.Arrays import java.util.concurrent.TimeUnit import scala.collection.mutable.ArrayBuffer -import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.implicitConversions import scala.language.postfixOps import akka.actor._ -import akka.pattern.ask import akka.util.Timeout import org.mockito.Mockito.{mock, when} @@ -40,6 +38,8 @@ import org.scalatest.concurrent.Timeouts._ import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SecurityManager} import org.apache.spark.executor.DataReadMethod import org.apache.spark.network.nio.NioBlockTransferService +import org.apache.spark.rpc.akka.AkkaRpcEnv +import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.shuffle.hash.HashShuffleManager @@ -53,6 +53,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach private val conf = new SparkConf(false) var store: BlockManager = null var store2: BlockManager = null + var rpcEnv: RpcEnv = null var actorSystem: ActorSystem = null var master: BlockManagerMaster = null conf.set("spark.authenticate", "false") @@ -72,7 +73,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) - val manager = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, + val manager = new BlockManager(name, rpcEnv, master, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) manager.initialize("app-id") manager @@ -82,6 +83,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach val (actorSystem, boundPort) = AkkaUtils.createActorSystem( "test", "localhost", 0, conf = conf, securityManager = securityMgr) this.actorSystem = actorSystem + this.rpcEnv = new AkkaRpcEnv(actorSystem, conf) // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case System.setProperty("os.arch", "amd64") @@ -92,7 +94,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach conf.set("spark.storage.unrollMemoryThreshold", "512") master = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), + rpcEnv.setupEndpoint("block-manager-master", + new BlockManagerMasterActor(rpcEnv, true, conf, new LiveListenerBus)), conf, true) val initialize = PrivateMethod[Unit]('initialize) @@ -108,6 +111,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach store2.stop() store2 = null } + rpcEnv.stopAll() + rpcEnv = null actorSystem.shutdown() actorSystem.awaitTermination() actorSystem = null @@ -357,11 +362,9 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach master.removeExecutor(store.blockManagerId.executorId) assert(master.getLocations("a1").size == 0, "a1 was not removed from master") - implicit val timeout = Timeout(30, TimeUnit.SECONDS) - val reregister = !Await.result( - master.driverActor ? BlockManagerHeartbeat(store.blockManagerId), - timeout.duration).asInstanceOf[Boolean] - assert(reregister == true) + val reregister = ! master.driverActor.askWithReply[Boolean]( + BlockManagerHeartbeat(store.blockManagerId)) + assert(reregister === true) } test("reregistration on block update") { @@ -785,7 +788,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach test("block store put failure") { // Use Java serializer so we can create an unserializable error. val transfer = new NioBlockTransferService(conf, securityMgr) - store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, actorSystem, master, + store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master, new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) From c1d3df82a7f9f055e675db067a868c47eea7c2f3 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sun, 4 Jan 2015 17:55:14 +0800 Subject: [PATCH 20/36] Change WorkerWatcher to use RpcEndpoint --- .../spark/deploy/worker/WorkerWatcher.scala | 44 +++++++++++-------- .../CoarseGrainedExecutorBackend.scala | 4 +- .../org/apache/spark/executor/Executor.scala | 2 - .../deploy/worker/WorkerWatcherSuite.scala | 28 +++++++----- 4 files changed, 43 insertions(+), 35 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 63a8ac817b61..145ae97ce3b1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -17,26 +17,25 @@ package org.apache.spark.deploy.worker -import akka.actor.{Actor, Address, AddressFromURIString} -import akka.remote.{AssociatedEvent, AssociationErrorEvent, AssociationEvent, DisassociatedEvent, RemotingLifecycleEvent} +import akka.remote.{AssociatedEvent, AssociationErrorEvent, AssociationEvent} import org.apache.spark.Logging import org.apache.spark.deploy.DeployMessages.SendHeartbeat -import org.apache.spark.util.ActorLogReceive +import org.apache.spark.rpc.{RpcEnv, RpcEndpointRef, RpcEndpoint} /** * Actor which connects to a worker process and terminates the JVM if the connection is severed. * Provides fate sharing between a worker and its associated child processes. */ -private[spark] class WorkerWatcher(workerUrl: String) - extends Actor with ActorLogReceive with Logging { +private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: String) + extends RpcEndpoint with Logging { override def preStart() { - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - logInfo(s"Connecting to worker $workerUrl") - val worker = context.actorSelection(workerUrl) - worker ! SendHeartbeat // need to send a message here to initiate connection + if (!isTesting) { + val worker = rpcEnv.setupEndpointRefByUrl(workerUrl) + worker.send(SendHeartbeat) // need to send a message here to initiate connection + } } // Used to avoid shutting down JVM during tests @@ -45,30 +44,37 @@ private[spark] class WorkerWatcher(workerUrl: String) private var isTesting = false // Lets us filter events only from the worker's actor system - private val expectedHostPort = AddressFromURIString(workerUrl).hostPort - private def isWorker(address: Address) = address.hostPort == expectedHostPort + private val expectedHostPort = new java.net.URI(workerUrl) + private def isWorker(address: String) = { + val uri = new java.net.URI(address) + expectedHostPort.getHost == uri.getHost && expectedHostPort.getPort == uri.getPort + } def exitNonZero() = if (isTesting) isShutDown = true else System.exit(-1) - override def receiveWithLogging = { - case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => + override def receive(sender: RpcEndpointRef) = { + case AssociatedEvent(localAddress, remoteAddress, inbound) + if isWorker(remoteAddress.toString) => logInfo(s"Successfully connected to $workerUrl") case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) - if isWorker(remoteAddress) => + if isWorker(remoteAddress.toString) => // These logs may not be seen if the worker (and associated pipe) has died logError(s"Could not initialize connection to worker $workerUrl. Exiting.") logError(s"Error was: $cause") exitNonZero() - case DisassociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => - // This log message will never be seen - logError(s"Lost connection to worker actor $workerUrl. Exiting.") - exitNonZero() - case e: AssociationEvent => // pass through association events relating to other remote actor systems case e => logWarning(s"Received unexpected actor system event: $e") } + + override def remoteConnectionTerminated(remoteAddress: String): Unit = { + if(isWorker(remoteAddress)) { + // This log message will never be seen + logError(s"Lost connection to worker actor $workerUrl. Exiting.") + exitNonZero() + } + } } diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 75fbee49c621..30f4e2a07d10 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -19,8 +19,6 @@ package org.apache.spark.executor import java.nio.ByteBuffer -import akka.actor.Props - import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv} import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil @@ -144,7 +142,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend( driverUrl, executorId, sparkHostPort, cores, env)) workerUrl.foreach { url => - env.actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher") + env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url)) } env.actorSystem.awaitTermination() } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 928653aa1168..5301746b8432 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -26,8 +26,6 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal -import akka.actor.Props - import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler._ diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala index 5e538d6fab2a..a3dcf9b14033 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala @@ -17,32 +17,38 @@ package org.apache.spark.deploy.worker -import akka.actor.{ActorSystem, AddressFromURIString, Props} -import akka.testkit.TestActorRef -import akka.remote.DisassociatedEvent +import akka.actor.{ActorSystem, AddressFromURIString} +import org.apache.spark.SparkConf +import org.apache.spark.rpc.akka.AkkaRpcEnv import org.scalatest.FunSuite class WorkerWatcherSuite extends FunSuite { test("WorkerWatcher shuts down on valid disassociation") { val actorSystem = ActorSystem("test") + val rpcEnv = new AkkaRpcEnv(actorSystem, new SparkConf()) val targetWorkerUrl = "akka://1.2.3.4/user/Worker" val targetWorkerAddress = AddressFromURIString(targetWorkerUrl) - val actorRef = TestActorRef[WorkerWatcher](Props(classOf[WorkerWatcher], targetWorkerUrl))(actorSystem) - val workerWatcher = actorRef.underlyingActor + val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) workerWatcher.setTesting(testing = true) - actorRef.underlyingActor.receive(new DisassociatedEvent(null, targetWorkerAddress, false)) - assert(actorRef.underlyingActor.isShutDown) + rpcEnv.setupEndpoint("worker-watcher", workerWatcher) + workerWatcher.remoteConnectionTerminated(targetWorkerAddress.toString) + assert(workerWatcher.isShutDown) + rpcEnv.stopAll() + actorSystem.shutdown() } test("WorkerWatcher stays alive on invalid disassociation") { val actorSystem = ActorSystem("test") + val rpcEnv = new AkkaRpcEnv(actorSystem, new SparkConf()) val targetWorkerUrl = "akka://1.2.3.4/user/Worker" val otherAkkaURL = "akka://4.3.2.1/user/OtherActor" val otherAkkaAddress = AddressFromURIString(otherAkkaURL) - val actorRef = TestActorRef[WorkerWatcher](Props(classOf[WorkerWatcher], targetWorkerUrl))(actorSystem) - val workerWatcher = actorRef.underlyingActor + val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) workerWatcher.setTesting(testing = true) - actorRef.underlyingActor.receive(new DisassociatedEvent(null, otherAkkaAddress, false)) - assert(!actorRef.underlyingActor.isShutDown) + rpcEnv.setupEndpoint("worker-watcher", workerWatcher) + workerWatcher.remoteConnectionTerminated(otherAkkaAddress.toString) + assert(!workerWatcher.isShutDown) + rpcEnv.stopAll() + actorSystem.shutdown() } } From 20682d1fd4a211b58742f0effeb611df0a0d00ba Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sun, 4 Jan 2015 19:46:50 +0800 Subject: [PATCH 21/36] Change Master to use RpcEndpoint --- .../spark/deploy/master/ApplicationInfo.scala | 6 +- .../apache/spark/deploy/master/Master.scala | 162 +++++++++--------- .../spark/deploy/master/WorkerInfo.scala | 6 +- .../deploy/master/ui/ApplicationPage.scala | 11 +- .../spark/deploy/master/ui/MasterPage.scala | 8 +- .../spark/deploy/master/ui/MasterWebUI.scala | 2 +- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 2 +- .../storage/BlockManagerMasterActor.scala | 2 +- 8 files changed, 96 insertions(+), 103 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index ad7d81747c37..149a9fccde16 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -22,10 +22,8 @@ import java.util.Date import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import akka.actor.ActorRef - -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.ApplicationDescription +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils private[spark] class ApplicationInfo( @@ -33,7 +31,7 @@ private[spark] class ApplicationInfo( val id: String, val desc: ApplicationDescription, val submitDate: Date, - val driver: ActorRef, + val driver: RpcEndpointRef, defaultCores: Int) extends Serializable { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index e8a5cfc746fe..c34d2493c532 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -20,17 +20,15 @@ package org.apache.spark.deploy.master import java.io.FileNotFoundException import java.net.URLEncoder import java.text.SimpleDateFormat +import java.util.concurrent.{ScheduledFuture, TimeUnit, Executors} import java.util.Date +import org.apache.spark.rpc.akka.AkkaRpcEnv + import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.concurrent.Await -import scala.concurrent.duration._ -import scala.language.postfixOps import scala.util.Random -import akka.actor._ -import akka.pattern.ask -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} +import akka.actor.ActorSystem import akka.serialization.Serialization import akka.serialization.SerializationExtension import org.apache.hadoop.fs.Path @@ -44,18 +42,23 @@ import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.deploy.master.MasterMessages._ import org.apache.spark.deploy.master.ui.MasterWebUI import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.rpc.{RpcEndpoint, RpcEnv, RpcEndpointRef} import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus} import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} +import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} private[spark] class Master( + override val rpcEnv: RpcEnv, host: String, port: Int, webUiPort: Int, val securityMgr: SecurityManager) - extends Actor with ActorLogReceive with Logging with LeaderElectable { + extends RpcEndpoint with Logging with LeaderElectable { + + val scheduler = Executors.newScheduledThreadPool(1, + Utils.namedThreadFactory("check-worker-timeout")) - import context.dispatcher // to use Akka's scheduler.schedule() + private def internalActorSystem: ActorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem val conf = new SparkConf val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) @@ -69,12 +72,12 @@ private[spark] class Master( val workers = new HashSet[WorkerInfo] val idToWorker = new HashMap[String, WorkerInfo] - val addressToWorker = new HashMap[Address, WorkerInfo] + val addressToWorker = new HashMap[String, WorkerInfo] val apps = new HashSet[ApplicationInfo] val idToApp = new HashMap[String, ApplicationInfo] - val actorToApp = new HashMap[ActorRef, ApplicationInfo] - val addressToApp = new HashMap[Address, ApplicationInfo] + val actorToApp = new HashMap[RpcEndpointRef, ApplicationInfo] + val addressToApp = new HashMap[String, ApplicationInfo] val waitingApps = new ArrayBuffer[ApplicationInfo] val completedApps = new ArrayBuffer[ApplicationInfo] var nextAppNumber = 0 @@ -108,7 +111,7 @@ private[spark] class Master( var leaderElectionAgent: LeaderElectionAgent = _ - private var recoveryCompletionTask: Cancellable = _ + private var recoveryCompletionTask: ScheduledFuture[_] = _ // As a temporary workaround before better ways of configuring memory, we allow users to set // a flag that will perform round-robin scheduling across the nodes (spreading out each app @@ -125,10 +128,11 @@ private[spark] class Master( logInfo("Starting Spark master at " + masterUrl) logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}") // Listen for remote client disconnection events, since they don't go through Akka's watch() - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) webUi.bind() masterWebUiUrl = "http://" + masterPublicAddress + ":" + webUi.boundPort - context.system.scheduler.schedule(0 millis, WORKER_TIMEOUT millis, self, CheckForWorkerTimeOut) + scheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = self.send(CheckForWorkerTimeOut) + }, 0, WORKER_TIMEOUT, TimeUnit.MILLISECONDS) masterMetricsSystem.registerSource(masterSource) masterMetricsSystem.start() @@ -142,16 +146,16 @@ private[spark] class Master( case "ZOOKEEPER" => logInfo("Persisting recovery state to ZooKeeper") val zkFactory = - new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(context.system)) + new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(internalActorSystem)) (zkFactory.createPersistenceEngine(), zkFactory.createLeaderElectionAgent(this)) case "FILESYSTEM" => val fsFactory = - new FileSystemRecoveryModeFactory(conf, SerializationExtension(context.system)) + new FileSystemRecoveryModeFactory(conf, SerializationExtension(internalActorSystem)) (fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this)) case "CUSTOM" => val clazz = Class.forName(conf.get("spark.deploy.recoveryMode.factory")) val factory = clazz.getConstructor(conf.getClass, Serialization.getClass) - .newInstance(conf, SerializationExtension(context.system)) + .newInstance(conf, SerializationExtension(internalActorSystem)) .asInstanceOf[StandaloneRecoveryModeFactory] (factory.createPersistenceEngine(), factory.createLeaderElectionAgent(this)) case _ => @@ -161,9 +165,9 @@ private[spark] class Master( leaderElectionAgent = leaderElectionAgent_ } - override def preRestart(reason: Throwable, message: Option[Any]) { - super.preRestart(reason, message) // calls postStop()! - logError("Master actor restarted due to exception", reason) + override def onError(reason: Throwable) { + logError("Master actor is crashed due to exception", reason) + throw reason // throw it so that the master will be restarted } override def postStop() { @@ -171,7 +175,7 @@ private[spark] class Master( applicationMetricsSystem.report() // prevent the CompleteRecovery message sending to restarted master if (recoveryCompletionTask != null) { - recoveryCompletionTask.cancel() + recoveryCompletionTask.cancel(true) } webUi.stop() masterMetricsSystem.stop() @@ -181,14 +185,14 @@ private[spark] class Master( } override def electedLeader() { - self ! ElectedLeader + self.send(ElectedLeader) } override def revokedLeadership() { - self ! RevokedLeadership + self.send(RevokedLeadership) } - override def receiveWithLogging = { + override def receive(sender: RpcEndpointRef) = { case ElectedLeader => { val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData() state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) { @@ -199,8 +203,9 @@ private[spark] class Master( logInfo("I have been elected leader! New state: " + state) if (state == RecoveryState.RECOVERING) { beginRecovery(storedApps, storedDrivers, storedWorkers) - recoveryCompletionTask = context.system.scheduler.scheduleOnce(WORKER_TIMEOUT millis, self, - CompleteRecovery) + recoveryCompletionTask = scheduler.schedule(new Runnable { + override def run(): Unit = self.send(CompleteRecovery) + }, WORKER_TIMEOUT, TimeUnit.MILLISECONDS) } } @@ -218,20 +223,20 @@ private[spark] class Master( if (state == RecoveryState.STANDBY) { // ignore, don't send response } else if (idToWorker.contains(id)) { - sender ! RegisterWorkerFailed("Duplicate worker ID") + sender.send(RegisterWorkerFailed("Duplicate worker ID")) } else { val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory, sender, workerUiPort, publicAddress) if (registerWorker(worker)) { persistenceEngine.addWorker(worker) - sender ! RegisteredWorker(masterUrl, masterWebUiUrl) + sender.send(RegisteredWorker(masterUrl, masterWebUiUrl)) schedule() } else { - val workerAddress = worker.actor.path.address + val workerAddress = worker.actor.address logWarning("Worker registration failed. Attempted to re-register worker at same " + "address: " + workerAddress) - sender ! RegisterWorkerFailed("Attempted to re-register worker at same address: " - + workerAddress) + sender.send(RegisterWorkerFailed("Attempted to re-register worker at same address: " + + workerAddress)) } } } @@ -239,7 +244,7 @@ private[spark] class Master( case RequestSubmitDriver(description) => { if (state != RecoveryState.ALIVE) { val msg = s"Can only accept driver submissions in ALIVE state. Current state: $state." - sender ! SubmitDriverResponse(false, None, msg) + sender.send(SubmitDriverResponse(false, None, msg)) } else { logInfo("Driver submitted " + description.command.mainClass) val driver = createDriver(description) @@ -251,15 +256,15 @@ private[spark] class Master( // TODO: It might be good to instead have the submission client poll the master to determine // the current status of the driver. For now it's simply "fire and forget". - sender ! SubmitDriverResponse(true, Some(driver.id), - s"Driver successfully submitted as ${driver.id}") + sender.send(SubmitDriverResponse(true, Some(driver.id), + s"Driver successfully submitted as ${driver.id}")) } } case RequestKillDriver(driverId) => { if (state != RecoveryState.ALIVE) { val msg = s"Can only kill drivers in ALIVE state. Current state: $state." - sender ! KillDriverResponse(driverId, success = false, msg) + sender.send(KillDriverResponse(driverId, success = false, msg)) } else { logInfo("Asked to kill driver " + driverId) val driver = drivers.find(_.id == driverId) @@ -267,23 +272,23 @@ private[spark] class Master( case Some(d) => if (waitingDrivers.contains(d)) { waitingDrivers -= d - self ! DriverStateChanged(driverId, DriverState.KILLED, None) + self.send(DriverStateChanged(driverId, DriverState.KILLED, None)) } else { // We just notify the worker to kill the driver here. The final bookkeeping occurs // on the return path when the worker submits a state change back to the master // to notify it that the driver was successfully killed. d.worker.foreach { w => - w.actor ! KillDriver(driverId) + w.actor.send(KillDriver(driverId)) } } // TODO: It would be nice for this to be a synchronous response val msg = s"Kill request for $driverId submitted" logInfo(msg) - sender ! KillDriverResponse(driverId, success = true, msg) + sender.send(KillDriverResponse(driverId, success = true, msg)) case None => val msg = s"Driver $driverId has already finished or does not exist" logWarning(msg) - sender ! KillDriverResponse(driverId, success = false, msg) + sender.send(KillDriverResponse(driverId, success = false, msg)) } } } @@ -291,10 +296,10 @@ private[spark] class Master( case RequestDriverStatus(driverId) => { (drivers ++ completedDrivers).find(_.id == driverId) match { case Some(driver) => - sender ! DriverStatusResponse(found = true, Some(driver.state), - driver.worker.map(_.id), driver.worker.map(_.hostPort), driver.exception) + sender.send(DriverStatusResponse(found = true, Some(driver.state), + driver.worker.map(_.id), driver.worker.map(_.hostPort), driver.exception)) case None => - sender ! DriverStatusResponse(found = false, None, None, None, None) + sender.send(DriverStatusResponse(found = false, None, None, None, None)) } } @@ -307,7 +312,7 @@ private[spark] class Master( registerApplication(app) logInfo("Registered app " + description.name + " with ID " + app.id) persistenceEngine.addApplication(app) - sender ! RegisteredApplication(app.id, masterUrl) + sender.send(RegisteredApplication(app.id, masterUrl)) schedule() } } @@ -319,7 +324,7 @@ private[spark] class Master( val appInfo = idToApp(appId) exec.state = state if (state == ExecutorState.RUNNING) { appInfo.resetRetryCount() } - exec.application.driver ! ExecutorUpdated(execId, state, message, exitStatus) + exec.application.driver.send(ExecutorUpdated(execId, state, message, exitStatus)) if (ExecutorState.isFinished(state)) { // Remove this executor from the worker and app logInfo(s"Removing executor ${exec.fullId} because it is $state") @@ -364,7 +369,7 @@ private[spark] class Master( if (workers.map(_.id).contains(workerId)) { logWarning(s"Got heartbeat from unregistered worker $workerId." + " Asking it to re-register.") - sender ! ReconnectWorker(masterUrl) + sender.send(ReconnectWorker(masterUrl)) } else { logWarning(s"Got heartbeat from unregistered worker $workerId." + " This worker was never registered, so ignoring the heartbeat.") @@ -412,17 +417,9 @@ private[spark] class Master( if (canCompleteRecovery) { completeRecovery() } } - case DisassociatedEvent(_, address, _) => { - // The disconnected client could've been either a worker or an app; remove whichever it was - logInfo(s"$address got disassociated, removing it.") - addressToWorker.get(address).foreach(removeWorker) - addressToApp.get(address).foreach(finishApplication) - if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() } - } - case RequestMasterState => { - sender ! MasterStateResponse(host, port, workers.toArray, apps.toArray, completedApps.toArray, - drivers.toArray, completedDrivers.toArray, state) + sender.send(MasterStateResponse(host, port, workers.toArray, apps.toArray, completedApps.toArray, + drivers.toArray, completedDrivers.toArray, state)) } case CheckForWorkerTimeOut => { @@ -430,10 +427,18 @@ private[spark] class Master( } case RequestWebUIPort => { - sender ! WebUIPortResponse(webUi.boundPort) + sender.send(WebUIPortResponse(webUi.boundPort)) } } + override def remoteConnectionTerminated(address: String): Unit = { + // The disconnected client could've been either a worker or an app; remove whichever it was + logInfo(s"$address got disassociated, removing it.") + addressToWorker.get(address).foreach(removeWorker) + addressToApp.get(address).foreach(finishApplication) + if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() } + } + def canCompleteRecovery = workers.count(_.state == WorkerState.UNKNOWN) == 0 && apps.count(_.state == ApplicationState.UNKNOWN) == 0 @@ -445,7 +450,7 @@ private[spark] class Master( try { registerApplication(app) app.state = ApplicationState.UNKNOWN - app.driver ! MasterChanged(masterUrl, masterWebUiUrl) + app.driver.send(MasterChanged(masterUrl, masterWebUiUrl)) } catch { case e: Exception => logInfo("App " + app.id + " had exception on reconnect") } @@ -462,7 +467,7 @@ private[spark] class Master( try { registerWorker(worker) worker.state = WorkerState.UNKNOWN - worker.actor ! MasterChanged(masterUrl, masterWebUiUrl) + worker.actor.send(MasterChanged(masterUrl, masterWebUiUrl)) } catch { case e: Exception => logInfo("Worker " + worker.id + " had exception on reconnect") } @@ -584,10 +589,10 @@ private[spark] class Master( def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo) { logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) - worker.actor ! LaunchExecutor(masterUrl, - exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory) - exec.application.driver ! ExecutorAdded( - exec.id, worker.id, worker.hostPort, exec.cores, exec.memory) + worker.actor.send(LaunchExecutor(masterUrl, + exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory)) + exec.application.driver.send(ExecutorAdded( + exec.id, worker.id, worker.hostPort, exec.cores, exec.memory)) } def registerWorker(worker: WorkerInfo): Boolean = { @@ -599,7 +604,7 @@ private[spark] class Master( workers -= w } - val workerAddress = worker.actor.path.address + val workerAddress = worker.actor.address if (addressToWorker.contains(workerAddress)) { val oldWorker = addressToWorker(workerAddress) if (oldWorker.state == WorkerState.UNKNOWN) { @@ -622,11 +627,11 @@ private[spark] class Master( logInfo("Removing worker " + worker.id + " on " + worker.host + ":" + worker.port) worker.setState(WorkerState.DEAD) idToWorker -= worker.id - addressToWorker -= worker.actor.path.address + addressToWorker -= worker.actor.address for (exec <- worker.executors.values) { logInfo("Telling app of lost executor: " + exec.id) - exec.application.driver ! ExecutorUpdated( - exec.id, ExecutorState.LOST, Some("worker lost"), None) + exec.application.driver.send(ExecutorUpdated( + exec.id, ExecutorState.LOST, Some("worker lost"), None)) exec.application.removeExecutor(exec) } for (driver <- worker.drivers.values) { @@ -648,14 +653,14 @@ private[spark] class Master( schedule() } - def createApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = { + def createApplication(desc: ApplicationDescription, driver: RpcEndpointRef): ApplicationInfo = { val now = System.currentTimeMillis() val date = new Date(now) new ApplicationInfo(now, newApplicationId(date), desc, date, driver, defaultCores) } def registerApplication(app: ApplicationInfo): Unit = { - val appAddress = app.driver.path.address + val appAddress = app.driver.address if (addressToWorker.contains(appAddress)) { logInfo("Attempted to re-register application at same address: " + appAddress) return @@ -679,7 +684,7 @@ private[spark] class Master( apps -= app idToApp -= app.id actorToApp -= app.driver - addressToApp -= app.driver.path.address + addressToApp -= app.driver.address if (completedApps.size >= RETAINED_APPLICATIONS) { val toRemove = math.max(RETAINED_APPLICATIONS / 10, 1) completedApps.take(toRemove).foreach( a => { @@ -696,19 +701,19 @@ private[spark] class Master( for (exec <- app.executors.values) { exec.worker.removeExecutor(exec) - exec.worker.actor ! KillExecutor(masterUrl, exec.application.id, exec.id) + exec.worker.actor.send(KillExecutor(masterUrl, exec.application.id, exec.id)) exec.state = ExecutorState.KILLED } app.markFinished(state) if (state != ApplicationState.FINISHED) { - app.driver ! ApplicationRemoved(state.toString) + app.driver.send(ApplicationRemoved(state.toString)) } persistenceEngine.removeApplication(app) schedule() // Tell all workers that the application has finished, so they can clean up any app state. workers.foreach { w => - w.actor ! ApplicationFinished(app.id) + w.actor.send(ApplicationFinished(app.id)) } } } @@ -817,7 +822,7 @@ private[spark] class Master( logInfo("Launching driver " + driver.id + " on worker " + worker.id) worker.addDriver(driver) driver.worker = Some(worker) - worker.actor ! LaunchDriver(driver.id, driver.desc) + worker.actor.send(LaunchDriver(driver.id, driver.desc)) driver.state = DriverState.RUNNING } @@ -873,11 +878,10 @@ private[spark] object Master extends Logging { val securityMgr = new SecurityManager(conf) val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf, securityManager = securityMgr) - val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort, - securityMgr), actorName) - val timeout = AkkaUtils.askTimeout(conf) - val respFuture = actor.ask(RequestWebUIPort)(timeout) - val resp = Await.result(respFuture, timeout).asInstanceOf[WebUIPortResponse] + val rpcEnv = new AkkaRpcEnv(actorSystem, conf) + val actor = rpcEnv.setupEndpoint(actorName, + new Master(rpcEnv, host, boundPort, webUiPort, securityMgr)) + val resp = actor.askWithReply[WebUIPortResponse](RequestWebUIPort) (actorSystem, boundPort, resp.webUIBoundPort) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala index 473ddc23ff0f..0f1e51f88b72 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala @@ -19,9 +19,7 @@ package org.apache.spark.deploy.master import scala.collection.mutable -import akka.actor.ActorRef - -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils private[spark] class WorkerInfo( @@ -30,7 +28,7 @@ private[spark] class WorkerInfo( val port: Int, val cores: Int, val memory: Int, - val actor: ActorRef, + val actor: RpcEndpointRef, val webUiPort: Int, val publicAddress: String) extends Serializable { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 4588c130ef43..96827229eee3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -19,10 +19,8 @@ package org.apache.spark.deploy.master.ui import javax.servlet.http.HttpServletRequest -import scala.concurrent.Await import scala.xml.Node -import akka.pattern.ask import org.json4s.JValue import org.apache.spark.deploy.{ExecutorState, JsonProtocol} @@ -33,14 +31,12 @@ import org.apache.spark.util.Utils private[spark] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") { - private val master = parent.masterActorRef - private val timeout = parent.timeout + private def master = parent.masterActorRef /** Executor details for a particular application */ override def renderJson(request: HttpServletRequest): JValue = { val appId = request.getParameter("appId") - val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] - val state = Await.result(stateFuture, timeout) + val state = master.askWithReply[MasterStateResponse](RequestMasterState) val app = state.activeApps.find(_.id == appId).getOrElse({ state.completedApps.find(_.id == appId).getOrElse(null) }) @@ -50,8 +46,7 @@ private[spark] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app /** Executor details for a particular application */ def render(request: HttpServletRequest): Seq[Node] = { val appId = request.getParameter("appId") - val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] - val state = Await.result(stateFuture, timeout) + val state = master.askWithReply[MasterStateResponse](RequestMasterState) val app = state.activeApps.find(_.id == appId).getOrElse({ state.completedApps.find(_.id == appId).getOrElse(null) }) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 7ca3b08a2872..740fe3460cfb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -32,19 +32,17 @@ import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.util.Utils private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { - private val master = parent.masterActorRef + private def master = parent.masterActorRef private val timeout = parent.timeout override def renderJson(request: HttpServletRequest): JValue = { - val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] - val state = Await.result(stateFuture, timeout) + val state = master.askWithReply[MasterStateResponse](RequestMasterState) JsonProtocol.writeMasterState(state) } /** Index view listing applications and executors */ def render(request: HttpServletRequest): Seq[Node] = { - val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] - val state = Await.result(stateFuture, timeout) + val state = master.askWithReply[MasterStateResponse](RequestMasterState) val workerHeaders = Seq("Id", "Address", "State", "Cores", "Memory") val workers = state.workers.sortBy(_.id) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 73400c5affb5..3727bc7b9e83 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -30,7 +30,7 @@ private[spark] class MasterWebUI(val master: Master, requestedPort: Int) extends WebUI(master.securityMgr, requestedPort, master.conf, name = "MasterUI") with Logging { - val masterActorRef = master.self + def masterActorRef = master.self val timeout = AkkaUtils.askTimeout(master.conf) initialize() diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 5912901ff32e..fd99b62f4750 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -35,7 +35,7 @@ import org.apache.spark.util.AkkaUtils import scala.reflect.ClassTag import scala.util.control.NonFatal -class AkkaRpcEnv(actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { +class AkkaRpcEnv(val actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { // TODO Once finishing the new Rpc mechanism, make actorSystem be a private val override def setupEndpoint(name: String, endpointCreator: => RpcEndpoint): RpcEndpointRef = { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index fa7a25608e57..133fc3dfa8b7 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -43,7 +43,7 @@ class BlockManagerMasterActor(override val rpcEnv: RpcEnv, val isLocal: Boolean, Utils.namedThreadFactory("block-manager-master-actor-heartbeat-scheduler")) implicit val executor = ExecutionContext.fromExecutor( - Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors(), + Executors.newFixedThreadPool(Runtime.getRuntime.availableProcessors(), Utils.namedThreadFactory("block-manager-master-actor-ask-timeout-executor"))) // Mapping from block manager id to the block manager's information. From 7b43e395e0528a0184cc60e80b69226cd7844a37 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 5 Jan 2015 16:57:38 +0800 Subject: [PATCH 22/36] Add RpcAddress and change AppClient to use RpcEndpoint --- .../spark/deploy/client/AppClient.scala | 91 ++++++++++--------- .../spark/deploy/client/TestClient.scala | 4 +- .../apache/spark/deploy/master/Master.scala | 8 +- .../spark/deploy/worker/WorkerWatcher.scala | 14 +-- .../CoarseGrainedExecutorBackend.scala | 4 +- .../scala/org/apache/spark/rpc/RpcEnv.scala | 15 +-- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 10 +- .../CoarseGrainedSchedulerBackend.scala | 6 +- .../scheduler/cluster/ExecutorData.scala | 4 +- .../cluster/SparkDeploySchedulerBackend.scala | 2 +- .../storage/BlockManagerMasterActor.scala | 6 +- .../org/apache/spark/util/AkkaUtils.scala | 8 +- .../deploy/worker/WorkerWatcherSuite.scala | 11 ++- 13 files changed, 97 insertions(+), 86 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 4efebcaa350f..209632ec3e63 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -17,20 +17,19 @@ package org.apache.spark.deploy.client -import java.util.concurrent.TimeoutException +import java.util.concurrent.{ScheduledFuture, TimeUnit, Executors, TimeoutException} -import scala.concurrent.Await import scala.concurrent.duration._ import akka.actor._ -import akka.pattern.ask -import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent} +import akka.remote.AssociationErrorEvent import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.Master -import org.apache.spark.util.{ActorLogReceive, Utils, AkkaUtils} +import org.apache.spark.rpc.{RpcAddress, RpcEnv, RpcEndpointRef, RpcEndpoint} +import org.apache.spark.util.Utils /** * Interface allowing applications to speak with a Spark deploy cluster. Takes a master URL, @@ -40,74 +39,77 @@ import org.apache.spark.util.{ActorLogReceive, Utils, AkkaUtils} * @param masterUrls Each url should look like spark://host:port. */ private[spark] class AppClient( - actorSystem: ActorSystem, + rpcEnv: RpcEnv, masterUrls: Array[String], appDescription: ApplicationDescription, listener: AppClientListener, conf: SparkConf) extends Logging { - val REGISTRATION_TIMEOUT = 20.seconds + val REGISTRATION_TIMEOUT = 20.seconds.toMillis val REGISTRATION_RETRIES = 3 - var masterAddress: Address = null - var actor: ActorRef = null + var masterAddress: RpcAddress = null + var actor: RpcEndpointRef = null var appId: String = null var registered = false var activeMasterUrl: String = null - class ClientActor extends Actor with ActorLogReceive with Logging { - var master: ActorSelection = null + class ClientActor(override val rpcEnv: RpcEnv) extends RpcEndpoint with Logging { + var master: RpcEndpointRef = null var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times var alreadyDead = false // To avoid calling listener.dead() multiple times - var registrationRetryTimer: Option[Cancellable] = None + var registrationRetryTimer: Option[ScheduledFuture[_]] = None + + private val scheduler = + Executors.newScheduledThreadPool(1, Utils.namedThreadFactory("client-actor")) override def preStart() { - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) try { registerWithMaster() } catch { case e: Exception => logWarning("Failed to connect to master", e) markDisconnected() - context.stop(self) + stop() } } def tryRegisterAllMasters() { for (masterUrl <- masterUrls) { logInfo("Connecting to master " + masterUrl + "...") - val actor = context.actorSelection(Master.toAkkaUrl(masterUrl)) - actor ! RegisterApplication(appDescription) + val actor = rpcEnv.setupEndpointRefByUrl(Master.toAkkaUrl(masterUrl)) + actor.send(RegisterApplication(appDescription)) } } def registerWithMaster() { tryRegisterAllMasters() - import context.dispatcher var retries = 0 registrationRetryTimer = Some { - context.system.scheduler.schedule(REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT) { - Utils.tryOrExit { - retries += 1 - if (registered) { - registrationRetryTimer.foreach(_.cancel()) - } else if (retries >= REGISTRATION_RETRIES) { - markDead("All masters are unresponsive! Giving up.") - } else { - tryRegisterAllMasters() + scheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = { + Utils.tryOrExit { + retries += 1 + if (registered) { + registrationRetryTimer.foreach(_.cancel(true)) + } else if (retries >= REGISTRATION_RETRIES) { + markDead("All masters are unresponsive! Giving up.") + } else { + tryRegisterAllMasters() + } } } - } + }, REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT, TimeUnit.MILLISECONDS) } } def changeMaster(url: String) { activeMasterUrl = url - master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl)) + master = rpcEnv.setupEndpointRefByUrl(Master.toAkkaUrl(activeMasterUrl)) masterAddress = activeMasterUrl match { case Master.sparkUrlRegex(host, port) => - Address("akka.tcp", Master.systemName, host, port.toInt) + RpcAddress(host, port.toInt) case x => throw new SparkException("Invalid spark URL: " + x) } @@ -119,7 +121,7 @@ private[spark] class AppClient( .contains(remoteUrl.hostPort) } - override def receiveWithLogging = { + override def receive(sender: RpcEndpointRef) = { case RegisteredApplication(appId_, masterUrl) => appId = appId_ registered = true @@ -128,13 +130,13 @@ private[spark] class AppClient( case ApplicationRemoved(message) => markDead("Master removed our application: %s".format(message)) - context.stop(self) + stop() case ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) => val fullId = appId + "/" + id logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort, cores)) - master ! ExecutorStateChanged(appId, id, ExecutorState.RUNNING, None, None) + master.send(ExecutorStateChanged(appId, id, ExecutorState.RUNNING, None, None)) listener.executorAdded(fullId, workerId, hostPort, cores, memory) case ExecutorUpdated(id, state, message, exitStatus) => @@ -149,19 +151,22 @@ private[spark] class AppClient( logInfo("Master has changed, new master is at " + masterUrl) changeMaster(masterUrl) alreadyDisconnected = false - sender ! MasterChangeAcknowledged(appId) - - case DisassociatedEvent(_, address, _) if address == masterAddress => - logWarning(s"Connection to $address failed; waiting for master to reconnect...") - markDisconnected() + sender.send(MasterChangeAcknowledged(appId)) case AssociationErrorEvent(cause, _, address, _, _) if isPossibleMaster(address) => logWarning(s"Could not connect to $address: $cause") case StopAppClient => markDead("Application has been stopped.") - sender ! true - context.stop(self) + sender.send(true) + stop() + } + + override def remoteConnectionTerminated(address: RpcAddress): Unit = { + if (address == masterAddress) { + logWarning(s"Connection to $address failed; waiting for master to reconnect...") + markDisconnected() + } } /** @@ -182,22 +187,20 @@ private[spark] class AppClient( } override def postStop() { - registrationRetryTimer.foreach(_.cancel()) + registrationRetryTimer.foreach(_.cancel(true)) } } def start() { // Just launch an actor; it will call back into the listener. - actor = actorSystem.actorOf(Props(new ClientActor)) + actor = rpcEnv.setupEndpoint("client-actor", new ClientActor(rpcEnv)) } def stop() { if (actor != null) { try { - val timeout = AkkaUtils.askTimeout(conf) - val future = actor.ask(StopAppClient)(timeout) - Await.result(future, timeout) + actor.askWithReply(StopAppClient) } catch { case e: TimeoutException => logInfo("Stop request to Master timed out; it may already be shut down.") diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala index 88a0862b96af..10b86294a7c0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala @@ -17,6 +17,7 @@ package org.apache.spark.deploy.client +import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.{SecurityManager, SparkConf, Logging} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.util.{AkkaUtils, Utils} @@ -48,10 +49,11 @@ private[spark] object TestClient { val conf = new SparkConf val (actorSystem, _) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0, conf = conf, securityManager = new SecurityManager(conf)) + val rpcEnv = new AkkaRpcEnv(actorSystem, conf) val desc = new ApplicationDescription("TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map(), Seq(), Seq(), Seq()), "ignored") val listener = new TestListener - val client = new AppClient(actorSystem, Array(url), desc, listener, new SparkConf) + val client = new AppClient(rpcEnv, Array(url), desc, listener, new SparkConf) client.start() actorSystem.awaitTermination() } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index c34d2493c532..2c6b895ee3b7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -42,7 +42,7 @@ import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.deploy.master.MasterMessages._ import org.apache.spark.deploy.master.ui.MasterWebUI import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.rpc.{RpcEndpoint, RpcEnv, RpcEndpointRef} +import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, RpcEnv, RpcEndpointRef} import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus} import org.apache.spark.ui.SparkUI import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} @@ -72,12 +72,12 @@ private[spark] class Master( val workers = new HashSet[WorkerInfo] val idToWorker = new HashMap[String, WorkerInfo] - val addressToWorker = new HashMap[String, WorkerInfo] + val addressToWorker = new HashMap[RpcAddress, WorkerInfo] val apps = new HashSet[ApplicationInfo] val idToApp = new HashMap[String, ApplicationInfo] val actorToApp = new HashMap[RpcEndpointRef, ApplicationInfo] - val addressToApp = new HashMap[String, ApplicationInfo] + val addressToApp = new HashMap[RpcAddress, ApplicationInfo] val waitingApps = new ArrayBuffer[ApplicationInfo] val completedApps = new ArrayBuffer[ApplicationInfo] var nextAppNumber = 0 @@ -431,7 +431,7 @@ private[spark] class Master( } } - override def remoteConnectionTerminated(address: String): Unit = { + override def remoteConnectionTerminated(address: RpcAddress): Unit = { // The disconnected client could've been either a worker or an app; remove whichever it was logInfo(s"$address got disassociated, removing it.") addressToWorker.get(address).foreach(removeWorker) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 145ae97ce3b1..14c10b3dafe9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -21,7 +21,8 @@ import akka.remote.{AssociatedEvent, AssociationErrorEvent, AssociationEvent} import org.apache.spark.Logging import org.apache.spark.deploy.DeployMessages.SendHeartbeat -import org.apache.spark.rpc.{RpcEnv, RpcEndpointRef, RpcEndpoint} +import org.apache.spark.rpc.{RpcAddress, RpcEnv, RpcEndpointRef, RpcEndpoint} +import org.apache.spark.util.AkkaUtils /** * Actor which connects to a worker process and terminates the JVM if the connection is severed. @@ -45,20 +46,19 @@ private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: Strin // Lets us filter events only from the worker's actor system private val expectedHostPort = new java.net.URI(workerUrl) - private def isWorker(address: String) = { - val uri = new java.net.URI(address) - expectedHostPort.getHost == uri.getHost && expectedHostPort.getPort == uri.getPort + private def isWorker(address: RpcAddress) = { + expectedHostPort.getHost == address.host && expectedHostPort.getPort == address.port } def exitNonZero() = if (isTesting) isShutDown = true else System.exit(-1) override def receive(sender: RpcEndpointRef) = { case AssociatedEvent(localAddress, remoteAddress, inbound) - if isWorker(remoteAddress.toString) => + if isWorker(AkkaUtils.akkaAddressToRpcAddress(remoteAddress)) => logInfo(s"Successfully connected to $workerUrl") case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) - if isWorker(remoteAddress.toString) => + if isWorker(AkkaUtils.akkaAddressToRpcAddress(remoteAddress)) => // These logs may not be seen if the worker (and associated pipe) has died logError(s"Could not initialize connection to worker $workerUrl. Exiting.") logError(s"Error was: $cause") @@ -70,7 +70,7 @@ private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: Strin case e => logWarning(s"Received unexpected actor system event: $e") } - override def remoteConnectionTerminated(remoteAddress: String): Unit = { + override def remoteConnectionTerminated(remoteAddress: RpcAddress): Unit = { if(isWorker(remoteAddress)) { // This log message will never be seen logError(s"Lost connection to worker actor $workerUrl. Exiting.") diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 30f4e2a07d10..bf0b5c21ed3a 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -24,7 +24,7 @@ import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.worker.WorkerWatcher import org.apache.spark.rpc.akka.AkkaRpcEnv -import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef} +import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, RpcEndpointRef} import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} @@ -88,7 +88,7 @@ private[spark] class CoarseGrainedExecutorBackend( env.actorSystem.shutdown() } - override def remoteConnectionTerminated(remoteAddress: String): Unit = { + override def remoteConnectionTerminated(remoteAddress: RpcAddress): Unit = { logError(s"Driver $remoteAddress disassociated! Shutting down.") System.exit(1) } diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 412ddb0ce645..dcd526117f4a 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -114,7 +114,7 @@ trait RpcEndpoint { throw e } - def remoteConnectionTerminated(remoteAddress: String): Unit = { + def remoteConnectionTerminated(remoteAddress: RpcAddress): Unit = { // By default, do nothing. } @@ -135,11 +135,7 @@ object RpcEndpoint { */ trait RpcEndpointRef { - def address: String - - def host: Option[String] - - def port: Option[Int] + def address: RpcAddress def ask[T: ClassTag](message: Any): Future[T] @@ -154,3 +150,10 @@ trait RpcEndpointRef { */ def send(message: Any)(implicit sender: RpcEndpointRef = RpcEndpoint.noSender): Unit } + +case class RpcAddress(host: String, port: Int) { + + val hostPort: String = host + ":" + port + + override val toString: String = hostPort +} \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index fd99b62f4750..07726232511e 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -59,7 +59,9 @@ class AkkaRpcEnv(val actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { override def receive: Receive = { case DisassociatedEvent(_, remoteAddress, _) => try { - endpoint.remoteConnectionTerminated(remoteAddress.toString) + // TODO How to handle that a remoteAddress doesn't have host & port + endpoint.remoteConnectionTerminated( + AkkaUtils.akkaAddressToRpcAddress(remoteAddress)) } catch { case NonFatal(e) => endpoint.onError(e) } @@ -117,11 +119,7 @@ private[akka] class AkkaRpcEndpointRef(val actorRef: ActorRef, @transient conf: private[this] val retryWaitMs = conf.getInt("spark.akka.retry.wait", 3000) private[this] val defaultTimeout = conf.getLong("spark.akka.lookupTimeout", 30) seconds - override val address: String = actorRef.path.address.toString - - override val host: Option[String] = actorRef.path.address.host - - override val port: Option[Int] = actorRef.path.address.port + override val address: RpcAddress = AkkaUtils.akkaAddressToRpcAddress(actorRef.path.address) override def ask[T: ClassTag](message: Any): Future[T] = ask(message, defaultTimeout) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 35f2a47b6802..3e05afc5aa16 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -23,7 +23,7 @@ import java.util.concurrent.{TimeUnit, Executors} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import org.apache.spark.{ExecutorAllocationClient, Logging, SparkEnv, SparkException, TaskState} -import org.apache.spark.rpc.{RpcEndpoint, RpcEnv, RpcEndpointRef} +import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, RpcEnv, RpcEndpointRef} import org.apache.spark.scheduler.{SchedulerBackend, SlaveLost, TaskDescription, TaskSchedulerImpl, WorkerOffer} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.{SerializableBuffer, AkkaUtils, Utils} @@ -69,7 +69,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp extends RpcEndpoint with Logging { override protected def log = CoarseGrainedSchedulerBackend.this.log - private val addressToExecutorId = new HashMap[String, String] + private val addressToExecutorId = new HashMap[RpcAddress, String] private val reviveScheduler = Executors.newScheduledThreadPool(1, Utils.namedThreadFactory("driver-revive-scheduler")) @@ -195,7 +195,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } } - override def remoteConnectionTerminated(remoteAddress: String): Unit = { + override def remoteConnectionTerminated(remoteAddress: RpcAddress): Unit = { addressToExecutorId.get(remoteAddress).foreach(removeExecutor(_, "remote Akka client disassociated")) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala index 7045c577443c..8e6a3fe3360f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler.cluster -import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef} /** * Grouping of data for an executor used by CoarseGrainedSchedulerBackend. @@ -30,7 +30,7 @@ import org.apache.spark.rpc.RpcEndpointRef */ private[cluster] class ExecutorData( val executorActor: RpcEndpointRef, - val executorAddress: String, + val executorAddress: RpcAddress, val executorHost: String , var freeCores: Int, val totalCores: Int diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index f342d9c466e2..7cbf917bd2bb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -72,7 +72,7 @@ private[spark] class SparkDeploySchedulerBackend( val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, appUIAddress, sc.eventLogDir) - client = new AppClient(sc.env.actorSystem, masters, appDesc, this, conf) + client = new AppClient(sc.env.rpcEnv, masters, appDesc, this, conf) client.start() waitForRegistration() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 133fc3dfa8b7..ea3d1f45cd33 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -418,11 +418,9 @@ class BlockManagerMasterActor(override val rpcEnv: RpcEnv, val isLocal: Boolean, private def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = { for ( blockManagerId <- blockManagerIdByExecutor.get(executorId); - info <- blockManagerInfo.get(blockManagerId); - host <- info.slaveActor.host; - port <- info.slaveActor.port + info <- blockManagerInfo.get(blockManagerId) ) yield { - (host, port) + (info.slaveActor.address.host, info.slaveActor.address.port) } } } diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 8c2457f56bff..580727e9bf17 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -17,11 +17,13 @@ package org.apache.spark.util +import org.apache.spark.rpc.RpcAddress + import scala.collection.JavaConversions.mapAsJavaMap import scala.concurrent.Await import scala.concurrent.duration.{Duration, FiniteDuration} -import akka.actor.{ActorRef, ActorSystem, ExtendedActorSystem} +import akka.actor.{Address, ActorRef, ActorSystem, ExtendedActorSystem} import akka.pattern.ask import com.typesafe.config.ConfigFactory @@ -233,4 +235,8 @@ private[spark] object AkkaUtils extends Logging { logInfo(s"Connecting to $name: $url") Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) } + + def akkaAddressToRpcAddress(akkaAddress: Address): RpcAddress = { + RpcAddress(akkaAddress.host.getOrElse("localhost"), akkaAddress.port.getOrElse(-1)) + } } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala index a3dcf9b14033..4a7530426d88 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala @@ -20,18 +20,19 @@ package org.apache.spark.deploy.worker import akka.actor.{ActorSystem, AddressFromURIString} import org.apache.spark.SparkConf import org.apache.spark.rpc.akka.AkkaRpcEnv +import org.apache.spark.util.AkkaUtils import org.scalatest.FunSuite class WorkerWatcherSuite extends FunSuite { test("WorkerWatcher shuts down on valid disassociation") { val actorSystem = ActorSystem("test") val rpcEnv = new AkkaRpcEnv(actorSystem, new SparkConf()) - val targetWorkerUrl = "akka://1.2.3.4/user/Worker" + val targetWorkerUrl = "akka://test@1.2.3.4:1234/user/Worker" val targetWorkerAddress = AddressFromURIString(targetWorkerUrl) val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) workerWatcher.setTesting(testing = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) - workerWatcher.remoteConnectionTerminated(targetWorkerAddress.toString) + workerWatcher.remoteConnectionTerminated(AkkaUtils.akkaAddressToRpcAddress(targetWorkerAddress)) assert(workerWatcher.isShutDown) rpcEnv.stopAll() actorSystem.shutdown() @@ -40,13 +41,13 @@ class WorkerWatcherSuite extends FunSuite { test("WorkerWatcher stays alive on invalid disassociation") { val actorSystem = ActorSystem("test") val rpcEnv = new AkkaRpcEnv(actorSystem, new SparkConf()) - val targetWorkerUrl = "akka://1.2.3.4/user/Worker" - val otherAkkaURL = "akka://4.3.2.1/user/OtherActor" + val targetWorkerUrl = "akka://test@1.2.3.4:1234/user/Worker" + val otherAkkaURL = "akka://test@4.3.2.1:1234/user/OtherActor" val otherAkkaAddress = AddressFromURIString(otherAkkaURL) val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) workerWatcher.setTesting(testing = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) - workerWatcher.remoteConnectionTerminated(otherAkkaAddress.toString) + workerWatcher.remoteConnectionTerminated(AkkaUtils.akkaAddressToRpcAddress(otherAkkaAddress)) assert(!workerWatcher.isShutDown) rpcEnv.stopAll() actorSystem.shutdown() From 6dff656eeaa9c9a53918a80f88eee57914ca6738 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 5 Jan 2015 19:25:02 +0800 Subject: [PATCH 23/36] Change Worker to use RpcEndpoint --- .../org/apache/spark/deploy/Client.scala | 43 ++++--- .../spark/deploy/master/ui/MasterPage.scala | 3 - .../spark/deploy/worker/DriverRunner.scala | 6 +- .../spark/deploy/worker/ExecutorRunner.scala | 8 +- .../apache/spark/deploy/worker/Worker.scala | 107 +++++++++++------- .../spark/deploy/worker/ui/WorkerPage.scala | 10 +- .../spark/deploy/worker/ui/WorkerWebUI.scala | 2 - 7 files changed, 94 insertions(+), 85 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index f2687ce6b42b..0b4262144084 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -17,31 +17,26 @@ package org.apache.spark.deploy -import scala.concurrent._ - -import akka.actor._ -import akka.pattern.ask -import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent} +import akka.remote.AssociationErrorEvent import org.apache.log4j.{Level, Logger} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, Utils} +import org.apache.spark.rpc.akka.AkkaRpcEnv +import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, RpcEnv, RpcEndpointRef} +import org.apache.spark.util.{AkkaUtils, Utils} /** * Proxy that relays messages to the driver. */ -private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) - extends Actor with ActorLogReceive with Logging { +private class ClientActor(override val rpcEnv: RpcEnv, driverArgs: ClientArguments, conf: SparkConf) + extends RpcEndpoint with Logging { - var masterActor: ActorSelection = _ - val timeout = AkkaUtils.askTimeout(conf) + var masterActor: RpcEndpointRef = _ override def preStart() = { - masterActor = context.actorSelection(Master.toAkkaUrl(driverArgs.master)) - - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + masterActor = rpcEnv.setupEndpointRefByUrl(Master.toAkkaUrl(driverArgs.master)) println(s"Sending ${driverArgs.cmd} command to ${driverArgs.master}") @@ -77,11 +72,11 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) driverArgs.supervise, command) - masterActor ! RequestSubmitDriver(driverDescription) + masterActor.send(RequestSubmitDriver(driverDescription)) case "kill" => val driverId = driverArgs.driverId - masterActor ! RequestKillDriver(driverId) + masterActor.send(RequestKillDriver(driverId)) } } @@ -90,9 +85,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) println(s"... waiting before polling master for driver state") Thread.sleep(5000) println("... polling master for driver state") - val statusFuture = (masterActor ? RequestDriverStatus(driverId))(timeout) - .mapTo[DriverStatusResponse] - val statusResponse = Await.result(statusFuture, timeout) + val statusResponse = masterActor.askWithReply[DriverStatusResponse](RequestDriverStatus(driverId)) statusResponse.found match { case false => @@ -116,7 +109,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) } } - override def receiveWithLogging = { + override def receive(sender: RpcEndpointRef) = { case SubmitDriverResponse(success, driverId, message) => println(message) @@ -126,15 +119,16 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) println(message) if (success) pollAndReportStatus(driverId) else System.exit(-1) - case DisassociatedEvent(_, remoteAddress, _) => - println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") - System.exit(-1) - case AssociationErrorEvent(cause, _, remoteAddress, _, _) => println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") println(s"Cause was: $cause") System.exit(-1) } + + override def remoteConnectionTerminated(remoteAddress: RpcAddress): Unit = { + println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") + System.exit(-1) + } } /** @@ -160,7 +154,8 @@ object Client { val (actorSystem, _) = AkkaUtils.createActorSystem( "driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf)) - actorSystem.actorOf(Props(classOf[ClientActor], driverArgs, conf)) + val rpcEnv = new AkkaRpcEnv(actorSystem, conf) + rpcEnv.setupEndpoint("client-actor", new ClientActor(rpcEnv, driverArgs, conf)) actorSystem.awaitTermination() } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 740fe3460cfb..7d9e5f7d78c7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -19,10 +19,8 @@ package org.apache.spark.deploy.master.ui import javax.servlet.http.HttpServletRequest -import scala.concurrent.Await import scala.xml.Node -import akka.pattern.ask import org.json4s.JValue import org.apache.spark.deploy.JsonProtocol @@ -33,7 +31,6 @@ import org.apache.spark.util.Utils private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { private def master = parent.masterActorRef - private val timeout = parent.timeout override def renderJson(request: HttpServletRequest): JValue = { val state = master.askWithReply[MasterStateResponse](RequestMasterState) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 28cab36c7b9e..54326bd3721c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -19,6 +19,8 @@ package org.apache.spark.deploy.worker import java.io._ +import org.apache.spark.rpc.RpcEndpointRef + import scala.collection.JavaConversions._ import scala.collection.Map @@ -44,7 +46,7 @@ private[spark] class DriverRunner( val workDir: File, val sparkHome: File, val driverDesc: DriverDescription, - val worker: ActorRef, + val worker: RpcEndpointRef, val workerUrl: String) extends Logging { @@ -98,7 +100,7 @@ private[spark] class DriverRunner( finalState = Some(state) - worker ! DriverStateChanged(driverId, state, finalException) + worker.send(DriverStateChanged(driverId, state, finalException)) } }.start() } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index acbdf0d8bd7b..7e77c9028e9c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -21,13 +21,13 @@ import java.io._ import scala.collection.JavaConversions._ -import akka.actor.ActorRef import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.apache.spark.{SparkConf, Logging} import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState} import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.logging.FileAppender /** @@ -40,7 +40,7 @@ private[spark] class ExecutorRunner( val appDesc: ApplicationDescription, val cores: Int, val memory: Int, - val worker: ActorRef, + val worker: RpcEndpointRef, val workerId: String, val host: String, val sparkHome: File, @@ -94,7 +94,7 @@ private[spark] class ExecutorRunner( } exitCode = Some(process.waitFor()) } - worker ! ExecutorStateChanged(appId, execId, state, message, exitCode) + worker.send(ExecutorStateChanged(appId, execId, state, message, exitCode)) } /** Stop this executor runner, including killing the process it launched */ @@ -151,7 +151,7 @@ private[spark] class ExecutorRunner( val exitCode = process.waitFor() state = ExecutorState.EXITED val message = "Command exited with code " + exitCode - worker ! ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode)) + worker.send(ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode))) } catch { case interrupted: InterruptedException => { logInfo("Runner thread for executor " + fullId + " interrupted") diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index f0f3da5eec4d..01fb8aeb1fc9 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -20,16 +20,20 @@ package org.apache.spark.deploy.worker import java.io.File import java.io.IOException import java.text.SimpleDateFormat +import java.util.concurrent.{ScheduledFuture, TimeUnit, Executors} import java.util.{UUID, Date} +import org.apache.spark.rpc.akka.AkkaRpcEnv +import org.apache.spark.rpc.{RpcAddress, RpcEnv, RpcEndpointRef, RpcEndpoint} + import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} +import scala.concurrent.ExecutionContext import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.Random -import akka.actor._ -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} +import akka.actor.ActorSystem import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ExecutorDescription, ExecutorState} @@ -37,12 +41,13 @@ import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} +import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} /** * @param masterUrls Each url should look like spark://host:port. */ private[spark] class Worker( + override val rpcEnv: RpcEnv, host: String, port: Int, webUiPort: Int, @@ -54,8 +59,13 @@ private[spark] class Worker( workDirPath: String = null, val conf: SparkConf, val securityMgr: SecurityManager) - extends Actor with ActorLogReceive with Logging { - import context.dispatcher + extends RpcEndpoint with Logging { + + val scheduler = Executors.newScheduledThreadPool(1, + Utils.namedThreadFactory("worker-scheduler")) + + implicit val cleanupExecutor = ExecutionContext.fromExecutor( + Executors.newFixedThreadPool(1, Utils.namedThreadFactory("cleanup-thread"))) Utils.checkHost(host, "Expected hostname") assert (port > 0) @@ -89,8 +99,8 @@ private[spark] class Worker( val APP_DATA_RETENTION_SECS = conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600) val testing: Boolean = sys.props.contains("spark.testing") - var master: ActorSelection = null - var masterAddress: Address = null + var master: RpcEndpointRef = null + var masterAddress: RpcAddress = null var activeMasterUrl: String = "" var activeMasterWebUiUrl : String = "" val akkaUrl = "akka.tcp://%s@%s:%s/user/%s".format(actorSystemName, host, port, actorName) @@ -128,7 +138,7 @@ private[spark] class Worker( val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr) val workerSource = new WorkerSource(this) - var registrationRetryTimer: Option[Cancellable] = None + var registrationRetryTimer: Option[ScheduledFuture[_]] = None def coresFree: Int = cores - coresUsed def memoryFree: Int = memory - memoryUsed @@ -158,7 +168,6 @@ private[spark] class Worker( logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}") logInfo("Spark home: " + sparkHome) createWorkDir() - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) shuffleService.startIfEnabled() webUi = new WorkerWebUI(this, workDir, webUiPort) webUi.bind() @@ -173,24 +182,25 @@ private[spark] class Worker( def changeMaster(url: String, uiUrl: String) { activeMasterUrl = url activeMasterWebUiUrl = uiUrl - master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl)) + master = rpcEnv.setupEndpointRefByUrl(Master.toAkkaUrl(activeMasterUrl)) masterAddress = activeMasterUrl match { case Master.sparkUrlRegex(_host, _port) => - Address("akka.tcp", Master.systemName, _host, _port.toInt) + RpcAddress(_host, _port.toInt) case x => throw new SparkException("Invalid spark URL: " + x) } connected = true // Cancel any outstanding re-registration attempts because we found a new master - registrationRetryTimer.foreach(_.cancel()) + registrationRetryTimer.foreach(_.cancel(true)) registrationRetryTimer = None } private def tryRegisterAllMasters() { for (masterUrl <- masterUrls) { logInfo("Connecting to master " + masterUrl + "...") - val actor = context.actorSelection(Master.toAkkaUrl(masterUrl)) - actor ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort, publicAddress) + val actor = rpcEnv.setupEndpointRefByUrl(Master.toAkkaUrl(masterUrl)) + actor.send( + RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort, publicAddress)) } } @@ -203,7 +213,7 @@ private[spark] class Worker( Utils.tryOrExit { connectionAttemptCount += 1 if (registered) { - registrationRetryTimer.foreach(_.cancel()) + registrationRetryTimer.foreach(_.cancel(true)) registrationRetryTimer = None } else if (connectionAttemptCount <= TOTAL_REGISTRATION_RETRIES) { logInfo(s"Retrying connection to master (attempt # $connectionAttemptCount)") @@ -228,8 +238,8 @@ private[spark] class Worker( * less likely scenario. */ if (master != null) { - master ! RegisterWorker( - workerId, host, port, cores, memory, webUi.boundPort, publicAddress) + master.send(RegisterWorker( + workerId, host, port, cores, memory, webUi.boundPort, publicAddress)) } else { // We are retrying the initial registration tryRegisterAllMasters() @@ -237,10 +247,12 @@ private[spark] class Worker( // We have exceeded the initial registration retry threshold // All retries from now on should use a higher interval if (connectionAttemptCount == INITIAL_REGISTRATION_RETRIES) { - registrationRetryTimer.foreach(_.cancel()) + registrationRetryTimer.foreach(_.cancel(true)) registrationRetryTimer = Some { - context.system.scheduler.schedule(PROLONGED_REGISTRATION_RETRY_INTERVAL, - PROLONGED_REGISTRATION_RETRY_INTERVAL, self, ReregisterWithMaster) + scheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = self.send(ReregisterWithMaster) + }, PROLONGED_REGISTRATION_RETRY_INTERVAL.toMillis, + PROLONGED_REGISTRATION_RETRY_INTERVAL.toMillis, TimeUnit.MILLISECONDS) } } } else { @@ -259,8 +271,10 @@ private[spark] class Worker( tryRegisterAllMasters() connectionAttemptCount = 0 registrationRetryTimer = Some { - context.system.scheduler.schedule(INITIAL_REGISTRATION_RETRY_INTERVAL, - INITIAL_REGISTRATION_RETRY_INTERVAL, self, ReregisterWithMaster) + scheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = self.send(ReregisterWithMaster) + }, INITIAL_REGISTRATION_RETRY_INTERVAL.toMillis, + INITIAL_REGISTRATION_RETRY_INTERVAL.toMillis, TimeUnit.MILLISECONDS) } case Some(_) => logInfo("Not spawning another attempt to register with the master, since there is an" + @@ -268,20 +282,23 @@ private[spark] class Worker( } } - override def receiveWithLogging = { + override def receive(sender: RpcEndpointRef) = { case RegisteredWorker(masterUrl, masterWebUiUrl) => logInfo("Successfully registered with master " + masterUrl) registered = true changeMaster(masterUrl, masterWebUiUrl) - context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis, self, SendHeartbeat) + scheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = self.send(SendHeartbeat) + }, 0, HEARTBEAT_MILLIS, TimeUnit.MILLISECONDS) if (CLEANUP_ENABLED) { logInfo(s"Worker cleanup enabled; old application directories will be deleted in: $workDir") - context.system.scheduler.schedule(CLEANUP_INTERVAL_MILLIS millis, - CLEANUP_INTERVAL_MILLIS millis, self, WorkDirCleanup) + scheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = self.send(WorkDirCleanup) + }, CLEANUP_INTERVAL_MILLIS, CLEANUP_INTERVAL_MILLIS, TimeUnit.MILLISECONDS) } case SendHeartbeat => - if (connected) { master ! Heartbeat(workerId) } + if (connected) { master.send(Heartbeat(workerId)) } case WorkDirCleanup => // Spin up a separate thread (in a future) to do the dir cleanup; don't tie up worker actor @@ -314,10 +331,10 @@ private[spark] class Worker( val execs = executors.values. map(e => new ExecutorDescription(e.appId, e.execId, e.cores, e.state)) - sender ! WorkerSchedulerStateResponse(workerId, execs.toList, drivers.keys.toSeq) + sender.send(WorkerSchedulerStateResponse(workerId, execs.toList, drivers.keys.toSeq)) case Heartbeat => - logInfo(s"Received heartbeat from driver ${sender.path}") + logInfo(s"Received heartbeat from driver ${sender.address}") case RegisterWorkerFailed(message) => if (!registered) { @@ -359,7 +376,7 @@ private[spark] class Worker( manager.start() coresUsed += cores_ memoryUsed += memory_ - master ! ExecutorStateChanged(appId, execId, manager.state, None, None) + master.send(ExecutorStateChanged(appId, execId, manager.state, None, None)) } catch { case e: Exception => { logError(s"Failed to launch executor $appId/$execId for ${appDesc.name}.", e) @@ -367,14 +384,14 @@ private[spark] class Worker( executors(appId + "/" + execId).kill() executors -= appId + "/" + execId } - master ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, - Some(e.toString), None) + master.send(ExecutorStateChanged(appId, execId, ExecutorState.FAILED, + Some(e.toString), None)) } } } case ExecutorStateChanged(appId, execId, state, message, exitStatus) => - master ! ExecutorStateChanged(appId, execId, state, message, exitStatus) + master.send(ExecutorStateChanged(appId, execId, state, message, exitStatus)) val fullId = appId + "/" + execId if (ExecutorState.isFinished(state)) { executors.get(fullId) match { @@ -441,22 +458,18 @@ private[spark] class Worker( case _ => logDebug(s"Driver $driverId changed state to $state") } - master ! DriverStateChanged(driverId, state, exception) + master.send(DriverStateChanged(driverId, state, exception)) val driver = drivers.remove(driverId).get finishedDrivers(driverId) = driver memoryUsed -= driver.driverDesc.mem coresUsed -= driver.driverDesc.cores } - case x: DisassociatedEvent if x.remoteAddress == masterAddress => - logInfo(s"$x Disassociated !") - masterDisconnected() - case RequestWorkerState => - sender ! WorkerStateResponse(host, port, workerId, executors.values.toList, + sender.send(WorkerStateResponse(host, port, workerId, executors.values.toList, finishedExecutors.values.toList, drivers.values.toList, finishedDrivers.values.toList, activeMasterUrl, cores, memory, - coresUsed, memoryUsed, activeMasterWebUiUrl) + coresUsed, memoryUsed, activeMasterWebUiUrl)) case ReregisterWithMaster => reregisterWithMaster() @@ -466,6 +479,13 @@ private[spark] class Worker( maybeCleanupApplication(id) } + override def remoteConnectionTerminated(remoteAddress: RpcAddress): Unit = { + if (remoteAddress == masterAddress ) { + logInfo(s"$remoteAddress Disassociated !") + masterDisconnected() + } + } + private def masterDisconnected() { logError("Connection to master failed! Waiting for master to reconnect...") connected = false @@ -491,7 +511,7 @@ private[spark] class Worker( override def postStop() { metricsSystem.report() - registrationRetryTimer.foreach(_.cancel()) + registrationRetryTimer.foreach(_.cancel(true)) executors.values.foreach(_.kill()) drivers.values.foreach(_.kill()) shuffleService.stop() @@ -527,8 +547,9 @@ private[spark] object Worker extends Logging { val securityMgr = new SecurityManager(conf) val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf, securityManager = securityMgr) - actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory, - masterUrls, systemName, actorName, workDir, conf, securityMgr), name = actorName) + val rpcEnv = new AkkaRpcEnv(actorSystem, conf) + rpcEnv.setupEndpoint(actorName, new Worker(rpcEnv, host, boundPort, webUiPort, cores, memory, + masterUrls, systemName, actorName, workDir, conf, securityMgr)) (actorSystem, boundPort) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index 327b90503280..c3f9b91e7056 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -17,10 +17,8 @@ package org.apache.spark.deploy.worker.ui -import scala.concurrent.Await import scala.xml.Node -import akka.pattern.ask import javax.servlet.http.HttpServletRequest import org.json4s.JValue @@ -34,17 +32,15 @@ import org.apache.spark.util.Utils private[spark] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { val workerActor = parent.worker.self val worker = parent.worker - val timeout = parent.timeout override def renderJson(request: HttpServletRequest): JValue = { - val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse] - val workerState = Await.result(stateFuture, timeout) + val workerState = workerActor.askWithReply[WorkerStateResponse](RequestWorkerState) JsonProtocol.writeWorkerState(workerState) } def render(request: HttpServletRequest): Seq[Node] = { - val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse] - val workerState = Await.result(stateFuture, timeout) + val workerState = workerActor.askWithReply[WorkerStateResponse](RequestWorkerState) + JsonProtocol.writeWorkerState(workerState) val executorHeaders = Seq("ExecutorID", "Cores", "State", "Memory", "Job Details", "Logs") val runningExecutors = workerState.executors diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index 7ac81a2d87ef..9c2f51e27a56 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -38,8 +38,6 @@ class WorkerWebUI( extends WebUI(worker.securityMgr, requestedPort, worker.conf, name = "WorkerUI") with Logging { - val timeout = AkkaUtils.askTimeout(worker.conf) - initialize() /** Initialize all components of the server. */ From 3e90325ad5475c6642e53ea5425db73ce1d87db8 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 5 Jan 2015 19:59:07 +0800 Subject: [PATCH 24/36] Some cleanup --- .../scala/org/apache/spark/SparkContext.scala | 1 - .../main/scala/org/apache/spark/SparkEnv.scala | 17 ++++------------- .../spark/deploy/worker/DriverRunner.scala | 1 - .../spark/deploy/worker/DriverWrapper.scala | 6 ++++-- .../org/apache/spark/storage/BlockManager.scala | 1 - 5 files changed, 8 insertions(+), 18 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 3d65126c44d3..8f96ecebb464 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -36,7 +36,6 @@ import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, Sequence import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHadoopJob} import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.mesos.MesosNativeLibrary -import akka.actor.Props import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 3896a3d5d14b..222759377cd7 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConversions._ import scala.collection.mutable import scala.util.Properties -import akka.actor._ +import akka.actor.ActorSystem import com.google.common.collect.MapMaker import org.apache.spark.annotation.DeveloperApi @@ -263,16 +263,7 @@ object SparkEnv extends Logging { val closureSerializer = instantiateClassFromConf[Serializer]( "spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer") - def registerOrLookup(name: String, newActor: => Actor): ActorRef = { - if (isDriver) { - logInfo("Registering " + name) - actorSystem.actorOf(Props(newActor), name = name) - } else { - AkkaUtils.makeDriverRef(name, conf, actorSystem) - } - } - - def registerOrLookupEndpoint(name: String, endpointCreator: => RpcEndpoint): RpcEndpointRef = { + def registerOrLookup(name: String, endpointCreator: => RpcEndpoint): RpcEndpointRef = { if (isDriver) { logInfo("Registering " + name) rpcEnv.setupEndpoint(name, endpointCreator) @@ -289,7 +280,7 @@ object SparkEnv extends Logging { // Have to assign trackerActor after initialization as MapOutputTrackerActor // requires the MapOutputTracker itself - mapOutputTracker.trackerActor = registerOrLookupEndpoint("MapOutputTracker", + mapOutputTracker.trackerActor = registerOrLookup("MapOutputTracker", new MapOutputTrackerMasterActor( rpcEnv, mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf)) @@ -311,7 +302,7 @@ object SparkEnv extends Logging { new NioBlockTransferService(conf, securityManager) } - val blockManagerMaster = new BlockManagerMaster(registerOrLookupEndpoint( + val blockManagerMaster = new BlockManagerMaster(registerOrLookup( "BlockManagerMaster", new BlockManagerMasterActor(rpcEnv, isLocal, conf, listenerBus)), conf, isDriver) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 54326bd3721c..9b063e3c41f9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -24,7 +24,6 @@ import org.apache.spark.rpc.RpcEndpointRef import scala.collection.JavaConversions._ import scala.collection.Map -import akka.actor.ActorRef import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.apache.hadoop.conf.Configuration diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index 05e242e6df70..c44dd05d31bb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.worker -import akka.actor._ +import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.util.{AkkaUtils, Utils} @@ -32,13 +32,15 @@ object DriverWrapper { val conf = new SparkConf() val (actorSystem, _) = AkkaUtils.createActorSystem("Driver", Utils.localHostName(), 0, conf, new SecurityManager(conf)) - actorSystem.actorOf(Props(classOf[WorkerWatcher], workerUrl), name = "workerWatcher") + val rpcEnv = new AkkaRpcEnv(actorSystem, conf) + rpcEnv.setupEndpoint("workerWatcher", new WorkerWatcher(rpcEnv, workerUrl)) // Delegate to supplied main class val clazz = Class.forName(args(1)) val mainMethod = clazz.getMethod("main", classOf[Array[String]]) mainMethod.invoke(null, extraArgs.toArray[String]) + rpcEnv.stopAll() actorSystem.shutdown() case _ => diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 0285ec18be6f..a4705967ec81 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -26,7 +26,6 @@ import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ import scala.util.Random -import akka.actor.{ActorSystem, Props} import sun.nio.ch.DirectBuffer import org.apache.spark._ From 9a348cf28da1d071d978ff0ff1270af322b9cc91 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 6 Jan 2015 11:10:20 +0800 Subject: [PATCH 25/36] Change YarnSchedulerBackend and LocalBackend to use RpcEndpoint --- .../master/ZooKeeperLeaderElectionAgent.scala | 2 - .../deploy/master/ui/ApplicationPage.scala | 2 +- .../spark/deploy/master/ui/MasterPage.scala | 2 +- .../spark/deploy/master/ui/MasterWebUI.scala | 2 +- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 6 +-- .../scheduler/cluster/ExecutorData.scala | 2 +- .../cluster/YarnSchedulerBackend.scala | 49 +++++++------------ .../spark/scheduler/local/LocalBackend.scala | 24 +++++---- .../apache/spark/util/ActorLogReceive.scala | 2 +- 9 files changed, 38 insertions(+), 53 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala index 8eaa0ad94851..017844dd57de 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala @@ -17,8 +17,6 @@ package org.apache.spark.deploy.master -import akka.actor.ActorRef - import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.master.MasterMessages._ import org.apache.curator.framework.CuratorFramework diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 96827229eee3..c83a9600896c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -31,7 +31,7 @@ import org.apache.spark.util.Utils private[spark] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") { - private def master = parent.masterActorRef + private def master = parent.masterEndpointRef /** Executor details for a particular application */ override def renderJson(request: HttpServletRequest): JValue = { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 7d9e5f7d78c7..67bad341532e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -30,7 +30,7 @@ import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.util.Utils private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { - private def master = parent.masterActorRef + private def master = parent.masterEndpointRef override def renderJson(request: HttpServletRequest): JValue = { val state = master.askWithReply[MasterStateResponse](RequestMasterState) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 3727bc7b9e83..2c9d2c6b1e52 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -30,7 +30,7 @@ private[spark] class MasterWebUI(val master: Master, requestedPort: Int) extends WebUI(master.securityMgr, requestedPort, master.conf, name = "MasterUI") with Logging { - def masterActorRef = master.self + def masterEndpointRef = master.self val timeout = AkkaUtils.askTimeout(master.conf) initialize() diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 07726232511e..6f3e74c4156a 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -30,7 +30,7 @@ import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import org.apache.spark.{Logging, SparkException, SparkConf} import org.apache.spark.rpc._ -import org.apache.spark.util.AkkaUtils +import org.apache.spark.util.{ActorLogReceive, AkkaUtils} import scala.reflect.ClassTag import scala.util.control.NonFatal @@ -42,7 +42,7 @@ class AkkaRpcEnv(val actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { val latch = new CountDownLatch(1) try { @volatile var endpointRef: AkkaRpcEndpointRef = null - val actorRef = actorSystem.actorOf(Props(new Actor with Logging { + val actorRef = actorSystem.actorOf(Props(new Actor with ActorLogReceive with Logging { val endpoint = endpointCreator latch.await() @@ -56,7 +56,7 @@ class AkkaRpcEnv(val actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) } - override def receive: Receive = { + override def receiveWithLogging: Receive = { case DisassociatedEvent(_, remoteAddress, _) => try { // TODO How to handle that a remoteAddress doesn't have host & port diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala index 8e6a3fe3360f..d9ef4b85c5d8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala @@ -22,7 +22,7 @@ import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef} /** * Grouping of data for an executor used by CoarseGrainedSchedulerBackend. * - * @param executorActor The ActorRef representing this executor + * @param executorActor The RpcEndpointRef representing this executor * @param executorAddress The network address of this executor * @param executorHost The hostname that this executor is running on * @param freeCores The current number of cores available for work on the executor diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 23f232211746..1dbc49ac98e8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -17,10 +17,8 @@ package org.apache.spark.scheduler.cluster -import akka.actor.{Actor, ActorRef, Props} -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} - import org.apache.spark.SparkContext +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv, RpcEndpoint} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.ui.JettyUtils @@ -35,18 +33,14 @@ private[spark] abstract class YarnSchedulerBackend( sc: SparkContext) extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) { - val actorSystem = sc.env.actorSystem - if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { minRegisteredRatio = 0.8 } protected var totalExpectedExecutors = 0 - private val yarnSchedulerActor: ActorRef = - actorSystem.actorOf( - Props(new YarnSchedulerActor), - name = YarnSchedulerBackend.ACTOR_NAME) + private val yarnSchedulerActor: RpcEndpointRef = + rpcEnv.setupEndpoint(YarnSchedulerBackend.ACTOR_NAME, new YarnSchedulerActor(rpcEnv)) private implicit val askTimeout = AkkaUtils.askTimeout(sc.conf) @@ -55,16 +49,14 @@ private[spark] abstract class YarnSchedulerBackend( * This includes executors already pending or running. */ override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { - AkkaUtils.askWithReply[Boolean]( - RequestExecutors(requestedTotal), yarnSchedulerActor, askTimeout) + yarnSchedulerActor.askWithReply(RequestExecutors(requestedTotal)) } /** * Request that the ApplicationMaster kill the specified executors. */ override def doKillExecutors(executorIds: Seq[String]): Boolean = { - AkkaUtils.askWithReply[Boolean]( - KillExecutors(executorIds), yarnSchedulerActor, askTimeout) + yarnSchedulerActor.askWithReply(KillExecutors(executorIds)) } override def sufficientResourcesRegistered(): Boolean = { @@ -96,15 +88,10 @@ private[spark] abstract class YarnSchedulerBackend( /** * An actor that communicates with the ApplicationMaster. */ - private class YarnSchedulerActor extends Actor { - private var amActor: Option[ActorRef] = None - - override def preStart(): Unit = { - // Listen for disassociation events - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - } + private class YarnSchedulerActor(override val rpcEnv: RpcEnv) extends RpcEndpoint { + private var amActor: Option[RpcEndpointRef] = None - override def receive = { + override def receive(sender: RpcEndpointRef) = { case RegisterClusterManager => logInfo(s"ApplicationMaster registered as $sender") amActor = Some(sender) @@ -112,29 +99,31 @@ private[spark] abstract class YarnSchedulerBackend( case r: RequestExecutors => amActor match { case Some(actor) => - sender ! AkkaUtils.askWithReply[Boolean](r, actor, askTimeout) + actor.askWithReply(r) + sender.send(actor.askWithReply[Boolean](r)) case None => logWarning("Attempted to request executors before the AM has registered!") - sender ! false + sender.send(false) } case k: KillExecutors => amActor match { case Some(actor) => - sender ! AkkaUtils.askWithReply[Boolean](k, actor, askTimeout) + sender.send(actor.askWithReply[Boolean](k)) case None => logWarning("Attempted to kill executors before the AM has registered!") - sender ! false + sender.send(false) } case AddWebUIFilter(filterName, filterParams, proxyBase) => addWebUIFilter(filterName, filterParams, proxyBase) - sender ! true + sender.send(true) + } - case d: DisassociatedEvent => - if (amActor.isDefined && sender == amActor.get) { - logWarning(s"ApplicationMaster has disassociated: $d") - } + override def remoteConnectionTerminated(remoteAddress: RpcAddress): Unit = { + if (amActor.isDefined && remoteAddress == amActor.get.address) { + logWarning(s"ApplicationMaster has disassociated: $remoteAddress") + } } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index b3bd3110ac80..a058542529a9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -19,13 +19,11 @@ package org.apache.spark.scheduler.local import java.nio.ByteBuffer -import akka.actor.{Actor, ActorRef, Props} - import org.apache.spark.{Logging, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} +import org.apache.spark.rpc.{RpcEnv, RpcEndpointRef, RpcEndpoint} import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} -import org.apache.spark.util.ActorLogReceive private case class ReviveOffers() @@ -41,10 +39,11 @@ private case class StopExecutor() * and the TaskSchedulerImpl. */ private[spark] class LocalActor( + override val rpcEnv: RpcEnv, scheduler: TaskSchedulerImpl, executorBackend: LocalBackend, private val totalCores: Int) - extends Actor with ActorLogReceive with Logging { + extends RpcEndpoint with Logging { private var freeCores = totalCores @@ -54,7 +53,7 @@ private[spark] class LocalActor( private val executor = new Executor( localExecutorId, localExecutorHostname, SparkEnv.get, isLocal = true) - override def receiveWithLogging = { + override def receive(sender: RpcEndpointRef) = { case ReviveOffers => reviveOffers() @@ -90,31 +89,30 @@ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: extends SchedulerBackend with ExecutorBackend { private val appId = "local-" + System.currentTimeMillis - var localActor: ActorRef = null + var localActor: RpcEndpointRef = null override def start() { - localActor = SparkEnv.get.actorSystem.actorOf( - Props(new LocalActor(scheduler, this, totalCores)), - "LocalBackendActor") + localActor = SparkEnv.get.rpcEnv.setupEndpoint("LocalBackendActor", + new LocalActor(SparkEnv.get.rpcEnv, scheduler, this, totalCores)) } override def stop() { - localActor ! StopExecutor + localActor.send(StopExecutor) } override def reviveOffers() { - localActor ! ReviveOffers + localActor.send(ReviveOffers) } override def defaultParallelism() = scheduler.conf.getInt("spark.default.parallelism", totalCores) override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { - localActor ! KillTask(taskId, interruptThread) + localActor.send(KillTask(taskId, interruptThread)) } override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) { - localActor ! StatusUpdate(taskId, state, serializedData) + localActor.send(StatusUpdate(taskId, state, serializedData)) } override def applicationId(): String = appId diff --git a/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala b/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala index 332d0cbb2dc0..142a32eb1dce 100644 --- a/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala +++ b/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala @@ -43,7 +43,7 @@ private[spark] trait ActorLogReceive { private val _receiveWithLogging = receiveWithLogging - override def isDefinedAt(o: Any): Boolean = _receiveWithLogging.isDefinedAt(o) + final override def isDefinedAt(o: Any): Boolean = _receiveWithLogging.isDefinedAt(o) override def apply(o: Any): Unit = { if (log.isDebugEnabled) { From e08d762c375fc4f5127fd01f03e09b0ec0ed0f8e Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 6 Jan 2015 12:28:41 +0800 Subject: [PATCH 26/36] Tune the interface for network --- .../org/apache/spark/deploy/Client.scala | 16 ++++---- .../spark/deploy/client/AppClient.scala | 6 +-- .../apache/spark/deploy/master/Master.scala | 6 +-- .../apache/spark/deploy/worker/Worker.scala | 6 +-- .../spark/deploy/worker/WorkerWatcher.scala | 38 +++++++++---------- .../CoarseGrainedExecutorBackend.scala | 4 +- .../scala/org/apache/spark/rpc/RpcEnv.scala | 22 +++++++---- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 28 +++++++++++--- .../apache/spark/scheduler/DAGScheduler.scala | 2 +- .../CoarseGrainedSchedulerBackend.scala | 4 +- .../cluster/YarnSchedulerBackend.scala | 2 +- .../storage/BlockManagerMasterActor.scala | 4 +- .../deploy/worker/WorkerWatcherSuite.scala | 4 +- 13 files changed, 83 insertions(+), 59 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 0b4262144084..9fdd6a76c8c9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -17,7 +17,6 @@ package org.apache.spark.deploy -import akka.remote.AssociationErrorEvent import org.apache.log4j.{Level, Logger} import org.apache.spark.{Logging, SecurityManager, SparkConf} @@ -35,7 +34,7 @@ private class ClientActor(override val rpcEnv: RpcEnv, driverArgs: ClientArgumen var masterActor: RpcEndpointRef = _ - override def preStart() = { + override def onStart() = { masterActor = rpcEnv.setupEndpointRefByUrl(Master.toAkkaUrl(driverArgs.master)) println(s"Sending ${driverArgs.cmd} command to ${driverArgs.master}") @@ -110,7 +109,6 @@ private class ClientActor(override val rpcEnv: RpcEnv, driverArgs: ClientArgumen } override def receive(sender: RpcEndpointRef) = { - case SubmitDriverResponse(success, driverId, message) => println(message) if (success) pollAndReportStatus(driverId.get) else System.exit(-1) @@ -119,14 +117,16 @@ private class ClientActor(override val rpcEnv: RpcEnv, driverArgs: ClientArgumen println(message) if (success) pollAndReportStatus(driverId) else System.exit(-1) - case AssociationErrorEvent(cause, _, remoteAddress, _, _) => - println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") - println(s"Cause was: $cause") - System.exit(-1) } - override def remoteConnectionTerminated(remoteAddress: RpcAddress): Unit = { + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") + System.exit(-1) + } + + override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") + println(s"Cause was: $cause") System.exit(-1) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 209632ec3e63..996734575efc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -64,7 +64,7 @@ private[spark] class AppClient( private val scheduler = Executors.newScheduledThreadPool(1, Utils.namedThreadFactory("client-actor")) - override def preStart() { + override def onStart() { try { registerWithMaster() } catch { @@ -162,7 +162,7 @@ private[spark] class AppClient( stop() } - override def remoteConnectionTerminated(address: RpcAddress): Unit = { + override def onDisconnected(address: RpcAddress): Unit = { if (address == masterAddress) { logWarning(s"Connection to $address failed; waiting for master to reconnect...") markDisconnected() @@ -186,7 +186,7 @@ private[spark] class AppClient( } } - override def postStop() { + override def onStop() { registrationRetryTimer.foreach(_.cancel(true)) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 2c6b895ee3b7..7e2faa63f463 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -124,7 +124,7 @@ private[spark] class Master( throw new SparkException("spark.deploy.defaultCores must be positive") } - override def preStart() { + override def onStart() { logInfo("Starting Spark master at " + masterUrl) logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}") // Listen for remote client disconnection events, since they don't go through Akka's watch() @@ -170,7 +170,7 @@ private[spark] class Master( throw reason // throw it so that the master will be restarted } - override def postStop() { + override def onStop() { masterMetricsSystem.report() applicationMetricsSystem.report() // prevent the CompleteRecovery message sending to restarted master @@ -431,7 +431,7 @@ private[spark] class Master( } } - override def remoteConnectionTerminated(address: RpcAddress): Unit = { + override def onDisconnected(address: RpcAddress): Unit = { // The disconnected client could've been either a worker or an app; remove whichever it was logInfo(s"$address got disassociated, removing it.") addressToWorker.get(address).foreach(removeWorker) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 01fb8aeb1fc9..6d9424bce927 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -161,7 +161,7 @@ private[spark] class Worker( } } - override def preStart() { + override def onStart() { assert(!registered) logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format( host, port, cores, Utils.megabytesToString(memory))) @@ -479,7 +479,7 @@ private[spark] class Worker( maybeCleanupApplication(id) } - override def remoteConnectionTerminated(remoteAddress: RpcAddress): Unit = { + override def onDisconnected(remoteAddress: RpcAddress): Unit = { if (remoteAddress == masterAddress ) { logInfo(s"$remoteAddress Disassociated !") masterDisconnected() @@ -509,7 +509,7 @@ private[spark] class Worker( "worker-%s-%s-%d".format(createDateFormat.format(new Date), host, port) } - override def postStop() { + override def onStop() { metricsSystem.report() registrationRetryTimer.foreach(_.cancel(true)) executors.values.foreach(_.kill()) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 14c10b3dafe9..3e626ce7124a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -17,12 +17,9 @@ package org.apache.spark.deploy.worker -import akka.remote.{AssociatedEvent, AssociationErrorEvent, AssociationEvent} - import org.apache.spark.Logging import org.apache.spark.deploy.DeployMessages.SendHeartbeat import org.apache.spark.rpc.{RpcAddress, RpcEnv, RpcEndpointRef, RpcEndpoint} -import org.apache.spark.util.AkkaUtils /** * Actor which connects to a worker process and terminates the JVM if the connection is severed. @@ -31,7 +28,7 @@ import org.apache.spark.util.AkkaUtils private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: String) extends RpcEndpoint with Logging { - override def preStart() { + override def onStart() { logInfo(s"Connecting to worker $workerUrl") if (!isTesting) { val worker = rpcEnv.setupEndpointRefByUrl(workerUrl) @@ -53,28 +50,29 @@ private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: Strin def exitNonZero() = if (isTesting) isShutDown = true else System.exit(-1) override def receive(sender: RpcEndpointRef) = { - case AssociatedEvent(localAddress, remoteAddress, inbound) - if isWorker(AkkaUtils.akkaAddressToRpcAddress(remoteAddress)) => - logInfo(s"Successfully connected to $workerUrl") - - case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) - if isWorker(AkkaUtils.akkaAddressToRpcAddress(remoteAddress)) => - // These logs may not be seen if the worker (and associated pipe) has died - logError(s"Could not initialize connection to worker $workerUrl. Exiting.") - logError(s"Error was: $cause") - exitNonZero() - - case e: AssociationEvent => - // pass through association events relating to other remote actor systems - case e => logWarning(s"Received unexpected actor system event: $e") } - override def remoteConnectionTerminated(remoteAddress: RpcAddress): Unit = { - if(isWorker(remoteAddress)) { + override def onConnected(remoteAddress: RpcAddress): Unit = { + if (isWorker(remoteAddress)) { + logInfo(s"Successfully connected to $workerUrl") + } + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (isWorker(remoteAddress)) { // This log message will never be seen logError(s"Lost connection to worker actor $workerUrl. Exiting.") exitNonZero() } } + + override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + if (isWorker(remoteAddress)) { + // These logs may not be seen if the worker (and associated pipe) has died + logError(s"Could not initialize connection to worker $workerUrl. Exiting.") + logError(s"Error was: $cause") + exitNonZero() + } + } } diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index bf0b5c21ed3a..b9ee77718456 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -44,7 +44,7 @@ private[spark] class CoarseGrainedExecutorBackend( var executor: Executor = null var driver: RpcEndpointRef = _ - override def preStart(): Unit = { + override def onStart(): Unit = { // self is valid now. So now we can use `send` logInfo("Connecting to driver: " + driverUrl) driver = rpcEnv.setupEndpointRefByUrl(driverUrl) @@ -88,7 +88,7 @@ private[spark] class CoarseGrainedExecutorBackend( env.actorSystem.shutdown() } - override def remoteConnectionTerminated(remoteAddress: RpcAddress): Unit = { + override def onDisconnected(remoteAddress: RpcAddress): Unit = { logError(s"Driver $remoteAddress disassociated! Shutting down.") System.exit(1) } diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index dcd526117f4a..1e4615add8ac 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -85,8 +85,6 @@ trait RpcEndpoint { val rpcEnv: RpcEnv - def preStart(): Unit = {} - /** * Provide the implicit sender. `self` will become valid when `preStart` is called. */ @@ -95,6 +93,8 @@ trait RpcEndpoint { rpcEnv.endpointRef(this) } + def onStart(): Unit = {} + /** * Same assumption like Actor: messages sent to a RpcEndpoint will be delivered in sequence, and * messages from the same RpcEndpoint will be delivered in order. @@ -107,18 +107,26 @@ trait RpcEndpoint { /** * Call onError when any exception is thrown during handling messages. * - * @param e + * @param cause */ - def onError(e: Throwable): Unit = { + def onError(cause: Throwable): Unit = { // By default, throw e and let RpcEnv handle it - throw e + throw cause + } + + def onConnected(remoteAddress: RpcAddress): Unit = { + // By default, do nothing. } - def remoteConnectionTerminated(remoteAddress: RpcAddress): Unit = { + def onDisconnected(remoteAddress: RpcAddress): Unit = { // By default, do nothing. } - def postStop(): Unit = {} + def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + // By default, throw e and let RpcEnv handle it + } + + def onStop(): Unit = {} final def stop(): Unit = { rpcEnv.stop(self) diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 6f3e74c4156a..b24f42c4fd92 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -26,7 +26,7 @@ import scala.language.postfixOps import akka.actor.{ActorRef, Actor, Props, ActorSystem} import akka.pattern.{ask => akkaAsk} -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} +import akka.remote._ import org.apache.spark.{Logging, SparkException, SparkConf} import org.apache.spark.rpc._ @@ -50,21 +50,39 @@ class AkkaRpcEnv(val actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { registerEndpoint(endpoint, endpointRef) override def preStart(): Unit = { - endpoint.preStart() + endpoint.onStart() // Listen for remote client disconnection events, // since they don't go through Akka's watch() context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) } override def receiveWithLogging: Receive = { + case AssociatedEvent(_, remoteAddress, _) => + try { + // TODO How to handle that a remoteAddress doesn't have host & port + endpoint.onConnected(AkkaUtils.akkaAddressToRpcAddress(remoteAddress)) + } catch { + case NonFatal(e) => endpoint.onError(e) + } + case DisassociatedEvent(_, remoteAddress, _) => try { // TODO How to handle that a remoteAddress doesn't have host & port - endpoint.remoteConnectionTerminated( - AkkaUtils.akkaAddressToRpcAddress(remoteAddress)) + endpoint.onDisconnected(AkkaUtils.akkaAddressToRpcAddress(remoteAddress)) } catch { case NonFatal(e) => endpoint.onError(e) } + + case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) => + try { + // TODO How to handle that a remoteAddress doesn't have host & port + endpoint.onNetworkError(cause, AkkaUtils.akkaAddressToRpcAddress(remoteAddress)) + } catch { + case NonFatal(e) => endpoint.onError(e) + } + case e: RemotingLifecycleEvent => + // ignore? + case message: Any => try { logInfo("Received RPC message: " + message) @@ -78,7 +96,7 @@ class AkkaRpcEnv(val actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { } override def postStop(): Unit = { - endpoint.postStop() + endpoint.onStop() } }), name = name) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 95ed84e74b91..365c8145c96a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1386,7 +1386,7 @@ private[scheduler] class DAGSchedulerEventProcessActor( dagScheduler.resubmitFailedStages() } - override def postStop() { + override def onStop() { // Cancel any active jobs in postStop hook dagScheduler.cleanUpAfterSchedulerStop() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 3e05afc5aa16..da1c770b0d5f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -74,7 +74,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp private val reviveScheduler = Executors.newScheduledThreadPool(1, Utils.namedThreadFactory("driver-revive-scheduler")) - override def preStart() { + override def onStart() { // Periodically revive offers to allow delay scheduling to work val reviveInterval = conf.getLong("spark.scheduler.revive.interval", 1000) reviveScheduler.scheduleAtFixedRate(new Runnable { @@ -195,7 +195,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } } - override def remoteConnectionTerminated(remoteAddress: RpcAddress): Unit = { + override def onDisconnected(remoteAddress: RpcAddress): Unit = { addressToExecutorId.get(remoteAddress).foreach(removeExecutor(_, "remote Akka client disassociated")) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 1dbc49ac98e8..1cfc9bb8f5f3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -120,7 +120,7 @@ private[spark] abstract class YarnSchedulerBackend( sender.send(true) } - override def remoteConnectionTerminated(remoteAddress: RpcAddress): Unit = { + override def onDisconnected(remoteAddress: RpcAddress): Unit = { if (amActor.isDefined && remoteAddress == amActor.get.address) { logWarning(s"ApplicationMaster has disassociated: $remoteAddress") } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index c708b1c32a15..5627c7a48d71 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -61,8 +61,8 @@ class BlockManagerMasterActor(override val rpcEnv: RpcEnv, val isLocal: Boolean, var timeoutCheckingTask: ScheduledFuture[_] = null - override def preStart() { - super.preStart() + override def onStart() { + super.onStart() timeoutCheckingTask = scheduler.scheduleAtFixedRate(new Runnable { override def run(): Unit = self.send(ExpireDeadHosts) }, 0, checkTimeoutInterval, TimeUnit.MILLISECONDS) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala index 4a7530426d88..e3b0571cbe1a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala @@ -32,7 +32,7 @@ class WorkerWatcherSuite extends FunSuite { val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) workerWatcher.setTesting(testing = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) - workerWatcher.remoteConnectionTerminated(AkkaUtils.akkaAddressToRpcAddress(targetWorkerAddress)) + workerWatcher.onDisconnected(AkkaUtils.akkaAddressToRpcAddress(targetWorkerAddress)) assert(workerWatcher.isShutDown) rpcEnv.stopAll() actorSystem.shutdown() @@ -47,7 +47,7 @@ class WorkerWatcherSuite extends FunSuite { val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) workerWatcher.setTesting(testing = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) - workerWatcher.remoteConnectionTerminated(AkkaUtils.akkaAddressToRpcAddress(otherAkkaAddress)) + workerWatcher.onDisconnected(AkkaUtils.akkaAddressToRpcAddress(otherAkkaAddress)) assert(!workerWatcher.isShutDown) rpcEnv.stopAll() actorSystem.shutdown() From 1e32c4feb5a1b8adc2503bfb6896e56751f8952b Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 6 Jan 2015 15:26:14 +0800 Subject: [PATCH 27/36] Add NetworkRpcEndpoint for RpcEndpoint interested in network events --- .../org/apache/spark/deploy/Client.scala | 4 ++-- .../spark/deploy/client/AppClient.scala | 4 ++-- .../apache/spark/deploy/master/Master.scala | 4 ++-- .../apache/spark/deploy/worker/Worker.scala | 7 +++---- .../spark/deploy/worker/WorkerWatcher.scala | 4 ++-- .../CoarseGrainedExecutorBackend.scala | 4 ++-- .../scala/org/apache/spark/rpc/RpcEnv.scala | 19 +++++++++++------- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 20 ++++++++++--------- .../CoarseGrainedSchedulerBackend.scala | 4 ++-- .../cluster/YarnSchedulerBackend.scala | 4 ++-- .../org/apache/spark/util/AkkaUtils.scala | 1 + 11 files changed, 41 insertions(+), 34 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 9fdd6a76c8c9..3a3588844ccc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -23,14 +23,14 @@ import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.rpc.akka.AkkaRpcEnv -import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, RpcEnv, RpcEndpointRef} +import org.apache.spark.rpc.{RpcAddress, RpcEnv, RpcEndpointRef, NetworkRpcEndpoint} import org.apache.spark.util.{AkkaUtils, Utils} /** * Proxy that relays messages to the driver. */ private class ClientActor(override val rpcEnv: RpcEnv, driverArgs: ClientArguments, conf: SparkConf) - extends RpcEndpoint with Logging { + extends NetworkRpcEndpoint with Logging { var masterActor: RpcEndpointRef = _ diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 996734575efc..53af87f67598 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -28,7 +28,7 @@ import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.Master -import org.apache.spark.rpc.{RpcAddress, RpcEnv, RpcEndpointRef, RpcEndpoint} +import org.apache.spark.rpc._ import org.apache.spark.util.Utils /** @@ -55,7 +55,7 @@ private[spark] class AppClient( var registered = false var activeMasterUrl: String = null - class ClientActor(override val rpcEnv: RpcEnv) extends RpcEndpoint with Logging { + class ClientActor(override val rpcEnv: RpcEnv) extends NetworkRpcEndpoint with Logging { var master: RpcEndpointRef = null var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times var alreadyDead = false // To avoid calling listener.dead() multiple times diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 7e2faa63f463..e9417acb249a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -42,7 +42,7 @@ import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.deploy.master.MasterMessages._ import org.apache.spark.deploy.master.ui.MasterWebUI import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, RpcEnv, RpcEndpointRef} +import org.apache.spark.rpc.{RpcAddress, NetworkRpcEndpoint, RpcEnv, RpcEndpointRef} import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus} import org.apache.spark.ui.SparkUI import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} @@ -53,7 +53,7 @@ private[spark] class Master( port: Int, webUiPort: Int, val securityMgr: SecurityManager) - extends RpcEndpoint with Logging with LeaderElectable { + extends NetworkRpcEndpoint with Logging with LeaderElectable { val scheduler = Executors.newScheduledThreadPool(1, Utils.namedThreadFactory("check-worker-timeout")) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 6d9424bce927..a465d8546f48 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -23,9 +23,6 @@ import java.text.SimpleDateFormat import java.util.concurrent.{ScheduledFuture, TimeUnit, Executors} import java.util.{UUID, Date} -import org.apache.spark.rpc.akka.AkkaRpcEnv -import org.apache.spark.rpc.{RpcAddress, RpcEnv, RpcEndpointRef, RpcEndpoint} - import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} import scala.concurrent.ExecutionContext @@ -41,6 +38,8 @@ import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.rpc.akka.AkkaRpcEnv +import org.apache.spark.rpc.{RpcAddress, RpcEnv, RpcEndpointRef, NetworkRpcEndpoint} import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} /** @@ -59,7 +58,7 @@ private[spark] class Worker( workDirPath: String = null, val conf: SparkConf, val securityMgr: SecurityManager) - extends RpcEndpoint with Logging { + extends NetworkRpcEndpoint with Logging { val scheduler = Executors.newScheduledThreadPool(1, Utils.namedThreadFactory("worker-scheduler")) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 3e626ce7124a..ce8c095dfcbe 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -19,14 +19,14 @@ package org.apache.spark.deploy.worker import org.apache.spark.Logging import org.apache.spark.deploy.DeployMessages.SendHeartbeat -import org.apache.spark.rpc.{RpcAddress, RpcEnv, RpcEndpointRef, RpcEndpoint} +import org.apache.spark.rpc._ /** * Actor which connects to a worker process and terminates the JVM if the connection is severed. * Provides fate sharing between a worker and its associated child processes. */ private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: String) - extends RpcEndpoint with Logging { + extends NetworkRpcEndpoint with Logging { override def onStart() { logInfo(s"Connecting to worker $workerUrl") diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index b9ee77718456..21c4e1e544d0 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -24,7 +24,7 @@ import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.worker.WorkerWatcher import org.apache.spark.rpc.akka.AkkaRpcEnv -import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, RpcEndpointRef} +import org.apache.spark.rpc.{RpcAddress, NetworkRpcEndpoint, RpcEndpointRef} import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} @@ -35,7 +35,7 @@ private[spark] class CoarseGrainedExecutorBackend( hostPort: String, cores: Int, env: SparkEnv) - extends RpcEndpoint with ExecutorBackend with Logging { + extends NetworkRpcEndpoint with ExecutorBackend with Logging { override val rpcEnv = env.rpcEnv diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 1e4615add8ac..59456ee3d407 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -114,6 +114,18 @@ trait RpcEndpoint { throw cause } + def onStop(): Unit = {} + + final def stop(): Unit = { + rpcEnv.stop(self) + } +} + +/** + * A RpcEndoint interested in network events. + */ +trait NetworkRpcEndpoint extends RpcEndpoint { + def onConnected(remoteAddress: RpcAddress): Unit = { // By default, do nothing. } @@ -125,15 +137,8 @@ trait RpcEndpoint { def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { // By default, throw e and let RpcEnv handle it } - - def onStop(): Unit = {} - - final def stop(): Unit = { - rpcEnv.stop(self) - } } - object RpcEndpoint { final val noSender: RpcEndpointRef = null } diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index b24f42c4fd92..55ca4775b15f 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -51,32 +51,34 @@ class AkkaRpcEnv(val actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { override def preStart(): Unit = { endpoint.onStart() - // Listen for remote client disconnection events, - // since they don't go through Akka's watch() - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + if (endpoint.isInstanceOf[NetworkRpcEndpoint]) { + // Listen for remote client disconnection events, + // since they don't go through Akka's watch() + context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + } } override def receiveWithLogging: Receive = { case AssociatedEvent(_, remoteAddress, _) => try { - // TODO How to handle that a remoteAddress doesn't have host & port - endpoint.onConnected(AkkaUtils.akkaAddressToRpcAddress(remoteAddress)) + endpoint.asInstanceOf[NetworkRpcEndpoint]. + onConnected(AkkaUtils.akkaAddressToRpcAddress(remoteAddress)) } catch { case NonFatal(e) => endpoint.onError(e) } case DisassociatedEvent(_, remoteAddress, _) => try { - // TODO How to handle that a remoteAddress doesn't have host & port - endpoint.onDisconnected(AkkaUtils.akkaAddressToRpcAddress(remoteAddress)) + endpoint.asInstanceOf[NetworkRpcEndpoint]. + onDisconnected(AkkaUtils.akkaAddressToRpcAddress(remoteAddress)) } catch { case NonFatal(e) => endpoint.onError(e) } case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) => try { - // TODO How to handle that a remoteAddress doesn't have host & port - endpoint.onNetworkError(cause, AkkaUtils.akkaAddressToRpcAddress(remoteAddress)) + endpoint.asInstanceOf[NetworkRpcEndpoint]. + onNetworkError(cause, AkkaUtils.akkaAddressToRpcAddress(remoteAddress)) } catch { case NonFatal(e) => endpoint.onError(e) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index da1c770b0d5f..cb946ef0d3c0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -23,7 +23,7 @@ import java.util.concurrent.{TimeUnit, Executors} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import org.apache.spark.{ExecutorAllocationClient, Logging, SparkEnv, SparkException, TaskState} -import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, RpcEnv, RpcEndpointRef} +import org.apache.spark.rpc.{RpcAddress, NetworkRpcEndpoint, RpcEnv, RpcEndpointRef} import org.apache.spark.scheduler.{SchedulerBackend, SlaveLost, TaskDescription, TaskSchedulerImpl, WorkerOffer} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.{SerializableBuffer, AkkaUtils, Utils} @@ -66,7 +66,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp private val executorsPendingToRemove = new HashSet[String] class DriverActor(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) - extends RpcEndpoint with Logging { + extends NetworkRpcEndpoint with Logging { override protected def log = CoarseGrainedSchedulerBackend.this.log private val addressToExecutorId = new HashMap[RpcAddress, String] diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 1cfc9bb8f5f3..83cb78ff9e00 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -18,7 +18,7 @@ package org.apache.spark.scheduler.cluster import org.apache.spark.SparkContext -import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv, RpcEndpoint} +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv, NetworkRpcEndpoint} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.ui.JettyUtils @@ -88,7 +88,7 @@ private[spark] abstract class YarnSchedulerBackend( /** * An actor that communicates with the ApplicationMaster. */ - private class YarnSchedulerActor(override val rpcEnv: RpcEnv) extends RpcEndpoint { + private class YarnSchedulerActor(override val rpcEnv: RpcEnv) extends NetworkRpcEndpoint { private var amActor: Option[RpcEndpointRef] = None override def receive(sender: RpcEndpointRef) = { diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 21c6ac6007e9..86b67de98ba4 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -237,6 +237,7 @@ private[spark] object AkkaUtils extends Logging { } def akkaAddressToRpcAddress(akkaAddress: Address): RpcAddress = { + // TODO How to handle that a remoteAddress doesn't have host & port RpcAddress(akkaAddress.host.getOrElse("localhost"), akkaAddress.port.getOrElse(-1)) } } From 9a9c1b1f23cb2c2390a05a07933c5dc73bd1bc3b Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 6 Jan 2015 15:29:54 +0800 Subject: [PATCH 28/36] Fix the code style --- core/src/main/scala/org/apache/spark/deploy/Client.scala | 3 ++- .../main/scala/org/apache/spark/deploy/master/Master.scala | 4 ++-- core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala | 6 ++++-- .../scala/org/apache/spark/storage/BlockManagerMaster.scala | 3 ++- 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 3a3588844ccc..3b1a5b3de479 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -84,7 +84,8 @@ private class ClientActor(override val rpcEnv: RpcEnv, driverArgs: ClientArgumen println(s"... waiting before polling master for driver state") Thread.sleep(5000) println("... polling master for driver state") - val statusResponse = masterActor.askWithReply[DriverStatusResponse](RequestDriverStatus(driverId)) + val statusResponse = + masterActor.askWithReply[DriverStatusResponse](RequestDriverStatus(driverId)) statusResponse.found match { case false => diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index e9417acb249a..c066be4d552b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -418,8 +418,8 @@ private[spark] class Master( } case RequestMasterState => { - sender.send(MasterStateResponse(host, port, workers.toArray, apps.toArray, completedApps.toArray, - drivers.toArray, completedDrivers.toArray, state)) + sender.send(MasterStateResponse(host, port, workers.toArray, apps.toArray, + completedApps.toArray, drivers.toArray, completedDrivers.toArray, state)) } case CheckForWorkerTimeOut => { diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 59456ee3d407..b918dfb88868 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -19,6 +19,8 @@ package org.apache.spark.rpc import java.util.concurrent.ConcurrentHashMap +import org.apache.spark.deploy.master.Master + import scala.concurrent.Future import scala.concurrent.duration.{FiniteDuration, Duration} import scala.reflect.ClassTag @@ -146,7 +148,7 @@ object RpcEndpoint { /** * A reference for a remote [[RpcEndpoint]]. */ -trait RpcEndpointRef { +trait RpcEndpointRef {Master def address: RpcAddress @@ -169,4 +171,4 @@ case class RpcAddress(host: String, port: Int) { val hostPort: String = host + ":" + port override val toString: String = hostPort -} \ No newline at end of file +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 4743a9f64a34..1906de736244 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -43,7 +43,8 @@ class BlockManagerMaster( } /** Register the BlockManager's id with the driver. */ - def registerBlockManager(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: RpcEndpointRef) { + def registerBlockManager(blockManagerId: BlockManagerId, maxMemSize: Long, + slaveActor: RpcEndpointRef) { logInfo("Trying to register BlockManager") tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveActor)) logInfo("Registered BlockManager") From a05cba50fc41d5aee8f8f33d3611820ad8c6a2b8 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 6 Jan 2015 16:17:33 +0800 Subject: [PATCH 29/36] Fix AppClient --- .../apache/spark/deploy/client/AppClient.scala | 17 ++++++++--------- .../scala/org/apache/spark/rpc/RpcEnv.scala | 9 +++++++++ 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 53af87f67598..3322bd4b1df8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -21,9 +21,6 @@ import java.util.concurrent.{ScheduledFuture, TimeUnit, Executors, TimeoutExcept import scala.concurrent.duration._ -import akka.actor._ -import akka.remote.AssociationErrorEvent - import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ @@ -115,9 +112,8 @@ private[spark] class AppClient( } } - private def isPossibleMaster(remoteUrl: Address) = { - masterUrls.map(s => Master.toAkkaUrl(s)) - .map(u => AddressFromURIString(u).hostPort) + private def isPossibleMaster(remoteUrl: RpcAddress) = { + masterUrls.map(RpcAddress.fromURIString(_).hostPort) .contains(remoteUrl.hostPort) } @@ -153,9 +149,6 @@ private[spark] class AppClient( alreadyDisconnected = false sender.send(MasterChangeAcknowledged(appId)) - case AssociationErrorEvent(cause, _, address, _, _) if isPossibleMaster(address) => - logWarning(s"Could not connect to $address: $cause") - case StopAppClient => markDead("Application has been stopped.") sender.send(true) @@ -169,6 +162,12 @@ private[spark] class AppClient( } } + override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + if (isPossibleMaster(remoteAddress)) { + logWarning(s"Could not connect to $remoteAddress: $cause") + } + } + /** * Notify the listener that we disconnected, if we hadn't already done so before. */ diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index b918dfb88868..cd8a508af653 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -172,3 +172,12 @@ case class RpcAddress(host: String, port: Int) { override val toString: String = hostPort } + +object RpcAddress { + + def fromURIString(uri: String): RpcAddress = { + val u = new java.net.URI(uri) + RpcAddress(u.getHost, u.getPort) + } + +} From b80d8b1b4930ffa39916e571e29626bc5d00c9d9 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 6 Jan 2015 19:12:24 +0800 Subject: [PATCH 30/36] Hide ActorSystem into AkkaRpcEnv --- .../scala/org/apache/spark/SparkEnv.scala | 21 ++++---- .../org/apache/spark/deploy/Client.scala | 7 ++- .../spark/deploy/LocalSparkCluster.scala | 17 ++++--- .../spark/deploy/client/TestClient.scala | 7 ++- .../apache/spark/deploy/master/Master.scala | 23 ++++----- .../spark/deploy/worker/DriverWrapper.scala | 4 +- .../apache/spark/deploy/worker/Worker.scala | 20 ++++---- .../CoarseGrainedExecutorBackend.scala | 11 ++--- .../org/apache/spark/executor/Executor.scala | 8 +--- .../scala/org/apache/spark/rpc/RpcEnv.scala | 7 ++- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 33 +++++++++++-- .../apache/spark/scheduler/DAGScheduler.scala | 1 - .../spark/scheduler/TaskSchedulerImpl.scala | 13 +++-- .../apache/spark/MapOutputTrackerSuite.scala | 36 ++++---------- .../deploy/worker/WorkerWatcherSuite.scala | 8 +--- .../org/apache/spark/rpc/RpcEnvSuite.scala | 4 +- .../spark/rpc/akka/AkkaRpcEnvSuite.scala | 13 +---- .../spark/scheduler/SparkListenerSuite.scala | 4 +- .../scheduler/TaskResultGetterSuite.scala | 7 ++- .../BlockManagerReplicationSuite.scala | 15 ++---- .../spark/storage/BlockManagerSuite.scala | 15 ++---- .../apache/spark/util/AkkaUtilsSuite.scala | 48 +++++++------------ 22 files changed, 132 insertions(+), 190 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 222759377cd7..95d69f8efe44 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -24,7 +24,6 @@ import scala.collection.JavaConversions._ import scala.collection.mutable import scala.util.Properties -import akka.actor.ActorSystem import com.google.common.collect.MapMaker import org.apache.spark.annotation.DeveloperApi @@ -40,7 +39,7 @@ import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} import org.apache.spark.storage._ -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -55,7 +54,6 @@ import org.apache.spark.util.{AkkaUtils, Utils} @DeveloperApi class SparkEnv ( val executorId: String, - val actorSystem: ActorSystem, val rpcEnv: RpcEnv, val serializer: Serializer, val closureSerializer: Serializer, @@ -72,6 +70,9 @@ class SparkEnv ( val shuffleMemoryManager: ShuffleMemoryManager, val conf: SparkConf) extends Logging { + // TODO actorSystem is used by Streaming + val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem + private[spark] var isStopped = false private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() @@ -89,7 +90,6 @@ class SparkEnv ( blockManager.stop() blockManager.master.stop() metricsSystem.stop() - actorSystem.shutdown() rpcEnv.stopAll() // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut // down, but let's call it anyway in case it gets fixed in a later release @@ -216,18 +216,14 @@ object SparkEnv extends Logging { val securityManager = new SecurityManager(conf) // Create the ActorSystem for Akka and get the port it binds to. - val (actorSystem, boundPort) = { - val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName - AkkaUtils.createActorSystem(actorSystemName, hostname, port, conf, securityManager) - } - - val rpcEnv = new AkkaRpcEnv(actorSystem, conf) + val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName + val rpcEnv = AkkaRpcEnv(actorSystemName, hostname, port, conf, securityManager) // Figure out which port Akka actually bound to in case the original port is 0 or occupied. if (isDriver) { - conf.set("spark.driver.port", boundPort.toString) + conf.set("spark.driver.port", rpcEnv.boundPort.toString) } else { - conf.set("spark.executor.port", boundPort.toString) + conf.set("spark.executor.port", rpcEnv.boundPort.toString) } // Create an instance of the class with the given name, possibly initializing it with our conf @@ -354,7 +350,6 @@ object SparkEnv extends Logging { new SparkEnv( executorId, - actorSystem, rpcEnv, serializer, closureSerializer, diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 3b1a5b3de479..43a5f17d7a34 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -24,7 +24,7 @@ import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.rpc.{RpcAddress, RpcEnv, RpcEndpointRef, NetworkRpcEndpoint} -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.Utils /** * Proxy that relays messages to the driver. @@ -152,12 +152,11 @@ object Client { conf.set("akka.loglevel", driverArgs.logLevel.toString.replace("WARN", "WARNING")) Logger.getRootLogger.setLevel(driverArgs.logLevel) - val (actorSystem, _) = AkkaUtils.createActorSystem( + val rpcEnv = AkkaRpcEnv( "driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf)) - val rpcEnv = new AkkaRpcEnv(actorSystem, conf) rpcEnv.setupEndpoint("client-actor", new ClientActor(rpcEnv, driverArgs, conf)) - actorSystem.awaitTermination() + rpcEnv.awaitTermination() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index 9a7a113c9571..f23ae9a65a62 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -19,11 +19,10 @@ package org.apache.spark.deploy import scala.collection.mutable.ArrayBuffer -import akka.actor.ActorSystem - import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.worker.Worker import org.apache.spark.deploy.master.Master +import org.apache.spark.rpc.RpcEnv import org.apache.spark.util.Utils /** @@ -37,22 +36,22 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I extends Logging { private val localHostname = Utils.localHostName() - private val masterActorSystems = ArrayBuffer[ActorSystem]() - private val workerActorSystems = ArrayBuffer[ActorSystem]() + private val masterActorSystems = ArrayBuffer[RpcEnv]() + private val workerActorSystems = ArrayBuffer[RpcEnv]() def start(): Array[String] = { logInfo("Starting a local Spark cluster with " + numWorkers + " workers.") /* Start the Master */ val conf = new SparkConf(false) - val (masterSystem, masterPort, _) = Master.startSystemAndActor(localHostname, 0, 0, conf) + val (masterSystem, _) = Master.startSystemAndActor(localHostname, 0, 0, conf) masterActorSystems += masterSystem - val masterUrl = "spark://" + localHostname + ":" + masterPort + val masterUrl = "spark://" + localHostname + ":" + masterSystem.boundPort val masters = Array(masterUrl) /* Start the Workers */ for (workerNum <- 1 to numWorkers) { - val (workerSystem, _) = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker, + val workerSystem = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker, memoryPerWorker, masters, null, Some(workerNum)) workerActorSystems += workerSystem } @@ -65,9 +64,9 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I // Stop the workers before the master so they don't get upset that it disconnected // TODO: In Akka 2.1.x, ActorSystem.awaitTermination hangs when you have remote actors! // This is unfortunate, but for now we just comment it out. - workerActorSystems.foreach(_.shutdown()) + workerActorSystems.foreach(_.stopAll()) // workerActorSystems.foreach(_.awaitTermination()) - masterActorSystems.foreach(_.shutdown()) + masterActorSystems.foreach(_.stopAll()) // masterActorSystems.foreach(_.awaitTermination()) masterActorSystems.clear() workerActorSystems.clear() diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala index 10b86294a7c0..07ccfe4af353 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala @@ -20,7 +20,7 @@ package org.apache.spark.deploy.client import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.{SecurityManager, SparkConf, Logging} import org.apache.spark.deploy.{ApplicationDescription, Command} -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.Utils private[spark] object TestClient { @@ -47,14 +47,13 @@ private[spark] object TestClient { def main(args: Array[String]) { val url = args(0) val conf = new SparkConf - val (actorSystem, _) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0, + val rpcEnv = AkkaRpcEnv("spark", Utils.localIpAddress, 0, conf = conf, securityManager = new SecurityManager(conf)) - val rpcEnv = new AkkaRpcEnv(actorSystem, conf) val desc = new ApplicationDescription("TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map(), Seq(), Seq(), Seq()), "ignored") val listener = new TestListener val client = new AppClient(rpcEnv, Array(url), desc, listener, new SparkConf) client.start() - actorSystem.awaitTermination() + rpcEnv.awaitTermination() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index c066be4d552b..3261cf679ef0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -23,12 +23,9 @@ import java.text.SimpleDateFormat import java.util.concurrent.{ScheduledFuture, TimeUnit, Executors} import java.util.Date -import org.apache.spark.rpc.akka.AkkaRpcEnv - import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.util.Random -import akka.actor.ActorSystem import akka.serialization.Serialization import akka.serialization.SerializationExtension import org.apache.hadoop.fs.Path @@ -42,10 +39,11 @@ import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.deploy.master.MasterMessages._ import org.apache.spark.deploy.master.ui.MasterWebUI import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.rpc.{RpcAddress, NetworkRpcEndpoint, RpcEnv, RpcEndpointRef} import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus} import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} +import org.apache.spark.util.{SignalLogger, Utils} private[spark] class Master( override val rpcEnv: RpcEnv, @@ -58,7 +56,8 @@ private[spark] class Master( val scheduler = Executors.newScheduledThreadPool(1, Utils.namedThreadFactory("check-worker-timeout")) - private def internalActorSystem: ActorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem + // TODO hide the actor system + private def internalActorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem val conf = new SparkConf val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) @@ -856,8 +855,8 @@ private[spark] object Master extends Logging { SignalLogger.register(log) val conf = new SparkConf val args = new MasterArguments(argStrings, conf) - val (actorSystem, _, _) = startSystemAndActor(args.host, args.port, args.webUiPort, conf) - actorSystem.awaitTermination() + val (rpcEnv, _) = startSystemAndActor(args.host, args.port, args.webUiPort, conf) + rpcEnv.awaitTermination() } /** Returns an `akka.tcp://...` URL for the Master actor given a sparkUrl `spark://host:ip`. */ @@ -874,14 +873,12 @@ private[spark] object Master extends Logging { host: String, port: Int, webUiPort: Int, - conf: SparkConf): (ActorSystem, Int, Int) = { + conf: SparkConf): (RpcEnv, Int) = { val securityMgr = new SecurityManager(conf) - val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf, - securityManager = securityMgr) - val rpcEnv = new AkkaRpcEnv(actorSystem, conf) + val rpcEnv = AkkaRpcEnv(systemName, host, port, conf = conf, securityManager = securityMgr) val actor = rpcEnv.setupEndpoint(actorName, - new Master(rpcEnv, host, boundPort, webUiPort, securityMgr)) + new Master(rpcEnv, host, rpcEnv.boundPort, webUiPort, securityMgr)) val resp = actor.askWithReply[WebUIPortResponse](RequestWebUIPort) - (actorSystem, boundPort, resp.webUIBoundPort) + (rpcEnv, resp.webUIBoundPort) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index c44dd05d31bb..b5bd626d0957 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -30,9 +30,8 @@ object DriverWrapper { args.toList match { case workerUrl :: mainClass :: extraArgs => val conf = new SparkConf() - val (actorSystem, _) = AkkaUtils.createActorSystem("Driver", + val rpcEnv = AkkaRpcEnv("Driver", Utils.localHostName(), 0, conf, new SecurityManager(conf)) - val rpcEnv = new AkkaRpcEnv(actorSystem, conf) rpcEnv.setupEndpoint("workerWatcher", new WorkerWatcher(rpcEnv, workerUrl)) // Delegate to supplied main class @@ -41,7 +40,6 @@ object DriverWrapper { mainMethod.invoke(null, extraArgs.toArray[String]) rpcEnv.stopAll() - actorSystem.shutdown() case _ => System.err.println("Usage: DriverWrapper [options]") diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index a465d8546f48..da908da283b9 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -30,8 +30,6 @@ import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.Random -import akka.actor.ActorSystem - import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ExecutorDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ @@ -40,7 +38,7 @@ import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.metrics.MetricsSystem import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.rpc.{RpcAddress, RpcEnv, RpcEndpointRef, NetworkRpcEndpoint} -import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} +import org.apache.spark.util.{SignalLogger, Utils} /** * @param masterUrls Each url should look like spark://host:port. @@ -524,9 +522,9 @@ private[spark] object Worker extends Logging { SignalLogger.register(log) val conf = new SparkConf val args = new WorkerArguments(argStrings, conf) - val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores, + val rpcEnv = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores, args.memory, args.masters, args.workDir) - actorSystem.awaitTermination() + rpcEnv.awaitTermination() } def startSystemAndActor( @@ -537,19 +535,17 @@ private[spark] object Worker extends Logging { memory: Int, masterUrls: Array[String], workDir: String, - workerNumber: Option[Int] = None): (ActorSystem, Int) = { + workerNumber: Option[Int] = None): RpcEnv = { // The LocalSparkCluster runs multiple local sparkWorkerX actor systems val conf = new SparkConf val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("") val actorName = "Worker" val securityMgr = new SecurityManager(conf) - val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, - conf = conf, securityManager = securityMgr) - val rpcEnv = new AkkaRpcEnv(actorSystem, conf) - rpcEnv.setupEndpoint(actorName, new Worker(rpcEnv, host, boundPort, webUiPort, cores, memory, - masterUrls, systemName, actorName, workDir, conf, securityMgr)) - (actorSystem, boundPort) + val rpcEnv = AkkaRpcEnv(systemName, host, port, conf = conf, securityManager = securityMgr) + rpcEnv.setupEndpoint(actorName, new Worker(rpcEnv, host, rpcEnv.boundPort, webUiPort, cores, + memory, masterUrls, systemName, actorName, workDir, conf, securityMgr)) + rpcEnv } } diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 21c4e1e544d0..8a60ea630f3c 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -27,7 +27,7 @@ import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.rpc.{RpcAddress, NetworkRpcEndpoint, RpcEndpointRef} import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} +import org.apache.spark.util.{SignalLogger, Utils} private[spark] class CoarseGrainedExecutorBackend( driverUrl: String, @@ -85,7 +85,6 @@ private[spark] class CoarseGrainedExecutorBackend( executor.stop() stop() rpcEnv.stopAll() - env.actorSystem.shutdown() } override def onDisconnected(remoteAddress: RpcAddress): Unit = { @@ -117,16 +116,14 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { // Bootstrap to fetch the driver's Spark properties. val executorConf = new SparkConf val port = executorConf.getInt("spark.executor.port", 0) - val (fetcher, _) = AkkaUtils.createActorSystem( + val rpcEnv = AkkaRpcEnv( "driverPropsFetcher", hostname, port, executorConf, new SecurityManager(executorConf)) - val rpcEnv = new AkkaRpcEnv(fetcher, executorConf) val driver = rpcEnv.setupEndpointRefByUrl(driverUrl) val props = driver.askWithReply[Seq[(String, String)]](RetrieveSparkProps) ++ Seq[(String, String)](("spark.app.id", appId)) rpcEnv.stopAll() - fetcher.shutdown() - fetcher.awaitTermination() + rpcEnv.awaitTermination() // Create SparkEnv using properties we fetched from the driver. val driverConf = new SparkConf().setAll(props) @@ -144,7 +141,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { workerUrl.foreach { url => env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url)) } - env.actorSystem.awaitTermination() + env.rpcEnv.awaitTermination() } } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 5301746b8432..ae14cc929416 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -350,10 +350,7 @@ private[spark] class Executor( def startDriverHeartbeater() { val interval = conf.getInt("spark.executor.heartbeatInterval", 10000) - val timeout = AkkaUtils.lookupTimeout(conf) - val retryAttempts = AkkaUtils.numRetries(conf) - val retryIntervalMs = AkkaUtils.retryWaitMs(conf) - val heartbeatReceiverRef = AkkaUtils.makeDriverRef("HeartbeatReceiver", conf, env.actorSystem) + val heartbeatReceiverRef = env.rpcEnv.setupDriverEndpointRef("HeartbeatReceiver") val t = new Thread() { override def run() { @@ -385,8 +382,7 @@ private[spark] class Executor( val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId) try { - val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef, - retryAttempts, retryIntervalMs, timeout) + val response = heartbeatReceiverRef.askWithReply[HeartbeatResponse](message) if (response.reregisterBlockManager) { logWarning("Told to re-register on heartbeat") env.blockManager.reregister() diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index cd8a508af653..8fad9e006e47 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -59,6 +59,8 @@ trait RpcEnv { endpointRef } + def boundPort: Int + def setupEndpoint(name: String, endpointCreator: => RpcEndpoint): RpcEndpointRef def setupDriverEndpointRef(name: String): RpcEndpointRef @@ -68,6 +70,8 @@ trait RpcEnv { def stop(endpoint: RpcEndpointRef): Unit def stopAll(): Unit + + def awaitTermination(): Unit } @@ -148,7 +152,8 @@ object RpcEndpoint { /** * A reference for a remote [[RpcEndpoint]]. */ -trait RpcEndpointRef {Master +trait RpcEndpointRef { + Master def address: RpcAddress diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 55ca4775b15f..c97bd82c272f 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -19,6 +19,8 @@ package org.apache.spark.rpc.akka import java.util.concurrent.CountDownLatch +import com.google.common.annotations.VisibleForTesting + import scala.concurrent.Await import scala.concurrent.duration._ import scala.concurrent.Future @@ -28,15 +30,15 @@ import akka.actor.{ActorRef, Actor, Props, ActorSystem} import akka.pattern.{ask => akkaAsk} import akka.remote._ -import org.apache.spark.{Logging, SparkException, SparkConf} +import org.apache.spark.{SecurityManager, Logging, SparkException, SparkConf} import org.apache.spark.rpc._ import org.apache.spark.util.{ActorLogReceive, AkkaUtils} import scala.reflect.ClassTag import scala.util.control.NonFatal -class AkkaRpcEnv(val actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { - // TODO Once finishing the new Rpc mechanism, make actorSystem be a private val +class AkkaRpcEnv private (val actorSystem: ActorSystem, conf: SparkConf, val boundPort: Int) + extends RpcEnv { override def setupEndpoint(name: String, endpointCreator: => RpcEndpoint): RpcEndpointRef = { val latch = new CountDownLatch(1) @@ -120,7 +122,7 @@ class AkkaRpcEnv(val actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { } override def stopAll(): Unit = { - // Do nothing since actorSystem was created outside. + actorSystem.shutdown() } override def stop(endpoint: RpcEndpointRef): Unit = { @@ -129,9 +131,32 @@ class AkkaRpcEnv(val actorSystem: ActorSystem, conf: SparkConf) extends RpcEnv { actorSystem.stop(endpoint.asInstanceOf[AkkaRpcEndpointRef].actorRef) } + override def awaitTermination(): Unit = { + actorSystem.awaitTermination() + } + override def toString = s"${getClass.getSimpleName}($actorSystem)" } +object AkkaRpcEnv { + + def apply( + name: String, + host: String, + port: Int, + conf: SparkConf, + securityManager: SecurityManager): AkkaRpcEnv = { + val (actorSystem, boundPort) = + AkkaUtils.createActorSystem(name, host, port, conf, securityManager) + new AkkaRpcEnv(actorSystem, conf, boundPort) + } + + @VisibleForTesting + def apply(name: String, conf: SparkConf): AkkaRpcEnv = { + new AkkaRpcEnv(ActorSystem(name), conf, -1) + } +} + private[akka] class AkkaRpcEndpointRef(val actorRef: ActorRef, @transient conf: SparkConf) extends RpcEndpointRef with Serializable with Logging { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 365c8145c96a..417346fa4638 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1071,7 +1071,6 @@ class DAGScheduler( // in that case the event will already have been scheduled. eventProcessActor may be // null during unit tests. // TODO: Cancel running tasks in the stage - import env.actorSystem.dispatcher logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " + s"$failedStage (${failedStage.name}) due to fetch failure") messageScheduler.schedule(new Runnable { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index a41f3eef195d..8cc62490e99a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -18,14 +18,13 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer +import java.util.concurrent.{TimeUnit, Executors} import java.util.{TimerTask, Timer} import java.util.concurrent.atomic.AtomicLong -import scala.concurrent.duration._ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet -import scala.language.postfixOps import scala.util.Random import org.apache.spark._ @@ -141,11 +140,11 @@ private[spark] class TaskSchedulerImpl( if (!isLocal && conf.getBoolean("spark.speculation", false)) { logInfo("Starting speculative execution thread") - import sc.env.actorSystem.dispatcher - sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL milliseconds, - SPECULATION_INTERVAL milliseconds) { - Utils.tryOrExit { checkSpeculatableTasks() } - } + val scheduler = + Executors.newScheduledThreadPool(1, Utils.namedThreadFactory("task-scheduler-speculation")) + scheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryOrExit { checkSpeculatableTasks() } + }, SPECULATION_INTERVAL, SPECULATION_INTERVAL, TimeUnit.MILLISECONDS) } } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index f4fb476d4864..7df1717ce67b 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark import org.apache.spark.rpc.akka.AkkaRpcEnv -import akka.actor._ import org.mockito.Mockito._ import org.scalatest.FunSuite @@ -27,25 +26,21 @@ import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.AkkaUtils class MapOutputTrackerSuite extends FunSuite { private val conf = new SparkConf test("master start and stop") { - val actorSystem = ActorSystem("test") - val rpcEnv = new AkkaRpcEnv(actorSystem, conf) + val rpcEnv = AkkaRpcEnv("test", conf) val tracker = new MapOutputTrackerMaster(conf) tracker.trackerActor = rpcEnv.setupEndpoint( "MapOutputTracker", new MapOutputTrackerMasterActor(rpcEnv, tracker, conf)) tracker.stop() rpcEnv.stopAll() - actorSystem.shutdown() } test("master register shuffle and fetch") { - val actorSystem = ActorSystem("test") - val rpcEnv = new AkkaRpcEnv(actorSystem, conf) + val rpcEnv = AkkaRpcEnv("test", conf) val tracker = new MapOutputTrackerMaster(conf) tracker.trackerActor = rpcEnv.setupEndpoint( "MapOutputTracker", new MapOutputTrackerMasterActor(rpcEnv, tracker, conf)) @@ -62,12 +57,10 @@ class MapOutputTrackerSuite extends FunSuite { (BlockManagerId("b", "hostB", 1000), size10000))) tracker.stop() rpcEnv.stopAll() - actorSystem.shutdown() } test("master register and unregister shuffle") { - val actorSystem = ActorSystem("test") - val rpcEnv = new AkkaRpcEnv(actorSystem, conf) + val rpcEnv = AkkaRpcEnv("test", conf) val tracker = new MapOutputTrackerMaster(conf) tracker.trackerActor = rpcEnv.setupEndpoint( "MapOutputTracker", new MapOutputTrackerMasterActor(rpcEnv, tracker, conf)) @@ -86,12 +79,10 @@ class MapOutputTrackerSuite extends FunSuite { tracker.stop() rpcEnv.stopAll() - actorSystem.shutdown() } test("master register shuffle and unregister map output and fetch") { - val actorSystem = ActorSystem("test") - val rpcEnv = new AkkaRpcEnv(actorSystem, conf) + val rpcEnv = AkkaRpcEnv("test", conf) val tracker = new MapOutputTrackerMaster(conf) tracker.trackerActor = rpcEnv.setupEndpoint( "MapOutputTracker", new MapOutputTrackerMasterActor(rpcEnv, tracker, conf)) @@ -114,24 +105,21 @@ class MapOutputTrackerSuite extends FunSuite { tracker.stop() rpcEnv.stopAll() - actorSystem.shutdown() } test("remote fetch") { val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf, + val rpcEnv = AkkaRpcEnv("spark", hostname, 0, conf = conf, securityManager = new SecurityManager(conf)) - val rpcEnv = new AkkaRpcEnv(actorSystem, conf) val masterTracker = new MapOutputTrackerMaster(conf) masterTracker.trackerActor = rpcEnv.setupEndpoint( "MapOutputTracker", new MapOutputTrackerMasterActor(rpcEnv, masterTracker, conf)) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf, + val slaveRpcEnv = AkkaRpcEnv("spark-slave", hostname, 0, conf = conf, securityManager = new SecurityManager(conf)) - val slaveRpcEnv = new AkkaRpcEnv(slaveSystem, conf) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker" + val selection = s"akka.tcp://spark@localhost:${rpcEnv.boundPort}/user/MapOutputTracker" slaveTracker.trackerActor = slaveRpcEnv.setupEndpointRefByUrl(selection) masterTracker.registerShuffle(10, 1) @@ -158,9 +146,7 @@ class MapOutputTrackerSuite extends FunSuite { masterTracker.stop() slaveTracker.stop() rpcEnv.stopAll() - actorSystem.shutdown() slaveRpcEnv.stopAll() - slaveSystem.shutdown() } test("remote fetch below akka frame size") { @@ -169,8 +155,7 @@ class MapOutputTrackerSuite extends FunSuite { newConf.set("spark.akka.askTimeout", "1") // Fail fast val masterTracker = new MapOutputTrackerMaster(conf) - val actorSystem = ActorSystem("test") - val rpcEnv = new AkkaRpcEnv(actorSystem, conf) + val rpcEnv = AkkaRpcEnv("test", conf) val masterActor = new MapOutputTrackerMasterActor(rpcEnv, masterTracker, newConf) rpcEnv.setupEndpoint("MapOutputTracker", masterActor) @@ -183,7 +168,6 @@ class MapOutputTrackerSuite extends FunSuite { // masterTracker.stop() // this throws an exception rpcEnv.stopAll() - actorSystem.shutdown() } test("remote fetch exceeds akka frame size") { @@ -192,8 +176,7 @@ class MapOutputTrackerSuite extends FunSuite { newConf.set("spark.akka.askTimeout", "1") // Fail fast val masterTracker = new MapOutputTrackerMaster(conf) - val actorSystem = ActorSystem("test") - val rpcEnv = new AkkaRpcEnv(actorSystem, conf) + val rpcEnv = AkkaRpcEnv("test", conf) val masterActor = new MapOutputTrackerMasterActor(rpcEnv, masterTracker, newConf) // Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should throw exception. @@ -210,6 +193,5 @@ class MapOutputTrackerSuite extends FunSuite { // masterTracker.stop() // this throws an exception rpcEnv.stopAll() - actorSystem.shutdown() } } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala index e3b0571cbe1a..9650ac280aa7 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala @@ -25,8 +25,7 @@ import org.scalatest.FunSuite class WorkerWatcherSuite extends FunSuite { test("WorkerWatcher shuts down on valid disassociation") { - val actorSystem = ActorSystem("test") - val rpcEnv = new AkkaRpcEnv(actorSystem, new SparkConf()) + val rpcEnv = AkkaRpcEnv("test", new SparkConf()) val targetWorkerUrl = "akka://test@1.2.3.4:1234/user/Worker" val targetWorkerAddress = AddressFromURIString(targetWorkerUrl) val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) @@ -35,12 +34,10 @@ class WorkerWatcherSuite extends FunSuite { workerWatcher.onDisconnected(AkkaUtils.akkaAddressToRpcAddress(targetWorkerAddress)) assert(workerWatcher.isShutDown) rpcEnv.stopAll() - actorSystem.shutdown() } test("WorkerWatcher stays alive on invalid disassociation") { - val actorSystem = ActorSystem("test") - val rpcEnv = new AkkaRpcEnv(actorSystem, new SparkConf()) + val rpcEnv = AkkaRpcEnv("test", new SparkConf()) val targetWorkerUrl = "akka://test@1.2.3.4:1234/user/Worker" val otherAkkaURL = "akka://test@4.3.2.1:1234/user/OtherActor" val otherAkkaAddress = AddressFromURIString(otherAkkaURL) @@ -50,6 +47,5 @@ class WorkerWatcherSuite extends FunSuite { workerWatcher.onDisconnected(AkkaUtils.akkaAddressToRpcAddress(otherAkkaAddress)) assert(!workerWatcher.isShutDown) rpcEnv.stopAll() - actorSystem.shutdown() } } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 0eabd73ae2e5..7a7c4e06b35f 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -36,14 +36,12 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { override def afterAll(): Unit = { if(env != null) { - destroyRpcEnv(env) + env.stopAll() } } def createRpcEnv: RpcEnv - def destroyRpcEnv(rpcEnv: RpcEnv) - test("send a message locally") { @volatile var message: String = null val rpcEndpointRef = env.setupEndpoint("send_test", new RpcEndpoint { diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala index 1a1da58fd148..a2f9d5e10c05 100644 --- a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala @@ -17,25 +17,14 @@ package org.apache.spark.rpc.akka -import akka.actor.ActorSystem - import org.apache.spark.rpc.{RpcEnv, RpcEnvSuite} import org.apache.spark.SparkConf class AkkaRpcEnvSuite extends RpcEnvSuite { - var akkaSystem: ActorSystem = _ - override def createRpcEnv: RpcEnv = { val conf = new SparkConf() - akkaSystem = ActorSystem("test") - new AkkaRpcEnv(akkaSystem, conf) + AkkaRpcEnv("test", conf) } - override def destroyRpcEnv(rpcEnv: RpcEnv): Unit = { - rpcEnv.stopAll() - if (akkaSystem != null) { - akkaSystem.shutdown() - } - } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 24f41bf8cccd..9475e4150baf 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -26,6 +26,7 @@ import org.scalatest.Matchers import org.apache.spark.{LocalSparkContext, SparkContext} import org.apache.spark.executor.TaskMetrics +import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.util.ResetSystemProperties class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers with BeforeAndAfter @@ -271,7 +272,8 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers // Make a task whose result is larger than the akka frame size System.setProperty("spark.akka.frameSize", "1") val akkaFrameSize = - sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt + sc.env.rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem.settings. + config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt val result = sc.parallelize(Seq(1), 1) .map { x => 1.to(akkaFrameSize).toArray } .reduce { case (x, y) => x } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index e3a3803e6483..6578d2f39092 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -27,6 +27,7 @@ import org.scalatest.{BeforeAndAfter, FunSuite} import org.scalatest.concurrent.Eventually._ import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv} +import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.storage.TaskResultBlockId /** @@ -86,7 +87,8 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with LocalSpark test("handling results larger than Akka frame size") { sc = new SparkContext("local", "test", conf) val akkaFrameSize = - sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt + sc.env.rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem.settings. + config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x) assert(result === 1.to(akkaFrameSize).toArray) @@ -111,7 +113,8 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with LocalSpark val resultGetter = new ResultDeletingTaskResultGetter(sc.env, scheduler) scheduler.taskResultGetter = resultGetter val akkaFrameSize = - sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt + sc.env.rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem.settings. + config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x) assert(resultGetter.removeBlockSuccessfully) assert(result === 1.to(akkaFrameSize).toArray) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index 6b9112dc2bc0..42b1190bd22f 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -22,9 +22,8 @@ import scala.concurrent.duration._ import scala.language.implicitConversions import scala.language.postfixOps -import akka.actor.ActorSystem import org.mockito.Mockito.{mock, when} -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers, PrivateMethodTester} +import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SecurityManager} @@ -36,13 +35,11 @@ import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.storage.StorageLevel._ -import org.apache.spark.util.{AkkaUtils, SizeEstimator} /** Testsuite that tests block replication in BlockManager */ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAndAfter { private val conf = new SparkConf(false) - var actorSystem: ActorSystem = null var rpcEnv: RpcEnv = null var master: BlockManagerMaster = null val securityMgr = new SecurityManager(conf) @@ -72,13 +69,11 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd } before { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem( + this.rpcEnv = AkkaRpcEnv( "test", "localhost", 0, conf = conf, securityManager = securityMgr) - this.actorSystem = actorSystem - this.rpcEnv = new AkkaRpcEnv(actorSystem, conf) conf.set("spark.authenticate", "false") - conf.set("spark.driver.port", boundPort.toString) + conf.set("spark.driver.port", rpcEnv.boundPort.toString) conf.set("spark.storage.unrollFraction", "0.4") conf.set("spark.storage.unrollMemoryThreshold", "512") @@ -98,10 +93,8 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd allStores.foreach { _.stop() } allStores.clear() rpcEnv.stopAll() + rpcEnv.awaitTermination() rpcEnv = null - actorSystem.shutdown() - actorSystem.awaitTermination() - actorSystem = null master = null } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 169fb727f1ee..aa2ac0165a92 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -19,16 +19,12 @@ package org.apache.spark.storage import java.nio.{ByteBuffer, MappedByteBuffer} import java.util.Arrays -import java.util.concurrent.TimeUnit import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.language.implicitConversions import scala.language.postfixOps -import akka.actor._ -import akka.util.Timeout - import org.mockito.Mockito.{mock, when} import org.scalatest._ @@ -54,7 +50,6 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach var store: BlockManager = null var store2: BlockManager = null var rpcEnv: RpcEnv = null - var actorSystem: ActorSystem = null var master: BlockManagerMaster = null conf.set("spark.authenticate", "false") val securityMgr = new SecurityManager(conf) @@ -80,16 +75,14 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach } override def beforeEach(): Unit = { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem( + this.rpcEnv = AkkaRpcEnv( "test", "localhost", 0, conf = conf, securityManager = securityMgr) - this.actorSystem = actorSystem - this.rpcEnv = new AkkaRpcEnv(actorSystem, conf) // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case System.setProperty("os.arch", "amd64") conf.set("os.arch", "amd64") conf.set("spark.test.useCompressedOops", "true") - conf.set("spark.driver.port", boundPort.toString) + conf.set("spark.driver.port", rpcEnv.boundPort.toString) conf.set("spark.storage.unrollFraction", "0.4") conf.set("spark.storage.unrollMemoryThreshold", "512") @@ -112,10 +105,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach store2 = null } rpcEnv.stopAll() + rpcEnv.awaitTermination() rpcEnv = null - actorSystem.shutdown() - actorSystem.awaitTermination() - actorSystem = null master = null } diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala index 678c2ccdc9de..df32658fbfee 100644 --- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala @@ -37,11 +37,10 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + val rpcEnv = AkkaRpcEnv("spark", hostname, 0, conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + System.setProperty("spark.hostPort", hostname + ":" + rpcEnv.boundPort) assert(securityManager.isAuthenticationEnabled() === true) - val rpcEnv = new AkkaRpcEnv(actorSystem, conf) val masterTracker = new MapOutputTrackerMaster(conf) masterTracker.trackerActor = rpcEnv.setupEndpoint( "MapOutputTracker", new MapOutputTrackerMasterActor(rpcEnv, masterTracker, conf)) @@ -53,19 +52,16 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(securityManagerBad.isAuthenticationEnabled() === true) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + val slaveRpcEnv = AkkaRpcEnv("spark-slave", hostname, 0, conf = conf, securityManager = securityManagerBad) - val slaveRpcEnv = new AkkaRpcEnv(slaveSystem, conf) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker" + val selection = s"akka.tcp://spark@localhost:${rpcEnv.boundPort}/user/MapOutputTracker" intercept[akka.actor.ActorNotFound] { slaveTracker.trackerActor = slaveRpcEnv.setupEndpointRefByUrl(selection) } rpcEnv.stopAll() - actorSystem.shutdown() slaveRpcEnv.stopAll() - slaveSystem.shutdown() } test("remote fetch security off") { @@ -75,10 +71,9 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + val rpcEnv = AkkaRpcEnv("spark", hostname, 0, conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) - val rpcEnv = new AkkaRpcEnv(actorSystem, conf) + System.setProperty("spark.hostPort", hostname + ":" + rpcEnv.boundPort) assert(securityManager.isAuthenticationEnabled() === false) @@ -91,11 +86,10 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro badconf.set("spark.authenticate.secret", "good") val securityManagerBad = new SecurityManager(badconf) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + val slaveRpcEnv = AkkaRpcEnv("spark-slave", hostname, 0, conf = badconf, securityManager = securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val slaveRpcEnv = new AkkaRpcEnv(slaveSystem, conf) - val selection = s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker" + val selection = s"akka.tcp://spark@localhost:${rpcEnv.boundPort}/user/MapOutputTracker" slaveTracker.trackerActor = slaveRpcEnv.setupEndpointRefByUrl(selection) assert(securityManagerBad.isAuthenticationEnabled() === false) @@ -115,9 +109,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro Seq((BlockManagerId("a", "hostA", 1000), size1000))) rpcEnv.stopAll() - actorSystem.shutdown() slaveRpcEnv.stopAll() - slaveSystem.shutdown() } test("remote fetch security pass") { @@ -127,10 +119,9 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + val rpcEnv = AkkaRpcEnv("spark", hostname, 0, conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) - val rpcEnv = new AkkaRpcEnv(actorSystem, conf) + System.setProperty("spark.hostPort", hostname + ":" + rpcEnv.boundPort) assert(securityManager.isAuthenticationEnabled() === true) @@ -145,11 +136,10 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(securityManagerGood.isAuthenticationEnabled() === true) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + val slaveRpcEnv = AkkaRpcEnv("spark-slave", hostname, 0, conf = goodconf, securityManager = securityManagerGood) val slaveTracker = new MapOutputTrackerWorker(conf) - val slaveRpcEnv = new AkkaRpcEnv(slaveSystem, conf) - val selection = s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker" + val selection = s"akka.tcp://spark@localhost:${rpcEnv.boundPort}/user/MapOutputTracker" slaveTracker.trackerActor = slaveRpcEnv.setupEndpointRefByUrl(selection) masterTracker.registerShuffle(10, 1) @@ -167,9 +157,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro Seq((BlockManagerId("a", "hostA", 1000), size1000))) rpcEnv.stopAll() - actorSystem.shutdown() slaveRpcEnv.stopAll() - slaveSystem.shutdown() } test("remote fetch security off client") { @@ -180,10 +168,9 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + val rpcEnv = AkkaRpcEnv("spark", hostname, 0, conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) - val rpcEnv = new AkkaRpcEnv(actorSystem, conf) + System.setProperty("spark.hostPort", hostname + ":" + rpcEnv.boundPort) assert(securityManager.isAuthenticationEnabled() === true) @@ -198,19 +185,16 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(securityManagerBad.isAuthenticationEnabled() === false) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + val slaveRpcEnv = AkkaRpcEnv("spark-slave", hostname, 0, conf = badconf, securityManager = securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val slaveRpcEnv = new AkkaRpcEnv(slaveSystem, conf) - val selection = s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker" + val selection = s"akka.tcp://spark@localhost:${rpcEnv.boundPort}/user/MapOutputTracker" intercept[akka.actor.ActorNotFound] { slaveTracker.trackerActor = slaveRpcEnv.setupEndpointRefByUrl(selection) } rpcEnv.stopAll() - actorSystem.shutdown() slaveRpcEnv.stopAll() - slaveSystem.shutdown() } } From afe39977d5c197bfd64fee67b6f485d07c6e246b Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 6 Jan 2015 20:10:27 +0800 Subject: [PATCH 31/36] Make AkkaRpcEnv pluggable --- .../scala/org/apache/spark/SparkEnv.scala | 4 +- .../org/apache/spark/deploy/Client.scala | 3 +- .../spark/deploy/client/TestClient.scala | 4 +- .../apache/spark/deploy/master/Master.scala | 2 +- .../spark/deploy/worker/DriverWrapper.scala | 6 +-- .../apache/spark/deploy/worker/Worker.scala | 3 +- .../CoarseGrainedExecutorBackend.scala | 5 +- .../scala/org/apache/spark/rpc/RpcEnv.scala | 46 +++++++++++++++++-- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 2 +- .../apache/spark/MapOutputTrackerSuite.scala | 20 ++++---- .../deploy/worker/WorkerWatcherSuite.scala | 8 ++-- .../BlockManagerReplicationSuite.scala | 3 +- .../spark/storage/BlockManagerSuite.scala | 3 +- .../apache/spark/util/AkkaUtilsSuite.scala | 18 ++++---- 14 files changed, 80 insertions(+), 47 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 95d69f8efe44..b05ccce42e61 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -33,8 +33,8 @@ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.network.nio.NioBlockTransferService -import org.apache.spark.rpc.{RpcEndpointRef, RpcEndpoint, RpcEnv} import org.apache.spark.rpc.akka.AkkaRpcEnv +import org.apache.spark.rpc.{RpcEndpointRef, RpcEndpoint, RpcEnv} import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} @@ -217,7 +217,7 @@ object SparkEnv extends Logging { // Create the ActorSystem for Akka and get the port it binds to. val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName - val rpcEnv = AkkaRpcEnv(actorSystemName, hostname, port, conf, securityManager) + val rpcEnv = RpcEnv.create(actorSystemName, hostname, port, conf, securityManager) // Figure out which port Akka actually bound to in case the original port is 0 or occupied. if (isDriver) { diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 43a5f17d7a34..b55a971153ac 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -22,7 +22,6 @@ import org.apache.log4j.{Level, Logger} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} -import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.rpc.{RpcAddress, RpcEnv, RpcEndpointRef, NetworkRpcEndpoint} import org.apache.spark.util.Utils @@ -152,7 +151,7 @@ object Client { conf.set("akka.loglevel", driverArgs.logLevel.toString.replace("WARN", "WARNING")) Logger.getRootLogger.setLevel(driverArgs.logLevel) - val rpcEnv = AkkaRpcEnv( + val rpcEnv = RpcEnv.create( "driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf)) rpcEnv.setupEndpoint("client-actor", new ClientActor(rpcEnv, driverArgs, conf)) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala index 07ccfe4af353..a8dd9086618c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.client -import org.apache.spark.rpc.akka.AkkaRpcEnv +import org.apache.spark.rpc.RpcEnv import org.apache.spark.{SecurityManager, SparkConf, Logging} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.util.Utils @@ -47,7 +47,7 @@ private[spark] object TestClient { def main(args: Array[String]) { val url = args(0) val conf = new SparkConf - val rpcEnv = AkkaRpcEnv("spark", Utils.localIpAddress, 0, + val rpcEnv = RpcEnv.create("spark", Utils.localIpAddress, 0, conf = conf, securityManager = new SecurityManager(conf)) val desc = new ApplicationDescription("TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map(), Seq(), Seq(), Seq()), "ignored") diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 3261cf679ef0..6f2f59fd1709 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -875,7 +875,7 @@ private[spark] object Master extends Logging { webUiPort: Int, conf: SparkConf): (RpcEnv, Int) = { val securityMgr = new SecurityManager(conf) - val rpcEnv = AkkaRpcEnv(systemName, host, port, conf = conf, securityManager = securityMgr) + val rpcEnv = RpcEnv.create(systemName, host, port, conf = conf, securityManager = securityMgr) val actor = rpcEnv.setupEndpoint(actorName, new Master(rpcEnv, host, rpcEnv.boundPort, webUiPort, securityMgr)) val resp = actor.askWithReply[WebUIPortResponse](RequestWebUIPort) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index b5bd626d0957..dfcd34e51a5e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -17,10 +17,10 @@ package org.apache.spark.deploy.worker -import org.apache.spark.rpc.akka.AkkaRpcEnv +import org.apache.spark.rpc.RpcEnv import org.apache.spark.{SecurityManager, SparkConf} -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.Utils /** * Utility object for launching driver programs such that they share fate with the Worker process. @@ -30,7 +30,7 @@ object DriverWrapper { args.toList match { case workerUrl :: mainClass :: extraArgs => val conf = new SparkConf() - val rpcEnv = AkkaRpcEnv("Driver", + val rpcEnv = RpcEnv.create("Driver", Utils.localHostName(), 0, conf, new SecurityManager(conf)) rpcEnv.setupEndpoint("workerWatcher", new WorkerWatcher(rpcEnv, workerUrl)) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index da908da283b9..54d0fb21f397 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -36,7 +36,6 @@ import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.rpc.{RpcAddress, RpcEnv, RpcEndpointRef, NetworkRpcEndpoint} import org.apache.spark.util.{SignalLogger, Utils} @@ -542,7 +541,7 @@ private[spark] object Worker extends Logging { val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("") val actorName = "Worker" val securityMgr = new SecurityManager(conf) - val rpcEnv = AkkaRpcEnv(systemName, host, port, conf = conf, securityManager = securityMgr) + val rpcEnv = RpcEnv.create(systemName, host, port, conf = conf, securityManager = securityMgr) rpcEnv.setupEndpoint(actorName, new Worker(rpcEnv, host, rpcEnv.boundPort, webUiPort, cores, memory, masterUrls, systemName, actorName, workDir, conf, securityMgr)) rpcEnv diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 8a60ea630f3c..15b2595e9aae 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -23,8 +23,7 @@ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv} import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.worker.WorkerWatcher -import org.apache.spark.rpc.akka.AkkaRpcEnv -import org.apache.spark.rpc.{RpcAddress, NetworkRpcEndpoint, RpcEndpointRef} +import org.apache.spark.rpc.{RpcEnv, RpcAddress, NetworkRpcEndpoint, RpcEndpointRef} import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.{SignalLogger, Utils} @@ -116,7 +115,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { // Bootstrap to fetch the driver's Spark properties. val executorConf = new SparkConf val port = executorConf.getInt("spark.executor.port", 0) - val rpcEnv = AkkaRpcEnv( + val rpcEnv = RpcEnv.create( "driverPropsFetcher", hostname, port, executorConf, new SecurityManager(executorConf)) val driver = rpcEnv.setupEndpointRefByUrl(driverUrl) val props = driver.askWithReply[Seq[(String, String)]](RetrieveSparkProps) ++ diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 8fad9e006e47..5fff9d43555d 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -19,12 +19,16 @@ package org.apache.spark.rpc import java.util.concurrent.ConcurrentHashMap -import org.apache.spark.deploy.master.Master - import scala.concurrent.Future -import scala.concurrent.duration.{FiniteDuration, Duration} +import scala.concurrent.duration.FiniteDuration import scala.reflect.ClassTag +import com.google.common.annotations.VisibleForTesting + +import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.deploy.master.Master +import org.apache.spark.util.Utils + /** * An RPC environment. */ @@ -74,6 +78,42 @@ trait RpcEnv { def awaitTermination(): Unit } +object RpcEnv { + + private def getRpcEnvCompanion(conf: SparkConf): AnyRef = { + val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnv") + val rpcEnvName = conf.get("spark.rpc", "akka") + val rpcEnvClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName) + val companion = Class.forName( + rpcEnvClassName + "$", true, Utils.getContextOrSparkClassLoader).getField("MODULE$").get(null) + companion + } + + def create( + name: String, + host: String, + port: Int, + conf: SparkConf, + securityManager: SecurityManager): RpcEnv = { + val companion = getRpcEnvCompanion(conf) + companion.getClass.getMethod("apply", + classOf[String], + classOf[String], + java.lang.Integer.TYPE, + classOf[SparkConf], + classOf[SecurityManager]). + invoke(companion, name, host, port: java.lang.Integer, conf, securityManager). + asInstanceOf[RpcEnv] + } + + @VisibleForTesting + def create(name: String, conf: SparkConf): RpcEnv = { + val companion = getRpcEnvCompanion(conf) + companion.getClass.getMethod("apply", classOf[String], classOf[SparkConf]). + invoke(companion, name, conf).asInstanceOf[RpcEnv] + } + +} /** * An end point for the RPC that defines what functions to trigger given a message. diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index c97bd82c272f..08f6342f0656 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -138,7 +138,7 @@ class AkkaRpcEnv private (val actorSystem: ActorSystem, conf: SparkConf, val bou override def toString = s"${getClass.getSimpleName}($actorSystem)" } -object AkkaRpcEnv { +private[rpc] object AkkaRpcEnv { def apply( name: String, diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 7df1717ce67b..3e2577ff1672 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -17,12 +17,10 @@ package org.apache.spark -import org.apache.spark.rpc.akka.AkkaRpcEnv - import org.mockito.Mockito._ import org.scalatest.FunSuite -import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef} +import org.apache.spark.rpc.{RpcEnv, RpcEndpoint, RpcEndpointRef} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.BlockManagerId @@ -31,7 +29,7 @@ class MapOutputTrackerSuite extends FunSuite { private val conf = new SparkConf test("master start and stop") { - val rpcEnv = AkkaRpcEnv("test", conf) + val rpcEnv = RpcEnv.create("test", conf) val tracker = new MapOutputTrackerMaster(conf) tracker.trackerActor = rpcEnv.setupEndpoint( "MapOutputTracker", new MapOutputTrackerMasterActor(rpcEnv, tracker, conf)) @@ -40,7 +38,7 @@ class MapOutputTrackerSuite extends FunSuite { } test("master register shuffle and fetch") { - val rpcEnv = AkkaRpcEnv("test", conf) + val rpcEnv = RpcEnv.create("test", conf) val tracker = new MapOutputTrackerMaster(conf) tracker.trackerActor = rpcEnv.setupEndpoint( "MapOutputTracker", new MapOutputTrackerMasterActor(rpcEnv, tracker, conf)) @@ -60,7 +58,7 @@ class MapOutputTrackerSuite extends FunSuite { } test("master register and unregister shuffle") { - val rpcEnv = AkkaRpcEnv("test", conf) + val rpcEnv = RpcEnv.create("test", conf) val tracker = new MapOutputTrackerMaster(conf) tracker.trackerActor = rpcEnv.setupEndpoint( "MapOutputTracker", new MapOutputTrackerMasterActor(rpcEnv, tracker, conf)) @@ -82,7 +80,7 @@ class MapOutputTrackerSuite extends FunSuite { } test("master register shuffle and unregister map output and fetch") { - val rpcEnv = AkkaRpcEnv("test", conf) + val rpcEnv = RpcEnv.create("test", conf) val tracker = new MapOutputTrackerMaster(conf) tracker.trackerActor = rpcEnv.setupEndpoint( "MapOutputTracker", new MapOutputTrackerMasterActor(rpcEnv, tracker, conf)) @@ -109,14 +107,14 @@ class MapOutputTrackerSuite extends FunSuite { test("remote fetch") { val hostname = "localhost" - val rpcEnv = AkkaRpcEnv("spark", hostname, 0, conf = conf, + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf = conf, securityManager = new SecurityManager(conf)) val masterTracker = new MapOutputTrackerMaster(conf) masterTracker.trackerActor = rpcEnv.setupEndpoint( "MapOutputTracker", new MapOutputTrackerMasterActor(rpcEnv, masterTracker, conf)) - val slaveRpcEnv = AkkaRpcEnv("spark-slave", hostname, 0, conf = conf, + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, conf = conf, securityManager = new SecurityManager(conf)) val slaveTracker = new MapOutputTrackerWorker(conf) val selection = s"akka.tcp://spark@localhost:${rpcEnv.boundPort}/user/MapOutputTracker" @@ -155,7 +153,7 @@ class MapOutputTrackerSuite extends FunSuite { newConf.set("spark.akka.askTimeout", "1") // Fail fast val masterTracker = new MapOutputTrackerMaster(conf) - val rpcEnv = AkkaRpcEnv("test", conf) + val rpcEnv = RpcEnv.create("test", conf) val masterActor = new MapOutputTrackerMasterActor(rpcEnv, masterTracker, newConf) rpcEnv.setupEndpoint("MapOutputTracker", masterActor) @@ -176,7 +174,7 @@ class MapOutputTrackerSuite extends FunSuite { newConf.set("spark.akka.askTimeout", "1") // Fail fast val masterTracker = new MapOutputTrackerMaster(conf) - val rpcEnv = AkkaRpcEnv("test", conf) + val rpcEnv = RpcEnv.create("test", conf) val masterActor = new MapOutputTrackerMasterActor(rpcEnv, masterTracker, newConf) // Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should throw exception. diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala index 9650ac280aa7..8d5cb38ef729 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala @@ -17,15 +17,15 @@ package org.apache.spark.deploy.worker -import akka.actor.{ActorSystem, AddressFromURIString} +import akka.actor.AddressFromURIString import org.apache.spark.SparkConf -import org.apache.spark.rpc.akka.AkkaRpcEnv +import org.apache.spark.rpc.RpcEnv import org.apache.spark.util.AkkaUtils import org.scalatest.FunSuite class WorkerWatcherSuite extends FunSuite { test("WorkerWatcher shuts down on valid disassociation") { - val rpcEnv = AkkaRpcEnv("test", new SparkConf()) + val rpcEnv = RpcEnv.create("test", new SparkConf()) val targetWorkerUrl = "akka://test@1.2.3.4:1234/user/Worker" val targetWorkerAddress = AddressFromURIString(targetWorkerUrl) val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) @@ -37,7 +37,7 @@ class WorkerWatcherSuite extends FunSuite { } test("WorkerWatcher stays alive on invalid disassociation") { - val rpcEnv = AkkaRpcEnv("test", new SparkConf()) + val rpcEnv = RpcEnv.create("test", new SparkConf()) val targetWorkerUrl = "akka://test@1.2.3.4:1234/user/Worker" val otherAkkaURL = "akka://test@4.3.2.1:1234/user/OtherActor" val otherAkkaAddress = AddressFromURIString(otherAkkaURL) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index 42b1190bd22f..efc796f4ab16 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -30,7 +30,6 @@ import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, Securi import org.apache.spark.network.BlockTransferService import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.rpc.RpcEnv -import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.hash.HashShuffleManager @@ -69,7 +68,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd } before { - this.rpcEnv = AkkaRpcEnv( + this.rpcEnv = RpcEnv.create( "test", "localhost", 0, conf = conf, securityManager = securityMgr) conf.set("spark.authenticate", "false") diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index aa2ac0165a92..d736f2eb29d0 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -34,7 +34,6 @@ import org.scalatest.concurrent.Timeouts._ import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SecurityManager} import org.apache.spark.executor.DataReadMethod import org.apache.spark.network.nio.NioBlockTransferService -import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} @@ -75,7 +74,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach } override def beforeEach(): Unit = { - this.rpcEnv = AkkaRpcEnv( + this.rpcEnv = RpcEnv.create( "test", "localhost", 0, conf = conf, securityManager = securityMgr) // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala index df32658fbfee..bbad58c11627 100644 --- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.util +import org.apache.spark.rpc.RpcEnv import org.scalatest.FunSuite import org.apache.spark._ -import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.scheduler.MapStatus import org.apache.spark.storage.BlockManagerId @@ -37,7 +37,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val securityManager = new SecurityManager(conf) val hostname = "localhost" - val rpcEnv = AkkaRpcEnv("spark", hostname, 0, + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf = conf, securityManager = securityManager) System.setProperty("spark.hostPort", hostname + ":" + rpcEnv.boundPort) assert(securityManager.isAuthenticationEnabled() === true) @@ -52,7 +52,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(securityManagerBad.isAuthenticationEnabled() === true) - val slaveRpcEnv = AkkaRpcEnv("spark-slave", hostname, 0, + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, conf = conf, securityManager = securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) val selection = s"akka.tcp://spark@localhost:${rpcEnv.boundPort}/user/MapOutputTracker" @@ -71,7 +71,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val securityManager = new SecurityManager(conf) val hostname = "localhost" - val rpcEnv = AkkaRpcEnv("spark", hostname, 0, + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf = conf, securityManager = securityManager) System.setProperty("spark.hostPort", hostname + ":" + rpcEnv.boundPort) @@ -86,7 +86,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro badconf.set("spark.authenticate.secret", "good") val securityManagerBad = new SecurityManager(badconf) - val slaveRpcEnv = AkkaRpcEnv("spark-slave", hostname, 0, + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, conf = badconf, securityManager = securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) val selection = s"akka.tcp://spark@localhost:${rpcEnv.boundPort}/user/MapOutputTracker" @@ -119,7 +119,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val securityManager = new SecurityManager(conf) val hostname = "localhost" - val rpcEnv = AkkaRpcEnv("spark", hostname, 0, + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf = conf, securityManager = securityManager) System.setProperty("spark.hostPort", hostname + ":" + rpcEnv.boundPort) @@ -136,7 +136,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(securityManagerGood.isAuthenticationEnabled() === true) - val slaveRpcEnv = AkkaRpcEnv("spark-slave", hostname, 0, + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, conf = goodconf, securityManager = securityManagerGood) val slaveTracker = new MapOutputTrackerWorker(conf) val selection = s"akka.tcp://spark@localhost:${rpcEnv.boundPort}/user/MapOutputTracker" @@ -168,7 +168,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val securityManager = new SecurityManager(conf) val hostname = "localhost" - val rpcEnv = AkkaRpcEnv("spark", hostname, 0, + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf = conf, securityManager = securityManager) System.setProperty("spark.hostPort", hostname + ":" + rpcEnv.boundPort) @@ -185,7 +185,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(securityManagerBad.isAuthenticationEnabled() === false) - val slaveRpcEnv = AkkaRpcEnv("spark-slave", hostname, 0, + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, conf = badconf, securityManager = securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) val selection = s"akka.tcp://spark@localhost:${rpcEnv.boundPort}/user/MapOutputTracker" From 952d468e954f182ee8741523a415ad122cc1aef3 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Wed, 7 Jan 2015 12:23:10 +0800 Subject: [PATCH 32/36] Add comments and minor interface changes --- .../scala/org/apache/spark/rpc/RpcEnv.scala | 180 +++++++++++++++--- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 89 +++++---- .../org/apache/spark/util/AkkaUtils.scala | 2 +- 3 files changed, 201 insertions(+), 70 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 5fff9d43555d..cf9d84abb5cc 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -26,7 +26,6 @@ import scala.reflect.ClassTag import com.google.common.annotations.VisibleForTesting import org.apache.spark.{SecurityManager, SparkConf} -import org.apache.spark.deploy.master.Master import org.apache.spark.util.Utils /** @@ -35,7 +34,8 @@ import org.apache.spark.util.Utils trait RpcEnv { /** - * Need this map to set up the `sender` for the send method. + * A lookup table to search a [[RpcEndpointRef]] for a [[RpcEndpoint]]. We need it to make + * [[RpcEndpoint.self]] work. */ private val endpointToRef = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]() @@ -44,10 +44,9 @@ trait RpcEnv { */ private val refToEndpoint = new ConcurrentHashMap[RpcEndpointRef, RpcEndpoint]() - protected def registerEndpoint(endpoint: RpcEndpoint, endpointRef: RpcEndpointRef): Unit = { - refToEndpoint.put(endpointRef, endpoint) endpointToRef.put(endpoint, endpointRef) + refToEndpoint.put(endpointRef, endpoint) } protected def unregisterEndpoint(endpointRef: RpcEndpointRef): Unit = { @@ -57,30 +56,77 @@ trait RpcEnv { } } + /** + * Retrieve the [[RpcEndpointRef]] of `endpoint`. + */ def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = { val endpointRef = endpointToRef.get(endpoint) require(endpointRef != null, s"Cannot find RpcEndpointRef of ${endpoint} in ${this}") endpointRef } + /** + * Return the port that [[RpcEnv]] is listening to. + */ def boundPort: Int + /** + * Register a [[RpcEndpoint]] with a name and return its [[RpcEndpointRef]]. + */ def setupEndpoint(name: String, endpointCreator: => RpcEndpoint): RpcEndpointRef + /** + * Retrieve a [[RpcEndpointRef]] which is located in the driver via its name. + */ def setupDriverEndpointRef(name: String): RpcEndpointRef + /** + * Retrieve the [[RpcEndpointRef]] represented by `url`. + */ def setupEndpointRefByUrl(url: String): RpcEndpointRef + /** + * Stop [[RpcEndpoint]] specified by `endpoint`. + */ def stop(endpoint: RpcEndpointRef): Unit + /** + * Shutdown this [[RpcEnv]] asynchronously. If need to make sure [[RpcEnv]] exits successfully, + * call [[awaitTermination()]] straight after [[stopAll()]]. + */ def stopAll(): Unit + /** + * Wait until [[RpcEnv]] exits. + * + * TODO do we need a timeout parameter? + */ def awaitTermination(): Unit } +private[rpc] case class RpcEnvConfig( + conf: SparkConf, + name: String, + host: String, + port: Int, + securityManager: SecurityManager) + +/** + * A RpcEnv implementation must have a companion object with an + * `apply(config: RpcEnvConfig): RpcEnv` method so that it can be created via Reflection. + * + * {{{ + * object MyCustomRpcEnv { + * def apply(config: RpcEnvConfig): RpcEnv = { + * ... + * } + * } + * }}} + */ object RpcEnv { private def getRpcEnvCompanion(conf: SparkConf): AnyRef = { + // Add more RpcEnv implementations here val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnv") val rpcEnvName = conf.get("spark.rpc", "akka") val rpcEnvClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName) @@ -95,17 +141,14 @@ object RpcEnv { port: Int, conf: SparkConf, securityManager: SecurityManager): RpcEnv = { + // Using Reflection to create the RpcEnv to avoid to depend on Akka directly + val config = RpcEnvConfig(conf, name, host, port, securityManager) val companion = getRpcEnvCompanion(conf) - companion.getClass.getMethod("apply", - classOf[String], - classOf[String], - java.lang.Integer.TYPE, - classOf[SparkConf], - classOf[SecurityManager]). - invoke(companion, name, host, port: java.lang.Integer, conf, securityManager). - asInstanceOf[RpcEnv] + companion.getClass.getMethod("apply", classOf[RpcEnvConfig]). + invoke(companion, config).asInstanceOf[RpcEnv] } + // TODO Remove it @VisibleForTesting def create(name: String, conf: SparkConf): RpcEnv = { val companion = getRpcEnvCompanion(conf) @@ -118,29 +161,36 @@ object RpcEnv { /** * An end point for the RPC that defines what functions to trigger given a message. * - * RpcEndpoint will be guaranteed that `preStart`, `receive` and `remoteConnectionTerminated` will + * RpcEndpoint will be guaranteed that `onStart`, `receive` and `onStop` will * be called in sequence. * - * Happen before relation: + * The lift-cycle will be: * - * constructor preStart receive* remoteConnectionTerminated + * constructor onStart receive* onStop * - * ?? Need to guarantee that no message will be delivered after remoteConnectionTerminated ?? + * If any error is thrown from one of RpcEndpoint methods except `onError`, [[RpcEndpoint.onError)]] + * will be invoked with the cause. If onError throws an error, it will force [[RpcEndpoint]] to + * restart by creating a new one. */ trait RpcEndpoint { + /** + * The [[RpcEnv]] that this [[RpcEndpoint]] is registered to. + */ val rpcEnv: RpcEnv /** - * Provide the implicit sender. `self` will become valid when `preStart` is called. + * Provide the implicit sender. `self` will become valid when `onStart` is called. + * + * Note: Because before `onStart`, [[RpcEndpoint]] has not yet been registered and there is not + * valid [[RpcEndpointRef]] for it. So don't call `self` before `onStart` is called. In the other + * words, don't call [[RpcEndpointRef.send]] in the constructor of [[RpcEndpoint]]. */ implicit final def self: RpcEndpointRef = { require(rpcEnv != null, "rpcEnv has not been initialized") rpcEnv.endpointRef(this) } - def onStart(): Unit = {} - /** * Same assumption like Actor: messages sent to a RpcEndpoint will be delivered in sequence, and * messages from the same RpcEndpoint will be delivered in order. @@ -160,8 +210,23 @@ trait RpcEndpoint { throw cause } - def onStop(): Unit = {} + /** + * Invoked before [[RpcEndpoint]] starts to handle any message. + */ + def onStart(): Unit = { + // By default, do nothing. + } + /** + * Invoked when [[RpcEndpoint]] is stopping. + */ + def onStop(): Unit = { + // By default, do nothing. + } + + /** + * An convenient method to stop [[RpcEndpoint]]. + */ final def stop(): Unit = { rpcEnv.stop(self) } @@ -169,19 +234,40 @@ trait RpcEndpoint { /** * A RpcEndoint interested in network events. + * + * [[NetworkRpcEndpoint]] will be guaranteed that `onStart`, `receive` , `onConnected`, + * `onDisconnected`, `onNetworkError` and `onStop` will be called in sequence. + * + * The lift-cycle will be: + * + * constructor onStart (receive|onConnected|onDisconnected|onNetworkError)* onStop + * + * If any error is thrown from `onConnected`, `onDisconnected` or `onNetworkError`, + * [[RpcEndpoint.onError)]] will be invoked with the cause. If onError throws an error, it will + * force [[RpcEndpoint]] to restart by creating a new one. */ trait NetworkRpcEndpoint extends RpcEndpoint { + /** + * Invoked when `remoteAddress` is connected to the current node. + */ def onConnected(remoteAddress: RpcAddress): Unit = { // By default, do nothing. } + /** + * Invoked when `remoteAddress` is lost. + */ def onDisconnected(remoteAddress: RpcAddress): Unit = { // By default, do nothing. } + /** + * Invoked when some network error happens in the connection between the current node and + * `remoteAddress`. + */ def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { - // By default, throw e and let RpcEnv handle it + // By default, do nothing. } } @@ -190,28 +276,71 @@ object RpcEndpoint { } /** - * A reference for a remote [[RpcEndpoint]]. + * A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe. */ trait RpcEndpointRef { - Master + /** + * return the address for the [[RpcEndpointRef]] + */ def address: RpcAddress + /** + * Send a message to the corresponding [[RpcEndpoint]] and return a `Future` to receive the reply + * within a default timeout. + */ def ask[T: ClassTag](message: Any): Future[T] + /** + * Send a message to the corresponding [[RpcEndpoint]] and return a `Future` to receive the reply + * within the specified timeout. + */ def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] + /** + * Send a message to the corresponding [[RpcEndpoint]] and get its result within a default + * timeout, or throw a SparkException if this fails even after the default number of retries. + * + * Note: this is a blocking action which may cost a lot of time, so don't call it in an message + * loop of [[RpcEndpoint]]. + * + * @param message the message to send + * @tparam T type of the reply message + * @return the reply message from the corresponding [[RpcEndpoint]] + */ def askWithReply[T](message: Any): T + /** + * Send a message to the corresponding [[RpcEndpoint]] and get its result within a specified + * timeout, throw a SparkException if this fails even after the specified number of retries. + * + * Note: this is a blocking action which may cost a lot of time, so don't call it in an message + * loop of [[RpcEndpoint]]. + * + * @param message the message to send + * @param timeout the timeout duration + * @tparam T type of the reply message + * @return the reply message from the corresponding [[RpcEndpoint]] + */ def askWithReply[T](message: Any, timeout: FiniteDuration): T /** - * Send a message to the remote endpoint asynchronously. No delivery guarantee is provided. + * Sends a one-way asynchronous message. Fire-and-forget semantics. + * + * If invoked from within an [[RpcEndpoint]] then `self` is implicitly passed on as the implicit + * 'sender' argument. If not then no sender is available. + * + * This `sender` reference is then available in the receiving [[RpcEndpoint]] as the `sender` + * parameter of [[RpcEndpoint.receive]] */ def send(message: Any)(implicit sender: RpcEndpointRef = RpcEndpoint.noSender): Unit } +/** + * Represent a host with a port + */ case class RpcAddress(host: String, port: Int) { + // TODO do we need to add the type of RpcEnv in the address? val hostPort: String = host + ":" + port @@ -220,6 +349,9 @@ case class RpcAddress(host: String, port: Int) { object RpcAddress { + /** + * Return the [[RpcAddress]] represented by `uri`. + */ def fromURIString(uri: String): RpcAddress = { val u = new java.net.URI(uri) RpcAddress(u.getHost, u.getPort) diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 08f6342f0656..63d65d07c693 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -19,25 +19,33 @@ package org.apache.spark.rpc.akka import java.util.concurrent.CountDownLatch -import com.google.common.annotations.VisibleForTesting - import scala.concurrent.Await import scala.concurrent.duration._ import scala.concurrent.Future import scala.language.postfixOps +import scala.reflect.ClassTag +import scala.util.control.NonFatal import akka.actor.{ActorRef, Actor, Props, ActorSystem} import akka.pattern.{ask => akkaAsk} import akka.remote._ +import com.google.common.annotations.VisibleForTesting -import org.apache.spark.{SecurityManager, Logging, SparkException, SparkConf} +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.rpc._ import org.apache.spark.util.{ActorLogReceive, AkkaUtils} -import scala.reflect.ClassTag -import scala.util.control.NonFatal - -class AkkaRpcEnv private (val actorSystem: ActorSystem, conf: SparkConf, val boundPort: Int) +/** + * A RpcEnv implementation based on Akka. + * + * TODO Once we remove all usages of Akka in other place, we can move this file to a new project and + * remove Akka from the dependencies. + * + * @param actorSystem + * @param conf + * @param boundPort + */ +private[spark] class AkkaRpcEnv private (val actorSystem: ActorSystem, conf: SparkConf, val boundPort: Int) extends RpcEnv { override def setupEndpoint(name: String, endpointCreator: => RpcEndpoint): RpcEndpointRef = { @@ -47,20 +55,23 @@ class AkkaRpcEnv private (val actorSystem: ActorSystem, conf: SparkConf, val bou val actorRef = actorSystem.actorOf(Props(new Actor with ActorLogReceive with Logging { val endpoint = endpointCreator + // Wait until `endpointRef` is set. TODO better solution? latch.await() require(endpointRef != null) registerEndpoint(endpoint, endpointRef) + var isNetworkRpcEndpoint = false + override def preStart(): Unit = { - endpoint.onStart() if (endpoint.isInstanceOf[NetworkRpcEndpoint]) { - // Listen for remote client disconnection events, - // since they don't go through Akka's watch() + isNetworkRpcEndpoint = true + // Listen for remote client network events only when it's `NetworkRpcEndpoint` context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) } + endpoint.onStart() } - override def receiveWithLogging: Receive = { + override def receiveWithLogging: Receive = if (isNetworkRpcEndpoint) { case AssociatedEvent(_, remoteAddress, _) => try { endpoint.asInstanceOf[NetworkRpcEndpoint]. @@ -85,11 +96,22 @@ class AkkaRpcEnv private (val actorSystem: ActorSystem, conf: SparkConf, val bou case NonFatal(e) => endpoint.onError(e) } case e: RemotingLifecycleEvent => - // ignore? + // TODO ignore? case message: Any => + logDebug("Received RPC message: " + message) + try { + val pf = endpoint.receive(new AkkaRpcEndpointRef(sender(), conf)) + if (pf.isDefinedAt(message)) { + pf.apply(message) + } + } catch { + case NonFatal(e) => endpoint.onError(e) + } + } else { + case message: Any => + logDebug("Received RPC message: " + message) try { - logInfo("Received RPC message: " + message) val pf = endpoint.receive(new AkkaRpcEndpointRef(sender(), conf)) if (pf.isDefinedAt(message)) { pf.apply(message) @@ -140,17 +162,13 @@ class AkkaRpcEnv private (val actorSystem: ActorSystem, conf: SparkConf, val bou private[rpc] object AkkaRpcEnv { - def apply( - name: String, - host: String, - port: Int, - conf: SparkConf, - securityManager: SecurityManager): AkkaRpcEnv = { - val (actorSystem, boundPort) = - AkkaUtils.createActorSystem(name, host, port, conf, securityManager) - new AkkaRpcEnv(actorSystem, conf, boundPort) + def apply(config: RpcEnvConfig): RpcEnv = { + val (actorSystem, boundPort) = AkkaUtils.createActorSystem( + config.name, config.host, config.port, config.conf, config.securityManager) + new AkkaRpcEnv(actorSystem, config.conf, boundPort) } + // TODO Remove it @VisibleForTesting def apply(name: String, conf: SparkConf): AkkaRpcEnv = { new AkkaRpcEnv(ActorSystem(name), conf, -1) @@ -159,9 +177,10 @@ private[rpc] object AkkaRpcEnv { private[akka] class AkkaRpcEndpointRef(val actorRef: ActorRef, @transient conf: SparkConf) extends RpcEndpointRef with Serializable with Logging { + // `conf` won't be used after initialization. So it's safe to be transient. private[this] val maxRetries = conf.getInt("spark.akka.num.retries", 3) - private[this] val retryWaitMs = conf.getInt("spark.akka.retry.wait", 3000) + private[this] val retryWaitMs = conf.getLong("spark.akka.retry.wait", 3000) private[this] val defaultTimeout = conf.getLong("spark.akka.lookupTimeout", 30) seconds override val address: RpcAddress = AkkaUtils.akkaAddressToRpcAddress(actorRef.path.address) @@ -175,28 +194,8 @@ private[akka] class AkkaRpcEndpointRef(val actorRef: ActorRef, @transient conf: override def askWithReply[T](message: Any): T = askWithReply(message, defaultTimeout) override def askWithReply[T](message: Any, timeout: FiniteDuration): T = { - var attempts = 0 - var lastException: Exception = null - while (attempts < maxRetries) { - attempts += 1 - try { - val future = actorRef.ask(message)(timeout) - val result = Await.result(future, timeout) - if (result == null) { - throw new SparkException("Actor returned null") - } - return result.asInstanceOf[T] - } catch { - case ie: InterruptedException => throw ie - case e: Exception => - lastException = e - logWarning("Error sending message in " + attempts + " attempts", e) - } - Thread.sleep(retryWaitMs) - } - - throw new SparkException( - "Error sending message [message = " + message + "]", lastException) + // TODO: Consider removing multiple attempts + AkkaUtils.askWithReply(message, actorRef, maxRetries, retryWaitMs, timeout) } override def send(message: Any)(implicit sender: RpcEndpointRef = RpcEndpoint.noSender): Unit = { diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 86b67de98ba4..4bb406e8999a 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -180,7 +180,7 @@ private[spark] object AkkaUtils extends Logging { message: Any, actor: ActorRef, maxAttempts: Int, - retryInterval: Int, + retryInterval: Long, timeout: FiniteDuration): T = { // TODO: Consider removing multiple attempts if (actor == null) { From d8687ba735c45b58882442c4770f2826f50e8a6c Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 8 Jan 2015 10:25:26 +0800 Subject: [PATCH 33/36] Add a type parameter to `askTracker` --- .../main/scala/org/apache/spark/MapOutputTracker.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 6d32c3231784..e874d7622833 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -42,7 +42,8 @@ private[spark] class MapOutputTrackerMasterActor(override val rpcEnv: RpcEnv, override def receive(sender: RpcEndpointRef) = { case GetMapOutputStatuses(shuffleId: Int) => - logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + sender) + logInfo( + "Asked to send map output locations for shuffle " + shuffleId + " to " + sender.address) val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId) val serializedSize = mapOutputStatuses.size if (serializedSize > maxAkkaFrameSize) { @@ -101,7 +102,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * Send a message to the trackerActor and get its result within a default timeout, or * throw a SparkException if this fails. */ - protected def askTracker(message: Any): Any = { + protected def askTracker[T](message: Any): T = { try { trackerActor.askWithReply(message) } catch { @@ -113,7 +114,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging /** Send a one-way message to the trackerActor, to which we expect it to reply with true. */ protected def sendTracker(message: Any) { - val response = askTracker(message) + val response = askTracker[Boolean](message) if (response != true) { throw new SparkException( "Error reply received from MapOutputTracker. Expecting true, got " + response.toString) @@ -153,8 +154,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging logInfo("Doing the fetch; tracker actor = " + trackerActor) // This try-finally prevents hangs due to timeouts: try { - val fetchedBytes = - askTracker(GetMapOutputStatuses(shuffleId)).asInstanceOf[Array[Byte]] + val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes) logInfo("Got the output locations") mapStatuses.put(shuffleId, fetchedStatuses) From 69380934edcec9e513cfb5fc0632ac660a630e7d Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 8 Jan 2015 20:48:36 +0800 Subject: [PATCH 34/36] Revert the network changes since they are not ready to review --- .../spark/network/rpc/SimpleRpcServer.scala | 103 ------------------ .../spark/network/client/TransportClient.java | 2 +- .../network/protocol/MessageDecoder.java | 1 - .../apache/spark/network/util/JavaUtils.java | 26 ----- 4 files changed, 1 insertion(+), 131 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/network/rpc/SimpleRpcServer.scala diff --git a/core/src/main/scala/org/apache/spark/network/rpc/SimpleRpcServer.scala b/core/src/main/scala/org/apache/spark/network/rpc/SimpleRpcServer.scala deleted file mode 100644 index 10bbeb032e89..000000000000 --- a/core/src/main/scala/org/apache/spark/network/rpc/SimpleRpcServer.scala +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.rpc - -import java.nio.ByteBuffer -import org.apache.spark.SparkConf -import org.apache.spark.network.TransportContext -import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} -import org.apache.spark.network.netty.SparkTransportConf -import org.apache.spark.network.server._ -import org.apache.spark.network.util.JavaUtils -import org.slf4j.Logger - - -class SimpleRpcClient(conf: SparkConf) { - private val transportConf = SparkTransportConf.fromSparkConf(conf, 1) - val transportContext = new TransportContext(transportConf, new RpcHandler { - override def getStreamManager: StreamManager = new OneForOneStreamManager - - override def receive( - client: TransportClient, message: Array[Byte], callback: RpcResponseCallback): Unit = { - println("gotten some message " + JavaUtils.bytesToString(ByteBuffer.wrap(message))) - callback.onSuccess(new Array[Byte](0)) - } - }) - val clientF = transportContext.createClientFactory() - val client = clientF.createClient("localhost", 12345) - - def sendMessage(message: Any): Unit = { - client.sendRpcSync(JavaUtils.serialize(message), 5000) - } -} - - -abstract class SimpleRpcServer(conf: SparkConf) { - - protected def log: Logger - - private val transportConf = SparkTransportConf.fromSparkConf(conf, 1) - - val transportContext = new TransportContext(transportConf, new RpcHandler { - override def getStreamManager: StreamManager = new OneForOneStreamManager - - override def receive( - client: TransportClient, message: Array[Byte], callback: RpcResponseCallback): Unit = { - callback.onSuccess(new Array[Byte](0)) - val received = JavaUtils.deserialize[Any](message) - println("got mesage " + received) - remote = client - if (receiveWithLogging.isDefinedAt(received)) { - receiveWithLogging.apply(received) - } - } - }) - - private[this] val clientFactory = transportContext.createClientFactory() - private[this] var server: TransportServer = _ - - startServer() - private[this] val client = clientFactory.createClient("localhost", 12345) - - def startServer(): Unit = { - server = transportContext.createServer(12345) - log.info("RPC server created on " + server.getPort) - } - - var remote: TransportClient = _ - - def reply(message: Any): Unit = { -// val c = clientFactory.createClient("localhost", -// remote.channel.remoteAddress.asInstanceOf[InetSocketAddress].getPort) -// c.sendRpc(JavaUtils.serialize(message), new RpcResponseCallback { -// override def onSuccess(response: Array[Byte]): Unit = {} -// override def onFailure(e: Throwable): Unit = {} -// }) -// remote.sendRpc(JavaUtils.serialize(message), new RpcResponseCallback { -// override def onFailure(e: Throwable): Unit = {} -// override def onSuccess(response: Array[Byte]): Unit = {} -// }) - remote.sendRpcSync(JavaUtils.serialize(message), 5000) - } - - def sendMessage(message: Any): Unit = { - client.sendRpcSync(JavaUtils.serialize(message), 5000) - } - - def receiveWithLogging: PartialFunction[Any, Unit] -} diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index 4b22f87f7316..37f2e34ceb24 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -67,7 +67,7 @@ public class TransportClient implements Closeable { private final Logger logger = LoggerFactory.getLogger(TransportClient.class); - public final Channel channel; + private final Channel channel; private final TransportResponseHandler handler; public TransportClient(Channel channel, TransportResponseHandler handler) { diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java index 7c4fd8161454..81f8d7f96350 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -34,7 +34,6 @@ public final class MessageDecoder extends MessageToMessageDecoder { private final Logger logger = LoggerFactory.getLogger(MessageDecoder.class); - @Override public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { Message.Type msgType = Message.Type.decode(in); diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index d58f4e104243..bf8a1fc42fc6 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -41,32 +41,6 @@ public class JavaUtils { private static final Logger logger = LoggerFactory.getLogger(JavaUtils.class); - public static T deserialize(byte[] bytes) { - try { - ObjectInputStream is = new ObjectInputStream(new ByteArrayInputStream(bytes)); - Object out = is.readObject(); - is.close(); - return (T) out; - } catch (ClassNotFoundException e) { - throw new RuntimeException("Could not deserialize object", e); - } catch (IOException e) { - throw new RuntimeException("Could not deserialize object", e); - } - } - - // TODO: Make this configurable, do not use Java serialization! - public static byte[] serialize(Object object) { - try { - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - ObjectOutputStream os = new ObjectOutputStream(baos); - os.writeObject(object); - os.close(); - return baos.toByteArray(); - } catch (IOException e) { - throw new RuntimeException("Could not serialize object", e); - } - } - /** Closes the given object, ignoring IOExceptions. */ public static void closeQuietly(Closeable closeable) { try { From ef040bf9811c2ca3020167fe4fc73d31805a9a5d Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 9 Jan 2015 17:58:32 +0800 Subject: [PATCH 35/36] Fix ReceivedBlockHandlerSuite in streaming --- .../streaming/ReceivedBlockHandlerSuite.scala | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 132ff2443fc0..7864b2e712e4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -24,7 +24,6 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.language.postfixOps -import akka.actor.{ActorSystem, Props} import com.google.common.io.Files import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration @@ -33,6 +32,7 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark._ import org.apache.spark.network.nio.NioBlockTransferService +import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.hash.HashShuffleManager @@ -56,7 +56,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche val manualClock = new ManualClock val blockManagerSize = 10000000 - var actorSystem: ActorSystem = null + var rpcEnv: RpcEnv = null var blockManagerMaster: BlockManagerMaster = null var blockManager: BlockManager = null var tempDirectory: File = null @@ -64,14 +64,14 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche before { val (actorSystem, boundPort) = AkkaUtils.createActorSystem( "test", "localhost", 0, conf = conf, securityManager = securityMgr) - this.actorSystem = actorSystem - conf.set("spark.driver.port", boundPort.toString) + rpcEnv = RpcEnv.create("test", "localhost", 0, conf = conf, securityManager = securityMgr) + conf.set("spark.driver.port", rpcEnv.boundPort.toString) - blockManagerMaster = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), + blockManagerMaster = new BlockManagerMaster(rpcEnv.setupEndpoint("block-manager-master-actor", + new BlockManagerMasterActor(rpcEnv, true, conf, new LiveListenerBus)), conf, true) - blockManager = new BlockManager("bm", actorSystem, blockManagerMaster, serializer, + blockManager = new BlockManager("bm", rpcEnv, blockManagerMaster, serializer, blockManagerSize, conf, mapOutputTracker, shuffleManager, new NioBlockTransferService(conf, securityMgr), securityMgr, 0) blockManager.initialize("app-id") @@ -89,9 +89,9 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche blockManagerMaster.stop() blockManagerMaster = null } - actorSystem.shutdown() - actorSystem.awaitTermination() - actorSystem = null + rpcEnv.stopAll() + rpcEnv.awaitTermination() + rpcEnv = null if (tempDirectory != null && tempDirectory.exists()) { FileUtils.deleteDirectory(tempDirectory) From c3359f09e551ab79c829b7b24215b70921709045 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 16 Jan 2015 17:48:20 +0800 Subject: [PATCH 36/36] Revert DAGScheduler --- .../apache/spark/scheduler/DAGScheduler.scala | 116 ++++++++++-------- .../spark/scheduler/DAGSchedulerSuite.scala | 47 ++++--- 2 files changed, 98 insertions(+), 65 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 502d94010679..9e6f11686a4b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -20,20 +20,24 @@ package org.apache.spark.scheduler import java.io.NotSerializableException import java.util.Properties import java.util.concurrent.atomic.AtomicInteger -import java.util.concurrent.{TimeUnit, Executors} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack} +import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps import scala.reflect.ClassTag import scala.util.control.NonFatal +import akka.actor._ +import akka.actor.SupervisorStrategy.Stop +import akka.pattern.ask +import akka.util.Timeout + import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD -import org.apache.spark.rpc.{RpcEnv, RpcEndpointRef, RpcEndpoint} import org.apache.spark.storage._ import org.apache.spark.util.{CallSite, SystemClock, Clock, Utils} import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat @@ -77,11 +81,6 @@ class DAGScheduler( def this(sc: SparkContext) = this(sc, sc.taskScheduler) - private val rpcEnv = env.rpcEnv - - private val messageScheduler = - Executors.newScheduledThreadPool(1, Utils.namedThreadFactory("dag-scheduler-message")) - private[scheduler] val nextJobId = new AtomicInteger(0) private[scheduler] def numTotalJobs: Int = nextJobId.get() private val nextStageId = new AtomicInteger(0) @@ -113,30 +112,42 @@ class DAGScheduler( // stray messages to detect. private val failedEpoch = new HashMap[String, Long] + private val dagSchedulerActorSupervisor = + env.actorSystem.actorOf(Props(new DAGSchedulerActorSupervisor(this))) + // A closure serializer that we reuse. // This is only safe because DAGScheduler runs in a single thread. private val closureSerializer = SparkEnv.get.closureSerializer.newInstance() + private[scheduler] var eventProcessActor: ActorRef = _ + /** If enabled, we may run certain actions like take() and first() locally. */ private val localExecutionEnabled = sc.getConf.getBoolean("spark.localExecution.enabled", false) /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */ private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false) - private[scheduler] var eventProcessActor = rpcEnv.setupEndpoint( - "DAGSchedulerEventProcessActor-" + DAGScheduler.nextId, - new DAGSchedulerEventProcessActor(rpcEnv, this)) + private def initializeEventProcessActor() { + // blocking the thread until supervisor is started, which ensures eventProcessActor is + // not null before any job is submitted + implicit val timeout = Timeout(30 seconds) + val initEventActorReply = + dagSchedulerActorSupervisor ? Props(new DAGSchedulerEventProcessActor(this)) + eventProcessActor = Await.result(initEventActorReply, timeout.duration). + asInstanceOf[ActorRef] + } + initializeEventProcessActor() taskScheduler.setDAGScheduler(this) // Called by TaskScheduler to report task's starting. def taskStarted(task: Task[_], taskInfo: TaskInfo) { - eventProcessActor.send(BeginEvent(task, taskInfo)) + eventProcessActor ! BeginEvent(task, taskInfo) } // Called to report that a task has completed and results are being fetched remotely. def taskGettingResult(taskInfo: TaskInfo) { - eventProcessActor.send(GettingResultEvent(taskInfo)) + eventProcessActor ! GettingResultEvent(taskInfo) } // Called by TaskScheduler to report task completions or failures. @@ -147,8 +158,7 @@ class DAGScheduler( accumUpdates: Map[Long, Any], taskInfo: TaskInfo, taskMetrics: TaskMetrics) { - eventProcessActor.send( - CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics)) + eventProcessActor ! CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics) } /** @@ -167,18 +177,18 @@ class DAGScheduler( // Called by TaskScheduler when an executor fails. def executorLost(execId: String) { - eventProcessActor.send(ExecutorLost(execId)) + eventProcessActor ! ExecutorLost(execId) } // Called by TaskScheduler when a host is added def executorAdded(execId: String, host: String) { - eventProcessActor.send(ExecutorAdded(execId, host)) + eventProcessActor ! ExecutorAdded(execId, host) } // Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or // cancellation of the job itself. def taskSetFailed(taskSet: TaskSet, reason: String) { - eventProcessActor.send(TaskSetFailed(taskSet, reason)) + eventProcessActor ! TaskSetFailed(taskSet, reason) } private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = { @@ -483,8 +493,8 @@ class DAGScheduler( assert(partitions.size > 0) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler) - eventProcessActor.send(JobSubmitted( - jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties)) + eventProcessActor ! JobSubmitted( + jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties) waiter } @@ -524,8 +534,8 @@ class DAGScheduler( val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val partitions = (0 until rdd.partitions.size).toArray val jobId = nextJobId.getAndIncrement() - eventProcessActor.send(JobSubmitted( - jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties)) + eventProcessActor ! JobSubmitted( + jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties) listener.awaitResult() // Will throw an exception if the job fails } @@ -534,19 +544,19 @@ class DAGScheduler( */ def cancelJob(jobId: Int) { logInfo("Asked to cancel job " + jobId) - eventProcessActor.send(JobCancelled(jobId)) + eventProcessActor ! JobCancelled(jobId) } def cancelJobGroup(groupId: String) { logInfo("Asked to cancel job group " + groupId) - eventProcessActor.send(JobGroupCancelled(groupId)) + eventProcessActor ! JobGroupCancelled(groupId) } /** * Cancel all jobs that are running or waiting in the queue. */ def cancelAllJobs() { - eventProcessActor.send(AllJobsCancelled) + eventProcessActor ! AllJobsCancelled } private[scheduler] def doCancelAllJobs() { @@ -562,7 +572,7 @@ class DAGScheduler( * Cancel all jobs associated with a running or scheduled stage. */ def cancelStage(stageId: Int) { - eventProcessActor.send(StageCancelled(stageId)) + eventProcessActor ! StageCancelled(stageId) } /** @@ -1051,13 +1061,11 @@ class DAGScheduler( // in that case the event will already have been scheduled. eventProcessActor may be // null during unit tests. // TODO: Cancel running tasks in the stage + import env.actorSystem.dispatcher logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " + s"$failedStage (${failedStage.name}) due to fetch failure") - messageScheduler.schedule(new Runnable { - override def run(): Unit = { - eventProcessActor.send(ResubmitFailedStages) - } - }, RESUBMIT_TIMEOUT.toMillis, TimeUnit.MILLISECONDS) + env.actorSystem.scheduler.scheduleOnce( + RESUBMIT_TIMEOUT, eventProcessActor, ResubmitFailedStages) } failedStages += failedStage failedStages += mapStage @@ -1315,18 +1323,40 @@ class DAGScheduler( def stop() { logInfo("Stopping DAGScheduler") - rpcEnv.stop(eventProcessActor) + dagSchedulerActorSupervisor ! PoisonPill taskScheduler.stop() } } -private[scheduler] class DAGSchedulerEventProcessActor( - override val rpcEnv: RpcEnv, dagScheduler: DAGScheduler) extends RpcEndpoint with Logging { +private[scheduler] class DAGSchedulerActorSupervisor(dagScheduler: DAGScheduler) + extends Actor with Logging { + + override val supervisorStrategy = + OneForOneStrategy() { + case x: Exception => + logError("eventProcesserActor failed; shutting down SparkContext", x) + try { + dagScheduler.doCancelAllJobs() + } catch { + case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t) + } + dagScheduler.sc.stop() + Stop + } + + def receive = { + case p: Props => sender ! context.actorOf(p) + case _ => logWarning("received unknown message in DAGSchedulerActorSupervisor") + } +} + +private[scheduler] class DAGSchedulerEventProcessActor(dagScheduler: DAGScheduler) + extends Actor with Logging { /** * The main event loop of the DAG scheduler. */ - def receive(sender: RpcEndpointRef) = { + def receive = { case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) => dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) @@ -1365,21 +1395,10 @@ private[scheduler] class DAGSchedulerEventProcessActor( dagScheduler.resubmitFailedStages() } - override def onStop() { + override def postStop() { // Cancel any active jobs in postStop hook dagScheduler.cleanUpAfterSchedulerStop() } - - override def onError(e: Throwable): Unit = { - logError("eventProcesserActor failed; shutting down SparkContext", e) - stop() - try { - dagScheduler.doCancelAllJobs() - } catch { - case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t) - } - dagScheduler.sc.stop() - } } private[spark] object DAGScheduler { @@ -1391,9 +1410,4 @@ private[spark] object DAGScheduler { // The time, in millis, to wake up between polls of the completion queue in order to potentially // resubmit failed stages val POLL_TIMEOUT = 10L - - private val id = new AtomicInteger(0) - - // To resolve the conflicts of actor name in the unit tests - def nextId = id.getAndDecrement } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 7d9c4793feaf..d30eb10bbe94 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -20,18 +20,26 @@ package org.apache.spark.scheduler import scala.collection.mutable.{ArrayBuffer, HashSet, HashMap, Map} import scala.language.reflectiveCalls -import org.scalatest.{FunSuite, BeforeAndAfter} +import akka.actor._ +import akka.testkit.{ImplicitSender, TestKit, TestActorRef} +import org.scalatest.{BeforeAndAfter, FunSuiteLike} import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.rpc.RpcEndpoint import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} import org.apache.spark.util.CallSite import org.apache.spark.executor.TaskMetrics +class BuggyDAGEventProcessActor extends Actor { + val state = 0 + def receive = { + case _ => throw new SparkException("error") + } +} + /** * An RDD for passing to DAGScheduler. These RDDs will use the dependencies and * preferredLocations (if any) that are passed to them. They are deliberately not executable @@ -57,7 +65,8 @@ class MyRDD( class DAGSchedulerSuiteDummyException extends Exception -class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with Timeouts { +class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with FunSuiteLike + with ImplicitSender with BeforeAndAfter with LocalSparkContext with Timeouts { val conf = new SparkConf /** Set of TaskSets the DAGScheduler has requested executed. */ @@ -102,10 +111,9 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with Timeouts { } } - var sc: SparkContext = null var mapOutputTracker: MapOutputTrackerMaster = null var scheduler: DAGScheduler = null - var dagEventProcessTestActor: RpcEndpoint = null + var dagEventProcessTestActor: TestActorRef[DAGSchedulerEventProcessActor] = null /** * Set of cache locations to return from our mock BlockManagerMaster. @@ -159,13 +167,13 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with Timeouts { runLocallyWithinThread(job) } } - val rpcEnv = sc.env.rpcEnv - dagEventProcessTestActor = new DAGSchedulerEventProcessActor(rpcEnv, scheduler) - rpcEnv.setupEndpoint("DAGSchedulerEventProcessActorTest", dagEventProcessTestActor) + dagEventProcessTestActor = TestActorRef[DAGSchedulerEventProcessActor]( + Props(classOf[DAGSchedulerEventProcessActor], scheduler))(system) } - after { - sc.stop() + override def afterAll() { + super.afterAll() + TestKit.shutdownActorSystem(system) } /** @@ -182,7 +190,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with Timeouts { * DAGScheduler event loop. */ private def runEvent(event: DAGSchedulerEvent) { - dagEventProcessTestActor.receive(RpcEndpoint.noSender)(event) + dagEventProcessTestActor.receive(event) } /** @@ -389,9 +397,8 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with Timeouts { runLocallyWithinThread(job) } } - val rpcEnv = sc.env.rpcEnv - dagEventProcessTestActor = new DAGSchedulerEventProcessActor(rpcEnv, noKillScheduler) - rpcEnv.setupEndpoint("DAGSchedulerEventProcessActor-nokill", dagEventProcessTestActor) + dagEventProcessTestActor = TestActorRef[DAGSchedulerEventProcessActor]( + Props(classOf[DAGSchedulerEventProcessActor], noKillScheduler))(system) val jobId = submit(new MyRDD(sc, 1, Nil), Array(0)) cancel(jobId) // Because the job wasn't actually cancelled, we shouldn't have received a failure message. @@ -719,6 +726,18 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with Timeouts { assert(sc.parallelize(1 to 10, 2).first() === 1) } + test("DAGSchedulerActorSupervisor closes the SparkContext when EventProcessActor crashes") { + val actorSystem = ActorSystem("test") + val supervisor = actorSystem.actorOf( + Props(classOf[DAGSchedulerActorSupervisor], scheduler), "dagSupervisor") + supervisor ! Props[BuggyDAGEventProcessActor] + val child = expectMsgType[ActorRef] + watch(child) + child ! "hi" + expectMsgPF(){ case Terminated(child) => () } + assert(scheduler.sc.dagScheduler === null) + } + test("accumulator not calculated for resubmitted result stage") { //just for register val accum = new Accumulator[Int](0, AccumulatorParam.IntAccumulatorParam)