Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.scheduler

import java.nio.ByteBuffer
import java.util.{Locale, Timer, TimerTask}
import java.util.concurrent.TimeUnit
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
import java.util.concurrent.atomic.AtomicLong

import scala.collection.Set
Expand Down Expand Up @@ -91,7 +91,7 @@ private[spark] class TaskSchedulerImpl(
private val taskSetsByStageIdAndAttempt = new HashMap[Int, HashMap[Int, TaskSetManager]]

// Protected by `this`
private[scheduler] val taskIdToTaskSetManager = new HashMap[Long, TaskSetManager]
private[scheduler] val taskIdToTaskSetManager = new ConcurrentHashMap[Long, TaskSetManager]
val taskIdToExecutorId = new HashMap[Long, String]

@volatile private var hasReceivedTask = false
Expand Down Expand Up @@ -315,7 +315,7 @@ private[spark] class TaskSchedulerImpl(
for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
tasks(i) += task
val tid = task.taskId
taskIdToTaskSetManager(tid) = taskSet
taskIdToTaskSetManager.put(tid, taskSet)
taskIdToExecutorId(tid) = execId
executorIdToRunningTaskIds(execId).add(tid)
availableCpus(i) -= CPUS_PER_TASK
Expand Down Expand Up @@ -465,7 +465,7 @@ private[spark] class TaskSchedulerImpl(
var reason: Option[ExecutorLossReason] = None
synchronized {
try {
taskIdToTaskSetManager.get(tid) match {
Option(taskIdToTaskSetManager.get(tid)) match {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I like this direction.

I have a question about the change of semantics. By removing synchronization at accumUpdatesWithTaskIds(), a pair of operations in this synchronized get() and remove() in cleanupTaskState() is not atomic regarding get.
Is this change ok?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It think your question is about the get() and remove() inside cleanupTaskState all being inside a single synchronize block, correct?

I don't see that as being a problem here since taskIdToTaskSetManager is a concurrentHashMap. That protects the operations from being atomic and if you do a remove on an object that isn't there then it does nothing. There is no other code that removes from there so I don't think that can happen anyway. With this change the only thing outside of a synchronize block is a get in accumUpdatesWithTaskIds which will be harmless if it had been removed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ConcurrentHashMap make each operation like get(), remove(), and others. Thus, I reviewed places more than one operations are within a synchronized. The place is here.
When we apply this PR, the get in accumUpdatesWithTaskIds can be executed between get() and remove(). My question is like a confirmation whether it is safe or not.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes as far as I can see its safe. If the get happened before its removed it calculates the accumulators, if its after its removed it just gets an empty array back. This isn't any different then when it was synchronized. There is nothing in the statusUpdate between the get and call to cleanupTaskState where it removes that I see depends on accumulators or anything else.

case Some(taskSet) =>
if (state == TaskState.LOST) {
// TaskState.LOST is only used by the deprecated Mesos fine-grained scheduling mode,
Expand Down Expand Up @@ -517,10 +517,10 @@ private[spark] class TaskSchedulerImpl(
accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])],
blockManagerId: BlockManagerId): Boolean = {
// (taskId, stageId, stageAttemptId, accumUpdates)
val accumUpdatesWithTaskIds: Array[(Long, Int, Int, Seq[AccumulableInfo])] = synchronized {
val accumUpdatesWithTaskIds: Array[(Long, Int, Int, Seq[AccumulableInfo])] = {
accumUpdates.flatMap { case (id, updates) =>
val accInfos = updates.map(acc => acc.toInfo(Some(acc.value), None))
taskIdToTaskSetManager.get(id).map { taskSetMgr =>
Option(taskIdToTaskSetManager.get(id)).map { taskSetMgr =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just leave a small concern here, original code locked hole scope of ids in accumUpdates, after this changing, maybe some id could be found originally but can't find now, because taskIdToTaskSetManager can be changed by removeExecutor or statusUpdate. Its not big problem if executor has been removed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree this could happen, but it shouldn't cause issues because before this change the executor could have been removed right before this function was called (its all timing dependent), so that does not change this functionality. This is only to update accumulators for running tasks. If the tasks had finished then the accumulator updates would have been processed via the task end events.

(id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, accInfos)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
for (task <- tasks.flatten) {
val serializedTask = TaskDescription.encode(task)
if (serializedTask.limit() >= maxRpcMessageSize) {
scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr =>
Option(scheduler.taskIdToTaskSetManager.get(task.taskId)).foreach { taskSetMgr =>
try {
var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
"spark.rpc.message.maxSize (%d bytes). Consider increasing " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,8 @@ private[spark] abstract class MockBackend(
// get the task now, since that requires a lock on TaskSchedulerImpl, to prevent individual
// tests from introducing a race if they need it.
val newTasks = newTaskDescriptions.map { taskDescription =>
val taskSet = taskScheduler.taskIdToTaskSetManager(taskDescription.taskId).taskSet
val taskSet =
Option(taskScheduler.taskIdToTaskSetManager.get(taskDescription.taskId).taskSet).get
val task = taskSet.tasks(taskDescription.index)
(taskDescription, task)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
taskScheduler.submitTasks(attempt2)
val taskDescriptions3 = taskScheduler.resourceOffers(workerOffers).flatten
assert(1 === taskDescriptions3.length)
val mgr = taskScheduler.taskIdToTaskSetManager.get(taskDescriptions3(0).taskId).get
val mgr = Option(taskScheduler.taskIdToTaskSetManager.get(taskDescriptions3(0).taskId)).get
assert(mgr.taskSet.stageAttemptId === 1)
assert(!failedTaskSet)
}
Expand Down Expand Up @@ -286,7 +286,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
assert(10 === taskDescriptions3.length)

taskDescriptions3.foreach { task =>
val mgr = taskScheduler.taskIdToTaskSetManager.get(task.taskId).get
val mgr = Option(taskScheduler.taskIdToTaskSetManager.get(task.taskId)).get
assert(mgr.taskSet.stageAttemptId === 1)
}
assert(!failedTaskSet)
Expand Down Expand Up @@ -724,7 +724,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
// only schedule one task because of locality
assert(taskDescs.size === 1)

val mgr = taskScheduler.taskIdToTaskSetManager.get(taskDescs(0).taskId).get
val mgr = Option(taskScheduler.taskIdToTaskSetManager.get(taskDescs(0).taskId)).get
assert(mgr.myLocalityLevels.toSet === Set(TaskLocality.NODE_LOCAL, TaskLocality.ANY))
// we should know about both executors, even though we only scheduled tasks on one of them
assert(taskScheduler.getExecutorsAliveOnHost("host0") === Some(Set("executor0")))
Expand Down