diff --git a/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java b/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java index edcd78b1d3..e52b1ca389 100644 --- a/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java +++ b/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java @@ -124,6 +124,9 @@ public class RssMRConfig { public static String RSS_ACCESS_TIMEOUT_MS = MR_RSS_CONFIG_PREFIX + RssClientConfig.RSS_ACCESS_TIMEOUT_MS; public static int RSS_ACCESS_TIMEOUT_MS_DEFAULT_VALUE = RssClientConfig.RSS_ACCESS_TIMEOUT_MS_DEFAULT_VALUE; + public static final String RSS_CLIENT_ASSIGNMENT_TAGS = + MR_RSS_CONFIG_PREFIX + RssClientConfig.RSS_CLIENT_ASSIGNMENT_TAGS; + public static String RSS_CONF_FILE = "rss_conf.xml"; public static Set RSS_MANDATORY_CLUSTER_CONF = Sets.newHashSet( diff --git a/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java b/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java index 2a6ce41bdb..db8d7b9c44 100644 --- a/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java +++ b/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java @@ -19,8 +19,11 @@ import java.io.IOException; import java.lang.reflect.Field; +import java.util.Arrays; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ThreadFactory; @@ -115,11 +118,19 @@ public static void main(String[] args) { LOG.info("Registering coordinators {}", coordinators); client.registerCoordinators(coordinators); + // Get the configured server assignment tags and it will also add default shuffle version tag. + Set assignmentTags = new HashSet<>(); + String rawTags = conf.get(RssMRConfig.RSS_CLIENT_ASSIGNMENT_TAGS, ""); + if (StringUtils.isNotEmpty(rawTags)) { + rawTags = rawTags.trim(); + assignmentTags.addAll(Arrays.asList(rawTags.split(","))); + } + assignmentTags.add(Constants.SHUFFLE_SERVER_VERSION); + ApplicationAttemptId applicationAttemptId = RssMRUtils.getApplicationAttemptId(); String appId = applicationAttemptId.toString(); ShuffleAssignmentsInfo response = client.getShuffleAssignments( - appId, 0, numReduceTasks, - 1, Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION)); + appId, 0, numReduceTasks, 1, Sets.newHashSet(assignmentTags)); Map> serverToPartitionRanges = response.getServerToPartitionRanges(); final ScheduledExecutorService scheduledExecutorService = Executors.newSingleThreadScheduledExecutor( diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java index 819b42fa4b..041e21f82a 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java @@ -133,6 +133,8 @@ public class RssSparkConfig { SPARK_RSS_CONFIG_PREFIX + RssClientConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED; public static final boolean RSS_DYNAMIC_CLIENT_CONF_ENABLED_DEFAULT_VALUE = RssClientConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED_DEFAULT_VALUE; + public static final String RSS_CLIENT_ASSIGNMENT_TAGS = + SPARK_RSS_CONFIG_PREFIX + RssClientConfig.RSS_CLIENT_ASSIGNMENT_TAGS; public static final Set RSS_MANDATORY_CLUSTER_CONF = Sets.newHashSet(RSS_STORAGE_TYPE, RSS_REMOTE_STORAGE_PATH); diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java index ca060fe549..a4874f5129 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java @@ -18,9 +18,13 @@ package org.apache.spark.shuffle; import java.lang.reflect.Constructor; +import java.util.Arrays; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; +import org.apache.commons.lang3.StringUtils; import org.apache.hadoop.conf.Configuration; import org.apache.spark.SparkConf; import org.apache.spark.deploy.SparkHadoopUtil; @@ -30,6 +34,7 @@ import org.apache.uniffle.client.api.CoordinatorClient; import org.apache.uniffle.client.factory.CoordinatorClientFactory; import org.apache.uniffle.common.RemoteStorageInfo; +import org.apache.uniffle.common.util.Constants; public class RssSparkShuffleUtils { @@ -123,4 +128,15 @@ public static Configuration getRemoteStorageHadoopConf( } return readerHadoopConf; } + + public static Set getAssignmentTags(SparkConf sparkConf) { + Set assignmentTags = new HashSet<>(); + String rawTags = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_TAGS, ""); + if (StringUtils.isNotEmpty(rawTags)) { + rawTags = rawTags.trim(); + assignmentTags.addAll(Arrays.asList(rawTags.split(","))); + } + assignmentTags.add(Constants.SHUFFLE_SERVER_VERSION); + return assignmentTags; + } } diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/RssSparkShuffleUtilsTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/RssSparkShuffleUtilsTest.java index 09007f2236..f15c4cfddc 100644 --- a/client-spark/common/src/test/java/org/apache/spark/shuffle/RssSparkShuffleUtilsTest.java +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/RssSparkShuffleUtilsTest.java @@ -17,11 +17,15 @@ package org.apache.spark.shuffle; +import java.util.Iterator; import java.util.Map; +import java.util.Set; import com.google.common.collect.Maps; + import org.apache.hadoop.conf.Configuration; import org.apache.spark.SparkConf; +import org.apache.uniffle.common.util.Constants; import org.junit.jupiter.api.Test; import org.apache.uniffle.client.util.RssClientConfig; @@ -32,6 +36,30 @@ import static org.junit.jupiter.api.Assertions.assertTrue; public class RssSparkShuffleUtilsTest { + + @Test + public void testAssignmentTags() { + SparkConf conf = new SparkConf(); + + /** + * Case1: dont set the tag implicitly and will return the {@code Constants.SHUFFLE_SERVER_VERSION} + */ + Set tags = RssSparkShuffleUtils.getAssignmentTags(conf); + assertEquals(Constants.SHUFFLE_SERVER_VERSION, tags.iterator().next()); + + /** + * Case2: set the multiple tags implicitly and will return the {@code Constants.SHUFFLE_SERVER_VERSION} + * and configured tags. + */ + conf.set(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_TAGS, " a,b"); + tags = RssSparkShuffleUtils.getAssignmentTags(conf); + assertEquals(3, tags.size()); + Iterator iterator = tags.iterator(); + assertEquals("a", iterator.next()); + assertEquals("b", iterator.next()); + assertEquals(Constants.SHUFFLE_SERVER_VERSION, iterator.next()); + } + @Test public void odfsConfigurationTest() { SparkConf conf = new SparkConf(); diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/DelegationRssShuffleManager.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/DelegationRssShuffleManager.java index c48eb1585e..01c4cf6e22 100644 --- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/DelegationRssShuffleManager.java +++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/DelegationRssShuffleManager.java @@ -18,9 +18,9 @@ package org.apache.spark.shuffle; import java.util.List; +import java.util.Set; import com.google.common.collect.Lists; -import com.google.common.collect.Sets; import org.apache.commons.lang3.StringUtils; import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; @@ -100,9 +100,11 @@ private boolean tryAccessCluster() { for (CoordinatorClient coordinatorClient : coordinatorClients) { try { + Set assignmentTags = RssSparkShuffleUtils.getAssignmentTags(sparkConf); + RssAccessClusterResponse response = coordinatorClient.accessCluster(new RssAccessClusterRequest( - accessId, Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION), accessTimeoutMs)); + accessId, assignmentTags, accessTimeoutMs)); if (response.getStatusCode() == ResponseStatusCode.SUCCESS) { LOG.warn("Success to access cluster {} using {}", coordinatorClient.getDesc(), accessId); return true; diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index e3c67b3660..738589fe02 100644 --- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -61,7 +61,6 @@ import org.apache.uniffle.common.ShuffleAssignmentsInfo; import org.apache.uniffle.common.ShuffleBlockInfo; import org.apache.uniffle.common.ShuffleServerInfo; -import org.apache.uniffle.common.util.Constants; import org.apache.uniffle.common.util.RssUtils; public class RssShuffleManager implements ShuffleManager { @@ -230,9 +229,11 @@ public ShuffleHandle registerShuffle(int shuffleId, int numMaps, Shuff RssSparkConfig.RSS_PARTITION_NUM_PER_RANGE_DEFAULT_VALUE); // get all register info according to coordinator's response + Set assignmentTags = RssSparkShuffleUtils.getAssignmentTags(sparkConf); + ShuffleAssignmentsInfo response = shuffleWriteClient.getShuffleAssignments( appId, shuffleId, dependency.partitioner().numPartitions(), - partitionNumPerRange, Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION)); + partitionNumPerRange, assignmentTags); Map> partitionToServers = response.getPartitionToServers(); startHeartbeat(); diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/DelegationRssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/DelegationRssShuffleManager.java index 555a55a1fe..6b8a77e630 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/DelegationRssShuffleManager.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/DelegationRssShuffleManager.java @@ -18,9 +18,9 @@ package org.apache.spark.shuffle; import java.util.List; +import java.util.Set; import com.google.common.collect.Lists; -import com.google.common.collect.Sets; import org.apache.commons.lang3.StringUtils; import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; @@ -100,9 +100,11 @@ private boolean tryAccessCluster() { for (CoordinatorClient coordinatorClient : coordinatorClients) { try { + Set assignmentTags = RssSparkShuffleUtils.getAssignmentTags(sparkConf); + RssAccessClusterResponse response = coordinatorClient.accessCluster(new RssAccessClusterRequest( - accessId, Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION), accessTimeoutMs)); + accessId, assignmentTags, accessTimeoutMs)); if (response.getStatusCode() == ResponseStatusCode.SUCCESS) { LOG.warn("Success to access cluster {} using {}", coordinatorClient.getDesc(), accessId); return true; diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index 884b5e9300..5c7e2d90cf 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -65,7 +65,6 @@ import org.apache.uniffle.common.ShuffleAssignmentsInfo; import org.apache.uniffle.common.ShuffleBlockInfo; import org.apache.uniffle.common.ShuffleServerInfo; -import org.apache.uniffle.common.util.Constants; import org.apache.uniffle.common.util.RssUtils; public class RssShuffleManager implements ShuffleManager { @@ -276,12 +275,14 @@ public ShuffleHandle registerShuffle(int shuffleId, ShuffleDependency< remoteStorage = ClientUtils.fetchRemoteStorage( id.get(), remoteStorage, dynamicConfEnabled, storageType, shuffleWriteClient); + Set assignmentTags = RssSparkShuffleUtils.getAssignmentTags(sparkConf); + ShuffleAssignmentsInfo response = shuffleWriteClient.getShuffleAssignments( id.get(), shuffleId, dependency.partitioner().numPartitions(), 1, - Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION)); + assignmentTags); Map> partitionToServers = response.getPartitionToServers(); startHeartbeat(); diff --git a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java index efe4fae8c3..168b7c49cf 100644 --- a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java +++ b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java @@ -124,7 +124,7 @@ private boolean sendShuffleDataAsync( appId, retryMax, retryIntervalMax, shuffleIdToBlocks); long s = System.currentTimeMillis(); RssSendShuffleDataResponse response = getShuffleServerClient(ssi).sendShuffleData(request); - LOG.info("ShuffleWriteClientImpl sendShuffleData cost:" + (System.currentTimeMillis() - s)); + LOG.info("ShuffleWriteClientImpl sendShuffleData cost:" + (System.currentTimeMillis() - s) + "(ms)"); if (response.getStatusCode() == ResponseStatusCode.SUCCESS) { // mark a replica of block that has been sent diff --git a/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java b/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java index 475ea27e18..0c519a8ed6 100644 --- a/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java +++ b/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java @@ -56,6 +56,8 @@ public class RssClientConfig { // When the size of read buffer reaches the half of JVM region (i.e., 32m), // it will incur humongous allocation, so we set it to 14m. public static String RSS_CLIENT_READ_BUFFER_SIZE_DEFAULT_VALUE = "14m"; + // The tags specified by rss client to determine server assignment. + public static String RSS_CLIENT_ASSIGNMENT_TAGS = "rss.client.assignment.tags"; public static String RSS_ACCESS_TIMEOUT_MS = "rss.access.timeout.ms"; public static int RSS_ACCESS_TIMEOUT_MS_DEFAULT_VALUE = 10000; diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentWithTagsTest.java b/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentWithTagsTest.java new file mode 100644 index 0000000000..416af72a99 --- /dev/null +++ b/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentWithTagsTest.java @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +import java.io.File; +import java.io.IOException; +import java.net.ServerSocket; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import org.apache.commons.lang3.StringUtils; +import org.apache.uniffle.client.impl.ShuffleWriteClientImpl; +import org.apache.uniffle.client.util.ClientType; +import org.apache.uniffle.common.ShuffleAssignmentsInfo; +import org.apache.uniffle.common.util.Constants; +import org.apache.uniffle.coordinator.CoordinatorConf; +import org.apache.uniffle.coordinator.CoordinatorServer; +import org.apache.uniffle.server.ShuffleServer; +import org.apache.uniffle.server.ShuffleServerConf; +import org.apache.uniffle.storage.util.StorageType; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.collect.Sets; +import com.google.common.io.Files; + +/** + * This class is to test the conf of {@code org.apache.uniffle.server.ShuffleServerConf.Tags} + * and {@code RssClientConfig.RSS_CLIENT_ASSIGNMENT_TAGS} + */ +public class AssignmentWithTagsTest extends CoordinatorTestBase { + private static final Logger LOG = LoggerFactory.getLogger(AssignmentWithTagsTest.class); + + // KV: tag -> shuffle server id + private static Map> tagOfShufflePorts = new HashMap<>(); + + private static List findAvailablePorts(int num) throws IOException { + List sockets = new ArrayList<>(); + List ports = new ArrayList<>(); + + for (int i = 0; i < num; i++) { + ServerSocket socket = new ServerSocket(0); + ports.add(socket.getLocalPort()); + sockets.add(socket); + } + + for (ServerSocket socket : sockets) { + socket.close(); + } + + return ports; + } + + private static void createAndStartShuffleServerWithTags(Set tags) throws Exception { + ShuffleServerConf shuffleServerConf = getShuffleServerConf(); + shuffleServerConf.setLong("rss.server.app.expired.withoutHeartbeat", 4000); + + File tmpDir = Files.createTempDir(); + tmpDir.deleteOnExit(); + + File dataDir1 = new File(tmpDir, "data1"); + File dataDir2 = new File(tmpDir, "data2"); + String basePath = dataDir1.getAbsolutePath() + "," + dataDir2.getAbsolutePath(); + + shuffleServerConf.setString("rss.storage.type", StorageType.LOCALFILE.name()); + shuffleServerConf.setString("rss.storage.basePath", basePath); + shuffleServerConf.setString("rss.server.tags", StringUtils.join(tags, ",")); + + List ports = findAvailablePorts(2); + shuffleServerConf.setInteger("rss.rpc.server.port", ports.get(0)); + shuffleServerConf.setInteger("rss.jetty.http.port", ports.get(1)); + + for (String tag : tags) { + tagOfShufflePorts.putIfAbsent(tag, new ArrayList<>()); + tagOfShufflePorts.get(tag).add(ports.get(0)); + } + tagOfShufflePorts.putIfAbsent(Constants.SHUFFLE_SERVER_VERSION, new ArrayList<>()); + tagOfShufflePorts.get(Constants.SHUFFLE_SERVER_VERSION).add(ports.get(0)); + + LOG.info("Shuffle server data dir: {}, rpc port: {}, http port: {}", dataDir1 + "," + dataDir2, + ports.get(0), ports.get(1)); + + ShuffleServer server = new ShuffleServer(shuffleServerConf); + shuffleServers.add(server); + server.start(); + } + + @BeforeAll + public static void setupServers() throws Exception { + CoordinatorConf coordinatorConf = getCoordinatorConf(); + createCoordinatorServer(coordinatorConf); + + for (CoordinatorServer coordinator : coordinators) { + coordinator.start(); + } + + for (int i = 0; i < 2; i ++) { + createAndStartShuffleServerWithTags(Sets.newHashSet()); + } + + for (int i = 0; i < 2; i++) { + createAndStartShuffleServerWithTags(Sets.newHashSet("fixed")); + } + + for (int i = 0; i < 2; i++) { + createAndStartShuffleServerWithTags(Sets.newHashSet("elastic")); + } + + // Wait all shuffle servers registering to coordinator + long startTimeMS = System.currentTimeMillis(); + while (true) { + int nodeSum = coordinators.get(0).getClusterManager().getNodesNum(); + if (nodeSum == 6) { + break; + } + if (System.currentTimeMillis() - startTimeMS > 1000 * 5) { + throw new Exception("Timeout of waiting shuffle servers registry, timeout: 5s."); + } + } + } + + @Test + public void testTags() throws Exception { + ShuffleWriteClientImpl shuffleWriteClient = new ShuffleWriteClientImpl(ClientType.GRPC.name(), 3, 1000, 1, + 1, 1, 1, true, 1); + shuffleWriteClient.registerCoordinators(COORDINATOR_QUORUM); + + // Case1 : only set the single default shuffle version tag + ShuffleAssignmentsInfo assignmentsInfo = + shuffleWriteClient.getShuffleAssignments("app-1", + 1, 1, 1, Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION)); + + List assignedServerPorts = assignmentsInfo + .getPartitionToServers() + .values() + .stream() + .flatMap(x -> x.stream()) + .map(x -> x.getPort()) + .collect(Collectors.toList()); + assertEquals(1, assignedServerPorts.size()); + assertTrue(tagOfShufflePorts.get(Constants.SHUFFLE_SERVER_VERSION).contains(assignedServerPorts.get(0))); + + // Case2: Set the single non-exist shuffle server tag + try { + assignmentsInfo = shuffleWriteClient.getShuffleAssignments("app-2", + 1, 1, 1, Sets.newHashSet("non-exist")); + fail(); + } catch (Exception e) { + assertTrue(e.getMessage().startsWith("Error happened when getShuffleAssignments with")); + } + + // Case3: Set the single fixed tag + assignmentsInfo = shuffleWriteClient.getShuffleAssignments("app-3", + 1, 1, 1, Sets.newHashSet("fixed")); + assignedServerPorts = assignmentsInfo + .getPartitionToServers() + .values() + .stream() + .flatMap(x -> x.stream()) + .map(x -> x.getPort()) + .collect(Collectors.toList()); + assertEquals(1, assignedServerPorts.size()); + assertTrue(tagOfShufflePorts.get("fixed").contains(assignedServerPorts.get(0))); + + // case4: Set the multiple tags if exists + assignmentsInfo = shuffleWriteClient.getShuffleAssignments("app-4", + 1, 1, 1, Sets.newHashSet("fixed", Constants.SHUFFLE_SERVER_VERSION)); + assignedServerPorts = assignmentsInfo + .getPartitionToServers() + .values() + .stream() + .flatMap(x -> x.stream()) + .map(x -> x.getPort()) + .collect(Collectors.toList()); + assertEquals(1, assignedServerPorts.size()); + assertTrue(tagOfShufflePorts.get("fixed").contains(assignedServerPorts.get(0))); + + // case5: Set the multiple tags if non-exist + try { + assignmentsInfo = shuffleWriteClient.getShuffleAssignments("app-5", + 1, 1, 1, Sets.newHashSet("fixed", "elastic", Constants.SHUFFLE_SERVER_VERSION)); + fail(); + } catch (Exception e) { + assertTrue(e.getMessage().startsWith("Error happened when getShuffleAssignments with")); + } + } +} diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServer.java b/server/src/main/java/org/apache/uniffle/server/ShuffleServer.java index 45c446babe..8ddf6fc5ff 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleServer.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServer.java @@ -17,6 +17,7 @@ package org.apache.uniffle.server; +import java.util.Collections; import java.util.List; import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; @@ -25,6 +26,7 @@ import com.google.common.collect.Lists; import com.google.common.collect.Sets; import io.prometheus.client.CollectorRegistry; +import org.apache.commons.collections.CollectionUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import picocli.CommandLine; @@ -157,8 +159,18 @@ private void initialization() throws Exception { setServer(); + initServerTags(); + } + + private void initServerTags() { // it's the system tag for server's version tags.add(Constants.SHUFFLE_SERVER_VERSION); + + List configuredTags = shuffleServerConf.get(ShuffleServerConf.TAGS); + if (CollectionUtils.isNotEmpty(configuredTags)) { + tags.addAll(configuredTags); + } + LOG.info("Server tags: {}", tags); } private void registerMetrics() { @@ -262,7 +274,7 @@ public StorageManager getStorageManager() { } public Set getTags() { - return tags; + return Collections.unmodifiableSet(tags); } public boolean isHealthy() { diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java b/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java index ee4b7b0be4..7f0a74a8ab 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java @@ -337,6 +337,13 @@ public class ShuffleServerConf extends RssBaseConf { .defaultValue(0L) .withDescription("For multistorage, fail times exceed the number, will switch storage"); + public static final ConfigOption> TAGS = ConfigOptions + .key("rss.server.tags") + .stringType() + .asList() + .noDefaultValue() + .withDescription("Tags list supported by shuffle server"); + public ShuffleServerConf() { }