-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-4084] Reuse sort key in Sorter #2937
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
464ddce
cf94e8a
b00db4d
6ffbe66
5f0d530
8626356
7de2efd
78f2879
720f731
38ba50c
a72f53c
0b7b682
d73c3d0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,18 +20,25 @@ | |
| import java.util.Comparator; | ||
|
|
||
| /** | ||
| * A port of the Android Timsort class, which utilizes a "stable, adaptive, iterative mergesort." | ||
| * A port of the Android TimSort class, which utilizes a "stable, adaptive, iterative mergesort." | ||
| * See the method comment on sort() for more details. | ||
| * | ||
| * This has been kept in Java with the original style in order to match very closely with the | ||
| * Anroid source code, and thus be easy to verify correctness. | ||
| * Android source code, and thus be easy to verify correctness. The class is package private. We put | ||
| * a simple Scala wrapper {@link org.apache.spark.util.collection.Sorter}, which is available to | ||
| * package org.apache.spark. | ||
| * | ||
| * The purpose of the port is to generalize the interface to the sort to accept input data formats | ||
| * besides simple arrays where every element is sorted individually. For instance, the AppendOnlyMap | ||
| * uses this to sort an Array with alternating elements of the form [key, value, key, value]. | ||
| * This generalization comes with minimal overhead -- see SortDataFormat for more information. | ||
| * | ||
| * We allow key reuse to prevent creating many key objects -- see SortDataFormat. | ||
| * | ||
| * @see org.apache.spark.util.collection.SortDataFormat | ||
| * @see org.apache.spark.util.collection.Sorter | ||
| */ | ||
| class Sorter<K, Buffer> { | ||
| class TimSort<K, Buffer> { | ||
|
|
||
| /** | ||
| * This is the minimum sized sequence that will be merged. Shorter | ||
|
|
@@ -54,7 +61,7 @@ class Sorter<K, Buffer> { | |
|
|
||
| private final SortDataFormat<K, Buffer> s; | ||
|
|
||
| public Sorter(SortDataFormat<K, Buffer> sortDataFormat) { | ||
| public TimSort(SortDataFormat<K, Buffer> sortDataFormat) { | ||
| this.s = sortDataFormat; | ||
| } | ||
|
|
||
|
|
@@ -91,7 +98,7 @@ public Sorter(SortDataFormat<K, Buffer> sortDataFormat) { | |
| * | ||
| * @author Josh Bloch | ||
| */ | ||
| void sort(Buffer a, int lo, int hi, Comparator<? super K> c) { | ||
| public void sort(Buffer a, int lo, int hi, Comparator<? super K> c) { | ||
| assert c != null; | ||
|
|
||
| int nRemaining = hi - lo; | ||
|
|
@@ -162,10 +169,13 @@ private void binarySort(Buffer a, int lo, int hi, int start, Comparator<? super | |
| if (start == lo) | ||
| start++; | ||
|
|
||
| K key0 = s.newKey(); | ||
| K key1 = s.newKey(); | ||
|
|
||
| Buffer pivotStore = s.allocate(1); | ||
| for ( ; start < hi; start++) { | ||
| s.copyElement(a, start, pivotStore, 0); | ||
| K pivot = s.getKey(pivotStore, 0); | ||
| K pivot = s.getKey(pivotStore, 0, key0); | ||
|
|
||
| // Set left (and right) to the index where a[start] (pivot) belongs | ||
| int left = lo; | ||
|
|
@@ -178,7 +188,7 @@ private void binarySort(Buffer a, int lo, int hi, int start, Comparator<? super | |
| */ | ||
| while (left < right) { | ||
| int mid = (left + right) >>> 1; | ||
| if (c.compare(pivot, s.getKey(a, mid)) < 0) | ||
| if (c.compare(pivot, s.getKey(a, mid, key1)) < 0) | ||
| right = mid; | ||
| else | ||
| left = mid + 1; | ||
|
|
@@ -235,13 +245,16 @@ private int countRunAndMakeAscending(Buffer a, int lo, int hi, Comparator<? supe | |
| if (runHi == hi) | ||
| return 1; | ||
|
|
||
| K key0 = s.newKey(); | ||
| K key1 = s.newKey(); | ||
|
|
||
| // Find end of run, and reverse range if descending | ||
| if (c.compare(s.getKey(a, runHi++), s.getKey(a, lo)) < 0) { // Descending | ||
| while (runHi < hi && c.compare(s.getKey(a, runHi), s.getKey(a, runHi - 1)) < 0) | ||
| if (c.compare(s.getKey(a, runHi++, key0), s.getKey(a, lo, key1)) < 0) { // Descending | ||
| while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) < 0) | ||
| runHi++; | ||
| reverseRange(a, lo, runHi); | ||
| } else { // Ascending | ||
| while (runHi < hi && c.compare(s.getKey(a, runHi), s.getKey(a, runHi - 1)) >= 0) | ||
| while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) >= 0) | ||
| runHi++; | ||
| } | ||
|
|
||
|
|
@@ -468,11 +481,13 @@ private void mergeAt(int i) { | |
| } | ||
| stackSize--; | ||
|
|
||
| K key0 = s.newKey(); | ||
|
|
||
| /* | ||
| * Find where the first element of run2 goes in run1. Prior elements | ||
| * in run1 can be ignored (because they're already in place). | ||
| */ | ||
| int k = gallopRight(s.getKey(a, base2), a, base1, len1, 0, c); | ||
| int k = gallopRight(s.getKey(a, base2, key0), a, base1, len1, 0, c); | ||
| assert k >= 0; | ||
| base1 += k; | ||
| len1 -= k; | ||
|
|
@@ -483,7 +498,7 @@ private void mergeAt(int i) { | |
| * Find where the last element of run1 goes in run2. Subsequent elements | ||
| * in run2 can be ignored (because they're already in place). | ||
| */ | ||
| len2 = gallopLeft(s.getKey(a, base1 + len1 - 1), a, base2, len2, len2 - 1, c); | ||
| len2 = gallopLeft(s.getKey(a, base1 + len1 - 1, key0), a, base2, len2, len2 - 1, c); | ||
| assert len2 >= 0; | ||
| if (len2 == 0) | ||
| return; | ||
|
|
@@ -517,10 +532,12 @@ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator< | |
| assert len > 0 && hint >= 0 && hint < len; | ||
| int lastOfs = 0; | ||
| int ofs = 1; | ||
| if (c.compare(key, s.getKey(a, base + hint)) > 0) { | ||
| K key0 = s.newKey(); | ||
|
|
||
| if (c.compare(key, s.getKey(a, base + hint, key0)) > 0) { | ||
| // Gallop right until a[base+hint+lastOfs] < key <= a[base+hint+ofs] | ||
| int maxOfs = len - hint; | ||
| while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs)) > 0) { | ||
| while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key0)) > 0) { | ||
| lastOfs = ofs; | ||
| ofs = (ofs << 1) + 1; | ||
| if (ofs <= 0) // int overflow | ||
|
|
@@ -535,7 +552,7 @@ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator< | |
| } else { // key <= a[base + hint] | ||
| // Gallop left until a[base+hint-ofs] < key <= a[base+hint-lastOfs] | ||
| final int maxOfs = hint + 1; | ||
| while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs)) <= 0) { | ||
| while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs, key0)) <= 0) { | ||
| lastOfs = ofs; | ||
| ofs = (ofs << 1) + 1; | ||
| if (ofs <= 0) // int overflow | ||
|
|
@@ -560,7 +577,7 @@ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator< | |
| while (lastOfs < ofs) { | ||
| int m = lastOfs + ((ofs - lastOfs) >>> 1); | ||
|
|
||
| if (c.compare(key, s.getKey(a, base + m)) > 0) | ||
| if (c.compare(key, s.getKey(a, base + m, key0)) > 0) | ||
| lastOfs = m + 1; // a[base + m] < key | ||
| else | ||
| ofs = m; // key <= a[base + m] | ||
|
|
@@ -587,10 +604,12 @@ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator | |
|
|
||
| int ofs = 1; | ||
| int lastOfs = 0; | ||
| if (c.compare(key, s.getKey(a, base + hint)) < 0) { | ||
| K key1 = s.newKey(); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. key0
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is an input parameter called |
||
|
|
||
| if (c.compare(key, s.getKey(a, base + hint, key1)) < 0) { | ||
| // Gallop left until a[b+hint - ofs] <= key < a[b+hint - lastOfs] | ||
| int maxOfs = hint + 1; | ||
| while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs)) < 0) { | ||
| while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs, key1)) < 0) { | ||
| lastOfs = ofs; | ||
| ofs = (ofs << 1) + 1; | ||
| if (ofs <= 0) // int overflow | ||
|
|
@@ -606,7 +625,7 @@ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator | |
| } else { // a[b + hint] <= key | ||
| // Gallop right until a[b+hint + lastOfs] <= key < a[b+hint + ofs] | ||
| int maxOfs = len - hint; | ||
| while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs)) >= 0) { | ||
| while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key1)) >= 0) { | ||
| lastOfs = ofs; | ||
| ofs = (ofs << 1) + 1; | ||
| if (ofs <= 0) // int overflow | ||
|
|
@@ -630,7 +649,7 @@ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator | |
| while (lastOfs < ofs) { | ||
| int m = lastOfs + ((ofs - lastOfs) >>> 1); | ||
|
|
||
| if (c.compare(key, s.getKey(a, base + m)) < 0) | ||
| if (c.compare(key, s.getKey(a, base + m, key1)) < 0) | ||
| ofs = m; // key < a[b + m] | ||
| else | ||
| lastOfs = m + 1; // a[b + m] <= key | ||
|
|
@@ -679,6 +698,9 @@ private void mergeLo(int base1, int len1, int base2, int len2) { | |
| return; | ||
| } | ||
|
|
||
| K key0 = s.newKey(); | ||
| K key1 = s.newKey(); | ||
|
|
||
| Comparator<? super K> c = this.c; // Use local variable for performance | ||
| int minGallop = this.minGallop; // " " " " " | ||
| outer: | ||
|
|
@@ -692,7 +714,7 @@ private void mergeLo(int base1, int len1, int base2, int len2) { | |
| */ | ||
| do { | ||
| assert len1 > 1 && len2 > 0; | ||
| if (c.compare(s.getKey(a, cursor2), s.getKey(tmp, cursor1)) < 0) { | ||
| if (c.compare(s.getKey(a, cursor2, key0), s.getKey(tmp, cursor1, key1)) < 0) { | ||
| s.copyElement(a, cursor2++, a, dest++); | ||
| count2++; | ||
| count1 = 0; | ||
|
|
@@ -714,7 +736,7 @@ private void mergeLo(int base1, int len1, int base2, int len2) { | |
| */ | ||
| do { | ||
| assert len1 > 1 && len2 > 0; | ||
| count1 = gallopRight(s.getKey(a, cursor2), tmp, cursor1, len1, 0, c); | ||
| count1 = gallopRight(s.getKey(a, cursor2, key0), tmp, cursor1, len1, 0, c); | ||
| if (count1 != 0) { | ||
| s.copyRange(tmp, cursor1, a, dest, count1); | ||
| dest += count1; | ||
|
|
@@ -727,7 +749,7 @@ private void mergeLo(int base1, int len1, int base2, int len2) { | |
| if (--len2 == 0) | ||
| break outer; | ||
|
|
||
| count2 = gallopLeft(s.getKey(tmp, cursor1), a, cursor2, len2, 0, c); | ||
| count2 = gallopLeft(s.getKey(tmp, cursor1, key0), a, cursor2, len2, 0, c); | ||
| if (count2 != 0) { | ||
| s.copyRange(a, cursor2, a, dest, count2); | ||
| dest += count2; | ||
|
|
@@ -784,6 +806,9 @@ private void mergeHi(int base1, int len1, int base2, int len2) { | |
| int cursor2 = len2 - 1; // Indexes into tmp array | ||
| int dest = base2 + len2 - 1; // Indexes into a | ||
|
|
||
| K key0 = s.newKey(); | ||
| K key1 = s.newKey(); | ||
|
|
||
| // Move last element of first run and deal with degenerate cases | ||
| s.copyElement(a, cursor1--, a, dest--); | ||
| if (--len1 == 0) { | ||
|
|
@@ -811,7 +836,7 @@ private void mergeHi(int base1, int len1, int base2, int len2) { | |
| */ | ||
| do { | ||
| assert len1 > 0 && len2 > 1; | ||
| if (c.compare(s.getKey(tmp, cursor2), s.getKey(a, cursor1)) < 0) { | ||
| if (c.compare(s.getKey(tmp, cursor2, key0), s.getKey(a, cursor1, key1)) < 0) { | ||
| s.copyElement(a, cursor1--, a, dest--); | ||
| count1++; | ||
| count2 = 0; | ||
|
|
@@ -833,7 +858,7 @@ private void mergeHi(int base1, int len1, int base2, int len2) { | |
| */ | ||
| do { | ||
| assert len1 > 0 && len2 > 1; | ||
| count1 = len1 - gallopRight(s.getKey(tmp, cursor2), a, base1, len1, len1 - 1, c); | ||
| count1 = len1 - gallopRight(s.getKey(tmp, cursor2, key0), a, base1, len1, len1 - 1, c); | ||
| if (count1 != 0) { | ||
| dest -= count1; | ||
| cursor1 -= count1; | ||
|
|
@@ -846,7 +871,7 @@ private void mergeHi(int base1, int len1, int base2, int len2) { | |
| if (--len2 == 1) | ||
| break outer; | ||
|
|
||
| count2 = len2 - gallopLeft(s.getKey(a, cursor1), tmp, 0, len2, len2 - 1, c); | ||
| count2 = len2 - gallopLeft(s.getKey(a, cursor1, key0), tmp, 0, len2, len2 - 1, c); | ||
| if (count2 != 0) { | ||
| dest -= count2; | ||
| cursor2 -= count2; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1237,12 +1237,28 @@ private[spark] object Utils extends Logging { | |
| /** | ||
| * Timing method based on iterations that permit JVM JIT optimization. | ||
| * @param numIters number of iterations | ||
| * @param f function to be executed | ||
| * @param f function to be executed. If prepare is not None, the running time of each call to f | ||
| * must be an order of magnitude longer than one millisecond for accurate timing. | ||
| * @param prepare function to be executed before each call to f. Its running time doesn't count. | ||
| * @return the total time across all iterations (not couting preparation time) | ||
| */ | ||
| def timeIt(numIters: Int)(f: => Unit): Long = { | ||
| val start = System.currentTimeMillis | ||
| times(numIters)(f) | ||
| System.currentTimeMillis - start | ||
| def timeIt(numIters: Int)(f: => Unit, prepare: Option[() => Unit] = None): Long = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe add that it returns the total time across all iterations (which is not the behavior I expected).
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done. |
||
| if (prepare.isEmpty) { | ||
| val start = System.currentTimeMillis | ||
| times(numIters)(f) | ||
| System.currentTimeMillis - start | ||
| } else { | ||
| var i = 0 | ||
| var sum = 0L | ||
| while (i < numIters) { | ||
| prepare.get.apply() | ||
| val start = System.currentTimeMillis | ||
| f | ||
| sum += System.currentTimeMillis - start | ||
| i += 1 | ||
| } | ||
| sum | ||
| } | ||
| } | ||
|
|
||
| /** | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,33 +27,51 @@ import scala.reflect.ClassTag | |
| * Example format: an array of numbers, where each element is also the key. | ||
| * See [[KVArraySortDataFormat]] for a more exciting format. | ||
| * | ||
| * This trait extends Any to ensure it is universal (and thus compiled to a Java interface). | ||
| * Note: Declaring and instantiating multiple subclasses of this class would prevent JIT inlining | ||
| * overridden methods and hence decrease the shuffle performance. | ||
| * | ||
| * @tparam K Type of the sort key of each element | ||
| * @tparam Buffer Internal data structure used by a particular format (e.g., Array[Int]). | ||
| */ | ||
| // TODO: Making Buffer a real trait would be a better abstraction, but adds some complexity. | ||
| private[spark] trait SortDataFormat[K, Buffer] extends Any { | ||
| private[spark] | ||
| abstract class SortDataFormat[K, Buffer] { | ||
|
|
||
| /** | ||
| * Creates a new mutable key for reuse. This should be implemented if you want to override | ||
| * [[getKey(Buffer, Int, K)]]. | ||
| */ | ||
| def newKey(): K = null.asInstanceOf[K] | ||
|
|
||
| /** Return the sort key for the element at the given index. */ | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe add a comment that this is ONLY invoked by the default |
||
| protected def getKey(data: Buffer, pos: Int): K | ||
|
|
||
| /** | ||
| * Returns the sort key for the element at the given index and reuse the input key if possible. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a note that the default implementation simply ignores the reuse parameter and invokes the other method. Also give the precondition that the "reused" key will have initially been constructed via newKey().
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done. |
||
| * The default implementation ignores the reuse parameter and invokes [[getKey(Buffer, Int]]. | ||
| * If you want to override this method, you must implement [[newKey()]]. | ||
| */ | ||
| def getKey(data: Buffer, pos: Int, reuse: K): K = { | ||
| getKey(data, pos) | ||
| } | ||
|
|
||
| /** Swap two elements. */ | ||
| protected def swap(data: Buffer, pos0: Int, pos1: Int): Unit | ||
| def swap(data: Buffer, pos0: Int, pos1: Int): Unit | ||
|
|
||
| /** Copy a single element from src(srcPos) to dst(dstPos). */ | ||
| protected def copyElement(src: Buffer, srcPos: Int, dst: Buffer, dstPos: Int): Unit | ||
| def copyElement(src: Buffer, srcPos: Int, dst: Buffer, dstPos: Int): Unit | ||
|
|
||
| /** | ||
| * Copy a range of elements starting at src(srcPos) to dst, starting at dstPos. | ||
| * Overlapping ranges are allowed. | ||
| */ | ||
| protected def copyRange(src: Buffer, srcPos: Int, dst: Buffer, dstPos: Int, length: Int): Unit | ||
| def copyRange(src: Buffer, srcPos: Int, dst: Buffer, dstPos: Int, length: Int): Unit | ||
|
|
||
| /** | ||
| * Allocates a Buffer that can hold up to 'length' elements. | ||
| * All elements of the buffer should be considered invalid until data is explicitly copied in. | ||
| */ | ||
| protected def allocate(length: Int): Buffer | ||
| def allocate(length: Int): Buffer | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -67,9 +85,9 @@ private[spark] trait SortDataFormat[K, Buffer] extends Any { | |
| private[spark] | ||
| class KVArraySortDataFormat[K, T <: AnyRef : ClassTag] extends SortDataFormat[K, Array[T]] { | ||
|
|
||
| override protected def getKey(data: Array[T], pos: Int): K = data(2 * pos).asInstanceOf[K] | ||
| override def getKey(data: Array[T], pos: Int): K = data(2 * pos).asInstanceOf[K] | ||
|
|
||
| override protected def swap(data: Array[T], pos0: Int, pos1: Int) { | ||
| override def swap(data: Array[T], pos0: Int, pos1: Int) { | ||
| val tmpKey = data(2 * pos0) | ||
| val tmpVal = data(2 * pos0 + 1) | ||
| data(2 * pos0) = data(2 * pos1) | ||
|
|
@@ -78,17 +96,16 @@ class KVArraySortDataFormat[K, T <: AnyRef : ClassTag] extends SortDataFormat[K, | |
| data(2 * pos1 + 1) = tmpVal | ||
| } | ||
|
|
||
| override protected def copyElement(src: Array[T], srcPos: Int, dst: Array[T], dstPos: Int) { | ||
| override def copyElement(src: Array[T], srcPos: Int, dst: Array[T], dstPos: Int) { | ||
| dst(2 * dstPos) = src(2 * srcPos) | ||
| dst(2 * dstPos + 1) = src(2 * srcPos + 1) | ||
| } | ||
|
|
||
| override protected def copyRange(src: Array[T], srcPos: Int, | ||
| dst: Array[T], dstPos: Int, length: Int) { | ||
| override def copyRange(src: Array[T], srcPos: Int, dst: Array[T], dstPos: Int, length: Int) { | ||
| System.arraycopy(src, 2 * srcPos, dst, 2 * dstPos, 2 * length) | ||
| } | ||
|
|
||
| override protected def allocate(length: Int): Array[T] = { | ||
| override def allocate(length: Int): Array[T] = { | ||
| new Array[T](2 * length) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
key0, key1 pls