Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,6 @@ public long size() {
* Creates a memory block pointing to the memory used by the long array.
*/
public static MemoryBlock fromLongArray(final long[] array) {
return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8);
return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8L);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.util.collection.Sorter;
import org.apache.spark.util.collection.unsafe.sort.RadixSort;

final class ShuffleInMemorySorter {

private final Sorter<PackedRecordPointer, LongArray> sorter;
private static final class SortComparator implements Comparator<PackedRecordPointer> {
@Override
public int compare(PackedRecordPointer left, PackedRecordPointer right) {
Expand All @@ -44,6 +44,9 @@ public int compare(PackedRecordPointer left, PackedRecordPointer right) {
* An array of record pointers and partition ids that have been encoded by
* {@link PackedRecordPointer}. The sort operates on this array instead of directly manipulating
* records.
*
* Only part of the array will be used to store the pointers, the rest part is preserved as
* temporary buffer for sorting.
*/
private LongArray array;

Expand All @@ -54,14 +57,14 @@ public int compare(PackedRecordPointer left, PackedRecordPointer right) {
private final boolean useRadixSort;

/**
* Set to 2x for radix sort to reserve extra memory for sorting, otherwise 1x.
* The position in the pointer array where new records can be inserted.
*/
private final int memoryAllocationFactor;
private int pos = 0;

/**
* The position in the pointer array where new records can be inserted.
* How many records could be inserted, because part of the array should be left for sorting.
*/
private int pos = 0;
private int usableCapacity = 0;

private int initialSize;

Expand All @@ -70,9 +73,14 @@ public int compare(PackedRecordPointer left, PackedRecordPointer right) {
assert (initialSize > 0);
this.initialSize = initialSize;
this.useRadixSort = useRadixSort;
this.memoryAllocationFactor = useRadixSort ? 2 : 1;
this.array = consumer.allocateArray(initialSize);
this.sorter = new Sorter<>(ShuffleSortDataFormat.INSTANCE);
this.usableCapacity = getUsableCapacity();
}

private int getUsableCapacity() {
// Radix sort requires same amount of used memory as buffer, Tim sort requires
// half of the used memory as buffer.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we also update the comments above while instantiating array to better indicate the contents of the array (and talk about the additional buffer that's needed for sorting)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

by the way, out of curiosity, what's the worst case scenario in TimSort that requires 0.5x more buffer space?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sameeragarwal I think you already had a test case for the worst case (there are two ordered part in the array, the shortest part will be copied into a temporary buffer, os it need 0.5 buffer space), it's doced in TimSort.

return (int) (array.size() / (useRadixSort ? 2 : 1.5));
}

public void free() {
Expand All @@ -89,7 +97,8 @@ public int numRecords() {
public void reset() {
if (consumer != null) {
consumer.freeArray(array);
this.array = consumer.allocateArray(initialSize);
array = consumer.allocateArray(initialSize);
usableCapacity = getUsableCapacity();
}
pos = 0;
}
Expand All @@ -101,14 +110,15 @@ public void expandPointerArray(LongArray newArray) {
array.getBaseOffset(),
newArray.getBaseObject(),
newArray.getBaseOffset(),
array.size() * (8 / memoryAllocationFactor)
pos * 8L
);
consumer.freeArray(array);
array = newArray;
usableCapacity = getUsableCapacity();
}

public boolean hasSpaceForAnotherRecord() {
return pos < array.size() / memoryAllocationFactor;
return pos < usableCapacity;
}

public long getMemoryUsage() {
Expand Down Expand Up @@ -170,6 +180,14 @@ public ShuffleSorterIterator getSortedIterator() {
PackedRecordPointer.PARTITION_ID_START_BYTE_INDEX,
PackedRecordPointer.PARTITION_ID_END_BYTE_INDEX, false, false);
} else {
MemoryBlock unused = new MemoryBlock(
array.getBaseObject(),
array.getBaseOffset() + pos * 8L,
(array.size() - pos) * 8L);
LongArray buffer = new LongArray(unused);
Sorter<PackedRecordPointer, LongArray> sorter =
new Sorter<>(new ShuffleSortDataFormat(buffer));

sorter.sort(array, 0, pos, SORT_COMPARATOR);
}
return new ShuffleSorterIterator(pos, array, offset);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@

import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.util.collection.SortDataFormat;

final class ShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, LongArray> {

public static final ShuffleSortDataFormat INSTANCE = new ShuffleSortDataFormat();
private final LongArray buffer;

private ShuffleSortDataFormat() { }
ShuffleSortDataFormat(LongArray buffer) {
this.buffer = buffer;
}

@Override
public PackedRecordPointer getKey(LongArray data, int pos) {
Expand Down Expand Up @@ -70,8 +71,8 @@ public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int

@Override
public LongArray allocate(int length) {
// This buffer is used temporary (usually small), so it's fine to allocated from JVM heap.
return new LongArray(MemoryBlock.fromLongArray(new long[length]));
assert (length <= buffer.size()) :
"the buffer is smaller than required: " + buffer.size() + " < " + length;
return buffer;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,8 @@ public BytesToBytesMap(
SparkEnv.get() != null ? SparkEnv.get().blockManager() : null,
SparkEnv.get() != null ? SparkEnv.get().serializerManager() : null,
initialCapacity,
0.70,
// In order to re-use the longArray for sorting, the load factor cannot be larger than 0.5.
0.5,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make this a constant and document why it was chosen (to enable radix sort)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given that this may cause BytesToBytesMap to spill to disk more, do we have a sense of the performance implications of this change? Also, can we make this configurable (so that if spark.sql.sort.enableRadixSort = false, users can explicitly set the loadFactor to 0.7)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As analized in the comment, this only have about 10% (or less) difference in practice, I'd like not to have this complexity.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change also us to use radix sort when switching happens, so it's not always regression of performance. It's hard to measure the impact for different workloads, I'd like to prefer for robustness instead of small performance improvements for some workload.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, that sounds fair

pageSizeBytes,
enablePerfMetrics);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.util.collection.Sorter;

/**
Expand Down Expand Up @@ -69,8 +70,6 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) {
private final MemoryConsumer consumer;
private final TaskMemoryManager memoryManager;
@Nullable
private final Sorter<RecordPointerAndKeyPrefix, LongArray> sorter;
@Nullable
private final Comparator<RecordPointerAndKeyPrefix> sortComparator;

/**
Expand All @@ -79,14 +78,12 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) {
@Nullable
private final PrefixComparators.RadixSortSupport radixSortSupport;

/**
* Set to 2x for radix sort to reserve extra memory for sorting, otherwise 1x.
*/
private final int memoryAllocationFactor;

/**
* Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at
* index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
*
* Only part of the array will be used to store the pointers, the rest part is preserved as
* temporary buffer for sorting.
*/
private LongArray array;

Expand All @@ -95,6 +92,11 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) {
*/
private int pos = 0;

/**
* How many records could be inserted, because part of the array should be left for sorting.
*/
private int usableCapacity = 0;

private long initialSize;

private long totalSortTimeNanos = 0L;
Expand All @@ -121,20 +123,24 @@ public UnsafeInMemorySorter(
this.memoryManager = memoryManager;
this.initialSize = array.size();
if (recordComparator != null) {
this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE);
this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager);
if (canUseRadixSort && prefixComparator instanceof PrefixComparators.RadixSortSupport) {
this.radixSortSupport = (PrefixComparators.RadixSortSupport)prefixComparator;
} else {
this.radixSortSupport = null;
}
} else {
this.sorter = null;
this.sortComparator = null;
this.radixSortSupport = null;
}
this.memoryAllocationFactor = this.radixSortSupport != null ? 2 : 1;
this.array = array;
this.usableCapacity = getUsableCapacity();
}

private int getUsableCapacity() {
// Radix sort requires same amount of used memory as buffer, Tim sort requires
// half of the used memory as buffer.
return (int) (array.size() / (radixSortSupport != null ? 2 : 1.5));
}

/**
Expand All @@ -150,7 +156,8 @@ public void free() {
public void reset() {
if (consumer != null) {
consumer.freeArray(array);
this.array = consumer.allocateArray(initialSize);
array = consumer.allocateArray(initialSize);
usableCapacity = getUsableCapacity();
}
pos = 0;
}
Expand All @@ -174,7 +181,7 @@ public long getMemoryUsage() {
}

public boolean hasSpaceForAnotherRecord() {
return pos + 1 < (array.size() / memoryAllocationFactor);
return pos + 1 < usableCapacity;
}

public void expandPointerArray(LongArray newArray) {
Expand All @@ -186,9 +193,10 @@ public void expandPointerArray(LongArray newArray) {
array.getBaseOffset(),
newArray.getBaseObject(),
newArray.getBaseOffset(),
array.size() * (8 / memoryAllocationFactor));
pos * 8L);
consumer.freeArray(array);
array = newArray;
usableCapacity = getUsableCapacity();
}

/**
Expand Down Expand Up @@ -275,13 +283,20 @@ public void loadNext() {
public SortedIterator getSortedIterator() {
int offset = 0;
long start = System.nanoTime();
if (sorter != null) {
if (sortComparator != null) {
if (this.radixSortSupport != null) {
// TODO(ekl) we should handle NULL values before radix sort for efficiency, since they
// force a full-width sort (and we cannot radix-sort nullable long fields at all).
offset = RadixSort.sortKeyPrefixArray(
array, pos / 2, 0, 7, radixSortSupport.sortDescending(), radixSortSupport.sortSigned());
} else {
MemoryBlock unused = new MemoryBlock(
array.getBaseObject(),
array.getBaseOffset() + pos * 8L,
(array.size() - pos) * 8L);
LongArray buffer = new LongArray(unused);
Sorter<RecordPointerAndKeyPrefix, LongArray> sorter =
new Sorter<>(new UnsafeSortDataFormat(buffer));
sorter.sort(array, 0, pos / 2, sortComparator);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.util.collection.SortDataFormat;

/**
Expand All @@ -32,9 +31,11 @@
public final class UnsafeSortDataFormat
extends SortDataFormat<RecordPointerAndKeyPrefix, LongArray> {

public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat();
private final LongArray buffer;

private UnsafeSortDataFormat() { }
public UnsafeSortDataFormat(LongArray buffer) {
this.buffer = buffer;
}

@Override
public RecordPointerAndKeyPrefix getKey(LongArray data, int pos) {
Expand Down Expand Up @@ -83,9 +84,9 @@ public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int

@Override
public LongArray allocate(int length) {
assert (length < Integer.MAX_VALUE / 2) : "Length " + length + " is too large";
// This is used as temporary buffer, it's fine to allocate from JVM heap.
return new LongArray(MemoryBlock.fromLongArray(new long[length * 2]));
assert (length * 2 <= buffer.size()) :
"the buffer is smaller than required: " + buffer.size() + " < " + (length * 2);
return buffer;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@
import java.nio.ByteBuffer;
import java.util.*;

import scala.*;
import scala.Option;
import scala.Product2;
import scala.Tuple2;
import scala.Tuple2$;
import scala.collection.Iterator;
import scala.runtime.AbstractFunction1;

import com.google.common.collect.Iterators;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.Iterators;
import com.google.common.io.ByteStreams;
import org.junit.After;
import org.junit.Before;
Expand All @@ -35,29 +38,33 @@
import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.lessThan;
import static org.junit.Assert.*;
import static org.mockito.Answers.RETURNS_SMART_NULLS;
import static org.mockito.Mockito.*;

import org.apache.spark.*;
import org.apache.spark.HashPartitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.executor.TaskMetrics;
import org.apache.spark.io.CompressionCodec$;
import org.apache.spark.io.LZ4CompressionCodec;
import org.apache.spark.io.LZFCompressionCodec;
import org.apache.spark.io.SnappyCompressionCodec;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.executor.TaskMetrics;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.memory.TestMemoryManager;
import org.apache.spark.network.util.LimitedInputStream;
import org.apache.spark.serializer.*;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.serializer.*;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
import org.apache.spark.storage.*;
import org.apache.spark.memory.TestMemoryManager;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.util.Utils;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.lessThan;
import static org.junit.Assert.*;
import static org.mockito.Answers.RETURNS_SMART_NULLS;
import static org.mockito.Mockito.*;

public class UnsafeShuffleWriterSuite {

static final int NUM_PARTITITONS = 4;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ public void spillInIterator() throws IOException {
@Test
public void multipleValuesForSameKey() {
BytesToBytesMap map =
new BytesToBytesMap(taskMemoryManager, blockManager, serializerManager, 1, 0.75, 1024, false);
new BytesToBytesMap(taskMemoryManager, blockManager, serializerManager, 1, 0.5, 1024, false);
try {
int i;
for (i = 0; i < 1024; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,10 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
// that can trigger copyRange() in TimSort.mergeLo() or TimSort.mergeHi()
val ref = Array.tabulate[Long](size) { i => if (i < size / 2) size / 2 + i else i }
val buf = new LongArray(MemoryBlock.fromLongArray(ref))
val tmp = new Array[Long](size/2)
val tmpBuf = new LongArray(MemoryBlock.fromLongArray(tmp))

new Sorter(UnsafeSortDataFormat.INSTANCE).sort(
new Sorter(new UnsafeSortDataFormat(tmpBuf)).sort(
buf, 0, size, new Comparator[RecordPointerAndKeyPrefix] {
override def compare(
r1: RecordPointerAndKeyPrefix,
Expand Down
Loading