Skip to content

Commit 7ab0ce6

Browse files
Marcelo VanzinAndrew Or
authored andcommitted
[SPARK-11131][CORE] Fix race in worker registration protocol.
Because the registration RPC was not really an RPC, but a bunch of disconnected messages, it was possible for other messages to be sent before the reply to the registration arrived, and that would confuse the Worker. Especially in local-cluster mode, the worker was succeptible to receiving an executor request before it received a message from the master saying registration succeeded. On top of the above, the change also fixes a ClassCastException when the registration fails, which also affects the executor registration protocol. Because the `ask` is issued with a specific return type, if the error message (of a different type) was returned instead, the code would just die with an exception. This is fixed by having a common base trait for these reply messages. Author: Marcelo Vanzin <vanzin@cloudera.com> Closes #9138 from vanzin/SPARK-11131.
1 parent 6758213 commit 7ab0ce6

File tree

6 files changed

+86
-56
lines changed

6 files changed

+86
-56
lines changed

core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,14 @@ private[deploy] object DeployMessages {
6969

7070
// Master to Worker
7171

72+
sealed trait RegisterWorkerResponse
73+
7274
case class RegisteredWorker(master: RpcEndpointRef, masterWebUiUrl: String) extends DeployMessage
75+
with RegisterWorkerResponse
76+
77+
case class RegisterWorkerFailed(message: String) extends DeployMessage with RegisterWorkerResponse
7378

74-
case class RegisterWorkerFailed(message: String) extends DeployMessage
79+
case object MasterInStandby extends DeployMessage with RegisterWorkerResponse
7580

7681
case class ReconnectWorker(masterUrl: String) extends DeployMessage
7782

core/src/main/scala/org/apache/spark/deploy/master/Master.scala

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -233,31 +233,6 @@ private[deploy] class Master(
233233
System.exit(0)
234234
}
235235

236-
case RegisterWorker(
237-
id, workerHost, workerPort, workerRef, cores, memory, workerUiPort, publicAddress) => {
238-
logInfo("Registering worker %s:%d with %d cores, %s RAM".format(
239-
workerHost, workerPort, cores, Utils.megabytesToString(memory)))
240-
if (state == RecoveryState.STANDBY) {
241-
// ignore, don't send response
242-
} else if (idToWorker.contains(id)) {
243-
workerRef.send(RegisterWorkerFailed("Duplicate worker ID"))
244-
} else {
245-
val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory,
246-
workerRef, workerUiPort, publicAddress)
247-
if (registerWorker(worker)) {
248-
persistenceEngine.addWorker(worker)
249-
workerRef.send(RegisteredWorker(self, masterWebUiUrl))
250-
schedule()
251-
} else {
252-
val workerAddress = worker.endpoint.address
253-
logWarning("Worker registration failed. Attempted to re-register worker at same " +
254-
"address: " + workerAddress)
255-
workerRef.send(RegisterWorkerFailed("Attempted to re-register worker at same address: "
256-
+ workerAddress))
257-
}
258-
}
259-
}
260-
261236
case RegisterApplication(description, driver) => {
262237
// TODO Prevent repeated registrations from some driver
263238
if (state == RecoveryState.STANDBY) {
@@ -387,6 +362,31 @@ private[deploy] class Master(
387362
}
388363

389364
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
365+
case RegisterWorker(
366+
id, workerHost, workerPort, workerRef, cores, memory, workerUiPort, publicAddress) => {
367+
logInfo("Registering worker %s:%d with %d cores, %s RAM".format(
368+
workerHost, workerPort, cores, Utils.megabytesToString(memory)))
369+
if (state == RecoveryState.STANDBY) {
370+
context.reply(MasterInStandby)
371+
} else if (idToWorker.contains(id)) {
372+
context.reply(RegisterWorkerFailed("Duplicate worker ID"))
373+
} else {
374+
val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory,
375+
workerRef, workerUiPort, publicAddress)
376+
if (registerWorker(worker)) {
377+
persistenceEngine.addWorker(worker)
378+
context.reply(RegisteredWorker(self, masterWebUiUrl))
379+
schedule()
380+
} else {
381+
val workerAddress = worker.endpoint.address
382+
logWarning("Worker registration failed. Attempted to re-register worker at same " +
383+
"address: " + workerAddress)
384+
context.reply(RegisterWorkerFailed("Attempted to re-register worker at same address: "
385+
+ workerAddress))
386+
}
387+
}
388+
}
389+
390390
case RequestSubmitDriver(description) => {
391391
if (state != RecoveryState.ALIVE) {
392392
val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " +

core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala

Lines changed: 47 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFut
2626

2727
import scala.collection.mutable.{HashMap, HashSet, LinkedHashMap}
2828
import scala.concurrent.ExecutionContext
29-
import scala.util.Random
29+
import scala.util.{Failure, Random, Success}
3030
import scala.util.control.NonFatal
3131

3232
import org.apache.spark.{Logging, SecurityManager, SparkConf}
@@ -213,8 +213,7 @@ private[deploy] class Worker(
213213
logInfo("Connecting to master " + masterAddress + "...")
214214
val masterEndpoint =
215215
rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME)
216-
masterEndpoint.send(RegisterWorker(
217-
workerId, host, port, self, cores, memory, webUi.boundPort, publicAddress))
216+
registerWithMaster(masterEndpoint)
218217
} catch {
219218
case ie: InterruptedException => // Cancelled
220219
case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e)
@@ -271,8 +270,7 @@ private[deploy] class Worker(
271270
logInfo("Connecting to master " + masterAddress + "...")
272271
val masterEndpoint =
273272
rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME)
274-
masterEndpoint.send(RegisterWorker(
275-
workerId, host, port, self, cores, memory, webUi.boundPort, publicAddress))
273+
registerWithMaster(masterEndpoint)
276274
} catch {
277275
case ie: InterruptedException => // Cancelled
278276
case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e)
@@ -341,25 +339,54 @@ private[deploy] class Worker(
341339
}
342340
}
343341

344-
override def receive: PartialFunction[Any, Unit] = {
345-
case RegisteredWorker(masterRef, masterWebUiUrl) =>
346-
logInfo("Successfully registered with master " + masterRef.address.toSparkURL)
347-
registered = true
348-
changeMaster(masterRef, masterWebUiUrl)
349-
forwordMessageScheduler.scheduleAtFixedRate(new Runnable {
350-
override def run(): Unit = Utils.tryLogNonFatalError {
351-
self.send(SendHeartbeat)
352-
}
353-
}, 0, HEARTBEAT_MILLIS, TimeUnit.MILLISECONDS)
354-
if (CLEANUP_ENABLED) {
355-
logInfo(s"Worker cleanup enabled; old application directories will be deleted in: $workDir")
342+
private def registerWithMaster(masterEndpoint: RpcEndpointRef): Unit = {
343+
masterEndpoint.ask[RegisterWorkerResponse](RegisterWorker(
344+
workerId, host, port, self, cores, memory, webUi.boundPort, publicAddress))
345+
.onComplete {
346+
// This is a very fast action so we can use "ThreadUtils.sameThread"
347+
case Success(msg) =>
348+
Utils.tryLogNonFatalError {
349+
handleRegisterResponse(msg)
350+
}
351+
case Failure(e) =>
352+
logError(s"Cannot register with master: ${masterEndpoint.address}", e)
353+
System.exit(1)
354+
}(ThreadUtils.sameThread)
355+
}
356+
357+
private def handleRegisterResponse(msg: RegisterWorkerResponse): Unit = synchronized {
358+
msg match {
359+
case RegisteredWorker(masterRef, masterWebUiUrl) =>
360+
logInfo("Successfully registered with master " + masterRef.address.toSparkURL)
361+
registered = true
362+
changeMaster(masterRef, masterWebUiUrl)
356363
forwordMessageScheduler.scheduleAtFixedRate(new Runnable {
357364
override def run(): Unit = Utils.tryLogNonFatalError {
358-
self.send(WorkDirCleanup)
365+
self.send(SendHeartbeat)
359366
}
360-
}, CLEANUP_INTERVAL_MILLIS, CLEANUP_INTERVAL_MILLIS, TimeUnit.MILLISECONDS)
361-
}
367+
}, 0, HEARTBEAT_MILLIS, TimeUnit.MILLISECONDS)
368+
if (CLEANUP_ENABLED) {
369+
logInfo(
370+
s"Worker cleanup enabled; old application directories will be deleted in: $workDir")
371+
forwordMessageScheduler.scheduleAtFixedRate(new Runnable {
372+
override def run(): Unit = Utils.tryLogNonFatalError {
373+
self.send(WorkDirCleanup)
374+
}
375+
}, CLEANUP_INTERVAL_MILLIS, CLEANUP_INTERVAL_MILLIS, TimeUnit.MILLISECONDS)
376+
}
362377

378+
case RegisterWorkerFailed(message) =>
379+
if (!registered) {
380+
logError("Worker registration failed: " + message)
381+
System.exit(1)
382+
}
383+
384+
case MasterInStandby =>
385+
// Ignore. Master not yet ready.
386+
}
387+
}
388+
389+
override def receive: PartialFunction[Any, Unit] = synchronized {
363390
case SendHeartbeat =>
364391
if (connected) { sendToMaster(Heartbeat(workerId, self)) }
365392

@@ -399,12 +426,6 @@ private[deploy] class Worker(
399426
map(e => new ExecutorDescription(e.appId, e.execId, e.cores, e.state))
400427
masterRef.send(WorkerSchedulerStateResponse(workerId, execs.toList, drivers.keys.toSeq))
401428

402-
case RegisterWorkerFailed(message) =>
403-
if (!registered) {
404-
logError("Worker registration failed: " + message)
405-
System.exit(1)
406-
}
407-
408429
case ReconnectWorker(masterUrl) =>
409430
logInfo(s"Master with url $masterUrl requested this worker to reconnect.")
410431
registerWithMaster()

core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,12 @@ private[spark] class CoarseGrainedExecutorBackend(
5959
rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref =>
6060
// This is a very fast action so we can use "ThreadUtils.sameThread"
6161
driver = Some(ref)
62-
ref.ask[RegisteredExecutor.type](
62+
ref.ask[RegisterExecutorResponse](
6363
RegisterExecutor(executorId, self, hostPort, cores, extractLogUrls))
6464
}(ThreadUtils.sameThread).onComplete {
6565
// This is a very fast action so we can use "ThreadUtils.sameThread"
6666
case Success(msg) => Utils.tryLogNonFatalError {
67-
Option(self).foreach(_.send(msg)) // msg must be RegisteredExecutor
67+
Option(self).foreach(_.send(msg)) // msg must be RegisterExecutorResponse
6868
}
6969
case Failure(e) => {
7070
logError(s"Cannot register with driver: $driverUrl", e)

core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,13 @@ private[spark] object CoarseGrainedClusterMessages {
3636
case class KillTask(taskId: Long, executor: String, interruptThread: Boolean)
3737
extends CoarseGrainedClusterMessage
3838

39+
sealed trait RegisterExecutorResponse
40+
3941
case object RegisteredExecutor extends CoarseGrainedClusterMessage
42+
with RegisterExecutorResponse
4043

4144
case class RegisterExecutorFailed(message: String) extends CoarseGrainedClusterMessage
45+
with RegisterExecutorResponse
4246

4347
// Executors to driver
4448
case class RegisterExecutor(

core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,9 @@ class HeartbeatReceiverSuite
173173
val dummyExecutorEndpoint2 = new FakeExecutorEndpoint(rpcEnv)
174174
val dummyExecutorEndpointRef1 = rpcEnv.setupEndpoint("fake-executor-1", dummyExecutorEndpoint1)
175175
val dummyExecutorEndpointRef2 = rpcEnv.setupEndpoint("fake-executor-2", dummyExecutorEndpoint2)
176-
fakeSchedulerBackend.driverEndpoint.askWithRetry[RegisteredExecutor.type](
176+
fakeSchedulerBackend.driverEndpoint.askWithRetry[RegisterExecutorResponse](
177177
RegisterExecutor(executorId1, dummyExecutorEndpointRef1, "dummy:4040", 0, Map.empty))
178-
fakeSchedulerBackend.driverEndpoint.askWithRetry[RegisteredExecutor.type](
178+
fakeSchedulerBackend.driverEndpoint.askWithRetry[RegisterExecutorResponse](
179179
RegisterExecutor(executorId2, dummyExecutorEndpointRef2, "dummy:4040", 0, Map.empty))
180180
heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet)
181181
addExecutorAndVerify(executorId1)

0 commit comments

Comments
 (0)