Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 56 additions & 9 deletions core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,13 @@

package org.apache.spark.rdd

import scala.reflect.ClassTag
import java.util.Comparator

import org.apache.spark.{Logging, RangePartitioner}
import scala.collection.mutable.ArrayBuffer
import scala.reflect._

import org.apache.spark.{Logging, RangePartitioner, SparkEnv}
import org.apache.spark.util.collection.ExternalAppendOnlyMap

/**
* Extra functions available on RDDs of (key, value) pairs where the key is sortable through
Expand All @@ -41,30 +45,73 @@ import org.apache.spark.{Logging, RangePartitioner}
* rdd.sortByKey()
* }}}
*/
class OrderedRDDFunctions[K : Ordering : ClassTag,

class OrderedRDDFunctions[K: Ordering: ClassTag,
V: ClassTag,
P <: Product2[K, V] : ClassTag](
P <: Product2[K, V]: ClassTag](
self: RDD[P])
extends Logging with Serializable {

private val ordering = implicitly[Ordering[K]]

private type SortCombiner = ArrayBuffer[P]
/**
* Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling
* `collect` or `save` on the resulting RDD will return or output an ordered list of records
* (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in
* order of the keys).
*/
def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size): RDD[P] = {
val externalSorting = SparkEnv.get.conf.getBoolean("spark.shuffle.spill", true)
val part = new RangePartitioner(numPartitions, self, ascending)
val shuffled = new ShuffledRDD[K, V, V, P](self, part).setKeyOrdering(ordering)
shuffled.mapPartitions(iter => {
val buf = iter.toArray
if (!externalSorting) {
shuffled.mapPartitions(iter => {
val buf = iter.toArray
if (ascending) {
buf.sortWith((x, y) => ordering.lt(x._1, y._1)).iterator
} else {
buf.sortWith((x, y) => ordering.gt(x._1, y._1)).iterator
}
}, preservesPartitioning = true)
} else {
shuffled.mapPartitions(iter => {
val map = createExternalMap(ascending)
while (iter.hasNext) {
val kv = iter.next()
map.insert(kv._1, kv.asInstanceOf[P])
}
map.sortIterator
}).flatMap(elem => elem._2)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I would rewrite this in the following style

shuffled.mapPartitions { iter =>
  ...
}.flatMap { case (k, c) =>
  // flatten combiner to return values
  c.iterator.map { x => (k, x).asInstanceOf[P] }
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(and follow this style guide in other places of your code)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This place might be the performance hot place, since we need to reconstruct each tuple. But for current hash map implementation, seems no better solutions to avoid this. I think we should figure out the performance of this manipulation.

}

private def createExternalMap(ascending: Boolean): ExternalAppendOnlyMap[K, P, SortCombiner] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems unnecessary to have a combiner here: if there are multiple key-value pairs with the same key, this requires them to all fit in memory. Instead we should have an option for the ExternalAppendOnlyMap to not attempt to combine them. I'll work on this in my PR.

val createCombiner: (P => SortCombiner) = value => {
val newCombiner = new SortCombiner
newCombiner += value
newCombiner
}
val mergeValue: (SortCombiner, P) => SortCombiner = (combiner, value) => {
combiner += value
combiner
}
val mergeCombiners: (SortCombiner, SortCombiner) => SortCombiner =
(combiner1, combiner2) => {combiner1 ++= combiner2}

new ExternalAppendOnlyMap[K, P, SortCombiner](
createCombiner, mergeValue, mergeCombiners,
customizedComparator = new KeyComparator[K, SortCombiner](ascending, ordering))
}

private class KeyComparator[K, SortCombiner](ascending: Boolean, ord: Ordering[K])
extends Comparator[(K, SortCombiner)] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: indent by 2 spaces

def compare (kc1: (K, SortCombiner), kc2: (K, SortCombiner)): Int = {
if (ascending) {
buf.sortWith((x, y) => ordering.lt(x._1, y._1)).iterator
ord.compare(kc1._1, kc2._1)
} else {
buf.sortWith((x, y) => ordering.gt(x._1, y._1)).iterator
ord.compare(kc2._1, kc1._1)
}
}, preservesPartitioning = true)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ class ExternalAppendOnlyMap[K, V, C](
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C,
serializer: Serializer = SparkEnv.get.serializer,
blockManager: BlockManager = SparkEnv.get.blockManager)
blockManager: BlockManager = SparkEnv.get.blockManager,
customizedComparator: Comparator[(K, C)] = null
)
extends Iterable[(K, C)] with Serializable with Logging {

import ExternalAppendOnlyMap._
Expand Down Expand Up @@ -105,7 +107,8 @@ class ExternalAppendOnlyMap[K, V, C](
private var _diskBytesSpilled = 0L

private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024
private val comparator = new KCComparator[K, C]
private val comparator =
if (customizedComparator == null) new KCComparator[K, C] else customizedComparator
private val ser = serializer.newInstance()

/**
Expand Down Expand Up @@ -220,6 +223,14 @@ class ExternalAppendOnlyMap[K, V, C](
}
}

def sortIterator: Iterator[(K, C)] = {
if (spilledMaps.isEmpty) {
currentMap.destructiveSortedIterator(comparator)
} else {
new ExternalIterator()
}
}

/**
* An iterator that sort-merges (K, C) pairs from the in-memory map and the spilled maps
*/
Expand Down Expand Up @@ -250,12 +261,14 @@ class ExternalAppendOnlyMap[K, V, C](
private def getMorePairs(it: BufferedIterator[(K, C)]): ArrayBuffer[(K, C)] = {
val kcPairs = new ArrayBuffer[(K, C)]
if (it.hasNext) {
var kc = it.next()
val kc = it.next()
kcPairs += kc
val minHash = getKeyHashCode(kc)
while (it.hasNext && it.head._1.hashCode() == minHash) {
kc = it.next()
kcPairs += kc
while (it.hasNext && comparator.compare(kc, it.head) == 0) {
var kc1 = it.next()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be a val (same with var kc above)

kcPairs += kc1
//if (comparator.compare(kc, it.head) != 0) {
// return kcPairs
//}
}
}
kcPairs
Expand Down Expand Up @@ -293,15 +306,14 @@ class ExternalAppendOnlyMap[K, V, C](
}
// Select a key from the StreamBuffer that holds the lowest key hash
val minBuffer = mergeHeap.dequeue()
val (minPairs, minHash) = (minBuffer.pairs, minBuffer.minKeyHash)
val minPairs = minBuffer.pairs
val minPair = minPairs.remove(0)
var (minKey, minCombiner) = minPair
assert(getKeyHashCode(minPair) == minHash)

// For all other streams that may have this key (i.e. have the same minimum key hash),
// merge in the corresponding value (if any) from that stream
val mergedBuffers = ArrayBuffer[StreamBuffer](minBuffer)
while (mergeHeap.length > 0 && mergeHeap.head.minKeyHash == minHash) {
while (mergeHeap.length > 0 && comparator.compare(mergeHeap.head.pairs.head, minPair) == 0) {
val newBuffer = mergeHeap.dequeue()
minCombiner = mergeIfKeyExists(minKey, minCombiner, newBuffer)
mergedBuffers += newBuffer
Expand Down Expand Up @@ -334,15 +346,9 @@ class ExternalAppendOnlyMap[K, V, C](

def isEmpty = pairs.length == 0

// Invalid if there are no more pairs in this stream
def minKeyHash: Int = {
assert(pairs.length > 0)
getKeyHashCode(pairs.head)
}

override def compareTo(other: StreamBuffer): Int = {
// descending order because mutable.PriorityQueue dequeues the max, not the min
if (other.minKeyHash < minKeyHash) -1 else if (other.minKeyHash == minKeyHash) 0 else 1
comparator.compare(other.pairs.head, pairs.head)
}
}
}
Expand Down