Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@ public Thread newThread(Runnable r) {
numReduceTasks,
1,
Sets.newHashSet(assignmentTags),
requiredAssignmentShuffleServersNum
requiredAssignmentShuffleServersNum,
-1
);

Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,8 @@ public void reportShuffleResult(Map<Integer, List<ShuffleServerInfo>> partitionT

@Override
public ShuffleAssignmentsInfo getShuffleAssignments(String appId, int shuffleId, int partitionNum,
int partitionNumPerRange, Set<String> requiredTags, int assignmentShuffleServerNumber) {
int partitionNumPerRange, Set<String> requiredTags, int assignmentShuffleServerNumber,
int estimateTaskConcurrency) {
return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,8 @@ public void reportShuffleResult(Map<Integer, List<ShuffleServerInfo>> partitionT

@Override
public ShuffleAssignmentsInfo getShuffleAssignments(String appId, int shuffleId, int partitionNum,
int partitionNumPerRange, Set<String> requiredTags, int assignmentShuffleServerNumber) {
int partitionNumPerRange, Set<String> requiredTags, int assignmentShuffleServerNumber,
int estimateTaskConcurrency) {
return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,20 @@ public class RssSparkConfig {
.doc("Coordinator quorum"))
.createWithDefault("");

public static final ConfigEntry<Double> RSS_ESTIMATE_TASK_CONCURRENCY_DYNAMIC_FACTOR = createDoubleBuilder(
new ConfigBuilder(SPARK_RSS_CONFIG_PREFIX + RssClientConfig.RSS_ESTIMATE_TASK_CONCURRENCY_DYNAMIC_FACTOR)
.doc("Between 0 and 1, used to estimate task concurrency, how likely is this part of the resource between"
+ " spark.dynamicAllocation.minExecutors and spark.dynamicAllocation.maxExecutors"
+ " to be allocated"))
.createWithDefault(RssClientConfig.RSS_ESTIMATE_TASK_CONCURRENCY_DYNAMIC_FACTOR_DEFAULT_VALUE);

public static final ConfigEntry<Boolean> RSS_ESTIMATE_TASK_CONCURRENCY_ENABLED = createBooleanBuilder(
new ConfigBuilder(SPARK_RSS_CONFIG_PREFIX + RssClientConfig.RSS_ESTIMATE_TASK_CONCURRENCY_ENABLED)
.doc("When the Coordinator enables rss.coordinator.select.partition.strategy,"
+ " this configuration item is valid and is used to estimate how many consecutive"
+ " PartitionRanges should be allocated to a ShuffleServer"))
.createWithDefault(RssClientConfig.RSS_ESTIMATE_TASK_CONCURRENCY_DEFAULT_VALUE);

public static final Set<String> RSS_MANDATORY_CLUSTER_CONF =
ImmutableSet.of(RSS_STORAGE_TYPE.key(), RSS_REMOTE_STORAGE_PATH.key());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.uniffle.client.api.CoordinatorClient;
import org.apache.uniffle.client.factory.CoordinatorClientFactory;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.Constants;

public class RssSparkShuffleUtils {
Expand Down Expand Up @@ -135,4 +136,29 @@ public static Set<String> getAssignmentTags(SparkConf sparkConf) {
assignmentTags.add(Constants.SHUFFLE_SERVER_VERSION);
return assignmentTags;
}

public static int estimateTaskConcurrency(SparkConf sparkConf) {
int taskConcurrency;
double dynamicAllocationFactor = sparkConf.get(RssSparkConfig.RSS_ESTIMATE_TASK_CONCURRENCY_DYNAMIC_FACTOR);
if (dynamicAllocationFactor > 1 || dynamicAllocationFactor < 0) {
throw new RssException("dynamicAllocationFactor is not valid: " + dynamicAllocationFactor);
}
int executorCores = sparkConf.getInt(Constants.SPARK_EXECUTOR_CORES, Constants.SPARK_EXECUTOR_CORES_DEFAULT_VALUE);
int taskCpus = sparkConf.getInt(Constants.SPARK_TASK_CPUS, Constants.SPARK_TASK_CPUS_DEFAULT_VALUE);
int taskConcurrencyPerExecutor = Math.floorDiv(executorCores, taskCpus);
if (!sparkConf.getBoolean(Constants.SPARK_DYNAMIC_ENABLED, false)) {
int executorInstances = sparkConf.getInt(Constants.SPARK_EXECUTOR_INSTANTS,
Constants.SPARK_EXECUTOR_INSTANTS_DEFAULT_VALUE);
taskConcurrency = executorInstances > 0 ? executorInstances * taskConcurrencyPerExecutor : 0;
} else {
// Default is infinity
int maxExecutors = Math.min(sparkConf.getInt(Constants.SPARK_MAX_DYNAMIC_EXECUTOR,
Constants.SPARK_DYNAMIC_EXECUTOR_DEFAULT_VALUE), Constants.SPARK_MAX_DYNAMIC_EXECUTOR_LIMIT);
int minExecutors = sparkConf.getInt(Constants.SPARK_MIN_DYNAMIC_EXECUTOR,
Constants.SPARK_DYNAMIC_EXECUTOR_DEFAULT_VALUE);
taskConcurrency = (int)((maxExecutors - minExecutors) * dynamicAllocationFactor + minExecutors)
* taskConcurrencyPerExecutor;
}
return taskConcurrency;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -158,4 +158,34 @@ public void applyDynamicClientConfTest() {
assertEquals(Integer.toString(RssClientConfig.RSS_CLIENT_RETRY_MAX_DEFAULT_VALUE),
conf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX.key()));
}

@Test
public void testEstimateTaskConcurrency() {
SparkConf sparkConf = new SparkConf();
sparkConf.set(Constants.SPARK_DYNAMIC_ENABLED, "true");
sparkConf.set(Constants.SPARK_MAX_DYNAMIC_EXECUTOR, "200");
sparkConf.set(Constants.SPARK_MIN_DYNAMIC_EXECUTOR, "100");
sparkConf.set(RssSparkConfig.RSS_ESTIMATE_TASK_CONCURRENCY_ENABLED, true);
sparkConf.set(Constants.SPARK_EXECUTOR_CORES, "2");
int taskConcurrency;

sparkConf.set(RssSparkConfig.RSS_ESTIMATE_TASK_CONCURRENCY_DYNAMIC_FACTOR, 1.0);
taskConcurrency = RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf);
assertEquals(400, taskConcurrency);

sparkConf.set(RssSparkConfig.RSS_ESTIMATE_TASK_CONCURRENCY_DYNAMIC_FACTOR, 0.3);
taskConcurrency = RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf);
assertEquals(260, taskConcurrency);

sparkConf.set(Constants.SPARK_TASK_CPUS, "2");
sparkConf.set(RssSparkConfig.RSS_ESTIMATE_TASK_CONCURRENCY_DYNAMIC_FACTOR, 0.3);
taskConcurrency = RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf);
assertEquals(130, taskConcurrency);

sparkConf.set(Constants.SPARK_DYNAMIC_ENABLED, "false");
sparkConf.set(Constants.SPARK_EXECUTOR_INSTANTS, "70");
sparkConf.set(RssSparkConfig.RSS_ESTIMATE_TASK_CONCURRENCY_DYNAMIC_FACTOR, 1.0);
taskConcurrency = RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf);
assertEquals(70, taskConcurrency);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ public <K, V, C> ShuffleHandle registerShuffle(int shuffleId, int numMaps, Shuff
partitionToServers = RetryUtils.retry(() -> {
ShuffleAssignmentsInfo response = shuffleWriteClient.getShuffleAssignments(
appId, shuffleId, dependency.partitioner().numPartitions(),
partitionNumPerRange, assignmentTags, requiredShuffleServerNumber);
partitionNumPerRange, assignmentTags, requiredShuffleServerNumber, -1);
registerShuffleServers(appId, shuffleId, response.getServerToPartitionRanges(), remoteStorage);
return response.getPartitionToServers();
}, retryInterval, retryTimes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,9 @@ public <K, V, C> ShuffleHandle registerShuffle(int shuffleId, ShuffleDependency<
// retryInterval must bigger than `rss.server.heartbeat.timeout`, or maybe it will return the same result
long retryInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL);
int retryTimes = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES);
boolean enabledEstimateTaskConcurrency = sparkConf.get(RssSparkConfig.RSS_ESTIMATE_TASK_CONCURRENCY_ENABLED);
int estimateTaskConcurrency = enabledEstimateTaskConcurrency
? RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf) : -1;
Map<Integer, List<ShuffleServerInfo>> partitionToServers;
try {
partitionToServers = RetryUtils.retry(() -> {
Expand All @@ -291,7 +294,8 @@ public <K, V, C> ShuffleHandle registerShuffle(int shuffleId, ShuffleDependency<
dependency.partitioner().numPartitions(),
1,
assignmentTags,
requiredShuffleServerNumber);
requiredShuffleServerNumber,
estimateTaskConcurrency);
registerShuffleServers(id.get(), shuffleId, response.getServerToPartitionRanges(), remoteStorage);
return response.getPartitionToServers();
}, retryInterval, retryTimes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ void reportShuffleResult(
int bitmapNum);

ShuffleAssignmentsInfo getShuffleAssignments(String appId, int shuffleId, int partitionNum,
int partitionNumPerRange, Set<String> requiredTags, int assignmentShuffleServerNumber);
int partitionNumPerRange, Set<String> requiredTags, int assignmentShuffleServerNumber,
int estimateTaskConcurrency);

Roaring64NavigableMap getShuffleResult(String clientType, Set<ShuffleServerInfo> shuffleServerInfoSet,
String appId, int shuffleId, int partitionId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,11 @@ public RemoteStorageInfo fetchRemoteStorage(String appId) {

@Override
public ShuffleAssignmentsInfo getShuffleAssignments(String appId, int shuffleId, int partitionNum,
int partitionNumPerRange, Set<String> requiredTags, int assignmentShuffleServerNumber) {
int partitionNumPerRange, Set<String> requiredTags, int assignmentShuffleServerNumber,
int estimateTaskConcurrency) {
RssGetShuffleAssignmentsRequest request = new RssGetShuffleAssignmentsRequest(
appId, shuffleId, partitionNum, partitionNumPerRange, replica, requiredTags, assignmentShuffleServerNumber);
appId, shuffleId, partitionNum, partitionNumPerRange, replica, requiredTags,
assignmentShuffleServerNumber, estimateTaskConcurrency);

RssGetShuffleAssignmentsResponse response = new RssGetShuffleAssignmentsResponse(ResponseStatusCode.INTERNAL_ERROR);
for (CoordinatorClient coordinatorClient : coordinatorClients) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,11 @@ public class RssClientConfig {
public static final String RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER =
"rss.client.assignment.shuffle.nodes.max";
public static final int RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER_DEFAULT_VALUE = -1;

public static final String RSS_ESTIMATE_TASK_CONCURRENCY_DYNAMIC_FACTOR =
"rss.estimate.task.concurrency.dynamic.factor";
public static final double RSS_ESTIMATE_TASK_CONCURRENCY_DYNAMIC_FACTOR_DEFAULT_VALUE = 1.0;

public static final String RSS_ESTIMATE_TASK_CONCURRENCY_ENABLED = "rss.estimate.task.concurrency.enabled";
public static final boolean RSS_ESTIMATE_TASK_CONCURRENCY_DEFAULT_VALUE = false;
}
17 changes: 17 additions & 0 deletions common/src/main/java/org/apache/uniffle/common/util/Constants.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,21 @@ public class Constants {
RSS_CLIENT_CONF_COMMON_PREFIX + CONF_REMOTE_STORAGE_PATH;

public static final String ACCESS_INFO_REQUIRED_SHUFFLE_NODES_NUM = "access_info_required_shuffle_nodes_num";
public static final String SPARK_DYNAMIC_ENABLED = "spark.dynamicAllocation.enabled";
public static final String SPARK_MAX_DYNAMIC_EXECUTOR = "spark.dynamicAllocation.maxExecutors";
public static final String SPARK_MIN_DYNAMIC_EXECUTOR = "spark.dynamicAllocation.minExecutors";
public static final int SPARK_DYNAMIC_EXECUTOR_DEFAULT_VALUE = 0;
public static final String SPARK_EXECUTOR_INSTANTS = "spark.executor.instances";
public static final int SPARK_EXECUTOR_INSTANTS_DEFAULT_VALUE = -1;
public static final String SPARK_EXECUTOR_CORES = "spark.executor.cores";
public static final int SPARK_EXECUTOR_CORES_DEFAULT_VALUE = 1;
public static final String SPARK_TASK_CPUS = "spark.task.cpus";
public static final int SPARK_TASK_CPUS_DEFAULT_VALUE = 1;
public static final int SPARK_MAX_DYNAMIC_EXECUTOR_LIMIT = 10000;

public static final String MR_MAPS = "mapreduce.job.maps";
public static final String MR_REDUCES = "mapreduce.job.reduces";
public static final String MR_MAP_LIMIT = "mapreduce.job.running.map.limit";
public static final String MR_REDUCE_LIMIT = "mapreduce.job.running.reduce.limit";
public static final String MR_SLOW_START = "mapreduce.job.reduce.slowstart.completedmaps";
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,69 +17,69 @@

package org.apache.uniffle.coordinator;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.SortedMap;

import static org.apache.uniffle.coordinator.CoordinatorConf.COORDINATOR_ASSGINMENT_HOST_STRATEGY;
import org.apache.uniffle.common.PartitionRange;

import static org.apache.uniffle.coordinator.CoordinatorConf.COORDINATOR_ASSIGNMENT_HOST_STRATEGY;
import static org.apache.uniffle.coordinator.CoordinatorConf.COORDINATOR_SELECT_PARTITION_STRATEGY;

public abstract class AbstractAssignmentStrategy implements AssignmentStrategy {
protected final CoordinatorConf conf;
private final HostAssignmentStrategy assignmentHostStrategy;
private HostAssignmentStrategy hostAssignmentStrategy;
private SelectPartitionStrategy selectPartitionStrategy;

public AbstractAssignmentStrategy(CoordinatorConf conf) {
this.conf = conf;
assignmentHostStrategy = conf.get(COORDINATOR_ASSGINMENT_HOST_STRATEGY);
loadHostAssignmentStrategy();
loadSelectPartitionStrategy();
}

protected List<ServerNode> getCandidateNodes(List<ServerNode> allNodes, int expectNum) {
switch (assignmentHostStrategy) {
case MUST_DIFF: return getCandidateNodesWithDiffHost(allNodes, expectNum);
case PREFER_DIFF: return tryGetCandidateNodesWithDiffHost(allNodes, expectNum);
case NONE: return allNodes.subList(0, expectNum);
default: throw new RuntimeException("Unsupported host assignment strategy:" + assignmentHostStrategy);
private void loadSelectPartitionStrategy() {
SelectPartitionStrategyName selectPartitionStrategyName =
conf.get(COORDINATOR_SELECT_PARTITION_STRATEGY);
if (selectPartitionStrategyName == SelectPartitionStrategyName.ROUND) {
selectPartitionStrategy = new RoundSelectPartitionStrategy();
} else if (selectPartitionStrategyName == SelectPartitionStrategyName.CONTINUOUS) {
selectPartitionStrategy = new ContinuousSelectPartitionStrategy();
} else {
throw new RuntimeException("Unsupported partition assignment strategy:" + selectPartitionStrategyName);
}
}

protected List<ServerNode> tryGetCandidateNodesWithDiffHost(List<ServerNode> allNodes, int expectNum) {
List<ServerNode> candidatesNodes = getCandidateNodesWithDiffHost(allNodes, expectNum);
Set<ServerNode> candidatesNodeSet = candidatesNodes.stream().collect(Collectors.toSet());
if (candidatesNodes.size() < expectNum) {
for (ServerNode node : allNodes) {
if (candidatesNodeSet.contains(node)) {
continue;
}
candidatesNodes.add(node);
if (candidatesNodes.size() >= expectNum) {
break;
}
}
private void loadHostAssignmentStrategy() {
HostAssignmentStrategyName hostAssignmentStrategyName = conf.get(COORDINATOR_ASSIGNMENT_HOST_STRATEGY);
if (hostAssignmentStrategyName == HostAssignmentStrategyName.MUST_DIFF) {
hostAssignmentStrategy = new MustDiffHostAssignmentStrategy();
} else if (hostAssignmentStrategyName == HostAssignmentStrategyName.PREFER_DIFF) {
hostAssignmentStrategy = new PerferDiffHostAssignmentStrategy();
} else if (hostAssignmentStrategyName == HostAssignmentStrategyName.NONE) {
hostAssignmentStrategy = new BasicHostAssignmentStrategy();
} else {
throw new RuntimeException("Unsupported partition assignment strategy:" + hostAssignmentStrategyName);
}
return candidatesNodes;
}

protected List<ServerNode> getCandidateNodesWithDiffHost(List<ServerNode> allNodes, int expectNum) {
List<ServerNode> candidatesNodes = new ArrayList<>();
Set<String> hostIpCandidate = new HashSet<>();
for (ServerNode node : allNodes) {
if (hostIpCandidate.contains(node.getIp())) {
continue;
}
hostIpCandidate.add(node.getIp());
candidatesNodes.add(node);
if (candidatesNodes.size() >= expectNum) {
break;
}
}
return candidatesNodes;
protected List<ServerNode> getCandidateNodes(List<ServerNode> allNodes, int expectNum) {
return hostAssignmentStrategy.assign(allNodes, expectNum);
}

protected SortedMap<PartitionRange, List<ServerNode>> getPartitionAssignment(
int totalPartitionNum, int partitionNumPerRange, int replica, List<ServerNode> candidatesNodes,
int estimateTaskConcurrency) {
return selectPartitionStrategy.assign(totalPartitionNum, partitionNumPerRange, replica,
candidatesNodes, estimateTaskConcurrency);
}

public enum HostAssignmentStrategy {
public enum HostAssignmentStrategyName {
MUST_DIFF,
PREFER_DIFF,
NONE
}

public enum SelectPartitionStrategyName {
ROUND,
CONTINUOUS
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@
public interface AssignmentStrategy {

PartitionRangeAssignment assign(int totalPartitionNum, int partitionNumPerRange,
int replica, Set<String> requiredTags, int requiredShuffleServerNumber);
int replica, Set<String> requiredTags, int requiredShuffleServerNumber, int estimateTaskConcurrency);

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@
package org.apache.uniffle.coordinator;

import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -42,8 +40,7 @@ public BasicAssignmentStrategy(ClusterManager clusterManager, CoordinatorConf co

@Override
public PartitionRangeAssignment assign(int totalPartitionNum, int partitionNumPerRange,
int replica, Set<String> requiredTags, int requiredShuffleServerNumber) {
List<PartitionRange> ranges = CoordinatorUtils.generateRanges(totalPartitionNum, partitionNumPerRange);
int replica, Set<String> requiredTags, int requiredShuffleServerNumber, int estimateTaskConcurrency) {
int shuffleNodesMax = clusterManager.getShuffleNodesMax();
int expectedShuffleNodesNum = shuffleNodesMax;
if (requiredShuffleServerNumber < shuffleNodesMax && requiredShuffleServerNumber > 0) {
Expand All @@ -54,20 +51,8 @@ public PartitionRangeAssignment assign(int totalPartitionNum, int partitionNumPe
return new PartitionRangeAssignment(null);
}

SortedMap<PartitionRange, List<ServerNode>> assignments = new TreeMap<>();
int idx = 0;
int size = servers.size();

for (PartitionRange range : ranges) {
List<ServerNode> nodes = new LinkedList<>();
for (int i = 0; i < replica; ++i) {
ServerNode node = servers.get(idx);
nodes.add(node);
idx = CoordinatorUtils.nextIdx(idx, size);
}

assignments.put(range, nodes);
}
SortedMap<PartitionRange, List<ServerNode>> assignments =
getPartitionAssignment(totalPartitionNum, partitionNumPerRange, replica, servers, estimateTaskConcurrency);

return new PartitionRangeAssignment(assignments);
}
Expand Down
Loading