diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala index d0d25b43d047..798e7b54aa8c 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala @@ -30,7 +30,20 @@ private[spark] class PartitionedAppendOnlyMap[K, V] def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]]) : Iterator[((Int, K), V)] = { - val comparator = keyComparator.map(partitionKeyComparator).getOrElse(partitionComparator) + val comparator : Comparator[(Int, K)] = + if (keyComparator.isEmpty) { + partitionComparator + } else + new Comparator[(Int, K)] { + override def compare(a: (Int, K), b: (Int, K)): Int = { + val partitionDiff = a._1 - b._1 + if (partitionDiff != 0) { + partitionDiff + } else { + keyComparator.get.compare(a._2, b._2) + } + } + } destructiveSortedIterator(comparator) }