diff --git a/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java b/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java index ef47e2116d..1c98bdabd3 100644 --- a/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java +++ b/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java @@ -151,6 +151,15 @@ public class RssMRConfig { public static final int RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER_DEFAULT_VALUE = RssClientConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER_DEFAULT_VALUE; + public static final String RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL = + MR_RSS_CONFIG_PREFIX + RssClientConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL; + public static final long RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL_DEFAULT_VALUE = + RssClientConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL_DEFAULT_VALUE; + public static final String RSS_CLIENT_ASSIGNMENT_RETRY_TIMES = + MR_RSS_CONFIG_PREFIX + RssClientConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES; + public static final int RSS_CLIENT_ASSIGNMENT_RETRY_TIMES_DEFAULT_VALUE = + RssClientConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES_DEFAULT_VALUE; + public static final String RSS_CONF_FILE = "rss_conf.xml"; public static final Set RSS_MANDATORY_CLUSTER_CONF = Sets.newHashSet( diff --git a/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java b/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java index c65f2a2d2b..a3cb2700eb 100644 --- a/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java +++ b/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java @@ -77,7 +77,9 @@ import org.apache.uniffle.common.RemoteStorageInfo; import org.apache.uniffle.common.ShuffleAssignmentsInfo; import org.apache.uniffle.common.ShuffleServerInfo; +import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.util.Constants; +import org.apache.uniffle.common.util.RetryUtils; import org.apache.uniffle.storage.util.StorageType; public class RssMRAppMaster extends MRAppMaster { @@ -128,25 +130,9 @@ public static void main(String[] args) { } assignmentTags.add(Constants.SHUFFLE_SERVER_VERSION); - int requiredAssignmentShuffleServersNum = conf.getInt( - RssMRConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER, - RssMRConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER_DEFAULT_VALUE - ); - ApplicationAttemptId applicationAttemptId = RssMRUtils.getApplicationAttemptId(); String appId = applicationAttemptId.toString(); - ShuffleAssignmentsInfo response = - client.getShuffleAssignments( - appId, - 0, - numReduceTasks, - 1, - Sets.newHashSet(assignmentTags), - requiredAssignmentShuffleServersNum - ); - - Map> serverToPartitionRanges = response.getServerToPartitionRanges(); final ScheduledExecutorService scheduledExecutorService = Executors.newSingleThreadScheduledExecutor( new ThreadFactory() { @Override @@ -157,40 +143,9 @@ public Thread newThread(Runnable r) { } } ); - if (serverToPartitionRanges == null || serverToPartitionRanges.isEmpty()) { - return; - } - - long heartbeatInterval = conf.getLong(RssMRConfig.RSS_HEARTBEAT_INTERVAL, - RssMRConfig.RSS_HEARTBEAT_INTERVAL_DEFAULT_VALUE); - long heartbeatTimeout = conf.getLong(RssMRConfig.RSS_HEARTBEAT_TIMEOUT, heartbeatInterval / 2); - scheduledExecutorService.scheduleAtFixedRate( - () -> { - try { - client.sendAppHeartbeat(appId, heartbeatTimeout); - LOG.info("Finish send heartbeat to coordinator and servers"); - } catch (Exception e) { - LOG.warn("Fail to send heartbeat to coordinator and servers", e); - } - }, - heartbeatInterval / 2, - heartbeatInterval, - TimeUnit.MILLISECONDS); JobConf extraConf = new JobConf(); extraConf.clear(); - // write shuffle worker assignments to submit work directory - // format is as below: - // mapreduce.rss.assignment.partition.1:server1,server2 - // mapreduce.rss.assignment.partition.2:server3,server4 - // ... - response.getPartitionToServers().entrySet().forEach(entry -> { - List servers = Lists.newArrayList(); - for (ShuffleServerInfo server : entry.getValue()) { - servers.add(server.getHost() + ":" + server.getPort()); - } - extraConf.set(RssMRConfig.RSS_ASSIGNMENT_PREFIX + entry.getKey(), StringUtils.join(servers, ",")); - }); // get remote storage from coordinator if necessary boolean dynamicConfEnabled = conf.getBoolean(RssMRConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED, @@ -233,14 +188,80 @@ public Thread newThread(Runnable r) { } conf.setInt(MRJobConfig.REDUCE_MAX_ATTEMPTS, originalAttempts + inc); } + + int requiredAssignmentShuffleServersNum = conf.getInt( + RssMRConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER, + RssMRConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER_DEFAULT_VALUE + ); + + // retryInterval must bigger than `rss.server.heartbeat.timeout`, or maybe it will return the same result + long retryInterval = conf.getLong(RssMRConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL, + RssMRConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL_DEFAULT_VALUE); + int retryTimes = conf.getInt(RssMRConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES, + RssMRConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES_DEFAULT_VALUE); + ShuffleAssignmentsInfo response; + try { + response = RetryUtils.retry(() -> { + ShuffleAssignmentsInfo shuffleAssignments = + client.getShuffleAssignments( + appId, + 0, + numReduceTasks, + 1, + Sets.newHashSet(assignmentTags), + requiredAssignmentShuffleServersNum + ); + + Map> serverToPartitionRanges = + shuffleAssignments.getServerToPartitionRanges(); + + if (serverToPartitionRanges == null || serverToPartitionRanges.isEmpty()) { + return null; + } + LOG.info("Start to register shuffle"); + long start = System.currentTimeMillis(); + serverToPartitionRanges.entrySet().forEach(entry -> { + client.registerShuffle( + entry.getKey(), appId, 0, entry.getValue(), remoteStorage); + }); + LOG.info("Finish register shuffle with " + (System.currentTimeMillis() - start) + " ms"); + return shuffleAssignments; + }, retryInterval, retryTimes); + } catch (Throwable throwable) { + throw new RssException("registerShuffle failed!", throwable); + } - LOG.info("Start to register shuffle"); - long start = System.currentTimeMillis(); - serverToPartitionRanges.entrySet().forEach(entry -> { - client.registerShuffle( - entry.getKey(), appId, 0, entry.getValue(), remoteStorage); + if (response == null) { + return; + } + long heartbeatInterval = conf.getLong(RssMRConfig.RSS_HEARTBEAT_INTERVAL, + RssMRConfig.RSS_HEARTBEAT_INTERVAL_DEFAULT_VALUE); + long heartbeatTimeout = conf.getLong(RssMRConfig.RSS_HEARTBEAT_TIMEOUT, heartbeatInterval / 2); + scheduledExecutorService.scheduleAtFixedRate( + () -> { + try { + client.sendAppHeartbeat(appId, heartbeatTimeout); + LOG.info("Finish send heartbeat to coordinator and servers"); + } catch (Exception e) { + LOG.warn("Fail to send heartbeat to coordinator and servers", e); + } + }, + heartbeatInterval / 2, + heartbeatInterval, + TimeUnit.MILLISECONDS); + + // write shuffle worker assignments to submit work directory + // format is as below: + // mapreduce.rss.assignment.partition.1:server1,server2 + // mapreduce.rss.assignment.partition.2:server3,server4 + // ... + response.getPartitionToServers().entrySet().forEach(entry -> { + List servers = Lists.newArrayList(); + for (ShuffleServerInfo server : entry.getValue()) { + servers.add(server.getHost() + ":" + server.getPort()); + } + extraConf.set(RssMRConfig.RSS_ASSIGNMENT_PREFIX + entry.getKey(), StringUtils.join(servers, ",")); }); - LOG.info("Finish register shuffle with " + (System.currentTimeMillis() - start) + " ms"); writeExtraConf(conf, extraConf); diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java index 6b549b1af3..c546bdc4df 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java @@ -207,6 +207,14 @@ public class RssSparkConfig { new ConfigBuilder(SPARK_RSS_CONFIG_PREFIX + RssClientConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER)) .createWithDefault(RssClientConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER_DEFAULT_VALUE); + public static final ConfigEntry RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL = createLongBuilder( + new ConfigBuilder(SPARK_RSS_CONFIG_PREFIX + RssClientConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL)) + .createWithDefault(RssClientConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL_DEFAULT_VALUE); + + public static final ConfigEntry RSS_CLIENT_ASSIGNMENT_RETRY_TIMES = createIntegerBuilder( + new ConfigBuilder(SPARK_RSS_CONFIG_PREFIX + RssClientConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES)) + .createWithDefault(RssClientConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES_DEFAULT_VALUE); + public static final ConfigEntry RSS_COORDINATOR_QUORUM = createStringBuilder( new ConfigBuilder(SPARK_RSS_CONFIG_PREFIX + RssClientConfig.RSS_COORDINATOR_QUORUM) .doc("Coordinator quorum")) diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index ec843086c8..753d75966e 100644 --- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -60,6 +60,8 @@ import org.apache.uniffle.common.ShuffleAssignmentsInfo; import org.apache.uniffle.common.ShuffleBlockInfo; import org.apache.uniffle.common.ShuffleServerInfo; +import org.apache.uniffle.common.exception.RssException; +import org.apache.uniffle.common.util.RetryUtils; import org.apache.uniffle.common.util.RssUtils; import org.apache.uniffle.common.util.ThreadUtils; @@ -220,13 +222,23 @@ public ShuffleHandle registerShuffle(int shuffleId, int numMaps, Shuff int requiredShuffleServerNumber = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER); - ShuffleAssignmentsInfo response = shuffleWriteClient.getShuffleAssignments( - appId, shuffleId, dependency.partitioner().numPartitions(), - partitionNumPerRange, assignmentTags, requiredShuffleServerNumber); - Map> partitionToServers = response.getPartitionToServers(); + // retryInterval must bigger than `rss.server.heartbeat.timeout`, or maybe it will return the same result + long retryInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL); + int retryTimes = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES); + Map> partitionToServers; + try { + partitionToServers = RetryUtils.retry(() -> { + ShuffleAssignmentsInfo response = shuffleWriteClient.getShuffleAssignments( + appId, shuffleId, dependency.partitioner().numPartitions(), + partitionNumPerRange, assignmentTags, requiredShuffleServerNumber); + registerShuffleServers(appId, shuffleId, response.getServerToPartitionRanges()); + return response.getPartitionToServers(); + }, retryInterval, retryTimes); + } catch (Throwable throwable) { + throw new RssException("registerShuffle failed!", throwable); + } startHeartbeat(); - registerShuffleServers(appId, shuffleId, response.getServerToPartitionRanges()); LOG.info("RegisterShuffle with ShuffleId[" + shuffleId + "], partitionNum[" + partitionToServers.size() + "]"); return new RssShuffleHandle(shuffleId, appId, numMaps, dependency, partitionToServers, remoteStorage); diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index 80fac99eba..030e56fb5a 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -64,6 +64,8 @@ import org.apache.uniffle.common.ShuffleAssignmentsInfo; import org.apache.uniffle.common.ShuffleBlockInfo; import org.apache.uniffle.common.ShuffleServerInfo; +import org.apache.uniffle.common.exception.RssException; +import org.apache.uniffle.common.util.RetryUtils; import org.apache.uniffle.common.util.RssUtils; import org.apache.uniffle.common.util.ThreadUtils; @@ -258,17 +260,26 @@ public ShuffleHandle registerShuffle(int shuffleId, ShuffleDependency< int requiredShuffleServerNumber = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER); - ShuffleAssignmentsInfo response = shuffleWriteClient.getShuffleAssignments( - id.get(), - shuffleId, - dependency.partitioner().numPartitions(), - 1, - assignmentTags, - requiredShuffleServerNumber); - Map> partitionToServers = response.getPartitionToServers(); - + // retryInterval must bigger than `rss.server.heartbeat.timeout`, or maybe it will return the same result + long retryInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL); + int retryTimes = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES); + Map> partitionToServers; + try { + partitionToServers = RetryUtils.retry(() -> { + ShuffleAssignmentsInfo response = shuffleWriteClient.getShuffleAssignments( + id.get(), + shuffleId, + dependency.partitioner().numPartitions(), + 1, + assignmentTags, + requiredShuffleServerNumber); + registerShuffleServers(id.get(), shuffleId, response.getServerToPartitionRanges()); + return response.getPartitionToServers(); + }, retryInterval, retryTimes); + } catch (Throwable throwable) { + throw new RssException("registerShuffle failed!", throwable); + } startHeartbeat(); - registerShuffleServers(id.get(), shuffleId, response.getServerToPartitionRanges()); LOG.info("RegisterShuffle with ShuffleId[" + shuffleId + "], partitionNum[" + partitionToServers.size() + "], shuffleServerForResult: " + partitionToServers); diff --git a/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java b/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java index eb6006a06e..b3247b431a 100644 --- a/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java +++ b/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java @@ -60,6 +60,11 @@ public class RssClientConfig { public static final String RSS_CLIENT_READ_BUFFER_SIZE_DEFAULT_VALUE = "14m"; // The tags specified by rss client to determine server assignment. public static final String RSS_CLIENT_ASSIGNMENT_TAGS = "rss.client.assignment.tags"; + + public static final String RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL = "rss.client.assignment.retry.interval"; + public static final long RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL_DEFAULT_VALUE = 65000; + public static final String RSS_CLIENT_ASSIGNMENT_RETRY_TIMES = "rss.client.assignment.retry.times"; + public static final int RSS_CLIENT_ASSIGNMENT_RETRY_TIMES_DEFAULT_VALUE = 3; public static final String RSS_ACCESS_TIMEOUT_MS = "rss.access.timeout.ms"; public static final int RSS_ACCESS_TIMEOUT_MS_DEFAULT_VALUE = 10000; diff --git a/common/src/main/java/org/apache/uniffle/common/util/RetryUtils.java b/common/src/main/java/org/apache/uniffle/common/util/RetryUtils.java new file mode 100644 index 0000000000..603873f313 --- /dev/null +++ b/common/src/main/java/org/apache/uniffle/common/util/RetryUtils.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.common.util; + +import java.util.Set; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class RetryUtils { + private static final Logger LOG = LoggerFactory.getLogger(RetryUtils.class); + + public static T retry(RetryCmd cmd, long intervalMs, int retryTimes) throws Throwable { + return retry(cmd, null, intervalMs, retryTimes, null); + } + + /** + * @param cmd command to execute + * @param callBack the callback command executed when the attempt of command fail + * @param intervalMs retry interval + * @param retryTimes retry times + * @param exceptionClasses exception classes which need to be retry, null for all. + * @param return type + * @return + * @throws Throwable + */ + public static T retry(RetryCmd cmd, RetryCallBack callBack, long intervalMs, + int retryTimes, Set exceptionClasses) throws Throwable { + int retry = 0; + while (true) { + try { + T ret = cmd.execute(); + return ret; + } catch (Throwable t) { + retry++; + if ((exceptionClasses != null && !isInstanceOf(exceptionClasses, t)) || retry >= retryTimes) { + throw t; + } else { + LOG.info("Retry due to Throwable, " + t.getClass().getName() + " " + t.getMessage()); + LOG.info("Waiting " + intervalMs + " milliseconds before next connection attempt."); + Thread.sleep(intervalMs); + if (callBack != null) { + callBack.execute(); + } + } + } + } + } + + private static boolean isInstanceOf(Set classes, Throwable t) { + for (Class c : classes) { + if (c.isInstance(t)) { + return true; + } + } + return false; + } + + public interface RetryCmd { + T execute() throws Throwable; + } + + public interface RetryCallBack { + void execute() throws Throwable; + } +} diff --git a/common/src/test/java/org/apache/uniffle/common/util/RetryUtilsTest.java b/common/src/test/java/org/apache/uniffle/common/util/RetryUtilsTest.java new file mode 100644 index 0000000000..d119d273b8 --- /dev/null +++ b/common/src/test/java/org/apache/uniffle/common/util/RetryUtilsTest.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.common.util; + +import com.google.common.collect.Sets; +import org.apache.uniffle.common.exception.RssException; +import org.junit.jupiter.api.Test; + +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class RetryUtilsTest { + @Test + public void testRetry() { + AtomicInteger tryTimes = new AtomicInteger(); + AtomicInteger callbackTime = new AtomicInteger(); + int maxTryTime = 3; + try { + RetryUtils.retry(() -> { + tryTimes.incrementAndGet(); + throw new RssException(""); + }, () -> { + callbackTime.incrementAndGet(); + }, 10, maxTryTime, Sets.newHashSet(RssException.class)); + } catch (Throwable throwable) { + } + assertEquals(tryTimes.get(), maxTryTime); + assertEquals(callbackTime.get(), maxTryTime - 1); + + tryTimes.set(0); + try { + RetryUtils.retry(() -> { + tryTimes.incrementAndGet(); + throw new Exception(""); + }, 10, maxTryTime); + } catch (Throwable throwable) { + } + assertEquals(tryTimes.get(), maxTryTime); + + tryTimes.set(0); + try { + int ret = RetryUtils.retry(() -> { + tryTimes.incrementAndGet(); + return 1; + }, 10, maxTryTime); + assertEquals(ret, 1); + } catch (Throwable throwable) { + } + assertEquals(tryTimes.get(), 1); + } +} diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java index f2e35c1280..ec7d2277aa 100644 --- a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java +++ b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java @@ -20,12 +20,15 @@ import java.io.File; import java.util.List; import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import com.google.common.io.Files; import org.apache.uniffle.client.util.DefaultIdHelper; +import org.apache.uniffle.common.util.Constants; +import org.apache.uniffle.common.util.RetryUtils; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeAll; @@ -39,6 +42,7 @@ import org.apache.uniffle.client.util.ClientUtils; import org.apache.uniffle.common.PartitionRange; import org.apache.uniffle.common.RemoteStorageInfo; +import org.apache.uniffle.common.ShuffleAssignmentsInfo; import org.apache.uniffle.common.ShuffleBlockInfo; import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.coordinator.CoordinatorConf; @@ -47,6 +51,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; @@ -273,4 +278,45 @@ public void emptyTaskTest() { .sendCommit(Sets.newHashSet(shuffleServerInfo2), testAppId, 0, 2); assertFalse(commitResult); } + + @Test + public void testRetryAssgin() throws Throwable { + int maxTryTime = shuffleServers.size(); + AtomicInteger tryTime = new AtomicInteger(); + String appId = "app-1"; + RemoteStorageInfo remoteStorage = new RemoteStorageInfo(""); + ShuffleAssignmentsInfo response = null; + ShuffleServerConf shuffleServerConf = getShuffleServerConf(); + int heartbeatTimeout = shuffleServerConf.getInteger("rss.server.heartbeat.timeout", 65000); + int heartbeatInterval = shuffleServerConf.getInteger("rss.server.heartbeat.interval", 1000); + Thread.sleep(heartbeatInterval * 2); + shuffleWriteClientImpl.registerCoordinators(COORDINATOR_QUORUM); + response = RetryUtils.retry(() -> { + int currentTryTime = tryTime.incrementAndGet(); + ShuffleAssignmentsInfo shuffleAssignments = shuffleWriteClientImpl.getShuffleAssignments(appId, + 1, 1, 1, Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION), 1); + + Map> serverToPartitionRanges = + shuffleAssignments.getServerToPartitionRanges(); + + serverToPartitionRanges.entrySet().forEach(entry -> { + if (currentTryTime < maxTryTime) { + shuffleServers.forEach((ss) -> { + if (ss.getId().equals(entry.getKey().getId())) { + try { + ss.stopServer(); + } catch (Exception e) { + e.printStackTrace(); + } + } + }); + } + shuffleWriteClientImpl.registerShuffle( + entry.getKey(), appId, 0, entry.getValue(), remoteStorage); + }); + return shuffleAssignments; + }, heartbeatTimeout, maxTryTime); + + assertNotNull(response); + } }