diff --git a/common/src/main/java/org/apache/uniffle/common/exception/NotRetryException.java b/common/src/main/java/org/apache/uniffle/common/exception/NotRetryException.java new file mode 100644 index 0000000000..49eaee6448 --- /dev/null +++ b/common/src/main/java/org/apache/uniffle/common/exception/NotRetryException.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.uniffle.common.exception; + +public class NotRetryException extends RssException { + + public NotRetryException(String message) { + super(message); + } + + public NotRetryException(String message, Throwable e) { + super(message, e); + } +} diff --git a/common/src/main/java/org/apache/uniffle/common/util/RetryUtils.java b/common/src/main/java/org/apache/uniffle/common/util/RetryUtils.java index 603873f313..889d459ca0 100644 --- a/common/src/main/java/org/apache/uniffle/common/util/RetryUtils.java +++ b/common/src/main/java/org/apache/uniffle/common/util/RetryUtils.java @@ -22,6 +22,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.uniffle.common.exception.NotRetryException; + public class RetryUtils { private static final Logger LOG = LoggerFactory.getLogger(RetryUtils.class); @@ -48,7 +50,8 @@ public static T retry(RetryCmd cmd, RetryCallBack callBack, long interval return ret; } catch (Throwable t) { retry++; - if ((exceptionClasses != null && !isInstanceOf(exceptionClasses, t)) || retry >= retryTimes) { + if ((exceptionClasses != null && !isInstanceOf(exceptionClasses, t)) || retry >= retryTimes + || t instanceof NotRetryException) { throw t; } else { LOG.info("Retry due to Throwable, " + t.getClass().getName() + " " + t.getMessage()); diff --git a/common/src/test/java/org/apache/uniffle/common/util/RetryUtilsTest.java b/common/src/test/java/org/apache/uniffle/common/util/RetryUtilsTest.java index 70c59f8746..1d1bc1302d 100644 --- a/common/src/test/java/org/apache/uniffle/common/util/RetryUtilsTest.java +++ b/common/src/test/java/org/apache/uniffle/common/util/RetryUtilsTest.java @@ -20,6 +20,7 @@ import java.util.concurrent.atomic.AtomicInteger; import com.google.common.collect.Sets; +import org.apache.uniffle.common.exception.NotRetryException; import org.junit.jupiter.api.Test; import org.apache.uniffle.common.exception.RssException; @@ -67,5 +68,16 @@ public void testRetry() { // ignore } assertEquals(tryTimes.get(), 1); + + tryTimes.set(0); + try { + RetryUtils.retry(() -> { + tryTimes.incrementAndGet(); + throw new NotRetryException(""); + }, 10, maxTryTime); + } catch (Throwable throwable) { + // ignore + } + assertEquals(tryTimes.get(), 1); } } diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java index a1cc6a1cf3..c5b8c2b22b 100644 --- a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java +++ b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java @@ -27,6 +27,7 @@ import com.google.common.collect.Maps; import com.google.common.collect.Sets; import com.google.common.io.Files; +import com.google.protobuf.ByteString; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Disabled; @@ -56,6 +57,7 @@ import org.apache.uniffle.common.config.RssBaseConf; import org.apache.uniffle.common.util.Constants; import org.apache.uniffle.coordinator.CoordinatorConf; +import org.apache.uniffle.proto.RssProtos; import org.apache.uniffle.server.ShuffleDataFlushEvent; import org.apache.uniffle.server.ShuffleServerConf; import org.apache.uniffle.server.ShuffleServerGrpcMetrics; @@ -393,6 +395,47 @@ public void sendDataWithoutRegisterTest() throws Exception { assertEquals(0, shuffleServers.get(0).getPreAllocatedMemory()); } + + @Test + public void sendDataWithoutRequirePreAllocation() throws Exception { + String appId = "sendDataWithoutRequirePreAllocation"; + List blockInfos = Lists.newArrayList(new ShuffleBlockInfo(0, 0, 0, 100, 0, + new byte[]{}, Lists.newArrayList(), 0, 100, 0)); + Map> partitionToBlocks = Maps.newHashMap(); + partitionToBlocks.put(0, blockInfos); + Map>> shuffleToBlocks = Maps.newHashMap(); + shuffleToBlocks.put(0, partitionToBlocks); + for (Map.Entry>> stb : shuffleToBlocks.entrySet()) { + List shuffleData = Lists.newArrayList(); + for (Map.Entry> ptb : stb.getValue().entrySet()) { + List shuffleBlocks = Lists.newArrayList(); + for (ShuffleBlockInfo sbi : ptb.getValue()) { + shuffleBlocks.add(RssProtos.ShuffleBlock.newBuilder().setBlockId(sbi.getBlockId()) + .setCrc(sbi.getCrc()) + .setLength(sbi.getLength()) + .setTaskAttemptId(sbi.getTaskAttemptId()) + .setUncompressLength(sbi.getUncompressLength()) + .setData(ByteString.copyFrom(sbi.getData())) + .build()); + } + shuffleData.add(RssProtos.ShuffleData.newBuilder().setPartitionId(ptb.getKey()) + .addAllBlock(shuffleBlocks) + .build()); + } + + RssProtos.SendShuffleDataRequest rpcRequest = RssProtos.SendShuffleDataRequest.newBuilder() + .setAppId(appId) + .setShuffleId(0) + .setRequireBufferId(10000) + .addAllShuffleData(shuffleData) + .build(); + RssProtos.SendShuffleDataResponse response = + shuffleServerClient.getBlockingStub().sendShuffleData(rpcRequest); + assertTrue(RssProtos.StatusCode.INTERNAL_ERROR.equals(response.getStatus())); + assertTrue(response.getRetMsg().contains("Can't find requireBufferId[10000]")); + } + } + @Test public void multipleShuffleResultTest() throws Exception { Set expectedBlockIds = Sets.newConcurrentHashSet(); diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java index b9ac7b8a95..2852247b37 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java @@ -54,7 +54,9 @@ import org.apache.uniffle.common.PartitionRange; import org.apache.uniffle.common.RemoteStorageInfo; import org.apache.uniffle.common.ShuffleBlockInfo; +import org.apache.uniffle.common.exception.NotRetryException; import org.apache.uniffle.common.exception.RssException; +import org.apache.uniffle.common.util.RetryUtils; import org.apache.uniffle.proto.RssProtos.AppHeartBeatRequest; import org.apache.uniffle.proto.RssProtos.AppHeartBeatResponse; import org.apache.uniffle.proto.RssProtos.FinishShuffleRequest; @@ -109,7 +111,7 @@ public ShuffleServerGrpcClient(String host, int port, int maxRetryAttempts, bool blockingStub = ShuffleServerGrpc.newBlockingStub(channel); } - private ShuffleServerBlockingStub getBlockingStub() { + public ShuffleServerBlockingStub getBlockingStub() { return blockingStub.withDeadlineAfter(rpcTimeout, TimeUnit.MILLISECONDS); } @@ -253,29 +255,40 @@ public RssSendShuffleDataResponse sendShuffleData(RssSendShuffleDataRequest requ .build()); } - long requireId = requirePreAllocation(size, request.getRetryMax(), request.getRetryIntervalMax()); - if (requireId != FAILED_REQUIRE_ID) { - SendShuffleDataRequest rpcRequest = SendShuffleDataRequest.newBuilder() - .setAppId(appId) - .setShuffleId(stb.getKey()) - .setRequireBufferId(requireId) - .addAllShuffleData(shuffleData) - .build(); - long start = System.currentTimeMillis(); - SendShuffleDataResponse response = doSendData(rpcRequest); - LOG.info("Do sendShuffleData to {}:{} rpc cost:" + (System.currentTimeMillis() - start) - + " ms for " + size + " bytes with " + blockNum + " blocks", host, port); - - if (response.getStatus() != StatusCode.SUCCESS) { - String msg = "Can't send shuffle data with " + blockNum - + " blocks to " + host + ":" + port - + ", statusCode=" + response.getStatus() - + ", errorMsg:" + response.getRetMsg(); - LOG.warn(msg); - isSuccessful = false; - break; - } - } else { + final int allocateSize = size; + final int finalBlockNum = blockNum; + try { + RetryUtils.retry(() -> { + long requireId = requirePreAllocation(allocateSize, request.getRetryMax(), request.getRetryIntervalMax()); + if (requireId == FAILED_REQUIRE_ID) { + throw new RssException(String.format( + "requirePreAllocation failed! size[%s], host[%s], port[%s]", allocateSize, host, port)); + } + SendShuffleDataRequest rpcRequest = SendShuffleDataRequest.newBuilder() + .setAppId(appId) + .setShuffleId(stb.getKey()) + .setRequireBufferId(requireId) + .addAllShuffleData(shuffleData) + .build(); + long start = System.currentTimeMillis(); + SendShuffleDataResponse response = getBlockingStub().sendShuffleData(rpcRequest); + LOG.info("Do sendShuffleData to {}:{} rpc cost:" + (System.currentTimeMillis() - start) + + " ms for " + allocateSize + " bytes with " + finalBlockNum + " blocks", host, port); + if (response.getStatus() != StatusCode.SUCCESS) { + String msg = "Can't send shuffle data with " + finalBlockNum + + " blocks to " + host + ":" + port + + ", statusCode=" + response.getStatus() + + ", errorMsg:" + response.getRetMsg(); + if (response.getStatus() == StatusCode.NO_REGISTER) { + throw new NotRetryException(msg); + } else { + throw new RssException(msg); + } + } + return response; + }, request.getRetryIntervalMax(), maxRetryAttempts); + } catch (Throwable throwable) { + LOG.warn(throwable.getMessage()); isSuccessful = false; break; } @@ -290,21 +303,6 @@ public RssSendShuffleDataResponse sendShuffleData(RssSendShuffleDataRequest requ return response; } - private SendShuffleDataResponse doSendData(SendShuffleDataRequest rpcRequest) { - int retryNum = 0; - while (retryNum < maxRetryAttempts) { - try { - SendShuffleDataResponse response = getBlockingStub().sendShuffleData(rpcRequest); - return response; - } catch (Exception e) { - retryNum++; - LOG.warn("Send data to host[" + host + "], port[" + port - + "] failed, try again, retryNum[" + retryNum + "]", e); - } - } - throw new RssException("Send data to host[" + host + "], port[" + port + "] failed"); - } - @Override public RssSendCommitResponse sendCommit(RssSendCommitRequest request) { ShuffleCommitResponse rpcResponse = doSendCommit(request.getAppId(), request.getShuffleId()); diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java index 7fd218531c..8cdb392f29 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java @@ -150,8 +150,18 @@ public void sendShuffleData(SendShuffleDataRequest req, ShuffleServerMetrics.counterTotalReceivedDataSize.inc(requireSize); boolean isPreAllocated = shuffleServer.getShuffleTaskManager().isPreAllocated(requireBufferId); if (!isPreAllocated) { - LOG.warn("Can't find requireBufferId[" + requireBufferId + "] for appId[" + appId - + "], shuffleId[" + shuffleId + "]"); + String errorMsg = "Can't find requireBufferId[" + requireBufferId + "] for appId[" + appId + + "], shuffleId[" + shuffleId + "]"; + LOG.warn(errorMsg); + responseMessage = errorMsg; + reply = SendShuffleDataResponse + .newBuilder() + .setStatus(valueOf(StatusCode.INTERNAL_ERROR)) + .setRetMsg(responseMessage) + .build(); + responseObserver.onNext(reply); + responseObserver.onCompleted(); + return; } final long start = System.currentTimeMillis(); List shufflePartitionedData = toPartitionedData(req);