diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index 70dcc8b8b8b6..9ca5c3b761ce 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -218,21 +218,18 @@ private class ManagedBufferIterator implements Iterator { private final int shuffleId; // An array containing mapId and reduceId pairs. private final int[] mapIdAndReduceIds; + private final int shuffleGenerationId; ManagedBufferIterator(String appId, String execId, String[] blockIds) { this.appId = appId; this.execId = execId; - String[] blockId0Parts = blockIds[0].split("_"); - if (blockId0Parts.length != 4 || !blockId0Parts[0].equals("shuffle")) { - throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockIds[0]); - } + String[] blockId0Parts = splitBlockId(blockIds[0]); this.shuffleId = Integer.parseInt(blockId0Parts[1]); mapIdAndReduceIds = new int[2 * blockIds.length]; + this.shuffleGenerationId = + (blockId0Parts.length == 5) ? Integer.parseInt(blockId0Parts[4]) : -1; for (int i = 0; i < blockIds.length; i++) { - String[] blockIdParts = blockIds[i].split("_"); - if (blockIdParts.length != 4 || !blockIdParts[0].equals("shuffle")) { - throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockIds[i]); - } + String[] blockIdParts = splitBlockId(blockIds[i]); if (Integer.parseInt(blockIdParts[1]) != shuffleId) { throw new IllegalArgumentException("Expected shuffleId=" + shuffleId + ", got:" + blockIds[i]); @@ -242,6 +239,16 @@ private class ManagedBufferIterator implements Iterator { } } + private String[] splitBlockId(String blockId) { + String[] blockIdParts = blockId.split("_"); + if ((blockIdParts.length != 4 && blockIdParts.length != 5) + || !blockIdParts[0].equals("shuffle")) { + throw new IllegalArgumentException( + "Unexpected shuffle block id format: " + blockId); + } + return blockIdParts; + } + @Override public boolean hasNext() { return index < mapIdAndReduceIds.length; @@ -250,7 +257,7 @@ public boolean hasNext() { @Override public ManagedBuffer next() { final ManagedBuffer block = blockManager.getBlockData(appId, execId, shuffleId, - mapIdAndReduceIds[index], mapIdAndReduceIds[index + 1]); + mapIdAndReduceIds[index], mapIdAndReduceIds[index + 1], shuffleGenerationId); index += 2; metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0); return block; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 0b7a27402369..74bcdf4e5082 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -161,6 +161,18 @@ public void registerExecutor( executors.put(fullId, executorInfo); } + /** + * Overload getBlockData with setting stageAttemptId to an invalid value of -1. + */ + public ManagedBuffer getBlockData( + String appId, + String execId, + int shuffleId, + int mapId, + int reduceId) { + return getBlockData(appId, execId, shuffleId, mapId, reduceId, -1); + } + /** * Obtains a FileSegmentManagedBuffer from (shuffleId, mapId, reduceId). We make assumptions * about how the hash and sort based shuffles store their data. @@ -170,13 +182,15 @@ public ManagedBuffer getBlockData( String execId, int shuffleId, int mapId, - int reduceId) { + int reduceId, + int stageAttemptId) { ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId)); if (executor == null) { throw new RuntimeException( String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId)); } - return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId); + return getSortBasedShuffleBlockData( + executor, shuffleId, mapId, reduceId, stageAttemptId); } /** @@ -278,19 +292,25 @@ public boolean accept(File dir, String name) { * Sort-based shuffle data uses an index called "shuffle_ShuffleId_MapId_0.index" into a data file * called "shuffle_ShuffleId_MapId_0.data". This logic is from IndexShuffleBlockResolver, * and the block id format is from ShuffleDataBlockId and ShuffleIndexBlockId. + * While the shuffle data and index file generated from the indeterminate stage, + * the ShuffleDataBlockId and ShuffleIndexBlockId will be extended by the stage attempt id. */ private ManagedBuffer getSortBasedShuffleBlockData( - ExecutorShuffleInfo executor, int shuffleId, int mapId, int reduceId) { + ExecutorShuffleInfo executor, int shuffleId, + int mapId, int reduceId, int stageAttemptId) { + String baseFileName = "shuffle_" + shuffleId + "_" + mapId + "_0"; + if (stageAttemptId != -1) { + baseFileName = baseFileName + "_" + stageAttemptId; + } File indexFile = getFile(executor.localDirs, executor.subDirsPerLocalDir, - "shuffle_" + shuffleId + "_" + mapId + "_0.index"); + baseFileName + ".index"); try { ShuffleIndexInformation shuffleIndexInformation = shuffleIndexCache.get(indexFile); ShuffleIndexRecord shuffleIndexRecord = shuffleIndexInformation.getIndex(reduceId); return new FileSegmentManagedBuffer( conf, - getFile(executor.localDirs, executor.subDirsPerLocalDir, - "shuffle_" + shuffleId + "_" + mapId + "_0.data"), + getFile(executor.localDirs, executor.subDirsPerLocalDir, baseFileName + ".data"), shuffleIndexRecord.getOffset(), shuffleIndexRecord.getLength()); } catch (ExecutionException e) { diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index 537c277cd26b..63344aed63f8 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -85,8 +85,8 @@ public void testOpenShuffleBlocks() { ManagedBuffer block0Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[3])); ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); - when(blockResolver.getBlockData("app0", "exec1", 0, 0, 0)).thenReturn(block0Marker); - when(blockResolver.getBlockData("app0", "exec1", 0, 0, 1)).thenReturn(block1Marker); + when(blockResolver.getBlockData("app0", "exec1", 0, 0, 0, -1)).thenReturn(block0Marker); + when(blockResolver.getBlockData("app0", "exec1", 0, 0, 1, -1)).thenReturn(block1Marker); ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", new String[] { "shuffle_0_0_0", "shuffle_0_0_1" }) .toByteBuffer(); @@ -109,8 +109,8 @@ public void testOpenShuffleBlocks() { assertEquals(block0Marker, buffers.next()); assertEquals(block1Marker, buffers.next()); assertFalse(buffers.hasNext()); - verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 0); - verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 1); + verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 0, -1); + verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 1, -1); // Verify open block request latency metrics Timer openBlockRequestLatencyMillis = (Timer) ((ExternalShuffleBlockHandler) handler) diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index 459629c5f05f..3b7cee2b0145 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -53,7 +53,10 @@ public static void beforeAll() throws IOException { // Write some sort data. dataContext.insertSortShuffleData(0, 0, new byte[][] { sortBlock0.getBytes(StandardCharsets.UTF_8), - sortBlock1.getBytes(StandardCharsets.UTF_8)}); + sortBlock1.getBytes(StandardCharsets.UTF_8)}, false); + dataContext.insertSortShuffleData(0, 0, new byte[][] { + sortBlock0.getBytes(StandardCharsets.UTF_8), + sortBlock1.getBytes(StandardCharsets.UTF_8)}, true); } @AfterClass @@ -113,6 +116,27 @@ public void testSortShuffleBlocks() throws IOException { } } + @Test + public void testExtendedSortShuffleBlocks() throws IOException { + ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null); + resolver.registerExecutor("app0", "exec0", + dataContext.createExecutorInfo(SORT_MANAGER)); + + try (InputStream block0Stream = resolver.getBlockData( + "app0", "exec0", 0, 0, 0, 0).createInputStream()) { + String block0 = + CharStreams.toString(new InputStreamReader(block0Stream, StandardCharsets.UTF_8)); + assertEquals(sortBlock0, block0); + } + + try (InputStream block1Stream = resolver.getBlockData( + "app0", "exec0", 0, 0, 1, 0).createInputStream()) { + String block1 = + CharStreams.toString(new InputStreamReader(block1Stream, StandardCharsets.UTF_8)); + assertEquals(sortBlock1, block1); + } + } + @Test public void jsonSerializationOfExecutorRegistration() throws IOException { ObjectMapper mapper = new ObjectMapper(); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java index 47c087088a8a..259cc2e04b7d 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java @@ -142,7 +142,7 @@ private static TestShuffleDataContext createSomeData() throws IOException { dataContext.create(); dataContext.insertSortShuffleData(rand.nextInt(1000), rand.nextInt(1000), new byte[][] { "ABC".getBytes(StandardCharsets.UTF_8), - "DEF".getBytes(StandardCharsets.UTF_8)}); + "DEF".getBytes(StandardCharsets.UTF_8)}, false); return dataContext; } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index f5b1ec9d46da..516e24afdaef 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -84,7 +84,8 @@ public static void beforeAll() throws IOException { dataContext0 = new TestShuffleDataContext(2, 5); dataContext0.create(); - dataContext0.insertSortShuffleData(0, 0, exec0Blocks); + dataContext0.insertSortShuffleData(0, 0, exec0Blocks, false); + dataContext0.insertSortShuffleData(0, 0, exec0Blocks, true); conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); handler = new ExternalShuffleBlockHandler(conf, null); @@ -191,6 +192,28 @@ public void testFetchThreeSort() throws Exception { exec0Fetch.releaseBuffers(); } + @Test + public void testFetchOneExtendedSort() throws Exception { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult exec0Fetch = fetchBlocks("exec-0", new String[] { "shuffle_0_0_0_0" }); + assertEquals(Sets.newHashSet("shuffle_0_0_0_0"), exec0Fetch.successBlocks); + assertTrue(exec0Fetch.failedBlocks.isEmpty()); + assertBufferListsEqual(exec0Fetch.buffers, Arrays.asList(exec0Blocks[0])); + exec0Fetch.releaseBuffers(); + } + + @Test + public void testFetchThreeExtendedSort() throws Exception { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult exec0Fetch = fetchBlocks("exec-0", + new String[] { "shuffle_0_0_0_0", "shuffle_0_0_1_0", "shuffle_0_0_2_0" }); + assertEquals(Sets.newHashSet("shuffle_0_0_0_0", "shuffle_0_0_1_0", "shuffle_0_0_2_0"), + exec0Fetch.successBlocks); + assertTrue(exec0Fetch.failedBlocks.isEmpty()); + assertBufferListsEqual(exec0Fetch.buffers, Arrays.asList(exec0Blocks)); + exec0Fetch.releaseBuffers(); + } + @Test (expected = RuntimeException.class) public void testRegisterInvalidExecutor() throws Exception { registerExecutor("exec-1", dataContext0.createExecutorInfo("unknown sort manager")); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/NonShuffleFilesCleanupSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/NonShuffleFilesCleanupSuite.java index d22f3ace4103..80b27010d903 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/NonShuffleFilesCleanupSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/NonShuffleFilesCleanupSuite.java @@ -211,7 +211,7 @@ private static void createShuffleFiles(TestShuffleDataContext dataContext) throw Random rand = new Random(123); dataContext.insertSortShuffleData(rand.nextInt(1000), rand.nextInt(1000), new byte[][] { "ABC".getBytes(StandardCharsets.UTF_8), - "DEF".getBytes(StandardCharsets.UTF_8)}); + "DEF".getBytes(StandardCharsets.UTF_8)}, false); } private static void createNonShuffleFiles(TestShuffleDataContext dataContext) throws IOException { diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java index 6989c3baf2e2..f0904fbc42a1 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java @@ -68,8 +68,10 @@ public void cleanup() { } /** Creates reducer blocks in a sort-based data format within our local dirs. */ - public void insertSortShuffleData(int shuffleId, int mapId, byte[][] blocks) throws IOException { + public void insertSortShuffleData( + int shuffleId, int mapId, byte[][] blocks, boolean extendedBlockId) throws IOException { String blockId = "shuffle_" + shuffleId + "_" + mapId + "_0"; + if (extendedBlockId) blockId += "_0"; OutputStream dataStream = null; DataOutputStream indexStream = null; diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 32b446785a9f..c2ec0bc41ad1 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -38,6 +38,7 @@ import org.apache.spark.Partitioner; import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; import org.apache.spark.serializer.Serializer; @@ -83,6 +84,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { private final int mapId; private final Serializer serializer; private final IndexShuffleBlockResolver shuffleBlockResolver; + private final Option stageAttemptId; /** Array of file writers, one for each partition */ private DiskBlockObjectWriter[] partitionWriters; @@ -102,6 +104,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { IndexShuffleBlockResolver shuffleBlockResolver, BypassMergeSortShuffleHandle handle, int mapId, + TaskContext taskContext, SparkConf conf, ShuffleWriteMetricsReporter writeMetrics) { // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided @@ -116,6 +119,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { this.writeMetrics = writeMetrics; this.serializer = dep.serializer(); this.shuffleBlockResolver = shuffleBlockResolver; + this.stageAttemptId = taskContext.getShuffleGenerationId(dep.shuffleId()); } @Override @@ -123,7 +127,8 @@ public void write(Iterator> records) throws IOException { assert (partitionWriters == null); if (!records.hasNext()) { partitionLengths = new long[numPartitions]; - shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null); + shuffleBlockResolver.writeIndexFileAndCommit( + shuffleId, mapId, partitionLengths, null, stageAttemptId); mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); return; } @@ -156,11 +161,12 @@ public void write(Iterator> records) throws IOException { } } - File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); + File output = shuffleBlockResolver.getDataFile(shuffleId, mapId, stageAttemptId); File tmp = Utils.tempFileWith(output); try { partitionLengths = writePartitionedFile(tmp); - shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); + shuffleBlockResolver.writeIndexFileAndCommit( + shuffleId, mapId, partitionLengths, tmp, stageAttemptId); } finally { if (tmp.exists() && !tmp.delete()) { logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 9d05f03613ce..45c032545c19 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -82,6 +82,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final int initialSortBufferSize; private final int inputBufferSizeInBytes; private final int outputBufferSizeInBytes; + private final Option stageAttemptId; @Nullable private MapStatus mapStatus; @Nullable private ShuffleExternalSorter sorter; @@ -150,6 +151,7 @@ public UnsafeShuffleWriter( this.outputBufferSizeInBytes = (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE()) * 1024; open(); + this.stageAttemptId = taskContext.getShuffleGenerationId(dep.shuffleId()); } private void updatePeakMemoryUsed() { @@ -231,7 +233,8 @@ void closeAndWriteOutput() throws IOException { final SpillInfo[] spills = sorter.closeAndGetSpills(); sorter = null; final long[] partitionLengths; - final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); + final File output = shuffleBlockResolver.getDataFile( + shuffleId, mapId, stageAttemptId); final File tmp = Utils.tempFileWith(output); try { try { @@ -243,7 +246,8 @@ void closeAndWriteOutput() throws IOException { } } } - shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); + shuffleBlockResolver.writeIndexFileAndCommit( + shuffleId, mapId, partitionLengths, tmp, stageAttemptId); } finally { if (tmp.exists() && !tmp.delete()) { logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 1d4b1ef9c9a1..c439574955b7 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -295,8 +295,12 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * and the second item is a sequence of (shuffle block id, shuffle block size) tuples * describing the shuffle blocks that are stored at that block manager. */ - def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] + def getMapSizesByExecutorId( + shuffleId: Int, + startPartition: Int, + endPartition: Int, + stageAttemptId: Option[Int] = None) + : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] /** * Deletes map output status information for the specified shuffle stage. @@ -645,13 +649,18 @@ private[spark] class MapOutputTrackerMaster( // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. // This method is only called in local-mode. - def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { + def getMapSizesByExecutorId( + shuffleId: Int, + startPartition: Int, + endPartition: Int, + stageAttemptId: Option[Int] = None) + : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") shuffleStatuses.get(shuffleId) match { case Some (shuffleStatus) => shuffleStatus.withMapStatuses { statuses => - MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) + MapOutputTracker.convertMapStatuses( + shuffleId, startPartition, endPartition, statuses, stageAttemptId) } case None => Iterator.empty @@ -682,12 +691,17 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr private val fetching = new HashSet[Int] // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. - override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { + override def getMapSizesByExecutorId( + shuffleId: Int, + startPartition: Int, + endPartition: Int, + stageAttemptId: Option[Int] = None) + : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") val statuses = getStatuses(shuffleId) try { - MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) + MapOutputTracker.convertMapStatuses( + shuffleId, startPartition, endPartition, statuses, stageAttemptId) } catch { case e: MetadataFetchFailedException => // We experienced a fetch failure so our mapStatuses cache is outdated; clear it: @@ -863,6 +877,7 @@ private[spark] object MapOutputTracker extends Logging { * @param startPartition Start of map output partition ID range (included in range) * @param endPartition End of map output partition ID range (excluded from range) * @param statuses List of map statuses, indexed by map ID. + * @param stageAttemptId The stage attempt id for retried indeterminate stage. * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, * and the second item is a sequence of (shuffle block ID, shuffle block size) tuples * describing the shuffle blocks that are stored at that block manager. @@ -871,7 +886,8 @@ private[spark] object MapOutputTracker extends Logging { shuffleId: Int, startPartition: Int, endPartition: Int, - statuses: Array[MapStatus]): Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { + statuses: Array[MapStatus], + stageAttemptId: Option[Int]): Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { assert (statuses != null) val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long)]] for ((status, mapId) <- statuses.iterator.zipWithIndex) { @@ -884,7 +900,7 @@ private[spark] object MapOutputTracker extends Logging { val size = status.getSizeForBlock(part) if (size != 0) { splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) += - ((ShuffleBlockId(shuffleId, mapId, part), size)) + ((ShuffleBlockId(shuffleId, mapId, part, stageAttemptId), size)) } } } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 8b744356daae..ef8d8e40fbd5 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2547,6 +2547,7 @@ object SparkContext extends Logging { private[spark] val SPARK_SCHEDULER_POOL = "spark.scheduler.pool" private[spark] val RDD_SCOPE_KEY = "spark.rdd.scope" private[spark] val RDD_SCOPE_NO_OVERRIDE_KEY = "spark.rdd.scope.noOverride" + private[spark] val SHUFFLE_GENERATION_ID_PREFIX = "_shuffle_generation_id_" /** * Executor id for the driver. In earlier versions of Spark, this was ``, but this was diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 959f246f3f9f..e5c93bdf0165 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -176,6 +176,17 @@ abstract class TaskContext extends Serializable { */ def getLocalProperty(key: String): String + /** + * The shuffle generation ID of the stage that this task belongs to, it returns the stage + * attempt number while the stage is not determinate and returns none on the contrary. + */ + def getShuffleGenerationId(shuffleId: Int): Option[Int] = { + val id = getLocalProperty(SparkContext.SHUFFLE_GENERATION_ID_PREFIX + shuffleId) + if (id != null) { + Some(id.toInt) + } else None + } + @DeveloperApi def taskMetrics(): TaskMetrics diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 524b0c4f6c3a..fe2453a4e9f9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -39,7 +39,7 @@ import org.apache.spark.internal.config import org.apache.spark.internal.config.Tests.TEST_NO_STAGE_RETRY import org.apache.spark.network.util.JavaUtils import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} -import org.apache.spark.rdd.{DeterministicLevel, RDD, RDDCheckpointData} +import org.apache.spark.rdd.{RDD, RDDCheckpointData} import org.apache.spark.rpc.RpcTimeout import org.apache.spark.storage._ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat @@ -1098,7 +1098,14 @@ private[spark] class DAGScheduler( private def submitMissingTasks(stage: Stage, jobId: Int) { logDebug("submitMissingTasks(" + stage + ")") - // First figure out the indexes of partition ids to compute. + // Before find missing partition, do the intermediate state clean work first. + stage match { + case sms: ShuffleMapStage if stage.isIndeterminate => + mapOutputTracker.unregisterAllMapOutput(sms.shuffleDep.shuffleId) + case _ => + } + + // Figure out the indexes of partition ids to compute. val partitionsToCompute: Seq[Int] = stage.findMissingPartitions() // Use the scheduling pool, job group, description, etc. from an ActiveJob associated @@ -1137,12 +1144,28 @@ private[spark] class DAGScheduler( } stage.makeNewStageAttempt(partitionsToCompute.size, taskIdToLocations.values.toSeq) - - // If there are tasks to execute, record the submission time of the stage. Otherwise, - // post the even without the submission time, which indicates that this stage was - // skipped. if (partitionsToCompute.nonEmpty) { + // If there are tasks to execute, record the submission time of the stage. Otherwise, + // post the even without the submission time, which indicates that this stage was + // skipped. stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) + + // While an indeterminate stage retried, the stage attempt id will be used to extend the + // shuffle file in shuffle write task, and then the mapping of shuffle id to indeterminate + // stage id will be used for shuffle reader task. + if (stage.latestInfo.attemptNumber() > 0 && stage.isIndeterminate) { + // deal with shuffle writer side property. + stage match { + case sms: ShuffleMapStage => + val stageAttemptId = stage.latestInfo.attemptNumber() + properties.setProperty( + SparkContext.SHUFFLE_GENERATION_ID_PREFIX + sms.shuffleDep.shuffleId, + stageAttemptId.toString) + logInfo(s"Set INDETERMINATE_STAGE_ATTEMPT_ID for $stage(shuffleId:" + + s" ${sms.shuffleDep.shuffleId}) to $stageAttemptId") + case _ => + } + } } listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties)) @@ -1547,7 +1570,6 @@ private[spark] class DAGScheduler( } abortStage(failedStage, abortMessage, None) } else { // update failedStages and make sure a ResubmitFailedStages event is enqueued - // TODO: Cancel running tasks in the failed stage -- cf. SPARK-17064 val noResubmitEnqueued = !failedStages.contains(failedStage) failedStages += failedStage failedStages += mapStage @@ -1559,7 +1581,7 @@ private[spark] class DAGScheduler( // Note that, if map stage is UNORDERED, we are fine. The shuffle partitioner is // guaranteed to be determinate, so the input data of the reducers will not change // even if the map tasks are re-tried. - if (mapStage.rdd.outputDeterministicLevel == DeterministicLevel.INDETERMINATE) { + if (mapStage.isIndeterminate) { // It's a little tricky to find all the succeeding stages of `failedStage`, because // each stage only know its parents not children. Here we traverse the stages from // the leaf nodes (the result stages of active jobs), and rollback all the stages @@ -1591,11 +1613,9 @@ private[spark] class DAGScheduler( case mapStage: ShuffleMapStage => val numMissingPartitions = mapStage.findMissingPartitions().length if (numMissingPartitions < mapStage.numTasks) { - // TODO: support to rollback shuffle files. - // Currently the shuffle writing is "first write wins", so we can't re-run a - // shuffle map stage and overwrite existing shuffle files. We have to finish - // SPARK-8029 first. - abortStage(mapStage, generateErrorMessage(mapStage), None) + logInfo(s"The indeterminate stage $mapStage will be resubmitted," + + " the stage self and all indeterminate parent stage will be" + + " rollback and whole stage rerun.") } case resultStage: ResultStage if resultStage.activeJob.isDefined => diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala index 1b44d0aee319..522fc99e73a8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala @@ -87,7 +87,10 @@ private[spark] class ShuffleMapStage( */ def isAvailable: Boolean = numAvailableOutputs == numPartitions - /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */ + /** + * Returns the sequence of partition ids that are missing (i.e. needs to be computed). + * If the current stage is indeterminate, missing partition is all partitions every time. + */ override def findMissingPartitions(): Seq[Int] = { mapOutputTrackerMaster .findMissingPartitions(shuffleDep.shuffleId) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 26cca334d3bd..b0b1eec93a14 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.HashSet import org.apache.spark.executor.TaskMetrics import org.apache.spark.internal.Logging -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{DeterministicLevel, RDD} import org.apache.spark.util.CallSite /** @@ -116,4 +116,7 @@ private[scheduler] abstract class Stage( /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */ def findMissingPartitions(): Seq[Int] + + def isIndeterminate: Boolean = + rdd.outputDeterministicLevel == DeterministicLevel.INDETERMINATE } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index c7843710413d..e3475ef74de7 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -47,7 +47,8 @@ private[spark] class BlockStoreShuffleReader[K, C]( context, blockManager.shuffleClient, blockManager, - mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition), + mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition, + context.getShuffleGenerationId(handle.shuffleId)), serializerManager.wrapStream, // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024, diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index d3f1c7ec1bbe..705c2db64ed2 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -51,26 +51,35 @@ private[spark] class IndexShuffleBlockResolver( private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") - def getDataFile(shuffleId: Int, mapId: Int): File = { - blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) + def getDataFile( + shuffleId: Int, + mapId: Int, + shuffleGenerationId: Option[Int] = None): File = { + blockManager.diskBlockManager.getFile( + ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID, shuffleGenerationId)) } - private def getIndexFile(shuffleId: Int, mapId: Int): File = { - blockManager.diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) + private def getIndexFile( + shuffleId: Int, + mapId: Int, + shuffleGenerationId: Option[Int] = None): File = { + blockManager.diskBlockManager.getFile( + ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID, shuffleGenerationId)) } /** * Remove data file and index file that contain the output data from one map. */ - def removeDataByMap(shuffleId: Int, mapId: Int): Unit = { - var file = getDataFile(shuffleId, mapId) + def removeDataByMap( + shuffleId: Int, mapId: Int, shuffleGenerationId: Option[Int] = None): Unit = { + var file = getDataFile(shuffleId, mapId, shuffleGenerationId) if (file.exists()) { if (!file.delete()) { logWarning(s"Error deleting data ${file.getPath()}") } } - file = getIndexFile(shuffleId, mapId) + file = getIndexFile(shuffleId, mapId, shuffleGenerationId) if (file.exists()) { if (!file.delete()) { logWarning(s"Error deleting index ${file.getPath()}") @@ -137,11 +146,12 @@ private[spark] class IndexShuffleBlockResolver( shuffleId: Int, mapId: Int, lengths: Array[Long], - dataTmp: File): Unit = { - val indexFile = getIndexFile(shuffleId, mapId) + dataTmp: File, + shuffleGenerationId: Option[Int] = None): Unit = { + val indexFile = getIndexFile(shuffleId, mapId, shuffleGenerationId) val indexTmp = Utils.tempFileWith(indexFile) try { - val dataFile = getDataFile(shuffleId, mapId) + val dataFile = getDataFile(shuffleId, mapId, shuffleGenerationId) // There is only one IndexShuffleBlockResolver per executor, this synchronization make sure // the following check and rename are atomic. synchronized { @@ -193,7 +203,7 @@ private[spark] class IndexShuffleBlockResolver( override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { // The block is actually going to be a range of a single map output file for this map, so // find out the consolidated file, then the offset within that from our index - val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId) + val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId, blockId.shuffleGenerationId) // SPARK-22982: if this FileInputStream's position is seeked forward by another piece of code // which is incorrectly using our file descriptor then this code will fetch the wrong offsets @@ -215,7 +225,7 @@ private[spark] class IndexShuffleBlockResolver( } new FileSegmentManagedBuffer( transportConf, - getDataFile(blockId.shuffleId, blockId.mapId), + getDataFile(blockId.shuffleId, blockId.mapId, blockId.shuffleGenerationId), offset, nextOffset - offset) } finally { diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index b59fa8e8a3cc..c78a154088a1 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -19,6 +19,8 @@ package org.apache.spark.shuffle.sort import java.util.concurrent.ConcurrentHashMap +import scala.collection.JavaConverters._ + import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.shuffle._ @@ -75,9 +77,10 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager } /** - * A mapping from shuffle ids to the number of mappers producing output for those shuffles. + * A mapping from shuffle ids to the tuple of number of mappers producing output and + * indeterminate stage attempt id for those shuffles. */ - private[this] val numMapsForShuffle = new ConcurrentHashMap[Int, Int]() + private[this] val infoMapsForShuffle = new ConcurrentHashMap[Int, (Int, Option[Int])]() override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) @@ -127,8 +130,10 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager mapId: Int, context: TaskContext, metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { - numMapsForShuffle.putIfAbsent( - handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps) + infoMapsForShuffle.putIfAbsent( + handle.shuffleId, + (handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps, + context.getShuffleGenerationId(handle.shuffleId))) val env = SparkEnv.get handle match { case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => @@ -147,6 +152,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], bypassMergeSortHandle, mapId, + context, env.conf, metrics) case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => @@ -156,10 +162,11 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager /** Remove a shuffle's metadata from the ShuffleManager. */ override def unregisterShuffle(shuffleId: Int): Boolean = { - Option(numMapsForShuffle.remove(shuffleId)).foreach { numMaps => - (0 until numMaps).foreach { mapId => - shuffleBlockResolver.removeDataByMap(shuffleId, mapId) - } + Option(infoMapsForShuffle.remove(shuffleId)).foreach { + case (numMaps, shuffleGenerationId) => + (0 until numMaps).foreach { mapId => + shuffleBlockResolver.removeDataByMap(shuffleId, mapId, shuffleGenerationId) + } } true } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 16058de8bf3f..2f3cd55ca855 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -47,6 +47,8 @@ private[spark] class SortShuffleWriter[K, V, C]( private val writeMetrics = context.taskMetrics().shuffleWriteMetrics + private val shuffleGenerationId = context.getShuffleGenerationId(handle.shuffleId) + /** Write a bunch of records to this task's output */ override def write(records: Iterator[Product2[K, V]]): Unit = { sorter = if (dep.mapSideCombine) { @@ -64,12 +66,17 @@ private[spark] class SortShuffleWriter[K, V, C]( // Don't bother including the time to open the merged output file in the shuffle write time, // because it just opens a single file, so is typically too fast to measure accurately // (see SPARK-3570). - val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId) + val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId, shuffleGenerationId) val tmp = Utils.tempFileWith(output) try { - val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) + val blockId = ShuffleBlockId( + dep.shuffleId, + mapId, + IndexShuffleBlockResolver.NOOP_REDUCE_ID, + shuffleGenerationId) val partitionLengths = sorter.writePartitionedFile(blockId, tmp) - shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) + shuffleBlockResolver.writeIndexFileAndCommit( + dep.shuffleId, mapId, partitionLengths, tmp, shuffleGenerationId) mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) } finally { if (tmp.exists() && !tmp.delete()) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 7ac2c71c18eb..d4e7f09f538a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -51,19 +51,49 @@ case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId { // Format of the shuffle block ids (including data and index) should be kept in sync with // org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getBlockData(). +// IndeterminateAttemptId is only set when the ShuffleMapStage's [[DeterministicLevel]] is +// INDETERMINATE and fetch fail triggered whole map stage rerun. @DeveloperApi -case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { - override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId +case class ShuffleBlockId( + shuffleId: Int, + mapId: Int, + reduceId: Int, + shuffleGenerationId: Option[Int] = None) + extends BlockId { + override def name: String = { + val nameStr = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + if (shuffleGenerationId.isEmpty) nameStr else nameStr + "_" + shuffleGenerationId.get + } } @DeveloperApi -case class ShuffleDataBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { - override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".data" +case class ShuffleDataBlockId( + shuffleId: Int, + mapId: Int, + reduceId: Int, + shuffleGenerationId: Option[Int] = None) + extends BlockId { + override def name: String = { + val nameStr = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + val nameStrWithIndeterminateAttempt = + if (shuffleGenerationId.isEmpty) nameStr else nameStr + "_" + shuffleGenerationId.get + nameStrWithIndeterminateAttempt + ".data" + } } @DeveloperApi -case class ShuffleIndexBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { - override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index" +case class ShuffleIndexBlockId( + shuffleId: Int, + mapId: Int, + reduceId: Int, + shuffleGenerationId: Option[Int] = None) + extends BlockId { + override def name: String = { + val nameStr = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + val nameStrWithIndeterminateAttempt = + if (shuffleGenerationId.isEmpty) nameStr else nameStr + "_" + shuffleGenerationId.get + nameStrWithIndeterminateAttempt + ".index" + } } @DeveloperApi @@ -98,7 +128,7 @@ private[spark] case class TestBlockId(id: String) extends BlockId { @DeveloperApi class UnrecognizedBlockId(name: String) - extends SparkException(s"Failed to parse $name into a block ID") + extends SparkException(s"Failed to parse $name into a block ID") @DeveloperApi object BlockId { @@ -106,6 +136,11 @@ object BlockId { val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r val SHUFFLE_DATA = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).data".r val SHUFFLE_INDEX = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).index".r + // The EXTEND_SHUFFLE extends SHUFFLE regex with shuffle generation id, it is only used in the + // scenario of rerunning an INDETERMINATE stage. + val EXTEND_SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)_([0-9]+)".r + val EXTEND_SHUFFLE_DATA = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)_([0-9]+).data".r + val EXTEND_SHUFFLE_INDEX = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)_([0-9]+).index".r val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r val TASKRESULT = "taskresult_([0-9]+)".r val STREAM = "input-([0-9]+)-([0-9]+)".r @@ -122,6 +157,14 @@ object BlockId { ShuffleDataBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt) case SHUFFLE_INDEX(shuffleId, mapId, reduceId) => ShuffleIndexBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt) + case EXTEND_SHUFFLE(shuffleId, mapId, reduceId, shuffleGenerationId) => + ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt, Some(shuffleGenerationId.toInt)) + case EXTEND_SHUFFLE_DATA(shuffleId, mapId, reduceId, shuffleGenerationId) => + ShuffleDataBlockId( + shuffleId.toInt, mapId.toInt, reduceId.toInt, Some(shuffleGenerationId.toInt)) + case EXTEND_SHUFFLE_INDEX(shuffleId, mapId, reduceId, shuffleGenerationId) => + ShuffleIndexBlockId( + shuffleId.toInt, mapId.toInt, reduceId.toInt, Some(shuffleGenerationId.toInt)) case BROADCAST(broadcastId, field) => BroadcastBlockId(broadcastId.toLong, field.stripPrefix("_")) case TASKRESULT(taskId) => diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index c89d5cc971d2..69e7e4f89238 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -584,7 +584,7 @@ final class ShuffleBlockFetcherIterator( address: BlockManagerId, e: Throwable) = { blockId match { - case ShuffleBlockId(shufId, mapId, reduceId) => + case ShuffleBlockId(shufId, mapId, reduceId, _) => throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e) case _ => throw new SparkException( diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 88125a6b93ad..b6405b4bbbee 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -131,7 +131,8 @@ public void setUp() throws IOException { ); }); - when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile); + when(shuffleBlockResolver.getDataFile( + anyInt(), anyInt(), any())).thenReturn(mergedOutputFile); doAnswer(invocationOnMock -> { partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2]; File tmp = (File) invocationOnMock.getArguments()[3]; @@ -139,7 +140,8 @@ public void setUp() throws IOException { tmp.renameTo(mergedOutputFile); return null; }).when(shuffleBlockResolver) - .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), any(File.class)); + .writeIndexFileAndCommit( + anyInt(), anyInt(), any(long[].class), any(File.class), any(Option.class)); when(diskBlockManager.createTempShuffleBlock()).thenAnswer(invocationOnMock -> { TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID()); diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index c16e227edbfa..a66f433bd118 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -533,8 +533,8 @@ class CleanerTester( private def getShuffleBlocks(shuffleId: Int): Seq[BlockId] = { blockManager.master.getMatchingBlockIds( _ match { - case ShuffleBlockId(`shuffleId`, _, _) => true - case ShuffleIndexBlockId(`shuffleId`, _, _) => true + case ShuffleBlockId(`shuffleId`, _, _, _) => true + case ShuffleIndexBlockId(`shuffleId`, _, _, _) => true case _ => false }, askSlaves = true) } diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index 8b737cd8c81f..5e2e76c4937d 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.network.TransportContext import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.server.TransportServer import org.apache.spark.network.shuffle.{ExternalShuffleBlockHandler, ExternalShuffleClient} +import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils /** @@ -63,7 +64,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { } // This test ensures that the external shuffle service is actually in use for the other tests. - test("using external shuffle service") { + private def checkResultWithShuffleService(createRDD: (SparkContext => RDD[_])): Unit = { sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) sc.env.blockManager.externalShuffleServiceEnabled should equal(true) sc.env.blockManager.shuffleClient.getClass should equal(classOf[ExternalShuffleClient]) @@ -76,7 +77,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { // Therefore, we should wait until all slaves are up TestUtils.waitUntilExecutorsUp(sc, 2, 60000) - val rdd = sc.parallelize(0 until 1000, 10).map(i => (i, 1)).reduceByKey(_ + _) + val rdd = createRDD(sc) rdd.count() rdd.count() @@ -92,4 +93,16 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { } e.getMessage should include ("Fetch failure will not retry stage due to testing config") } + + test("using external shuffle service") { + val createRDD = (sc: SparkContext) => + sc.parallelize(0 until 1000, 10).map(i => (i, 1)).reduceByKey(_ + _) + checkResultWithShuffleService(createRDD) + } + + test("using external shuffle service for indeterminate rdd") { + val createIndeterminateRDD = (sc: SparkContext) => + sc.parallelize(0 until 1000, 10).repartition(11).repartition(12) + checkResultWithShuffleService(createIndeterminateRDD) + } } diff --git a/core/src/test/scala/org/apache/spark/deploy/ExternalShuffleServiceDbSuite.scala b/core/src/test/scala/org/apache/spark/deploy/ExternalShuffleServiceDbSuite.scala index e33c3f8f9550..79c7368f0bdb 100644 --- a/core/src/test/scala/org/apache/spark/deploy/ExternalShuffleServiceDbSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/ExternalShuffleServiceDbSuite.scala @@ -58,7 +58,7 @@ class ExternalShuffleServiceDbSuite extends SparkFunSuite { // Write some sort data. dataContext.insertSortShuffleData(0, 0, Array[Array[Byte]](sortBlock0.getBytes(StandardCharsets.UTF_8), - sortBlock1.getBytes(StandardCharsets.UTF_8))) + sortBlock1.getBytes(StandardCharsets.UTF_8)), false) registerExecutor() } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 72c20a817336..4d68986bcf70 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -2700,7 +2700,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assert(countSubmittedMapStageAttempts() === 2) } - test("SPARK-23207: retry all the succeeding stages when the map stage is indeterminate") { + test("SPARK-25341: retry all the succeeding stages when the map stage is indeterminate") { val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = true) val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(2)) @@ -2711,7 +2711,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi val shuffleId2 = shuffleDep2.shuffleId val finalRdd = new MyRDD(sc, 2, List(shuffleDep2), tracker = mapOutputTracker) - submit(finalRdd, Array(0, 1)) + submit(finalRdd, Array(0, 1), properties = new Properties()) // Finish the first shuffle map stage. complete(taskSets(0), Seq( @@ -2750,8 +2750,104 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi FetchFailed(makeBlockManagerId("hostA"), shuffleId1, 0, 0, "ignored"), null)) - // The job should fail because Spark can't rollback the shuffle map stage. - assert(failure != null && failure.getMessage.contains("Spark cannot rollback")) + val newFailedStages = scheduler.failedStages.toSeq + assert(newFailedStages.map(_.id) == Seq(0, 1)) + + scheduler.resubmitFailedStages() + + // First shuffle map stage resubmitted and reran all tasks. + assert(taskSets(4).stageId == 0) + assert(taskSets(4).stageAttemptId == 1) + assert(taskSets(4).tasks.length == 2) + + // Finish all stage. + complete(taskSets(4), Seq( + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)))) + assert(mapOutputTracker.findMissingPartitions(shuffleId1) === Some(Seq.empty)) + assert(taskSets(4).tasks.head.localProperties.getProperty( + SparkContext.SHUFFLE_GENERATION_ID_PREFIX + shuffleId1.toString) == "1") + + complete(taskSets(5), Seq( + (Success, makeMapStatus("hostC", 2)), + (Success, makeMapStatus("hostD", 2)))) + assert(mapOutputTracker.findMissingPartitions(shuffleId2) === Some(Seq.empty)) + assert(taskSets(5).tasks.head.localProperties.getProperty( + SparkContext.SHUFFLE_GENERATION_ID_PREFIX + shuffleId2.toString) == "2") + + complete(taskSets(6), Seq((Success, 11), (Success, 12))) + + // Job successful ended. + assert(results === Map(0 -> 11, 1 -> 12)) + results.clear() + assertDataStructuresEmpty() + } + + test("SPARK-25341: continuous indeterminate stage roll back") { + // shuffleMapRdd1/2/3 are all indeterminate. + val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = true) + val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(2)) + val shuffleId1 = shuffleDep1.shuffleId + + val shuffleMapRdd2 = new MyRDD( + sc, 2, List(shuffleDep1), tracker = mapOutputTracker, indeterminate = true) + val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new HashPartitioner(2)) + val shuffleId2 = shuffleDep2.shuffleId + + val shuffleMapRdd3 = new MyRDD( + sc, 2, List(shuffleDep2), tracker = mapOutputTracker, indeterminate = true) + val shuffleDep3 = new ShuffleDependency(shuffleMapRdd3, new HashPartitioner(2)) + val shuffleId3 = shuffleDep3.shuffleId + val finalRdd = new MyRDD(sc, 2, List(shuffleDep3), tracker = mapOutputTracker) + + submit(finalRdd, Array(0, 1), properties = new Properties()) + + // Finish the first 3 shuffle map stages. + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)))) + assert(mapOutputTracker.findMissingPartitions(shuffleId1) === Some(Seq.empty)) + + complete(taskSets(1), Seq( + (Success, makeMapStatus("hostB", 2)), + (Success, makeMapStatus("hostD", 2)))) + assert(mapOutputTracker.findMissingPartitions(shuffleId2) === Some(Seq.empty)) + + // Executor lost on hostB, both of stage 0 and 1 should be reran. + runEvent(makeCompletionEvent( + taskSets(2).tasks(0), + FetchFailed(makeBlockManagerId("hostB"), shuffleId2, 0, 0, "ignored"), + null)) + mapOutputTracker.removeOutputsOnHost("hostB") + + assert(scheduler.failedStages.toSeq.map(_.id) == Seq(1, 2)) + scheduler.resubmitFailedStages() + + def checkAndCompleteRetryStage( + taskSetIndex: Int, + stageId: Int, + shuffleId: Int): Unit = { + assert(taskSets(taskSetIndex).stageId == stageId) + assert(taskSets(taskSetIndex).stageAttemptId == 1) + assert(taskSets(taskSetIndex).tasks.length == 2) + complete(taskSets(taskSetIndex), Seq( + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)))) + assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty)) + assert(taskSets(taskSetIndex).tasks.head.localProperties.getProperty( + SparkContext.SHUFFLE_GENERATION_ID_PREFIX + shuffleId.toString) == "1") + } + + // Check all indeterminate stage roll back. + checkAndCompleteRetryStage(3, 0, shuffleId1) + checkAndCompleteRetryStage(4, 1, shuffleId2) + checkAndCompleteRetryStage(5, 2, shuffleId3) + + // Result stage success, all job ended. + complete(taskSets(6), Seq((Success, 11), (Success, 12))) + assert(results === Map(0 -> 11, 1 -> 12)) + results.clear() + assertDataStructuresEmpty() } private def assertResultStageFailToRollback(mapRdd: MyRDD): Unit = { diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index b29d32f7b35c..d3da1bf75ecf 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -57,7 +57,7 @@ object FakeTask { val tasks = Array.tabulate[Task[_]](numTasks) { i => new FakeTask(stageId, i, if (prefLocs.size != 0) prefLocs(i) else Nil) } - new TaskSet(tasks, stageId, stageAttemptId, priority = 0, null) + new TaskSet(tasks, stageId, stageAttemptId, priority = 0, new Properties()) } def createShuffleMapTaskSet( @@ -92,6 +92,6 @@ object FakeTask { val tasks = Array.tabulate[Task[_]](numTasks) { i => new FakeTask(stageId, i, if (prefLocs.size != 0) prefLocs(i) else Nil, isBarrier = true) } - new TaskSet(tasks, stageId, stageAttempId, priority = 0, null) + new TaskSet(tasks, stageId, stageAttempId, priority = 0, new Properties()) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 137ff2bd167a..1f1b5b08e654 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer +import java.util.Properties import scala.collection.mutable.HashMap import scala.concurrent.duration._ @@ -196,7 +197,8 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B config.CPUS_PER_TASK.key -> taskCpus.toString) val numFreeCores = 1 val taskSet = new TaskSet( - Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 0, 0, 0, null) + Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), + 0, 0, 0, new Properties()) val multiCoreWorkerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", taskCpus), new WorkerOffer("executor1", "host1", numFreeCores)) taskScheduler.submitTasks(taskSet) @@ -210,7 +212,8 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B // still be processed without error taskScheduler.submitTasks(FakeTask.createTaskSet(1)) val taskSet2 = new TaskSet( - Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 1, 0, 0, null) + Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), + 1, 0, 0, new Properties()) taskScheduler.submitTasks(taskSet2) taskDescriptions = taskScheduler.resourceOffers(multiCoreWorkerOffers).flatten assert(taskDescriptions.map(_.executorId) === Seq("executor0")) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 2442670b6d3f..c6d0e0a74cc5 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -351,7 +351,8 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { val denseBlockSizes = new Array[Long](5000) val sparseBlockSizes = Array[Long](0L, 1L, 0L, 2L) Seq(denseBlockSizes, sparseBlockSizes).foreach { blockSizes => - ser.serialize(HighlyCompressedMapStatus(BlockManagerId("exec-1", "host", 1234), blockSizes)) + ser.serialize(HighlyCompressedMapStatus( + BlockManagerId("exec-1", "host", 1234), blockSizes)) } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 6d2ef17a7a79..f78e43c5ebf0 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -102,7 +102,8 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // Make a mocked MapOutputTracker for the shuffle reader to use to determine what // shuffle data to read. val mapOutputTracker = mock(classOf[MapOutputTracker]) - when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)).thenReturn { + when(mapOutputTracker.getMapSizesByExecutorId( + shuffleId, reduceId, reduceId + 1, None)).thenReturn { // Test a scenario where all data is local, to avoid creating a bunch of additional mocks // for the code to read data over the network. val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId => diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index fc1422dfaac7..3ef59d140645 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -67,7 +67,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte when(dependency.partitioner).thenReturn(new HashPartitioner(7)) when(dependency.serializer).thenReturn(new JavaSerializer(conf)) when(taskContext.taskMetrics()).thenReturn(taskMetrics) - when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile) + when(taskContext.getShuffleGenerationId(any[Int])).thenReturn(None) + when(blockResolver.getDataFile(0, 0, None)).thenReturn(outputFile) doAnswer { (invocationOnMock: InvocationOnMock) => val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File] if (tmp != null) { @@ -76,7 +77,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte } null }.when(blockResolver) - .writeIndexFileAndCommit(anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File])) + .writeIndexFileAndCommit( + anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File]), any(classOf[Option[Int]])) when(blockManager.diskBlockManager).thenReturn(diskBlockManager) when(blockManager.getDiskWriter( any[BlockId], @@ -125,6 +127,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte blockResolver, shuffleHandle, 0, // MapId + taskContext, conf, taskContext.taskMetrics().shuffleWriteMetrics ) @@ -149,6 +152,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte blockResolver, shuffleHandle, 0, // MapId + taskContext, conf, taskContext.taskMetrics().shuffleWriteMetrics ) @@ -184,6 +188,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte blockResolver, shuffleHandle, 0, // MapId + taskContext, conf, taskContext.taskMetrics().shuffleWriteMetrics ) @@ -206,6 +211,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte blockResolver, shuffleHandle, 0, // MapId + taskContext, conf, taskContext.taskMetrics().shuffleWriteMetrics ) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala index 27bb06b4e063..363cd30fd53f 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala @@ -58,10 +58,8 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } } - test("commit shuffle files multiple times") { - val shuffleId = 1 - val mapId = 2 - val idxName = s"shuffle_${shuffleId}_${mapId}_0.index" + private def testWithIndexShuffleBlockResolver( + shuffleId: Int, mapId: Int, idxName: String, attemptId: Option[Int]): Unit = { val resolver = new IndexShuffleBlockResolver(conf, blockManager) val lengths = Array[Long](10, 0, 20) val dataTmp = File.createTempFile("shuffle", null, tempDir) @@ -71,10 +69,10 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } { out.close() } - resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp) + resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp, attemptId) val indexFile = new File(tempDir.getAbsolutePath, idxName) - val dataFile = resolver.getDataFile(shuffleId, mapId) + val dataFile = resolver.getDataFile(shuffleId, mapId, attemptId) assert(indexFile.exists()) assert(indexFile.length() === (lengths.length + 1) * 8) @@ -91,7 +89,7 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } { out2.close() } - resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths2, dataTmp2) + resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths2, dataTmp2, attemptId) assert(indexFile.length() === (lengths.length + 1) * 8) assert(lengths2.toSeq === lengths.toSeq) @@ -130,7 +128,7 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } { out3.close() } - resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths3, dataTmp3) + resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths3, dataTmp3, attemptId) assert(indexFile.length() === (lengths3.length + 1) * 8) assert(lengths3.toSeq != lengths.toSeq) assert(dataFile.exists()) @@ -155,4 +153,19 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa indexIn2.close() } } + + test("commit shuffle files multiple times") { + val shuffleId = 1 + val mapId = 2 + val idxName = s"shuffle_${shuffleId}_${mapId}_0.index" + testWithIndexShuffleBlockResolver(shuffleId, mapId, idxName, None) + } + + test("commit shuffle files with attempt id multiple times") { + val shuffleId = 1 + val mapId = 2 + val attemptId = Some(1) + val idxName = s"shuffle_${shuffleId}_${mapId}_0_1.index" + testWithIndexShuffleBlockResolver(shuffleId, mapId, idxName, attemptId) + } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala index ff4755833a91..e1d176ef4e8d 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala @@ -90,6 +90,48 @@ class BlockIdSuite extends SparkFunSuite { assertSame(id, BlockId(id.toString)) } + test("extend shuffle") { + val id = ShuffleBlockId(1, 2, 3, Some(4)) + assertSame(id, ShuffleBlockId(1, 2, 3, Some(4))) + assertDifferent(id, ShuffleBlockId(3, 2, 3, Some(5))) + assert(id.name === "shuffle_1_2_3_4") + assert(id.asRDDId === None) + assert(id.shuffleId === 1) + assert(id.mapId === 2) + assert(id.reduceId === 3) + assert(id.shuffleGenerationId.get === 4) + assert(id.isShuffle) + assertSame(id, BlockId(id.toString)) + } + + test("extend shuffle data") { + val id = ShuffleDataBlockId(4, 5, 6, Some(7)) + assertSame(id, ShuffleDataBlockId(4, 5, 6, Some(7))) + assertDifferent(id, ShuffleDataBlockId(6, 5, 6, Some(8))) + assert(id.name === "shuffle_4_5_6_7.data") + assert(id.asRDDId === None) + assert(id.shuffleId === 4) + assert(id.mapId === 5) + assert(id.reduceId === 6) + assert(id.shuffleGenerationId.get === 7) + assert(!id.isShuffle) + assertSame(id, BlockId(id.toString)) + } + + test("extend shuffle index") { + val id = ShuffleIndexBlockId(7, 8, 9, Some(10)) + assertSame(id, ShuffleIndexBlockId(7, 8, 9, Some(10))) + assertDifferent(id, ShuffleIndexBlockId(9, 8, 9, Some(11))) + assert(id.name === "shuffle_7_8_9_10.index") + assert(id.asRDDId === None) + assert(id.shuffleId === 7) + assert(id.mapId === 8) + assert(id.reduceId === 9) + assert(id.shuffleGenerationId.get === 10) + assert(!id.isShuffle) + assertSame(id, BlockId(id.toString)) + } + test("broadcast") { val id = BroadcastBlockId(42) assertSame(id, BroadcastBlockId(42)) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 2a5d01e3772c..776f3399240f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -270,6 +270,20 @@ object MimaExcludes { // [SPARK-26457] Show hadoop configurations in HistoryServer environment tab ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ApplicationEnvironmentInfo.this"), + // [SPARK-25341][CORE] Support rolling back a shuffle map stage and re-generate the shuffle files + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.ShuffleIndexBlockId.copy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.ShuffleIndexBlockId.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.ShuffleDataBlockId.copy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.ShuffleDataBlockId.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.ShuffleBlockId.copy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.ShuffleBlockId.this"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.storage.ShuffleIndexBlockId$"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.ShuffleIndexBlockId.apply"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.storage.ShuffleDataBlockId$"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.ShuffleDataBlockId.apply"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.storage.ShuffleBlockId$"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.ShuffleBlockId.apply"), + // Data Source V2 API changes (problem: Problem) => problem match { case MissingClassProblem(cls) =>