diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala index f1f4b4324edf..3cf7f7e6e8b9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala @@ -17,9 +17,13 @@ package org.apache.spark.rdd -import scala.reflect.ClassTag +import java.util.Comparator -import org.apache.spark.{Logging, RangePartitioner} +import scala.collection.mutable.ArrayBuffer +import scala.reflect._ + +import org.apache.spark.{Logging, RangePartitioner, SparkEnv} +import org.apache.spark.util.collection.ExternalAppendOnlyMap /** * Extra functions available on RDDs of (key, value) pairs where the key is sortable through @@ -41,14 +45,16 @@ import org.apache.spark.{Logging, RangePartitioner} * rdd.sortByKey() * }}} */ -class OrderedRDDFunctions[K : Ordering : ClassTag, + +class OrderedRDDFunctions[K: Ordering: ClassTag, V: ClassTag, - P <: Product2[K, V] : ClassTag]( + P <: Product2[K, V]: ClassTag]( self: RDD[P]) extends Logging with Serializable { private val ordering = implicitly[Ordering[K]] + private type SortCombiner = ArrayBuffer[P] /** * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling * `collect` or `save` on the resulting RDD will return or output an ordered list of records @@ -56,15 +62,56 @@ class OrderedRDDFunctions[K : Ordering : ClassTag, * order of the keys). */ def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size): RDD[P] = { + val externalSorting = SparkEnv.get.conf.getBoolean("spark.shuffle.spill", true) val part = new RangePartitioner(numPartitions, self, ascending) val shuffled = new ShuffledRDD[K, V, V, P](self, part).setKeyOrdering(ordering) - shuffled.mapPartitions(iter => { - val buf = iter.toArray + if (!externalSorting) { + shuffled.mapPartitions(iter => { + val buf = iter.toArray + if (ascending) { + buf.sortWith((x, y) => ordering.lt(x._1, y._1)).iterator + } else { + buf.sortWith((x, y) => ordering.gt(x._1, y._1)).iterator + } + }, preservesPartitioning = true) + } else { + shuffled.mapPartitions(iter => { + val map = createExternalMap(ascending) + while (iter.hasNext) { + val kv = iter.next() + map.insert(kv._1, kv.asInstanceOf[P]) + } + map.sortIterator + }).flatMap(elem => elem._2) + } + } + + private def createExternalMap(ascending: Boolean): ExternalAppendOnlyMap[K, P, SortCombiner] = { + val createCombiner: (P => SortCombiner) = value => { + val newCombiner = new SortCombiner + newCombiner += value + newCombiner + } + val mergeValue: (SortCombiner, P) => SortCombiner = (combiner, value) => { + combiner += value + combiner + } + val mergeCombiners: (SortCombiner, SortCombiner) => SortCombiner = + (combiner1, combiner2) => {combiner1 ++= combiner2} + + new ExternalAppendOnlyMap[K, P, SortCombiner]( + createCombiner, mergeValue, mergeCombiners, + customizedComparator = new KeyComparator[K, SortCombiner](ascending, ordering)) + } + + private class KeyComparator[K, SortCombiner](ascending: Boolean, ord: Ordering[K]) + extends Comparator[(K, SortCombiner)] { + def compare (kc1: (K, SortCombiner), kc2: (K, SortCombiner)): Int = { if (ascending) { - buf.sortWith((x, y) => ordering.lt(x._1, y._1)).iterator + ord.compare(kc1._1, kc2._1) } else { - buf.sortWith((x, y) => ordering.gt(x._1, y._1)).iterator + ord.compare(kc2._1, kc1._1) } - }, preservesPartitioning = true) + } } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 292d0962f4fd..580bb5206345 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -63,7 +63,9 @@ class ExternalAppendOnlyMap[K, V, C]( mergeValue: (C, V) => C, mergeCombiners: (C, C) => C, serializer: Serializer = SparkEnv.get.serializer, - blockManager: BlockManager = SparkEnv.get.blockManager) + blockManager: BlockManager = SparkEnv.get.blockManager, + customizedComparator: Comparator[(K, C)] = null + ) extends Iterable[(K, C)] with Serializable with Logging { import ExternalAppendOnlyMap._ @@ -105,7 +107,8 @@ class ExternalAppendOnlyMap[K, V, C]( private var _diskBytesSpilled = 0L private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024 - private val comparator = new KCComparator[K, C] + private val comparator = + if (customizedComparator == null) new KCComparator[K, C] else customizedComparator private val ser = serializer.newInstance() /** @@ -220,6 +223,14 @@ class ExternalAppendOnlyMap[K, V, C]( } } + def sortIterator: Iterator[(K, C)] = { + if (spilledMaps.isEmpty) { + currentMap.destructiveSortedIterator(comparator) + } else { + new ExternalIterator() + } + } + /** * An iterator that sort-merges (K, C) pairs from the in-memory map and the spilled maps */ @@ -250,12 +261,14 @@ class ExternalAppendOnlyMap[K, V, C]( private def getMorePairs(it: BufferedIterator[(K, C)]): ArrayBuffer[(K, C)] = { val kcPairs = new ArrayBuffer[(K, C)] if (it.hasNext) { - var kc = it.next() + val kc = it.next() kcPairs += kc - val minHash = getKeyHashCode(kc) - while (it.hasNext && it.head._1.hashCode() == minHash) { - kc = it.next() - kcPairs += kc + while (it.hasNext && comparator.compare(kc, it.head) == 0) { + var kc1 = it.next() + kcPairs += kc1 + //if (comparator.compare(kc, it.head) != 0) { + // return kcPairs + //} } } kcPairs @@ -293,15 +306,14 @@ class ExternalAppendOnlyMap[K, V, C]( } // Select a key from the StreamBuffer that holds the lowest key hash val minBuffer = mergeHeap.dequeue() - val (minPairs, minHash) = (minBuffer.pairs, minBuffer.minKeyHash) + val minPairs = minBuffer.pairs val minPair = minPairs.remove(0) var (minKey, minCombiner) = minPair - assert(getKeyHashCode(minPair) == minHash) // For all other streams that may have this key (i.e. have the same minimum key hash), // merge in the corresponding value (if any) from that stream val mergedBuffers = ArrayBuffer[StreamBuffer](minBuffer) - while (mergeHeap.length > 0 && mergeHeap.head.minKeyHash == minHash) { + while (mergeHeap.length > 0 && comparator.compare(mergeHeap.head.pairs.head, minPair) == 0) { val newBuffer = mergeHeap.dequeue() minCombiner = mergeIfKeyExists(minKey, minCombiner, newBuffer) mergedBuffers += newBuffer @@ -334,15 +346,9 @@ class ExternalAppendOnlyMap[K, V, C]( def isEmpty = pairs.length == 0 - // Invalid if there are no more pairs in this stream - def minKeyHash: Int = { - assert(pairs.length > 0) - getKeyHashCode(pairs.head) - } - override def compareTo(other: StreamBuffer): Int = { // descending order because mutable.PriorityQueue dequeues the max, not the min - if (other.minKeyHash < minKeyHash) -1 else if (other.minKeyHash == minKeyHash) 0 else 1 + comparator.compare(other.pairs.head, pairs.head) } } }