From 4095cd23e6c58da82c6974a726ad678aa65395f9 Mon Sep 17 00:00:00 2001 From: jinxing Date: Sat, 18 Feb 2017 23:00:50 +0800 Subject: [PATCH 1/6] [SPARK-19659] Fetch big blocks to disk when shuffle-read. --- .../buffer/FileSegmentManagedBuffer.java | 2 +- .../server/OneForOneStreamManager.java | 21 ++ .../shuffle/ExternalShuffleClient.java | 7 +- .../shuffle/OneForOneBlockFetcher.java | 74 ++++++- .../spark/network/shuffle/ShuffleClient.java | 4 +- .../network/sasl/SaslIntegrationSuite.java | 2 +- .../ExternalShuffleIntegrationSuite.java | 2 +- .../shuffle/OneForOneBlockFetcherSuite.java | 7 +- .../spark/internal/config/package.scala | 7 + .../apache/spark/memory/MemoryManager.scala | 2 +- .../spark/network/BlockTransferService.scala | 7 +- .../netty/NettyBlockTransferService.scala | 7 +- .../shuffle/BlockStoreShuffleReader.scala | 15 +- .../storage/ShuffleBlockFetcherIterator.scala | 98 ++++++--- .../apache/spark/MapOutputTrackerSuite.scala | 6 +- .../NettyBlockTransferSecuritySuite.scala | 2 +- .../BlockStoreShuffleReaderSuite.scala | 16 +- .../spark/storage/BlockManagerSuite.scala | 4 +- .../ShuffleBlockFetcherIteratorSuite.scala | 194 ++++++++++++++++-- docs/configuration.md | 12 +- 20 files changed, 425 insertions(+), 64 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java index c20fab83c346..1d331039e5b5 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -36,7 +36,7 @@ /** * A {@link ManagedBuffer} backed by a segment in a file. */ -public final class FileSegmentManagedBuffer extends ManagedBuffer { +public class FileSegmentManagedBuffer extends ManagedBuffer { private final TransportConf conf; private final File file; private final long offset; diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java index ee367f9998db..ad8e8b44d201 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java @@ -23,6 +23,8 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; +import scala.Tuple2; + import com.google.common.base.Preconditions; import io.netty.channel.Channel; import org.slf4j.Logger; @@ -94,6 +96,25 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { return nextChunk; } + @Override + public ManagedBuffer openStream(String streamChunkId) { + Tuple2 streamIdAndChunkId = parseStreamChunkId(streamChunkId); + return getChunk(streamIdAndChunkId._1, streamIdAndChunkId._2); + } + + public static String genStreamChunkId(long streamId, int chunkId) { + return String.format("%d_%d", streamId, chunkId); + } + + public static Tuple2 parseStreamChunkId(String streamChunkId) { + String[] array = streamChunkId.split("_"); + assert array.length == 2: + "Stream id and chunk index should be specified when open stream for fetching block."; + long streamId = Long.valueOf(array[0]); + int chunkIndex = Integer.valueOf(array[1]); + return new Tuple2<>(streamId, chunkIndex); + } + @Override public void connectionTerminated(Channel channel) { // Close all streams which have been associated with the channel. diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 2c5827bf7dc5..269fa72dad5f 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -17,6 +17,7 @@ package org.apache.spark.network.shuffle; +import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; import java.util.List; @@ -86,14 +87,16 @@ public void fetchBlocks( int port, String execId, String[] blockIds, - BlockFetchingListener listener) { + BlockFetchingListener listener, + File[] shuffleFiles) { checkInit(); logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); try { RetryingBlockFetcher.BlockFetchStarter blockFetchStarter = (blockIds1, listener1) -> { TransportClient client = clientFactory.createClient(host, port); - new OneForOneBlockFetcher(client, appId, execId, blockIds1, listener1).start(); + new OneForOneBlockFetcher(client, appId, execId, blockIds1, listener1, conf, + shuffleFiles).start(); }; int maxRetries = conf.maxIORetries(); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 35f69fe35c94..e98cd747a515 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -17,19 +17,28 @@ package org.apache.spark.network.shuffle; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.nio.channels.WritableByteChannel; import java.util.Arrays; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.client.ChunkReceivedCallback; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallback; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; import org.apache.spark.network.shuffle.protocol.OpenBlocks; import org.apache.spark.network.shuffle.protocol.StreamHandle; +import org.apache.spark.network.util.TransportConf; /** * Simple wrapper on top of a TransportClient which interprets each chunk as a whole block, and @@ -48,6 +57,8 @@ public class OneForOneBlockFetcher { private final String[] blockIds; private final BlockFetchingListener listener; private final ChunkReceivedCallback chunkCallback; + private TransportConf transportConf = null; + private File[] shuffleFiles = null; private StreamHandle streamHandle = null; @@ -56,12 +67,20 @@ public OneForOneBlockFetcher( String appId, String execId, String[] blockIds, - BlockFetchingListener listener) { + BlockFetchingListener listener, + TransportConf transportConf, + File[] shuffleFiles) { this.client = client; this.openMessage = new OpenBlocks(appId, execId, blockIds); this.blockIds = blockIds; this.listener = listener; this.chunkCallback = new ChunkCallback(); + this.transportConf = transportConf; + if (shuffleFiles != null) { + this.shuffleFiles = shuffleFiles; + assert this.shuffleFiles.length == blockIds.length: + "Number of shuffle files should equal to blocks"; + } } /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */ @@ -100,7 +119,12 @@ public void onSuccess(ByteBuffer response) { // Immediately request all chunks -- we expect that the total size of the request is // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]]. for (int i = 0; i < streamHandle.numChunks; i++) { - client.fetchChunk(streamHandle.streamId, i, chunkCallback); + if (shuffleFiles != null) { + client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i), + new DownloadCallback(shuffleFiles[i], i)); + } else { + client.fetchChunk(streamHandle.streamId, i, chunkCallback); + } } } catch (Exception e) { logger.error("Failed while starting block fetches after success", e); @@ -126,4 +150,50 @@ private void failRemainingBlocks(String[] failedBlockIds, Throwable e) { } } } + + private class DownloadCallback implements StreamCallback { + + private WritableByteChannel channel = null; + private File targetFile = null; + private int chunkIndex; + + public DownloadCallback(File targetFile, int chunkIndex) throws IOException { + this.targetFile = targetFile; + this.channel = Channels.newChannel(new FileOutputStream(targetFile)); + this.chunkIndex = chunkIndex; + } + + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + channel.write(buf); + } + + @Override + public void onComplete(String streamId) throws IOException { + channel.close(); + ManagedBuffer buffer = new FileSegmentManagedBuffer( + transportConf, targetFile, 0, targetFile.length()) { + @Override + public ManagedBuffer release() { + ManagedBuffer ret = super.release(); + if (!targetFile.delete()) { + logger.info("Failed to cleanup " + targetFile.getAbsolutePath()); + } + return ret; + } + }; + listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer); + } + + @Override + public void onFailure(String streamId, Throwable cause) throws IOException { + channel.close(); + // On receipt of a failure, fail every block from chunkIndex onwards. + String[] remainingBlockIds = Arrays.copyOfRange(blockIds, chunkIndex, blockIds.length); + failRemainingBlocks(remainingBlockIds, cause); + if (!targetFile.delete()) { + logger.info("Failed to cleanup " + targetFile.getAbsolutePath()); + } + } + } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java index f72ab40690d0..978ff5a2a869 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java @@ -18,6 +18,7 @@ package org.apache.spark.network.shuffle; import java.io.Closeable; +import java.io.File; /** Provides an interface for reading shuffle files, either from an Executor or external service. */ public abstract class ShuffleClient implements Closeable { @@ -40,5 +41,6 @@ public abstract void fetchBlocks( int port, String execId, String[] blockIds, - BlockFetchingListener listener); + BlockFetchingListener listener, + File[] shuffleFiles); } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index c0e170e5b935..0c054fc5db8f 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -204,7 +204,7 @@ public void onBlockFetchFailure(String blockId, Throwable t) { String[] blockIds = { "shuffle_2_3_4", "shuffle_6_7_8" }; OneForOneBlockFetcher fetcher = - new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener); + new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener, conf, null); fetcher.start(); blockFetchLatch.await(); checkSecurityException(exception.get()); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 7a33b6821792..d1d8f5b4e188 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -158,7 +158,7 @@ public void onBlockFetchFailure(String blockId, Throwable exception) { } } } - }); + }, null); if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) { fail("Timeout getting response from the server"); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index 3e51fea3cf0e..61d82214e7d3 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -46,8 +46,13 @@ import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; import org.apache.spark.network.shuffle.protocol.OpenBlocks; import org.apache.spark.network.shuffle.protocol.StreamHandle; +import org.apache.spark.network.util.MapConfigProvider; +import org.apache.spark.network.util.TransportConf; public class OneForOneBlockFetcherSuite { + + private static final TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); + @Test public void testFetchOne() { LinkedHashMap blocks = Maps.newLinkedHashMap(); @@ -126,7 +131,7 @@ private static BlockFetchingListener fetchBlocks(LinkedHashMap { diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index e193ed222e22..074d9939b730 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -287,4 +287,11 @@ package object config { .bytesConf(ByteUnit.BYTE) .createWithDefault(100 * 1024 * 1024) + private[spark] val REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM = + ConfigBuilder("spark.reducer.maxReqSizeShuffleToMem") + .doc("The blocks of a shuffle request will be fetched to disk when size of the request is " + + "above this threshold. This is to avoid a giant request takes too much memory. Note that" + + " value of this config should be smaller than spark.memory.offHeap.size.") + .longConf + .createWithDefault(200 * 1024 * 1024) } diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala index 82442cf56154..72825e620b1a 100644 --- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala @@ -20,7 +20,7 @@ package org.apache.spark.memory import javax.annotation.concurrent.GuardedBy import org.apache.spark.SparkConf -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.storage.BlockId import org.apache.spark.storage.memory.MemoryStore import org.apache.spark.unsafe.Platform diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index cb9d389dd7ea..6860214c7fe3 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -17,7 +17,7 @@ package org.apache.spark.network -import java.io.Closeable +import java.io.{Closeable, File} import java.nio.ByteBuffer import scala.concurrent.{Future, Promise} @@ -67,7 +67,8 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo port: Int, execId: String, blockIds: Array[String], - listener: BlockFetchingListener): Unit + listener: BlockFetchingListener, + shuffleFiles: Array[File]): Unit /** * Upload a single block to a remote node, available only after [[init]] is invoked. @@ -100,7 +101,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo ret.flip() result.success(new NioManagedBuffer(ret)) } - }) + }, shuffleFiles = null) ThreadUtils.awaitResult(result.future, Duration.Inf) } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index b75e91b66096..b13a9c681e54 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -17,6 +17,7 @@ package org.apache.spark.network.netty +import java.io.File import java.nio.ByteBuffer import scala.collection.JavaConverters._ @@ -88,13 +89,15 @@ private[spark] class NettyBlockTransferService( port: Int, execId: String, blockIds: Array[String], - listener: BlockFetchingListener): Unit = { + listener: BlockFetchingListener, + shuffleFiles: Array[File]): Unit = { logTrace(s"Fetch blocks from $host:$port (executor id $execId)") try { val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) { val client = clientFactory.createClient(host, port) - new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener).start() + new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener, + transportConf, shuffleFiles).start() } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index ba3e0e395e95..7fbeadd74359 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -18,7 +18,9 @@ package org.apache.spark.shuffle import org.apache.spark._ -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} +import org.apache.spark.memory.MemoryMode +import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.serializer.SerializerManager import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} import org.apache.spark.util.CompletionIterator @@ -42,6 +44,12 @@ private[spark] class BlockStoreShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { + val memMode = + if (SparkTransportConf.fromSparkConf(SparkEnv.get.conf, "shuffle").preferDirectBufs()) { + MemoryMode.OFF_HEAP + } else { + MemoryMode.ON_HEAP + } val wrappedStreams = new ShuffleBlockFetcherIterator( context, blockManager.shuffleClient, @@ -51,7 +59,10 @@ private[spark] class BlockStoreShuffleReader[K, C]( // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue), - SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true)) + SparkEnv.get.conf.get(config.REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM), + SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true), + context.taskMemoryManager(), + memMode) val serializerInstance = dep.serializer.newInstance() diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index f8906117638b..70f01e8bb2c0 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import java.io.{InputStream, IOException} +import java.io.{File, InputStream, IOException} import java.nio.ByteBuffer import java.util.concurrent.LinkedBlockingQueue import javax.annotation.concurrent.GuardedBy @@ -27,6 +27,7 @@ import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging +import org.apache.spark.memory.{MemoryConsumer, MemoryMode, TaskMemoryManager} import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} import org.apache.spark.shuffle.FetchFailedException @@ -52,7 +53,10 @@ import org.apache.spark.util.io.ChunkedByteBufferOutputStream * @param streamWrapper A function to wrap the returned input stream. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point. + * @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory. * @param detectCorrupt whether to detect any corruption in fetched blocks. + * @param tmm [[TaskMemoryManager]] used in [[MemoryConsumer]] for acquiring memory. + * @param memMode [[MemoryMode]] acquire memory from whether off heap or on heap. */ private[spark] final class ShuffleBlockFetcherIterator( @@ -63,8 +67,12 @@ final class ShuffleBlockFetcherIterator( streamWrapper: (BlockId, InputStream) => InputStream, maxBytesInFlight: Long, maxReqsInFlight: Int, - detectCorrupt: Boolean) - extends Iterator[(BlockId, InputStream)] with Logging { + maxReqSizeShuffleToMem: Long, + detectCorrupt: Boolean, + tmm: TaskMemoryManager, + memMode: MemoryMode = MemoryMode.OFF_HEAP) + extends MemoryConsumer(tmm, tmm.pageSizeBytes(), memMode) + with Iterator[(BlockId, InputStream)] with Logging { import ShuffleBlockFetcherIterator._ @@ -129,6 +137,12 @@ final class ShuffleBlockFetcherIterator( @GuardedBy("this") private[this] var isZombie = false + /** + * Used to store the blocks which are shuffled to memory. A block will be removed from here + * after released. + */ + private[this] val blocksShuffleToMemPendingFree = mutable.Set[String]() + initialize() // Decrements the buffer reference count. @@ -137,6 +151,10 @@ final class ShuffleBlockFetcherIterator( // Release the current buffer if necessary if (currentResult != null) { currentResult.buf.release() + if (blocksShuffleToMemPendingFree.contains(currentResult.blockId.toString)) { + freeMemory(currentResult.size) + blocksShuffleToMemPendingFree -= currentResult.blockId.toString + } } currentResult = null } @@ -154,12 +172,16 @@ final class ShuffleBlockFetcherIterator( while (iter.hasNext) { val result = iter.next() result match { - case SuccessFetchResult(_, address, _, buf, _) => + case SuccessFetchResult(bId, address, size, buf, _) => if (address != blockManager.blockManagerId) { shuffleMetrics.incRemoteBytesRead(buf.size) shuffleMetrics.incRemoteBlocksFetched(1) } buf.release() + if (blocksShuffleToMemPendingFree.contains(bId.toString)) { + freeMemory(size) + blocksShuffleToMemPendingFree -= bId.toString + } case _ => } } @@ -175,33 +197,54 @@ final class ShuffleBlockFetcherIterator( val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap val remainingBlocks = new HashSet[String]() ++= sizeMap.keys val blockIds = req.blocks.map(_._1.toString) - val address = req.address - shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, - new BlockFetchingListener { - override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { - // Only add the buffer to results queue if the iterator is not zombie, - // i.e. cleanup() has not been called yet. - ShuffleBlockFetcherIterator.this.synchronized { - if (!isZombie) { - // Increment the ref count because we need to pass this to a different thread. - // This needs to be released after use. - buf.retain() - remainingBlocks -= blockId - results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf, - remainingBlocks.isEmpty)) - logDebug("remainingBlocks: " + remainingBlocks) - } + + val blockFetchingListener = new BlockFetchingListener { + override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { + // Only add the buffer to results queue if the iterator is not zombie, + // i.e. cleanup() has not been called yet. + ShuffleBlockFetcherIterator.this.synchronized { + if (!isZombie) { + // Increment the ref count because we need to pass this to a different thread. + // This needs to be released after use. + buf.retain() + remainingBlocks -= blockId + results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf, + remainingBlocks.isEmpty)) + logDebug("remainingBlocks: " + remainingBlocks) } - logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } + logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) + } - override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { - logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) - results.put(new FailureFetchResult(BlockId(blockId), address, e)) - } + override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { + logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) + results.put(new FailureFetchResult(BlockId(blockId), address, e)) } - ) + } + + // Shuffle remote blocks to disk when the request is too large or local memory shortage. + val fetchToDisk = if (req.size > maxReqSizeShuffleToMem) { + true + } else { + val acquired = acquireMemory(req.size) + if (acquired < req.size) { + freeMemory(acquired) + true + } else { + false + } + } + + if (fetchToDisk) { + shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, + blockFetchingListener, + blockIds.map(bId => blockManager.diskBlockManager.getFile(s"remote-$bId")).toArray) + } else { + blocksShuffleToMemPendingFree ++= blockIds + shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, + blockFetchingListener, null) + } } private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { @@ -395,7 +438,6 @@ final class ShuffleBlockFetcherIterator( // Send fetch requests up to maxBytesInFlight fetchUpToMaxBytes() } - currentResult = result.asInstanceOf[SuccessFetchResult] (currentResult.blockId, new BufferReleasingInputStream(input, this)) } @@ -419,6 +461,8 @@ final class ShuffleBlockFetcherIterator( "Failed to get block " + blockId + ", which is not a shuffle block", e) } } + + override def spill(size: Long, trigger: MemoryConsumer): Long = 0 } /** diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index bb24c6ce4d33..3ecc25d15af6 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark import scala.collection.mutable.ArrayBuffer -import org.mockito.Matchers.{any, isA} +import org.mockito.Matchers.any import org.mockito.Mockito._ import org.apache.spark.broadcast.BroadcastManager @@ -29,7 +29,11 @@ import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} class MapOutputTrackerSuite extends SparkFunSuite { + private val conf = new SparkConf + val env = mock(classOf[SparkEnv]) + doReturn(conf).when(env).conf + SparkEnv.set(env) private def newTrackerMaster(sparkConf: SparkConf = conf) = { val broadcastManager = new BroadcastManager(true, sparkConf, diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 792a1d7f57e2..474e30144f62 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -165,7 +165,7 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { promise.success(data.retain()) } - }) + }, null) ThreadUtils.awaitReady(promise.future, FiniteDuration(10, TimeUnit.SECONDS)) promise.future.value.get diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index dba1172d5fdb..d0c4007b1955 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -19,10 +19,15 @@ package org.apache.spark.shuffle import java.io.{ByteArrayOutputStream, InputStream} import java.nio.ByteBuffer +import java.util.Properties +import org.mockito.Matchers.any import org.mockito.Mockito.{mock, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer import org.apache.spark._ +import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.serializer.{JavaSerializer, SerializerManager} import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} @@ -126,11 +131,20 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext .set("spark.shuffle.compress", "false") .set("spark.shuffle.spill.compress", "false")) + val taskMemoryManager = mock(classOf[TaskMemoryManager]) + when(taskMemoryManager.acquireExecutionMemory(any(), any())) + .thenAnswer(new Answer[Long] { + override def answer(invocation: InvocationOnMock): Long = { + invocation.getArguments()(0).asInstanceOf[Long] + } + }) + val tc = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, new Properties, null) + val shuffleReader = new BlockStoreShuffleReader( shuffleHandle, reduceId, reduceId + 1, - TaskContext.empty(), + tc, serializerManager, blockManager, mapOutputTracker) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 1e7bcdb6740f..0d2912ba8c5f 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.storage +import java.io.File import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer @@ -1290,7 +1291,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE port: Int, execId: String, blockIds: Array[String], - listener: BlockFetchingListener): Unit = { + listener: BlockFetchingListener, + shuffleFiles: Array[File]): Unit = { listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1))) } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 9900d1edc4cb..5488c69bf8a1 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.storage import java.io.{File, InputStream, IOException} +import java.util.Properties import java.util.concurrent.Semaphore import scala.concurrent.ExecutionContext.Implicits.global @@ -29,7 +30,8 @@ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.PrivateMethodTester -import org.apache.spark.{SparkFunSuite, TaskContext} +import org.apache.spark.{SparkFunSuite, TaskContext, TaskContextImpl} +import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.network._ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle.BlockFetchingListener @@ -44,7 +46,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT /** Creates a mock [[BlockTransferService]] that returns data from the given map. */ private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): BlockTransferService = { val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val blocks = invocation.getArguments()(3).asInstanceOf[Array[String]] val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -61,6 +64,17 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer } + private def createMockTaskMemoryManager(): TaskMemoryManager = { + val taskMemoryManager = mock(classOf[TaskMemoryManager]) + when(taskMemoryManager.acquireExecutionMemory(any(), any())) + .thenAnswer(new Answer[Long] { + override def answer(invocation: InvocationOnMock): Long = { + invocation.getArguments()(0).asInstanceOf[Long] + } + }) + taskMemoryManager + } + // Create a mock managed buffer for testing def createMockManagedBuffer(): ManagedBuffer = { val mockManagedBuffer = mock(classOf[ManagedBuffer]) @@ -106,7 +120,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT (_, in) => in, 48 * 1024 * 1024, Int.MaxValue, - true) + Int.MaxValue, + true, + createMockTaskMemoryManager()) // 3 local blocks fetched in initialization verify(blockManager, times(3)).getBlockData(any()) @@ -130,11 +146,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT verify(mockBuf, times(1)).release() verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close() } - // 3 local blocks, and 2 remote blocks // (but from the same block manager so one call to fetchBlocks) verify(blockManager, times(3)).getBlockData(any()) - verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any()) + verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any()) } test("release current unexhausted buffer in case the task completes early") { @@ -153,7 +168,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] Future { @@ -181,7 +197,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT (_, in) => in, 48 * 1024 * 1024, Int.MaxValue, - true) + Int.MaxValue, + true, + createMockTaskMemoryManager()) verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() iterator.next()._2.close() // close() first block's input stream @@ -218,7 +236,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] Future { @@ -246,7 +265,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT (_, in) => in, 48 * 1024 * 1024, Int.MaxValue, - true) + Int.MaxValue, + true, + createMockTaskMemoryManager) // Continue only after the mock calls onBlockFetchFailure sem.acquire() @@ -281,7 +302,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] Future { @@ -309,7 +331,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT (_, in) => new LimitedInputStream(in, 100), 48 * 1024 * 1024, Int.MaxValue, - true) + Int.MaxValue, + true, + createMockTaskMemoryManager()) // Continue only after the mock calls onBlockFetchFailure sem.acquire() @@ -318,7 +342,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val (id1, _) = iterator.next() assert(id1 === ShuffleBlockId(0, 0, 0)) - when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] Future { @@ -359,7 +384,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT when(corruptBuffer.createInputStream()).thenReturn(corruptStream) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] Future { @@ -387,7 +413,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT (_, in) => new LimitedInputStream(in, 100), 48 * 1024 * 1024, Int.MaxValue, - false) + Int.MaxValue, + false, + createMockTaskMemoryManager()) // Continue only after the mock calls onBlockFetchFailure sem.acquire() @@ -401,4 +429,142 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT assert(id3 === ShuffleBlockId(0, 2, 0)) } + test("Blocks should be shuffled to disk when size of the request is above the" + + " threshold(maxReqSizeShuffleToMem).") { + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + + val diskBlockManager = mock(classOf[DiskBlockManager]) + doReturn(new File("shuffle-read-file")).when(diskBlockManager).getFile(any(classOf[String])) + doReturn(diskBlockManager).when(blockManager).diskBlockManager + + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val remoteBlocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()) + val transfer = mock(classOf[BlockTransferService]) + var shuffleFiles: Array[File] = null + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer(new Answer[Unit] { + override def answer(invocation: InvocationOnMock): Unit = { + val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] + shuffleFiles = invocation.getArguments()(5).asInstanceOf[Array[File]] + Future { + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0))) + } + } + }) + + val taskMemoryManager = createMockTaskMemoryManager() + val tc = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, new Properties, null) + + val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq)) + // Set maxReqSizeShuffleToMem to be 200. + val iterator1 = new ShuffleBlockFetcherIterator( + tc, + transfer, + blockManager, + blocksByAddress1, + (_, in) => in, + Int.MaxValue, + Int.MaxValue, + 200, + true, + taskMemoryManager) + assert(shuffleFiles === null) + + val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)) + // Set maxReqSizeShuffleToMem to be 200. + val iterator2 = new ShuffleBlockFetcherIterator( + tc, + transfer, + blockManager, + blocksByAddress2, + (_, in) => in, + Int.MaxValue, + Int.MaxValue, + 200, + true, + taskMemoryManager) + assert(shuffleFiles != null) + } + + test("Blocks should be shuffled to disk when size is above memory threshold," + + " otherwise to memory.") { + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + + val diskBlockManager = mock(classOf[DiskBlockManager]) + doReturn(new File("shuffle-read-file")).when(diskBlockManager).getFile(any(classOf[String])) + doReturn(diskBlockManager).when(blockManager).diskBlockManager + + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val remoteBlocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()) + val transfer = mock(classOf[BlockTransferService]) + var shuffleFiles: Array[File] = null + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer(new Answer[Unit] { + override def answer(invocation: InvocationOnMock): Unit = { + val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] + shuffleFiles = invocation.getArguments()(5).asInstanceOf[Array[File]] + Future { + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0))) + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 1, 0).toString, remoteBlocks(ShuffleBlockId(0, 1, 0))) + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 2, 0).toString, remoteBlocks(ShuffleBlockId(0, 2, 0))) + } + } + }) + val taskMemoryManager = mock(classOf[TaskMemoryManager]) + when(taskMemoryManager.acquireExecutionMemory(any(), any())) + .thenAnswer(new Answer[Long] { + override def answer(invocationOnMock: InvocationOnMock): Long = { + val required = invocationOnMock.getArguments()(0).asInstanceOf[Long] + // 500 bytes at most can be offered from TaskMemoryManager. + math.min(required, 500) + } + }) + + val tc = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, new Properties, null) + + val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq)) + val iterator1 = new ShuffleBlockFetcherIterator( + tc, + transfer, + blockManager, + blocksByAddress1, + (_, in) => in, + 48 * 1024 * 1024, + Int.MaxValue, + Int.MaxValue, + true, + taskMemoryManager) + assert(shuffleFiles === null) + + val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 200L)).toSeq) + ) + val iterator2 = new ShuffleBlockFetcherIterator( + tc, + transfer, + blockManager, + blocksByAddress2, + (_, in) => in, + 48 * 1024 * 1024, + Int.MaxValue, + Int.MaxValue, + true, + taskMemoryManager) + assert(shuffleFiles != null) + } } diff --git a/docs/configuration.md b/docs/configuration.md index a6b6d5dfa5f9..d95148e8fd64 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -519,6 +519,14 @@ Apart from these, the following properties are also available, and may be useful By allowing it to limit the number of fetch requests, this scenario can be mitigated. + + spark.reducer.maxReqSizeShuffleToMem + 200 * 1024 * 1024 + + The blocks of a shuffle request will be fetched to disk when size of the request is above + this threshold. This is to avoid a giant request takes too much memory. + + spark.shuffle.compress true @@ -963,12 +971,12 @@ Apart from these, the following properties are also available, and may be useful spark.memory.offHeap.enabled false - If true, Spark will attempt to use off-heap memory for certain operations. If off-heap memory use is enabled, then spark.memory.offHeap.size must be positive. + If true, Spark will attempt to use off-heap memory for certain operations(e.g. sort, aggregate, etc. However, the buffer used for fetching shuffle blocks is always off-heap). If off-heap memory use is enabled, then spark.memory.offHeap.size must be positive. spark.memory.offHeap.size - 0 + 384 * 1024 * 1024 The absolute amount of memory in bytes which can be used for off-heap allocation. This setting has no impact on heap memory usage, so if your executors' total memory consumption must fit within some hard limit then be sure to shrink your JVM heap size accordingly. From 63b3292abb714bd688a80665164ec4d84b994821 Mon Sep 17 00:00:00 2001 From: jinxing Date: Tue, 23 May 2017 13:19:48 +0800 Subject: [PATCH 2/6] Remove the MemoryConsumer logic and refine logic of deleting shuffle files. --- .../spark/internal/config/package.scala | 3 +- .../apache/spark/memory/MemoryManager.scala | 2 +- .../shuffle/BlockStoreShuffleReader.scala | 10 +- .../storage/ShuffleBlockFetcherIterator.scala | 49 ++------ .../apache/spark/MapOutputTrackerSuite.scala | 5 - .../BlockStoreShuffleReaderSuite.scala | 16 +-- .../ShuffleBlockFetcherIteratorSuite.scala | 119 ++---------------- docs/configuration.md | 4 +- 8 files changed, 28 insertions(+), 180 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 074d9939b730..add934d9afb7 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -290,8 +290,7 @@ package object config { private[spark] val REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM = ConfigBuilder("spark.reducer.maxReqSizeShuffleToMem") .doc("The blocks of a shuffle request will be fetched to disk when size of the request is " + - "above this threshold. This is to avoid a giant request takes too much memory. Note that" + - " value of this config should be smaller than spark.memory.offHeap.size.") + "above this threshold. This is to avoid a giant request takes too much memory.") .longConf .createWithDefault(200 * 1024 * 1024) } diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala index 72825e620b1a..82442cf56154 100644 --- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala @@ -20,7 +20,7 @@ package org.apache.spark.memory import javax.annotation.concurrent.GuardedBy import org.apache.spark.SparkConf -import org.apache.spark.internal.{config, Logging} +import org.apache.spark.internal.Logging import org.apache.spark.storage.BlockId import org.apache.spark.storage.memory.MemoryStore import org.apache.spark.unsafe.Platform diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 7fbeadd74359..cdd22504d220 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -44,12 +44,6 @@ private[spark] class BlockStoreShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - val memMode = - if (SparkTransportConf.fromSparkConf(SparkEnv.get.conf, "shuffle").preferDirectBufs()) { - MemoryMode.OFF_HEAP - } else { - MemoryMode.ON_HEAP - } val wrappedStreams = new ShuffleBlockFetcherIterator( context, blockManager.shuffleClient, @@ -60,9 +54,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue), SparkEnv.get.conf.get(config.REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM), - SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true), - context.taskMemoryManager(), - memMode) + SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true)) val serializerInstance = dep.serializer.newInstance() diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 70f01e8bb2c0..5adbbe746e85 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import java.io.{File, InputStream, IOException} +import java.io.{InputStream, IOException} import java.nio.ByteBuffer import java.util.concurrent.LinkedBlockingQueue import javax.annotation.concurrent.GuardedBy @@ -27,7 +27,6 @@ import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging -import org.apache.spark.memory.{MemoryConsumer, MemoryMode, TaskMemoryManager} import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} import org.apache.spark.shuffle.FetchFailedException @@ -55,8 +54,6 @@ import org.apache.spark.util.io.ChunkedByteBufferOutputStream * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point. * @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory. * @param detectCorrupt whether to detect any corruption in fetched blocks. - * @param tmm [[TaskMemoryManager]] used in [[MemoryConsumer]] for acquiring memory. - * @param memMode [[MemoryMode]] acquire memory from whether off heap or on heap. */ private[spark] final class ShuffleBlockFetcherIterator( @@ -68,11 +65,7 @@ final class ShuffleBlockFetcherIterator( maxBytesInFlight: Long, maxReqsInFlight: Int, maxReqSizeShuffleToMem: Long, - detectCorrupt: Boolean, - tmm: TaskMemoryManager, - memMode: MemoryMode = MemoryMode.OFF_HEAP) - extends MemoryConsumer(tmm, tmm.pageSizeBytes(), memMode) - with Iterator[(BlockId, InputStream)] with Logging { + detectCorrupt: Boolean) extends Iterator[(BlockId, InputStream)] with Logging { import ShuffleBlockFetcherIterator._ @@ -137,12 +130,6 @@ final class ShuffleBlockFetcherIterator( @GuardedBy("this") private[this] var isZombie = false - /** - * Used to store the blocks which are shuffled to memory. A block will be removed from here - * after released. - */ - private[this] val blocksShuffleToMemPendingFree = mutable.Set[String]() - initialize() // Decrements the buffer reference count. @@ -151,10 +138,6 @@ final class ShuffleBlockFetcherIterator( // Release the current buffer if necessary if (currentResult != null) { currentResult.buf.release() - if (blocksShuffleToMemPendingFree.contains(currentResult.blockId.toString)) { - freeMemory(currentResult.size) - blocksShuffleToMemPendingFree -= currentResult.blockId.toString - } } currentResult = null } @@ -172,16 +155,12 @@ final class ShuffleBlockFetcherIterator( while (iter.hasNext) { val result = iter.next() result match { - case SuccessFetchResult(bId, address, size, buf, _) => + case SuccessFetchResult(bId, address, _, buf, _) => if (address != blockManager.blockManagerId) { shuffleMetrics.incRemoteBytesRead(buf.size) shuffleMetrics.incRemoteBlocksFetched(1) } buf.release() - if (blocksShuffleToMemPendingFree.contains(bId.toString)) { - freeMemory(size) - blocksShuffleToMemPendingFree -= bId.toString - } case _ => } } @@ -223,25 +202,22 @@ final class ShuffleBlockFetcherIterator( } } - // Shuffle remote blocks to disk when the request is too large or local memory shortage. + // Shuffle remote blocks to disk when the request is too large. val fetchToDisk = if (req.size > maxReqSizeShuffleToMem) { true } else { - val acquired = acquireMemory(req.size) - if (acquired < req.size) { - freeMemory(acquired) - true - } else { - false - } + false } if (fetchToDisk) { + val shuffleFiles = blockIds.map(bId => blockManager.diskBlockManager + .getFile(s"${context.taskAttemptId()}-remote-$bId")).toArray + // Register with a task completion callback to ensure that they're guaranteed to be deleted + // after the task finishes. This is another layer of defensiveness against disk file leaks. + context.addTaskCompletionListener(_ => shuffleFiles.foreach(_.delete())) shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, - blockFetchingListener, - blockIds.map(bId => blockManager.diskBlockManager.getFile(s"remote-$bId")).toArray) + blockFetchingListener, shuffleFiles) } else { - blocksShuffleToMemPendingFree ++= blockIds shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, blockFetchingListener, null) } @@ -438,6 +414,7 @@ final class ShuffleBlockFetcherIterator( // Send fetch requests up to maxBytesInFlight fetchUpToMaxBytes() } + currentResult = result.asInstanceOf[SuccessFetchResult] (currentResult.blockId, new BufferReleasingInputStream(input, this)) } @@ -461,8 +438,6 @@ final class ShuffleBlockFetcherIterator( "Failed to get block " + blockId + ", which is not a shuffle block", e) } } - - override def spill(size: Long, trigger: MemoryConsumer): Long = 0 } /** diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 3ecc25d15af6..cd04e8e3272f 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -30,11 +30,6 @@ import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} class MapOutputTrackerSuite extends SparkFunSuite { - private val conf = new SparkConf - val env = mock(classOf[SparkEnv]) - doReturn(conf).when(env).conf - SparkEnv.set(env) - private def newTrackerMaster(sparkConf: SparkConf = conf) = { val broadcastManager = new BroadcastManager(true, sparkConf, new SecurityManager(sparkConf)) diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index d0c4007b1955..dba1172d5fdb 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -19,15 +19,10 @@ package org.apache.spark.shuffle import java.io.{ByteArrayOutputStream, InputStream} import java.nio.ByteBuffer -import java.util.Properties -import org.mockito.Matchers.any import org.mockito.Mockito.{mock, when} -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer import org.apache.spark._ -import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.serializer.{JavaSerializer, SerializerManager} import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} @@ -131,20 +126,11 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext .set("spark.shuffle.compress", "false") .set("spark.shuffle.spill.compress", "false")) - val taskMemoryManager = mock(classOf[TaskMemoryManager]) - when(taskMemoryManager.acquireExecutionMemory(any(), any())) - .thenAnswer(new Answer[Long] { - override def answer(invocation: InvocationOnMock): Long = { - invocation.getArguments()(0).asInstanceOf[Long] - } - }) - val tc = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, new Properties, null) - val shuffleReader = new BlockStoreShuffleReader( shuffleHandle, reduceId, reduceId + 1, - tc, + TaskContext.empty(), serializerManager, blockManager, mapOutputTracker) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 5488c69bf8a1..43a0dbcf7cc5 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.storage import java.io.{File, InputStream, IOException} -import java.util.Properties import java.util.concurrent.Semaphore import scala.concurrent.ExecutionContext.Implicits.global @@ -30,8 +29,7 @@ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.PrivateMethodTester -import org.apache.spark.{SparkFunSuite, TaskContext, TaskContextImpl} -import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.{SparkFunSuite, TaskContext} import org.apache.spark.network._ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle.BlockFetchingListener @@ -64,17 +62,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer } - private def createMockTaskMemoryManager(): TaskMemoryManager = { - val taskMemoryManager = mock(classOf[TaskMemoryManager]) - when(taskMemoryManager.acquireExecutionMemory(any(), any())) - .thenAnswer(new Answer[Long] { - override def answer(invocation: InvocationOnMock): Long = { - invocation.getArguments()(0).asInstanceOf[Long] - } - }) - taskMemoryManager - } - // Create a mock managed buffer for testing def createMockManagedBuffer(): ManagedBuffer = { val mockManagedBuffer = mock(classOf[ManagedBuffer]) @@ -121,8 +108,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, - true, - createMockTaskMemoryManager()) + true) // 3 local blocks fetched in initialization verify(blockManager, times(3)).getBlockData(any()) @@ -198,8 +184,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, - true, - createMockTaskMemoryManager()) + true) verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() iterator.next()._2.close() // close() first block's input stream @@ -266,8 +251,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, - true, - createMockTaskMemoryManager) + true) // Continue only after the mock calls onBlockFetchFailure sem.acquire() @@ -332,8 +316,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, - true, - createMockTaskMemoryManager()) + true) // Continue only after the mock calls onBlockFetchFailure sem.acquire() @@ -414,8 +397,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, - false, - createMockTaskMemoryManager()) + false) // Continue only after the mock calls onBlockFetchFailure sem.acquire() @@ -456,14 +438,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } }) - val taskMemoryManager = createMockTaskMemoryManager() - val tc = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, new Properties, null) - val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq)) // Set maxReqSizeShuffleToMem to be 200. val iterator1 = new ShuffleBlockFetcherIterator( - tc, + TaskContext.empty(), transfer, blockManager, blocksByAddress1, @@ -471,15 +450,14 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, 200, - true, - taskMemoryManager) + true) assert(shuffleFiles === null) val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)) // Set maxReqSizeShuffleToMem to be 200. val iterator2 = new ShuffleBlockFetcherIterator( - tc, + TaskContext.empty(), transfer, blockManager, blocksByAddress2, @@ -487,84 +465,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, 200, - true, - taskMemoryManager) - assert(shuffleFiles != null) - } - - test("Blocks should be shuffled to disk when size is above memory threshold," + - " otherwise to memory.") { - val blockManager = mock(classOf[BlockManager]) - val localBmId = BlockManagerId("test-client", "test-client", 1) - doReturn(localBmId).when(blockManager).blockManagerId - - val diskBlockManager = mock(classOf[DiskBlockManager]) - doReturn(new File("shuffle-read-file")).when(diskBlockManager).getFile(any(classOf[String])) - doReturn(diskBlockManager).when(blockManager).diskBlockManager - - val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) - val remoteBlocks = Map[BlockId, ManagedBuffer]( - ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), - ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), - ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()) - val transfer = mock(classOf[BlockTransferService]) - var shuffleFiles: Array[File] = null - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) - .thenAnswer(new Answer[Unit] { - override def answer(invocation: InvocationOnMock): Unit = { - val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - shuffleFiles = invocation.getArguments()(5).asInstanceOf[Array[File]] - Future { - listener.onBlockFetchSuccess( - ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0))) - listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, remoteBlocks(ShuffleBlockId(0, 1, 0))) - listener.onBlockFetchSuccess( - ShuffleBlockId(0, 2, 0).toString, remoteBlocks(ShuffleBlockId(0, 2, 0))) - } - } - }) - val taskMemoryManager = mock(classOf[TaskMemoryManager]) - when(taskMemoryManager.acquireExecutionMemory(any(), any())) - .thenAnswer(new Answer[Long] { - override def answer(invocationOnMock: InvocationOnMock): Long = { - val required = invocationOnMock.getArguments()(0).asInstanceOf[Long] - // 500 bytes at most can be offered from TaskMemoryManager. - math.min(required, 500) - } - }) - - val tc = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, new Properties, null) - - val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq)) - val iterator1 = new ShuffleBlockFetcherIterator( - tc, - transfer, - blockManager, - blocksByAddress1, - (_, in) => in, - 48 * 1024 * 1024, - Int.MaxValue, - Int.MaxValue, - true, - taskMemoryManager) - assert(shuffleFiles === null) - - val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 200L)).toSeq) - ) - val iterator2 = new ShuffleBlockFetcherIterator( - tc, - transfer, - blockManager, - blocksByAddress2, - (_, in) => in, - 48 * 1024 * 1024, - Int.MaxValue, - Int.MaxValue, - true, - taskMemoryManager) + true) assert(shuffleFiles != null) } } diff --git a/docs/configuration.md b/docs/configuration.md index d95148e8fd64..807583c43057 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -971,12 +971,12 @@ Apart from these, the following properties are also available, and may be useful spark.memory.offHeap.enabled false - If true, Spark will attempt to use off-heap memory for certain operations(e.g. sort, aggregate, etc. However, the buffer used for fetching shuffle blocks is always off-heap). If off-heap memory use is enabled, then spark.memory.offHeap.size must be positive. + If true, Spark will attempt to use off-heap memory for certain operations. If off-heap memory use is enabled, then spark.memory.offHeap.size must be positive. spark.memory.offHeap.size - 384 * 1024 * 1024 + 0 The absolute amount of memory in bytes which can be used for off-heap allocation. This setting has no impact on heap memory usage, so if your executors' total memory consumption must fit within some hard limit then be sure to shrink your JVM heap size accordingly. From ac030fa08203bf6dbdcaa21aa5dc8b86389a3e16 Mon Sep 17 00:00:00 2001 From: jinxing Date: Tue, 23 May 2017 16:27:59 +0800 Subject: [PATCH 3/6] Remove all files in cleanup(), which is already registered as TaskCompletionListener --- .../spark/shuffle/BlockStoreShuffleReader.scala | 2 -- .../spark/storage/ShuffleBlockFetcherIterator.scala | 13 +++++++++---- .../org/apache/spark/MapOutputTrackerSuite.scala | 1 + 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index cdd22504d220..2fbac79a2305 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -19,8 +19,6 @@ package org.apache.spark.shuffle import org.apache.spark._ import org.apache.spark.internal.{config, Logging} -import org.apache.spark.memory.MemoryMode -import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.serializer.SerializerManager import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} import org.apache.spark.util.CompletionIterator diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 5adbbe746e85..772973624448 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import java.io.{InputStream, IOException} +import java.io.{File, InputStream, IOException} import java.nio.ByteBuffer import java.util.concurrent.LinkedBlockingQueue import javax.annotation.concurrent.GuardedBy @@ -130,6 +130,12 @@ final class ShuffleBlockFetcherIterator( @GuardedBy("this") private[this] var isZombie = false + /** + * A set to store the files used for shuffling remote huge blocks. Files in this set will be + * deleted when cleanup. This is another layer of defensiveness against disk file leaks. + */ + val shuffleFilesSet = mutable.HashSet[File]() + initialize() // Decrements the buffer reference count. @@ -164,6 +170,7 @@ final class ShuffleBlockFetcherIterator( case _ => } } + shuffleFilesSet.foreach(_.delete()) } private[this] def sendRequest(req: FetchRequest) { @@ -212,9 +219,7 @@ final class ShuffleBlockFetcherIterator( if (fetchToDisk) { val shuffleFiles = blockIds.map(bId => blockManager.diskBlockManager .getFile(s"${context.taskAttemptId()}-remote-$bId")).toArray - // Register with a task completion callback to ensure that they're guaranteed to be deleted - // after the task finishes. This is another layer of defensiveness against disk file leaks. - context.addTaskCompletionListener(_ => shuffleFiles.foreach(_.delete())) + shuffleFilesSet ++= shuffleFiles shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, blockFetchingListener, shuffleFiles) } else { diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index cd04e8e3272f..71bedda5ac89 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} class MapOutputTrackerSuite extends SparkFunSuite { + private val conf = new SparkConf private def newTrackerMaster(sparkConf: SparkConf = conf) = { val broadcastManager = new BroadcastManager(true, sparkConf, From 222680c9d311f2d3fe7265fbf6e834e73cf4c05d Mon Sep 17 00:00:00 2001 From: jinxing Date: Wed, 24 May 2017 11:21:31 +0800 Subject: [PATCH 4/6] Do not remove shuffle files when release() --- .../network/shuffle/OneForOneBlockFetcher.java | 13 ++----------- .../spark/storage/ShuffleBlockFetcherIterator.scala | 8 ++++++-- 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index e98cd747a515..b4380c925ca0 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -171,17 +171,8 @@ public void onData(String streamId, ByteBuffer buf) throws IOException { @Override public void onComplete(String streamId) throws IOException { channel.close(); - ManagedBuffer buffer = new FileSegmentManagedBuffer( - transportConf, targetFile, 0, targetFile.length()) { - @Override - public ManagedBuffer release() { - ManagedBuffer ret = super.release(); - if (!targetFile.delete()) { - logger.info("Failed to cleanup " + targetFile.getAbsolutePath()); - } - return ret; - } - }; + ManagedBuffer buffer = new FileSegmentManagedBuffer(transportConf, targetFile, 0, + targetFile.length()); listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer); } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 772973624448..21596afb32ee 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -132,7 +132,7 @@ final class ShuffleBlockFetcherIterator( /** * A set to store the files used for shuffling remote huge blocks. Files in this set will be - * deleted when cleanup. This is another layer of defensiveness against disk file leaks. + * deleted when cleanup. This is a layer of defensiveness against disk file leaks. */ val shuffleFilesSet = mutable.HashSet[File]() @@ -170,7 +170,11 @@ final class ShuffleBlockFetcherIterator( case _ => } } - shuffleFilesSet.foreach(_.delete()) + shuffleFilesSet.foreach { file => + if (!file.delete()) { + logInfo("Failed to cleanup " + file.getAbsolutePath()); + } + } } private[this] def sendRequest(req: FetchRequest) { From 2ce269991cceaee18fbab71689454c8602342e68 Mon Sep 17 00:00:00 2001 From: jinxing Date: Wed, 24 May 2017 22:41:13 +0800 Subject: [PATCH 5/6] Generate tmp files by createTempLocalBlock and refine some log info. --- .../network/shuffle/OneForOneBlockFetcher.java | 3 --- .../apache/spark/internal/config/package.scala | 4 ++-- .../storage/ShuffleBlockFetcherIterator.scala | 16 ++++++---------- 3 files changed, 8 insertions(+), 15 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index b4380c925ca0..5f428759252a 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -182,9 +182,6 @@ public void onFailure(String streamId, Throwable cause) throws IOException { // On receipt of a failure, fail every block from chunkIndex onwards. String[] remainingBlockIds = Arrays.copyOfRange(blockIds, chunkIndex, blockIds.length); failRemainingBlocks(remainingBlockIds, cause); - if (!targetFile.delete()) { - logger.info("Failed to cleanup " + targetFile.getAbsolutePath()); - } } } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index add934d9afb7..f8139b706a7c 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -291,6 +291,6 @@ package object config { ConfigBuilder("spark.reducer.maxReqSizeShuffleToMem") .doc("The blocks of a shuffle request will be fetched to disk when size of the request is " + "above this threshold. This is to avoid a giant request takes too much memory.") - .longConf - .createWithDefault(200 * 1024 * 1024) + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("200m") } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 21596afb32ee..8de204070d0f 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -172,7 +172,7 @@ final class ShuffleBlockFetcherIterator( } shuffleFilesSet.foreach { file => if (!file.delete()) { - logInfo("Failed to cleanup " + file.getAbsolutePath()); + logInfo("Failed to cleanup shuffle fetch temp file " + file.getAbsolutePath()); } } } @@ -214,15 +214,11 @@ final class ShuffleBlockFetcherIterator( } // Shuffle remote blocks to disk when the request is too large. - val fetchToDisk = if (req.size > maxReqSizeShuffleToMem) { - true - } else { - false - } - - if (fetchToDisk) { - val shuffleFiles = blockIds.map(bId => blockManager.diskBlockManager - .getFile(s"${context.taskAttemptId()}-remote-$bId")).toArray + // TODO: Encryption and compression should be considered. + if (req.size > maxReqSizeShuffleToMem) { + val shuffleFiles = blockIds.map { + bId => blockManager.diskBlockManager.createTempLocalBlock()._2 + }.toArray shuffleFilesSet ++= shuffleFiles shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, blockFetchingListener, shuffleFiles) From b07a3b61ba483989b2c205e88cf9fdc73a4205df Mon Sep 17 00:00:00 2001 From: jinxing Date: Thu, 25 May 2017 07:35:23 +0800 Subject: [PATCH 6/6] fix --- .../spark/network/buffer/FileSegmentManagedBuffer.java | 2 +- .../apache/spark/storage/ShuffleBlockFetcherIterator.scala | 5 +++-- .../spark/storage/ShuffleBlockFetcherIteratorSuite.scala | 7 ++++++- docs/configuration.md | 2 +- 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java index 1d331039e5b5..c20fab83c346 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -36,7 +36,7 @@ /** * A {@link ManagedBuffer} backed by a segment in a file. */ -public class FileSegmentManagedBuffer extends ManagedBuffer { +public final class FileSegmentManagedBuffer extends ManagedBuffer { private final TransportConf conf; private final File file; private final long offset; diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 8de204070d0f..ee3506092655 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -65,7 +65,8 @@ final class ShuffleBlockFetcherIterator( maxBytesInFlight: Long, maxReqsInFlight: Int, maxReqSizeShuffleToMem: Long, - detectCorrupt: Boolean) extends Iterator[(BlockId, InputStream)] with Logging { + detectCorrupt: Boolean) + extends Iterator[(BlockId, InputStream)] with Logging { import ShuffleBlockFetcherIterator._ @@ -161,7 +162,7 @@ final class ShuffleBlockFetcherIterator( while (iter.hasNext) { val result = iter.next() result match { - case SuccessFetchResult(bId, address, _, buf, _) => + case SuccessFetchResult(_, address, _, buf, _) => if (address != blockManager.blockManagerId) { shuffleMetrics.incRemoteBytesRead(buf.size) shuffleMetrics.incRemoteBlocksFetched(1) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 43a0dbcf7cc5..1f813a909fb8 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.storage import java.io.{File, InputStream, IOException} +import java.util.UUID import java.util.concurrent.Semaphore import scala.concurrent.ExecutionContext.Implicits.global @@ -132,6 +133,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT verify(mockBuf, times(1)).release() verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close() } + // 3 local blocks, and 2 remote blocks // (but from the same block manager so one call to fetchBlocks) verify(blockManager, times(3)).getBlockData(any()) @@ -418,7 +420,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT doReturn(localBmId).when(blockManager).blockManagerId val diskBlockManager = mock(classOf[DiskBlockManager]) - doReturn(new File("shuffle-read-file")).when(diskBlockManager).getFile(any(classOf[String])) + doReturn{ + var blockId = new TempLocalBlockId(UUID.randomUUID()) + (blockId, new File(blockId.name)) + }.when(diskBlockManager).createTempLocalBlock() doReturn(diskBlockManager).when(blockManager).diskBlockManager val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) diff --git a/docs/configuration.md b/docs/configuration.md index 807583c43057..0771e36f80b5 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -521,7 +521,7 @@ Apart from these, the following properties are also available, and may be useful spark.reducer.maxReqSizeShuffleToMem - 200 * 1024 * 1024 + 200m The blocks of a shuffle request will be fetched to disk when size of the request is above this threshold. This is to avoid a giant request takes too much memory.