Skip to content

Commit bbf359d

Browse files
committed
More work
1 parent 3a56341 commit bbf359d

File tree

7 files changed

+258
-23
lines changed

7 files changed

+258
-23
lines changed

core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@
1717

1818
package org.apache.spark.shuffle.sort
1919

20+
import java.io.{DataInputStream, FileInputStream}
21+
2022
import org.apache.spark.shuffle._
2123
import org.apache.spark.{TaskContext, ShuffleDependency}
2224
import org.apache.spark.shuffle.hash.HashShuffleReader
25+
import org.apache.spark.storage.{DiskBlockManager, FileSegment, ShuffleBlockId}
2326

2427
private[spark] class SortShuffleManager extends ShuffleManager {
2528
/**
@@ -57,4 +60,21 @@ private[spark] class SortShuffleManager extends ShuffleManager {
5760

5861
/** Shut down this ShuffleManager. */
5962
override def stop(): Unit = {}
63+
64+
/** Get the location of a block in a map output file. Uses the index file we create for it. */
65+
def getBlockLocation(blockId: ShuffleBlockId, diskManager: DiskBlockManager): FileSegment = {
66+
// The block is actually going to be a range of a single map output file for this map,
67+
// so
68+
val realId = ShuffleBlockId(blockId.shuffleId, blockId.mapId, 0)
69+
val indexFile = diskManager.getFile(realId.name + ".index")
70+
val in = new DataInputStream(new FileInputStream(indexFile))
71+
try {
72+
in.skip(blockId.reduceId * 8)
73+
val offset = in.readLong()
74+
val nextOffset = in.readLong()
75+
new FileSegment(diskManager.getFile(realId), offset, nextOffset - offset)
76+
} finally {
77+
in.close()
78+
}
79+
}
6080
}

core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala

Lines changed: 85 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,14 @@
1818
package org.apache.spark.shuffle.sort
1919

2020
import org.apache.spark.shuffle.{ShuffleWriter, BaseShuffleHandle}
21-
import org.apache.spark.{SparkEnv, Logging, TaskContext}
21+
import org.apache.spark.{MapOutputTracker, SparkEnv, Logging, TaskContext}
2222
import org.apache.spark.scheduler.MapStatus
2323
import org.apache.spark.serializer.Serializer
2424
import org.apache.spark.util.collection.ExternalSorter
25+
import org.apache.spark.storage.ShuffleBlockId
26+
import java.util.concurrent.atomic.AtomicInteger
27+
import org.apache.spark.executor.ShuffleWriteMetrics
28+
import java.io.{BufferedOutputStream, FileOutputStream, DataOutputStream}
2529

2630
private[spark] class SortShuffleWriter[K, V, C](
2731
handle: BaseShuffleHandle[K, V, C],
@@ -30,17 +34,24 @@ private[spark] class SortShuffleWriter[K, V, C](
3034
extends ShuffleWriter[K, V] with Logging {
3135

3236
private val dep = handle.dependency
33-
private val numOutputPartitions = dep.partitioner.numPartitions
37+
private val numPartitions = dep.partitioner.numPartitions
3438
private val metrics = context.taskMetrics
3539

3640
private val blockManager = SparkEnv.get.blockManager
41+
private val shuffleBlockManager = blockManager.shuffleBlockManager
3742
private val diskBlockManager = blockManager.diskBlockManager
3843
private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null))
3944

45+
private val conf = SparkEnv.get.conf
46+
private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024
47+
48+
private var sorter: ExternalSorter[K, V, _] = null
49+
50+
private var stopping = false
51+
private var mapStatus: MapStatus = null
52+
4053
/** Write a bunch of records to this task's output */
4154
override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
42-
var sorter: ExternalSorter[K, V, _] = null
43-
4455
val partitions: Iterator[(Int, Iterator[Product2[K, _]])] = {
4556
if (dep.mapSideCombine) {
4657
if (!dep.aggregator.isDefined) {
@@ -58,13 +69,81 @@ private[spark] class SortShuffleWriter[K, V, C](
5869
}
5970
}
6071

72+
// Create a single shuffle file with reduce ID 0 that we'll write all results to. We'll later
73+
// serve different ranges of this file using an index file that we create at the end.
74+
val blockId = ShuffleBlockId(dep.shuffleId, mapId, 0)
75+
val shuffleFile = blockManager.diskBlockManager.getFile(blockId)
76+
77+
// Track location of each range in the output file
78+
val offsets = new Array[Long](numPartitions + 1)
79+
val lengths = new Array[Long](numPartitions)
80+
81+
// Statistics
82+
var totalBytes = 0L
83+
var totalTime = 0L
84+
6185
for ((id, elements) <- partitions) {
86+
if (elements.hasNext) {
87+
val writer = blockManager.getDiskWriter(blockId, shuffleFile, ser, fileBufferSize)
88+
for (elem <- elements) {
89+
writer.write(elem)
90+
}
91+
writer.commit()
92+
writer.close()
93+
val segment = writer.fileSegment()
94+
offsets(id + 1) = segment.offset + segment.length
95+
lengths(id) = segment.length
96+
totalTime += writer.timeWriting()
97+
totalBytes += segment.length
98+
} else {
99+
// Don't create a new writer to avoid writing any headers and things like that
100+
offsets(id + 1) = offsets(id)
101+
}
102+
}
103+
104+
val shuffleMetrics = new ShuffleWriteMetrics
105+
shuffleMetrics.shuffleBytesWritten = totalBytes
106+
shuffleMetrics.shuffleWriteTime = totalTime
107+
context.taskMetrics.shuffleWriteMetrics = Some(shuffleMetrics)
108+
109+
// Write an index file with the offsets of each block, plus a final offset at the end for the
110+
// end of the output file. This will be used by SortShuffleManager.getBlockLocation to figure
111+
// out where each block begins and ends.
62112

113+
val diskBlockManager = blockManager.diskBlockManager
114+
val indexFile = diskBlockManager.getFile(blockId.name + ".index")
115+
val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile)))
116+
try {
117+
var i = 0
118+
while (i < numPartitions + 1) {
119+
out.writeLong(offsets(i))
120+
i += 1
121+
}
122+
} finally {
123+
out.close()
63124
}
64125

65-
???
126+
mapStatus = new MapStatus(blockManager.blockManagerId,
127+
lengths.map(MapOutputTracker.compressSize))
128+
129+
// TODO: keep track of our file in a way that can be cleaned up later
66130
}
67131

68132
/** Close this writer, passing along whether the map completed */
69-
override def stop(success: Boolean): Option[MapStatus] = ???
133+
override def stop(success: Boolean): Option[MapStatus] = {
134+
try {
135+
if (stopping) {
136+
return None
137+
}
138+
stopping = true
139+
if (success) {
140+
return Option(mapStatus)
141+
} else {
142+
// TODO: clean up our file
143+
return None
144+
}
145+
} finally {
146+
// TODO: sorter.stop()
147+
}
148+
}
70149
}

core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@ import java.io.File
2121
import java.text.SimpleDateFormat
2222
import java.util.{Date, Random, UUID}
2323

24-
import org.apache.spark.Logging
24+
import org.apache.spark.{SparkEnv, Logging}
2525
import org.apache.spark.executor.ExecutorExitCode
2626
import org.apache.spark.network.netty.{PathResolver, ShuffleSender}
2727
import org.apache.spark.util.Utils
28+
import org.apache.spark.shuffle.sort.SortShuffleManager
2829

2930
/**
3031
* Creates and maintains the logical mapping between logical blocks and physical on-disk
@@ -54,12 +55,15 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD
5455
addShutdownHook()
5556

5657
/**
57-
* Returns the physical file segment in which the given BlockId is located.
58-
* If the BlockId has been mapped to a specific FileSegment, that will be returned.
59-
* Otherwise, we assume the Block is mapped to a whole file identified by the BlockId directly.
58+
* Returns the physical file segment in which the given BlockId is located. If the BlockId has
59+
* been mapped to a specific FileSegment by the shuffle layer, that will be returned.
60+
* Otherwise, we assume the Block is mapped to the whole file identified by the BlockId.
6061
*/
6162
def getBlockLocation(blockId: BlockId): FileSegment = {
62-
if (blockId.isShuffle && shuffleManager.consolidateShuffleFiles) {
63+
if (blockId.isShuffle && SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager]) {
64+
val sortShuffleManager = SparkEnv.get.shuffleManager.asInstanceOf[SortShuffleManager]
65+
sortShuffleManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId], this)
66+
} else if (blockId.isShuffle && shuffleManager.consolidateShuffleFiles) {
6367
shuffleManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId])
6468
} else {
6569
val file = getFile(blockId.name)

core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.util.collection
1919

2020
import java.io._
21+
import java.util.Comparator
2122

2223
import scala.collection.mutable.ArrayBuffer
2324

@@ -88,6 +89,13 @@ private[spark] class ExternalSorter[K, V, C](
8889
(Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong
8990
}
9091

92+
// For now, just compare them by partition; later we can compare by key as well
93+
private val comparator = new Comparator[((Int, K), C)] {
94+
override def compare(a: ((Int, K), C), b: ((Int, K), C)): Int = {
95+
a._1._1 - b._1._1
96+
}
97+
}
98+
9199
// Information about a spilled file. Includes sizes in bytes of "batches" written by the
92100
// serializer as we periodically reset its stream, as well as number of elements in each
93101
// partition, used to efficiently keep track of partitions when merging.
@@ -192,7 +200,7 @@ private[spark] class ExternalSorter[K, V, C](
192200
}
193201

194202
try {
195-
val it = collection.iterator // TODO: destructiveSortedIterator(comparator)
203+
val it = collection.destructiveSortedIterator(comparator)
196204
while (it.hasNext) {
197205
val elem = it.next()
198206
val partitionId = elem._1._1
@@ -232,11 +240,22 @@ private[spark] class ExternalSorter[K, V, C](
232240
* inside each partition. This can be used to either write out a new file or return data to
233241
* the user.
234242
*/
235-
def merge(spills: Seq[SpilledFile]): Iterator[(Int, Iterator[Product2[K, C]])] = {
243+
def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)])
244+
: Iterator[(Int, Iterator[Product2[K, C]])] = {
236245
// TODO: merge intermediate results if they are sorted by the comparator
237246
val readers = spills.map(new SpillReader(_))
247+
val inMemBuffered = inMemory.buffered
238248
(0 until numPartitions).iterator.map { p =>
239-
(p, readers.iterator.flatMap(_.readNextPartition()))
249+
val inMemIterator = new Iterator[(K, C)] {
250+
override def hasNext: Boolean = {
251+
inMemBuffered.hasNext && inMemBuffered.head._1._1 == p
252+
}
253+
override def next(): (K, C) = {
254+
val elem = inMemBuffered.next()
255+
(elem._1._2, elem._2)
256+
}
257+
}
258+
(p, readers.iterator.flatMap(_.readNextPartition()) ++ inMemIterator)
240259
}
241260
}
242261

@@ -301,6 +320,11 @@ private[spark] class ExternalSorter[K, V, C](
301320
}
302321
val k = deserStream.readObject().asInstanceOf[K]
303322
val c = deserStream.readObject().asInstanceOf[C]
323+
if (partitionId == numPartitions - 1 &&
324+
indexInPartition == spill.elementsPerPartition(partitionId) - 1) {
325+
finished = true
326+
deserStream.close()
327+
}
304328
(k, c)
305329
} catch {
306330
case e: EOFException =>
@@ -319,6 +343,9 @@ private[spark] class ExternalSorter[K, V, C](
319343
override def hasNext: Boolean = {
320344
if (nextItem == null) {
321345
nextItem = readNextItem()
346+
if (nextItem == null) {
347+
return false
348+
}
322349
}
323350
// Check that we're still in the right partition; will be numPartitions at EOF
324351
partitionId == myPartition
@@ -328,7 +355,9 @@ private[spark] class ExternalSorter[K, V, C](
328355
if (!hasNext) {
329356
throw new NoSuchElementException
330357
}
331-
nextItem
358+
val item = nextItem
359+
nextItem = null
360+
item
332361
}
333362
}
334363
}
@@ -337,11 +366,16 @@ private[spark] class ExternalSorter[K, V, C](
337366
* Return an iterator over all the data written to this object, grouped by partition. For each
338367
* partition we then have an iterator over its contents, and these are expected to be accessed
339368
* in order (you can't "skip ahead" to one partition without reading the previous one).
369+
* Guaranteed to return a key-value pair for each partition, in order of partition ID.
340370
*
341371
* For now, we just merge all the spilled files in once pass, but this can be modified to
342372
* support hierarchical merging.
343373
*/
344-
def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = merge(spills)
374+
def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
375+
val usingMap = aggregator.isDefined
376+
val collection: SizeTrackingCollection[((Int, K), C)] = if (usingMap) map else buffer
377+
merge(spills, collection.destructiveSortedIterator(comparator))
378+
}
345379

346380
/**
347381
* Return an iterator over all the data written to this object.

core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -369,10 +369,3 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
369369
}
370370

371371
}
372-
373-
/**
374-
* A dummy class that always returns the same hash code, to easily test hash collisions
375-
*/
376-
case class FixedHashObject(v: Int, h: Int) extends Serializable {
377-
override def hashCode(): Int = h
378-
}
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.util.collection
19+
20+
import org.scalatest.FunSuite
21+
22+
import org.apache.spark.{SparkContext, SparkConf, LocalSparkContext}
23+
import org.apache.spark.SparkContext._
24+
import scala.collection.mutable.ArrayBuffer
25+
26+
class ExternalSorterSuite extends FunSuite with LocalSparkContext {
27+
28+
test("spilling in local cluster") {
29+
val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found
30+
conf.set("spark.shuffle.memoryFraction", "0.001")
31+
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
32+
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
33+
34+
// reduceByKey - should spill ~8 times
35+
val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i))
36+
val resultA = rddA.reduceByKey(math.max).collect()
37+
assert(resultA.length == 50000)
38+
resultA.foreach { case(k, v) =>
39+
k match {
40+
case 0 => assert(v == 1)
41+
case 25000 => assert(v == 50001)
42+
case 49999 => assert(v == 99999)
43+
case _ =>
44+
}
45+
}
46+
47+
// groupByKey - should spill ~17 times
48+
val rddB = sc.parallelize(0 until 100000).map(i => (i/4, i))
49+
val resultB = rddB.groupByKey().collect()
50+
assert(resultB.length == 25000)
51+
resultB.foreach { case(i, seq) =>
52+
i match {
53+
case 0 => assert(seq.toSet == Set[Int](0, 1, 2, 3))
54+
case 12500 => assert(seq.toSet == Set[Int](50000, 50001, 50002, 50003))
55+
case 24999 => assert(seq.toSet == Set[Int](99996, 99997, 99998, 99999))
56+
case _ =>
57+
}
58+
}
59+
60+
// cogroup - should spill ~7 times
61+
val rddC1 = sc.parallelize(0 until 10000).map(i => (i, i))
62+
val rddC2 = sc.parallelize(0 until 10000).map(i => (i%1000, i))
63+
val resultC = rddC1.cogroup(rddC2).collect()
64+
assert(resultC.length == 10000)
65+
resultC.foreach { case(i, (seq1, seq2)) =>
66+
i match {
67+
case 0 =>
68+
assert(seq1.toSet == Set[Int](0))
69+
assert(seq2.toSet == Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000))
70+
case 5000 =>
71+
assert(seq1.toSet == Set[Int](5000))
72+
assert(seq2.toSet == Set[Int]())
73+
case 9999 =>
74+
assert(seq1.toSet == Set[Int](9999))
75+
assert(seq2.toSet == Set[Int]())
76+
case _ =>
77+
}
78+
}
79+
}
80+
}

0 commit comments

Comments
 (0)