Skip to content

Commit 19596c2

Browse files
jinxingkayousterhout
authored andcommitted
[SPARK-16929] Improve performance when check speculatable tasks.
## What changes were proposed in this pull request? 1. Use a MedianHeap to record durations of successful tasks. When check speculatable tasks, we can get the median duration with O(1) time complexity. 2. `checkSpeculatableTasks` will synchronize `TaskSchedulerImpl`. If `checkSpeculatableTasks` doesn't finish with 100ms, then the possibility exists for that thread to release and then immediately re-acquire the lock. Change `scheduleAtFixedRate` to be `scheduleWithFixedDelay` when call method of `checkSpeculatableTasks`. ## How was this patch tested? Added MedianHeapSuite. Author: jinxing <jinxing6042@126.com> Closes #16867 from jinxing64/SPARK-16929.
1 parent bb823ca commit 19596c2

File tree

5 files changed

+176
-6
lines changed

5 files changed

+176
-6
lines changed

core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ private[spark] class TaskSchedulerImpl private[scheduler](
174174

175175
if (!isLocal && conf.getBoolean("spark.speculation", false)) {
176176
logInfo("Starting speculative execution thread")
177-
speculationScheduler.scheduleAtFixedRate(new Runnable {
177+
speculationScheduler.scheduleWithFixedDelay(new Runnable {
178178
override def run(): Unit = Utils.tryOrStopSparkContext(sc) {
179179
checkSpeculatableTasks()
180180
}

core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,18 @@ package org.apache.spark.scheduler
1919

2020
import java.io.NotSerializableException
2121
import java.nio.ByteBuffer
22-
import java.util.Arrays
2322
import java.util.concurrent.ConcurrentLinkedQueue
2423

2524
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
26-
import scala.math.{max, min}
25+
import scala.math.max
2726
import scala.util.control.NonFatal
2827

2928
import org.apache.spark._
3029
import org.apache.spark.internal.Logging
3130
import org.apache.spark.scheduler.SchedulingMode._
3231
import org.apache.spark.TaskState.TaskState
3332
import org.apache.spark.util.{AccumulatorV2, Clock, SystemClock, Utils}
33+
import org.apache.spark.util.collection.MedianHeap
3434

3535
/**
3636
* Schedules the tasks within a single TaskSet in the TaskSchedulerImpl. This class keeps track of
@@ -63,6 +63,8 @@ private[spark] class TaskSetManager(
6363
// Limit of bytes for total size of results (default is 1GB)
6464
val maxResultSize = Utils.getMaxResultSize(conf)
6565

66+
val speculationEnabled = conf.getBoolean("spark.speculation", false)
67+
6668
// Serializer for closures and tasks.
6769
val env = SparkEnv.get
6870
val ser = env.closureSerializer.newInstance()
@@ -141,6 +143,11 @@ private[spark] class TaskSetManager(
141143
// Task index, start and finish time for each task attempt (indexed by task ID)
142144
private val taskInfos = new HashMap[Long, TaskInfo]
143145

146+
// Use a MedianHeap to record durations of successful tasks so we know when to launch
147+
// speculative tasks. This is only used when speculation is enabled, to avoid the overhead
148+
// of inserting into the heap when the heap won't be used.
149+
val successfulTaskDurations = new MedianHeap()
150+
144151
// How frequently to reprint duplicate exceptions in full, in milliseconds
145152
val EXCEPTION_PRINT_INTERVAL =
146153
conf.getLong("spark.logging.exceptionPrintInterval", 10000)
@@ -698,6 +705,9 @@ private[spark] class TaskSetManager(
698705
val info = taskInfos(tid)
699706
val index = info.index
700707
info.markFinished(TaskState.FINISHED, clock.getTimeMillis())
708+
if (speculationEnabled) {
709+
successfulTaskDurations.insert(info.duration)
710+
}
701711
removeRunningTask(tid)
702712
// This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the
703713
// "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not
@@ -919,11 +929,10 @@ private[spark] class TaskSetManager(
919929
var foundTasks = false
920930
val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
921931
logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
932+
922933
if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) {
923934
val time = clock.getTimeMillis()
924-
val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
925-
Arrays.sort(durations)
926-
val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.length - 1))
935+
var medianDuration = successfulTaskDurations.median
927936
val threshold = max(SPECULATION_MULTIPLIER * medianDuration, minTimeToSpeculation)
928937
// TODO: Threshold should also look at standard deviation of task durations and have a lower
929938
// bound based on that.
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.util.collection
19+
20+
import scala.collection.mutable.PriorityQueue
21+
22+
/**
23+
* MedianHeap is designed to be used to quickly track the median of a group of numbers
24+
* that may contain duplicates. Inserting a new number has O(log n) time complexity and
25+
* determining the median has O(1) time complexity.
26+
* The basic idea is to maintain two heaps: a smallerHalf and a largerHalf. The smallerHalf
27+
* stores the smaller half of all numbers while the largerHalf stores the larger half.
28+
* The sizes of two heaps need to be balanced each time when a new number is inserted so
29+
* that their sizes will not be different by more than 1. Therefore each time when
30+
* findMedian() is called we check if two heaps have the same size. If they do, we should
31+
* return the average of the two top values of heaps. Otherwise we return the top of the
32+
* heap which has one more element.
33+
*/
34+
private[spark] class MedianHeap(implicit val ord: Ordering[Double]) {
35+
36+
/**
37+
* Stores all the numbers less than the current median in a smallerHalf,
38+
* i.e median is the maximum, at the root.
39+
*/
40+
private[this] var smallerHalf = PriorityQueue.empty[Double](ord)
41+
42+
/**
43+
* Stores all the numbers greater than the current median in a largerHalf,
44+
* i.e median is the minimum, at the root.
45+
*/
46+
private[this] var largerHalf = PriorityQueue.empty[Double](ord.reverse)
47+
48+
def isEmpty(): Boolean = {
49+
smallerHalf.isEmpty && largerHalf.isEmpty
50+
}
51+
52+
def size(): Int = {
53+
smallerHalf.size + largerHalf.size
54+
}
55+
56+
def insert(x: Double): Unit = {
57+
// If both heaps are empty, we arbitrarily insert it into a heap, let's say, the largerHalf.
58+
if (isEmpty) {
59+
largerHalf.enqueue(x)
60+
} else {
61+
// If the number is larger than current median, it should be inserted into largerHalf,
62+
// otherwise smallerHalf.
63+
if (x > median) {
64+
largerHalf.enqueue(x)
65+
} else {
66+
smallerHalf.enqueue(x)
67+
}
68+
}
69+
rebalance()
70+
}
71+
72+
private[this] def rebalance(): Unit = {
73+
if (largerHalf.size - smallerHalf.size > 1) {
74+
smallerHalf.enqueue(largerHalf.dequeue())
75+
}
76+
if (smallerHalf.size - largerHalf.size > 1) {
77+
largerHalf.enqueue(smallerHalf.dequeue)
78+
}
79+
}
80+
81+
def median: Double = {
82+
if (isEmpty) {
83+
throw new NoSuchElementException("MedianHeap is empty.")
84+
}
85+
if (largerHalf.size == smallerHalf.size) {
86+
(largerHalf.head + smallerHalf.head) / 2.0
87+
} else if (largerHalf.size > smallerHalf.size) {
88+
largerHalf.head
89+
} else {
90+
smallerHalf.head
91+
}
92+
}
93+
}

core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -893,6 +893,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
893893
val taskSet = FakeTask.createTaskSet(4)
894894
// Set the speculation multiplier to be 0 so speculative tasks are launched immediately
895895
sc.conf.set("spark.speculation.multiplier", "0.0")
896+
sc.conf.set("spark.speculation", "true")
896897
val clock = new ManualClock()
897898
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock)
898899
val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task =>
@@ -948,6 +949,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
948949
// Set the speculation multiplier to be 0 so speculative tasks are launched immediately
949950
sc.conf.set("spark.speculation.multiplier", "0.0")
950951
sc.conf.set("spark.speculation.quantile", "0.6")
952+
sc.conf.set("spark.speculation", "true")
951953
val clock = new ManualClock()
952954
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock)
953955
val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task =>
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.util.collection
19+
20+
import java.util.NoSuchElementException
21+
22+
import org.apache.spark.SparkFunSuite
23+
24+
class MedianHeapSuite extends SparkFunSuite {
25+
26+
test("If no numbers in MedianHeap, NoSuchElementException is thrown.") {
27+
val medianHeap = new MedianHeap()
28+
intercept[NoSuchElementException] {
29+
medianHeap.median
30+
}
31+
}
32+
33+
test("Median should be correct when size of MedianHeap is even") {
34+
val array = Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
35+
val medianHeap = new MedianHeap()
36+
array.foreach(medianHeap.insert(_))
37+
assert(medianHeap.size() === 10)
38+
assert(medianHeap.median === 4.5)
39+
}
40+
41+
test("Median should be correct when size of MedianHeap is odd") {
42+
val array = Array(0, 1, 2, 3, 4, 5, 6, 7, 8)
43+
val medianHeap = new MedianHeap()
44+
array.foreach(medianHeap.insert(_))
45+
assert(medianHeap.size() === 9)
46+
assert(medianHeap.median === 4)
47+
}
48+
49+
test("Median should be correct though there are duplicated numbers inside.") {
50+
val array = Array(0, 0, 1, 1, 2, 3, 4)
51+
val medianHeap = new MedianHeap()
52+
array.foreach(medianHeap.insert(_))
53+
assert(medianHeap.size === 7)
54+
assert(medianHeap.median === 1)
55+
}
56+
57+
test("Median should be correct when input data is skewed.") {
58+
val medianHeap = new MedianHeap()
59+
(0 until 10).foreach(_ => medianHeap.insert(5))
60+
assert(medianHeap.median === 5)
61+
(0 until 100).foreach(_ => medianHeap.insert(10))
62+
assert(medianHeap.median === 10)
63+
(0 until 1000).foreach(_ => medianHeap.insert(0))
64+
assert(medianHeap.median === 0)
65+
}
66+
}

0 commit comments

Comments
 (0)