Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion core/src/main/scala/org/apache/spark/Dependency.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.SortOrder.SortOrder
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleHandle

Expand Down Expand Up @@ -62,7 +63,8 @@ class ShuffleDependency[K, V, C](
val serializer: Option[Serializer] = None,
val keyOrdering: Option[Ordering[K]] = None,
val aggregator: Option[Aggregator[K, V, C]] = None,
val mapSideCombine: Boolean = false)
val mapSideCombine: Boolean = false,
val sortOrder: Option[SortOrder] = None)
extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) {

val shuffleId: Int = rdd.context.newShuffleId()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,13 @@ class OrderedRDDFunctions[K : Ordering : ClassTag,
*/
def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size): RDD[P] = {
val part = new RangePartitioner(numPartitions, self, ascending)
val shuffled = new ShuffledRDD[K, V, V, P](self, part).setKeyOrdering(ordering)
shuffled.mapPartitions(iter => {
val buf = iter.toArray
if (ascending) {
buf.sortWith((x, y) => ordering.lt(x._1, y._1)).iterator
} else {
buf.sortWith((x, y) => ordering.gt(x._1, y._1)).iterator
}
}, preservesPartitioning = true)
new ShuffledRDD[K, V, V, P](self, part)
.setKeyOrdering(ordering)
.setSortOrder(if (ascending) SortOrder.ASCENDING else SortOrder.DESCENDING)
}
}

private[spark] object SortOrder extends Enumeration {
type SortOrder = Value
val ASCENDING, DESCENDING = Value
}
12 changes: 11 additions & 1 deletion core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.reflect.ClassTag

import org.apache.spark._
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.SortOrder.SortOrder
import org.apache.spark.serializer.Serializer

private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
Expand Down Expand Up @@ -51,6 +52,8 @@ class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag](

private var mapSideCombine: Boolean = false

private var sortOrder: Option[SortOrder] = None

/** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
def setSerializer(serializer: Serializer): ShuffledRDD[K, V, C, P] = {
this.serializer = Option(serializer)
Expand All @@ -75,8 +78,15 @@ class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag](
this
}

/** Set sort order for RDD's sorting. */
def setSortOrder(sortOrder: SortOrder): ShuffledRDD[K, V, C, P] = {
this.sortOrder = Option(sortOrder)
this
}

override def getDependencies: Seq[Dependency[_]] = {
List(new ShuffleDependency(prev, part, serializer, keyOrdering, aggregator, mapSideCombine))
List(new ShuffleDependency(prev, part, serializer,
keyOrdering, aggregator, mapSideCombine, sortOrder))
}

override val partitioner = Some(part)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.shuffle.hash

import org.apache.spark.{InterruptibleIterator, TaskContext}
import org.apache.spark.rdd.SortOrder
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}

Expand All @@ -38,7 +39,7 @@ class HashShuffleReader[K, C](
val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context,
Serializer.getSerializer(dep.serializer))

if (dep.aggregator.isDefined) {
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context))
} else {
Expand All @@ -49,6 +50,17 @@ class HashShuffleReader[K, C](
} else {
iter
}

val sortedIter = for (sortOrder <- dep.sortOrder; ordering <- dep.keyOrdering) yield {
val buf = aggregatedIter.toArray
if (sortOrder == SortOrder.ASCENDING) {
buf.sortWith((x, y) => ordering.lt(x._1, y._1)).iterator
} else {
buf.sortWith((x, y) => ordering.gt(x._1, y._1)).iterator
}
}

sortedIter.getOrElse(aggregatedIter)
}

/** Close this reader */
Expand Down