Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,25 @@
import java.util.Comparator;

/**
* A port of the Android Timsort class, which utilizes a "stable, adaptive, iterative mergesort."
* A port of the Android TimSort class, which utilizes a "stable, adaptive, iterative mergesort."
* See the method comment on sort() for more details.
*
* This has been kept in Java with the original style in order to match very closely with the
* Anroid source code, and thus be easy to verify correctness.
* Android source code, and thus be easy to verify correctness. The class is package private. We put
* a simple Scala wrapper {@link org.apache.spark.util.collection.Sorter}, which is available to
* package org.apache.spark.
*
* The purpose of the port is to generalize the interface to the sort to accept input data formats
* besides simple arrays where every element is sorted individually. For instance, the AppendOnlyMap
* uses this to sort an Array with alternating elements of the form [key, value, key, value].
* This generalization comes with minimal overhead -- see SortDataFormat for more information.
*
* We allow key reuse to prevent creating many key objects -- see SortDataFormat.
*
* @see org.apache.spark.util.collection.SortDataFormat
* @see org.apache.spark.util.collection.Sorter
*/
class Sorter<K, Buffer> {
class TimSort<K, Buffer> {

/**
* This is the minimum sized sequence that will be merged. Shorter
Expand All @@ -54,7 +61,7 @@ class Sorter<K, Buffer> {

private final SortDataFormat<K, Buffer> s;

public Sorter(SortDataFormat<K, Buffer> sortDataFormat) {
public TimSort(SortDataFormat<K, Buffer> sortDataFormat) {
this.s = sortDataFormat;
}

Expand Down Expand Up @@ -91,7 +98,7 @@ public Sorter(SortDataFormat<K, Buffer> sortDataFormat) {
*
* @author Josh Bloch
*/
void sort(Buffer a, int lo, int hi, Comparator<? super K> c) {
public void sort(Buffer a, int lo, int hi, Comparator<? super K> c) {
assert c != null;

int nRemaining = hi - lo;
Expand Down Expand Up @@ -162,10 +169,13 @@ private void binarySort(Buffer a, int lo, int hi, int start, Comparator<? super
if (start == lo)
start++;

K key0 = s.newKey();
K key1 = s.newKey();

Buffer pivotStore = s.allocate(1);
for ( ; start < hi; start++) {
s.copyElement(a, start, pivotStore, 0);
K pivot = s.getKey(pivotStore, 0);
K pivot = s.getKey(pivotStore, 0, key0);

// Set left (and right) to the index where a[start] (pivot) belongs
int left = lo;
Expand All @@ -178,7 +188,7 @@ private void binarySort(Buffer a, int lo, int hi, int start, Comparator<? super
*/
while (left < right) {
int mid = (left + right) >>> 1;
if (c.compare(pivot, s.getKey(a, mid)) < 0)
if (c.compare(pivot, s.getKey(a, mid, key1)) < 0)
right = mid;
else
left = mid + 1;
Expand Down Expand Up @@ -235,13 +245,16 @@ private int countRunAndMakeAscending(Buffer a, int lo, int hi, Comparator<? supe
if (runHi == hi)
return 1;

K key0 = s.newKey();
K key1 = s.newKey();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

key0, key1 pls


// Find end of run, and reverse range if descending
if (c.compare(s.getKey(a, runHi++), s.getKey(a, lo)) < 0) { // Descending
while (runHi < hi && c.compare(s.getKey(a, runHi), s.getKey(a, runHi - 1)) < 0)
if (c.compare(s.getKey(a, runHi++, key0), s.getKey(a, lo, key1)) < 0) { // Descending
while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) < 0)
runHi++;
reverseRange(a, lo, runHi);
} else { // Ascending
while (runHi < hi && c.compare(s.getKey(a, runHi), s.getKey(a, runHi - 1)) >= 0)
while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) >= 0)
runHi++;
}

Expand Down Expand Up @@ -468,11 +481,13 @@ private void mergeAt(int i) {
}
stackSize--;

K key0 = s.newKey();

/*
* Find where the first element of run2 goes in run1. Prior elements
* in run1 can be ignored (because they're already in place).
*/
int k = gallopRight(s.getKey(a, base2), a, base1, len1, 0, c);
int k = gallopRight(s.getKey(a, base2, key0), a, base1, len1, 0, c);
assert k >= 0;
base1 += k;
len1 -= k;
Expand All @@ -483,7 +498,7 @@ private void mergeAt(int i) {
* Find where the last element of run1 goes in run2. Subsequent elements
* in run2 can be ignored (because they're already in place).
*/
len2 = gallopLeft(s.getKey(a, base1 + len1 - 1), a, base2, len2, len2 - 1, c);
len2 = gallopLeft(s.getKey(a, base1 + len1 - 1, key0), a, base2, len2, len2 - 1, c);
assert len2 >= 0;
if (len2 == 0)
return;
Expand Down Expand Up @@ -517,10 +532,12 @@ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator<
assert len > 0 && hint >= 0 && hint < len;
int lastOfs = 0;
int ofs = 1;
if (c.compare(key, s.getKey(a, base + hint)) > 0) {
K key0 = s.newKey();

if (c.compare(key, s.getKey(a, base + hint, key0)) > 0) {
// Gallop right until a[base+hint+lastOfs] < key <= a[base+hint+ofs]
int maxOfs = len - hint;
while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs)) > 0) {
while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key0)) > 0) {
lastOfs = ofs;
ofs = (ofs << 1) + 1;
if (ofs <= 0) // int overflow
Expand All @@ -535,7 +552,7 @@ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator<
} else { // key <= a[base + hint]
// Gallop left until a[base+hint-ofs] < key <= a[base+hint-lastOfs]
final int maxOfs = hint + 1;
while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs)) <= 0) {
while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs, key0)) <= 0) {
lastOfs = ofs;
ofs = (ofs << 1) + 1;
if (ofs <= 0) // int overflow
Expand All @@ -560,7 +577,7 @@ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator<
while (lastOfs < ofs) {
int m = lastOfs + ((ofs - lastOfs) >>> 1);

if (c.compare(key, s.getKey(a, base + m)) > 0)
if (c.compare(key, s.getKey(a, base + m, key0)) > 0)
lastOfs = m + 1; // a[base + m] < key
else
ofs = m; // key <= a[base + m]
Expand All @@ -587,10 +604,12 @@ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator

int ofs = 1;
int lastOfs = 0;
if (c.compare(key, s.getKey(a, base + hint)) < 0) {
K key1 = s.newKey();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

key0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an input parameter called key. So this one became key1.


if (c.compare(key, s.getKey(a, base + hint, key1)) < 0) {
// Gallop left until a[b+hint - ofs] <= key < a[b+hint - lastOfs]
int maxOfs = hint + 1;
while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs)) < 0) {
while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs, key1)) < 0) {
lastOfs = ofs;
ofs = (ofs << 1) + 1;
if (ofs <= 0) // int overflow
Expand All @@ -606,7 +625,7 @@ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator
} else { // a[b + hint] <= key
// Gallop right until a[b+hint + lastOfs] <= key < a[b+hint + ofs]
int maxOfs = len - hint;
while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs)) >= 0) {
while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key1)) >= 0) {
lastOfs = ofs;
ofs = (ofs << 1) + 1;
if (ofs <= 0) // int overflow
Expand All @@ -630,7 +649,7 @@ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator
while (lastOfs < ofs) {
int m = lastOfs + ((ofs - lastOfs) >>> 1);

if (c.compare(key, s.getKey(a, base + m)) < 0)
if (c.compare(key, s.getKey(a, base + m, key1)) < 0)
ofs = m; // key < a[b + m]
else
lastOfs = m + 1; // a[b + m] <= key
Expand Down Expand Up @@ -679,6 +698,9 @@ private void mergeLo(int base1, int len1, int base2, int len2) {
return;
}

K key0 = s.newKey();
K key1 = s.newKey();

Comparator<? super K> c = this.c; // Use local variable for performance
int minGallop = this.minGallop; // " " " " "
outer:
Expand All @@ -692,7 +714,7 @@ private void mergeLo(int base1, int len1, int base2, int len2) {
*/
do {
assert len1 > 1 && len2 > 0;
if (c.compare(s.getKey(a, cursor2), s.getKey(tmp, cursor1)) < 0) {
if (c.compare(s.getKey(a, cursor2, key0), s.getKey(tmp, cursor1, key1)) < 0) {
s.copyElement(a, cursor2++, a, dest++);
count2++;
count1 = 0;
Expand All @@ -714,7 +736,7 @@ private void mergeLo(int base1, int len1, int base2, int len2) {
*/
do {
assert len1 > 1 && len2 > 0;
count1 = gallopRight(s.getKey(a, cursor2), tmp, cursor1, len1, 0, c);
count1 = gallopRight(s.getKey(a, cursor2, key0), tmp, cursor1, len1, 0, c);
if (count1 != 0) {
s.copyRange(tmp, cursor1, a, dest, count1);
dest += count1;
Expand All @@ -727,7 +749,7 @@ private void mergeLo(int base1, int len1, int base2, int len2) {
if (--len2 == 0)
break outer;

count2 = gallopLeft(s.getKey(tmp, cursor1), a, cursor2, len2, 0, c);
count2 = gallopLeft(s.getKey(tmp, cursor1, key0), a, cursor2, len2, 0, c);
if (count2 != 0) {
s.copyRange(a, cursor2, a, dest, count2);
dest += count2;
Expand Down Expand Up @@ -784,6 +806,9 @@ private void mergeHi(int base1, int len1, int base2, int len2) {
int cursor2 = len2 - 1; // Indexes into tmp array
int dest = base2 + len2 - 1; // Indexes into a

K key0 = s.newKey();
K key1 = s.newKey();

// Move last element of first run and deal with degenerate cases
s.copyElement(a, cursor1--, a, dest--);
if (--len1 == 0) {
Expand Down Expand Up @@ -811,7 +836,7 @@ private void mergeHi(int base1, int len1, int base2, int len2) {
*/
do {
assert len1 > 0 && len2 > 1;
if (c.compare(s.getKey(tmp, cursor2), s.getKey(a, cursor1)) < 0) {
if (c.compare(s.getKey(tmp, cursor2, key0), s.getKey(a, cursor1, key1)) < 0) {
s.copyElement(a, cursor1--, a, dest--);
count1++;
count2 = 0;
Expand All @@ -833,7 +858,7 @@ private void mergeHi(int base1, int len1, int base2, int len2) {
*/
do {
assert len1 > 0 && len2 > 1;
count1 = len1 - gallopRight(s.getKey(tmp, cursor2), a, base1, len1, len1 - 1, c);
count1 = len1 - gallopRight(s.getKey(tmp, cursor2, key0), a, base1, len1, len1 - 1, c);
if (count1 != 0) {
dest -= count1;
cursor1 -= count1;
Expand All @@ -846,7 +871,7 @@ private void mergeHi(int base1, int len1, int base2, int len2) {
if (--len2 == 1)
break outer;

count2 = len2 - gallopLeft(s.getKey(a, cursor1), tmp, 0, len2, len2 - 1, c);
count2 = len2 - gallopLeft(s.getKey(a, cursor1, key0), tmp, 0, len2, len2 - 1, c);
if (count2 != 0) {
dest -= count2;
cursor2 -= count2;
Expand Down
26 changes: 21 additions & 5 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1237,12 +1237,28 @@ private[spark] object Utils extends Logging {
/**
* Timing method based on iterations that permit JVM JIT optimization.
* @param numIters number of iterations
* @param f function to be executed
* @param f function to be executed. If prepare is not None, the running time of each call to f
* must be an order of magnitude longer than one millisecond for accurate timing.
* @param prepare function to be executed before each call to f. Its running time doesn't count.
* @return the total time across all iterations (not couting preparation time)
*/
def timeIt(numIters: Int)(f: => Unit): Long = {
val start = System.currentTimeMillis
times(numIters)(f)
System.currentTimeMillis - start
def timeIt(numIters: Int)(f: => Unit, prepare: Option[() => Unit] = None): Long = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add that it returns the total time across all iterations (which is not the behavior I expected).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

if (prepare.isEmpty) {
val start = System.currentTimeMillis
times(numIters)(f)
System.currentTimeMillis - start
} else {
var i = 0
var sum = 0L
while (i < numIters) {
prepare.get.apply()
val start = System.currentTimeMillis
f
sum += System.currentTimeMillis - start
i += 1
}
sum
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,33 +27,51 @@ import scala.reflect.ClassTag
* Example format: an array of numbers, where each element is also the key.
* See [[KVArraySortDataFormat]] for a more exciting format.
*
* This trait extends Any to ensure it is universal (and thus compiled to a Java interface).
* Note: Declaring and instantiating multiple subclasses of this class would prevent JIT inlining
* overridden methods and hence decrease the shuffle performance.
*
* @tparam K Type of the sort key of each element
* @tparam Buffer Internal data structure used by a particular format (e.g., Array[Int]).
*/
// TODO: Making Buffer a real trait would be a better abstraction, but adds some complexity.
private[spark] trait SortDataFormat[K, Buffer] extends Any {
private[spark]
abstract class SortDataFormat[K, Buffer] {

/**
* Creates a new mutable key for reuse. This should be implemented if you want to override
* [[getKey(Buffer, Int, K)]].
*/
def newKey(): K = null.asInstanceOf[K]

/** Return the sort key for the element at the given index. */
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a comment that this is ONLY invoked by the default getKey(data: Buffer, pos: Int, reuse: K) method. That is, you should not call this from outside.

protected def getKey(data: Buffer, pos: Int): K

/**
* Returns the sort key for the element at the given index and reuse the input key if possible.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a note that the default implementation simply ignores the reuse parameter and invokes the other method. Also give the precondition that the "reused" key will have initially been constructed via newKey().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

* The default implementation ignores the reuse parameter and invokes [[getKey(Buffer, Int]].
* If you want to override this method, you must implement [[newKey()]].
*/
def getKey(data: Buffer, pos: Int, reuse: K): K = {
getKey(data, pos)
}

/** Swap two elements. */
protected def swap(data: Buffer, pos0: Int, pos1: Int): Unit
def swap(data: Buffer, pos0: Int, pos1: Int): Unit

/** Copy a single element from src(srcPos) to dst(dstPos). */
protected def copyElement(src: Buffer, srcPos: Int, dst: Buffer, dstPos: Int): Unit
def copyElement(src: Buffer, srcPos: Int, dst: Buffer, dstPos: Int): Unit

/**
* Copy a range of elements starting at src(srcPos) to dst, starting at dstPos.
* Overlapping ranges are allowed.
*/
protected def copyRange(src: Buffer, srcPos: Int, dst: Buffer, dstPos: Int, length: Int): Unit
def copyRange(src: Buffer, srcPos: Int, dst: Buffer, dstPos: Int, length: Int): Unit

/**
* Allocates a Buffer that can hold up to 'length' elements.
* All elements of the buffer should be considered invalid until data is explicitly copied in.
*/
protected def allocate(length: Int): Buffer
def allocate(length: Int): Buffer
}

/**
Expand All @@ -67,9 +85,9 @@ private[spark] trait SortDataFormat[K, Buffer] extends Any {
private[spark]
class KVArraySortDataFormat[K, T <: AnyRef : ClassTag] extends SortDataFormat[K, Array[T]] {

override protected def getKey(data: Array[T], pos: Int): K = data(2 * pos).asInstanceOf[K]
override def getKey(data: Array[T], pos: Int): K = data(2 * pos).asInstanceOf[K]

override protected def swap(data: Array[T], pos0: Int, pos1: Int) {
override def swap(data: Array[T], pos0: Int, pos1: Int) {
val tmpKey = data(2 * pos0)
val tmpVal = data(2 * pos0 + 1)
data(2 * pos0) = data(2 * pos1)
Expand All @@ -78,17 +96,16 @@ class KVArraySortDataFormat[K, T <: AnyRef : ClassTag] extends SortDataFormat[K,
data(2 * pos1 + 1) = tmpVal
}

override protected def copyElement(src: Array[T], srcPos: Int, dst: Array[T], dstPos: Int) {
override def copyElement(src: Array[T], srcPos: Int, dst: Array[T], dstPos: Int) {
dst(2 * dstPos) = src(2 * srcPos)
dst(2 * dstPos + 1) = src(2 * srcPos + 1)
}

override protected def copyRange(src: Array[T], srcPos: Int,
dst: Array[T], dstPos: Int, length: Int) {
override def copyRange(src: Array[T], srcPos: Int, dst: Array[T], dstPos: Int, length: Int) {
System.arraycopy(src, 2 * srcPos, dst, 2 * dstPos, 2 * length)
}

override protected def allocate(length: Int): Array[T] = {
override def allocate(length: Int): Array[T] = {
new Array[T](2 * length)
}
}
Loading