Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ object SparkEnv extends Logging {
// Let the user specify short names for shuffle managers
val shortShuffleMgrNames = Map(
"hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager",
"memory" -> "org.apache.spark.shuffle.memory.MemoryShuffleManager",
"sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager")
val shuffleMgrName = conf.get("spark.shuffle.manager", "sort")
val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle.memory

import java.nio.ByteBuffer
import java.util.concurrent.ConcurrentLinkedQueue

import scala.collection.JavaConversions._

import org.apache.spark.{Logging, SparkConf, SparkEnv}
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.shuffle.ShuffleBlockManager
import org.apache.spark.storage.{BlockNotFoundException, ShuffleBlockId}
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}

/**
* Tracks metadata about the shuffle blocks stored on a particular executor, and periodically
* deletes old shuffle files.
*
* TODO: There is significant code duplication between this class and FileShuffleBlockManager;
* consider making both classes inherit from a common parent class.
*/
private[spark] class MemoryShuffleBlockManager(conf: SparkConf)
extends ShuffleBlockManager with Logging {

private lazy val blockManager = SparkEnv.get.blockManager

private class ShuffleState(val numBuckets: Int) {
val completedMapTasks = new ConcurrentLinkedQueue[Int]()
}
private val shuffleIdToState = new TimeStampedHashMap[ShuffleId, ShuffleState]()

private val metadataCleaner =
new MetadataCleaner(MetadataCleanerType.SHUFFLE_BLOCK_MANAGER, this.cleanup, conf)

override def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] = {
val segment = getBlockData(blockId)
Some(segment.nioByteBuffer())
}

override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {
blockManager.memoryStore.getBytes(blockId) match {
case Some(bytes) => new NioManagedBuffer(bytes)
case None => throw new BlockNotFoundException(blockId.toString)
}
}

/**
* Registers shuffle output for a particular map task, so that it can be deleted later by the
* metadata cleaner.
*/
def addShuffleOutput(shuffleId: ShuffleId, mapId: Int, numBuckets: Int) {
shuffleIdToState.putIfAbsent(shuffleId, new ShuffleState(numBuckets))
shuffleIdToState(shuffleId).completedMapTasks.add(mapId)
}

/** Remove all the blocks / files and metadata related to a particular shuffle. */
def removeShuffle(shuffleId: ShuffleId): Boolean = {
// Do not change the ordering of this: shuffleState should be removed only
// after the corresponding shuffle blocks have been removed.
val cleaned = removeShuffleBlocks(shuffleId)
shuffleIdToState.remove(shuffleId)
cleaned
}

private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = {
shuffleIdToState.get(shuffleId) match {
case Some(state) =>
for (mapId <- state.completedMapTasks; reduceId <- 0 until state.numBuckets) {
val blockId = ShuffleBlockId(shuffleId, mapId, reduceId)
blockManager.removeBlock(blockId, tellMaster = false)
}
logInfo(s"Deleted all files for shuffle $shuffleId")
true
case None =>
logInfo(s"Could not find files for shuffle $shuffleId for deleting")
false
}
}

private def cleanup(cleanupTime: Long) {
shuffleIdToState.clearOldValues(
cleanupTime, (shuffleId, state) => removeShuffleBlocks(shuffleId))
}

override def stop() {
metadataCleaner.cancel()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle.memory

import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext}
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleHandle, ShuffleManager, ShuffleReader,
ShuffleWriter}
import org.apache.spark.shuffle.hash.HashShuffleReader

/**
* A ShuffleManager that stores shuffle data in-memory.
*
* This shuffle manager uses hashing: it creates one in-memory block per reduce partition on each
* mapper.
*/
private[spark] class MemoryShuffleManager(conf: SparkConf) extends ShuffleManager {
private val memoryShuffleBlockManager = new MemoryShuffleBlockManager(conf)

override def registerShuffle[K, V, C](
shuffleId: Int,
numMaps: Int,
dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
new BaseShuffleHandle(shuffleId, numMaps, dependency)
}

override def getReader[K, C](
handle: ShuffleHandle,
startPartition: Int,
endPartition: Int,
context: TaskContext): ShuffleReader[K, C] = {
new HashShuffleReader(
handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
}

override def getWriter[K, V](
handle: ShuffleHandle, mapId: Int, context: TaskContext): ShuffleWriter[K, V] = {
new MemoryShuffleWriter(
memoryShuffleBlockManager, handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context)
}

override def unregisterShuffle(shuffleId: Int): Boolean = {
memoryShuffleBlockManager.removeShuffle(shuffleId)
}

override def shuffleBlockManager: MemoryShuffleBlockManager = {
memoryShuffleBlockManager
}

override def stop(): Unit = {
shuffleBlockManager.stop()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle.memory

import java.io.{ByteArrayOutputStream, OutputStream}
import java.nio.ByteBuffer

import org.apache.spark.{ShuffleDependency, SparkEnv, TaskContext}
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.serializer.{SerializationStream, Serializer}
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriter}
import org.apache.spark.storage.{BlockManager, ShuffleBlockId, StorageLevel}

/** A ShuffleWriter that stores all shuffle data in memory using the block manager. */
private[spark] class MemoryShuffleWriter[K, V](
shuffleBlockManager: MemoryShuffleBlockManager,
handle: BaseShuffleHandle[K, V, _],
mapId: Int,
context: TaskContext) extends ShuffleWriter[K, V] {

val dep = handle.dependency

// Create a different writer for each output bucket.
val blockManager = SparkEnv.get.blockManager
val numBuckets = dep.partitioner.numPartitions
val shuffleData = Array.tabulate[SerializedObjectWriter](numBuckets) {
bucketId =>
new SerializedObjectWriter(blockManager, dep, mapId, bucketId)
}

val shuffleWriteMetrics = new ShuffleWriteMetrics()
context.taskMetrics().shuffleWriteMetrics = Some(shuffleWriteMetrics)

override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
val iter = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
dep.aggregator.get.combineValuesByKey(records, context)
} else {
records
}
} else {
require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
records
}

// Write the data to the appropriate bucket.
for (elem <- iter) {
val bucketId = dep.partitioner.getPartition(elem._1)
shuffleData(bucketId).write(elem)
shuffleWriteMetrics.incShuffleRecordsWritten(1)
}
}

override def stop(success: Boolean): Option[MapStatus] = {
// Store the shuffle data in the block manager (if the shuffle was successful) and update the
// bytes written in ShuffleWriteMetrics.
val sizes = shuffleData.map { shuffleWriter =>
val bytesWritten = shuffleWriter.close(success)
shuffleWriteMetrics.incShuffleBytesWritten(bytesWritten)
bytesWritten
}
if (success) {
shuffleBlockManager.addShuffleOutput(dep.shuffleId, mapId, numBuckets)
Some(MapStatus(SparkEnv.get.blockManager.blockManagerId, sizes))
} else {
None
}
}
}

/** Serializes and optionally compresses data into an in-memory byte stream. */
private[spark] class SerializedObjectWriter(
blockManager: BlockManager, dep: ShuffleDependency[_,_,_], partitionId: Int, bucketId: Int) {

/**
* A ByteArrayOutputStream that will convert the underlying byte array to a byte buffer without
* copying all of the data. This is to avoid calling the ByteArrayOutputStream.toByteArray
* method, because that method makes a copy of the byte array.
*/
private class ByteArrayOutputStreamWithZeroCopyByteBuffer extends ByteArrayOutputStream {
def getByteBuffer(): ByteBuffer = ByteBuffer.wrap(buf, 0, size())
}

private val byteOutputStream = new ByteArrayOutputStreamWithZeroCopyByteBuffer()
private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null))
private val shuffleId = dep.shuffleId
private val blockId = ShuffleBlockId(shuffleId, partitionId, bucketId)

/* Only initialize compressionStream and serializationStream if some bytes are written, otherwise
* 16 bytes will always be written to the byteOutputStream (and those bytes will be unnecessarily
* transferred to reduce tasks). */
private var initialized = false
private var compressionStream: OutputStream = null
private var serializationStream: SerializationStream = null

def open() {
compressionStream = blockManager.wrapForCompression(blockId, byteOutputStream)
serializationStream = ser.newInstance().serializeStream(compressionStream)
initialized = true
}

def write(value: Any) {
if (!initialized) {
open()
}
serializationStream.writeObject(value)
}

def close(saveToBlockManager: Boolean): Long = {
if (initialized) {
serializationStream.flush()
serializationStream.close()
if (saveToBlockManager) {
val result = blockManager.putBytes(
blockId,
byteOutputStream.getByteBuffer(),
StorageLevel.MEMORY_ONLY_SER,
tellMaster = false)
return result.size
}
}
return 0
}
}
27 changes: 27 additions & 0 deletions core/src/test/scala/org/apache/spark/MemoryShuffleSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark

import org.scalatest.BeforeAndAfterAll

/** Runs all tests in ShuffleSuite using in-memory shuffle. */
class MemoryShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {
override def beforeAll() {
conf.set("spark.shuffle.manager", "memory")
}
}
13 changes: 9 additions & 4 deletions core/src/test/scala/org/apache/spark/ShuffleSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -242,14 +242,14 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex
shuffleSpillCompress <- Set(true, false);
shuffleCompress <- Set(true, false)
) {
val conf = new SparkConf()
val myConf = conf.clone()
.setAppName("test")
.setMaster("local")
.set("spark.shuffle.spill.compress", shuffleSpillCompress.toString)
.set("spark.shuffle.compress", shuffleCompress.toString)
.set("spark.shuffle.memoryFraction", "0.001")
resetSparkContext()
sc = new SparkContext(conf)
sc = new SparkContext(myConf)
try {
sc.parallelize(0 until 100000).map(i => (i / 4, i)).groupByKey().collect()
} catch {
Expand All @@ -268,16 +268,21 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex
rdd.count()

// Delete one of the local shuffle blocks.
val hashFile = sc.env.blockManager.diskBlockManager.getFile(new ShuffleBlockId(0, 0, 0))
val shuffleBlockId = new ShuffleBlockId(0, 0, 0)
val hashFile = sc.env.blockManager.diskBlockManager.getFile(shuffleBlockId)
val sortFile = sc.env.blockManager.diskBlockManager.getFile(new ShuffleDataBlockId(0, 0, 0))
assert(hashFile.exists() || sortFile.exists())
val memoryBlock = sc.env.blockManager.memoryStore.getBytes(shuffleBlockId)
assert(hashFile.exists() || sortFile.exists() || memoryBlock.isDefined)

if (hashFile.exists()) {
hashFile.delete()
}
if (sortFile.exists()) {
sortFile.delete()
}
if (memoryBlock.isDefined) {
sc.env.blockManager.memoryStore.remove(shuffleBlockId)
}

// This count should retry the execution of the previous stage and rerun shuffle.
rdd.count()
Expand Down