diff --git a/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java b/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java index 751b1e0825..cd838c75e0 100644 --- a/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java +++ b/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java @@ -210,7 +210,8 @@ public Thread newThread(Runnable r) { numReduceTasks, 1, Sets.newHashSet(assignmentTags), - requiredAssignmentShuffleServersNum + requiredAssignmentShuffleServersNum, + -1 ); Map> serverToPartitionRanges = diff --git a/client-mr/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java b/client-mr/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java index d75d7647b8..f5c834b2bb 100644 --- a/client-mr/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java +++ b/client-mr/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java @@ -323,7 +323,8 @@ public void reportShuffleResult(Map> partitionT @Override public ShuffleAssignmentsInfo getShuffleAssignments(String appId, int shuffleId, int partitionNum, - int partitionNumPerRange, Set requiredTags, int assignmentShuffleServerNumber) { + int partitionNumPerRange, Set requiredTags, int assignmentShuffleServerNumber, + int estimateTaskConcurrency) { return null; } diff --git a/client-mr/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java b/client-mr/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java index 1b920aedab..5f05b56047 100644 --- a/client-mr/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java +++ b/client-mr/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java @@ -420,7 +420,8 @@ public void reportShuffleResult(Map> partitionT @Override public ShuffleAssignmentsInfo getShuffleAssignments(String appId, int shuffleId, int partitionNum, - int partitionNumPerRange, Set requiredTags, int assignmentShuffleServerNumber) { + int partitionNumPerRange, Set requiredTags, int assignmentShuffleServerNumber, + int estimateTaskConcurrency) { return null; } diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java index 71b4c283a6..2dfca2cd25 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java @@ -241,6 +241,20 @@ public class RssSparkConfig { .doc("Coordinator quorum")) .createWithDefault(""); + public static final ConfigEntry RSS_ESTIMATE_TASK_CONCURRENCY_DYNAMIC_FACTOR = createDoubleBuilder( + new ConfigBuilder(SPARK_RSS_CONFIG_PREFIX + RssClientConfig.RSS_ESTIMATE_TASK_CONCURRENCY_DYNAMIC_FACTOR) + .doc("Between 0 and 1, used to estimate task concurrency, how likely is this part of the resource between" + + " spark.dynamicAllocation.minExecutors and spark.dynamicAllocation.maxExecutors" + + " to be allocated")) + .createWithDefault(RssClientConfig.RSS_ESTIMATE_TASK_CONCURRENCY_DYNAMIC_FACTOR_DEFAULT_VALUE); + + public static final ConfigEntry RSS_ESTIMATE_TASK_CONCURRENCY_ENABLED = createBooleanBuilder( + new ConfigBuilder(SPARK_RSS_CONFIG_PREFIX + RssClientConfig.RSS_ESTIMATE_TASK_CONCURRENCY_ENABLED) + .doc("When the Coordinator enables rss.coordinator.select.partition.strategy," + + " this configuration item is valid and is used to estimate how many consecutive" + + " PartitionRanges should be allocated to a ShuffleServer")) + .createWithDefault(RssClientConfig.RSS_ESTIMATE_TASK_CONCURRENCY_DEFAULT_VALUE); + public static final Set RSS_MANDATORY_CLUSTER_CONF = ImmutableSet.of(RSS_STORAGE_TYPE.key(), RSS_REMOTE_STORAGE_PATH.key()); diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java index 13f7305442..a260c9a98c 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java @@ -34,6 +34,7 @@ import org.apache.uniffle.client.api.CoordinatorClient; import org.apache.uniffle.client.factory.CoordinatorClientFactory; import org.apache.uniffle.common.RemoteStorageInfo; +import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.util.Constants; public class RssSparkShuffleUtils { @@ -135,4 +136,29 @@ public static Set getAssignmentTags(SparkConf sparkConf) { assignmentTags.add(Constants.SHUFFLE_SERVER_VERSION); return assignmentTags; } + + public static int estimateTaskConcurrency(SparkConf sparkConf) { + int taskConcurrency; + double dynamicAllocationFactor = sparkConf.get(RssSparkConfig.RSS_ESTIMATE_TASK_CONCURRENCY_DYNAMIC_FACTOR); + if (dynamicAllocationFactor > 1 || dynamicAllocationFactor < 0) { + throw new RssException("dynamicAllocationFactor is not valid: " + dynamicAllocationFactor); + } + int executorCores = sparkConf.getInt(Constants.SPARK_EXECUTOR_CORES, Constants.SPARK_EXECUTOR_CORES_DEFAULT_VALUE); + int taskCpus = sparkConf.getInt(Constants.SPARK_TASK_CPUS, Constants.SPARK_TASK_CPUS_DEFAULT_VALUE); + int taskConcurrencyPerExecutor = Math.floorDiv(executorCores, taskCpus); + if (!sparkConf.getBoolean(Constants.SPARK_DYNAMIC_ENABLED, false)) { + int executorInstances = sparkConf.getInt(Constants.SPARK_EXECUTOR_INSTANTS, + Constants.SPARK_EXECUTOR_INSTANTS_DEFAULT_VALUE); + taskConcurrency = executorInstances > 0 ? executorInstances * taskConcurrencyPerExecutor : 0; + } else { + // Default is infinity + int maxExecutors = Math.min(sparkConf.getInt(Constants.SPARK_MAX_DYNAMIC_EXECUTOR, + Constants.SPARK_DYNAMIC_EXECUTOR_DEFAULT_VALUE), Constants.SPARK_MAX_DYNAMIC_EXECUTOR_LIMIT); + int minExecutors = sparkConf.getInt(Constants.SPARK_MIN_DYNAMIC_EXECUTOR, + Constants.SPARK_DYNAMIC_EXECUTOR_DEFAULT_VALUE); + taskConcurrency = (int)((maxExecutors - minExecutors) * dynamicAllocationFactor + minExecutors) + * taskConcurrencyPerExecutor; + } + return taskConcurrency; + } } diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/RssSparkShuffleUtilsTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/RssSparkShuffleUtilsTest.java index 09d1e63ead..28159be25a 100644 --- a/client-spark/common/src/test/java/org/apache/spark/shuffle/RssSparkShuffleUtilsTest.java +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/RssSparkShuffleUtilsTest.java @@ -158,4 +158,34 @@ public void applyDynamicClientConfTest() { assertEquals(Integer.toString(RssClientConfig.RSS_CLIENT_RETRY_MAX_DEFAULT_VALUE), conf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX.key())); } + + @Test + public void testEstimateTaskConcurrency() { + SparkConf sparkConf = new SparkConf(); + sparkConf.set(Constants.SPARK_DYNAMIC_ENABLED, "true"); + sparkConf.set(Constants.SPARK_MAX_DYNAMIC_EXECUTOR, "200"); + sparkConf.set(Constants.SPARK_MIN_DYNAMIC_EXECUTOR, "100"); + sparkConf.set(RssSparkConfig.RSS_ESTIMATE_TASK_CONCURRENCY_ENABLED, true); + sparkConf.set(Constants.SPARK_EXECUTOR_CORES, "2"); + int taskConcurrency; + + sparkConf.set(RssSparkConfig.RSS_ESTIMATE_TASK_CONCURRENCY_DYNAMIC_FACTOR, 1.0); + taskConcurrency = RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf); + assertEquals(400, taskConcurrency); + + sparkConf.set(RssSparkConfig.RSS_ESTIMATE_TASK_CONCURRENCY_DYNAMIC_FACTOR, 0.3); + taskConcurrency = RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf); + assertEquals(260, taskConcurrency); + + sparkConf.set(Constants.SPARK_TASK_CPUS, "2"); + sparkConf.set(RssSparkConfig.RSS_ESTIMATE_TASK_CONCURRENCY_DYNAMIC_FACTOR, 0.3); + taskConcurrency = RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf); + assertEquals(130, taskConcurrency); + + sparkConf.set(Constants.SPARK_DYNAMIC_ENABLED, "false"); + sparkConf.set(Constants.SPARK_EXECUTOR_INSTANTS, "70"); + sparkConf.set(RssSparkConfig.RSS_ESTIMATE_TASK_CONCURRENCY_DYNAMIC_FACTOR, 1.0); + taskConcurrency = RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf); + assertEquals(70, taskConcurrency); + } } diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index cf79c70c9c..774cd3caeb 100644 --- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -233,7 +233,7 @@ public ShuffleHandle registerShuffle(int shuffleId, int numMaps, Shuff partitionToServers = RetryUtils.retry(() -> { ShuffleAssignmentsInfo response = shuffleWriteClient.getShuffleAssignments( appId, shuffleId, dependency.partitioner().numPartitions(), - partitionNumPerRange, assignmentTags, requiredShuffleServerNumber); + partitionNumPerRange, assignmentTags, requiredShuffleServerNumber, -1); registerShuffleServers(appId, shuffleId, response.getServerToPartitionRanges(), remoteStorage); return response.getPartitionToServers(); }, retryInterval, retryTimes); diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index 27368d2079..e780cb3ae0 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -282,6 +282,9 @@ public ShuffleHandle registerShuffle(int shuffleId, ShuffleDependency< // retryInterval must bigger than `rss.server.heartbeat.timeout`, or maybe it will return the same result long retryInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL); int retryTimes = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES); + boolean enabledEstimateTaskConcurrency = sparkConf.get(RssSparkConfig.RSS_ESTIMATE_TASK_CONCURRENCY_ENABLED); + int estimateTaskConcurrency = enabledEstimateTaskConcurrency + ? RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf) : -1; Map> partitionToServers; try { partitionToServers = RetryUtils.retry(() -> { @@ -291,7 +294,8 @@ public ShuffleHandle registerShuffle(int shuffleId, ShuffleDependency< dependency.partitioner().numPartitions(), 1, assignmentTags, - requiredShuffleServerNumber); + requiredShuffleServerNumber, + estimateTaskConcurrency); registerShuffleServers(id.get(), shuffleId, response.getServerToPartitionRanges(), remoteStorage); return response.getPartitionToServers(); }, retryInterval, retryTimes); diff --git a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java index 4ab4b51c6c..9776678ca5 100644 --- a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java +++ b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java @@ -62,7 +62,8 @@ void reportShuffleResult( int bitmapNum); ShuffleAssignmentsInfo getShuffleAssignments(String appId, int shuffleId, int partitionNum, - int partitionNumPerRange, Set requiredTags, int assignmentShuffleServerNumber); + int partitionNumPerRange, Set requiredTags, int assignmentShuffleServerNumber, + int estimateTaskConcurrency); Roaring64NavigableMap getShuffleResult(String clientType, Set shuffleServerInfoSet, String appId, int shuffleId, int partitionId); diff --git a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java index be83ca0fc1..e44b9cd108 100644 --- a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java +++ b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java @@ -394,9 +394,11 @@ public RemoteStorageInfo fetchRemoteStorage(String appId) { @Override public ShuffleAssignmentsInfo getShuffleAssignments(String appId, int shuffleId, int partitionNum, - int partitionNumPerRange, Set requiredTags, int assignmentShuffleServerNumber) { + int partitionNumPerRange, Set requiredTags, int assignmentShuffleServerNumber, + int estimateTaskConcurrency) { RssGetShuffleAssignmentsRequest request = new RssGetShuffleAssignmentsRequest( - appId, shuffleId, partitionNum, partitionNumPerRange, replica, requiredTags, assignmentShuffleServerNumber); + appId, shuffleId, partitionNum, partitionNumPerRange, replica, requiredTags, + assignmentShuffleServerNumber, estimateTaskConcurrency); RssGetShuffleAssignmentsResponse response = new RssGetShuffleAssignmentsResponse(ResponseStatusCode.INTERNAL_ERROR); for (CoordinatorClient coordinatorClient : coordinatorClients) { diff --git a/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java b/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java index b3247b431a..160f5d3d60 100644 --- a/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java +++ b/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java @@ -74,4 +74,11 @@ public class RssClientConfig { public static final String RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER = "rss.client.assignment.shuffle.nodes.max"; public static final int RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER_DEFAULT_VALUE = -1; + + public static final String RSS_ESTIMATE_TASK_CONCURRENCY_DYNAMIC_FACTOR = + "rss.estimate.task.concurrency.dynamic.factor"; + public static final double RSS_ESTIMATE_TASK_CONCURRENCY_DYNAMIC_FACTOR_DEFAULT_VALUE = 1.0; + + public static final String RSS_ESTIMATE_TASK_CONCURRENCY_ENABLED = "rss.estimate.task.concurrency.enabled"; + public static final boolean RSS_ESTIMATE_TASK_CONCURRENCY_DEFAULT_VALUE = false; } diff --git a/common/src/main/java/org/apache/uniffle/common/util/Constants.java b/common/src/main/java/org/apache/uniffle/common/util/Constants.java index b480ce5847..4fd2f12ae6 100644 --- a/common/src/main/java/org/apache/uniffle/common/util/Constants.java +++ b/common/src/main/java/org/apache/uniffle/common/util/Constants.java @@ -45,4 +45,21 @@ public class Constants { RSS_CLIENT_CONF_COMMON_PREFIX + CONF_REMOTE_STORAGE_PATH; public static final String ACCESS_INFO_REQUIRED_SHUFFLE_NODES_NUM = "access_info_required_shuffle_nodes_num"; + public static final String SPARK_DYNAMIC_ENABLED = "spark.dynamicAllocation.enabled"; + public static final String SPARK_MAX_DYNAMIC_EXECUTOR = "spark.dynamicAllocation.maxExecutors"; + public static final String SPARK_MIN_DYNAMIC_EXECUTOR = "spark.dynamicAllocation.minExecutors"; + public static final int SPARK_DYNAMIC_EXECUTOR_DEFAULT_VALUE = 0; + public static final String SPARK_EXECUTOR_INSTANTS = "spark.executor.instances"; + public static final int SPARK_EXECUTOR_INSTANTS_DEFAULT_VALUE = -1; + public static final String SPARK_EXECUTOR_CORES = "spark.executor.cores"; + public static final int SPARK_EXECUTOR_CORES_DEFAULT_VALUE = 1; + public static final String SPARK_TASK_CPUS = "spark.task.cpus"; + public static final int SPARK_TASK_CPUS_DEFAULT_VALUE = 1; + public static final int SPARK_MAX_DYNAMIC_EXECUTOR_LIMIT = 10000; + + public static final String MR_MAPS = "mapreduce.job.maps"; + public static final String MR_REDUCES = "mapreduce.job.reduces"; + public static final String MR_MAP_LIMIT = "mapreduce.job.running.map.limit"; + public static final String MR_REDUCE_LIMIT = "mapreduce.job.running.reduce.limit"; + public static final String MR_SLOW_START = "mapreduce.job.reduce.slowstart.completedmaps"; } diff --git a/coordinator/src/main/java/org/apache/uniffle/coordinator/AbstractAssignmentStrategy.java b/coordinator/src/main/java/org/apache/uniffle/coordinator/AbstractAssignmentStrategy.java index a16bdf0943..4116f6fc98 100644 --- a/coordinator/src/main/java/org/apache/uniffle/coordinator/AbstractAssignmentStrategy.java +++ b/coordinator/src/main/java/org/apache/uniffle/coordinator/AbstractAssignmentStrategy.java @@ -17,69 +17,69 @@ package org.apache.uniffle.coordinator; -import java.util.ArrayList; -import java.util.HashSet; import java.util.List; -import java.util.Set; -import java.util.stream.Collectors; +import java.util.SortedMap; -import static org.apache.uniffle.coordinator.CoordinatorConf.COORDINATOR_ASSGINMENT_HOST_STRATEGY; +import org.apache.uniffle.common.PartitionRange; + +import static org.apache.uniffle.coordinator.CoordinatorConf.COORDINATOR_ASSIGNMENT_HOST_STRATEGY; +import static org.apache.uniffle.coordinator.CoordinatorConf.COORDINATOR_SELECT_PARTITION_STRATEGY; public abstract class AbstractAssignmentStrategy implements AssignmentStrategy { protected final CoordinatorConf conf; - private final HostAssignmentStrategy assignmentHostStrategy; + private HostAssignmentStrategy hostAssignmentStrategy; + private SelectPartitionStrategy selectPartitionStrategy; public AbstractAssignmentStrategy(CoordinatorConf conf) { this.conf = conf; - assignmentHostStrategy = conf.get(COORDINATOR_ASSGINMENT_HOST_STRATEGY); + loadHostAssignmentStrategy(); + loadSelectPartitionStrategy(); } - protected List getCandidateNodes(List allNodes, int expectNum) { - switch (assignmentHostStrategy) { - case MUST_DIFF: return getCandidateNodesWithDiffHost(allNodes, expectNum); - case PREFER_DIFF: return tryGetCandidateNodesWithDiffHost(allNodes, expectNum); - case NONE: return allNodes.subList(0, expectNum); - default: throw new RuntimeException("Unsupported host assignment strategy:" + assignmentHostStrategy); + private void loadSelectPartitionStrategy() { + SelectPartitionStrategyName selectPartitionStrategyName = + conf.get(COORDINATOR_SELECT_PARTITION_STRATEGY); + if (selectPartitionStrategyName == SelectPartitionStrategyName.ROUND) { + selectPartitionStrategy = new RoundSelectPartitionStrategy(); + } else if (selectPartitionStrategyName == SelectPartitionStrategyName.CONTINUOUS) { + selectPartitionStrategy = new ContinuousSelectPartitionStrategy(); + } else { + throw new RuntimeException("Unsupported partition assignment strategy:" + selectPartitionStrategyName); } } - protected List tryGetCandidateNodesWithDiffHost(List allNodes, int expectNum) { - List candidatesNodes = getCandidateNodesWithDiffHost(allNodes, expectNum); - Set candidatesNodeSet = candidatesNodes.stream().collect(Collectors.toSet()); - if (candidatesNodes.size() < expectNum) { - for (ServerNode node : allNodes) { - if (candidatesNodeSet.contains(node)) { - continue; - } - candidatesNodes.add(node); - if (candidatesNodes.size() >= expectNum) { - break; - } - } + private void loadHostAssignmentStrategy() { + HostAssignmentStrategyName hostAssignmentStrategyName = conf.get(COORDINATOR_ASSIGNMENT_HOST_STRATEGY); + if (hostAssignmentStrategyName == HostAssignmentStrategyName.MUST_DIFF) { + hostAssignmentStrategy = new MustDiffHostAssignmentStrategy(); + } else if (hostAssignmentStrategyName == HostAssignmentStrategyName.PREFER_DIFF) { + hostAssignmentStrategy = new PerferDiffHostAssignmentStrategy(); + } else if (hostAssignmentStrategyName == HostAssignmentStrategyName.NONE) { + hostAssignmentStrategy = new BasicHostAssignmentStrategy(); + } else { + throw new RuntimeException("Unsupported partition assignment strategy:" + hostAssignmentStrategyName); } - return candidatesNodes; } - protected List getCandidateNodesWithDiffHost(List allNodes, int expectNum) { - List candidatesNodes = new ArrayList<>(); - Set hostIpCandidate = new HashSet<>(); - for (ServerNode node : allNodes) { - if (hostIpCandidate.contains(node.getIp())) { - continue; - } - hostIpCandidate.add(node.getIp()); - candidatesNodes.add(node); - if (candidatesNodes.size() >= expectNum) { - break; - } - } - return candidatesNodes; + protected List getCandidateNodes(List allNodes, int expectNum) { + return hostAssignmentStrategy.assign(allNodes, expectNum); } + protected SortedMap> getPartitionAssignment( + int totalPartitionNum, int partitionNumPerRange, int replica, List candidatesNodes, + int estimateTaskConcurrency) { + return selectPartitionStrategy.assign(totalPartitionNum, partitionNumPerRange, replica, + candidatesNodes, estimateTaskConcurrency); + } - public enum HostAssignmentStrategy { + public enum HostAssignmentStrategyName { MUST_DIFF, PREFER_DIFF, NONE } + + public enum SelectPartitionStrategyName { + ROUND, + CONTINUOUS + } } diff --git a/coordinator/src/main/java/org/apache/uniffle/coordinator/AssignmentStrategy.java b/coordinator/src/main/java/org/apache/uniffle/coordinator/AssignmentStrategy.java index 36d1908caf..88c0027a5c 100644 --- a/coordinator/src/main/java/org/apache/uniffle/coordinator/AssignmentStrategy.java +++ b/coordinator/src/main/java/org/apache/uniffle/coordinator/AssignmentStrategy.java @@ -22,6 +22,6 @@ public interface AssignmentStrategy { PartitionRangeAssignment assign(int totalPartitionNum, int partitionNumPerRange, - int replica, Set requiredTags, int requiredShuffleServerNumber); + int replica, Set requiredTags, int requiredShuffleServerNumber, int estimateTaskConcurrency); } diff --git a/coordinator/src/main/java/org/apache/uniffle/coordinator/BasicAssignmentStrategy.java b/coordinator/src/main/java/org/apache/uniffle/coordinator/BasicAssignmentStrategy.java index 9bb4ba87b6..d56cef569f 100644 --- a/coordinator/src/main/java/org/apache/uniffle/coordinator/BasicAssignmentStrategy.java +++ b/coordinator/src/main/java/org/apache/uniffle/coordinator/BasicAssignmentStrategy.java @@ -18,11 +18,9 @@ package org.apache.uniffle.coordinator; import java.util.Collections; -import java.util.LinkedList; import java.util.List; import java.util.Set; import java.util.SortedMap; -import java.util.TreeMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -42,8 +40,7 @@ public BasicAssignmentStrategy(ClusterManager clusterManager, CoordinatorConf co @Override public PartitionRangeAssignment assign(int totalPartitionNum, int partitionNumPerRange, - int replica, Set requiredTags, int requiredShuffleServerNumber) { - List ranges = CoordinatorUtils.generateRanges(totalPartitionNum, partitionNumPerRange); + int replica, Set requiredTags, int requiredShuffleServerNumber, int estimateTaskConcurrency) { int shuffleNodesMax = clusterManager.getShuffleNodesMax(); int expectedShuffleNodesNum = shuffleNodesMax; if (requiredShuffleServerNumber < shuffleNodesMax && requiredShuffleServerNumber > 0) { @@ -54,20 +51,8 @@ public PartitionRangeAssignment assign(int totalPartitionNum, int partitionNumPe return new PartitionRangeAssignment(null); } - SortedMap> assignments = new TreeMap<>(); - int idx = 0; - int size = servers.size(); - - for (PartitionRange range : ranges) { - List nodes = new LinkedList<>(); - for (int i = 0; i < replica; ++i) { - ServerNode node = servers.get(idx); - nodes.add(node); - idx = CoordinatorUtils.nextIdx(idx, size); - } - - assignments.put(range, nodes); - } + SortedMap> assignments = + getPartitionAssignment(totalPartitionNum, partitionNumPerRange, replica, servers, estimateTaskConcurrency); return new PartitionRangeAssignment(assignments); } diff --git a/coordinator/src/main/java/org/apache/uniffle/coordinator/BasicHostAssignmentStrategy.java b/coordinator/src/main/java/org/apache/uniffle/coordinator/BasicHostAssignmentStrategy.java new file mode 100644 index 0000000000..a8bfc0bf53 --- /dev/null +++ b/coordinator/src/main/java/org/apache/uniffle/coordinator/BasicHostAssignmentStrategy.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.coordinator; + +import java.util.List; + +public class BasicHostAssignmentStrategy implements HostAssignmentStrategy { + @Override + public List assign(List allNodes, int expectNum) { + return allNodes.subList(0, expectNum); + } +} diff --git a/coordinator/src/main/java/org/apache/uniffle/coordinator/ContinuousSelectPartitionStrategy.java b/coordinator/src/main/java/org/apache/uniffle/coordinator/ContinuousSelectPartitionStrategy.java new file mode 100644 index 0000000000..9a0f9bbb79 --- /dev/null +++ b/coordinator/src/main/java/org/apache/uniffle/coordinator/ContinuousSelectPartitionStrategy.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.coordinator; + +import java.util.List; +import java.util.SortedMap; +import java.util.TreeMap; + +import com.google.common.collect.Lists; + +import org.apache.uniffle.common.PartitionRange; + +public class ContinuousSelectPartitionStrategy implements SelectPartitionStrategy { + @Override + public SortedMap> assign( + int totalPartitionNum, int partitionNumPerRange, int replica, + List candidatesNodes, int estimateTaskConcurrency) { + SortedMap> assignments = new TreeMap<>(); + int serverNum = candidatesNodes.size(); + List> rangesGroup = CoordinatorUtils.generateRangesGroup(totalPartitionNum, + partitionNumPerRange, serverNum, estimateTaskConcurrency); + + for (int rc = 0; rc < replica; rc++) { + for (int i = 0; i < rangesGroup.size(); i++) { + ServerNode node = candidatesNodes.get((i + rc) % serverNum); + List ranges = rangesGroup.get(i); + ranges.forEach(range -> { + List serverNodes = assignments.computeIfAbsent(range, key -> Lists.newArrayList()); + serverNodes.add(node); + }); + } + } + return assignments; + } +} diff --git a/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorConf.java b/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorConf.java index 70408cb574..ae98765636 100644 --- a/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorConf.java +++ b/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorConf.java @@ -153,11 +153,11 @@ public class CoordinatorConf extends RssBaseConf { .intType() .defaultValue(3) .withDescription("The number of times to read and write HDFS files"); - public static final ConfigOption - COORDINATOR_ASSGINMENT_HOST_STRATEGY = + public static final ConfigOption + COORDINATOR_ASSIGNMENT_HOST_STRATEGY = ConfigOptions.key("rss.coordinator.assignment.host.strategy") - .enumType(AbstractAssignmentStrategy.HostAssignmentStrategy.class) - .defaultValue(AbstractAssignmentStrategy.HostAssignmentStrategy.PREFER_DIFF) + .enumType(AbstractAssignmentStrategy.HostAssignmentStrategyName.class) + .defaultValue(AbstractAssignmentStrategy.HostAssignmentStrategyName.PREFER_DIFF) .withDescription("Strategy for selecting shuffle servers"); public static final ConfigOption COORDINATOR_START_SILENT_PERIOD_ENABLED = ConfigOptions .key("rss.coordinator.startup-silent-period.enabled") @@ -172,6 +172,12 @@ public class CoordinatorConf extends RssBaseConf { .defaultValue(20 * 1000L) .withDescription("The waiting duration(ms) when conf of " + COORDINATOR_START_SILENT_PERIOD_ENABLED + " is enabled."); + public static final ConfigOption + COORDINATOR_SELECT_PARTITION_STRATEGY = + ConfigOptions.key("rss.coordinator.select.partition.strategy") + .enumType(AbstractAssignmentStrategy.SelectPartitionStrategyName.class) + .defaultValue(AbstractAssignmentStrategy.SelectPartitionStrategyName.ROUND) + .withDescription("Strategy for selecting partitions"); public CoordinatorConf() { } diff --git a/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorGrpcService.java b/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorGrpcService.java index 987c85a9b6..e9cbb79b9e 100644 --- a/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorGrpcService.java +++ b/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorGrpcService.java @@ -110,6 +110,7 @@ public void getShuffleAssignments( final int replica = request.getDataReplica(); final Set requiredTags = Sets.newHashSet(request.getRequireTagsList()); final int requiredShuffleServerNumber = request.getAssignmentShuffleServerNumber(); + final int estimateTaskConcurrency = request.getEstimateTaskConcurrency(); LOG.info("Request of getShuffleAssignments for appId[" + appId + "], shuffleId[" + shuffleId + "], partitionNum[" + partitionNum @@ -127,7 +128,8 @@ public void getShuffleAssignments( final PartitionRangeAssignment pra = coordinatorServer .getAssignmentStrategy() - .assign(partitionNum, partitionNumPerRange, replica, requiredTags, requiredShuffleServerNumber); + .assign(partitionNum, partitionNumPerRange, replica, requiredTags, + requiredShuffleServerNumber, estimateTaskConcurrency); response = CoordinatorUtils.toGetShuffleAssignmentsResponse(pra); logAssignmentResult(appId, shuffleId, pra); diff --git a/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorUtils.java b/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorUtils.java index e3984ed7fc..3db6627861 100644 --- a/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorUtils.java +++ b/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorUtils.java @@ -21,6 +21,7 @@ import java.util.List; import java.util.Map; +import com.google.common.collect.Lists; import com.google.common.collect.Maps; import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.StringUtils; @@ -41,8 +42,8 @@ public static GetShuffleAssignmentsResponse toGetShuffleAssignmentsResponse( List praList = pra.convertToGrpcProto(); return GetShuffleAssignmentsResponse.newBuilder() - .addAllAssignments(praList) - .build(); + .addAllAssignments(praList) + .build(); } public static int nextIdx(int idx, int size) { @@ -53,6 +54,63 @@ public static int nextIdx(int idx, int size) { return idx; } + /** + * Assign multiple adjacent partitionRanges to several servers, The result returned is a double + * PartitionRange list, the first list will be assigned to server1, + * the second list will be assigned to server2, and so on. + * Suppose totalPartitionNum=52, partitionNumPerRange=2, serverNum=5, estimateTaskConcurrency=20 + * The final result generated is: + * server1: [0,1] [2,3] [4,5] [6,7] [40,41] [42,43] + * server2: [8,9] [10,11] [12,13] [14,15] [44,45] + * server3: [16,17] [18,19] [20,21] [22,23] [46,47] + * server4: [24,25] [26,27] [28,29] [30,31] [48,49] + * server5: [32,33] [34,35] [36,37] [38,39] [50,51] + */ + public static List> generateRangesGroup(int totalPartitionNum, int partitionNumPerRange, + int serverNum, int estimateTaskConcurrency) { + List> res = Lists.newArrayList(); + if (totalPartitionNum <= 0 || partitionNumPerRange <= 0) { + return res; + } + estimateTaskConcurrency = Math.min(totalPartitionNum, estimateTaskConcurrency); + int rangePerGroup = estimateTaskConcurrency > serverNum * partitionNumPerRange + ? Math.floorDiv(estimateTaskConcurrency, serverNum * partitionNumPerRange) : 1; + int totalRanges = (int) Math.ceil(totalPartitionNum * 1.0 / partitionNumPerRange); + int groupCount = 0; + int round = Math.floorDiv(totalRanges, rangePerGroup * serverNum); + int remainRange = totalRanges % (rangePerGroup * serverNum); + int lastRoundRangePerGroup = Math.floorDiv(remainRange, serverNum); + int lastRoundRemainRange = remainRange % serverNum; + int rangeInGroupCount = 0; + + List rangeGroup = Lists.newArrayList(); + for (int start = 0; start < totalPartitionNum; start += partitionNumPerRange) { + int end = start + partitionNumPerRange - 1; + PartitionRange range = new PartitionRange(start, end); + rangeGroup.add(range); + rangeInGroupCount += 1; + + boolean isLastRound = groupCount >= round * serverNum; + int groupIndexInRound = groupCount % serverNum; + if ((!isLastRound && rangeInGroupCount == rangePerGroup) + || (isLastRound + && ((groupIndexInRound < lastRoundRemainRange + && rangeInGroupCount == lastRoundRangePerGroup + 1) + || (groupIndexInRound >= lastRoundRemainRange + && rangeInGroupCount == lastRoundRangePerGroup)))) { + res.add(Lists.newArrayList(rangeGroup)); + rangeGroup.clear(); + rangeInGroupCount = 0; + groupCount += 1; + } + } + + if (!rangeGroup.isEmpty()) { + res.add(Lists.newArrayList(rangeGroup)); + } + return res; + } + public static List generateRanges(int totalPartitionNum, int partitionNumPerRange) { List ranges = new ArrayList<>(); if (totalPartitionNum <= 0 || partitionNumPerRange <= 0) { diff --git a/coordinator/src/main/java/org/apache/uniffle/coordinator/HostAssignmentStrategy.java b/coordinator/src/main/java/org/apache/uniffle/coordinator/HostAssignmentStrategy.java new file mode 100644 index 0000000000..b8c9552a16 --- /dev/null +++ b/coordinator/src/main/java/org/apache/uniffle/coordinator/HostAssignmentStrategy.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.coordinator; + +import java.util.List; + +public interface HostAssignmentStrategy { + List assign(List allNodes, int expectNum); +} diff --git a/coordinator/src/main/java/org/apache/uniffle/coordinator/MustDiffHostAssignmentStrategy.java b/coordinator/src/main/java/org/apache/uniffle/coordinator/MustDiffHostAssignmentStrategy.java new file mode 100644 index 0000000000..d4c0583062 --- /dev/null +++ b/coordinator/src/main/java/org/apache/uniffle/coordinator/MustDiffHostAssignmentStrategy.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.coordinator; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +public class MustDiffHostAssignmentStrategy implements HostAssignmentStrategy { + @Override + public List assign(List allNodes, int expectNum) { + List candidatesNodes = new ArrayList<>(); + Set hostIpCandidate = new HashSet<>(); + for (ServerNode node : allNodes) { + if (hostIpCandidate.contains(node.getIp())) { + continue; + } + hostIpCandidate.add(node.getIp()); + candidatesNodes.add(node); + if (candidatesNodes.size() >= expectNum) { + break; + } + } + return candidatesNodes; + } +} diff --git a/coordinator/src/main/java/org/apache/uniffle/coordinator/PartitionBalanceAssignmentStrategy.java b/coordinator/src/main/java/org/apache/uniffle/coordinator/PartitionBalanceAssignmentStrategy.java index a31f4bb950..088ce5bc38 100644 --- a/coordinator/src/main/java/org/apache/uniffle/coordinator/PartitionBalanceAssignmentStrategy.java +++ b/coordinator/src/main/java/org/apache/uniffle/coordinator/PartitionBalanceAssignmentStrategy.java @@ -17,15 +17,14 @@ package org.apache.uniffle.coordinator; +import java.util.Collection; import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.Set; import java.util.SortedMap; -import java.util.TreeMap; import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.Lists; import com.google.common.collect.Maps; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -68,13 +67,14 @@ public PartitionRangeAssignment assign( int partitionNumPerRange, int replica, Set requiredTags, - int requiredShuffleServerNumber) { + int requiredShuffleServerNumber, + int estimateTaskConcurrency) { if (partitionNumPerRange != 1) { throw new RuntimeException("PartitionNumPerRange must be one"); } - SortedMap> assignments = new TreeMap<>(); + SortedMap> assignments; synchronized (this) { List nodes = clusterManager.getServerList(requiredTags); Map newPartitionInfos = Maps.newConcurrentMap(); @@ -121,18 +121,10 @@ public int compare(ServerNode o1, ServerNode o2) { } List candidatesNodes = getCandidateNodes(nodes, expectNum); - int idx = 0; - List ranges = CoordinatorUtils.generateRanges(totalPartitionNum, 1); - for (PartitionRange range : ranges) { - List assignNodes = Lists.newArrayList(); - for (int rc = 0; rc < replica; rc++) { - ServerNode node = candidatesNodes.get(idx); - idx = CoordinatorUtils.nextIdx(idx, candidatesNodes.size()); - serverToPartitions.get(node).incrementPartitionNum(); - assignNodes.add(node); - } - assignments.put(range, assignNodes); - } + assignments = getPartitionAssignment(totalPartitionNum, partitionNumPerRange, replica, + candidatesNodes, estimateTaskConcurrency); + assignments.values().stream().flatMap(Collection::stream) + .forEach(server -> serverToPartitions.get(server).incrementPartitionNum()); } return new PartitionRangeAssignment(assignments); } @@ -164,6 +156,10 @@ public void incrementPartitionNum() { partitionNum++; } + public void incrementPartitionNum(int val) { + partitionNum += val; + } + public long getTimestamp() { return timestamp; } diff --git a/coordinator/src/main/java/org/apache/uniffle/coordinator/PerferDiffHostAssignmentStrategy.java b/coordinator/src/main/java/org/apache/uniffle/coordinator/PerferDiffHostAssignmentStrategy.java new file mode 100644 index 0000000000..6721210736 --- /dev/null +++ b/coordinator/src/main/java/org/apache/uniffle/coordinator/PerferDiffHostAssignmentStrategy.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.coordinator; + +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + +public class PerferDiffHostAssignmentStrategy implements HostAssignmentStrategy { + + private MustDiffHostAssignmentStrategy strategy; + + public PerferDiffHostAssignmentStrategy() { + strategy = new MustDiffHostAssignmentStrategy(); + } + + @Override + public List assign(List allNodes, int expectNum) { + List candidatesNodes = strategy.assign(allNodes, expectNum); + Set candidatesNodeSet = candidatesNodes.stream().collect(Collectors.toSet()); + if (candidatesNodes.size() < expectNum) { + for (ServerNode node : allNodes) { + if (candidatesNodeSet.contains(node)) { + continue; + } + candidatesNodes.add(node); + if (candidatesNodes.size() >= expectNum) { + break; + } + } + } + return candidatesNodes; + } +} diff --git a/coordinator/src/main/java/org/apache/uniffle/coordinator/RoundSelectPartitionStrategy.java b/coordinator/src/main/java/org/apache/uniffle/coordinator/RoundSelectPartitionStrategy.java new file mode 100644 index 0000000000..affaa64167 --- /dev/null +++ b/coordinator/src/main/java/org/apache/uniffle/coordinator/RoundSelectPartitionStrategy.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.coordinator; + +import java.util.List; +import java.util.SortedMap; +import java.util.TreeMap; + +import com.google.common.collect.Lists; + +import org.apache.uniffle.common.PartitionRange; + +public class RoundSelectPartitionStrategy implements SelectPartitionStrategy { + @Override + public SortedMap> assign( + int totalPartitionNum, int partitionNumPerRange, int replica, + List candidatesNodes, int estimateTaskConcurrency) { + SortedMap> assignments = new TreeMap<>(); + int idx = 0; + List ranges = CoordinatorUtils.generateRanges(totalPartitionNum, partitionNumPerRange); + for (PartitionRange range : ranges) { + List assignNodes = Lists.newArrayList(); + for (int rc = 0; rc < replica; rc++) { + ServerNode node = candidatesNodes.get(idx); + idx = CoordinatorUtils.nextIdx(idx, candidatesNodes.size()); + assignNodes.add(node); + } + assignments.put(range, assignNodes); + } + return assignments; + } +} diff --git a/coordinator/src/main/java/org/apache/uniffle/coordinator/SelectPartitionStrategy.java b/coordinator/src/main/java/org/apache/uniffle/coordinator/SelectPartitionStrategy.java new file mode 100644 index 0000000000..e785c0d347 --- /dev/null +++ b/coordinator/src/main/java/org/apache/uniffle/coordinator/SelectPartitionStrategy.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.coordinator; + +import java.util.List; +import java.util.SortedMap; + +import org.apache.uniffle.common.PartitionRange; + +public interface SelectPartitionStrategy { + /** + * Partition allocation strategy, which defines how to assign several partitions to several servers + */ + SortedMap> assign(int totalPartitionNum, + int replica, int partitionNumPerRange, List candidatesNodes, int estimateTaskConcurrency); +} diff --git a/coordinator/src/test/java/org/apache/uniffle/coordinator/BasicAssignmentStrategyTest.java b/coordinator/src/test/java/org/apache/uniffle/coordinator/BasicAssignmentStrategyTest.java index 89e54ac983..6a2d373810 100644 --- a/coordinator/src/test/java/org/apache/uniffle/coordinator/BasicAssignmentStrategyTest.java +++ b/coordinator/src/test/java/org/apache/uniffle/coordinator/BasicAssignmentStrategyTest.java @@ -26,6 +26,7 @@ import java.util.SortedMap; import java.util.stream.Collectors; +import com.google.common.collect.Lists; import com.google.common.collect.Sets; import org.apache.hadoop.conf.Configuration; import org.junit.jupiter.api.AfterEach; @@ -67,7 +68,7 @@ public void testAssign() { 20 - i, 0, tags, true)); } - PartitionRangeAssignment pra = strategy.assign(100, 10, 2, tags, -1); + PartitionRangeAssignment pra = strategy.assign(100, 10, 2, tags, -1, -1); SortedMap> assignments = pra.getAssignments(); assertEquals(10, assignments.size()); @@ -93,14 +94,14 @@ public void testRandomAssign() { clusterManager.add(new ServerNode(String.valueOf(i), "127.0.0." + i, 0, 0, 0, 0, 0, tags, true)); } - PartitionRangeAssignment pra = strategy.assign(100, 10, 2, tags, -1); + PartitionRangeAssignment pra = strategy.assign(100, 10, 2, tags, -1, -1); SortedMap> assignments = pra.getAssignments(); Set serverNodes1 = Sets.newHashSet(); for (Map.Entry> assignment : assignments.entrySet()) { serverNodes1.addAll(assignment.getValue()); } - pra = strategy.assign(100, 10, 2, tags, -1); + pra = strategy.assign(100, 10, 2, tags, -1, -1); assignments = pra.getAssignments(); Set serverNodes2 = Sets.newHashSet(); for (Map.Entry> assignment : assignments.entrySet()) { @@ -121,13 +122,13 @@ public void testAssignWithDifferentNodeNum() { 0, 0, tags, true); clusterManager.add(sn1); - PartitionRangeAssignment pra = strategy.assign(100, 10, 2, tags, -1); + PartitionRangeAssignment pra = strategy.assign(100, 10, 2, tags, -1, -1); // nodeNum < replica assertNull(pra.getAssignments()); // nodeNum = replica clusterManager.add(sn2); - pra = strategy.assign(100, 10, 2, tags, -1); + pra = strategy.assign(100, 10, 2, tags, -1, -1); SortedMap> assignments = pra.getAssignments(); Set serverNodes = Sets.newHashSet(); for (Map.Entry> assignment : assignments.entrySet()) { @@ -139,7 +140,7 @@ public void testAssignWithDifferentNodeNum() { // nodeNum > replica & nodeNum < shuffleNodesMax clusterManager.add(sn3); - pra = strategy.assign(100, 10, 2, tags, -1); + pra = strategy.assign(100, 10, 2, tags, -1, -1); assignments = pra.getAssignments(); serverNodes = Sets.newHashSet(); for (Map.Entry> assignment : assignments.entrySet()) { @@ -164,7 +165,7 @@ public void testAssignmentShuffleNodesNum() { * case1: user specify the illegal shuffle node num(<0) * it will use the default shuffle nodes num when having enough servers. */ - PartitionRangeAssignment pra = strategy.assign(100, 10, 1, serverTags, -1); + PartitionRangeAssignment pra = strategy.assign(100, 10, 1, serverTags, -1, -1); assertEquals( shuffleNodesMax, pra.getAssignments() @@ -179,7 +180,7 @@ public void testAssignmentShuffleNodesNum() { * case2: user specify the illegal shuffle node num(==0) * it will use the default shuffle nodes num when having enough servers. */ - pra = strategy.assign(100, 10, 1, serverTags, 0); + pra = strategy.assign(100, 10, 1, serverTags, 0, -1); assertEquals( shuffleNodesMax, pra.getAssignments() @@ -194,7 +195,7 @@ public void testAssignmentShuffleNodesNum() { * case3: user specify the illegal shuffle node num(>default max limitation) * it will use the default shuffle nodes num when having enough servers */ - pra = strategy.assign(100, 10, 1, serverTags, shuffleNodesMax + 10); + pra = strategy.assign(100, 10, 1, serverTags, shuffleNodesMax + 10, -1); assertEquals( shuffleNodesMax, pra.getAssignments() @@ -209,7 +210,7 @@ public void testAssignmentShuffleNodesNum() { * case4: user specify the legal shuffle node num, * it will use the customized shuffle nodes num when having enough servers */ - pra = strategy.assign(100, 10, 1, serverTags, shuffleNodesMax - 1); + pra = strategy.assign(100, 10, 1, serverTags, shuffleNodesMax - 1, -1); assertEquals( shuffleNodesMax - 1, pra.getAssignments() @@ -229,7 +230,7 @@ public void testAssignmentShuffleNodesNum() { clusterManager.add(new ServerNode("t2-" + i, "", 0, 0, 0, 20 - i, 0, serverTags, true)); } - pra = strategy.assign(100, 10, 1, serverTags, shuffleNodesMax); + pra = strategy.assign(100, 10, 1, serverTags, shuffleNodesMax, -1); assertEquals( shuffleNodesMax - 1, pra.getAssignments() @@ -240,4 +241,63 @@ public void testAssignmentShuffleNodesNum() { .size() ); } + + @Test + public void testWithContinuousSelectPartitionStrategy() throws Exception { + CoordinatorConf ssc = new CoordinatorConf(); + ssc.set(CoordinatorConf.COORDINATOR_SELECT_PARTITION_STRATEGY, + AbstractAssignmentStrategy.SelectPartitionStrategyName.CONTINUOUS); + ssc.setInteger(CoordinatorConf.COORDINATOR_SHUFFLE_NODES_MAX, shuffleNodesMax); + clusterManager = new SimpleClusterManager(ssc, new Configuration()); + strategy = new BasicAssignmentStrategy(clusterManager, ssc); + List list = Lists.newArrayList(20L, 20L, 20L, 20L, 20L, 20L, 20L, 20L, 20L, 20L, + 20L, 20L, 20L, 20L, 20L, 20L, 20L, 20L, 20L, 20L); + updateServerResource(list); + PartitionRangeAssignment assignment = strategy.assign(100, 1, 2, tags, 5, 20); + List expect = Lists.newArrayList(40L, 40L, 40L, 40L, 40L); + valid(expect, assignment.getAssignments()); + + assignment = strategy.assign(28, 1, 2, tags, 5, 20); + expect = Lists.newArrayList(11L, 12L, 12L, 11L, 10L); + valid(expect, assignment.getAssignments()); + + assignment = strategy.assign(29, 1, 2, tags, 5, 4); + expect = Lists.newArrayList(11L, 12L, 12L, 12L, 11L); + valid(expect, assignment.getAssignments()); + + assignment = strategy.assign(29, 2, 2, tags, 5, 4); + expect = Lists.newArrayList(12L, 12L, 12L, 12L, 12L); + valid(expect, assignment.getAssignments()); + } + + void updateServerResource(List resources) { + for (int i = 0; i < resources.size(); i++) { + ServerNode node = new ServerNode( + String.valueOf((char)('a' + i)), + "127.0.0." + i, + 0, + 10L, + 5L, + resources.get(i), + 5, + tags, + true); + clusterManager.add(node); + } + } + + private void valid(List expect, SortedMap> partitionToServerNodes) { + // Unable to match exactly, the order of the server is disordered + int actualPartitionNum = 0; + Set serverNodes = Sets.newHashSet(); + for (Map.Entry> entry : partitionToServerNodes.entrySet()) { + PartitionRange range = entry.getKey(); + actualPartitionNum += (range.getEnd() - range.getStart() + 1) * entry.getValue().size(); + serverNodes.addAll(entry.getValue()); + } + + long expectPartitionNum = expect.stream().mapToLong(Long::longValue).sum(); + assertEquals(expect.size(), serverNodes.size()); + assertEquals(expectPartitionNum, actualPartitionNum); + } } diff --git a/coordinator/src/test/java/org/apache/uniffle/coordinator/ContinuousSelectPartitionStrategyTest.java b/coordinator/src/test/java/org/apache/uniffle/coordinator/ContinuousSelectPartitionStrategyTest.java new file mode 100644 index 0000000000..fa21352c90 --- /dev/null +++ b/coordinator/src/test/java/org/apache/uniffle/coordinator/ContinuousSelectPartitionStrategyTest.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.coordinator; + +import java.util.Collection; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.SortedMap; +import java.util.TreeMap; + +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import org.junit.jupiter.api.Test; + +import org.apache.uniffle.common.PartitionRange; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class ContinuousSelectPartitionStrategyTest { + private int shuffleNodesMax = 5; + private Set tags = Sets.newHashSet("test"); + + @Test + public void test() throws Exception { + ContinuousSelectPartitionStrategy strategy = new ContinuousSelectPartitionStrategy(); + + List serverNodes = generateServerResource(Lists.newArrayList(20L, 20L, 20L, 20L, 20L)); + SortedMap> assignments = strategy.assign(100, 2, 2, serverNodes, 20); + assertEquals(50, assignments.size()); + List expect = Lists.newArrayList(20L, 20L, 20L, 20L, 20L); + valid(expect, assignments); + + assignments = strategy.assign(100, 2, 3, serverNodes, 20); + assertEquals(50, assignments.size()); + expect = Lists.newArrayList(30L, 30L, 30L, 30L, 30L); + valid(expect, assignments); + + assignments = strategy.assign(100, 2, 2, serverNodes, 4); + assertEquals(50, assignments.size()); + expect = Lists.newArrayList(20L, 20L, 20L, 20L, 20L); + valid(expect, assignments); + + assignments = strategy.assign(98, 2, 2, serverNodes, 20); + assertEquals(49, assignments.size()); + expect = Lists.newArrayList(19L, 20L, 20L, 20L, 19L); + valid(expect, assignments); + + assignments = strategy.assign(98, 2, 3, serverNodes, 20); + assertEquals(49, assignments.size()); + expect = Lists.newArrayList(29L, 29L, 30L, 30L, 29L); + valid(expect, assignments); + + assignments = strategy.assign(98, 2, 3, serverNodes, 4); + assertEquals(49, assignments.size()); + expect = Lists.newArrayList(29L, 29L, 30L, 30L, 29L); + valid(expect, assignments); + + assignments = strategy.assign(4, 2, 2, serverNodes, 4); + assertEquals(2, assignments.size()); + expect = Lists.newArrayList(1L, 2L, 1L); + valid(expect, assignments); + } + + private List generateServerResource(List resources) { + List serverNodes = Lists.newArrayList(); + for (int i = 0; i < resources.size(); i++) { + ServerNode node = new ServerNode( + String.valueOf((char) ('a' + i)), + "127.0.0." + i, + 0, + 10L, + 5L, + resources.get(i), + 5, + tags, + true); + serverNodes.add(node); + } + return serverNodes; + } + + private void valid(List expect, SortedMap> partitionToServerNodes) { + SortedMap serverToPartitionRangeNums = new TreeMap<>(Comparator.comparing(ServerNode::getId)); + partitionToServerNodes.values().stream().flatMap(Collection::stream).forEach(serverNode -> { + int oldVal = serverToPartitionRangeNums.getOrDefault(serverNode, 0); + serverToPartitionRangeNums.put(serverNode, oldVal + 1); + }); + assertEquals(serverToPartitionRangeNums.size(), expect.size()); + + int i = 0; + for (Map.Entry entry : serverToPartitionRangeNums.entrySet()) { + int partitionNum = entry.getValue(); + assertEquals(expect.get(i), partitionNum); + i++; + } + } +} diff --git a/coordinator/src/test/java/org/apache/uniffle/coordinator/CoordinatorUtilsTest.java b/coordinator/src/test/java/org/apache/uniffle/coordinator/CoordinatorUtilsTest.java index f2c45c28ac..f862cf79df 100644 --- a/coordinator/src/test/java/org/apache/uniffle/coordinator/CoordinatorUtilsTest.java +++ b/coordinator/src/test/java/org/apache/uniffle/coordinator/CoordinatorUtilsTest.java @@ -90,4 +90,50 @@ private void compareConfMap(Map> expect, Map> rangesGroup = CoordinatorUtils.generateRangesGroup(52,2, 5, 20); + assertEquals(15, rangesGroup.size()); + validate(new int[]{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1}, rangesGroup); + + rangesGroup = CoordinatorUtils.generateRangesGroup(48,2, 5, 20); + assertEquals(14, rangesGroup.size()); + validate(new int[]{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1}, rangesGroup); + + rangesGroup = CoordinatorUtils.generateRangesGroup(96,2, 5, 20); + assertEquals(25, rangesGroup.size()); + validate(new int[]{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1}, rangesGroup); + + rangesGroup = CoordinatorUtils.generateRangesGroup(96,2, 5, 30); + assertEquals(18, rangesGroup.size()); + validate(new int[]{3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 1, 1}, rangesGroup); + + rangesGroup = CoordinatorUtils.generateRangesGroup(48,1, 5, 20); + assertEquals(15, rangesGroup.size()); + validate(new int[]{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 2, 2, 1, 1}, rangesGroup); + + rangesGroup = CoordinatorUtils.generateRangesGroup(26,2, 5, 4); + assertEquals(13, rangesGroup.size()); + validate(new int[]{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, rangesGroup); + + rangesGroup = CoordinatorUtils.generateRangesGroup(2,2, 5, 4); + assertEquals(1, rangesGroup.size()); + validate(new int[]{1}, rangesGroup); + + rangesGroup = CoordinatorUtils.generateRangesGroup(12,2, 5, 0); + assertEquals(6, rangesGroup.size()); + validate(new int[]{1, 1, 1, 1, 1, 1}, rangesGroup); + + rangesGroup = CoordinatorUtils.generateRangesGroup(24,2, 5, 50); + assertEquals(7, rangesGroup.size()); + validate(new int[]{2, 2, 2, 2, 2, 1, 1}, rangesGroup); + } + + private void validate(int[] expect, List> rangesGroup) { + for (int i = 0; i < expect.length; i++) { + assertEquals(expect[i], rangesGroup.get(i).size()); + } + } + } diff --git a/coordinator/src/test/java/org/apache/uniffle/coordinator/PartitionBalanceAssignmentStrategyTest.java b/coordinator/src/test/java/org/apache/uniffle/coordinator/PartitionBalanceAssignmentStrategyTest.java index fad48f7756..43b2e0d783 100644 --- a/coordinator/src/test/java/org/apache/uniffle/coordinator/PartitionBalanceAssignmentStrategyTest.java +++ b/coordinator/src/test/java/org/apache/uniffle/coordinator/PartitionBalanceAssignmentStrategyTest.java @@ -64,32 +64,32 @@ public void testAssign() { updateServerResource(list); boolean isThrown = false; try { - strategy.assign(100, 2, 1, tags, -1); + strategy.assign(100, 2, 1, tags, -1, -1); } catch (Exception e) { isThrown = true; } assertTrue(isThrown); try { - strategy.assign(0, 1, 1, tags, -1); + strategy.assign(0, 1, 1, tags, -1, -1); } catch (Exception e) { fail(); } isThrown = false; try { - strategy.assign(10, 1, 1, Sets.newHashSet("fake"), 1); + strategy.assign(10, 1, 1, Sets.newHashSet("fake"), 1, -1); } catch (Exception e) { isThrown = true; } assertTrue(isThrown); - strategy.assign(100, 1, 1, tags, -1); + strategy.assign(100, 1, 1, tags, -1, -1); List expect = Lists.newArrayList(20L, 20L, 20L, 20L, 20L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L); valid(expect); - strategy.assign(75, 1, 1, tags, -1); + strategy.assign(75, 1, 1, tags, -1, -1); expect = Lists.newArrayList(20L, 20L, 20L, 20L, 20L, 15L, 15L, 15L, 15L, 15L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L); valid(expect); - strategy.assign(100, 1, 1, tags, -1); + strategy.assign(100, 1, 1, tags, -1, -1); expect = Lists.newArrayList(20L, 20L, 20L, 20L, 20L, 15L, 15L, 15L, 15L, 15L, 20L, 20L, 20L, 20L, 20L, 0L, 0L, 0L, 0L, 0L); valid(expect); @@ -98,16 +98,16 @@ public void testAssign() { list = Lists.newArrayList(7L, 18L, 7L, 3L, 19L, 15L, 11L, 10L, 16L, 11L, 14L, 17L, 15L, 17L, 8L, 1L, 3L, 3L, 6L, 12L); updateServerResource(list); - strategy.assign(100, 1, 1, tags, -1); + strategy.assign(100, 1, 1, tags, -1, -1); expect = Lists.newArrayList(0L, 20L, 0L, 0L, 20L, 0L, 0L, 0L, 20L, 0L, 0L, 20L, 0L, 20L, 0L, 0L, 0L, 0L, 0L, 0L); valid(expect); - strategy.assign(50, 1, 1, tags, -1); + strategy.assign(50, 1, 1, tags, -1, -1); expect = Lists.newArrayList(0L, 20L, 0L, 0L, 20L, 10L, 10L, 0L, 20L, 0L, 10L, 20L, 10L, 20L, 0L, 0L, 0L, 0L, 0L, 10L); valid(expect); - strategy.assign(75, 1, 1, tags, -1); + strategy.assign(75, 1, 1, tags, -1, -1); expect = Lists.newArrayList(0L, 20L, 0L, 0L, 20L, 25L, 10L, 15L, 20L, 15L, 25L, 20L, 25L, 20L, 0L, 0L, 0L, 0L, 0L, 10L); valid(expect); @@ -116,15 +116,15 @@ public void testAssign() { list = Lists.newArrayList(7L, 18L, 7L, 3L, 19L, 15L, 11L, 10L, 16L, 11L, 14L, 17L, 15L, 17L, 8L, 1L, 3L, 3L, 6L, 12L); updateServerResource(list); - strategy.assign(50, 1, 2, tags, -1); + strategy.assign(50, 1, 2, tags, -1, -1); expect = Lists.newArrayList(0L, 20L, 0L, 0L, 20L, 0L, 0L, 0L, 20L, 0L, 0L, 20L, 0L, 20L, 0L, 0L, 0L, 0L, 0L, 0L); valid(expect); - strategy.assign(75, 1, 2, tags, -1); + strategy.assign(75, 1, 2, tags, -1, -1); expect = Lists.newArrayList(0L, 20L, 0L, 0L, 50L, 30L, 0L, 0L, 20L, 0L, 30L, 20L, 30L, 20L, 0L, 0L, 0L, 0L, 0L, 30L); valid(expect); - strategy.assign(33, 1, 2, tags, -1); + strategy.assign(33, 1, 2, tags, -1, -1); expect = Lists.newArrayList(0L, 33L, 0L, 0L, 50L, 30L, 14L, 13L, 20L, 13L, 30L, 20L, 30L, 20L, 13L, 0L, 0L, 0L, 0L, 30L); valid(expect); @@ -140,19 +140,19 @@ public void testAssign() { Uninterruptibles.sleepUninterruptibly(10, TimeUnit.MILLISECONDS); updateServerResource(list); - strategy.assign(33, 1, 1, tags, -1); + strategy.assign(33, 1, 1, tags, -1, -1); expect = Lists.newArrayList(0L, 7L, 0L, 7L, 0L, 7L, 0L, 6L, 0L, 6L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L); valid(expect); - strategy.assign(41, 1, 2, tags, -1); + strategy.assign(41, 1, 2, tags, -1, -1); expect = Lists.newArrayList(0L, 7L, 0L, 7L, 0L, 7L, 0L, 6L, 0L, 6L, 0L, 17L, 0L, 17L, 0L, 16L, 0L, 16L, 0L, 16L); valid(expect); - strategy.assign(23, 1, 1, tags, -1); + strategy.assign(23, 1, 1, tags, -1, -1); expect = Lists.newArrayList(5L, 7L, 5L, 7L, 5L, 7L, 4L, 6L, 4L, 6L, 0L, 17L, 0L, 17L, 0L, 16L, 0L, 16L, 0L, 16L); valid(expect); - strategy.assign(11, 1, 3, tags, -1); + strategy.assign(11, 1, 3, tags, -1, -1); expect = Lists.newArrayList(5L, 7L, 5L, 7L, 5L, 7L, 4L, 13L, 4L, 13L, 7L, 17L, 6L, 17L, 6L, 16L, 0L, 16L, 0L, 16L); valid(expect); @@ -209,7 +209,7 @@ public void testAssignmentShuffleNodesNum() { * case1: user specify the illegal shuffle node num(<0) * it will use the default shuffle nodes num when having enough servers. */ - PartitionRangeAssignment pra = strategy.assign(100, 1, 1, serverTags, -1); + PartitionRangeAssignment pra = strategy.assign(100, 1, 1, serverTags, -1, -1); assertEquals( shuffleNodesMax, pra.getAssignments() @@ -224,7 +224,7 @@ public void testAssignmentShuffleNodesNum() { * case2: user specify the illegal shuffle node num(==0) * it will use the default shuffle nodes num when having enough servers. */ - pra = strategy.assign(100, 1, 1, serverTags, 0); + pra = strategy.assign(100, 1, 1, serverTags, 0, -1); assertEquals( shuffleNodesMax, pra.getAssignments() @@ -239,7 +239,7 @@ public void testAssignmentShuffleNodesNum() { * case3: user specify the illegal shuffle node num(>default max limitation) * it will use the default shuffle nodes num when having enough servers */ - pra = strategy.assign(100, 1, 1, serverTags, shuffleNodesMax + 10); + pra = strategy.assign(100, 1, 1, serverTags, shuffleNodesMax + 10, -1); assertEquals( shuffleNodesMax, pra.getAssignments() @@ -254,7 +254,7 @@ public void testAssignmentShuffleNodesNum() { * case4: user specify the legal shuffle node num, * it will use the customized shuffle nodes num when having enough servers */ - pra = strategy.assign(100, 1, 1, serverTags, shuffleNodesMax - 1); + pra = strategy.assign(100, 1, 1, serverTags, shuffleNodesMax - 1, -1); assertEquals( shuffleNodesMax - 1, pra.getAssignments() @@ -274,7 +274,7 @@ public void testAssignmentShuffleNodesNum() { clusterManager.add(new ServerNode("t2-" + i, "127.0.0." + i, 0, 0, 0, 20 - i, 0, serverTags, true)); } - pra = strategy.assign(100, 1, 1, serverTags, shuffleNodesMax); + pra = strategy.assign(100, 1, 1, serverTags, shuffleNodesMax, -1); assertEquals( shuffleNodesMax - 1, pra.getAssignments() @@ -291,8 +291,8 @@ public void testAssignmentShuffleNodesNum() { public void testAssignmentWithMustDiff() throws Exception { CoordinatorConf ssc = new CoordinatorConf(); ssc.setInteger(CoordinatorConf.COORDINATOR_SHUFFLE_NODES_MAX, shuffleNodesMax); - ssc.set(CoordinatorConf.COORDINATOR_ASSGINMENT_HOST_STRATEGY, - AbstractAssignmentStrategy.HostAssignmentStrategy.MUST_DIFF); + ssc.set(CoordinatorConf.COORDINATOR_ASSIGNMENT_HOST_STRATEGY, + AbstractAssignmentStrategy.HostAssignmentStrategyName.MUST_DIFF); SimpleClusterManager clusterManager = new SimpleClusterManager(ssc, new Configuration()); AssignmentStrategy strategy = new PartitionBalanceAssignmentStrategy(clusterManager, ssc); @@ -306,7 +306,7 @@ public void testAssignmentWithMustDiff() throws Exception { clusterManager.add(new ServerNode("t2-" + i, "127.0.0." + i, 1, 0, 0, 20 - i, 0, serverTags, true)); } - PartitionRangeAssignment pra = strategy.assign(100, 1, 5, serverTags, -1); + PartitionRangeAssignment pra = strategy.assign(100, 1, 5, serverTags, -1, -1); pra.getAssignments().values().forEach((nodeList) -> { Map nodeMap = new HashMap<>(); nodeList.forEach((node) -> { @@ -316,7 +316,7 @@ public void testAssignmentWithMustDiff() throws Exception { }); }); - pra = strategy.assign(100, 1, 6, serverTags, -1); + pra = strategy.assign(100, 1, 6, serverTags, -1, -1); pra.getAssignments().values().forEach((nodeList) -> { Map nodeMap = new HashMap<>(); boolean hasSameHost = false; @@ -337,8 +337,8 @@ public void testAssignmentWithMustDiff() throws Exception { public void testAssignmentWithPreferDiff() throws Exception { CoordinatorConf ssc = new CoordinatorConf(); ssc.setInteger(CoordinatorConf.COORDINATOR_SHUFFLE_NODES_MAX, shuffleNodesMax); - ssc.set(CoordinatorConf.COORDINATOR_ASSGINMENT_HOST_STRATEGY, - AbstractAssignmentStrategy.HostAssignmentStrategy.PREFER_DIFF); + ssc.set(CoordinatorConf.COORDINATOR_ASSIGNMENT_HOST_STRATEGY, + AbstractAssignmentStrategy.HostAssignmentStrategyName.PREFER_DIFF); SimpleClusterManager clusterManager = new SimpleClusterManager(ssc, new Configuration()); AssignmentStrategy strategy = new PartitionBalanceAssignmentStrategy(clusterManager, ssc); Set serverTags = Sets.newHashSet("tag-1"); @@ -351,7 +351,7 @@ public void testAssignmentWithPreferDiff() throws Exception { clusterManager.add(new ServerNode("t2-" + i, "127.0.0." + i, 1, 0, 0, 20 - i, 0, serverTags, true)); } - PartitionRangeAssignment pra = strategy.assign(100, 1, 5, serverTags, -1); + PartitionRangeAssignment pra = strategy.assign(100, 1, 5, serverTags, -1, -1); pra.getAssignments().values().forEach((nodeList) -> { assertEquals(5, nodeList.size()); }); @@ -367,7 +367,7 @@ public void testAssignmentWithPreferDiff() throws Exception { 20 - i, 0, serverTags, true)); } strategy = new PartitionBalanceAssignmentStrategy(clusterManager, ssc); - pra = strategy.assign(100, 1, 3, serverTags, -1); + pra = strategy.assign(100, 1, 3, serverTags, -1, -1); pra.getAssignments().values().forEach((nodeList) -> { Map nodeMap = new HashMap<>(); nodeList.forEach((node) -> { @@ -382,8 +382,8 @@ public void testAssignmentWithPreferDiff() throws Exception { public void testAssignmentWithNone() throws Exception { CoordinatorConf ssc = new CoordinatorConf(); ssc.setInteger(CoordinatorConf.COORDINATOR_SHUFFLE_NODES_MAX, shuffleNodesMax); - ssc.set(CoordinatorConf.COORDINATOR_ASSGINMENT_HOST_STRATEGY, - AbstractAssignmentStrategy.HostAssignmentStrategy.NONE); + ssc.set(CoordinatorConf.COORDINATOR_ASSIGNMENT_HOST_STRATEGY, + AbstractAssignmentStrategy.HostAssignmentStrategyName.NONE); SimpleClusterManager clusterManager = new SimpleClusterManager(ssc, new Configuration()); AssignmentStrategy strategy = new PartitionBalanceAssignmentStrategy(clusterManager, ssc); Set serverTags = Sets.newHashSet("tag-1"); @@ -396,9 +396,37 @@ public void testAssignmentWithNone() throws Exception { clusterManager.add(new ServerNode("t2-" + i, "127.0.0." + i, 1, 0, 0, 20 - i, 0, serverTags, true)); } - PartitionRangeAssignment pra = strategy.assign(100, 1, 5, serverTags, -1); + PartitionRangeAssignment pra = strategy.assign(100, 1, 5, serverTags, -1, -1); pra.getAssignments().values().forEach((nodeList) -> { assertEquals(5, nodeList.size()); }); } + + @Test + public void testWithContinuousSelectPartitionStrategy() throws Exception { + CoordinatorConf ssc = new CoordinatorConf(); + ssc.set(CoordinatorConf.COORDINATOR_SELECT_PARTITION_STRATEGY, + AbstractAssignmentStrategy.SelectPartitionStrategyName.CONTINUOUS); + ssc.setInteger(CoordinatorConf.COORDINATOR_SHUFFLE_NODES_MAX, shuffleNodesMax); + clusterManager = new SimpleClusterManager(ssc, new Configuration()); + strategy = new PartitionBalanceAssignmentStrategy(clusterManager, ssc); + List list = Lists.newArrayList(20L, 20L, 20L, 20L, 20L, 20L, 20L, 20L, 20L, 20L, + 20L, 20L, 20L, 20L, 20L, 20L, 20L, 20L, 20L, 20L); + updateServerResource(list); + strategy.assign(100, 1, 2, tags, 5, 20); + List expect = Lists.newArrayList(40L, 40L, 40L, 40L, 40L, 0L, 0L, 0L, 0L, 0L, + 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L); + valid(expect); + + strategy.assign(28, 1, 2, tags, 5, 20); + expect = Lists.newArrayList(40L, 40L, 40L, 40L, 40L, 11L, 12L, 12L, 11L, 10L, + 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L); + valid(expect); + + strategy.assign(29, 1, 2, tags, 5, 4); + expect = Lists.newArrayList(40L, 40L, 40L, 40L, 40L, 11L, 12L, 12L, 11L, 10L, + 11L, 12L, 12L, 12L, 11L, 0L, 0L, 0L, 0L, 0L); + valid(expect); + } + } diff --git a/docs/client_guide.md b/docs/client_guide.md index 61cd6852fe..71b9f5e1ed 100644 --- a/docs/client_guide.md +++ b/docs/client_guide.md @@ -56,15 +56,30 @@ After apply the patch and rebuild spark, add following configuration in spark co ### Support Spark AQE -To improve performance of AQE skew optimization, uniffle introduces the LOCAL_ORDER shuffle-data distribution mechanism -to filter the lots of data to reduce network bandwidth and shuffle-server local-disk pressure. - -It can be enabled by the following config - ```bash - # Default value is NORMAL, it will directly append to file when the memory data is flushed to external storage - spark.rss.client.shuffle.data.distribution.type LOCAL_ORDER - ``` - +To improve performance of AQE skew optimization, uniffle introduces the LOCAL_ORDER shuffle-data distribution mechanism +and Continuous partition assignment mechanism. + +1. LOCAL_ORDER shuffle-data distribution mechanism filter the lots of data to reduce network bandwidth and shuffle-server local-disk pressure. + + It can be enabled by the following config + ```bash + # Default value is NORMAL, it will directly append to file when the memory data is flushed to external storage + spark.rss.client.shuffle.data.distribution.type LOCAL_ORDER + ``` + +2. Continuous partition assignment mechanism assign consecutive partitions to the same ShuffleServer to reduce the frequency of getShuffleResult. + + It can be enabled by the following config + ```bash + # Default value is ROUND, it will poll to allocate partitions to ShuffleServer + rss.coordinator.select.partition.strategy CONTINUOUS + + # Default value is false, the CONTINUOUS allocation mechanism relies on enabling this configuration, and estimates how many consecutive allocations should be allocated based on task concurrency + --conf spark.rss.estimate.task.concurrency.enabled=true + + # Default value is 1.0, used to estimate task concurrency, how likely is this part of the resource between spark.dynamicAllocation.minExecutors and spark.dynamicAllocation.maxExecutors to be allocated + --conf spark.rss.estimate.task.concurrency.dynamic.factor=1.0 + ``` ### Deploy MapReduce Client Plugin 1. Add client jar to the classpath of each NodeManager, e.g., /share/hadoop/mapreduce/ @@ -103,6 +118,8 @@ These configurations are shared by all types of clients. |.rss.client.io.compression.codec|lz4|The compression codec is used to compress the shuffle data. Default codec is `lz4`. Other options are`ZSTD` and `SNAPPY`.| |.rss.client.io.compression.zstd.level|3|The zstd compression level, the default level is 3| |.rss.client.shuffle.data.distribution.type|NORMAL|The type of partition shuffle data distribution, including normal and local_order. The default value is normal. Now this config is only valid in Spark3.x| +|.rss.estimate.task.concurrency.enabled|false|Only works in spark3, whether to enable task concurrency estimation, only valid if rss.coordinator.select.partition.strategy is CONTINUOUS, this feature can improve performance in AQE scenarios.| +|.rss.estimate.task.concurrency.dynamic.factor|1.0|Between 0 and 1, used to estimate task concurrency, how likely is this part of the resource between spark.dynamicAllocation.minExecutors and spark.dynamicAllocation.maxExecutors to be allocated. Only works in spark3, .rss.estimate.task.concurrency.enabled=true, and Coordinator's rss.coordinator.select.partition.strategy is CONTINUOUS.| Notice: 1. `` should be `spark` or `mapreduce` diff --git a/docs/coordinator_guide.md b/docs/coordinator_guide.md index 13a2190683..90d20aea9c 100644 --- a/docs/coordinator_guide.md +++ b/docs/coordinator_guide.md @@ -101,6 +101,7 @@ This document will introduce how to deploy Uniffle coordinators. |rss.coordinator.remote.storage.io.sample.access.times|3|The number of times to read and write HDFS files| |rss.coordinator.startup-silent-period.enabled|false|Enable the startup-silent-period to reject the assignment requests for avoiding partial assignments. To avoid service interruption, this mechanism is disabled by default. Especially it's recommended to use in coordinator HA mode when restarting single coordinator.| |rss.coordinator.startup-silent-period.duration|20000|The waiting duration(ms) when conf of rss.coordinator.startup-silent-period.enabled is enabled.| +|rss.coordinator.select.partition.strategy|ROUND|There are two strategies for selecting partitions: ROUND and CONTINUOUS. ROUND will poll to allocate partitions to ShuffleServer, and CONTINUOUS will try to allocate consecutive partitions to ShuffleServer, this feature can improve performance in AQE scenarios.| ### AccessClusterLoadChecker settings |Property Name|Default| Description| diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentWithTagsTest.java b/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentWithTagsTest.java index 2688704e6f..f54e6050bb 100644 --- a/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentWithTagsTest.java +++ b/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentWithTagsTest.java @@ -154,7 +154,7 @@ public void testTags() throws Exception { // Case1 : only set the single default shuffle version tag ShuffleAssignmentsInfo assignmentsInfo = shuffleWriteClient.getShuffleAssignments("app-1", - 1, 1, 1, Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION), 1); + 1, 1, 1, Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION), 1, -1); List assignedServerPorts = assignmentsInfo .getPartitionToServers() @@ -169,7 +169,7 @@ public void testTags() throws Exception { // Case2: Set the single non-exist shuffle server tag try { assignmentsInfo = shuffleWriteClient.getShuffleAssignments("app-2", - 1, 1, 1, Sets.newHashSet("non-exist"), 1); + 1, 1, 1, Sets.newHashSet("non-exist"), 1, -1); fail(); } catch (Exception e) { assertTrue(e.getMessage().startsWith("Error happened when getShuffleAssignments with")); @@ -177,7 +177,7 @@ public void testTags() throws Exception { // Case3: Set the single fixed tag assignmentsInfo = shuffleWriteClient.getShuffleAssignments("app-3", - 1, 1, 1, Sets.newHashSet("fixed"), 1); + 1, 1, 1, Sets.newHashSet("fixed"), 1, -1); assignedServerPorts = assignmentsInfo .getPartitionToServers() .values() @@ -190,7 +190,7 @@ public void testTags() throws Exception { // case4: Set the multiple tags if exists assignmentsInfo = shuffleWriteClient.getShuffleAssignments("app-4", - 1, 1, 1, Sets.newHashSet("fixed", Constants.SHUFFLE_SERVER_VERSION), 1); + 1, 1, 1, Sets.newHashSet("fixed", Constants.SHUFFLE_SERVER_VERSION), 1, -1); assignedServerPorts = assignmentsInfo .getPartitionToServers() .values() @@ -204,7 +204,7 @@ public void testTags() throws Exception { // case5: Set the multiple tags if non-exist try { assignmentsInfo = shuffleWriteClient.getShuffleAssignments("app-5", - 1, 1, 1, Sets.newHashSet("fixed", "elastic", Constants.SHUFFLE_SERVER_VERSION), 1); + 1, 1, 1, Sets.newHashSet("fixed", "elastic", Constants.SHUFFLE_SERVER_VERSION), 1, -1); fail(); } catch (Exception e) { assertTrue(e.getMessage().startsWith("Error happened when getShuffleAssignments with")); diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/CoordinatorAssignmentTest.java b/integration-test/common/src/test/java/org/apache/uniffle/test/CoordinatorAssignmentTest.java index 455bd69ff0..234ae4cdba 100644 --- a/integration-test/common/src/test/java/org/apache/uniffle/test/CoordinatorAssignmentTest.java +++ b/integration-test/common/src/test/java/org/apache/uniffle/test/CoordinatorAssignmentTest.java @@ -91,7 +91,7 @@ public void testSilentPeriod() throws Exception { shuffleWriteClient.registerCoordinators(QUORUM); // Case1: Disable silent period - ShuffleAssignmentsInfo info = shuffleWriteClient.getShuffleAssignments("app1", 0, 10, 1, TAGS, -1); + ShuffleAssignmentsInfo info = shuffleWriteClient.getShuffleAssignments("app1", 0, 10, 1, TAGS, -1, -1); assertEquals(SHUFFLE_NODES_MAX, info.getServerToPartitionRanges().keySet().size()); // Case2: Enable silent period mechanism, it should fallback to slave coordinator. @@ -101,7 +101,7 @@ public void testSilentPeriod() throws Exception { clusterManager.setStartTime(System.currentTimeMillis() - 1); if (clusterManager.getNodesNum() < 10) { - info = shuffleWriteClient.getShuffleAssignments("app1", 0, 10, 1, TAGS, -1); + info = shuffleWriteClient.getShuffleAssignments("app1", 0, 10, 1, TAGS, -1, -1); assertEquals(SHUFFLE_NODES_MAX, info.getServerToPartitionRanges().keySet().size()); } @@ -119,28 +119,28 @@ public void testAssignmentServerNodesNumber() throws Exception { * case1: user specify the illegal shuffle node num(<0) * it will use the default shuffle nodes num when having enough servers. */ - ShuffleAssignmentsInfo info = shuffleWriteClient.getShuffleAssignments("app1", 0, 10, 1, TAGS, -1); + ShuffleAssignmentsInfo info = shuffleWriteClient.getShuffleAssignments("app1", 0, 10, 1, TAGS, -1, -1); assertEquals(SHUFFLE_NODES_MAX, info.getServerToPartitionRanges().keySet().size()); /** * case2: user specify the illegal shuffle node num(==0) * it will use the default shuffle nodes num when having enough servers. */ - info = shuffleWriteClient.getShuffleAssignments("app1", 0, 10, 1, TAGS, 0); + info = shuffleWriteClient.getShuffleAssignments("app1", 0, 10, 1, TAGS, 0, -1); assertEquals(SHUFFLE_NODES_MAX, info.getServerToPartitionRanges().keySet().size()); /** * case3: user specify the illegal shuffle node num(>default max limitation) * it will use the default shuffle nodes num when having enough servers */ - info = shuffleWriteClient.getShuffleAssignments("app1", 0, 10, 1, TAGS, SERVER_NUM + 10); + info = shuffleWriteClient.getShuffleAssignments("app1", 0, 10, 1, TAGS, SERVER_NUM + 10, -1); assertEquals(SHUFFLE_NODES_MAX, info.getServerToPartitionRanges().keySet().size()); /** * case4: user specify the legal shuffle node num, * it will use the customized shuffle nodes num when having enough servers */ - info = shuffleWriteClient.getShuffleAssignments("app1", 0, 10, 1, TAGS, SERVER_NUM - 1); + info = shuffleWriteClient.getShuffleAssignments("app1", 0, 10, 1, TAGS, SERVER_NUM - 1, -1); assertEquals(SHUFFLE_NODES_MAX - 1, info.getServerToPartitionRanges().keySet().size()); } } diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java index 0c98537863..b9e76f70d8 100644 --- a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java +++ b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java @@ -325,7 +325,7 @@ public void testRetryAssgin() throws Throwable { response = RetryUtils.retry(() -> { int currentTryTime = tryTime.incrementAndGet(); ShuffleAssignmentsInfo shuffleAssignments = shuffleWriteClientImpl.getShuffleAssignments(appId, - 1, 1, 1, Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION), 1); + 1, 1, 1, Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION), 1, -1); Map> serverToPartitionRanges = shuffleAssignments.getServerToPartitionRanges(); diff --git a/integration-test/spark3/src/test/java/org/apache/uniffle/test/ContinuousSelectPartitionStrategyTest.java b/integration-test/spark3/src/test/java/org/apache/uniffle/test/ContinuousSelectPartitionStrategyTest.java new file mode 100644 index 0000000000..02a610128d --- /dev/null +++ b/integration-test/spark3/src/test/java/org/apache/uniffle/test/ContinuousSelectPartitionStrategyTest.java @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.test; + +import java.io.File; +import java.nio.file.Files; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; + +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import org.apache.spark.SparkConf; +import org.apache.spark.shuffle.RssSparkConfig; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec; +import org.apache.spark.sql.execution.joins.SortMergeJoinExec; +import org.apache.spark.sql.functions; +import org.apache.spark.sql.internal.SQLConf; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import org.apache.uniffle.coordinator.AbstractAssignmentStrategy; +import org.apache.uniffle.coordinator.CoordinatorConf; +import org.apache.uniffle.server.MockedGrpcServer; +import org.apache.uniffle.server.MockedShuffleServerGrpcService; +import org.apache.uniffle.server.ShuffleServer; +import org.apache.uniffle.server.ShuffleServerConf; +import org.apache.uniffle.storage.util.StorageType; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ContinuousSelectPartitionStrategyTest extends SparkIntegrationTestBase { + + private static final int replicateWrite = 3; + private static final int replicateRead = 2; + + @BeforeAll + public static void setupServers() throws Exception { + CoordinatorConf coordinatorConf = getCoordinatorConf(); + Map dynamicConf = Maps.newHashMap(); + dynamicConf.put(CoordinatorConf.COORDINATOR_REMOTE_STORAGE_PATH.key(), HDFS_URI + "rss/test"); + dynamicConf.put(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.MEMORY_LOCALFILE_HDFS.name()); + + coordinatorConf.set(CoordinatorConf.COORDINATOR_SELECT_PARTITION_STRATEGY, + AbstractAssignmentStrategy.SelectPartitionStrategyName.CONTINUOUS); + addDynamicConf(coordinatorConf, dynamicConf); + createCoordinatorServer(coordinatorConf); + // Create multi shuffle servers + createShuffleServers(); + startServers(); + } + + private static void createShuffleServers() throws Exception { + for (int i = 0; i < 3; i++) { + // Copy from IntegrationTestBase#getShuffleServerConf + File dataFolder = Files.createTempDirectory("rssdata" + i).toFile(); + ShuffleServerConf serverConf = new ShuffleServerConf(); + dataFolder.deleteOnExit(); + serverConf.setInteger("rss.rpc.server.port", SHUFFLE_SERVER_PORT + i); + serverConf.setString("rss.storage.type", StorageType.MEMORY_LOCALFILE_HDFS.name()); + serverConf.setString("rss.storage.basePath", dataFolder.getAbsolutePath()); + serverConf.setString("rss.server.buffer.capacity", String.valueOf(671088640 - i)); + serverConf.setString("rss.server.memory.shuffle.highWaterMark", "50.0"); + serverConf.setString("rss.server.memory.shuffle.lowWaterMark", "0.0"); + serverConf.setString("rss.server.read.buffer.capacity", "335544320"); + serverConf.setString("rss.coordinator.quorum", COORDINATOR_QUORUM); + serverConf.setString("rss.server.heartbeat.delay", "1000"); + serverConf.setString("rss.server.heartbeat.interval", "1000"); + serverConf.setInteger("rss.jetty.http.port", 18080 + i); + serverConf.setInteger("rss.jetty.corePool.size", 64); + serverConf.setInteger("rss.rpc.executor.size", 10); + serverConf.setString("rss.server.hadoop.dfs.replication", "2"); + serverConf.setLong("rss.server.disk.capacity", 10L * 1024L * 1024L * 1024L); + serverConf.setBoolean("rss.server.health.check.enable", false); + createMockedShuffleServer(serverConf); + } + enableRecordGetShuffleResult(); + } + + private static void enableRecordGetShuffleResult() { + for (ShuffleServer shuffleServer : shuffleServers) { + ((MockedGrpcServer) shuffleServer.getServer()).getService() + .enableRecordGetShuffleResult(); + } + } + + @Override + public void updateCommonSparkConf(SparkConf sparkConf) { + sparkConf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), "true"); + sparkConf.set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD().key(), "-1"); + sparkConf.set(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM().key(), "1"); + sparkConf.set(SQLConf.SHUFFLE_PARTITIONS().key(), "100"); + sparkConf.set(SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD().key(), "800"); + sparkConf.set(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES().key(), "800"); + sparkConf.set("spark.dynamicAllocation.enabled", "true"); + sparkConf.set("spark.dynamicAllocation.maxExecutors", "5"); + sparkConf.set("spark.dynamicAllocation.minExecutors", "3"); + sparkConf.set("spark.executor.cores", "3"); + } + + @Override + public void updateSparkConfCustomer(SparkConf sparkConf) { + sparkConf.set(RssSparkConfig.RSS_STORAGE_TYPE.key(), "HDFS"); + sparkConf.set(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), HDFS_URI + "rss/test"); + } + + @Override + public void updateSparkConfWithRss(SparkConf sparkConf) { + super.updateSparkConfWithRss(sparkConf); + // Add multi replica conf + sparkConf.set(RssSparkConfig.RSS_DATA_REPLICA.key(), String.valueOf(replicateWrite)); + sparkConf.set(RssSparkConfig.RSS_DATA_REPLICA_WRITE.key(), String.valueOf(replicateWrite)); + sparkConf.set(RssSparkConfig.RSS_DATA_REPLICA_READ.key(), String.valueOf(replicateRead)); + sparkConf.set(RssSparkConfig.RSS_ESTIMATE_TASK_CONCURRENCY_ENABLED, true); + sparkConf.set("spark.shuffle.manager", + "org.apache.uniffle.test.GetShuffleReportForMultiPartTest$RssShuffleManagerWrapper"); + } + + @Test + public void resultCompareTest() throws Exception { + run(); + } + + @Override + Map runTest(SparkSession spark, String fileName) throws Exception { + Thread.sleep(4000); + Map map = Maps.newHashMap(); + Dataset df2 = spark.range(0, 1000, 1, 10) + .select(functions.when(functions.col("id").$less(250), 249) + .otherwise(functions.col("id")).as("key2"), functions.col("id").as("value2")); + Dataset df1 = spark.range(0, 1000, 1, 10) + .select(functions.when(functions.col("id").$less(250), 249) + .when(functions.col("id").$greater(750), 1000) + .otherwise(functions.col("id")).as("key1"), functions.col("id").as("value2")); + Dataset df3 = df1.join(df2, df1.col("key1").equalTo(df2.col("key2"))); + + List result = Lists.newArrayList(); + assertTrue(df3.queryExecution().executedPlan().toString().startsWith("AdaptiveSparkPlan isFinalPlan=false")); + df3.collectAsList().forEach(row -> { + result.add(row.json()); + }); + assertTrue(df3.queryExecution().executedPlan().toString().startsWith("AdaptiveSparkPlan isFinalPlan=true")); + AdaptiveSparkPlanExec plan = (AdaptiveSparkPlanExec) df3.queryExecution().executedPlan(); + SortMergeJoinExec joinExec = (SortMergeJoinExec) plan.executedPlan().children().iterator().next(); + assertTrue(joinExec.isSkewJoin()); + result.sort(new Comparator() { + @Override + public int compare(String o1, String o2) { + return o1.compareTo(o2); + } + }); + int i = 0; + for (String str : result) { + map.put(i, str); + i++; + } + SparkConf conf = spark.sparkContext().conf(); + if (conf.get("spark.shuffle.manager", "") + .equals("org.apache.uniffle.test.GetShuffleReportForMultiPartTest$RssShuffleManagerWrapper")) { + GetShuffleReportForMultiPartTest.RssShuffleManagerWrapper mockRssShuffleManager = + (GetShuffleReportForMultiPartTest.RssShuffleManagerWrapper) spark.sparkContext().env().shuffleManager(); + int expectRequestNum = mockRssShuffleManager.getShuffleIdToPartitionNum().values().stream() + .mapToInt(x -> x.get()).sum(); + // Validate getShuffleResultForMultiPart is correct before return result + validateRequestCount(spark.sparkContext().applicationId(), expectRequestNum * replicateRead); + } + return map; + } + + public void validateRequestCount(String appId, int expectRequestNum) { + for (ShuffleServer shuffleServer : shuffleServers) { + MockedShuffleServerGrpcService service = ((MockedGrpcServer) shuffleServer.getServer()).getService(); + Map> serviceRequestCount = service.getShuffleIdToPartitionRequest(); + int requestNum = serviceRequestCount.entrySet().stream().filter(x -> x.getKey().startsWith(appId)) + .flatMap(x -> x.getValue().values().stream()).mapToInt(AtomicInteger::get).sum(); + expectRequestNum -= requestNum; + } + assertEquals(0, expectRequestNum); + } + +} diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/CoordinatorGrpcClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/CoordinatorGrpcClient.java index 53e0922c4a..be44c60c76 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/CoordinatorGrpcClient.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/CoordinatorGrpcClient.java @@ -159,7 +159,8 @@ public RssProtos.GetShuffleAssignmentsResponse doGetShuffleAssignments( int partitionNumPerRange, int dataReplica, Set requiredTags, - int assignmentShuffleServerNumber) { + int assignmentShuffleServerNumber, + int estimateTaskConcurrency) { RssProtos.GetShuffleServerRequest getServerRequest = RssProtos.GetShuffleServerRequest.newBuilder() .setApplicationId(appId) @@ -169,6 +170,7 @@ public RssProtos.GetShuffleAssignmentsResponse doGetShuffleAssignments( .setDataReplica(dataReplica) .addAllRequireTags(requiredTags) .setAssignmentShuffleServerNumber(assignmentShuffleServerNumber) + .setEstimateTaskConcurrency(estimateTaskConcurrency) .build(); return blockingStub.getShuffleAssignments(getServerRequest); @@ -229,7 +231,8 @@ public RssGetShuffleAssignmentsResponse getShuffleAssignments(RssGetShuffleAssig request.getPartitionNumPerRange(), request.getDataReplica(), request.getRequiredTags(), - request.getAssignmentShuffleServerNumber()); + request.getAssignmentShuffleServerNumber(), + request.getEstimateTaskConcurrency()); RssGetShuffleAssignmentsResponse response; StatusCode statusCode = rpcResponse.getStatus(); diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleAssignmentsRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleAssignmentsRequest.java index d0971cbf84..9be636abd5 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleAssignmentsRequest.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleAssignmentsRequest.java @@ -30,15 +30,17 @@ public class RssGetShuffleAssignmentsRequest { private int dataReplica; private Set requiredTags; private int assignmentShuffleServerNumber; + private int estimateTaskConcurrency; @VisibleForTesting public RssGetShuffleAssignmentsRequest(String appId, int shuffleId, int partitionNum, int partitionNumPerRange, int dataReplica, Set requiredTags) { - this(appId, shuffleId, partitionNum, partitionNumPerRange, dataReplica, requiredTags, -1); + this(appId, shuffleId, partitionNum, partitionNumPerRange, dataReplica, requiredTags, -1, -1); } public RssGetShuffleAssignmentsRequest(String appId, int shuffleId, int partitionNum, - int partitionNumPerRange, int dataReplica, Set requiredTags, int assignmentShuffleServerNumber) { + int partitionNumPerRange, int dataReplica, Set requiredTags, int assignmentShuffleServerNumber, + int estimateTaskConcurrency) { this.appId = appId; this.shuffleId = shuffleId; this.partitionNum = partitionNum; @@ -46,6 +48,7 @@ public RssGetShuffleAssignmentsRequest(String appId, int shuffleId, int partitio this.dataReplica = dataReplica; this.requiredTags = requiredTags; this.assignmentShuffleServerNumber = assignmentShuffleServerNumber; + this.estimateTaskConcurrency = estimateTaskConcurrency; } public String getAppId() { @@ -75,4 +78,8 @@ public Set getRequiredTags() { public int getAssignmentShuffleServerNumber() { return assignmentShuffleServerNumber; } + + public int getEstimateTaskConcurrency() { + return estimateTaskConcurrency; + } } diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto index 4a4077cfd9..68b5c02c06 100644 --- a/proto/src/main/proto/Rss.proto +++ b/proto/src/main/proto/Rss.proto @@ -328,6 +328,7 @@ message GetShuffleServerRequest { int32 dataReplica = 8; repeated string requireTags = 9; int32 assignmentShuffleServerNumber = 10; + int32 estimateTaskConcurrency = 11; } message PartitionRangeAssignment {