Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 20 additions & 16 deletions core/src/main/java/org/apache/spark/memory/MemoryConsumer.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import java.io.IOException;

import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.unsafe.memory.MemoryBlock;


Expand All @@ -28,9 +29,9 @@
*/
public abstract class MemoryConsumer {

private final TaskMemoryManager taskMemoryManager;
protected final TaskMemoryManager taskMemoryManager;
private final long pageSize;
private long used;
protected long used;

protected MemoryConsumer(TaskMemoryManager taskMemoryManager, long pageSize) {
this.taskMemoryManager = taskMemoryManager;
Expand Down Expand Up @@ -74,26 +75,29 @@ public void spill() throws IOException {
public abstract long spill(long size, MemoryConsumer trigger) throws IOException;

/**
* Acquire `size` bytes memory.
*
* If there is not enough memory, throws OutOfMemoryError.
* Allocates a LongArray of `size`.
*/
protected void acquireMemory(long size) {
long got = taskMemoryManager.acquireExecutionMemory(size, this);
if (got < size) {
taskMemoryManager.releaseExecutionMemory(got, this);
public LongArray allocateArray(long size) {
long required = size * 8L;
MemoryBlock page = taskMemoryManager.allocatePage(required, this);
if (page == null || page.size() < required) {
long got = 0;
if (page != null) {
got = page.size();
taskMemoryManager.freePage(page, this);
}
taskMemoryManager.showMemoryUsage();
throw new OutOfMemoryError("Could not acquire " + size + " bytes of memory, got " + got);
throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got);
}
used += got;
used += required;
return new LongArray(page);
}

/**
* Release `size` bytes memory.
* Frees a LongArray.
*/
protected void releaseMemory(long size) {
used -= size;
taskMemoryManager.releaseExecutionMemory(size, this);
public void freeArray(LongArray array) {
freePage(array.memoryBlock());
}

/**
Expand All @@ -109,7 +113,7 @@ protected MemoryBlock allocatePage(long required) {
long got = 0;
if (page != null) {
got = page.size();
freePage(page);
taskMemoryManager.freePage(page, this);
}
taskMemoryManager.showMemoryUsage();
throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) {
if (got < required) {
// Call spill() on other consumers to release memory
for (MemoryConsumer c: consumers) {
if (c != null && c != consumer && c.getUsed() > 0) {
if (c != consumer && c.getUsed() > 0) {
try {
long released = c.spill(required - got, consumer);
if (released > 0) {
Expand Down Expand Up @@ -173,7 +173,9 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) {
}
}

consumers.add(consumer);
if (consumer != null) {
consumers.add(consumer);
}
logger.debug("Task {} acquire {} for {}", taskAttemptId, Utils.bytesToString(got), consumer);
return got;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.apache.spark.storage.DiskBlockObjectWriter;
import org.apache.spark.storage.TempShuffleBlockId;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.util.Utils;

Expand Down Expand Up @@ -114,8 +115,7 @@ public ShuffleExternalSorter(
this.numElementsForSpillThreshold =
conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE);
this.writeMetrics = writeMetrics;
acquireMemory(initialSize * 8L);
this.inMemSorter = new ShuffleInMemorySorter(initialSize);
this.inMemSorter = new ShuffleInMemorySorter(this, initialSize);
this.peakMemoryUsedBytes = getMemoryUsage();
}

Expand Down Expand Up @@ -301,9 +301,8 @@ private long freeMemory() {
public void cleanupResources() {
freeMemory();
if (inMemSorter != null) {
long sorterMemoryUsage = inMemSorter.getMemoryUsage();
inMemSorter.free();
inMemSorter = null;
releaseMemory(sorterMemoryUsage);
}
for (SpillInfo spill : spills) {
if (spill.file.exists() && !spill.file.delete()) {
Expand All @@ -321,26 +320,20 @@ private void growPointerArrayIfNecessary() throws IOException {
assert(inMemSorter != null);
if (!inMemSorter.hasSpaceForAnotherRecord()) {
long used = inMemSorter.getMemoryUsage();
long needed = used + inMemSorter.getMemoryToExpand();
LongArray array;
try {
acquireMemory(needed); // could trigger spilling
// could trigger spilling
array = allocateArray(used / 8 * 2);
} catch (OutOfMemoryError e) {
// should have trigger spilling
assert(inMemSorter.hasSpaceForAnotherRecord());
return;
}
// check if spilling is triggered or not
if (inMemSorter.hasSpaceForAnotherRecord()) {
releaseMemory(needed);
freeArray(array);
} else {
try {
inMemSorter.expandPointerArray();
releaseMemory(used);
} catch (OutOfMemoryError oom) {
// Just in case that JVM had run out of memory
releaseMemory(needed);
spill();
}
inMemSorter.expandPointerArray(array);
}
}
}
Expand Down Expand Up @@ -404,9 +397,8 @@ public SpillInfo[] closeAndGetSpills() throws IOException {
// Do not count the final file towards the spill count.
writeSortedFile(true);
freeMemory();
long sorterMemoryUsage = inMemSorter.getMemoryUsage();
inMemSorter.free();
inMemSorter = null;
releaseMemory(sorterMemoryUsage);
}
return spills.toArray(new SpillInfo[spills.size()]);
} catch (IOException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@

import java.util.Comparator;

import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.util.collection.Sorter;

final class ShuffleInMemorySorter {

private final Sorter<PackedRecordPointer, long[]> sorter;
private final Sorter<PackedRecordPointer, LongArray> sorter;
private static final class SortComparator implements Comparator<PackedRecordPointer> {
@Override
public int compare(PackedRecordPointer left, PackedRecordPointer right) {
Expand All @@ -32,24 +35,34 @@ public int compare(PackedRecordPointer left, PackedRecordPointer right) {
}
private static final SortComparator SORT_COMPARATOR = new SortComparator();

private final MemoryConsumer consumer;

/**
* An array of record pointers and partition ids that have been encoded by
* {@link PackedRecordPointer}. The sort operates on this array instead of directly manipulating
* records.
*/
private long[] array;
private LongArray array;

/**
* The position in the pointer array where new records can be inserted.
*/
private int pos = 0;

public ShuffleInMemorySorter(int initialSize) {
public ShuffleInMemorySorter(MemoryConsumer consumer, int initialSize) {
this.consumer = consumer;
assert (initialSize > 0);
this.array = new long[initialSize];
this.array = consumer.allocateArray(initialSize);
this.sorter = new Sorter<>(ShuffleSortDataFormat.INSTANCE);
}

public void free() {
if (array != null) {
consumer.freeArray(array);
array = null;
}
}

public int numRecords() {
return pos;
}
Expand All @@ -58,30 +71,25 @@ public void reset() {
pos = 0;
}

private int newLength() {
// Guard against overflow:
return array.length <= Integer.MAX_VALUE / 2 ? (array.length * 2) : Integer.MAX_VALUE;
}

/**
* Returns the memory needed to expand
*/
public long getMemoryToExpand() {
return ((long) (newLength() - array.length)) * 8;
}

public void expandPointerArray() {
final long[] oldArray = array;
array = new long[newLength()];
System.arraycopy(oldArray, 0, array, 0, oldArray.length);
public void expandPointerArray(LongArray newArray) {
assert(newArray.size() > array.size());
Platform.copyMemory(
array.getBaseObject(),
array.getBaseOffset(),
newArray.getBaseObject(),
newArray.getBaseOffset(),
array.size() * 8L
);
consumer.freeArray(array);
array = newArray;
}

public boolean hasSpaceForAnotherRecord() {
return pos < array.length;
return pos < array.size();
}

public long getMemoryUsage() {
return array.length * 8L;
return array.size() * 8L;
}

/**
Expand All @@ -96,14 +104,9 @@ public long getMemoryUsage() {
*/
public void insertRecord(long recordPointer, int partitionId) {
if (!hasSpaceForAnotherRecord()) {
if (array.length == Integer.MAX_VALUE) {
throw new IllegalStateException("Sort pointer array has reached maximum size");
} else {
expandPointerArray();
}
expandPointerArray(consumer.allocateArray(array.size() * 2));
}
array[pos] =
PackedRecordPointer.packPointer(recordPointer, partitionId);
array.set(pos, PackedRecordPointer.packPointer(recordPointer, partitionId));
pos++;
}

Expand All @@ -112,12 +115,12 @@ public void insertRecord(long recordPointer, int partitionId) {
*/
public static final class ShuffleSorterIterator {

private final long[] pointerArray;
private final LongArray pointerArray;
private final int numRecords;
final PackedRecordPointer packedRecordPointer = new PackedRecordPointer();
private int position = 0;

public ShuffleSorterIterator(int numRecords, long[] pointerArray) {
public ShuffleSorterIterator(int numRecords, LongArray pointerArray) {
this.numRecords = numRecords;
this.pointerArray = pointerArray;
}
Expand All @@ -127,7 +130,7 @@ public boolean hasNext() {
}

public void loadNext() {
packedRecordPointer.set(pointerArray[position]);
packedRecordPointer.set(pointerArray.get(position));
position++;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,19 @@

package org.apache.spark.shuffle.sort;

import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.util.collection.SortDataFormat;

final class ShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, long[]> {
final class ShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, LongArray> {

public static final ShuffleSortDataFormat INSTANCE = new ShuffleSortDataFormat();

private ShuffleSortDataFormat() { }

@Override
public PackedRecordPointer getKey(long[] data, int pos) {
public PackedRecordPointer getKey(LongArray data, int pos) {
// Since we re-use keys, this method shouldn't be called.
throw new UnsupportedOperationException();
}
Expand All @@ -37,31 +40,38 @@ public PackedRecordPointer newKey() {
}

@Override
public PackedRecordPointer getKey(long[] data, int pos, PackedRecordPointer reuse) {
reuse.set(data[pos]);
public PackedRecordPointer getKey(LongArray data, int pos, PackedRecordPointer reuse) {
reuse.set(data.get(pos));
return reuse;
}

@Override
public void swap(long[] data, int pos0, int pos1) {
final long temp = data[pos0];
data[pos0] = data[pos1];
data[pos1] = temp;
public void swap(LongArray data, int pos0, int pos1) {
final long temp = data.get(pos0);
data.set(pos0, data.get(pos1));
data.set(pos1, temp);
}

@Override
public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) {
dst[dstPos] = src[srcPos];
public void copyElement(LongArray src, int srcPos, LongArray dst, int dstPos) {
dst.set(dstPos, src.get(srcPos));
}

@Override
public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) {
System.arraycopy(src, srcPos, dst, dstPos, length);
public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int length) {
Platform.copyMemory(
src.getBaseObject(),
src.getBaseOffset() + srcPos * 8,
dst.getBaseObject(),
dst.getBaseOffset() + dstPos * 8,
length * 8
);
}

@Override
public long[] allocate(int length) {
return new long[length];
public LongArray allocate(int length) {
// This buffer is used temporary (usually small), so it's fine to allocated from JVM heap.
return new LongArray(MemoryBlock.fromLongArray(new long[length]));
}

}
Loading