diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java index fd7f72ccbb..e7f7c49b6b 100644 --- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java +++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java @@ -24,10 +24,12 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.BlockingQueue; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; +import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.function.Function; import java.util.function.Supplier; @@ -98,6 +100,7 @@ public class RssShuffleWriter extends ShuffleWriter { private long taskAttemptId; private ShuffleDependency shuffleDependency; private ShuffleWriteMetrics shuffleWriteMetrics; + private final BlockingQueue finishEventQueue = new LinkedBlockingQueue<>(); private Partitioner partitioner; private boolean shouldPartition; private WriteBufferManager bufferManager; @@ -282,7 +285,7 @@ private void writeImpl(Iterator> records) { long s = System.currentTimeMillis(); checkAllBufferSpilled(); checkSentRecordCount(recordCount); - checkBlockSendResult(new HashSet<>(blockIds)); + checkBlockSendResult(blockIds); checkSentBlockCount(); final long checkDuration = System.currentTimeMillis() - s; long commitDuration = 0; @@ -396,6 +399,13 @@ protected List> postBlockEvent( List> futures = new ArrayList<>(); for (AddBlockEvent event : bufferManager.buildBlockEvents(shuffleBlockInfoList)) { futures.add(shuffleManager.sendData(event)); + event.addCallback( + () -> { + boolean ret = finishEventQueue.add(new Object()); + if (!ret) { + LOG.error("Add event " + event + " to finishEventQueue fail"); + } + }); } return futures; } @@ -435,42 +445,73 @@ protected void sendCommit() { @VisibleForTesting protected void checkBlockSendResult(Set blockIds) { - long start = System.currentTimeMillis(); - while (true) { - Set failedBlockIds = shuffleManager.getFailedBlockIds(taskId); - Set successBlockIds = shuffleManager.getSuccessBlockIds(taskId); - // if failed when send data to shuffle server, mark task as failed - if (failedBlockIds.size() > 0) { - String errorMsg = - "Send failed: Task[" - + taskId - + "] failed because " - + failedBlockIds.size() - + " blocks can't be sent to shuffle server: " - + shuffleManager.getBlockIdsFailedSendTracker(taskId).getFaultyShuffleServers(); - LOG.error(errorMsg); - throw new RssSendFailedException(errorMsg); - } + boolean interrupted = false; - // remove blockIds which was sent successfully, if there has none left, all data are sent - blockIds.removeAll(successBlockIds); - if (blockIds.isEmpty()) { - break; + try { + long remainingMs = sendCheckTimeout; + long end = System.currentTimeMillis() + remainingMs; + long currentAckValue = 0; + for (Long blockId : blockIds) { + currentAckValue ^= blockId; + } + while (true) { + try { + finishEventQueue.clear(); + checkDataIfAnyFailure(); + Set successBlockIds = shuffleManager.getSuccessBlockIds(taskId); + if (blockIds.size() == successBlockIds.size()) { + for (Long successBlockId : successBlockIds) { + currentAckValue ^= successBlockId; + } + if (currentAckValue != 0) { + String errorMsg = "Ack value is not equal to 0, it should not happen!"; + throw new RssSendFailedException(errorMsg); + } + break; + } + if (finishEventQueue.isEmpty()) { + remainingMs = Math.max(end - System.currentTimeMillis(), 0); + Object event = finishEventQueue.poll(remainingMs, TimeUnit.MILLISECONDS); + if (event == null) { + break; + } + } + } catch (InterruptedException e) { + interrupted = true; + } } - LOG.info("Wait " + blockIds.size() + " blocks sent to shuffle server"); - Uninterruptibles.sleepUninterruptibly(sendCheckInterval, TimeUnit.MILLISECONDS); - if (System.currentTimeMillis() - start > sendCheckTimeout) { + Set successBlockIds = shuffleManager.getSuccessBlockIds(taskId); + if (currentAckValue != 0 || blockIds.size() != successBlockIds.size()) { + int failedBlockCount = blockIds.size() - successBlockIds.size(); String errorMsg = "Timeout: Task[" + taskId + "] failed because " - + blockIds.size() + + failedBlockCount + " blocks can't be sent to shuffle server in " + sendCheckTimeout + " ms."; LOG.error(errorMsg); throw new RssWaitFailedException(errorMsg); } + } finally { + if (interrupted) { + Thread.currentThread().interrupt(); + } + } + } + + protected void checkDataIfAnyFailure() { + Set failedBlockIds = shuffleManager.getFailedBlockIds(taskId); + if (failedBlockIds.size() > 0) { + String errorMsg = + "Send failed: Task[" + + taskId + + "] failed because " + + failedBlockIds.size() + + " blocks can't be sent to shuffle server: " + + shuffleManager.getBlockIdsFailedSendTracker(taskId).getFaultyShuffleServers(); + throw new RssSendFailedException(errorMsg); } } diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java index fc9bfe53a2..4deef511d6 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java @@ -389,7 +389,7 @@ protected void writeImpl(Iterator> records) { long checkStartTs = System.currentTimeMillis(); checkAllBufferSpilled(); checkSentRecordCount(recordCount); - checkBlockSendResult(new HashSet<>(blockIds)); + checkBlockSendResult(blockIds); checkSentBlockCount(); bufferManager.getShuffleServerPushCostTracker().statistics(); long commitStartTs = System.currentTimeMillis(); @@ -524,14 +524,23 @@ protected void checkBlockSendResult(Set blockIds) { try { long remainingMs = sendCheckTimeout; long end = System.currentTimeMillis() + remainingMs; - + long currentAckValue = 0; + for (Long blockId : blockIds) { + currentAckValue ^= blockId; + } while (true) { try { finishEventQueue.clear(); checkDataIfAnyFailure(); Set successBlockIds = shuffleManager.getSuccessBlockIds(taskId); - blockIds.removeAll(successBlockIds); - if (blockIds.isEmpty()) { + if (blockIds.size() == successBlockIds.size()) { + for (Long successBlockId : successBlockIds) { + currentAckValue ^= successBlockId; + } + if (currentAckValue != 0) { + String errorMsg = "Ack value is not equal to 0, it should not happen!"; + throw new RssSendFailedException(errorMsg); + } break; } if (finishEventQueue.isEmpty()) { @@ -545,12 +554,14 @@ protected void checkBlockSendResult(Set blockIds) { interrupted = true; } } - if (!blockIds.isEmpty()) { + Set successBlockIds = shuffleManager.getSuccessBlockIds(taskId); + if (currentAckValue != 0 || blockIds.size() != successBlockIds.size()) { + int failedBlockCount = blockIds.size() - successBlockIds.size(); String errorMsg = "Timeout: Task[" + taskId + "] failed because " - + blockIds.size() + + failedBlockCount + " blocks can't be sent to shuffle server in " + sendCheckTimeout + " ms.";