diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala index 4dc1251a4ca84..4f5bb264170de 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala @@ -24,9 +24,12 @@ import org.mockito.{Mock, MockitoAnnotations} import org.mockito.Answers.RETURNS_SMART_NULLS import org.mockito.ArgumentMatchers.any import org.mockito.Mockito.when +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer import scala.util.Random import org.apache.spark.{Aggregator, MapOutputTracker, ShuffleDependency, SparkConf, SparkEnv, TaskContext} +import org.apache.spark.api.shuffle.ShuffleLocation import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} @@ -194,14 +197,17 @@ object BlockStoreShuffleReaderBenchmark extends BenchmarkBase { } when(mapOutputTracker.getMapSizesByShuffleLocation(0, 0, 1)) - .thenReturn { - val shuffleBlockIdsAndSizes = (0 until NUM_MAPS).map { mapId => - val shuffleBlockId = ShuffleBlockId(0, mapId, 0) - (shuffleBlockId, dataFileLength) + .thenAnswer(new Answer[Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])]] { + def answer(invocationOnMock: InvocationOnMock): + Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = { + val shuffleBlockIdsAndSizes = (0 until NUM_MAPS).map { mapId => + val shuffleBlockId = ShuffleBlockId(0, mapId, 0) + (shuffleBlockId, dataFileLength) + } + Seq((Option.apply(DefaultMapShuffleLocations.get(dataBlockId)), shuffleBlockIdsAndSizes)) + .toIterator } - Seq((Option.apply(DefaultMapShuffleLocations.get(dataBlockId)), shuffleBlockIdsAndSizes)) - .toIterator - } + }) when(dependency.serializer).thenReturn(serializer) when(dependency.aggregator).thenReturn(aggregator)