Skip to content

Commit 6f0c44b

Browse files
mccheahbulldozer-bot[bot]
authored andcommitted
Set the task context in writer benchmarks (apache-spark-on-k8s#529)
1 parent bc40da2 commit 6f0c44b

File tree

3 files changed

+5
-4
lines changed

3 files changed

+5
-4
lines changed

core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ abstract class ShuffleWriterBenchmarkBase extends BenchmarkBase {
111111
when(rpcEnv.setupEndpoint(any[String], any[RpcEndpoint])).thenReturn(rpcEndpointRef)
112112

113113
def setup(): Unit = {
114+
TaskContext.setTaskContext(taskContext)
114115
memoryManager = new TestMemoryManager(defaultConf)
115116
memoryManager.limit(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES)
116117
taskMemoryManager = new TaskMemoryManager(memoryManager, 0)

core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,9 @@ package org.apache.spark.shuffle.sort
1919

2020
import org.mockito.Mockito.when
2121

22-
import org.apache.spark.{Aggregator, SparkEnv}
22+
import org.apache.spark.{Aggregator, SparkEnv, TaskContext}
2323
import org.apache.spark.benchmark.Benchmark
2424
import org.apache.spark.shuffle.BaseShuffleHandle
25-
import org.apache.spark.shuffle.sort.io.DefaultShuffleDataIO
2625

2726
/**
2827
* Benchmark to measure performance for aggregate primitives.
@@ -76,6 +75,7 @@ object SortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase {
7675
}
7776

7877
when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager)
78+
TaskContext.setTaskContext(taskContext)
7979

8080
val shuffleWriter = new SortShuffleWriter[String, String, String](
8181
blockResolver,

core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,8 @@
1616
*/
1717
package org.apache.spark.shuffle.sort
1818

19-
import org.apache.spark.SparkConf
19+
import org.apache.spark.{SparkConf, TaskContext}
2020
import org.apache.spark.benchmark.Benchmark
21-
import org.apache.spark.util.Utils
2221

2322
/**
2423
* Benchmark to measure performance for aggregate primitives.
@@ -44,6 +43,7 @@ object UnsafeShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase {
4443
val conf = new SparkConf(loadDefaults = false)
4544
conf.set("spark.file.transferTo", String.valueOf(transferTo))
4645

46+
TaskContext.setTaskContext(taskContext)
4747
new UnsafeShuffleWriter[String, String](
4848
blockManager,
4949
blockResolver,

0 commit comments

Comments
 (0)