diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java index 253fb7aca1d8..32222e910df0 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java @@ -255,4 +255,18 @@ public void getMergedBlockMeta( MergedBlocksMetaListener listener) { throw new UnsupportedOperationException(); } + + /** + * Remove the shuffle merge data in shuffle services + * + * @param host the host of the remote node. + * @param port the port of the remote node. + * @param shuffleId shuffle id. + * @param shuffleMergeId shuffle merge id. + * + * @since 3.4.0 + */ + public boolean removeShuffleMerge(String host, int port, int shuffleId, int shuffleMergeId) { + throw new UnsupportedOperationException(); + } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java index 4e40090b065e..3d7c1b1ca0cc 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java @@ -224,6 +224,12 @@ protected void handleMessage( } finally { responseDelayContext.stop(); } + } else if (msgObj instanceof RemoveShuffleMerge) { + RemoveShuffleMerge msg = (RemoveShuffleMerge) msgObj; + checkAuth(client, msg.appId); + logger.info("Removing shuffle merge data for application {} shuffle {} shuffleMerge {}", + msg.appId, msg.shuffleId, msg.shuffleMergeId); + mergeManager.removeShuffleMerge(msg); } else if (msgObj instanceof DiagnoseCorruption) { DiagnoseCorruption msg = (DiagnoseCorruption) msgObj; checkAuth(client, msg.appId); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java index b066d99e8ef8..1451d5712812 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java @@ -256,6 +256,23 @@ public void onFailure(Throwable e) { } } + @Override + public boolean removeShuffleMerge(String host, int port, int shuffleId, int shuffleMergeId) { + checkInit(); + try { + TransportClient client = clientFactory.createClient(host, port); + client.send( + new RemoveShuffleMerge(appId, comparableAppAttemptId, shuffleId, shuffleMergeId) + .toByteBuffer()); + // TODO(SPARK-42025): Add some error logs for RemoveShuffleMerge RPC + } catch (Exception e) { + logger.debug("Exception while sending RemoveShuffleMerge request to {}:{}", + host, port, e); + return false; + } + return true; + } + @Override public MetricSet shuffleMetrics() { checkInit(); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedShuffleFileManager.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedShuffleFileManager.java index 051684a92d0b..ab498367d500 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedShuffleFileManager.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedShuffleFileManager.java @@ -26,6 +26,7 @@ import org.apache.spark.network.shuffle.protocol.FinalizeShuffleMerge; import org.apache.spark.network.shuffle.protocol.MergeStatuses; import org.apache.spark.network.shuffle.protocol.PushBlockStream; +import org.apache.spark.network.shuffle.protocol.RemoveShuffleMerge; /** * The MergedShuffleFileManager is used to process push based shuffle when enabled. It works @@ -121,6 +122,14 @@ MergedBlockMeta getMergedBlockMeta( */ String[] getMergedBlockDirs(String appId); + /** + * Remove shuffle merge data files. + * + * @param removeShuffleMerge contains shuffle details (appId, shuffleId, etc) to uniquely + * identify a shuffle to be removed + */ + void removeShuffleMerge(RemoveShuffleMerge removeShuffleMerge); + /** * Optionally close any resources associated the MergedShuffleFileManager, such as the * leveldb for state persistence. diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/NoOpMergedShuffleFileManager.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/NoOpMergedShuffleFileManager.java index 876b10095938..7d8f9e27402a 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/NoOpMergedShuffleFileManager.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/NoOpMergedShuffleFileManager.java @@ -26,6 +26,7 @@ import org.apache.spark.network.shuffle.protocol.FinalizeShuffleMerge; import org.apache.spark.network.shuffle.protocol.MergeStatuses; import org.apache.spark.network.shuffle.protocol.PushBlockStream; +import org.apache.spark.network.shuffle.protocol.RemoveShuffleMerge; import org.apache.spark.network.util.TransportConf; /** @@ -84,4 +85,9 @@ public MergedBlockMeta getMergedBlockMeta( public String[] getMergedBlockDirs(String appId) { throw new UnsupportedOperationException("Cannot handle shuffle block merge"); } + + @Override + public void removeShuffleMerge(RemoveShuffleMerge removeShuffleMerge) { + throw new UnsupportedOperationException("Cannot handle merged shuffle remove"); + } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java index 816d1082850c..edb0b6f2d4d1 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java @@ -71,6 +71,7 @@ import org.apache.spark.network.shuffle.protocol.FinalizeShuffleMerge; import org.apache.spark.network.shuffle.protocol.MergeStatuses; import org.apache.spark.network.shuffle.protocol.PushBlockStream; +import org.apache.spark.network.shuffle.protocol.RemoveShuffleMerge; import org.apache.spark.network.shuffledb.DB; import org.apache.spark.network.shuffledb.DBBackend; import org.apache.spark.network.shuffledb.DBIterator; @@ -95,6 +96,12 @@ public class RemoteBlockPushResolver implements MergedShuffleFileManager { public static final String MERGE_DIR_KEY = "mergeDir"; public static final String ATTEMPT_ID_KEY = "attemptId"; private static final int UNDEFINED_ATTEMPT_ID = -1; + + /** + * The flag for deleting all merged shuffle data. + */ + public static final int DELETE_ALL_MERGED_SHUFFLE = -1; + private static final String DB_KEY_DELIMITER = ";"; private static final ErrorHandler.BlockPushErrorHandler ERROR_HANDLER = createErrorHandler(); // ByteBuffer to respond to client upon a successful merge of a pushed block @@ -396,6 +403,59 @@ public void applicationRemoved(String appId, boolean cleanupLocalDirs) { } } + @Override + public void removeShuffleMerge(RemoveShuffleMerge msg) { + AppShuffleInfo appShuffleInfo = validateAndGetAppShuffleInfo(msg.appId); + if (appShuffleInfo.attemptId != msg.appAttemptId) { + throw new IllegalArgumentException( + String.format("The attempt id %s in this RemoveShuffleMerge message does not match " + + "with the current attempt id %s stored in shuffle service for application %s", + msg.appAttemptId, appShuffleInfo.attemptId, msg.appId)); + } + appShuffleInfo.shuffles.compute(msg.shuffleId, (shuffleId, mergePartitionsInfo) -> { + if (mergePartitionsInfo == null) { + if (msg.shuffleMergeId == DELETE_ALL_MERGED_SHUFFLE) { + return null; + } else { + writeAppAttemptShuffleMergeInfoToDB(new AppAttemptShuffleMergeId( + msg.appId, msg.appAttemptId, msg.shuffleId, msg.shuffleMergeId)); + return new AppShuffleMergePartitionsInfo(msg.shuffleMergeId, true); + } + } + boolean deleteCurrentMergedShuffle = + msg.shuffleMergeId == DELETE_ALL_MERGED_SHUFFLE || + msg.shuffleMergeId == mergePartitionsInfo.shuffleMergeId; + int shuffleMergeIdToDelete = msg.shuffleMergeId != DELETE_ALL_MERGED_SHUFFLE ? + msg.shuffleMergeId : mergePartitionsInfo.shuffleMergeId; + if (deleteCurrentMergedShuffle || + shuffleMergeIdToDelete > mergePartitionsInfo.shuffleMergeId) { + AppAttemptShuffleMergeId currentAppAttemptShuffleMergeId = + new AppAttemptShuffleMergeId( + msg.appId, msg.appAttemptId, msg.shuffleId, mergePartitionsInfo.shuffleMergeId); + if (!mergePartitionsInfo.isFinalized()) { + // Clean up shuffle data before the shuffle was finalized. Close and delete all the open + // files. + submitCleanupTask(() -> + closeAndDeleteOutdatedPartitions( + currentAppAttemptShuffleMergeId, mergePartitionsInfo.shuffleMergePartitions)); + } else { + // Current shuffle was finalized, delete all the merged files through reduceIds set + // in finalizeShuffleMerge method. + submitCleanupTask(() -> + deleteMergedFiles(currentAppAttemptShuffleMergeId, appShuffleInfo, + mergePartitionsInfo.getReduceIds(), false)); + } + } else { + throw new RuntimeException(String.format("Asked to remove old shuffle merged data for " + + "application %s shuffleId %s shuffleMergeId %s, but current shuffleMergeId %s ", + msg.appId, msg.shuffleId, shuffleMergeIdToDelete, mergePartitionsInfo.shuffleMergeId)); + } + writeAppAttemptShuffleMergeInfoToDB(new AppAttemptShuffleMergeId( + msg.appId, msg.appAttemptId, msg.shuffleId, shuffleMergeIdToDelete)); + return new AppShuffleMergePartitionsInfo(shuffleMergeIdToDelete, true); + }); + } + /** * Clean up the AppShufflePartitionInfo for a specific AppShuffleInfo. * If cleanupLocalDirs is true, the merged shuffle files will also be deleted. @@ -470,6 +530,40 @@ void closeAndDeleteOutdatedPartitions( }); } + void deleteMergedFiles( + AppAttemptShuffleMergeId appAttemptShuffleMergeId, + AppShuffleInfo appShuffleInfo, + int[] reduceIds, + boolean deleteFromDB) { + if (deleteFromDB) { + removeAppShufflePartitionInfoFromDB(appAttemptShuffleMergeId); + } + int shuffleId = appAttemptShuffleMergeId.shuffleId; + int shuffleMergeId = appAttemptShuffleMergeId.shuffleMergeId; + int dataFilesDeleteCnt = 0; + int indexFilesDeleteCnt = 0; + int metaFilesDeleteCnt = 0; + for (int reduceId : reduceIds) { + File dataFile = + appShuffleInfo.getMergedShuffleDataFile(shuffleId, shuffleMergeId, reduceId); + if (dataFile.delete()) { + dataFilesDeleteCnt++; + } + File indexFile = new File( + appShuffleInfo.getMergedShuffleIndexFilePath(shuffleId, shuffleMergeId, reduceId)); + if (indexFile.delete()) { + indexFilesDeleteCnt++; + } + File metaFile = + appShuffleInfo.getMergedShuffleMetaFile(shuffleId, shuffleMergeId, reduceId); + if (metaFile.delete()) { + metaFilesDeleteCnt++; + } + } + logger.info("Delete {} data files, {} index files, {} meta files for {}", + dataFilesDeleteCnt, indexFilesDeleteCnt, metaFilesDeleteCnt, appAttemptShuffleMergeId); + } + /** * Remove the finalized shuffle partition information for a specific appAttemptShuffleMergeId * @param appAttemptShuffleMergeId @@ -712,6 +806,7 @@ public MergeStatuses finalizeShuffleMerge(FinalizeShuffleMerge msg) { mergeStatuses = new MergeStatuses(msg.shuffleId, msg.shuffleMergeId, bitmaps.toArray(new RoaringBitmap[bitmaps.size()]), Ints.toArray(reduceIds), Longs.toArray(sizes)); + appShuffleInfo.shuffles.get(msg.shuffleId).setReduceIds(Ints.toArray(reduceIds)); } logger.info("{} attempt {} shuffle {} shuffleMerge {}: finalization of shuffle merge completed", msg.appId, msg.appAttemptId, msg.shuffleId, msg.shuffleMergeId); @@ -1465,6 +1560,8 @@ public static class AppShuffleMergePartitionsInfo { private final int shuffleMergeId; private final Map shuffleMergePartitions; + private final AtomicReference reduceIds = new AtomicReference<>(new int[0]); + public AppShuffleMergePartitionsInfo(int shuffleMergeId, boolean shuffleFinalized) { this.shuffleMergeId = shuffleMergeId; this.shuffleMergePartitions = shuffleFinalized ? SHUFFLE_FINALIZED_MARKER : @@ -1479,6 +1576,14 @@ public Map getShuffleMergePartitions() { public boolean isFinalized() { return shuffleMergePartitions == SHUFFLE_FINALIZED_MARKER; } + + public void setReduceIds(int[] reduceIds) { + this.reduceIds.set(reduceIds); + } + + public int[] getReduceIds() { + return this.reduceIds.get(); + } } /** @@ -1687,9 +1792,9 @@ void closeAllFilesAndDeleteIfNeeded(boolean delete) { try { if (dataChannel.isOpen()) { dataChannel.close(); - if (delete) { - dataFile.delete(); - } + } + if (delete) { + dataFile.delete(); } } catch (IOException ioe) { logger.warn("Error closing data channel for {} reduceId {}", diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index ad959c7e2e7c..33411baa09f8 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -50,7 +50,7 @@ public enum Type { FETCH_SHUFFLE_BLOCKS(9), GET_LOCAL_DIRS_FOR_EXECUTORS(10), LOCAL_DIRS_FOR_EXECUTORS(11), PUSH_BLOCK_STREAM(12), FINALIZE_SHUFFLE_MERGE(13), MERGE_STATUSES(14), FETCH_SHUFFLE_BLOCK_CHUNKS(15), DIAGNOSE_CORRUPTION(16), CORRUPTION_CAUSE(17), - PUSH_BLOCK_RETURN_CODE(18); + PUSH_BLOCK_RETURN_CODE(18), REMOVE_SHUFFLE_MERGE(19); private final byte id; @@ -88,6 +88,7 @@ public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) { case 16: return DiagnoseCorruption.decode(buf); case 17: return CorruptionCause.decode(buf); case 18: return BlockPushReturnCode.decode(buf); + case 19: return RemoveShuffleMerge.decode(buf); default: throw new IllegalArgumentException("Unknown message type: " + type); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RemoveShuffleMerge.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RemoveShuffleMerge.java new file mode 100644 index 000000000000..3bcb57a70bcb --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RemoveShuffleMerge.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; +import org.apache.commons.lang3.builder.ToStringBuilder; +import org.apache.commons.lang3.builder.ToStringStyle; + +import org.apache.spark.network.protocol.Encoders; + +/** + * Remove the merged data for a given shuffle. + * Returns {@link Boolean} + * + * @since 3.4.0 + */ +public class RemoveShuffleMerge extends BlockTransferMessage { + public final String appId; + public final int appAttemptId; + public final int shuffleId; + public final int shuffleMergeId; + + public RemoveShuffleMerge( + String appId, + int appAttemptId, + int shuffleId, + int shuffleMergeId) { + this.appId = appId; + this.appAttemptId = appAttemptId; + this.shuffleId = shuffleId; + this.shuffleMergeId = shuffleMergeId; + } + + @Override + protected Type type() { + return Type.REMOVE_SHUFFLE_MERGE; + } + + @Override + public int hashCode() { + return Objects.hashCode(appId, appAttemptId, shuffleId, shuffleMergeId); + } + + @Override + public String toString() { + return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) + .append("appId", appId) + .append("attemptId", appAttemptId) + .append("shuffleId", shuffleId) + .append("shuffleMergeId", shuffleMergeId) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof RemoveShuffleMerge) { + RemoveShuffleMerge o = (RemoveShuffleMerge) other; + return Objects.equal(appId, o.appId) + && appAttemptId == o.appAttemptId + && shuffleId == o.shuffleId + && shuffleMergeId == o.shuffleMergeId; + } + return false; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + 4 + 4 + 4; + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + buf.writeInt(appAttemptId); + buf.writeInt(shuffleId); + buf.writeInt(shuffleMergeId); + } + + public static RemoveShuffleMerge decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + int attemptId = buf.readInt(); + int shuffleId = buf.readInt(); + int shuffleMergeId = buf.readInt(); + return new RemoveShuffleMerge(appId, attemptId, shuffleId, shuffleMergeId); + } +} diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RemoteBlockPushResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RemoteBlockPushResolverSuite.java index eb2c1d9fa5cb..4c0869402762 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RemoteBlockPushResolverSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RemoteBlockPushResolverSuite.java @@ -31,6 +31,7 @@ import java.util.Map; import java.util.concurrent.Semaphore; import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.ImmutableMap; @@ -56,6 +57,7 @@ import org.apache.spark.network.shuffle.protocol.FinalizeShuffleMerge; import org.apache.spark.network.shuffle.protocol.MergeStatuses; import org.apache.spark.network.shuffle.protocol.PushBlockStream; +import org.apache.spark.network.shuffle.protocol.RemoveShuffleMerge; import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; @@ -1316,6 +1318,121 @@ public void testJsonSerializationOfPushShufflePartitionInfo() throws IOException RemoteBlockPushResolver.AppAttemptShuffleMergeId.class)); } + @Test + public void testRemoveShuffleMerge() throws IOException, InterruptedException { + Semaphore closed = new Semaphore(0); + String testApp = "testRemoveShuffleMerge"; + RemoteBlockPushResolver pushResolver = new RemoteBlockPushResolver(conf, null) { + @Override + void closeAndDeleteOutdatedPartitions( + AppAttemptShuffleMergeId appAttemptShuffleMergeId, + Map partitions) { + super.closeAndDeleteOutdatedPartitions(appAttemptShuffleMergeId, partitions); + closed.release(); + } + + @Override + void deleteMergedFiles( + AppAttemptShuffleMergeId appAttemptShuffleMergeId, + AppShuffleInfo appShuffleInfo, + int[] reduceIds, + boolean deleteFromDB) { + super.deleteMergedFiles(appAttemptShuffleMergeId, appShuffleInfo, reduceIds, deleteFromDB); + closed.release(); + } + }; + pushResolver.registerExecutor(testApp, new ExecutorShuffleInfo( + prepareLocalDirs(localDirs, MERGE_DIRECTORY), 1, MERGE_DIRECTORY_META)); + RemoteBlockPushResolver.AppShuffleInfo shuffleInfo = + pushResolver.validateAndGetAppShuffleInfo(testApp); + + // 1. Check whether the data is cleaned up when merged shuffle is finalized + // 1.1 Cleaned up the merged files when msg.shuffleMergeId is current shuffleMergeId + StreamCallbackWithID streamCallback0 = pushResolver.receiveBlockDataAsStream( + new PushBlockStream(testApp, NO_ATTEMPT_ID, 0, 1, 0, 0, 0)); + streamCallback0.onData(streamCallback0.getID(), ByteBuffer.wrap(new byte[2])); + streamCallback0.onComplete(streamCallback0.getID()); + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(testApp, NO_ATTEMPT_ID, 0, 1)); + assertTrue(shuffleInfo.getMergedShuffleMetaFile(0, 1, 0).exists()); + assertTrue(new File(shuffleInfo.getMergedShuffleIndexFilePath(0, 1, 0)).exists()); + assertTrue(shuffleInfo.getMergedShuffleDataFile(0, 1, 0).exists()); + pushResolver.removeShuffleMerge( + new RemoveShuffleMerge(testApp, NO_ATTEMPT_ID, 0, 1)); + closed.tryAcquire(10, TimeUnit.SECONDS); + assertFalse(shuffleInfo.getMergedShuffleMetaFile(0, 1, 0).exists()); + assertFalse(new File(shuffleInfo.getMergedShuffleIndexFilePath(0, 1, 0)).exists()); + assertFalse(shuffleInfo.getMergedShuffleDataFile(0, 1, 0).exists()); + + // 1.2 Cleaned up the merged files when msg.shuffleMergeId is DELETE_ALL_MERGED_SHUFFLE + StreamCallbackWithID streamCallback1 = pushResolver.receiveBlockDataAsStream( + new PushBlockStream(testApp, NO_ATTEMPT_ID, 1, 1, 0, 0, 0)); + streamCallback1.onData(streamCallback1.getID(), ByteBuffer.wrap(new byte[2])); + streamCallback1.onComplete(streamCallback1.getID()); + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(testApp, NO_ATTEMPT_ID, 1, 1)); + assertTrue(shuffleInfo.getMergedShuffleMetaFile(1, 1, 0).exists()); + assertTrue(new File(shuffleInfo.getMergedShuffleIndexFilePath(1, 1, 0)).exists()); + assertTrue(shuffleInfo.getMergedShuffleDataFile(1, 1, 0).exists()); + pushResolver.removeShuffleMerge(new RemoveShuffleMerge(testApp, NO_ATTEMPT_ID, 1, + RemoteBlockPushResolver.DELETE_ALL_MERGED_SHUFFLE)); + closed.tryAcquire(10, TimeUnit.SECONDS); + assertFalse(shuffleInfo.getMergedShuffleMetaFile(1, 1, 0).exists()); + assertFalse(new File(shuffleInfo.getMergedShuffleIndexFilePath(0, 1, 0)).exists()); + assertFalse(shuffleInfo.getMergedShuffleDataFile(1, 1, 0).exists()); + + // 1.3 Cleaned up the merged files when msg.shuffleMergeId < current shuffleMergeId + StreamCallbackWithID streamCallback2 = pushResolver.receiveBlockDataAsStream( + new PushBlockStream(testApp, NO_ATTEMPT_ID, 2, 1, 0, 0, 0)); + streamCallback2.onData(streamCallback2.getID(), ByteBuffer.wrap(new byte[2])); + streamCallback2.onComplete(streamCallback2.getID()); + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(testApp, NO_ATTEMPT_ID, 2, 1)); + assertTrue(shuffleInfo.getMergedShuffleMetaFile(2, 1, 0).exists()); + assertTrue(new File(shuffleInfo.getMergedShuffleIndexFilePath(2, 1, 0)).exists()); + assertTrue(shuffleInfo.getMergedShuffleDataFile(2, 1, 0).exists()); + + RuntimeException e = assertThrows(RuntimeException.class, + () -> pushResolver.removeShuffleMerge( + new RemoveShuffleMerge(testApp, NO_ATTEMPT_ID, 2, 0))); + assertEquals("Asked to remove old shuffle merged data for application " + testApp + + " shuffleId 2 shuffleMergeId 0, but current shuffleMergeId 1 ", e.getMessage()); + + // 1.4 Cleaned up the merged files when msg.shuffleMergeId > current shuffleMergeId + StreamCallbackWithID streamCallback3 = pushResolver.receiveBlockDataAsStream( + new PushBlockStream(testApp, NO_ATTEMPT_ID, 3, 1, 0, 0, 0)); + streamCallback3.onData(streamCallback3.getID(), ByteBuffer.wrap(new byte[2])); + streamCallback3.onComplete(streamCallback3.getID()); + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(testApp, NO_ATTEMPT_ID, 3, 1)); + assertTrue(shuffleInfo.getMergedShuffleMetaFile(3, 1, 0).exists()); + assertTrue(new File(shuffleInfo.getMergedShuffleIndexFilePath(3, 1, 0)).exists()); + assertTrue(shuffleInfo.getMergedShuffleDataFile(3, 1, 0).exists()); + pushResolver.removeShuffleMerge( + new RemoveShuffleMerge(testApp, NO_ATTEMPT_ID, 3, 2)); + closed.tryAcquire(10, TimeUnit.SECONDS); + assertFalse(shuffleInfo.getMergedShuffleMetaFile(3, 1, 0).exists()); + assertFalse(new File(shuffleInfo.getMergedShuffleIndexFilePath(3, 1, 0)).exists()); + assertFalse(shuffleInfo.getMergedShuffleDataFile(3, 1, 0).exists()); + + // 2. Check whether the data is cleaned up when merged shuffle is not finalized. + StreamCallbackWithID streamCallback4 = pushResolver.receiveBlockDataAsStream( + new PushBlockStream(testApp, NO_ATTEMPT_ID, 4, 1, 0, 0, 0)); + streamCallback4.onData(streamCallback4.getID(), ByteBuffer.wrap(new byte[2])); + streamCallback4.onComplete(streamCallback4.getID()); + assertTrue(shuffleInfo.getMergedShuffleMetaFile(4, 1, 0).exists()); + pushResolver.removeShuffleMerge( + new RemoveShuffleMerge(testApp, NO_ATTEMPT_ID, 4, 1)); + closed.tryAcquire(10, TimeUnit.SECONDS); + assertFalse(shuffleInfo.getMergedShuffleMetaFile(4, 1, 0).exists()); + + // 3. Check whether the data is cleaned up when higher shuffleMergeId finalize request comes + StreamCallbackWithID streamCallback5 = pushResolver.receiveBlockDataAsStream( + new PushBlockStream(testApp, NO_ATTEMPT_ID, 5, 1, 0, 0, 0)); + streamCallback5.onData(streamCallback5.getID(), ByteBuffer.wrap(new byte[2])); + streamCallback5.onComplete(streamCallback5.getID()); + assertTrue(shuffleInfo.getMergedShuffleMetaFile(5, 1, 0).exists()); + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(testApp, NO_ATTEMPT_ID, 5, 2)); + closed.tryAcquire(10, TimeUnit.SECONDS); + assertFalse(shuffleInfo.getMergedShuffleMetaFile(5, 1, 0).exists()); + } + private void useTestFiles(boolean useTestIndexFile, boolean useTestMetaFile) throws IOException { pushResolver = new RemoteBlockPushResolver(conf, null) { @Override diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index a163fef693ea..fade0b86dd8f 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -1207,7 +1207,7 @@ private[spark] class MapOutputTrackerMaster( // This method is only called in local-mode. override def getShufflePushMergerLocations(shuffleId: Int): Seq[BlockManagerId] = { - shuffleStatuses(shuffleId).getShufflePushMergerLocations + shuffleStatuses.get(shuffleId).map(_.getShufflePushMergerLocations).getOrElse(Seq.empty) } override def stop(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index cc991178481f..c53730818a7a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1398,9 +1398,7 @@ private[spark] class DAGScheduler( */ private def prepareShuffleServicesForShuffleMapStage(stage: ShuffleMapStage): Unit = { assert(stage.shuffleDep.shuffleMergeAllowed && !stage.shuffleDep.isShuffleMergeFinalizedMarked) - if (stage.shuffleDep.getMergerLocs.isEmpty) { - getAndSetShufflePushMergerLocations(stage) - } + configureShufflePushMergerLocations(stage) val shuffleId = stage.shuffleDep.shuffleId val shuffleMergeId = stage.shuffleDep.shuffleMergeId @@ -1415,17 +1413,17 @@ private[spark] class DAGScheduler( } } - private def getAndSetShufflePushMergerLocations(stage: ShuffleMapStage): Seq[BlockManagerId] = { + private def configureShufflePushMergerLocations(stage: ShuffleMapStage): Unit = { + if (stage.shuffleDep.getMergerLocs.nonEmpty) return val mergerLocs = sc.schedulerBackend.getShufflePushMergerLocations( stage.shuffleDep.partitioner.numPartitions, stage.resourceProfileId) if (mergerLocs.nonEmpty) { stage.shuffleDep.setMergerLocs(mergerLocs) + mapOutputTracker.registerShufflePushMergerLocations(stage.shuffleDep.shuffleId, mergerLocs) + logDebug(s"Shuffle merge locations for shuffle ${stage.shuffleDep.shuffleId} with" + + s" shuffle merge ${stage.shuffleDep.shuffleMergeId} is" + + s" ${stage.shuffleDep.getMergerLocs.map(_.host).mkString(", ")}") } - - logDebug(s"Shuffle merge locations for shuffle ${stage.shuffleDep.shuffleId} with" + - s" shuffle merge ${stage.shuffleDep.shuffleMergeId} is" + - s" ${stage.shuffleDep.getMergerLocs.map(_.host).mkString(", ")}") - mergerLocs } /** Called when stage's parents are available and we can now do its task. */ @@ -2629,16 +2627,15 @@ private[spark] class DAGScheduler( shuffleIdToMapStage.filter { case (_, stage) => stage.shuffleDep.shuffleMergeAllowed && stage.shuffleDep.getMergerLocs.isEmpty && runningStages.contains(stage) - }.foreach { case(_, stage: ShuffleMapStage) => - if (getAndSetShufflePushMergerLocations(stage).nonEmpty) { - logInfo(s"Shuffle merge enabled adaptively for $stage with shuffle" + - s" ${stage.shuffleDep.shuffleId} and shuffle merge" + - s" ${stage.shuffleDep.shuffleMergeId} with ${stage.shuffleDep.getMergerLocs.size}" + - s" merger locations") - mapOutputTracker.registerShufflePushMergerLocations(stage.shuffleDep.shuffleId, - stage.shuffleDep.getMergerLocs) - } + }.foreach { case (_, stage: ShuffleMapStage) => + configureShufflePushMergerLocations(stage) + if (stage.shuffleDep.getMergerLocs.nonEmpty) { + logInfo(s"Shuffle merge enabled adaptively for $stage with shuffle" + + s" ${stage.shuffleDep.shuffleId} and shuffle merge" + + s" ${stage.shuffleDep.shuffleMergeId} with ${stage.shuffleDep.getMergerLocs.size}" + + s" merger locations") } + } } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index d30272c51be3..681a812e880a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -32,7 +32,7 @@ import com.google.common.cache.CacheBuilder import org.apache.spark.{MapOutputTrackerMaster, SparkConf} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.{config, Logging} -import org.apache.spark.network.shuffle.ExternalBlockStoreClient +import org.apache.spark.network.shuffle.{ExternalBlockStoreClient, RemoteBlockPushResolver} import org.apache.spark.rpc.{IsolatedThreadSafeRpcEndpoint, RpcCallContext, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.{CoarseGrainedClusterMessages, CoarseGrainedSchedulerBackend} @@ -321,14 +321,6 @@ class BlockManagerMasterEndpoint( } private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = { - val removeMsg = RemoveShuffle(shuffleId) - val removeShuffleFromExecutorsFutures = blockManagerInfo.values.map { bm => - bm.storageEndpoint.ask[Boolean](removeMsg).recover { - // use false as default value means no shuffle data were removed - handleBlockRemovalFailure("shuffle", shuffleId.toString, bm.blockManagerId, false) - } - }.toSeq - // Find all shuffle blocks on executors that are no longer running val blocksToDeleteByShuffleService = new mutable.HashMap[BlockManagerId, mutable.HashSet[BlockId]] @@ -366,8 +358,32 @@ class BlockManagerMasterEndpoint( } }.getOrElse(Seq.empty) + val removeShuffleMergeFromShuffleServicesFutures = + externalBlockStoreClient.map { shuffleClient => + val mergerLocations = + if (Utils.isPushBasedShuffleEnabled(conf, isDriver)) { + mapOutputTracker.getShufflePushMergerLocations(shuffleId) + } else { + Seq.empty[BlockManagerId] + } + mergerLocations.map { bmId => + Future[Boolean] { + shuffleClient.removeShuffleMerge(bmId.host, bmId.port, shuffleId, + RemoteBlockPushResolver.DELETE_ALL_MERGED_SHUFFLE) + } + } + }.getOrElse(Seq.empty) + + val removeMsg = RemoveShuffle(shuffleId) + val removeShuffleFromExecutorsFutures = blockManagerInfo.values.map { bm => + bm.storageEndpoint.ask[Boolean](removeMsg).recover { + // use false as default value means no shuffle data were removed + handleBlockRemovalFailure("shuffle", shuffleId.toString, bm.blockManagerId, false) + } + }.toSeq Future.sequence(removeShuffleFromExecutorsFutures ++ - removeShuffleFromShuffleServicesFutures) + removeShuffleFromShuffleServicesFutures ++ + removeShuffleMergeFromShuffleServicesFutures) } /** diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index a13527f4b74c..dfad4a924d7c 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -17,12 +17,15 @@ package org.apache.spark +import java.util.{Collections => JCollections, HashSet => JHashSet} import java.util.concurrent.atomic.LongAdder +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.mockito.ArgumentMatchers.any import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock import org.roaringbitmap.RoaringBitmap import org.apache.spark.LocalSparkContext._ @@ -30,10 +33,11 @@ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.internal.config._ import org.apache.spark.internal.config.Network.{RPC_ASK_TIMEOUT, RPC_MESSAGE_MAX_SIZE} import org.apache.spark.internal.config.Tests.IS_TESTING -import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv} +import org.apache.spark.network.shuffle.ExternalBlockStoreClient +import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus, MapStatus, MergeStatus} import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId, ShuffleMergedBlockId} +import org.apache.spark.storage.{BlockManagerId, BlockManagerMasterEndpoint, ShuffleBlockId, ShuffleMergedBlockId} class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { private val conf = new SparkConf @@ -913,9 +917,63 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { slaveRpcEnv.shutdown() } + private def fetchDeclaredField(value: AnyRef, fieldName: String): AnyRef = { + val field = value.getClass.getDeclaredField(fieldName) + field.setAccessible(true) + field.get(value) + } + + private def lookupBlockManagerMasterEndpoint(sc: SparkContext): BlockManagerMasterEndpoint = { + val rpcEnv = sc.env.rpcEnv + val dispatcher = fetchDeclaredField(rpcEnv, "dispatcher") + fetchDeclaredField(dispatcher, "endpointRefs"). + asInstanceOf[java.util.Map[RpcEndpoint, RpcEndpointRef]].asScala. + filter(_._1.isInstanceOf[BlockManagerMasterEndpoint]). + head._1.asInstanceOf[BlockManagerMasterEndpoint] + } + + test("SPARK-40480: shuffle remove should cleanup merged files as well") { + val newConf = new SparkConf + newConf.set("spark.shuffle.push.enabled", "true") + newConf.set("spark.shuffle.service.enabled", "true") + newConf.set(SERIALIZER, "org.apache.spark.serializer.KryoSerializer") + newConf.set(IS_TESTING, true) + + val SHUFFLE_ID = 10 + withSpark(new SparkContext("local", "MapOutputTrackerSuite", newConf)) { sc => + val masterTracker = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + + val blockStoreClient = mock(classOf[ExternalBlockStoreClient]) + val bmMaster = lookupBlockManagerMasterEndpoint(sc) + val field = bmMaster.getClass.getDeclaredField("externalBlockStoreClient") + field.setAccessible(true) + field.set(bmMaster, Some(blockStoreClient)) + + masterTracker.registerShuffle(SHUFFLE_ID, 10, 10) + val mergerLocs = (1 to 10).map(x => BlockManagerId(s"exec-$x", s"host-$x", x)) + masterTracker.registerShufflePushMergerLocations(SHUFFLE_ID, mergerLocs) + + assert(masterTracker.getShufflePushMergerLocations(SHUFFLE_ID).map(_.host).toSet == + mergerLocs.map(_.host).toSet) + + val foundHosts = JCollections.synchronizedSet(new JHashSet[String]()) + when(blockStoreClient.removeShuffleMerge(any(), any(), any(), any())).thenAnswer( + (m: InvocationOnMock) => { + val host = m.getArgument(0).asInstanceOf[String] + val shuffleId = m.getArgument(2).asInstanceOf[Int] + assert(shuffleId == SHUFFLE_ID) + foundHosts.add(host) + true + }) + + sc.cleaner.get.doCleanupShuffle(SHUFFLE_ID, blocking = true) + assert(foundHosts.asScala == mergerLocs.map(_.host).toSet) + } + } + test("SPARK-34826: Adaptive shuffle mergers") { val newConf = new SparkConf - newConf.set("spark.shuffle.push.based.enabled", "true") + newConf.set("spark.shuffle.push.enabled", "true") newConf.set("spark.shuffle.service.enabled", "true") // needs TorrentBroadcast so need a SparkContext