diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index 81ee7ab58ab5..3c2980e442ab 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -215,8 +215,6 @@ private void writeSortedFile(boolean isLastFile) throws IOException { } } - inMemSorter.reset(); - if (!isLastFile) { // i.e. this is a spill file // The current semantics of `shuffleRecordsWritten` seem to be that it's updated when records // are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter @@ -255,6 +253,10 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { writeSortedFile(false); final long spillSize = freeMemory(); + inMemSorter.reset(); + // Reset the in-memory sorter's pointer array only after freeing up the memory pages holding the + // records. Otherwise, if the task is over allocated memory, then without freeing the memory pages, + // we might not be able to get memory for the pointer array. taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); return spillSize; } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java index fe79ff0e3052..76b0e6a304ac 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java @@ -51,9 +51,12 @@ public int compare(PackedRecordPointer left, PackedRecordPointer right) { */ private int pos = 0; + private int initialSize; + ShuffleInMemorySorter(MemoryConsumer consumer, int initialSize) { this.consumer = consumer; assert (initialSize > 0); + this.initialSize = initialSize; this.array = consumer.allocateArray(initialSize); this.sorter = new Sorter<>(ShuffleSortDataFormat.INSTANCE); } @@ -70,6 +73,10 @@ public int numRecords() { } public void reset() { + if (consumer != null) { + consumer.freeArray(array); + this.array = consumer.allocateArray(initialSize); + } pos = 0; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index ded8f0472b27..ef79b4908347 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -200,14 +200,17 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix()); } spillWriter.close(); - - inMemSorter.reset(); } final long spillSize = freeMemory(); // Note that this is more-or-less going to be a multiple of the page size, so wasted space in // pages will currently be counted as memory spilled even though that space isn't actually // written to disk. This also counts the space needed to store the sorter's pointer array. + inMemSorter.reset(); + // Reset the in-memory sorter's pointer array only after freeing up the memory pages holding the + // records. Otherwise, if the task is over allocated memory, then without freeing the memory pages, + // we might not be able to get memory for the pointer array. + taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); return spillSize; diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 145c3a195064..01eae0e8dc14 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -84,6 +84,8 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) { */ private int pos = 0; + private long initialSize; + public UnsafeInMemorySorter( final MemoryConsumer consumer, final TaskMemoryManager memoryManager, @@ -102,6 +104,7 @@ public UnsafeInMemorySorter( LongArray array) { this.consumer = consumer; this.memoryManager = memoryManager; + this.initialSize = array.size(); if (recordComparator != null) { this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE); this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager); @@ -123,6 +126,10 @@ public void free() { } public void reset() { + if (consumer != null) { + consumer.freeArray(array); + this.array = consumer.allocateArray(initialSize); + } pos = 0; }