diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index 1745d52c8192..0a1de59a9d13 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -18,7 +18,7 @@ package org.apache.spark.network import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.storage.{BlockId, StorageLevel} +import org.apache.spark.storage.{BlockManagerId, BlockId, StorageLevel} private[spark] trait BlockDataManager { @@ -29,6 +29,12 @@ trait BlockDataManager { */ def getBlockData(blockId: BlockId): ManagedBuffer + /** + * Interface to get other executor's block data as the same node as blockManagerId. Throws + * an exception if the block cannot be found or cannot be read successfully. + */ + def getBlockData(blockId: BlockId, blockManagerId: BlockManagerId): ManagedBuffer + /** * Put the block locally, using the given storage level. */ diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index d0178dfde693..5462c8ec5c77 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -162,7 +162,7 @@ class FileShuffleBlockManager(conf: SparkConf) val fileId = shuffleState.nextFileId.getAndIncrement() val files = Array.tabulate[File](numBuckets) { bucketId => val filename = physicalFileName(shuffleId, bucketId, fileId) - blockManager.diskBlockManager.getFile(filename) + blockManager.diskBlockManager.getFile(filename, blockManager.blockManagerId) } val fileGroup = new ShuffleFileGroup(shuffleId, fileId, files) shuffleState.allFileGroups.add(fileGroup) @@ -180,7 +180,8 @@ class FileShuffleBlockManager(conf: SparkConf) Some(segment.nioByteBuffer()) } - override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { + override def getBlockData(blockId: ShuffleBlockId, + blockManagerId: BlockManagerId = blockManager.blockManagerId): ManagedBuffer = { if (consolidateShuffleFiles) { // Search all file groups associated with this shuffle. val shuffleState = shuffleStates(blockId.shuffleId) diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala index 87fd161e06c8..41e01939101f 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala @@ -52,12 +52,16 @@ class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager { ShuffleBlockId(shuffleId, mapId, 0) } - def getDataFile(shuffleId: Int, mapId: Int): File = { - blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, 0)) + def getDataFile(shuffleId: Int, + mapId: Int, + blockManagerId: BlockManagerId = blockManager.blockManagerId): File = { + blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, 0), blockManagerId) } - private def getIndexFile(shuffleId: Int, mapId: Int): File = { - blockManager.diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, 0)) + private def getIndexFile(shuffleId: Int, + mapId: Int, + blockManagerId: BlockManagerId = blockManager.blockManagerId): File = { + blockManager.diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, 0), blockManagerId) } /** @@ -101,10 +105,11 @@ class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager { Some(getBlockData(blockId).nioByteBuffer()) } - override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { + override def getBlockData(blockId: ShuffleBlockId, + blockManagerId: BlockManagerId = blockManager.blockManagerId): ManagedBuffer = { // The block is actually going to be a range of a single map output file for this map, so // find out the consolidated file, then the offset within that from our index - val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId) + val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId, blockManagerId) val in = new DataInputStream(new FileInputStream(indexFile)) try { @@ -113,7 +118,7 @@ class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager { val nextOffset = in.readLong() new FileSegmentManagedBuffer( transportConf, - getDataFile(blockId.shuffleId, blockId.mapId), + getDataFile(blockId.shuffleId, blockId.mapId, blockManagerId), offset, nextOffset - offset) } finally { diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala index b521f0c7fc77..32ad3d47bdfb 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala @@ -19,7 +19,7 @@ package org.apache.spark.shuffle import java.nio.ByteBuffer import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.storage.ShuffleBlockId +import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} private[spark] trait ShuffleBlockManager { @@ -31,7 +31,7 @@ trait ShuffleBlockManager { */ def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] - def getBlockData(blockId: ShuffleBlockId): ManagedBuffer + def getBlockData(blockId: ShuffleBlockId, blockManagerId: BlockManagerId): ManagedBuffer def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 1dff09a75d03..7c5fd7a9bc70 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -202,7 +202,8 @@ private[spark] class BlockManager( blockManagerId } - master.registerBlockManager(blockManagerId, maxMemory, slaveActor) + master.registerBlockManager( + blockManagerId, maxMemory, slaveActor, diskBlockManager.getLocalDirsPath()) // Register Executors' configuration with the local shuffle service, if one should exist. if (externalShuffleServiceEnabled && !blockManagerId.isDriver) { @@ -265,7 +266,8 @@ private[spark] class BlockManager( def reregister(): Unit = { // TODO: We might need to rate limit re-registering. logInfo("BlockManager re-registering with master") - master.registerBlockManager(blockManagerId, maxMemory, slaveActor) + master.registerBlockManager( + blockManagerId, maxMemory, slaveActor, diskBlockManager.getLocalDirsPath()) reportAllBlocks() } @@ -295,13 +297,18 @@ private[spark] class BlockManager( } } + override def getBlockData(blockId: BlockId): ManagedBuffer = { + getBlockData(blockId, blockManagerId) + } + /** - * Interface to get local block data. Throws an exception if the block cannot be found or - * cannot be read successfully. + * Interface to get other executor's block data as the same node as blockManagerId. + * Throws an exception if the block cannot be found or cannot be read successfully. */ - override def getBlockData(blockId: BlockId): ManagedBuffer = { + override def getBlockData(blockId: BlockId, blockManagerId: BlockManagerId): ManagedBuffer = { if (blockId.isShuffle) { - shuffleManager.shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) + shuffleManager.shuffleBlockManager.getBlockData( + blockId.asInstanceOf[ShuffleBlockId], blockManagerId) } else { val blockBytesOpt = doGetLocal(blockId, asBlockResult = false) .asInstanceOf[Option[ByteBuffer]] diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 061964826f08..c2f0e14d343a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -46,9 +46,12 @@ class BlockManagerMaster( } /** Register the BlockManager's id with the driver. */ - def registerBlockManager(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { + def registerBlockManager(blockManagerId: BlockManagerId, + maxMemSize: Long, + slaveActor: ActorRef, + localDirs: Array[String]) { logInfo("Trying to register BlockManager") - tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveActor)) + tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveActor, localDirs)) logInfo("Registered BlockManager") } @@ -75,6 +78,11 @@ class BlockManagerMaster( askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) } + /** Return other blockmanager's local dirs on the same machine as blockManagerId */ + def getLocalDirsPath(blockManagerId: BlockManagerId): Map[BlockManagerId, Array[String]] = { + askDriverWithReply[Map[BlockManagerId, Array[String]]](GetLocalDirsPath(blockManagerId)) + } + /** * Check if block manager master has a block. Note that this can be used to check for only * those blocks that are reported to block manager master. diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 5b5328016124..631e8fdfc907 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -53,8 +53,8 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus private val akkaTimeout = AkkaUtils.askTimeout(conf) override def receiveWithLogging: PartialFunction[Any, Unit] = { - case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) => - register(blockManagerId, maxMemSize, slaveActor) + case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor, localDirsPath) => + register(blockManagerId, maxMemSize, slaveActor, localDirsPath) sender ! true case UpdateBlockInfo( @@ -77,6 +77,9 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus case GetMemoryStatus => sender ! memoryStatus + case GetLocalDirsPath(blockManagerId) => + sender ! getLocalDirsPath(blockManagerId) + case GetStorageStatus => sender ! storageStatus @@ -223,6 +226,15 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus } } + // Return local dirs of other blockmanager on the same machine as blockManagerId + private def getLocalDirsPath( + blockManagerId: BlockManagerId): Map[BlockManagerId, Array[String]] = { + blockManagerInfo + .filter { case(id, _) => (id != blockManagerId && id.host == blockManagerId.host)} + .mapValues { info => info.localDirsPath } + .toMap + } + // Return a map from the block manager id to max memory and remaining memory. private def memoryStatus: Map[BlockManagerId, (Long, Long)] = { blockManagerInfo.map { case(blockManagerId, info) => @@ -291,7 +303,10 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus ).map(_.flatten.toSeq) } - private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { + private def register( + id: BlockManagerId, + maxMemSize: Long, + slaveActor: ActorRef, localDirsPath: Array[String]) { val time = System.currentTimeMillis() if (!blockManagerInfo.contains(id)) { blockManagerIdByExecutor.get(id.executorId) match { @@ -308,7 +323,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus blockManagerIdByExecutor(id.executorId) = id blockManagerInfo(id) = new BlockManagerInfo( - id, System.currentTimeMillis(), maxMemSize, slaveActor) + id, System.currentTimeMillis(), maxMemSize, slaveActor, localDirsPath) } listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxMemSize)) } @@ -320,7 +335,6 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus memSize: Long, diskSize: Long, tachyonSize: Long): Boolean = { - if (!blockManagerInfo.contains(blockManagerId)) { if (blockManagerId.isDriver && !isLocal) { // We intentionally do not register the master (except in local mode), @@ -412,7 +426,8 @@ private[spark] class BlockManagerInfo( val blockManagerId: BlockManagerId, timeMs: Long, val maxMem: Long, - val slaveActor: ActorRef) + val slaveActor: ActorRef, + val localDirsPath: Array[String]) extends Logging { private var _lastSeenMs: Long = timeMs diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 48247453edef..eb3ffd272b42 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -52,7 +52,8 @@ private[spark] object BlockManagerMessages { case class RegisterBlockManager( blockManagerId: BlockManagerId, maxMemSize: Long, - sender: ActorRef) + sender: ActorRef, + subDirs: Array[String]) extends ToBlockManagerMaster case class UpdateBlockInfo( @@ -109,4 +110,6 @@ private[spark] object BlockManagerMessages { extends ToBlockManagerMaster case class BlockManagerHeartbeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster + + case class GetLocalDirsPath(blockManagerId: BlockManagerId) extends ToBlockManagerMaster } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 288313787260..abd17dd01346 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -24,6 +24,8 @@ import org.apache.spark.{SparkConf, Logging} import org.apache.spark.executor.ExecutorExitCode import org.apache.spark.util.Utils +import scala.collection.mutable + /** * Creates and maintains the logical mapping between logical blocks and physical on-disk * locations. By default, one block is mapped to one file with a name given by its BlockId. @@ -51,40 +53,74 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon // of subDirs(i) is protected by the lock of subDirs(i) private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) + private val localDirsByBlkMgr = new mutable.HashMap[BlockManagerId, Array[String]] + + def getLocalDirsPath(): Array[String] = { + localDirs.map(file => file.getAbsolutePath) + } + private val shutdownHook = addShutdownHook() - /** Looks up a file by hashing it into one of our local subdirectories. */ - // This method should be kept in sync with - // org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getFile(). - def getFile(filename: String): File = { - // Figure out which local directory it hashes to, and which subdirectory in that - val hash = Utils.nonNegativeHash(filename) - val dirId = hash % localDirs.length - val subDirId = (hash / localDirs.length) % subDirsPerLocalDir - - // Create the subdirectory if it doesn't already exist - val subDir = subDirs(dirId).synchronized { - val old = subDirs(dirId)(subDirId) - if (old != null) { - old - } else { - val newDir = new File(localDirs(dirId), "%02x".format(subDirId)) - if (!newDir.exists() && !newDir.mkdir()) { - throw new IOException(s"Failed to create local dir in $newDir.") + def getFile( + fileName: String, + blockManagerId: BlockManagerId): File = { + val hash = Utils.nonNegativeHash(fileName) + val createDirIfAbsent = + blockManagerId.executorId == blockManager.blockManagerId.executorId + + if (createDirIfAbsent) { + val dirId = hash % localDirs.length + val subDirId = (hash / localDirs.length) % subDirsPerLocalDir + + // Create the subdirectory if it doesn't already exist + var subDir = subDirs(dirId)(subDirId) + if (subDir == null) { + subDir = subDirs(dirId).synchronized { + val old = subDirs(dirId)(subDirId) + if (old != null) { + old + } else { + val newDir = new File(localDirs(dirId), "%02x".format(subDirId)) + if (!newDir.exists() && !newDir.mkdir()) { + throw new IOException(s"Failed to create local dir in $newDir.") + } + subDirs(dirId)(subDirId) = newDir + newDir + } + } + } + new File(subDir, fileName) + } else { + var tmpLocalDirs = localDirsByBlkMgr.get(blockManagerId) + if (!tmpLocalDirs.isDefined) { + tmpLocalDirs = localDirsByBlkMgr.synchronized { + val old = localDirsByBlkMgr.get(blockManagerId) + if(old.isDefined) { + old + } else { + localDirsByBlkMgr ++= blockManager.master.getLocalDirsPath(blockManager.blockManagerId) + localDirsByBlkMgr.get(blockManagerId) + } } - subDirs(dirId)(subDirId) = newDir - newDir } - } - new File(subDir, filename) + val dirId = hash % tmpLocalDirs.get.length + val subDirId = (hash / tmpLocalDirs.get.length) % subDirsPerLocalDir + new File(tmpLocalDirs.get(dirId) + "/" + "%02x".format(subDirId), fileName) + } } - def getFile(blockId: BlockId): File = getFile(blockId.name) + def getFile( + blockId: BlockId, + blockManagerId: BlockManagerId = blockManager.blockManagerId): File = { +// val getFromThisExecutor = blockManagerId == blockManager.blockManagerId +// val dirs = if (getFromThisExecutor) getLocalDirsPath else localDirsByBlkMgr(blockManagerId) + getFile(blockId.name, blockManagerId) + } /** Check if disk block manager has a block. */ def containsBlock(blockId: BlockId): Boolean = { - getFile(blockId.name).exists() + getFile(blockId).exists() } /** List all the files currently stored on disk by the disk manager. */ diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 61ef5ff16879..2e91302b6260 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -35,7 +35,7 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc "spark.storage.memoryMapThreshold", 2 * 1024L * 1024L) override def getSize(blockId: BlockId): Long = { - diskManager.getFile(blockId.name).length + diskManager.getFile(blockId).length } override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel): PutResult = { @@ -129,7 +129,7 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc } override def getBytes(blockId: BlockId): Option[ByteBuffer] = { - val file = diskManager.getFile(blockId.name) + val file = diskManager.getFile(blockId) getBytes(file, 0, file.length) } @@ -152,7 +152,7 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc } override def remove(blockId: BlockId): Boolean = { - val file = diskManager.getFile(blockId.name) + val file = diskManager.getFile(blockId) // If consolidation mode is used With HashShuffleMananger, the physical filename for the block // is different from blockId.name. So the file returns here will not be exist, thus we avoid to // delete the whole consolidated file by mistake. @@ -164,7 +164,7 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc } override def contains(blockId: BlockId): Boolean = { - val file = diskManager.getFile(blockId.name) + val file = diskManager.getFile(blockId) file.exists() } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 8f28ef49a8a6..cfa98e98db91 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -20,7 +20,7 @@ package org.apache.spark.storage import java.io.{InputStream, IOException} import java.util.concurrent.LinkedBlockingQueue -import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} +import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet, Queue} import scala.util.{Failure, Success, Try} import org.apache.spark.{Logging, TaskContext} @@ -78,7 +78,7 @@ final class ShuffleBlockFetcherIterator( private[this] val startTime = System.currentTimeMillis /** Local blocks to fetch, excluding zero-sized blocks. */ - private[this] val localBlocks = new ArrayBuffer[BlockId]() + private[this] val localBlocksbyBlockMgr = new HashMap[BlockManagerId, ArrayBuffer[BlockId]] /** Remote blocks to fetch, excluding zero-sized blocks. */ private[this] val remoteBlocks = new HashSet[BlockId]() @@ -181,68 +181,102 @@ final class ShuffleBlockFetcherIterator( // at most maxBytesInFlight in order to limit the amount of data in flight. val remoteRequests = new ArrayBuffer[FetchRequest] + val shuffleMgrName = blockManager.conf.get("spark.shuffle.manager", "sort") + val externalShuffleServiceEnabled = + blockManager.conf.getBoolean("spark.shuffle.service.enabled", false) + // Tracks total number of blocks (including zero sized blocks) var totalBlocks = 0 - for ((address, blockInfos) <- blocksByAddress) { - totalBlocks += blockInfos.size - if (address.executorId == blockManager.blockManagerId.executorId) { - // Filter out zero-sized blocks - localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1) - numBlocksToFetch += localBlocks.size - } else { - val iterator = blockInfos.iterator - var curRequestSize = 0L - var curBlocks = new ArrayBuffer[(BlockId, Long)] - while (iterator.hasNext) { - val (blockId, size) = iterator.next() - // Skip empty blocks - if (size > 0) { - curBlocks += ((blockId, size)) - remoteBlocks += blockId - numBlocksToFetch += 1 - curRequestSize += size - } else if (size < 0) { - throw new BlockException(blockId, "Negative block size " + size) - } - if (curRequestSize >= targetRequestSize) { - // Add this FetchRequest - remoteRequests += new FetchRequest(address, curBlocks) - curBlocks = new ArrayBuffer[(BlockId, Long)] - logDebug(s"Creating fetch request of $curRequestSize at $address") - curRequestSize = 0 - } + if (shuffleMgrName.toLowerCase == "hash" || externalShuffleServiceEnabled) { + for ((address, blockInfos) <- blocksByAddress) { + totalBlocks += blockInfos.size + if (address.executorId == blockManager.blockManagerId.executorId) { + val tmpBlocks = blockInfos.filter(_._2 != 0).map(_._1) + localBlocksbyBlockMgr.getOrElseUpdate(address, ArrayBuffer()) ++= tmpBlocks + numBlocksToFetch += tmpBlocks.size + } else { + remoteRequests ++= getRemote(address, blockInfos, targetRequestSize) } - // Add in the final request - if (curBlocks.nonEmpty) { - remoteRequests += new FetchRequest(address, curBlocks) + } + } else { + for ((address, blockInfos) <- blocksByAddress) { + totalBlocks += blockInfos.size + if (address.host == blockManager.blockManagerId.host) { + val tmpBlocks = blockInfos.filter(_._2 != 0).map(_._1) + localBlocksbyBlockMgr.getOrElseUpdate(address, ArrayBuffer()) ++= tmpBlocks + numBlocksToFetch += tmpBlocks.size + } else { + remoteRequests ++= getRemote(address, blockInfos, targetRequestSize) } } } + logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks") remoteRequests } + private[this] def getRemote(address: BlockManagerId, + blockInfos: Seq[(BlockId, Long)], + targetRequestSize: Long): ArrayBuffer[FetchRequest] = { + val remoteRequests = new ArrayBuffer[FetchRequest] + + val iterator = blockInfos.iterator + var curRequestSize = 0L + var curBlocks = new ArrayBuffer[(BlockId, Long)] + while (iterator.hasNext) { + val (blockId, size) = iterator.next() + // Skip empty blocks + if (size > 0) { + curBlocks += ((blockId, size)) + remoteBlocks += blockId + numBlocksToFetch += 1 + curRequestSize += size + } else if (size < 0) { + throw new BlockException(blockId, "Negative block size " + size) + } + if (curRequestSize >= targetRequestSize) { + // Add this FetchRequest + remoteRequests += new FetchRequest(address, curBlocks) + curBlocks = new ArrayBuffer[(BlockId, Long)] + logDebug(s"Creating fetch request of $curRequestSize at $address") + curRequestSize = 0 + } + } + // Add in the final request + if (curBlocks.nonEmpty) { + remoteRequests += new FetchRequest(address, curBlocks) + } + + remoteRequests + } + /** * Fetch the local blocks while we are fetching remote blocks. This is ok because * [[ManagedBuffer]]'s memory is allocated lazily when we create the input stream, so all we * track in-memory are the ManagedBuffer references themselves. */ private[this] def fetchLocalBlocks() { - val iter = localBlocks.iterator + val iter = localBlocksbyBlockMgr.iterator while (iter.hasNext) { - val blockId = iter.next() - try { - val buf = blockManager.getBlockData(blockId) - shuffleMetrics.incLocalBlocksFetched(1) - shuffleMetrics.incLocalBytesRead(buf.size) - buf.retain() - results.put(new SuccessFetchResult(blockId, 0, buf)) - } catch { - case e: Exception => - // If we see an exception, stop immediately. - logError(s"Error occurred while fetching local blocks", e) - results.put(new FailureFetchResult(blockId, e)) - return + val localBlocksByBlkMgrId = iter.next() + val blockManagerId = localBlocksByBlkMgrId._1 + val blockIter = localBlocksByBlkMgrId._2.iterator + + while (blockIter.hasNext) { + val blockId = blockIter.next() + try { + val buf = blockManager.getBlockData(blockId, blockManagerId) + shuffleMetrics.incLocalBlocksFetched(1) + shuffleMetrics.incLocalBytesRead(buf.size) + buf.retain() + results.put(new SuccessFetchResult(blockId, 0, buf)) + } catch { + case e: Exception => + // If we see an exception, stop immediately. + logError(s"Error occurred while fetching local blocks", e) + results.put(new FailureFetchResult(blockId, e)) + return + } } } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index ecd1cba5b5ab..8b5776c053f1 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -31,7 +31,7 @@ import akka.actor._ import akka.pattern.ask import akka.util.Timeout -import org.mockito.Mockito.{mock, when} +import org.mockito.Mockito.{mock, when, doReturn} import org.scalatest._ import org.scalatest.concurrent.Eventually._ @@ -788,6 +788,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, actorSystem, master, new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) + store.initialize("app-id") // The put should fail since a1 is not serializable. class UnserializableClass @@ -817,6 +818,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach // be nice to refactor classes involved in disk storage in a way that // allows for easier testing. val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId when(blockManager.conf).thenReturn(conf.clone.set(confKey, 0.toString)) val diskBlockManager = new DiskBlockManager(blockManager, conf) diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index bc5c74c126b7..4acf0682ed8e 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -21,7 +21,7 @@ import java.io.{File, FileWriter} import scala.language.reflectiveCalls -import org.mockito.Mockito.{mock, when} +import org.mockito.Mockito.{mock, when, doReturn} import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} import org.apache.spark.SparkConf @@ -35,6 +35,8 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before val blockManager = mock(classOf[BlockManager]) when(blockManager.conf).thenReturn(testConf) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId var diskBlockManager: DiskBlockManager = _ override def beforeAll() { diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 37b593b2c5f7..3aec5b65ae3a 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -64,6 +64,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) doReturn(localBmId).when(blockManager).blockManagerId + doReturn(conf).when(blockManager).conf // Make sure blockManager.getBlockData would return the blocks val localBlocks = Map[BlockId, ManagedBuffer]( @@ -71,7 +72,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]), ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer])) localBlocks.foreach { case (blockId, buf) => - doReturn(buf).when(blockManager).getBlockData(meq(blockId)) + doReturn(buf).when(blockManager).getBlockData(meq(blockId), meq(localBmId)) } // Make sure remote blocks would return @@ -97,7 +98,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { 48 * 1024 * 1024) // 3 local blocks fetched in initialization - verify(blockManager, times(3)).getBlockData(any()) + verify(blockManager, times(3)).getBlockData(any(), any()) for (i <- 0 until 5) { assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements") @@ -114,7 +115,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { // 3 local blocks, and 2 remote blocks // (but from the same block manager so one call to fetchBlocks) - verify(blockManager, times(3)).getBlockData(any()) + verify(blockManager, times(3)).getBlockData(any(), any()) verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any()) } @@ -122,6 +123,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) doReturn(localBmId).when(blockManager).blockManagerId + doReturn(conf).when(blockManager).conf // Make sure remote blocks would return val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) @@ -185,6 +187,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) doReturn(localBmId).when(blockManager).blockManagerId + doReturn(conf).when(blockManager).conf // Make sure remote blocks would return val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)