diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index f3bd0797aa03..43fa44320381 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -152,6 +152,28 @@ private[spark] class TaskSetManager( addPendingTask(i) } + sortPendingTasksForHosts(pendingTasksForHost) + + // Improve tasks preferrd locality by sorting tasks partial ordering. + private def sortPendingTasksForHosts(tasksMap: HashMap[String, ArrayBuffer[Int]]) { + tasksMap.foreach(pair => { + val host = pair._1 + var v = pair._2 + var map = new HashMap[Int, ArrayBuffer[Int]]() + v.foreach(index => { + for (loc <- tasks(index).preferredLocations) { + var i = 0 + if (loc.host == host) { + map.getOrElseUpdate(i, new ArrayBuffer) += index + } + i += 1 + } + }) + v.clear + map.foreach(kv => v ++= kv._2) + }) + } + // Figure out which locality levels we have in our TaskSet, so we can do delay scheduling val myLocalityLevels = computeValidLocalityLevels() val localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level