diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index e73ba39468828..2db6e91a93fa4 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -272,9 +272,6 @@ void spill() throws IOException { spills.size() > 1 ? " times" : " time"); writeSortedFile(false); - final long inMemSorterMemoryUsage = inMemSorter.getMemoryUsage(); - inMemSorter = null; - shuffleMemoryManager.release(inMemSorterMemoryUsage); final long spillSize = freeMemory(); taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); @@ -316,6 +313,12 @@ private long freeMemory() { currentPage = null; currentPagePosition = -1; freeSpaceInCurrentPage = 0; + if (inMemSorter != null) { + final long inMemSorterMemoryUsage = inMemSorter.getMemoryUsage(); + inMemSorter = null; + shuffleMemoryManager.release(inMemSorterMemoryUsage); + memoryFreed += inMemSorterMemoryUsage; + } return memoryFreed; } @@ -329,10 +332,6 @@ public void cleanupResources() { logger.error("Unable to delete spill file {}", spill.file.getPath()); } } - if (inMemSorter != null) { - shuffleMemoryManager.release(inMemSorter.getMemoryUsage()); - inMemSorter = null; - } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 9edf9f048f9fd..4f1c0f4d1256e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -23,7 +23,7 @@ import java.nio.ByteBuffer import scala.collection.mutable.HashMap import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.{Accumulator, SparkEnv, TaskContextImpl, TaskContext} +import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.serializer.SerializerInstance import org.apache.spark.unsafe.memory.TaskMemoryManager @@ -48,7 +48,7 @@ private[spark] abstract class Task[T]( val stageId: Int, val stageAttemptId: Int, val partitionId: Int, - internalAccumulators: Seq[Accumulator[Long]]) extends Serializable { + internalAccumulators: Seq[Accumulator[Long]]) extends Logging with Serializable { /** * The key of the Map is the accumulator id and the value of the Map is the latest accumulator @@ -84,19 +84,35 @@ private[spark] abstract class Task[T]( if (_killed) { kill(interruptThread = false) } + var exceptionThrown: Boolean = true try { - (runTask(context), context.collectAccumulators()) + val res = (runTask(context), context.collectAccumulators()) + exceptionThrown = false + res } finally { context.markTaskCompleted() try { + val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager + val shuffleMemoryUsed = shuffleMemoryManager.getMemoryConsumptionForThisTask() Utils.tryLogNonFatalError { // Release memory used by this thread for shuffles - SparkEnv.get.shuffleMemoryManager.releaseMemoryForThisTask() + shuffleMemoryManager.releaseMemoryForThisTask() } Utils.tryLogNonFatalError { // Release memory used by this thread for unrolling blocks SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask() } + if (SparkEnv.get.conf.contains("spark.testing") && shuffleMemoryUsed != 0) { + val errMsg = + s"Shuffle memory leak detected; size = $shuffleMemoryUsed bytes, TID = $taskAttemptId" + if (!exceptionThrown) { + throw new SparkException(errMsg) + } else { + // The task failed with an exception, so don't throw here in order to avoid masking + // the original exception: + logWarning(errMsg) + } + } } finally { TaskContext.unset() } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index 0c8f08f0f3b1b..0d7976888e181 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -104,7 +104,7 @@ private[spark] class HashShuffleReader[K, C]( context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) context.internalMetricsToAccumulators( InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes) - sorter.iterator + CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) case None => aggregatedIter } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index f929b12606f0a..5a1eb120c1e61 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -30,6 +30,7 @@ import org.apache.spark.{Logging, SparkEnv, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.serializer.{DeserializationStream, Serializer} import org.apache.spark.storage.{BlockId, BlockManager} +import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator import org.apache.spark.executor.ShuffleWriteMetrics @@ -122,6 +123,10 @@ class ExternalAppendOnlyMap[K, V, C]( * The shuffle memory usage of the first trackMemoryThreshold entries is not tracked. */ def insertAll(entries: Iterator[Product2[K, V]]): Unit = { + if (currentMap == null) { + throw new IllegalStateException( + "Cannot insert new elements into a map after calling iterator") + } // An update function for the map that we reuse across entries to avoid allocating // a new closure each time var curEntry: Product2[K, V] = null @@ -216,13 +221,22 @@ class ExternalAppendOnlyMap[K, V, C]( spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes)) } + private def freeCurrentMap(): Unit = { + currentMap = null // So that the memory can be garbage-collected + releaseMemoryForThisThread() + } + /** - * Return an iterator that merges the in-memory map with the spilled maps. + * Return a destructive iterator that merges the in-memory map with the spilled maps. * If no spill has occurred, simply return the in-memory map's iterator. */ override def iterator: Iterator[(K, C)] = { + if (currentMap == null) { + throw new IllegalStateException( + "ExternalAppendOnlyMap.iterator is destructive and should only be called once.") + } if (spilledMaps.isEmpty) { - currentMap.iterator + CompletionIterator[(K, C), Iterator[(K, C)]](currentMap.iterator, freeCurrentMap()) } else { new ExternalIterator() } @@ -239,7 +253,8 @@ class ExternalAppendOnlyMap[K, V, C]( // Input streams are derived both from the in-memory map and spilled maps on disk // The in-memory map is sorted in place, while the spilled maps are already in sorted order - private val sortedMap = currentMap.destructiveSortedIterator(keyComparator) + private val sortedMap = CompletionIterator[(K, C), Iterator[(K, C)]]( + currentMap.destructiveSortedIterator(keyComparator), freeCurrentMap()) private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered) inputStreams.foreach { it => diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 138c05dff19e4..44ba1786c1d82 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -709,6 +709,9 @@ private[spark] class ExternalSorter[K, V, C]( } def stop(): Unit = { + map = null // So that the memory can be garbage-collected + buffer = null // So that the memory can be garbage-collected + releaseMemoryForThisThread() spills.foreach(s => s.file.delete()) spills.clear() } diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index 747ecf075a397..6205742e8514c 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -100,7 +100,7 @@ private[spark] trait Spillable[C] extends Logging { /** * Release our memory back to the shuffle pool so that other threads can grab it. */ - private def releaseMemoryForThisThread(): Unit = { + protected def releaseMemoryForThisThread(): Unit = { // The amount we requested does not include the initial memory tracking threshold shuffleMemoryManager.release(myMemoryThreshold - initialMemoryThreshold) myMemoryThreshold = initialMemoryThreshold diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 12e9bafcc92c1..3a5597589ffa2 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -46,23 +46,27 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { conf } - test("simple insert") { + test("single insert insert") { val conf = createSparkConf(loadDefaults = false) sc = new SparkContext("local", "test", conf) val map = createExternalMap[Int] - - // Single insert map.insert(1, 10) - var it = map.iterator + val it = map.iterator assert(it.hasNext) val kv = it.next() assert(kv._1 === 1 && kv._2 === ArrayBuffer[Int](10)) assert(!it.hasNext) + sc.stop() + } - // Multiple insert + test("multiple insert") { + val conf = createSparkConf(loadDefaults = false) + sc = new SparkContext("local", "test", conf) + val map = createExternalMap[Int] + map.insert(1, 10) map.insert(2, 20) map.insert(3, 30) - it = map.iterator + val it = map.iterator assert(it.hasNext) assert(it.toSet === Set[(Int, ArrayBuffer[Int])]( (1, ArrayBuffer[Int](10)), @@ -141,39 +145,22 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { sc = new SparkContext("local", "test", conf) val map = createExternalMap[Int] + val nullInt = null.asInstanceOf[Int] map.insert(1, 5) map.insert(2, 6) map.insert(3, 7) - assert(map.size === 3) - assert(map.iterator.toSet === Set[(Int, Seq[Int])]( - (1, Seq[Int](5)), - (2, Seq[Int](6)), - (3, Seq[Int](7)) - )) - - // Null keys - val nullInt = null.asInstanceOf[Int] + map.insert(4, nullInt) map.insert(nullInt, 8) - assert(map.size === 4) - assert(map.iterator.toSet === Set[(Int, Seq[Int])]( + map.insert(nullInt, nullInt) + val result = map.iterator.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.sorted)) + assert(result === Set[(Int, Seq[Int])]( (1, Seq[Int](5)), (2, Seq[Int](6)), (3, Seq[Int](7)), - (nullInt, Seq[Int](8)) + (4, Seq[Int](nullInt)), + (nullInt, Seq[Int](nullInt, 8)) )) - // Null values - map.insert(4, nullInt) - map.insert(nullInt, nullInt) - assert(map.size === 5) - val result = map.iterator.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.toSet)) - assert(result === Set[(Int, Set[Int])]( - (1, Set[Int](5)), - (2, Set[Int](6)), - (3, Set[Int](7)), - (4, Set[Int](nullInt)), - (nullInt, Set[Int](nullInt, 8)) - )) sc.stop() }