Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.shuffle.hash
import org.apache.spark.{InterruptibleIterator, TaskContext}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
import org.apache.spark.util.collection.ExternalSorter

private[spark] class HashShuffleReader[K, C](
handle: BaseShuffleHandle[K, _, C],
Expand All @@ -35,8 +36,8 @@ private[spark] class HashShuffleReader[K, C](

/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context,
Serializer.getSerializer(dep.serializer))
val ser = Serializer.getSerializer(dep.serializer)
val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser)

val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
Expand All @@ -54,16 +55,13 @@ private[spark] class HashShuffleReader[K, C](
// Sort the output if there is a sort ordering defined.
dep.keyOrdering match {
case Some(keyOrd: Ordering[K]) =>
// Define a Comparator for the whole record based on the key Ordering.
val cmp = new Ordering[Product2[K, C]] {
override def compare(o1: Product2[K, C], o2: Product2[K, C]): Int = {
keyOrd.compare(o1._1, o2._1)
}
}
val sortBuffer: Array[Product2[K, C]] = aggregatedIter.toArray
// TODO: do external sort.
scala.util.Sorting.quickSort(sortBuffer)(cmp)
sortBuffer.iterator
// Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
// the ExternalSorter won't spill to disk.
val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser))
sorter.write(aggregatedIter)
context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled
context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled
sorter.iterator
case None =>
aggregatedIter
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,11 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
fail(s"Value 2 for ${i} was wrong: expected ${expected}, got ${seq2.toSet}")
}
}

// sortByKey - should spill ~17 times
val rddE = sc.parallelize(0 until 100000).map(i => (i/4, i))
val resultE = rddE.sortByKey().collect().toSeq
assert(resultE === (0 until 100000).map(i => (i/4, i)).toSeq)
}

test("spilling in local cluster with many reduce tasks") {
Expand Down Expand Up @@ -256,6 +261,11 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
fail(s"Value 2 for ${i} was wrong: expected ${expected}, got ${seq2.toSet}")
}
}

// sortByKey - should spill ~8 times per executor
val rddE = sc.parallelize(0 until 100000).map(i => (i/4, i))
val resultE = rddE.sortByKey().collect().toSeq
assert(resultE === (0 until 100000).map(i => (i/4, i)).toSeq)
}

test("cleanup of intermediate files in sorter") {
Expand Down