diff --git a/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRClientConf.java b/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRClientConf.java new file mode 100644 index 0000000000..31272ed55b --- /dev/null +++ b/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRClientConf.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.mapreduce; + +import java.util.List; +import java.util.Set; + +import com.google.common.collect.ImmutableSet; +import org.apache.commons.lang3.StringUtils; +import org.apache.hadoop.mapred.JobConf; + +import org.apache.uniffle.common.config.ConfigOption; +import org.apache.uniffle.common.config.ConfigOptions; +import org.apache.uniffle.common.config.ConfigUtils; +import org.apache.uniffle.common.config.RssClientConf; + +public class RssMRClientConf extends RssClientConf { + public static final String MR_RSS_CONFIG_PREFIX = "mapreduce."; + public static final String RSS_ASSIGNMENT_PREFIX = MR_RSS_CONFIG_PREFIX + "rss.assignment.partition."; + public static final long RSS_WRITER_BUFFER_SIZE_DEFAULT_VALUE = 1024 * 1024 * 14; + public static final String RSS_CONF_FILE = "rss_conf.xml"; + public static final Set RSS_MANDATORY_CLUSTER_CONF = + ImmutableSet.of( + RSS_STORAGE_TYPE.key(), + RSS_REMOTE_STORAGE_PATH.key() + ); + + public static final ConfigOption RSS_CLIENT_SEND_THRESHOLD = ConfigOptions + .key("rss.client.send.threshold") + .floatType() + .defaultValue(0.2f); + + public static final ConfigOption RSS_CLIENT_BATCH_TRIGGER_NUM = ConfigOptions + .key("rss.client.batch.trigger.num") + .intType() + .defaultValue(50); + + public static final ConfigOption RSS_CLIENT_SORT_MEMORY_USE_THRESHOLD = ConfigOptions + .key("rss.client.sort.memory.use.threshold") + .floatType() + .defaultValue(0.9f); + + public static final ConfigOption RSS_CLIENT_MEMORY_THRESHOLD = ConfigOptions + .key("rss.client.memory.threshold") + .floatType() + .defaultValue(0.8f); + + public static final ConfigOption RSS_CLIENT_BITMAP_NUM = ConfigOptions + .key("rss.client.bitmap.num") + .intType() + .defaultValue(1); + + public static final ConfigOption RSS_CLIENT_MAX_SEGMENT_SIZE = ConfigOptions + .key("rss.client.max.buffer.size") + .longType() + .defaultValue(3 * 1024L); + + public static final ConfigOption RSS_REDUCE_REMOTE_SPILL_ENABLED = ConfigOptions + .key("rss.reduce.remote.spill.enable") + .booleanType() + .defaultValue(false); + + public static final ConfigOption RSS_REDUCE_REMOTE_SPILL_ATTEMPT_INC = ConfigOptions + .key("rss.reduce.remote.spill.attempt.inc") + .intType() + .defaultValue(1); + + public static final ConfigOption RSS_REDUCE_REMOTE_SPILL_REPLICATION = ConfigOptions + .key("rss.reduce.remote.spill.replication") + .intType() + .defaultValue(1); + + public static final ConfigOption RSS_REDUCE_REMOTE_SPILL_RETRIES = ConfigOptions + .key("rss.reduce.remote.spill.retries") + .intType() + .defaultValue(5); + + public static final ConfigOption RSS_REMOTE_STORAGE_CONF = ConfigOptions + .key("rss.remote.storage.conf") + .stringType() + .noDefaultValue(); + + private RssMRClientConf(JobConf jobConf) { + List> configOptions = ConfigUtils.getAllConfigOptions(RssMRClientConf.class); + + for (ConfigOption option : configOptions) { + String val = jobConf.get(MR_RSS_CONFIG_PREFIX + option.key()); + if (StringUtils.isNotEmpty(val)) { + set(option, ConfigUtils.convertValue(val, option.getClazz())); + } + } + } + + public static RssMRClientConf from(JobConf jobConf) { + return new RssMRClientConf(jobConf); + } +} diff --git a/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java b/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java index 740de51eed..99f467673a 100644 --- a/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java +++ b/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java @@ -35,6 +35,8 @@ import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.util.Constants; +import static org.apache.hadoop.mapreduce.RssMRClientConf.MR_RSS_CONFIG_PREFIX; + public class RssMRUtils { private static final Logger LOG = LoggerFactory.getLogger(RssMRUtils.class); @@ -73,37 +75,39 @@ public static TaskAttemptID createMRTaskAttemptId( } public static ShuffleWriteClient createShuffleClient(JobConf jobConf) { - int heartBeatThreadNum = jobConf.getInt(RssMRConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM, - RssMRConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM_DEFAULT_VALUE); - int retryMax = jobConf.getInt(RssMRConfig.RSS_CLIENT_RETRY_MAX, - RssMRConfig.RSS_CLIENT_RETRY_MAX_DEFAULT_VALUE); - long retryIntervalMax = jobConf.getLong(RssMRConfig.RSS_CLIENT_RETRY_INTERVAL_MAX, - RssMRConfig.RSS_CLIENT_RETRY_INTERVAL_MAX_DEFAULT_VALUE); - String clientType = jobConf.get(RssMRConfig.RSS_CLIENT_TYPE, - RssMRConfig.RSS_CLIENT_TYPE_DEFAULT_VALUE); - int replicaWrite = jobConf.getInt(RssMRConfig.RSS_DATA_REPLICA_WRITE, - RssMRConfig.RSS_DATA_REPLICA_WRITE_DEFAULT_VALUE); - int replicaRead = jobConf.getInt(RssMRConfig.RSS_DATA_REPLICA_READ, - RssMRConfig.RSS_DATA_REPLICA_READ_DEFAULT_VALUE); - int replica = jobConf.getInt(RssMRConfig.RSS_DATA_REPLICA, - RssMRConfig.RSS_DATA_REPLICA_DEFAULT_VALUE); - boolean replicaSkipEnabled = jobConf.getBoolean(RssMRConfig.RSS_DATA_REPLICA_SKIP_ENABLED, - RssMRConfig.RSS_DATA_REPLICA_SKIP_ENABLED_DEFAULT_VALUE); - int dataTransferPoolSize = jobConf.getInt(RssMRConfig.RSS_DATA_TRANSFER_POOL_SIZE, - RssMRConfig.RSS_DATA_TRANSFER_POOL_SIZE_DEFAULT_VALUE); - int dataCommitPoolSize = jobConf.getInt(RssMRConfig.RSS_DATA_COMMIT_POOL_SIZE, - RssMRConfig.RSS_DATA_COMMIT_POOL_SIZE_DEFAULT_VALUE); + RssMRClientConf clientConf = RssMRClientConf.from(jobConf); + + int heartBeatThreadNum = clientConf.get(RssMRClientConf.RSS_CLIENT_HEARTBEAT_THREAD_NUM); + int retryMax = clientConf.get(RssMRClientConf.RSS_CLIENT_RETRY_MAX); + long retryIntervalMax = clientConf.get(RssMRClientConf.RSS_CLIENT_RETRY_INTERVAL_MAX); + String clientType = clientConf.get(RssMRClientConf.RSS_CLIENT_TYPE); + int replicaWrite = clientConf.get(RssMRClientConf.RSS_DATA_REPLICA_WRITE); + int replicaRead = clientConf.get(RssMRClientConf.RSS_DATA_REPLICA_READ); + int replica = clientConf.get(RssMRClientConf.RSS_DATA_REPLICA); + boolean replicaSkipEnabled = clientConf.get(RssMRClientConf.RSS_DATA_REPLICA_SKIP_ENABLED); + int dataTransferPoolSize = clientConf.get(RssMRClientConf.RSS_DATA_TRANSFER_POOL_SIZE); + int dataCommitPoolSize = clientConf.get(RssMRClientConf.RSS_DATA_COMMIT_POOL_SIZE); ShuffleWriteClient client = ShuffleClientFactory .getInstance() - .createShuffleWriteClient(clientType, retryMax, retryIntervalMax, - heartBeatThreadNum, replica, replicaWrite, replicaRead, replicaSkipEnabled, - dataTransferPoolSize, dataCommitPoolSize); + .createShuffleWriteClient( + clientType, + retryMax, + retryIntervalMax, + heartBeatThreadNum, + replica, + replicaWrite, + replicaRead, + replicaSkipEnabled, + dataTransferPoolSize, + dataCommitPoolSize + ); return client; } public static Set getAssignedServers(JobConf jobConf, int reduceID) { - String servers = jobConf.get(RssMRConfig.RSS_ASSIGNMENT_PREFIX - + String.valueOf(reduceID)); + String servers = jobConf.get( + MR_RSS_CONFIG_PREFIX + RssMRClientConf.RSS_ASSIGNMENT_PREFIX + String.valueOf(reduceID) + ); String[] splitServers = servers.split(","); Set assignServers = Sets.newHashSet(); for (String splitServer : splitServers) { @@ -111,8 +115,11 @@ public static Set getAssignedServers(JobConf jobConf, int red if (serverInfo.length != 2) { throw new RssException("partition " + reduceID + " server info isn't right"); } - ShuffleServerInfo sever = new ShuffleServerInfo(StringUtils.join(serverInfo, "-"), - serverInfo[0], Integer.parseInt(serverInfo[1])); + ShuffleServerInfo sever = new ShuffleServerInfo( + StringUtils.join(serverInfo, "-"), + serverInfo[0], + Integer.parseInt(serverInfo[1]) + ); assignServers.add(sever); } return assignServers; @@ -138,12 +145,12 @@ public static void applyDynamicClientConf(JobConf jobConf, Map c for (Map.Entry kv : confItems.entrySet()) { String mrConfKey = kv.getKey(); - if (!mrConfKey.startsWith(RssMRConfig.MR_RSS_CONFIG_PREFIX)) { - mrConfKey = RssMRConfig.MR_RSS_CONFIG_PREFIX + mrConfKey; + if (!mrConfKey.startsWith(RssMRClientConf.MR_RSS_CONFIG_PREFIX)) { + mrConfKey = RssMRClientConf.MR_RSS_CONFIG_PREFIX + mrConfKey; } String mrConfVal = kv.getValue(); if (StringUtils.isEmpty(jobConf.get(mrConfKey, "")) - || RssMRConfig.RSS_MANDATORY_CLUSTER_CONF.contains(mrConfKey)) { + || RssMRClientConf.RSS_MANDATORY_CLUSTER_CONF.contains(mrConfKey)) { LOG.warn("Use conf dynamic conf {} = {}", mrConfKey, mrConfVal); jobConf.set(mrConfKey, mrConfVal); } diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkClientConf.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkClientConf.java new file mode 100644 index 0000000000..ba98354f0e --- /dev/null +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkClientConf.java @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableSet; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.spark.SparkConf; +import scala.Tuple2; + +import org.apache.uniffle.common.config.ConfigOption; +import org.apache.uniffle.common.config.ConfigOptions; +import org.apache.uniffle.common.config.ConfigUtils; +import org.apache.uniffle.common.config.RssClientConf; +import org.apache.uniffle.common.config.RssConf; + +public class RssSparkClientConf extends RssClientConf { + public static final String SPARK_CONFIG_KEY_PREFIX = "spark."; + public static final String SPARK_CONFIG_RSS_KEY_PREFIX = SPARK_CONFIG_KEY_PREFIX + "rss."; + + public static final String DEFAULT_RSS_WRITER_BUFFER_SIZE = "3m"; + public static final long DEFAULT_RSS_HEARTBEAT_TIMEOUT = 5 * 1000L; + + public static final Set RSS_MANDATORY_CLUSTER_CONF = + ImmutableSet.of(RSS_STORAGE_TYPE.key(), RSS_REMOTE_STORAGE_PATH.key()); + + public static final ConfigOption RSS_WRITER_SERIALIZER_BUFFER_SIZE = ConfigOptions + .key("rss.writer.serializer.buffer.size") + .stringType() + .defaultValue("3k") + .withDescription(""); + + public static final ConfigOption RSS_WRITER_BUFFER_SEGMENT_SIZE = ConfigOptions + .key("rss.writer.buffer.segment.size") + .stringType() + .defaultValue("3k") + .withDescription(""); + + public static final ConfigOption RSS_WRITER_BUFFER_SPILL_SIZE = ConfigOptions + .key("rss.writer.buffer.spill.size") + .stringType() + .defaultValue("128m") + .withDescription("Buffer size for total partition data"); + + public static final ConfigOption RSS_WRITER_PRE_ALLOCATED_BUFFER_SIZE = ConfigOptions + .key("rss.writer.pre.allocated.buffer.size") + .stringType() + .defaultValue("16m") + .withDescription("Buffer size for total partition data"); + + public static final ConfigOption RSS_WRITER_REQUIRE_MEMORY_RETRY_MAX = ConfigOptions + .key("rss.writer.require.memory.retryMax") + .intType() + .defaultValue(1200) + .withDescription(""); + + public static final ConfigOption RSS_WRITER_REQUIRE_MEMORY_INTERVAL = ConfigOptions + .key("rss.writer.require.memory.interval") + .longType() + .defaultValue(1000L) + .withDescription(""); + + public static final ConfigOption RSS_TEST_FLAG = ConfigOptions + .key("rss.test") + .booleanType() + .defaultValue(false); + + public static final ConfigOption RSS_CLIENT_SEND_SIZE_LIMIT = ConfigOptions + .key("rss.client.send.size.limit") + .stringType() + .defaultValue("16m") + .withDescription("The max data size sent to shuffle server"); + + public static final ConfigOption RSS_CLIENT_SEND_THREAD_POOL_SIZE = ConfigOptions + .key("rss.client.send.threadPool.size") + .intType() + .defaultValue(10) + .withDescription("The thread size for send shuffle data to shuffle server"); + + public static final ConfigOption RSS_CLIENT_SEND_THREAD_POOL_KEEPALIVE = ConfigOptions + .key("rss.client.send.threadPool.keepalive") + .intType() + .defaultValue(60) + .withDescription(""); + + public static final ConfigOption RSS_OZONE_DFS_NAMENODE_ODFS_ENABLE = ConfigOptions + .key("rss.ozone.dfs.namenode.odfs.enable") + .booleanType() + .defaultValue(false) + .withDescription(""); + + public static final ConfigOption RSS_OZONE_FS_HDFS_IMPL = ConfigOptions + .key("rss.ozone.fs.hdfs.impl") + .stringType() + .defaultValue("org.apache.hadoop.odfs.HdfsOdfsFilesystem") + .withDescription(""); + + public static final ConfigOption RSS_OZONE_FS_ABSTRACT_FILE_SYSTEM_HDFS_IMPL = ConfigOptions + .key("rss.ozone.fs.AbstractFileSystem.hdfs.impl") + .stringType() + .defaultValue("org.apache.hadoop.odfs.HdfsOdfs") + .withDescription(""); + + public static final ConfigOption RSS_CLIENT_BITMAP_SPLIT_NUM = ConfigOptions + .key("rss.client.bitmap.splitNum") + .intType() + .defaultValue(1) + .withDescription(""); + + public static final ConfigOption RSS_ACCESS_ID = ConfigOptions + .key("rss.access.id") + .stringType() + .noDefaultValue() + .withDescription(""); + + public static final ConfigOption RSS_ENABLED = ConfigOptions + .key("rss.enabled") + .booleanType() + .defaultValue(false) + .withDescription(""); + + public static final ConfigOption RSS_CLIENT_ACCESS_RETRY_INTERVAL_MS = ConfigOptions + .key("rss.client.access.retry.interval.ms") + .longType() + .defaultValue(20000L) + .withDescription("Interval between retries fallback to SortShuffleManager"); + + public static final ConfigOption RSS_CLIENT_ACCESS_RETRY_TIMES = ConfigOptions + .key("rss.client.access.retry.times") + .intType() + .defaultValue(0) + .withDescription("Number of retries fallback to SortShuffleManager"); + + public RssSparkClientConf() { + // ignore + } + + private RssSparkClientConf(SparkConf sparkConf) { + List> configOptions = ConfigUtils.getAllConfigOptions(RssSparkClientConf.class); + + Map> configOptionMap = configOptions + .stream() + .collect( + Collectors.toMap( + entry -> entry.key(), + entry -> entry + ) + ); + + for (Tuple2 tuple : sparkConf.getAll()) { + String key = tuple._1; + if (!key.startsWith(SPARK_CONFIG_RSS_KEY_PREFIX)) { + continue; + } + key = key.substring(SPARK_CONFIG_KEY_PREFIX.length()); + String val = tuple._2; + ConfigOption configOption = configOptionMap.get(key); + if (configOption != null) { + set(configOption, ConfigUtils.convertValue(val, configOption.getClazz())); + } + } + } + + public static RssSparkClientConf from(SparkConf sparkConf) { + return new RssSparkClientConf(sparkConf); + } + + public static String toKey(ConfigOption option) { + return String.format("%s%s", SPARK_CONFIG_KEY_PREFIX, option.key()); + } + + @VisibleForTesting + public static void toSparkConf(RssConf rssConf, SparkConf sparkConf) { + List> confs = rssConf.getAll(); + for (Pair conf : confs) { + sparkConf.set(SPARK_CONFIG_KEY_PREFIX + conf.getLeft(), String.valueOf(conf.getRight())); + } + return; + } +} diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java index 13f7305442..06fdab5614 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java @@ -18,13 +18,12 @@ package org.apache.spark.shuffle; import java.lang.reflect.Constructor; -import java.util.Arrays; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; -import org.apache.commons.lang3.StringUtils; +import org.apache.commons.collections.CollectionUtils; import org.apache.hadoop.conf.Configuration; import org.apache.spark.SparkConf; import org.apache.spark.deploy.SparkHadoopUtil; @@ -34,8 +33,11 @@ import org.apache.uniffle.client.api.CoordinatorClient; import org.apache.uniffle.client.factory.CoordinatorClientFactory; import org.apache.uniffle.common.RemoteStorageInfo; +import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.util.Constants; +import static org.apache.spark.shuffle.RssSparkClientConf.toKey; + public class RssSparkShuffleUtils { private static final Logger LOG = LoggerFactory.getLogger(RssSparkShuffleUtils.class); @@ -44,16 +46,30 @@ public static Configuration newHadoopConfiguration(SparkConf sparkConf) { SparkHadoopUtil util = new SparkHadoopUtil(); Configuration conf = util.newConfiguration(sparkConf); - boolean useOdfs = sparkConf.get(RssSparkConfig.RSS_OZONE_DFS_NAMENODE_ODFS_ENABLE); + boolean useOdfs = sparkConf.getBoolean( + toKey(RssSparkClientConf.RSS_OZONE_DFS_NAMENODE_ODFS_ENABLE), + false + ); if (useOdfs) { final int OZONE_PREFIX_LEN = "spark.rss.ozone.".length(); - conf.setBoolean(RssSparkConfig.RSS_OZONE_DFS_NAMENODE_ODFS_ENABLE.key().substring(OZONE_PREFIX_LEN), useOdfs); + conf.setBoolean( + toKey(RssSparkClientConf.RSS_OZONE_DFS_NAMENODE_ODFS_ENABLE).substring(OZONE_PREFIX_LEN), + useOdfs + ); conf.set( - RssSparkConfig.RSS_OZONE_FS_HDFS_IMPL.key().substring(OZONE_PREFIX_LEN), - sparkConf.get(RssSparkConfig.RSS_OZONE_FS_HDFS_IMPL)); + toKey(RssSparkClientConf.RSS_OZONE_FS_HDFS_IMPL).substring(OZONE_PREFIX_LEN), + sparkConf.get( + toKey(RssSparkClientConf.RSS_OZONE_FS_HDFS_IMPL), + RssSparkClientConf.RSS_OZONE_FS_HDFS_IMPL.defaultValue() + ) + ); conf.set( - RssSparkConfig.RSS_OZONE_FS_ABSTRACT_FILE_SYSTEM_HDFS_IMPL.key().substring(OZONE_PREFIX_LEN), - sparkConf.get(RssSparkConfig.RSS_OZONE_FS_ABSTRACT_FILE_SYSTEM_HDFS_IMPL)); + toKey(RssSparkClientConf.RSS_OZONE_FS_ABSTRACT_FILE_SYSTEM_HDFS_IMPL).substring(OZONE_PREFIX_LEN), + sparkConf.get( + toKey(RssSparkClientConf.RSS_OZONE_FS_ABSTRACT_FILE_SYSTEM_HDFS_IMPL), + RssSparkClientConf.RSS_OZONE_FS_ABSTRACT_FILE_SYSTEM_HDFS_IMPL.defaultValue() + ) + ); } return conf; @@ -74,39 +90,16 @@ public static ShuffleManager loadShuffleManager(String name, SparkConf conf, boo } public static List createCoordinatorClients(SparkConf sparkConf) throws RuntimeException { - String clientType = sparkConf.get(RssSparkConfig.RSS_CLIENT_TYPE); - String coordinators = sparkConf.get(RssSparkConfig.RSS_COORDINATOR_QUORUM); + RssSparkClientConf clientConf = RssSparkClientConf.from(sparkConf); + String clientType = clientConf.get(RssSparkClientConf.RSS_CLIENT_TYPE); + String coordinators = clientConf.get(RssSparkClientConf.RSS_COORDINATOR_QUORUM); CoordinatorClientFactory coordinatorClientFactory = new CoordinatorClientFactory(clientType); return coordinatorClientFactory.createCoordinatorClient(coordinators); } - public static void applyDynamicClientConf(SparkConf sparkConf, Map confItems) { - if (sparkConf == null) { - LOG.warn("Spark conf is null"); - return; - } - - if (confItems == null || confItems.isEmpty()) { - LOG.warn("Empty conf items"); - return; - } - - for (Map.Entry kv : confItems.entrySet()) { - String sparkConfKey = kv.getKey(); - if (!sparkConfKey.startsWith(RssSparkConfig.SPARK_RSS_CONFIG_PREFIX)) { - sparkConfKey = RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + sparkConfKey; - } - String confVal = kv.getValue(); - if (!sparkConf.contains(sparkConfKey) || RssSparkConfig.RSS_MANDATORY_CLUSTER_CONF.contains(sparkConfKey)) { - LOG.warn("Use conf dynamic conf {} = {}", sparkConfKey, confVal); - sparkConf.set(sparkConfKey, confVal); - } - } - } - - public static void validateRssClientConf(SparkConf sparkConf) { + public static void validateRssClientConf(RssConf rssConf) { String msgFormat = "%s must be set by the client or fetched from coordinators."; - if (!sparkConf.contains(RssSparkConfig.RSS_STORAGE_TYPE.key())) { + if (!rssConf.contains(RssSparkClientConf.RSS_STORAGE_TYPE)) { String msg = String.format(msgFormat, "Storage type"); LOG.error(msg); throw new IllegalArgumentException(msg); @@ -125,12 +118,11 @@ public static Configuration getRemoteStorageHadoopConf( return readerHadoopConf; } - public static Set getAssignmentTags(SparkConf sparkConf) { + public static Set getAssignmentTags(RssConf rssConf) { Set assignmentTags = new HashSet<>(); - String rawTags = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_TAGS.key(), ""); - if (StringUtils.isNotEmpty(rawTags)) { - rawTags = rawTags.trim(); - assignmentTags.addAll(Arrays.asList(rawTags.split(","))); + List tags = rssConf.get(RssSparkClientConf.RSS_CLIENT_ASSIGNMENT_TAGS); + if (CollectionUtils.isNotEmpty(tags)) { + assignmentTags.addAll(tags); } assignmentTags.add(Constants.SHUFFLE_SERVER_VERSION); return assignmentTags; diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/BufferManagerOptions.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/BufferManagerOptions.java index 3ae878ccd3..d66c21167f 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/BufferManagerOptions.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/BufferManagerOptions.java @@ -17,8 +17,7 @@ package org.apache.spark.shuffle.writer; -import org.apache.spark.SparkConf; -import org.apache.spark.shuffle.RssSparkConfig; +import org.apache.spark.shuffle.RssSparkClientConf; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -34,37 +33,56 @@ public class BufferManagerOptions { private long requireMemoryInterval; private int requireMemoryRetryMax; - public BufferManagerOptions(SparkConf sparkConf) { - bufferSize = sparkConf.getSizeAsBytes(RssSparkConfig.RSS_WRITER_BUFFER_SIZE.key(), - RssSparkConfig.RSS_WRITER_BUFFER_SIZE.defaultValue().get()); - serializerBufferSize = sparkConf.getSizeAsBytes(RssSparkConfig.RSS_WRITER_SERIALIZER_BUFFER_SIZE.key(), - RssSparkConfig.RSS_WRITER_SERIALIZER_BUFFER_SIZE.defaultValue().get()); - bufferSegmentSize = sparkConf.getSizeAsBytes(RssSparkConfig.RSS_WRITER_BUFFER_SEGMENT_SIZE.key(), - RssSparkConfig.RSS_WRITER_BUFFER_SEGMENT_SIZE.defaultValue().get()); - bufferSpillThreshold = sparkConf.getSizeAsBytes(RssSparkConfig.RSS_WRITER_BUFFER_SPILL_SIZE.key(), - RssSparkConfig.RSS_WRITER_BUFFER_SPILL_SIZE.defaultValue().get()); - preAllocatedBufferSize = sparkConf.getSizeAsBytes(RssSparkConfig.RSS_WRITER_PRE_ALLOCATED_BUFFER_SIZE.key(), - RssSparkConfig.RSS_WRITER_PRE_ALLOCATED_BUFFER_SIZE.defaultValue().get()); - requireMemoryInterval = sparkConf.get(RssSparkConfig.RSS_WRITER_REQUIRE_MEMORY_INTERVAL); - requireMemoryRetryMax = sparkConf.get(RssSparkConfig.RSS_WRITER_REQUIRE_MEMORY_RETRY_MAX); - LOG.info(RssSparkConfig.RSS_WRITER_BUFFER_SIZE.key() + "=" + bufferSize); - LOG.info(RssSparkConfig.RSS_WRITER_BUFFER_SPILL_SIZE.key() + "=" + bufferSpillThreshold); - LOG.info(RssSparkConfig.RSS_WRITER_PRE_ALLOCATED_BUFFER_SIZE.key() + "=" + preAllocatedBufferSize); + public BufferManagerOptions(RssSparkClientConf rssConf) { + + bufferSize = rssConf.getSizeAsBytes( + RssSparkClientConf.RSS_WRITER_BUFFER_SIZE.key(), + RssSparkClientConf.DEFAULT_RSS_WRITER_BUFFER_SIZE + ); + + serializerBufferSize = rssConf.getSizeAsBytes( + RssSparkClientConf.RSS_WRITER_SERIALIZER_BUFFER_SIZE.key(), + RssSparkClientConf.RSS_WRITER_SERIALIZER_BUFFER_SIZE.defaultValue() + ); + + bufferSegmentSize = rssConf.getSizeAsBytes( + RssSparkClientConf.RSS_WRITER_BUFFER_SEGMENT_SIZE.key(), + RssSparkClientConf.RSS_WRITER_BUFFER_SEGMENT_SIZE.defaultValue() + ); + + bufferSpillThreshold = rssConf.getSizeAsBytes( + RssSparkClientConf.RSS_WRITER_BUFFER_SPILL_SIZE.key(), + RssSparkClientConf.RSS_WRITER_BUFFER_SPILL_SIZE.defaultValue() + ); + + preAllocatedBufferSize = rssConf.getSizeAsBytes( + RssSparkClientConf.RSS_WRITER_PRE_ALLOCATED_BUFFER_SIZE.key(), + RssSparkClientConf.RSS_WRITER_PRE_ALLOCATED_BUFFER_SIZE.defaultValue() + ); + + requireMemoryInterval = rssConf.get(RssSparkClientConf.RSS_WRITER_REQUIRE_MEMORY_INTERVAL); + requireMemoryRetryMax = rssConf.get(RssSparkClientConf.RSS_WRITER_REQUIRE_MEMORY_RETRY_MAX); + + LOG.info(RssSparkClientConf.RSS_WRITER_BUFFER_SIZE.key() + "=" + bufferSize); + LOG.info(RssSparkClientConf.RSS_WRITER_BUFFER_SPILL_SIZE.key() + "=" + bufferSpillThreshold); + LOG.info(RssSparkClientConf.RSS_WRITER_PRE_ALLOCATED_BUFFER_SIZE.key() + "=" + preAllocatedBufferSize); checkBufferSize(); } private void checkBufferSize() { if (bufferSize < 0) { - throw new RuntimeException("Unexpected value of " + RssSparkConfig.RSS_WRITER_BUFFER_SIZE.key() + throw new RuntimeException("Unexpected value of " + RssSparkClientConf.RSS_WRITER_BUFFER_SIZE.key() + "=" + bufferSize); } if (bufferSpillThreshold < 0) { - throw new RuntimeException("Unexpected value of " + RssSparkConfig.RSS_WRITER_BUFFER_SPILL_SIZE.key() + throw new RuntimeException("Unexpected value of " + RssSparkClientConf.RSS_WRITER_BUFFER_SPILL_SIZE.key() + "=" + bufferSpillThreshold); } if (bufferSegmentSize > bufferSize) { - LOG.warn(RssSparkConfig.RSS_WRITER_BUFFER_SEGMENT_SIZE.key() + "[" + bufferSegmentSize + "] should be less than " - + RssSparkConfig.RSS_WRITER_BUFFER_SIZE.key() + "[" + bufferSize + "]"); + LOG.warn(RssSparkClientConf.RSS_WRITER_BUFFER_SEGMENT_SIZE.key() + + "[" + bufferSegmentSize + "] should be less than " + + RssSparkClientConf.RSS_WRITER_BUFFER_SIZE.key() + + "[" + bufferSize + "]"); } } diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/RssShuffleUtilsTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/RssShuffleUtilsTest.java new file mode 100644 index 0000000000..77a004ba0a --- /dev/null +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/RssShuffleUtilsTest.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle; + +import java.util.Map; + +import com.google.common.collect.Maps; +import org.junit.jupiter.api.Test; + +import org.apache.uniffle.client.util.RssShuffleUtils; +import org.apache.uniffle.common.config.ConfigOption; +import org.apache.uniffle.common.config.RssClientConf; +import org.apache.uniffle.storage.util.StorageType; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class RssShuffleUtilsTest { + + @Test + public void applyDynamicClientConfTest() { + final RssClientConf conf = new RssClientConf(); + + Map clientConf = Maps.newHashMap(); + + String remoteStoragePath = "hdfs://path1"; + String mockKey = "spark.mockKey"; + String mockValue = "v"; + + clientConf.put(RssSparkClientConf.RSS_REMOTE_STORAGE_PATH.key(), remoteStoragePath); + clientConf.put(RssSparkClientConf.RSS_STORAGE_TYPE.key(), StorageType.MEMORY_LOCALFILE_HDFS.name()); + clientConf.put(mockKey, mockValue); + + putWithDefaultVal(clientConf, RssSparkClientConf.RSS_CLIENT_TYPE); + putWithDefaultVal(clientConf, RssSparkClientConf.RSS_CLIENT_RETRY_MAX); + putWithDefaultVal(clientConf, RssSparkClientConf.RSS_CLIENT_RETRY_INTERVAL_MAX); + putWithDefaultVal(clientConf, RssSparkClientConf.RSS_DATA_REPLICA); + putWithDefaultVal(clientConf, RssSparkClientConf.RSS_DATA_REPLICA_WRITE); + putWithDefaultVal(clientConf, RssSparkClientConf.RSS_CLIENT_READ_BUFFER_SIZE); + putWithDefaultVal(clientConf, RssSparkClientConf.RSS_INDEX_READ_LIMIT); + putWithDefaultVal(clientConf, RssSparkClientConf.RSS_PARTITION_NUM_PER_RANGE); + putWithDefaultVal(clientConf, RssSparkClientConf.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS); + putWithDefaultVal(clientConf, RssSparkClientConf.RSS_CLIENT_SEND_CHECK_INTERVAL_MS); + putWithDefaultVal(clientConf, RssSparkClientConf.RSS_HEARTBEAT_INTERVAL); + putWithDefaultVal(clientConf, RssSparkClientConf.RSS_DATA_REPLICA_READ); + + RssShuffleUtils.applyDynamicClientConf(RssSparkClientConf.RSS_MANDATORY_CLUSTER_CONF, conf, clientConf); + + assertEquals(remoteStoragePath, conf.get(RssSparkClientConf.RSS_REMOTE_STORAGE_PATH)); + equalsWithDefaultVal(conf, RssSparkClientConf.RSS_CLIENT_TYPE); + equalsWithDefaultVal(conf, RssSparkClientConf.RSS_CLIENT_RETRY_MAX); + equalsWithDefaultVal(conf, RssSparkClientConf.RSS_INDEX_READ_LIMIT); + equalsWithDefaultVal(conf, RssSparkClientConf.RSS_PARTITION_NUM_PER_RANGE); + equalsWithDefaultVal(conf, RssSparkClientConf.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS); + equalsWithDefaultVal(conf, RssSparkClientConf.RSS_CLIENT_SEND_CHECK_INTERVAL_MS); + equalsWithDefaultVal(conf, RssSparkClientConf.RSS_HEARTBEAT_INTERVAL); + equalsWithDefaultVal(conf, RssSparkClientConf.RSS_DATA_REPLICA_READ); + equalsWithDefaultVal(conf, RssSparkClientConf.RSS_DATA_REPLICA_WRITE); + equalsWithDefaultVal(conf, RssSparkClientConf.RSS_DATA_REPLICA); + equalsWithDefaultVal(conf, RssSparkClientConf.RSS_CLIENT_RETRY_INTERVAL_MAX); + + assertEquals(StorageType.MEMORY_LOCALFILE_HDFS.name(), conf.get(RssSparkClientConf.RSS_STORAGE_TYPE)); + + assertEquals(mockValue, conf.getString(mockKey, "")); + + String remoteStoragePath2 = "hdfs://path2"; + clientConf = Maps.newHashMap(); + clientConf.put(RssSparkClientConf.RSS_STORAGE_TYPE.key(), StorageType.MEMORY_HDFS.name()); + clientConf.put(RssSparkClientConf.RSS_REMOTE_STORAGE_PATH.key(), remoteStoragePath2); + clientConf.put(mockKey, "won't be rewrite"); + clientConf.put(RssSparkClientConf.RSS_CLIENT_RETRY_MAX.key(), "99999"); + RssShuffleUtils.applyDynamicClientConf(RssSparkClientConf.RSS_MANDATORY_CLUSTER_CONF, conf, clientConf); + // overwrite + assertEquals(remoteStoragePath2, conf.get(RssSparkClientConf.RSS_REMOTE_STORAGE_PATH)); + assertEquals(StorageType.MEMORY_HDFS.name(), conf.get(RssSparkClientConf.RSS_STORAGE_TYPE)); + // won't be overwrite + assertEquals(mockValue, conf.getString(mockKey, "")); + assertEquals( + RssSparkClientConf.RSS_CLIENT_RETRY_MAX.defaultValue(), + conf.get(RssSparkClientConf.RSS_CLIENT_RETRY_MAX) + ); + } + + private void equalsWithDefaultVal(RssClientConf conf, ConfigOption option) { + assertEquals(option.defaultValue(), conf.get(option)); + } + + private void putWithDefaultVal(Map clientConf, ConfigOption option) { + clientConf.put(option.key(), String.valueOf(option.defaultValue())); + } +} diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/RssSparkClientConfTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/RssSparkClientConfTest.java new file mode 100644 index 0000000000..b136ecd006 --- /dev/null +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/RssSparkClientConfTest.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle; + +import org.apache.spark.SparkConf; +import org.junit.jupiter.api.Test; + +import org.apache.uniffle.common.config.RssConf; + +import static org.apache.spark.shuffle.RssSparkClientConf.RSS_WRITER_SERIALIZER_BUFFER_SIZE; +import static org.apache.uniffle.common.config.RssClientConf.RSS_CLIENT_TYPE; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class RssSparkClientConfTest { + + @Test + public void testInitializedFromSparkConf() { + SparkConf sparkConf = new SparkConf(); + sparkConf.set("spark.rss.writer.serializer.buffer.size", "6k"); + sparkConf.set("spark.rss.client.type", "NETTY"); + + RssConf rssConf = RssSparkClientConf.from(sparkConf); + assertEquals("6k", rssConf.get(RSS_WRITER_SERIALIZER_BUFFER_SIZE)); + assertEquals("NETTY", rssConf.get(RSS_CLIENT_TYPE)); + } +} diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/RssSparkShuffleUtilsTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/RssSparkShuffleUtilsTest.java index 09d1e63ead..62a5811228 100644 --- a/client-spark/common/src/test/java/org/apache/spark/shuffle/RssSparkShuffleUtilsTest.java +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/RssSparkShuffleUtilsTest.java @@ -17,19 +17,18 @@ package org.apache.spark.shuffle; +import java.util.Arrays; import java.util.Iterator; -import java.util.Map; import java.util.Set; -import com.google.common.collect.Maps; import org.apache.hadoop.conf.Configuration; import org.apache.spark.SparkConf; import org.junit.jupiter.api.Test; -import org.apache.uniffle.client.util.RssClientConfig; +import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.util.Constants; -import org.apache.uniffle.storage.util.StorageType; +import static org.apache.spark.shuffle.RssSparkClientConf.toKey; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -38,7 +37,7 @@ public class RssSparkShuffleUtilsTest { @Test public void testAssignmentTags() { - SparkConf conf = new SparkConf(); + RssConf conf = new RssConf(); /** * Case1: dont set the tag implicitly and will return the {@code Constants.SHUFFLE_SERVER_VERSION} @@ -50,7 +49,7 @@ public void testAssignmentTags() { * Case2: set the multiple tags implicitly and will return the {@code Constants.SHUFFLE_SERVER_VERSION} * and configured tags. */ - conf.set(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_TAGS.key(), " a,b"); + conf.set(RssSparkClientConf.RSS_CLIENT_ASSIGNMENT_TAGS, Arrays.asList("a", "b")); tags = RssSparkShuffleUtils.getAssignmentTags(conf); assertEquals(3, tags.size()); Iterator iterator = tags.iterator(); @@ -66,96 +65,22 @@ public void odfsConfigurationTest() { assertFalse(conf1.getBoolean("dfs.namenode.odfs.enable", false)); assertEquals("org.apache.hadoop.fs.Hdfs", conf1.get("fs.AbstractFileSystem.hdfs.impl")); - conf.set(RssSparkConfig.RSS_OZONE_DFS_NAMENODE_ODFS_ENABLE.key(), "true"); + conf.set(toKey(RssSparkClientConf.RSS_OZONE_DFS_NAMENODE_ODFS_ENABLE), "true"); conf1 = RssSparkShuffleUtils.newHadoopConfiguration(conf); assertTrue(conf1.getBoolean("dfs.namenode.odfs.enable", false)); assertEquals("org.apache.hadoop.odfs.HdfsOdfsFilesystem", conf1.get("fs.hdfs.impl")); assertEquals("org.apache.hadoop.odfs.HdfsOdfs", conf1.get("fs.AbstractFileSystem.hdfs.impl")); - conf.set(RssSparkConfig.RSS_OZONE_FS_HDFS_IMPL.key(), "expect_odfs_impl"); - conf.set(RssSparkConfig.RSS_OZONE_FS_ABSTRACT_FILE_SYSTEM_HDFS_IMPL.key(), "expect_odfs_abstract_impl"); + conf.set( + toKey(RssSparkClientConf.RSS_OZONE_FS_HDFS_IMPL), + "expect_odfs_impl" + ); + conf.set( + toKey(RssSparkClientConf.RSS_OZONE_FS_ABSTRACT_FILE_SYSTEM_HDFS_IMPL), + "expect_odfs_abstract_impl" + ); conf1 = RssSparkShuffleUtils.newHadoopConfiguration(conf); assertEquals("expect_odfs_impl", conf1.get("fs.hdfs.impl")); assertEquals("expect_odfs_abstract_impl", conf1.get("fs.AbstractFileSystem.hdfs.impl")); } - - @Test - public void applyDynamicClientConfTest() { - final SparkConf conf = new SparkConf(); - Map clientConf = Maps.newHashMap(); - String remoteStoragePath = "hdfs://path1"; - String mockKey = "spark.mockKey"; - String mockValue = "v"; - - clientConf.put(RssClientConfig.RSS_REMOTE_STORAGE_PATH, remoteStoragePath); - clientConf.put(RssClientConfig.RSS_CLIENT_TYPE, RssClientConfig.RSS_CLIENT_TYPE_DEFAULT_VALUE); - clientConf.put(RssClientConfig.RSS_CLIENT_RETRY_MAX, - Integer.toString(RssClientConfig.RSS_CLIENT_RETRY_MAX_DEFAULT_VALUE)); - clientConf.put(RssClientConfig.RSS_CLIENT_RETRY_INTERVAL_MAX, - Long.toString(RssClientConfig.RSS_CLIENT_RETRY_INTERVAL_MAX_DEFAULT_VALUE)); - clientConf.put(RssClientConfig.RSS_DATA_REPLICA, - Integer.toString(RssClientConfig.RSS_DATA_REPLICA_DEFAULT_VALUE)); - clientConf.put(RssClientConfig.RSS_DATA_REPLICA_WRITE, - Integer.toString(RssClientConfig.RSS_DATA_REPLICA_WRITE_DEFAULT_VALUE)); - clientConf.put(RssClientConfig.RSS_DATA_REPLICA_READ, - Integer.toString(RssClientConfig.RSS_DATA_REPLICA_READ_DEFAULT_VALUE)); - clientConf.put(RssClientConfig.RSS_HEARTBEAT_INTERVAL, - Long.toString(RssClientConfig.RSS_HEARTBEAT_INTERVAL_DEFAULT_VALUE)); - clientConf.put(RssClientConfig.RSS_STORAGE_TYPE, StorageType.MEMORY_LOCALFILE_HDFS.name()); - clientConf.put(RssClientConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS, - Long.toString(RssClientConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS_DEFAULT_VALUE)); - clientConf.put(RssClientConfig.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS, - Long.toString(RssClientConfig.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS_DEFAULT_VALUE)); - clientConf.put(RssClientConfig.RSS_PARTITION_NUM_PER_RANGE, - Integer.toString(RssClientConfig.RSS_PARTITION_NUM_PER_RANGE_DEFAULT_VALUE)); - clientConf.put(RssClientConfig.RSS_INDEX_READ_LIMIT, - Integer.toString(RssClientConfig.RSS_INDEX_READ_LIMIT_DEFAULT_VALUE)); - clientConf.put(RssClientConfig.RSS_CLIENT_READ_BUFFER_SIZE, - RssClientConfig.RSS_CLIENT_READ_BUFFER_SIZE_DEFAULT_VALUE); - clientConf.put(mockKey, mockValue); - - RssSparkShuffleUtils.applyDynamicClientConf(conf, clientConf); - assertEquals(remoteStoragePath, conf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key())); - assertEquals(RssClientConfig.RSS_CLIENT_TYPE_DEFAULT_VALUE, - conf.get(RssSparkConfig.RSS_CLIENT_TYPE.key())); - assertEquals(Integer.toString(RssClientConfig.RSS_CLIENT_RETRY_MAX_DEFAULT_VALUE), - conf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX.key())); - assertEquals(Long.toString(RssClientConfig.RSS_CLIENT_RETRY_INTERVAL_MAX_DEFAULT_VALUE), - conf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX.key())); - assertEquals(Integer.toString(RssClientConfig.RSS_DATA_REPLICA_DEFAULT_VALUE), - conf.get(RssSparkConfig.RSS_DATA_REPLICA.key())); - assertEquals(Integer.toString(RssClientConfig.RSS_DATA_REPLICA_WRITE_DEFAULT_VALUE), - conf.get(RssSparkConfig.RSS_DATA_REPLICA_WRITE.key())); - assertEquals(Integer.toString(RssClientConfig.RSS_DATA_REPLICA_READ_DEFAULT_VALUE), - conf.get(RssSparkConfig.RSS_DATA_REPLICA_READ.key())); - assertEquals(Long.toString(RssClientConfig.RSS_HEARTBEAT_INTERVAL_DEFAULT_VALUE), - conf.get(RssSparkConfig.RSS_HEARTBEAT_INTERVAL.key())); - assertEquals(StorageType.MEMORY_LOCALFILE_HDFS.name(), conf.get(RssSparkConfig.RSS_STORAGE_TYPE.key())); - assertEquals(Long.toString(RssClientConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS_DEFAULT_VALUE), - conf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS.key())); - assertEquals(Long.toString(RssClientConfig.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS_DEFAULT_VALUE), - conf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS.key())); - assertEquals(Integer.toString(RssClientConfig.RSS_PARTITION_NUM_PER_RANGE_DEFAULT_VALUE), - conf.get(RssSparkConfig.RSS_PARTITION_NUM_PER_RANGE.key())); - assertEquals(Integer.toString(RssClientConfig.RSS_INDEX_READ_LIMIT_DEFAULT_VALUE), - conf.get(RssSparkConfig.RSS_INDEX_READ_LIMIT.key())); - assertEquals(RssClientConfig.RSS_CLIENT_READ_BUFFER_SIZE_DEFAULT_VALUE, - conf.get(RssSparkConfig.RSS_CLIENT_READ_BUFFER_SIZE.key())); - assertEquals(mockValue, conf.get(mockKey)); - - String remoteStoragePath2 = "hdfs://path2"; - clientConf = Maps.newHashMap(); - clientConf.put(RssClientConfig.RSS_STORAGE_TYPE, StorageType.MEMORY_HDFS.name()); - clientConf.put(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), remoteStoragePath2); - clientConf.put(mockKey, "won't be rewrite"); - clientConf.put(RssClientConfig.RSS_CLIENT_RETRY_MAX, "99999"); - RssSparkShuffleUtils.applyDynamicClientConf(conf, clientConf); - // overwrite - assertEquals(remoteStoragePath2, conf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key())); - assertEquals(StorageType.MEMORY_HDFS.name(), conf.get(RssSparkConfig.RSS_STORAGE_TYPE.key())); - // won't be overwrite - assertEquals(mockValue, conf.get(mockKey)); - assertEquals(Integer.toString(RssClientConfig.RSS_CLIENT_RETRY_MAX_DEFAULT_VALUE), - conf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX.key())); - } } diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java index 665f5d2d6d..654f632b0c 100644 --- a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java @@ -25,11 +25,12 @@ import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.serializer.KryoSerializer; import org.apache.spark.serializer.Serializer; -import org.apache.spark.shuffle.RssSparkConfig; +import org.apache.spark.shuffle.RssSparkClientConf; import org.junit.jupiter.api.Test; import org.apache.uniffle.common.ShuffleBlockInfo; +import static org.apache.spark.shuffle.RssSparkClientConf.toKey; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.anyLong; @@ -44,7 +45,7 @@ private WriteBufferManager createManager(SparkConf conf) { Serializer kryoSerializer = new KryoSerializer(conf); TaskMemoryManager mockTaskMemoryManager = mock(TaskMemoryManager.class); - BufferManagerOptions bufferOptions = new BufferManagerOptions(conf); + BufferManagerOptions bufferOptions = new BufferManagerOptions(RssSparkClientConf.from(conf)); WriteBufferManager wbm = new WriteBufferManager( 0, 0, bufferOptions, kryoSerializer, Maps.newHashMap(), mockTaskMemoryManager, new ShuffleWriteMetrics()); @@ -55,11 +56,11 @@ private WriteBufferManager createManager(SparkConf conf) { private SparkConf getConf() { SparkConf conf = new SparkConf(false); - conf.set(RssSparkConfig.RSS_WRITER_BUFFER_SIZE.key(), "64") - .set(RssSparkConfig.RSS_WRITER_BUFFER_SEGMENT_SIZE.key(), "32") - .set(RssSparkConfig.RSS_WRITER_SERIALIZER_BUFFER_SIZE.key(), "128") - .set(RssSparkConfig.RSS_WRITER_PRE_ALLOCATED_BUFFER_SIZE.key(), "512") - .set(RssSparkConfig.RSS_WRITER_BUFFER_SPILL_SIZE.key(), "190"); + conf.set(toKey(RssSparkClientConf.RSS_WRITER_BUFFER_SIZE), "64") + .set(toKey(RssSparkClientConf.RSS_WRITER_BUFFER_SEGMENT_SIZE), "32") + .set(toKey(RssSparkClientConf.RSS_WRITER_SERIALIZER_BUFFER_SIZE), "128") + .set(toKey(RssSparkClientConf.RSS_WRITER_PRE_ALLOCATED_BUFFER_SIZE), "512") + .set(toKey(RssSparkClientConf.RSS_WRITER_BUFFER_SPILL_SIZE), "190"); return conf; } diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/DelegationRssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/DelegationRssShuffleManager.java index 7c76e200d0..22486bf3d3 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/DelegationRssShuffleManager.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/DelegationRssShuffleManager.java @@ -36,6 +36,8 @@ import org.apache.uniffle.common.util.Constants; import org.apache.uniffle.common.util.RetryUtils; +import static org.apache.spark.shuffle.RssSparkClientConf.toKey; + public class DelegationRssShuffleManager implements ShuffleManager { private static final Logger LOG = LoggerFactory.getLogger(DelegationRssShuffleManager.class); @@ -44,16 +46,19 @@ public class DelegationRssShuffleManager implements ShuffleManager { private final List coordinatorClients; private final int accessTimeoutMs; private final SparkConf sparkConf; + private final RssSparkClientConf rssSparkClientConf; public DelegationRssShuffleManager(SparkConf sparkConf, boolean isDriver) throws Exception { this.sparkConf = sparkConf; - accessTimeoutMs = sparkConf.get(RssSparkConfig.RSS_ACCESS_TIMEOUT_MS); + this.rssSparkClientConf = RssSparkClientConf.from(sparkConf); + + this.accessTimeoutMs = rssSparkClientConf.get(RssSparkClientConf.RSS_ACCESS_TIMEOUT_MS); if (isDriver) { - coordinatorClients = RssSparkShuffleUtils.createCoordinatorClients(sparkConf); - delegate = createShuffleManagerInDriver(); + this.coordinatorClients = RssSparkShuffleUtils.createCoordinatorClients(sparkConf); + this.delegate = createShuffleManagerInDriver(); } else { - coordinatorClients = Lists.newArrayList(); - delegate = createShuffleManagerInExecutor(); + this.coordinatorClients = Lists.newArrayList(); + this.delegate = createShuffleManagerInExecutor(); } if (delegate == null) { @@ -68,8 +73,14 @@ private ShuffleManager createShuffleManagerInDriver() throws RssException { if (canAccess) { try { shuffleManager = new RssShuffleManager(sparkConf, true); - sparkConf.set(RssSparkConfig.RSS_ENABLED.key(), "true"); - sparkConf.set("spark.shuffle.manager", RssShuffleManager.class.getCanonicalName()); + sparkConf.set( + toKey(RssSparkClientConf.RSS_ENABLED), + "true" + ); + sparkConf.set( + "spark.shuffle.manager", + RssShuffleManager.class.getCanonicalName() + ); LOG.info("Use RssShuffleManager"); return shuffleManager; } catch (Exception exception) { @@ -79,7 +90,7 @@ private ShuffleManager createShuffleManagerInDriver() throws RssException { try { shuffleManager = RssSparkShuffleUtils.loadShuffleManager(Constants.SORT_SHUFFLE_MANAGER_NAME, sparkConf, true); - sparkConf.set(RssSparkConfig.RSS_ENABLED.key(), "false"); + sparkConf.set(toKey(RssSparkClientConf.RSS_ENABLED), "false"); sparkConf.set("spark.shuffle.manager", "sort"); LOG.info("Use SortShuffleManager"); } catch (Exception e) { @@ -90,17 +101,16 @@ private ShuffleManager createShuffleManagerInDriver() throws RssException { } private boolean tryAccessCluster() { - String accessId = sparkConf.get( - RssSparkConfig.RSS_ACCESS_ID.key(), "").trim(); + String accessId = rssSparkClientConf.getString(RssSparkClientConf.RSS_ACCESS_ID.key(), "").trim(); if (StringUtils.isEmpty(accessId)) { LOG.warn("Access id key is empty"); return false; } - long retryInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_ACCESS_RETRY_INTERVAL_MS); - int retryTimes = sparkConf.get(RssSparkConfig.RSS_CLIENT_ACCESS_RETRY_TIMES); + long retryInterval = rssSparkClientConf.get(RssSparkClientConf.RSS_CLIENT_ACCESS_RETRY_INTERVAL_MS); + int retryTimes = rssSparkClientConf.get(RssSparkClientConf.RSS_CLIENT_ACCESS_RETRY_TIMES); for (CoordinatorClient coordinatorClient : coordinatorClients) { - Set assignmentTags = RssSparkShuffleUtils.getAssignmentTags(sparkConf); + Set assignmentTags = RssSparkShuffleUtils.getAssignmentTags(rssSparkClientConf); boolean canAccess; try { canAccess = RetryUtils.retry(() -> { @@ -130,7 +140,7 @@ private boolean tryAccessCluster() { private ShuffleManager createShuffleManagerInExecutor() throws RssException { ShuffleManager shuffleManager; // get useRSS from spark conf - boolean useRSS = sparkConf.get(RssSparkConfig.RSS_ENABLED); + boolean useRSS = rssSparkClientConf.get(RssSparkClientConf.RSS_ENABLED); if (useRSS) { // Executor will not do any fallback shuffleManager = new RssShuffleManager(sparkConf, false); diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index 0d0ca539ad..d6dd3ba8f6 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -59,16 +59,20 @@ import org.apache.uniffle.client.factory.ShuffleClientFactory; import org.apache.uniffle.client.response.SendShuffleDataResult; import org.apache.uniffle.client.util.ClientUtils; +import org.apache.uniffle.client.util.RssShuffleUtils; import org.apache.uniffle.common.PartitionRange; import org.apache.uniffle.common.RemoteStorageInfo; import org.apache.uniffle.common.ShuffleAssignmentsInfo; import org.apache.uniffle.common.ShuffleBlockInfo; import org.apache.uniffle.common.ShuffleServerInfo; +import org.apache.uniffle.common.config.RssClientConf; import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.util.RetryUtils; import org.apache.uniffle.common.util.RssUtils; import org.apache.uniffle.common.util.ThreadUtils; +import static org.apache.spark.shuffle.RssSparkClientConf.RSS_MANDATORY_CLUSTER_CONF; + public class RssShuffleManager implements ShuffleManager { private static final Logger LOG = LoggerFactory.getLogger(RssShuffleManager.class); @@ -77,7 +81,8 @@ public class RssShuffleManager implements ShuffleManager { private final long heartbeatTimeout; private final ThreadPoolExecutor threadPoolExecutor; private AtomicReference id = new AtomicReference<>(); - private SparkConf sparkConf; + private final SparkConf sparkConf; + private RssSparkClientConf rssSparkClientConf; private final int dataReplica; private final int dataReplicaWrite; private final int dataReplicaRead; @@ -140,25 +145,27 @@ private synchronized void putBlockId( public RssShuffleManager(SparkConf conf, boolean isDriver) { this.sparkConf = conf; + this.rssSparkClientConf = RssSparkClientConf.from(sparkConf); // set & check replica config - this.dataReplica = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA); - this.dataReplicaWrite = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_WRITE); - this.dataReplicaRead = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_READ); - this.dataReplicaSkipEnabled = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_SKIP_ENABLED); + this.dataReplica = rssSparkClientConf.get(RssClientConf.RSS_DATA_REPLICA); + this.dataReplicaWrite = rssSparkClientConf.get(RssClientConf.RSS_DATA_REPLICA_WRITE); + this.dataReplicaRead = rssSparkClientConf.get(RssClientConf.RSS_DATA_REPLICA_READ); + this.dataReplicaSkipEnabled = rssSparkClientConf.get(RssClientConf.RSS_DATA_REPLICA_SKIP_ENABLED); LOG.info("Check quorum config [" + dataReplica + ":" + dataReplicaWrite + ":" + dataReplicaRead + ":" + dataReplicaSkipEnabled + "]"); RssUtils.checkQuorumSetting(dataReplica, dataReplicaWrite, dataReplicaRead); - this.heartbeatInterval = sparkConf.get(RssSparkConfig.RSS_HEARTBEAT_INTERVAL); - this.heartbeatTimeout = sparkConf.getLong(RssSparkConfig.RSS_HEARTBEAT_TIMEOUT.key(), heartbeatInterval / 2); - final int retryMax = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX); - this.clientType = sparkConf.get(RssSparkConfig.RSS_CLIENT_TYPE); - this.dynamicConfEnabled = sparkConf.get(RssSparkConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED); - long retryIntervalMax = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX); - int heartBeatThreadNum = sparkConf.get(RssSparkConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM); - this.dataTransferPoolSize = sparkConf.get(RssSparkConfig.RSS_DATA_TRANSFER_POOL_SIZE); - this.dataCommitPoolSize = sparkConf.get(RssSparkConfig.RSS_DATA_COMMIT_POOL_SIZE); + this.heartbeatInterval = rssSparkClientConf.get(RssClientConf.RSS_HEARTBEAT_INTERVAL); + this.heartbeatTimeout = rssSparkClientConf + .getLong(RssClientConf.RSS_HEARTBEAT_TIMEOUT.key(), heartbeatInterval / 2); + final int retryMax = rssSparkClientConf.get(RssClientConf.RSS_CLIENT_RETRY_MAX); + this.clientType = rssSparkClientConf.get(RssClientConf.RSS_CLIENT_TYPE); + this.dynamicConfEnabled = rssSparkClientConf.get(RssClientConf.RSS_DYNAMIC_CLIENT_CONF_ENABLED); + long retryIntervalMax = rssSparkClientConf.get(RssClientConf.RSS_CLIENT_RETRY_INTERVAL_MAX); + int heartBeatThreadNum = rssSparkClientConf.get(RssSparkClientConf.RSS_CLIENT_HEARTBEAT_THREAD_NUM); + this.dataTransferPoolSize = rssSparkClientConf.get(RssClientConf.RSS_DATA_TRANSFER_POOL_SIZE); + this.dataCommitPoolSize = rssSparkClientConf.get(RssClientConf.RSS_DATA_COMMIT_POOL_SIZE); shuffleWriteClient = ShuffleClientFactory .getInstance() @@ -169,11 +176,11 @@ public RssShuffleManager(SparkConf conf, boolean isDriver) { // fetch client conf and apply them if necessary and disable ESS if (isDriver && dynamicConfEnabled) { Map clusterClientConf = shuffleWriteClient.fetchClientConf( - sparkConf.getInt(RssSparkConfig.RSS_ACCESS_TIMEOUT_MS.key(), - RssSparkConfig.RSS_ACCESS_TIMEOUT_MS.defaultValue().get())); - RssSparkShuffleUtils.applyDynamicClientConf(sparkConf, clusterClientConf); + rssSparkClientConf.get(RssClientConf.RSS_ACCESS_TIMEOUT_MS) + ); + RssShuffleUtils.applyDynamicClientConf(RSS_MANDATORY_CLUSTER_CONF, rssSparkClientConf, clusterClientConf); } - RssSparkShuffleUtils.validateRssClientConf(sparkConf); + RssSparkShuffleUtils.validateRssClientConf(rssSparkClientConf); // External shuffle service is not supported when using remote shuffle service sparkConf.set("spark.shuffle.service.enabled", "false"); LOG.info("Disable external shuffle service in RssShuffleManager."); @@ -183,8 +190,8 @@ public RssShuffleManager(SparkConf conf, boolean isDriver) { LOG.info("RSS data send thread is starting"); eventLoop = defaultEventLoop; eventLoop.start(); - int poolSize = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_SIZE); - int keepAliveTime = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_KEEPALIVE); + int poolSize = rssSparkClientConf.get(RssSparkClientConf.RSS_CLIENT_SEND_THREAD_POOL_SIZE); + int keepAliveTime = rssSparkClientConf.get(RssSparkClientConf.RSS_CLIENT_SEND_THREAD_POOL_KEEPALIVE); threadPoolExecutor = new ThreadPoolExecutor(poolSize, poolSize * 2, keepAliveTime, TimeUnit.SECONDS, Queues.newLinkedBlockingQueue(Integer.MAX_VALUE), ThreadUtils.getThreadFactory("SendData-%d")); @@ -203,22 +210,24 @@ public RssShuffleManager(SparkConf conf, boolean isDriver) { Map> taskToSuccessBlockIds, Map> taskToFailedBlockIds) { this.sparkConf = conf; - this.clientType = sparkConf.get(RssSparkConfig.RSS_CLIENT_TYPE); - this.heartbeatInterval = sparkConf.get(RssSparkConfig.RSS_HEARTBEAT_INTERVAL); - this.heartbeatTimeout = sparkConf.getLong(RssSparkConfig.RSS_HEARTBEAT_TIMEOUT.key(), heartbeatInterval / 2); - this.dataReplica = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA); - this.dataReplicaWrite = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_WRITE); - this.dataReplicaRead = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_READ); - this.dataReplicaSkipEnabled = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_SKIP_ENABLED); + this.rssSparkClientConf = RssSparkClientConf.from(sparkConf); + this.clientType = rssSparkClientConf.get(RssSparkClientConf.RSS_CLIENT_TYPE); + this.heartbeatInterval = rssSparkClientConf.get(RssSparkClientConf.RSS_HEARTBEAT_INTERVAL); + this.heartbeatTimeout = rssSparkClientConf.getLong(RssSparkClientConf.RSS_HEARTBEAT_TIMEOUT.key(), + heartbeatInterval / 2); + this.dataReplica = rssSparkClientConf.get(RssSparkClientConf.RSS_DATA_REPLICA); + this.dataReplicaWrite = rssSparkClientConf.get(RssSparkClientConf.RSS_DATA_REPLICA_WRITE); + this.dataReplicaRead = rssSparkClientConf.get(RssSparkClientConf.RSS_DATA_REPLICA_READ); + this.dataReplicaSkipEnabled = rssSparkClientConf.get(RssSparkClientConf.RSS_DATA_REPLICA_SKIP_ENABLED); LOG.info("Check quorum config [" + dataReplica + ":" + dataReplicaWrite + ":" + dataReplicaRead + ":" + dataReplicaSkipEnabled + "]"); RssUtils.checkQuorumSetting(dataReplica, dataReplicaWrite, dataReplicaRead); - int retryMax = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX); - long retryIntervalMax = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX); - int heartBeatThreadNum = sparkConf.get(RssSparkConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM); - this.dataTransferPoolSize = sparkConf.get(RssSparkConfig.RSS_DATA_TRANSFER_POOL_SIZE); - this.dataCommitPoolSize = sparkConf.get(RssSparkConfig.RSS_DATA_COMMIT_POOL_SIZE); + int retryMax = rssSparkClientConf.get(RssSparkClientConf.RSS_CLIENT_RETRY_MAX); + long retryIntervalMax = rssSparkClientConf.get(RssSparkClientConf.RSS_CLIENT_RETRY_INTERVAL_MAX); + int heartBeatThreadNum = rssSparkClientConf.get(RssSparkClientConf.RSS_CLIENT_HEARTBEAT_THREAD_NUM); + this.dataTransferPoolSize = rssSparkClientConf.get(RssSparkClientConf.RSS_DATA_TRANSFER_POOL_SIZE); + this.dataCommitPoolSize = rssSparkClientConf.get(RssSparkClientConf.RSS_DATA_COMMIT_POOL_SIZE); shuffleWriteClient = ShuffleClientFactory .getInstance() @@ -251,19 +260,20 @@ public ShuffleHandle registerShuffle(int shuffleId, ShuffleDependency< } LOG.info("Generate application id used in rss: " + id.get()); - String storageType = sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key()); + String storageType = rssSparkClientConf.get(RssSparkClientConf.RSS_STORAGE_TYPE); remoteStorage = new RemoteStorageInfo( - sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), "")); + rssSparkClientConf.getString(RssSparkClientConf.RSS_REMOTE_STORAGE_PATH.key(), "")); remoteStorage = ClientUtils.fetchRemoteStorage( id.get(), remoteStorage, dynamicConfEnabled, storageType, shuffleWriteClient); - Set assignmentTags = RssSparkShuffleUtils.getAssignmentTags(sparkConf); + Set assignmentTags = RssSparkShuffleUtils.getAssignmentTags(rssSparkClientConf); - int requiredShuffleServerNumber = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER); + int requiredShuffleServerNumber = rssSparkClientConf + .get(RssSparkClientConf.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER); // retryInterval must bigger than `rss.server.heartbeat.timeout`, or maybe it will return the same result - long retryInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL); - int retryTimes = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES); + long retryInterval = rssSparkClientConf.get(RssSparkClientConf.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL); + int retryTimes = rssSparkClientConf.get(RssSparkClientConf.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES); Map> partitionToServers; try { partitionToServers = RetryUtils.retry(() -> { @@ -308,7 +318,7 @@ public ShuffleWriter getWriter( } int shuffleId = rssHandle.getShuffleId(); String taskId = "" + context.taskAttemptId() + "_" + context.attemptNumber(); - BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf); + BufferManagerOptions bufferOptions = new BufferManagerOptions(rssSparkClientConf); ShuffleWriteMetrics writeMetrics; if (metrics != null) { writeMetrics = new WriteMetrics(metrics); @@ -394,14 +404,16 @@ public ShuffleReader getReaderImpl( if (!(handle instanceof RssShuffleHandle)) { throw new RuntimeException("Unexpected ShuffleHandle:" + handle.getClass().getName()); } - final String storageType = sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key()); - final int indexReadLimit = sparkConf.get(RssSparkConfig.RSS_INDEX_READ_LIMIT); + final String storageType = rssSparkClientConf.get(RssSparkClientConf.RSS_STORAGE_TYPE); + final int indexReadLimit = rssSparkClientConf.get(RssSparkClientConf.RSS_INDEX_READ_LIMIT); RssShuffleHandle rssShuffleHandle = (RssShuffleHandle) handle; final int partitionNum = rssShuffleHandle.getDependency().partitioner().numPartitions(); - long readBufferSize = sparkConf.getSizeAsBytes(RssSparkConfig.RSS_CLIENT_READ_BUFFER_SIZE.key(), - RssSparkConfig.RSS_CLIENT_READ_BUFFER_SIZE.defaultValue().get()); + long readBufferSize = rssSparkClientConf.getSizeAsBytes( + RssSparkClientConf.RSS_CLIENT_READ_BUFFER_SIZE.key(), + RssSparkClientConf.RSS_CLIENT_READ_BUFFER_SIZE.defaultValue() + ); if (readBufferSize > Integer.MAX_VALUE) { - LOG.warn(RssSparkConfig.RSS_CLIENT_READ_BUFFER_SIZE.key() + " can support 2g as max"); + LOG.warn(RssSparkClientConf.RSS_CLIENT_READ_BUFFER_SIZE.key() + " can support 2g as max"); readBufferSize = Integer.MAX_VALUE; } int shuffleId = rssShuffleHandle.getShuffleId(); @@ -589,7 +601,7 @@ protected void registerShuffleServers( @VisibleForTesting protected void registerCoordinator() { - String coordinators = sparkConf.get(RssSparkConfig.RSS_COORDINATOR_QUORUM.key()); + String coordinators = rssSparkClientConf.get(RssSparkClientConf.RSS_COORDINATOR_QUORUM); LOG.info("Start Registering coordinators {}", coordinators); shuffleWriteClient.registerCoordinators(coordinators); } diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java index dbebc338f6..b1de59c082 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java @@ -39,7 +39,7 @@ import org.apache.spark.scheduler.MapStatus; import org.apache.spark.shuffle.RssShuffleHandle; import org.apache.spark.shuffle.RssShuffleManager; -import org.apache.spark.shuffle.RssSparkConfig; +import org.apache.spark.shuffle.RssSparkClientConf; import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.BlockManagerId; import org.slf4j.Logger; @@ -95,6 +95,7 @@ public RssShuffleWriter( ShuffleWriteClient shuffleWriteClient, RssShuffleHandle rssHandle) { LOG.warn("RssShuffle start write taskAttemptId data" + taskAttemptId); + this.shuffleManager = shuffleManager; this.appId = appId; this.bufferManager = bufferManager; @@ -106,11 +107,15 @@ public RssShuffleWriter( this.shuffleDependency = rssHandle.getDependency(); this.partitioner = shuffleDependency.partitioner(); this.shouldPartition = partitioner.numPartitions() > 1; - this.sendCheckTimeout = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS); - this.sendCheckInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS); - this.sendSizeLimit = sparkConf.getSizeAsBytes(RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMIT.key(), - RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMIT.defaultValue().get()); - this.bitmapSplitNum = sparkConf.get(RssSparkConfig.RSS_CLIENT_BITMAP_SPLIT_NUM); + + RssSparkClientConf rssSparkClientConf = RssSparkClientConf.from(sparkConf); + this.sendCheckTimeout = rssSparkClientConf.get(RssSparkClientConf.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS); + this.sendCheckInterval = rssSparkClientConf.get(RssSparkClientConf.RSS_CLIENT_SEND_CHECK_INTERVAL_MS); + this.sendSizeLimit = rssSparkClientConf.getSizeAsBytes( + RssSparkClientConf.RSS_CLIENT_SEND_SIZE_LIMIT.key(), + RssSparkClientConf.RSS_CLIENT_SEND_SIZE_LIMIT.defaultValue() + ); + this.bitmapSplitNum = rssSparkClientConf.get(RssSparkClientConf.RSS_CLIENT_BITMAP_SPLIT_NUM); this.partitionToBlockIds = Maps.newConcurrentMap(); this.shuffleWriteClient = shuffleWriteClient; this.shuffleServersForData = rssHandle.getShuffleServersForData(); @@ -118,7 +123,7 @@ public RssShuffleWriter( Arrays.fill(partitionLengths, 0); partitionToServers = rssHandle.getPartitionToServers(); this.isMemoryShuffleEnabled = isMemoryShuffleEnabled( - sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key())); + rssSparkClientConf.get(RssSparkClientConf.RSS_STORAGE_TYPE)); } private boolean isMemoryShuffleEnabled(String storageType) { diff --git a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/DelegationRssShuffleManagerTest.java b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/DelegationRssShuffleManagerTest.java index 222dbc00f0..fb8d6823f7 100644 --- a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/DelegationRssShuffleManagerTest.java +++ b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/DelegationRssShuffleManagerTest.java @@ -33,6 +33,7 @@ import org.apache.uniffle.client.response.RssAccessClusterResponse; import org.apache.uniffle.storage.util.StorageType; +import static org.apache.spark.shuffle.RssSparkClientConf.SPARK_CONFIG_KEY_PREFIX; import static org.apache.uniffle.client.response.ResponseStatusCode.ACCESS_DENIED; import static org.apache.uniffle.client.response.ResponseStatusCode.SUCCESS; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -65,7 +66,7 @@ public void testCreateInDriverDenied() throws Exception { mockedStaticRssShuffleUtils.when(() -> RssSparkShuffleUtils.createCoordinatorClients(any())).thenReturn(coordinatorClients); SparkConf conf = new SparkConf(); - conf.set(RssSparkConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED.key(), "false"); + conf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_DYNAMIC_CLIENT_CONF_ENABLED.key(), "false"); assertCreateSortShuffleManager(conf); } @@ -81,15 +82,15 @@ public void testCreateInDriver() throws Exception { SparkConf conf = new SparkConf(); assertCreateSortShuffleManager(conf); - conf.set(RssSparkConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED.key(), "false"); - conf.set(RssSparkConfig.RSS_ACCESS_ID.key(), "mockId"); + conf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_DYNAMIC_CLIENT_CONF_ENABLED.key(), "false"); + conf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_ACCESS_ID.key(), "mockId"); assertCreateSortShuffleManager(conf); - conf.set(RssSparkConfig.RSS_COORDINATOR_QUORUM.key(), "m1:8001,m2:8002"); + conf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_COORDINATOR_QUORUM.key(), "m1:8001,m2:8002"); conf.set("spark.rss.storage.type", StorageType.LOCALFILE.name()); assertCreateRssShuffleManager(conf); conf = new SparkConf(); - conf.set(RssSparkConfig.RSS_COORDINATOR_QUORUM.key(), "m1:8001,m2:8002"); + conf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_COORDINATOR_QUORUM.key(), "m1:8001,m2:8002"); when(mockCoordinatorClient.accessCluster(any())).thenReturn( new RssAccessClusterResponse(SUCCESS, "")); assertCreateSortShuffleManager(conf); @@ -99,7 +100,7 @@ public void testCreateInDriver() throws Exception { public void testCreateInExecutor() throws Exception { DelegationRssShuffleManager delegationRssShuffleManager; SparkConf conf = new SparkConf(); - conf.set(RssSparkConfig.RSS_COORDINATOR_QUORUM.key(), "m1:8001,m2:8002"); + conf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_COORDINATOR_QUORUM.key(), "m1:8001,m2:8002"); delegationRssShuffleManager = new DelegationRssShuffleManager(conf, false); assertFalse(delegationRssShuffleManager.getDelegate() instanceof RssShuffleManager); assertTrue(delegationRssShuffleManager.getDelegate() instanceof SortShuffleManager); @@ -115,15 +116,15 @@ public void testCreateFallback() throws Exception { RssSparkShuffleUtils.createCoordinatorClients(any())).thenReturn(coordinatorClients); SparkConf conf = new SparkConf(); - conf.set(RssSparkConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED.key(), "false"); - conf.set(RssSparkConfig.RSS_ACCESS_ID.key(), "mockId"); - conf.set(RssSparkConfig.RSS_ENABLED.key(), "true"); + conf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_DYNAMIC_CLIENT_CONF_ENABLED.key(), "false"); + conf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_ACCESS_ID.key(), "mockId"); + conf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_ENABLED.key(), "true"); // fall back to SortShuffleManager in driver assertCreateSortShuffleManager(conf); // No fall back in executor - conf.set(RssSparkConfig.RSS_ENABLED.key(), "true"); + conf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_ENABLED.key(), "true"); boolean hasException = false; try { new DelegationRssShuffleManager(conf, false); @@ -146,11 +147,11 @@ public void testTryAccessCluster() throws Exception { mockedStaticRssShuffleUtils.when(() -> RssSparkShuffleUtils.createCoordinatorClients(any())).thenReturn(coordinatorClients); SparkConf conf = new SparkConf(); - conf.set(RssSparkConfig.RSS_CLIENT_ACCESS_RETRY_INTERVAL_MS, 3000L); - conf.set(RssSparkConfig.RSS_CLIENT_ACCESS_RETRY_TIMES, 3); - conf.set(RssSparkConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED.key(), "false"); - conf.set(RssSparkConfig.RSS_ACCESS_ID.key(), "mockId"); - conf.set(RssSparkConfig.RSS_COORDINATOR_QUORUM.key(), "m1:8001,m2:8002"); + conf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_CLIENT_ACCESS_RETRY_INTERVAL_MS.key(), "3000"); + conf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_CLIENT_ACCESS_RETRY_TIMES.key(), "3"); + conf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_DYNAMIC_CLIENT_CONF_ENABLED.key(), "false"); + conf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_ACCESS_ID.key(), "mockId"); + conf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_COORDINATOR_QUORUM.key(), "m1:8001,m2:8002"); conf.set("spark.rss.storage.type", StorageType.LOCALFILE.name()); assertCreateRssShuffleManager(conf); @@ -164,11 +165,11 @@ public void testTryAccessCluster() throws Exception { mockedStaticRssShuffleUtils.when(() -> RssSparkShuffleUtils.createCoordinatorClients(any())).thenReturn(secondCoordinatorClients); SparkConf secondConf = new SparkConf(); - secondConf.set(RssSparkConfig.RSS_CLIENT_ACCESS_RETRY_INTERVAL_MS, 3000L); - secondConf.set(RssSparkConfig.RSS_CLIENT_ACCESS_RETRY_TIMES, 3); - secondConf.set(RssSparkConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED.key(), "false"); - secondConf.set(RssSparkConfig.RSS_ACCESS_ID.key(), "mockId"); - secondConf.set(RssSparkConfig.RSS_COORDINATOR_QUORUM.key(), "m1:8001,m2:8002"); + secondConf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_CLIENT_ACCESS_RETRY_INTERVAL_MS.key(), "3000"); + secondConf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_CLIENT_ACCESS_RETRY_TIMES.key(), "3"); + secondConf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_DYNAMIC_CLIENT_CONF_ENABLED.key(), "false"); + secondConf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_ACCESS_ID.key(), "mockId"); + secondConf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_COORDINATOR_QUORUM.key(), "m1:8001,m2:8002"); secondConf.set("spark.rss.storage.type", StorageType.LOCALFILE.name()); assertCreateSortShuffleManager(secondConf); } @@ -177,7 +178,7 @@ private DelegationRssShuffleManager assertCreateSortShuffleManager(SparkConf con DelegationRssShuffleManager delegationRssShuffleManager = new DelegationRssShuffleManager(conf, true); assertTrue(delegationRssShuffleManager.getDelegate() instanceof SortShuffleManager); assertFalse(delegationRssShuffleManager.getDelegate() instanceof RssShuffleManager); - assertFalse(conf.getBoolean(RssSparkConfig.RSS_ENABLED.key(), false)); + assertFalse(conf.getBoolean(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_ENABLED.key(), false)); assertEquals("sort", conf.get("spark.shuffle.manager")); return delegationRssShuffleManager; } @@ -186,7 +187,7 @@ private DelegationRssShuffleManager assertCreateRssShuffleManager(SparkConf conf DelegationRssShuffleManager delegationRssShuffleManager = new DelegationRssShuffleManager(conf, true); assertFalse(delegationRssShuffleManager.getDelegate() instanceof SortShuffleManager); assertTrue(delegationRssShuffleManager.getDelegate() instanceof RssShuffleManager); - assertTrue(Boolean.parseBoolean(conf.get(RssSparkConfig.RSS_ENABLED.key()))); + assertTrue(Boolean.parseBoolean(conf.get(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_ENABLED.key()))); assertEquals(RssShuffleManager.class.getCanonicalName(), conf.get("spark.shuffle.manager")); return delegationRssShuffleManager; } diff --git a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java index 1b7afcd985..99de61481e 100644 --- a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java +++ b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java @@ -38,7 +38,7 @@ import org.apache.spark.serializer.Serializer; import org.apache.spark.shuffle.RssShuffleHandle; import org.apache.spark.shuffle.RssShuffleManager; -import org.apache.spark.shuffle.RssSparkConfig; +import org.apache.spark.shuffle.RssSparkClientConf; import org.apache.spark.shuffle.TestUtils; import org.apache.spark.util.EventLoop; import org.junit.jupiter.api.Test; @@ -51,6 +51,7 @@ import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.storage.util.StorageType; +import static org.apache.spark.shuffle.RssSparkClientConf.SPARK_CONFIG_KEY_PREFIX; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -68,11 +69,11 @@ public void checkBlockSendResultTest() { SparkConf conf = new SparkConf(); conf.setAppName("testApp") .setMaster("local[2]") - .set(RssSparkConfig.RSS_TEST_FLAG.key(), "true") - .set(RssSparkConfig.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS.key(), "10000") - .set(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS.key(), "1000") - .set(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.LOCALFILE.name()) - .set(RssSparkConfig.RSS_COORDINATOR_QUORUM.key(), "127.0.0.1:12345,127.0.0.1:12346"); + .set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_TEST_FLAG.key(), "true") + .set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS.key(), "10000") + .set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_CLIENT_SEND_CHECK_INTERVAL_MS.key(), "1000") + .set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_STORAGE_TYPE.key(), StorageType.LOCALFILE.name()) + .set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_COORDINATOR_QUORUM.key(), "127.0.0.1:12345,127.0.0.1:12346"); // init SparkContext final SparkContext sc = SparkContext.getOrCreate(conf); Map> failBlocks = Maps.newConcurrentMap(); @@ -95,7 +96,7 @@ public void checkBlockSendResultTest() { when(mockHandle.getPartitionToServers()).thenReturn(Maps.newHashMap()); when(mockDependency.partitioner()).thenReturn(mockPartitioner); - BufferManagerOptions bufferOptions = new BufferManagerOptions(conf); + BufferManagerOptions bufferOptions = new BufferManagerOptions(RssSparkClientConf.from(conf)); WriteBufferManager bufferManager = new WriteBufferManager( 0, 0, bufferOptions, kryoSerializer, Maps.newHashMap(), mockTaskMemoryManager, new ShuffleWriteMetrics()); @@ -134,16 +135,17 @@ public void checkBlockSendResultTest() { @Test public void writeTest() throws Exception { SparkConf conf = new SparkConf(); - conf.setAppName("testApp").setMaster("local[2]") - .set(RssSparkConfig.RSS_WRITER_SERIALIZER_BUFFER_SIZE.key(), "32") - .set(RssSparkConfig.RSS_WRITER_BUFFER_SIZE.key(), "32") - .set(RssSparkConfig.RSS_TEST_FLAG.key(), "true") - .set(RssSparkConfig.RSS_WRITER_BUFFER_SEGMENT_SIZE.key(), "64") - .set(RssSparkConfig.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS.key(), "10000") - .set(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS.key(), "1000") - .set(RssSparkConfig.RSS_WRITER_BUFFER_SPILL_SIZE.key(), "128") - .set(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.LOCALFILE.name()) - .set(RssSparkConfig.RSS_COORDINATOR_QUORUM.key(), "127.0.0.1:12345,127.0.0.1:12346"); + conf.setAppName("testApp") + .setMaster("local[2]") + .set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_WRITER_SERIALIZER_BUFFER_SIZE.key(), "32") + .set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_WRITER_BUFFER_SIZE.key(), "32") + .set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_TEST_FLAG.key(), "true") + .set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_WRITER_BUFFER_SEGMENT_SIZE.key(), "64") + .set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS.key(), "10000") + .set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_CLIENT_SEND_CHECK_INTERVAL_MS.key(), "1000") + .set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_WRITER_BUFFER_SPILL_SIZE.key(), "128") + .set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_STORAGE_TYPE.key(), StorageType.LOCALFILE.name()) + .set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_COORDINATOR_QUORUM.key(), "127.0.0.1:12345,127.0.0.1:12346"); // init SparkContext List shuffleBlockInfos = Lists.newArrayList(); final SparkContext sc = SparkContext.getOrCreate(conf); @@ -202,7 +204,7 @@ public void onError(Throwable e) { TaskMemoryManager mockTaskMemoryManager = mock(TaskMemoryManager.class); - BufferManagerOptions bufferOptions = new BufferManagerOptions(conf); + BufferManagerOptions bufferOptions = new BufferManagerOptions(RssSparkClientConf.from(conf)); ShuffleWriteMetrics shuffleWriteMetrics = new ShuffleWriteMetrics(); WriteBufferManager bufferManager = new WriteBufferManager( 0, 0, bufferOptions, kryoSerializer, @@ -303,8 +305,8 @@ public void onError(Throwable e) { when(mockHandle.getDependency()).thenReturn(mockDependency); ShuffleWriteClient mockWriteClient = mock(ShuffleWriteClient.class); SparkConf conf = new SparkConf(); - conf.set(RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMIT.key(), "64") - .set(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.MEMORY_LOCALFILE.name()); + conf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_CLIENT_SEND_SIZE_LIMIT.key(), "64") + .set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_STORAGE_TYPE.key(), StorageType.MEMORY_LOCALFILE.name()); List shuffleBlockInfoList = createShuffleBlockList(1, 31); RssShuffleWriter writer = new RssShuffleWriter("appId", 0, "taskId", 1L, mockBufferManager, mockMetrics, mockShuffleManager, conf, mockWriteClient, mockHandle); diff --git a/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java b/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java index b3247b431a..6e11b8f90d 100644 --- a/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java +++ b/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java @@ -26,48 +26,69 @@ public class RssClientConfig { public static final String RSS_CLIENT_RETRY_INTERVAL_MAX = "rss.client.retry.interval.max"; public static final long RSS_CLIENT_RETRY_INTERVAL_MAX_DEFAULT_VALUE = 10000L; public static final String RSS_COORDINATOR_QUORUM = "rss.coordinator.quorum"; + public static final String RSS_DATA_REPLICA = "rss.data.replica"; public static final int RSS_DATA_REPLICA_DEFAULT_VALUE = 1; + public static final String RSS_DATA_REPLICA_WRITE = "rss.data.replica.write"; public static final int RSS_DATA_REPLICA_WRITE_DEFAULT_VALUE = 1; + public static final String RSS_DATA_REPLICA_READ = "rss.data.replica.read"; public static final int RSS_DATA_REPLICA_READ_DEFAULT_VALUE = 1; + public static final String RSS_DATA_REPLICA_SKIP_ENABLED = "rss.data.replica.skip.enabled"; public static final boolean RSS_DATA_REPLICA_SKIP_ENABLED_DEFAULT_VALUE = true; + public static final String RSS_DATA_TRANSFER_POOL_SIZE = "rss.client.data.transfer.pool.size"; public static final int RSS_DATA_TRANFER_POOL_SIZE_DEFAULT_VALUE = Runtime.getRuntime().availableProcessors(); + public static final String RSS_DATA_COMMIT_POOL_SIZE = "rss.client.data.commit.pool.size"; public static final int RSS_DATA_COMMIT_POOL_SIZE_DEFAULT_VALUE = -1; + public static final String RSS_HEARTBEAT_INTERVAL = "rss.heartbeat.interval"; public static final long RSS_HEARTBEAT_INTERVAL_DEFAULT_VALUE = 10 * 1000L; + public static final String RSS_HEARTBEAT_TIMEOUT = "rss.heartbeat.timeout"; + public static final String RSS_STORAGE_TYPE = "rss.storage.type"; + public static final String RSS_CLIENT_SEND_CHECK_INTERVAL_MS = "rss.client.send.check.interval.ms"; public static final long RSS_CLIENT_SEND_CHECK_INTERVAL_MS_DEFAULT_VALUE = 500L; + public static final String RSS_CLIENT_SEND_CHECK_TIMEOUT_MS = "rss.client.send.check.timeout.ms"; public static final long RSS_CLIENT_SEND_CHECK_TIMEOUT_MS_DEFAULT_VALUE = 60 * 1000 * 10L; + public static final String RSS_WRITER_BUFFER_SIZE = "rss.writer.buffer.size"; + public static final String RSS_PARTITION_NUM_PER_RANGE = "rss.partitionNum.per.range"; public static final int RSS_PARTITION_NUM_PER_RANGE_DEFAULT_VALUE = 1; + public static final String RSS_REMOTE_STORAGE_PATH = "rss.remote.storage.path"; + public static final String RSS_INDEX_READ_LIMIT = "rss.index.read.limit"; public static final int RSS_INDEX_READ_LIMIT_DEFAULT_VALUE = 500; + public static final String RSS_CLIENT_SEND_THREAD_NUM = "rss.client.send.thread.num"; public static final int RSS_CLIENT_DEFAULT_SEND_NUM = 5; + public static final String RSS_CLIENT_READ_BUFFER_SIZE = "rss.client.read.buffer.size"; // When the size of read buffer reaches the half of JVM region (i.e., 32m), // it will incur humongous allocation, so we set it to 14m. public static final String RSS_CLIENT_READ_BUFFER_SIZE_DEFAULT_VALUE = "14m"; + + // The tags specified by rss client to determine server assignment. public static final String RSS_CLIENT_ASSIGNMENT_TAGS = "rss.client.assignment.tags"; public static final String RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL = "rss.client.assignment.retry.interval"; public static final long RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL_DEFAULT_VALUE = 65000; + public static final String RSS_CLIENT_ASSIGNMENT_RETRY_TIMES = "rss.client.assignment.retry.times"; public static final int RSS_CLIENT_ASSIGNMENT_RETRY_TIMES_DEFAULT_VALUE = 3; public static final String RSS_ACCESS_TIMEOUT_MS = "rss.access.timeout.ms"; public static final int RSS_ACCESS_TIMEOUT_MS_DEFAULT_VALUE = 10000; + public static final String RSS_DYNAMIC_CLIENT_CONF_ENABLED = "rss.dynamicClientConf.enabled"; public static final boolean RSS_DYNAMIC_CLIENT_CONF_ENABLED_DEFAULT_VALUE = true; diff --git a/client/src/main/java/org/apache/uniffle/client/util/RssShuffleUtils.java b/client/src/main/java/org/apache/uniffle/client/util/RssShuffleUtils.java new file mode 100644 index 0000000000..22e955cc72 --- /dev/null +++ b/client/src/main/java/org/apache/uniffle/client/util/RssShuffleUtils.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.client.util; + +import java.util.Map; +import java.util.Set; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.uniffle.common.config.RssConf; + +public class RssShuffleUtils { + private static final Logger LOGGER = LoggerFactory.getLogger(RssShuffleUtils.class); + + public static void applyDynamicClientConf( + Set mandatoryList, + RssConf rssConf, + Map confItems) { + + if (rssConf == null) { + LOGGER.warn("Rss client conf is null"); + return; + } + + if (confItems == null || confItems.isEmpty()) { + LOGGER.warn("Empty conf items"); + return; + } + + for (Map.Entry kv : confItems.entrySet()) { + String remoteKey = kv.getKey(); + String remoteVal = kv.getValue(); + + if (!rssConf.containsKey(remoteKey) || mandatoryList.contains(remoteKey)) { + LOGGER.warn("Use conf dynamic conf {} = {}", remoteKey, remoteVal); + rssConf.setString(remoteKey, remoteVal); + } + } + } +} diff --git a/client/src/test/java/org/apache/uniffle/client/RssShuffleUtilsTest.java b/client/src/test/java/org/apache/uniffle/client/RssShuffleUtilsTest.java new file mode 100644 index 0000000000..363126bc52 --- /dev/null +++ b/client/src/test/java/org/apache/uniffle/client/RssShuffleUtilsTest.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.client; + +import java.util.Map; + +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import org.junit.jupiter.api.Test; + +import org.apache.uniffle.client.util.RssShuffleUtils; +import org.apache.uniffle.common.config.ConfigOption; +import org.apache.uniffle.common.config.ConfigOptions; +import org.apache.uniffle.common.config.RssConf; + +import static org.apache.uniffle.client.RssShuffleUtilsTest.MockedRssClientConf.RSS_CLIENT_RETRY_MAX; +import static org.apache.uniffle.client.RssShuffleUtilsTest.MockedRssClientConf.RSS_CLIENT_TYPE; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class RssShuffleUtilsTest { + + static class MockedRssClientConf extends RssConf { + public static final ConfigOption RSS_CLIENT_TYPE = ConfigOptions + .key("rss.client.type") + .stringType() + .defaultValue("GRPC") + .withDescription(""); + + public static final ConfigOption RSS_CLIENT_RETRY_MAX = ConfigOptions + .key("rss.client.retry.max") + .intType() + .defaultValue(100) + .withDescription(""); + } + + @Test + public void applyDynamicClientConfTest() { + RssConf conf = new MockedRssClientConf(); + + Map remoteConf = Maps.newHashMap(); + remoteConf.put(RSS_CLIENT_TYPE.key(), "NETTY"); + remoteConf.put(RSS_CLIENT_RETRY_MAX.key(), "200"); + + // case1: should be overwritten + RssShuffleUtils.applyDynamicClientConf(Sets.newHashSet(), conf, remoteConf); + assertEquals("NETTY", conf.get(RSS_CLIENT_TYPE)); + assertEquals(200, conf.get(RSS_CLIENT_RETRY_MAX)); + + // case2: if exist, only key in mandatory list will be overwritten + conf = new MockedRssClientConf(); + conf.set(RSS_CLIENT_TYPE, "GRPC"); + conf.set(RSS_CLIENT_RETRY_MAX, 300); + RssShuffleUtils.applyDynamicClientConf(Sets.newHashSet(RSS_CLIENT_TYPE.key()), conf, remoteConf); + assertEquals("NETTY", conf.get(RSS_CLIENT_TYPE)); + assertEquals(300, conf.get(RSS_CLIENT_RETRY_MAX)); + + // case3: if exist, wont be overwritten + conf = new MockedRssClientConf(); + conf.set(RSS_CLIENT_TYPE, "GRPC"); + conf.set(RSS_CLIENT_RETRY_MAX, 300); + RssShuffleUtils.applyDynamicClientConf(Sets.newHashSet(), conf, remoteConf); + assertEquals("GRPC", conf.get(RSS_CLIENT_TYPE)); + assertEquals(300, conf.get(RSS_CLIENT_RETRY_MAX)); + } +} diff --git a/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java new file mode 100644 index 0000000000..bb62588b5d --- /dev/null +++ b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.common.config; + +import java.util.List; + +public class RssClientConf extends RssConf { + + public static final ConfigOption RSS_CLIENT_HEARTBEAT_THREAD_NUM = ConfigOptions + .key("rss.client.heartBeat.threadNum") + .intType() + .defaultValue(4) + .withDescription(""); + + public static final ConfigOption RSS_CLIENT_TYPE = ConfigOptions + .key("rss.client.type") + .stringType() + .defaultValue("GRPC") + .withDescription(""); + + public static final ConfigOption RSS_CLIENT_RETRY_MAX = ConfigOptions + .key("rss.client.retry.max") + .intType() + .defaultValue(100) + .withDescription(""); + + public static final ConfigOption RSS_CLIENT_RETRY_INTERVAL_MAX = ConfigOptions + .key("rss.client.retry.interval.max") + .longType() + .defaultValue(10000L) + .withDescription(""); + + public static final ConfigOption RSS_COORDINATOR_QUORUM = ConfigOptions + .key("rss.coordinator.quorum") + .stringType() + .noDefaultValue() + .withDescription(""); + + public static final ConfigOption RSS_DATA_REPLICA = ConfigOptions + .key("rss.data.replica") + .intType() + .defaultValue(1) + .withDescription(""); + + public static final ConfigOption RSS_DATA_REPLICA_WRITE = ConfigOptions + .key("rss.data.replica.write") + .intType() + .defaultValue(1) + .withDescription(""); + + public static final ConfigOption RSS_DATA_REPLICA_READ = ConfigOptions + .key("rss.data.replica.read") + .intType() + .defaultValue(1) + .withDescription(""); + + public static final ConfigOption RSS_DATA_REPLICA_SKIP_ENABLED = ConfigOptions + .key("rss.data.replica.skip.enabled") + .booleanType() + .defaultValue(true) + .withDescription(""); + + public static final ConfigOption RSS_DATA_TRANSFER_POOL_SIZE = ConfigOptions + .key("rss.client.data.transfer.pool.size") + .intType() + .defaultValue(Runtime.getRuntime().availableProcessors()) + .withDescription(""); + + public static final ConfigOption RSS_DATA_COMMIT_POOL_SIZE = ConfigOptions + .key("rss.client.data.commit.pool.size") + .intType() + .defaultValue(-1) + .withDescription(""); + + public static final ConfigOption RSS_HEARTBEAT_INTERVAL = ConfigOptions + .key("rss.heartbeat.interval") + .longType() + .defaultValue(10 * 1000L) + .withDescription(""); + + //todo + public static final ConfigOption RSS_HEARTBEAT_TIMEOUT = ConfigOptions + .key("rss.heartbeat.timeout") + .longType() + .noDefaultValue() + .withDescription(""); + + public static final ConfigOption RSS_STORAGE_TYPE = ConfigOptions + .key("rss.storage.type") + .stringType() + .noDefaultValue() + .withDescription(""); + + public static final ConfigOption RSS_CLIENT_SEND_CHECK_INTERVAL_MS = ConfigOptions + .key("rss.client.send.check.interval.ms") + .longType() + .defaultValue(500L) + .withDescription(""); + + public static final ConfigOption RSS_CLIENT_SEND_CHECK_TIMEOUT_MS = ConfigOptions + .key("rss.client.send.check.timeout.ms") + .longType() + .defaultValue(60 * 1000 * 10L) + .withDescription(""); + + // todo + public static final ConfigOption RSS_WRITER_BUFFER_SIZE = ConfigOptions + .key("rss.writer.buffer.size") + .longType() + .noDefaultValue() + .withDescription("Buffer size for single partition data"); + + public static final ConfigOption RSS_PARTITION_NUM_PER_RANGE = ConfigOptions + .key("rss.partitionNum.per.range") + .intType() + .defaultValue(1) + .withDescription(""); + + public static final ConfigOption RSS_REMOTE_STORAGE_PATH = ConfigOptions + .key("rss.remote.storage.path") + .stringType() + .noDefaultValue() + .withDescription(""); + + public static final ConfigOption RSS_INDEX_READ_LIMIT = ConfigOptions + .key("rss.index.read.limit") + .intType() + .defaultValue(500) + .withDescription(""); + + public static final ConfigOption RSS_CLIENT_SEND_THREAD_NUM = ConfigOptions + .key("rss.client.send.thread.num") + .intType() + .defaultValue(5) + .withDescription(""); + + public static final ConfigOption RSS_CLIENT_READ_BUFFER_SIZE = ConfigOptions + .key("rss.client.read.buffer.size") + .stringType() + .defaultValue("14m") + .withDescription(""); + + public static final ConfigOption> RSS_CLIENT_ASSIGNMENT_TAGS = ConfigOptions + .key("rss.client.assignment.tags") + .stringType() + .asList() + .noDefaultValue() + .withDescription(""); + + public static final ConfigOption RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL = ConfigOptions + .key("rss.client.assignment.retry.interval") + .intType() + .defaultValue(65000) + .withDescription(""); + + public static final ConfigOption RSS_CLIENT_ASSIGNMENT_RETRY_TIMES = ConfigOptions + .key("rss.client.assignment.retry.times") + .intType() + .defaultValue(3) + .withDescription(""); + + public static final ConfigOption RSS_ACCESS_TIMEOUT_MS = ConfigOptions + .key("rss.access.timeout.ms") + .intType() + .defaultValue(10000) + .withDescription(""); + + public static final ConfigOption RSS_DYNAMIC_CLIENT_CONF_ENABLED = ConfigOptions + .key("rss.dynamicClientConf.enabled") + .booleanType() + .defaultValue(true) + .withDescription(""); + + public static final ConfigOption RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER = ConfigOptions + .key("rss.client.assignment.shuffle.nodes.max") + .intType() + .defaultValue(-1) + .withDescription(""); +} diff --git a/common/src/main/java/org/apache/uniffle/common/config/RssConf.java b/common/src/main/java/org/apache/uniffle/common/config/RssConf.java index 9e1c1f7811..c7be2b50e2 100644 --- a/common/src/main/java/org/apache/uniffle/common/config/RssConf.java +++ b/common/src/main/java/org/apache/uniffle/common/config/RssConf.java @@ -18,12 +18,15 @@ package org.apache.uniffle.common.config; import java.util.Arrays; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Collectors; import com.google.common.collect.Sets; +import org.apache.commons.lang3.tuple.Pair; import org.apache.uniffle.common.util.UnitConverter; @@ -640,4 +643,8 @@ public boolean equals(Object obj) { public String toString() { return this.settings.toString(); } + + public List> getAll() { + return settings.entrySet().stream().map(x -> Pair.of(x.getKey(), x.getValue())).collect(Collectors.toList()); + } } diff --git a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/AutoAccessTest.java b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/AutoAccessTest.java index c6ad3fd589..88658b0017 100644 --- a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/AutoAccessTest.java +++ b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/AutoAccessTest.java @@ -27,7 +27,7 @@ import org.apache.spark.SparkConf; import org.apache.spark.shuffle.DelegationRssShuffleManager; import org.apache.spark.shuffle.RssShuffleManager; -import org.apache.spark.shuffle.RssSparkConfig; +import org.apache.spark.shuffle.RssSparkClientConf; import org.apache.spark.shuffle.ShuffleManager; import org.apache.spark.shuffle.sort.SortShuffleManager; import org.junit.jupiter.api.Test; @@ -37,6 +37,7 @@ import org.apache.uniffle.storage.util.StorageType; import static java.lang.Thread.sleep; +import static org.apache.spark.shuffle.RssSparkClientConf.SPARK_CONFIG_KEY_PREFIX; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertSame; @@ -48,9 +49,9 @@ public class AutoAccessTest extends IntegrationTestBase { public void test() throws Exception { SparkConf sparkConf = new SparkConf(); sparkConf.set("spark.shuffle.manager", "org.apache.spark.shuffle.DelegationRssShuffleManager"); - sparkConf.set(RssSparkConfig.RSS_COORDINATOR_QUORUM.key(), COORDINATOR_QUORUM); + sparkConf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_COORDINATOR_QUORUM.key(), COORDINATOR_QUORUM); sparkConf.set("spark.mock.2", "no-overwrite-conf"); - sparkConf.set(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), "overwrite-path"); + sparkConf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_REMOTE_STORAGE_PATH.key(), "overwrite-path"); sparkConf.set("spark.shuffle.service.enabled", "true"); String cfgFile = HDFS_URI + "/test/client_conf"; @@ -61,7 +62,7 @@ public void test() throws Exception { printWriter.println(" spark.mock.2 overwrite-conf "); printWriter.println(" spark.mock.3 true "); printWriter.println("spark.rss.storage.type " + StorageType.MEMORY_LOCALFILE_HDFS.name()); - printWriter.println(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key() + " expectedPath"); + printWriter.println(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_REMOTE_STORAGE_PATH.key() + " expectedPath"); printWriter.flush(); printWriter.close(); @@ -101,7 +102,7 @@ public void test() throws Exception { ShuffleManager shuffleManager = delegationRssShuffleManager.getDelegate(); assertTrue(shuffleManager instanceof SortShuffleManager); assertTrue(sparkConf.getBoolean("spark.shuffle.service.enabled", true)); - assertEquals("overwrite-path", sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key())); + assertEquals("overwrite-path", sparkConf.get(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_REMOTE_STORAGE_PATH.key())); assertFalse(sparkConf.contains("spark.rss.storage.type")); // wrong access id @@ -109,7 +110,7 @@ public void test() throws Exception { delegationRssShuffleManager = new DelegationRssShuffleManager(sparkConf, true); shuffleManager = delegationRssShuffleManager.getDelegate(); assertTrue(shuffleManager instanceof SortShuffleManager); - assertEquals("overwrite-path", sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key())); + assertEquals("overwrite-path", sparkConf.get(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_REMOTE_STORAGE_PATH.key())); assertTrue(sparkConf.getBoolean("spark.shuffle.service.enabled", true)); assertFalse(sparkConf.contains("spark.rss.storage.type")); @@ -123,7 +124,7 @@ public void test() throws Exception { assertEquals("no-overwrite-conf", sparkConf.get("spark.mock.2")); assertTrue(sparkConf.getBoolean("spark.mock.3", false)); assertEquals(StorageType.MEMORY_LOCALFILE_HDFS.name(), sparkConf.get("spark.rss.storage.type")); - assertEquals("expectedPath", sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key())); + assertEquals("expectedPath", sparkConf.get(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_REMOTE_STORAGE_PATH.key())); assertFalse(sparkConf.getBoolean("spark.shuffle.service.enabled", true)); // update candidates file @@ -150,7 +151,7 @@ public void test() throws Exception { assertEquals("no-overwrite-conf", sparkConf.get("spark.mock.2")); assertTrue(sparkConf.getBoolean("spark.mock.3", false)); assertEquals(StorageType.MEMORY_LOCALFILE_HDFS.name(), sparkConf.get("spark.rss.storage.type")); - assertEquals("expectedPath", sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key())); + assertEquals("expectedPath", sparkConf.get(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_REMOTE_STORAGE_PATH.key())); assertFalse(sparkConf.getBoolean("spark.shuffle.service.enabled", true)); // update client conf file @@ -163,7 +164,7 @@ public void test() throws Exception { printWriter.println(" spark.mock.2 overwrite-conf "); printWriter.println(" spark.mock.3 false "); printWriter.println("spark.rss.storage.type " + StorageType.MEMORY_LOCALFILE_HDFS.name()); - printWriter.println(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key() + " expectedPathNew"); + printWriter.println(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_REMOTE_STORAGE_PATH.key() + " expectedPathNew"); printWriter.flush(); printWriter.close(); fs.rename(tmpPath, path); @@ -178,7 +179,7 @@ public void test() throws Exception { assertEquals("overwrite-conf", sparkConf.get("spark.mock.2")); assertTrue(sparkConf.getBoolean("spark.mock.3", false)); assertEquals(StorageType.MEMORY_LOCALFILE_HDFS.name(), sparkConf.get("spark.rss.storage.type")); - assertEquals("expectedPathNew", sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key())); + assertEquals("expectedPathNew", sparkConf.get(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_REMOTE_STORAGE_PATH.key())); assertFalse(sparkConf.getBoolean("spark.shuffle.service.enabled", true)); } } diff --git a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/DynamicFetchClientConfTest.java b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/DynamicFetchClientConfTest.java index 9f259c4f56..b08190f2a4 100644 --- a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/DynamicFetchClientConfTest.java +++ b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/DynamicFetchClientConfTest.java @@ -26,12 +26,13 @@ import org.apache.hadoop.fs.Path; import org.apache.spark.SparkConf; import org.apache.spark.shuffle.RssShuffleManager; -import org.apache.spark.shuffle.RssSparkConfig; +import org.apache.spark.shuffle.RssSparkClientConf; import org.junit.jupiter.api.Test; import org.apache.uniffle.coordinator.CoordinatorConf; import org.apache.uniffle.storage.util.StorageType; +import static org.apache.spark.shuffle.RssSparkClientConf.SPARK_CONFIG_KEY_PREFIX; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -42,7 +43,7 @@ public class DynamicFetchClientConfTest extends IntegrationTestBase { public void test() throws Exception { SparkConf sparkConf = new SparkConf(); sparkConf.set("spark.shuffle.manager", "org.apache.spark.shuffle.RssShuffleManager"); - sparkConf.set(RssSparkConfig.RSS_COORDINATOR_QUORUM.key(), COORDINATOR_QUORUM); + sparkConf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_COORDINATOR_QUORUM.key(), COORDINATOR_QUORUM); sparkConf.set("spark.mock.2", "no-overwrite-conf"); sparkConf.set("spark.shuffle.service.enabled", "true"); @@ -54,11 +55,11 @@ public void test() throws Exception { printWriter.println(" spark.mock.2 overwrite-conf "); printWriter.println(" spark.mock.3 true "); printWriter.println("spark.rss.storage.type " + StorageType.MEMORY_LOCALFILE_HDFS.name()); - printWriter.println(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key() + " expectedPath"); + printWriter.println(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_REMOTE_STORAGE_PATH.key() + " expectedPath"); printWriter.flush(); printWriter.close(); - for (String k : RssSparkConfig.RSS_MANDATORY_CLUSTER_CONF) { - sparkConf.set(k, "Dummy-" + k); + for (String k : RssSparkClientConf.RSS_MANDATORY_CLUSTER_CONF) { + sparkConf.set("spark." + k, "Dummy-" + k); } sparkConf.set("spark.mock.2", "no-overwrite-conf"); @@ -74,10 +75,10 @@ public void test() throws Exception { assertFalse(sparkConf.contains("spark.mock.1")); assertEquals("no-overwrite-conf", sparkConf.get("spark.mock.2")); assertFalse(sparkConf.contains("spark.mock.3")); - assertEquals("Dummy-" + RssSparkConfig.RSS_STORAGE_TYPE.key(), - sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key())); - assertEquals("Dummy-" + RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), - sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key())); + assertEquals("Dummy-" + RssSparkClientConf.RSS_STORAGE_TYPE.key(), + sparkConf.get(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_STORAGE_TYPE.key())); + assertEquals("Dummy-" + RssSparkClientConf.RSS_REMOTE_STORAGE_PATH.key(), + sparkConf.get(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_REMOTE_STORAGE_PATH.key())); assertTrue(sparkConf.getBoolean("spark.shuffle.service.enabled", true)); RssShuffleManager rssShuffleManager = new RssShuffleManager(sparkConf, true); @@ -86,14 +87,14 @@ public void test() throws Exception { assertEquals(1234, sparkConf1.getInt("spark.mock.1", 0)); assertEquals("no-overwrite-conf", sparkConf1.get("spark.mock.2")); assertEquals(StorageType.MEMORY_LOCALFILE_HDFS.name(), sparkConf.get("spark.rss.storage.type")); - assertEquals("expectedPath", sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key())); + assertEquals("expectedPath", sparkConf.get(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_REMOTE_STORAGE_PATH.key())); assertFalse(sparkConf1.getBoolean("spark.shuffle.service.enabled", true)); fs.delete(path, true); shutdownServers(); sparkConf = new SparkConf(); sparkConf.set("spark.shuffle.manager", "org.apache.spark.shuffle.RssShuffleManager"); - sparkConf.set(RssSparkConfig.RSS_COORDINATOR_QUORUM.key(), COORDINATOR_QUORUM); + sparkConf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_COORDINATOR_QUORUM.key(), COORDINATOR_QUORUM); sparkConf.set("spark.mock.2", "no-overwrite-conf"); sparkConf.set("spark.shuffle.service.enabled", "true"); diff --git a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RepartitionWithHdfsMultiStorageRssTest.java b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RepartitionWithHdfsMultiStorageRssTest.java index 4d71ea526a..b890d001cd 100644 --- a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RepartitionWithHdfsMultiStorageRssTest.java +++ b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RepartitionWithHdfsMultiStorageRssTest.java @@ -24,20 +24,25 @@ import com.google.common.collect.Maps; import com.google.common.io.Files; import org.apache.spark.SparkConf; -import org.apache.spark.shuffle.RssSparkConfig; +import org.apache.spark.shuffle.RssSparkClientConf; import org.junit.jupiter.api.BeforeAll; import org.apache.uniffle.coordinator.CoordinatorConf; import org.apache.uniffle.server.ShuffleServerConf; import org.apache.uniffle.storage.util.StorageType; +import static org.apache.spark.shuffle.RssSparkClientConf.SPARK_CONFIG_KEY_PREFIX; + public class RepartitionWithHdfsMultiStorageRssTest extends RepartitionTest { @BeforeAll public static void setupServers() throws Exception { CoordinatorConf coordinatorConf = getCoordinatorConf(); Map dynamicConf = Maps.newHashMap(); dynamicConf.put(CoordinatorConf.COORDINATOR_REMOTE_STORAGE_PATH.key(), HDFS_URI + "rss/test"); - dynamicConf.put(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.LOCALFILE_HDFS.name()); + dynamicConf.put( + SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_STORAGE_TYPE.key(), + StorageType.LOCALFILE_HDFS.name() + ); addDynamicConf(coordinatorConf, dynamicConf); createCoordinatorServer(coordinatorConf); diff --git a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RepartitionWithLocalFileRssTest.java b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RepartitionWithLocalFileRssTest.java index 82586d9113..25d752534c 100644 --- a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RepartitionWithLocalFileRssTest.java +++ b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RepartitionWithLocalFileRssTest.java @@ -23,20 +23,22 @@ import com.google.common.collect.Maps; import com.google.common.io.Files; import org.apache.spark.SparkConf; -import org.apache.spark.shuffle.RssSparkConfig; +import org.apache.spark.shuffle.RssSparkClientConf; import org.junit.jupiter.api.BeforeAll; import org.apache.uniffle.coordinator.CoordinatorConf; import org.apache.uniffle.server.ShuffleServerConf; import org.apache.uniffle.storage.util.StorageType; +import static org.apache.spark.shuffle.RssSparkClientConf.SPARK_CONFIG_KEY_PREFIX; + public class RepartitionWithLocalFileRssTest extends RepartitionTest { @BeforeAll public static void setupServers() throws Exception { CoordinatorConf coordinatorConf = getCoordinatorConf(); Map dynamicConf = Maps.newHashMap(); - dynamicConf.put(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.LOCALFILE.name()); + dynamicConf.put(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_STORAGE_TYPE.key(), StorageType.LOCALFILE.name()); addDynamicConf(coordinatorConf, dynamicConf); createCoordinatorServer(coordinatorConf); ShuffleServerConf shuffleServerConf = getShuffleServerConf(); diff --git a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RepartitionWithMemoryMultiStorageRssTest.java b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RepartitionWithMemoryMultiStorageRssTest.java index 232a31b887..7b22a186e4 100644 --- a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RepartitionWithMemoryMultiStorageRssTest.java +++ b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RepartitionWithMemoryMultiStorageRssTest.java @@ -24,20 +24,25 @@ import com.google.common.collect.Maps; import com.google.common.io.Files; import org.apache.spark.SparkConf; -import org.apache.spark.shuffle.RssSparkConfig; +import org.apache.spark.shuffle.RssSparkClientConf; import org.junit.jupiter.api.BeforeAll; import org.apache.uniffle.coordinator.CoordinatorConf; import org.apache.uniffle.server.ShuffleServerConf; import org.apache.uniffle.storage.util.StorageType; +import static org.apache.spark.shuffle.RssSparkClientConf.SPARK_CONFIG_KEY_PREFIX; + public class RepartitionWithMemoryMultiStorageRssTest extends RepartitionTest { @BeforeAll public static void setupServers() throws Exception { CoordinatorConf coordinatorConf = getCoordinatorConf(); Map dynamicConf = Maps.newHashMap(); dynamicConf.put(CoordinatorConf.COORDINATOR_REMOTE_STORAGE_PATH.key(), HDFS_URI + "rss/test"); - dynamicConf.put(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.MEMORY_LOCALFILE_HDFS.name()); + dynamicConf.put( + SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_STORAGE_TYPE.key(), + StorageType.MEMORY_LOCALFILE_HDFS.name() + ); addDynamicConf(coordinatorConf, dynamicConf); createCoordinatorServer(coordinatorConf); ShuffleServerConf shuffleServerConf = getShuffleServerConf(); diff --git a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RepartitionWithMemoryRssTest.java b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RepartitionWithMemoryRssTest.java index 63a641eab3..bd3fa8a5db 100644 --- a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RepartitionWithMemoryRssTest.java +++ b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RepartitionWithMemoryRssTest.java @@ -24,7 +24,7 @@ import com.google.common.collect.Maps; import com.google.common.io.Files; import org.apache.spark.SparkConf; -import org.apache.spark.shuffle.RssSparkConfig; +import org.apache.spark.shuffle.RssSparkClientConf; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; @@ -32,6 +32,8 @@ import org.apache.uniffle.server.ShuffleServerConf; import org.apache.uniffle.storage.util.StorageType; +import static org.apache.spark.shuffle.RssSparkClientConf.SPARK_CONFIG_KEY_PREFIX; + public class RepartitionWithMemoryRssTest extends RepartitionTest { @BeforeAll @@ -39,7 +41,10 @@ public static void setupServers() throws Exception { CoordinatorConf coordinatorConf = getCoordinatorConf(); coordinatorConf.set(CoordinatorConf.COORDINATOR_APP_EXPIRED, 5000L); Map dynamicConf = Maps.newHashMap(); - dynamicConf.put(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.MEMORY_LOCALFILE.name()); + dynamicConf.put( + SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_STORAGE_TYPE.key(), + StorageType.MEMORY_LOCALFILE.name() + ); addDynamicConf(coordinatorConf, dynamicConf); createCoordinatorServer(coordinatorConf); ShuffleServerConf shuffleServerConf = getShuffleServerConf(); diff --git a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SimpleTestBase.java b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SimpleTestBase.java index 5b7e4d5319..6c9930dc28 100644 --- a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SimpleTestBase.java +++ b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SimpleTestBase.java @@ -21,20 +21,23 @@ import com.google.common.collect.Maps; import org.apache.spark.SparkConf; -import org.apache.spark.shuffle.RssSparkConfig; +import org.apache.spark.shuffle.RssSparkClientConf; import org.junit.jupiter.api.BeforeAll; import org.apache.uniffle.coordinator.CoordinatorConf; import org.apache.uniffle.server.ShuffleServerConf; import org.apache.uniffle.storage.util.StorageType; +import static org.apache.spark.shuffle.RssSparkClientConf.SPARK_CONFIG_KEY_PREFIX; + public abstract class SimpleTestBase extends SparkIntegrationTestBase { @BeforeAll public static void setupServers() throws Exception { CoordinatorConf coordinatorConf = getCoordinatorConf(); Map dynamicConf = Maps.newHashMap(); dynamicConf.put(CoordinatorConf.COORDINATOR_REMOTE_STORAGE_PATH.key(), HDFS_URI + "rss/test"); - dynamicConf.put(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.MEMORY_LOCALFILE_HDFS.name()); + dynamicConf.put(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_STORAGE_TYPE.key(), + StorageType.MEMORY_LOCALFILE_HDFS.name()); addDynamicConf(coordinatorConf, dynamicConf); createCoordinatorServer(coordinatorConf); ShuffleServerConf shuffleServerConf = getShuffleServerConf(); diff --git a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkIntegrationTestBase.java b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkIntegrationTestBase.java index 1ea90007d4..a691e96647 100644 --- a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkIntegrationTestBase.java +++ b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkIntegrationTestBase.java @@ -22,11 +22,12 @@ import com.google.common.util.concurrent.Uninterruptibles; import org.apache.spark.SparkConf; -import org.apache.spark.shuffle.RssSparkConfig; +import org.apache.spark.shuffle.RssSparkClientConf; import org.apache.spark.sql.SparkSession; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import static org.apache.spark.shuffle.RssSparkClientConf.SPARK_CONFIG_KEY_PREFIX; import static org.junit.jupiter.api.Assertions.assertEquals; public abstract class SparkIntegrationTestBase extends IntegrationTestBase { @@ -85,18 +86,18 @@ protected SparkConf createSparkConf() { public void updateSparkConfWithRss(SparkConf sparkConf) { sparkConf.set("spark.shuffle.manager", "org.apache.spark.shuffle.RssShuffleManager"); sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer"); - sparkConf.set(RssSparkConfig.RSS_WRITER_BUFFER_SIZE.key(), "4m"); - sparkConf.set(RssSparkConfig.RSS_WRITER_BUFFER_SPILL_SIZE.key(), "32m"); - sparkConf.set(RssSparkConfig.RSS_CLIENT_READ_BUFFER_SIZE.key(), "2m"); - sparkConf.set(RssSparkConfig.RSS_WRITER_SERIALIZER_BUFFER_SIZE.key(), "128k"); - sparkConf.set(RssSparkConfig.RSS_WRITER_BUFFER_SEGMENT_SIZE.key(), "256k"); - sparkConf.set(RssSparkConfig.RSS_COORDINATOR_QUORUM.key(), COORDINATOR_QUORUM); - sparkConf.set(RssSparkConfig.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS.key(), "30000"); - sparkConf.set(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS.key(), "1000"); - sparkConf.set(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX.key(), "1000"); - sparkConf.set(RssSparkConfig.RSS_INDEX_READ_LIMIT.key(), "100"); - sparkConf.set(RssSparkConfig.RSS_CLIENT_READ_BUFFER_SIZE.key(), "1m"); - sparkConf.set(RssSparkConfig.RSS_HEARTBEAT_INTERVAL.key(), "2000"); + sparkConf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_WRITER_BUFFER_SIZE.key(), "4m"); + sparkConf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_WRITER_BUFFER_SPILL_SIZE.key(), "32m"); + sparkConf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_CLIENT_READ_BUFFER_SIZE.key(), "2m"); + sparkConf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_WRITER_SERIALIZER_BUFFER_SIZE.key(), "128k"); + sparkConf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_WRITER_BUFFER_SEGMENT_SIZE.key(), "256k"); + sparkConf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_COORDINATOR_QUORUM.key(), COORDINATOR_QUORUM); + sparkConf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS.key(), "30000"); + sparkConf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_CLIENT_SEND_CHECK_INTERVAL_MS.key(), "1000"); + sparkConf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_CLIENT_RETRY_INTERVAL_MAX.key(), "1000"); + sparkConf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_INDEX_READ_LIMIT.key(), "100"); + sparkConf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_CLIENT_READ_BUFFER_SIZE.key(), "1m"); + sparkConf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_HEARTBEAT_INTERVAL.key(), "2000"); } private void verifyTestResult(Map expected, Map actual) { diff --git a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkSQLWithDelegationShuffleManager.java b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkSQLWithDelegationShuffleManager.java index 45d0e7747b..00c1ee995a 100644 --- a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkSQLWithDelegationShuffleManager.java +++ b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkSQLWithDelegationShuffleManager.java @@ -27,13 +27,15 @@ import com.google.common.io.Files; import com.google.common.util.concurrent.Uninterruptibles; import org.apache.spark.SparkConf; -import org.apache.spark.shuffle.RssSparkConfig; +import org.apache.spark.shuffle.RssSparkClientConf; import org.junit.jupiter.api.BeforeAll; import org.apache.uniffle.coordinator.CoordinatorConf; import org.apache.uniffle.server.ShuffleServerConf; import org.apache.uniffle.storage.util.StorageType; +import static org.apache.spark.shuffle.RssSparkClientConf.SPARK_CONFIG_KEY_PREFIX; + public class SparkSQLWithDelegationShuffleManager extends SparkSQLTest { @BeforeAll @@ -49,7 +51,8 @@ public static void setupServers() throws Exception { coordinatorConf.set(CoordinatorConf.COORDINATOR_APP_EXPIRED, 5000L); coordinatorConf.set(CoordinatorConf.COORDINATOR_ACCESS_LOADCHECKER_SERVER_NUM_THRESHOLD, 1); Map dynamicConf = Maps.newHashMap(); - dynamicConf.put(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.MEMORY_LOCALFILE.name()); + dynamicConf.put(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_STORAGE_TYPE.key(), + StorageType.MEMORY_LOCALFILE.name()); addDynamicConf(coordinatorConf, dynamicConf); createCoordinatorServer(coordinatorConf); ShuffleServerConf shuffleServerConf = getShuffleServerConf(); @@ -69,7 +72,7 @@ public static void setupServers() throws Exception { @Override public void updateRssStorage(SparkConf sparkConf) { - sparkConf.set(RssSparkConfig.RSS_ACCESS_ID.key(), "test_access_id"); + sparkConf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_ACCESS_ID.key(), "test_access_id"); sparkConf.set("spark.shuffle.manager", "org.apache.spark.shuffle.DelegationRssShuffleManager"); } diff --git a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkSQLWithDelegationShuffleManagerFallback.java b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkSQLWithDelegationShuffleManagerFallback.java index 441209584f..7c1bb381d9 100644 --- a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkSQLWithDelegationShuffleManagerFallback.java +++ b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkSQLWithDelegationShuffleManagerFallback.java @@ -27,13 +27,15 @@ import com.google.common.io.Files; import com.google.common.util.concurrent.Uninterruptibles; import org.apache.spark.SparkConf; -import org.apache.spark.shuffle.RssSparkConfig; +import org.apache.spark.shuffle.RssSparkClientConf; import org.junit.jupiter.api.BeforeAll; import org.apache.uniffle.coordinator.CoordinatorConf; import org.apache.uniffle.server.ShuffleServerConf; import org.apache.uniffle.storage.util.StorageType; +import static org.apache.spark.shuffle.RssSparkClientConf.SPARK_CONFIG_KEY_PREFIX; + public class SparkSQLWithDelegationShuffleManagerFallback extends SparkSQLTest { @BeforeAll @@ -49,7 +51,8 @@ public static void setupServers() throws Exception { coordinatorConf.set(CoordinatorConf.COORDINATOR_APP_EXPIRED, 5000L); coordinatorConf.set(CoordinatorConf.COORDINATOR_ACCESS_LOADCHECKER_SERVER_NUM_THRESHOLD, 1); Map dynamicConf = Maps.newHashMap(); - dynamicConf.put(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.MEMORY_LOCALFILE.name()); + dynamicConf.put(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_STORAGE_TYPE.key(), + StorageType.MEMORY_LOCALFILE.name()); addDynamicConf(coordinatorConf, dynamicConf); createCoordinatorServer(coordinatorConf); ShuffleServerConf shuffleServerConf = getShuffleServerConf(); @@ -70,7 +73,7 @@ public static void setupServers() throws Exception { @Override public void updateRssStorage(SparkConf sparkConf) { - sparkConf.set(RssSparkConfig.RSS_ACCESS_ID.key(), "wrong_id"); + sparkConf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_ACCESS_ID.key(), "wrong_id"); sparkConf.set("spark.shuffle.manager", "org.apache.spark.shuffle.DelegationRssShuffleManager"); } diff --git a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkSQLWithMemoryLocalTest.java b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkSQLWithMemoryLocalTest.java index 98432720ae..150b726e2f 100644 --- a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkSQLWithMemoryLocalTest.java +++ b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkSQLWithMemoryLocalTest.java @@ -23,13 +23,14 @@ import com.google.common.collect.Maps; import com.google.common.io.Files; import org.apache.spark.SparkConf; -import org.apache.spark.shuffle.RssSparkConfig; +import org.apache.spark.shuffle.RssSparkClientConf; import org.junit.jupiter.api.BeforeAll; import org.apache.uniffle.coordinator.CoordinatorConf; import org.apache.uniffle.server.ShuffleServerConf; import org.apache.uniffle.storage.util.StorageType; +import static org.apache.spark.shuffle.RssSparkClientConf.SPARK_CONFIG_KEY_PREFIX; import static org.junit.jupiter.api.Assertions.assertEquals; public class SparkSQLWithMemoryLocalTest extends SparkSQLTest { @@ -41,7 +42,8 @@ public static void setupServers() throws Exception { CoordinatorConf coordinatorConf = getCoordinatorConf(); coordinatorConf.setLong("rss.coordinator.app.expired", 5000); Map dynamicConf = Maps.newHashMap(); - dynamicConf.put(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.MEMORY_LOCALFILE.name()); + dynamicConf.put(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_STORAGE_TYPE.key(), + StorageType.MEMORY_LOCALFILE.name()); addDynamicConf(coordinatorConf, dynamicConf); createCoordinatorServer(coordinatorConf); ShuffleServerConf shuffleServerConf = getShuffleServerConf(); diff --git a/integration-test/spark3/src/test/java/org/apache/uniffle/test/AQERepartitionTest.java b/integration-test/spark3/src/test/java/org/apache/uniffle/test/AQERepartitionTest.java index 297a2386b6..685ae104ae 100644 --- a/integration-test/spark3/src/test/java/org/apache/uniffle/test/AQERepartitionTest.java +++ b/integration-test/spark3/src/test/java/org/apache/uniffle/test/AQERepartitionTest.java @@ -24,7 +24,7 @@ import com.google.common.collect.Lists; import com.google.common.collect.Maps; import org.apache.spark.SparkConf; -import org.apache.spark.shuffle.RssSparkConfig; +import org.apache.spark.shuffle.RssSparkClientConf; import org.apache.spark.sql.Column; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.SparkSession; @@ -37,6 +37,7 @@ import org.apache.uniffle.server.ShuffleServerConf; import org.apache.uniffle.storage.util.StorageType; +import static org.apache.spark.shuffle.RssSparkClientConf.SPARK_CONFIG_KEY_PREFIX; import static org.junit.jupiter.api.Assertions.assertTrue; public class AQERepartitionTest extends SparkIntegrationTestBase { @@ -46,7 +47,8 @@ public static void setupServers() throws Exception { CoordinatorConf coordinatorConf = getCoordinatorConf(); Map dynamicConf = Maps.newHashMap(); dynamicConf.put(CoordinatorConf.COORDINATOR_REMOTE_STORAGE_PATH.key(), HDFS_URI + "rss/test"); - dynamicConf.put(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.MEMORY_LOCALFILE_HDFS.name()); + dynamicConf.put(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_STORAGE_TYPE.key(), + StorageType.MEMORY_LOCALFILE_HDFS.name()); addDynamicConf(coordinatorConf, dynamicConf); createCoordinatorServer(coordinatorConf); ShuffleServerConf shuffleServerConf = getShuffleServerConf(); diff --git a/integration-test/spark3/src/test/java/org/apache/uniffle/test/AQESkewedJoinTest.java b/integration-test/spark3/src/test/java/org/apache/uniffle/test/AQESkewedJoinTest.java index 50e0c27e6d..3deec5a6fd 100644 --- a/integration-test/spark3/src/test/java/org/apache/uniffle/test/AQESkewedJoinTest.java +++ b/integration-test/spark3/src/test/java/org/apache/uniffle/test/AQESkewedJoinTest.java @@ -24,7 +24,7 @@ import com.google.common.collect.Lists; import com.google.common.collect.Maps; import org.apache.spark.SparkConf; -import org.apache.spark.shuffle.RssSparkConfig; +import org.apache.spark.shuffle.RssSparkClientConf; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; @@ -39,6 +39,7 @@ import org.apache.uniffle.server.ShuffleServerConf; import org.apache.uniffle.storage.util.StorageType; +import static org.apache.spark.shuffle.RssSparkClientConf.SPARK_CONFIG_KEY_PREFIX; import static org.junit.jupiter.api.Assertions.assertTrue; public class AQESkewedJoinTest extends SparkIntegrationTestBase { @@ -48,7 +49,8 @@ public static void setupServers() throws Exception { CoordinatorConf coordinatorConf = getCoordinatorConf(); Map dynamicConf = Maps.newHashMap(); dynamicConf.put(CoordinatorConf.COORDINATOR_REMOTE_STORAGE_PATH.key(), HDFS_URI + "rss/test"); - dynamicConf.put(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.MEMORY_LOCALFILE_HDFS.name()); + dynamicConf.put(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_STORAGE_TYPE.key(), + StorageType.MEMORY_LOCALFILE_HDFS.name()); addDynamicConf(coordinatorConf, dynamicConf); createCoordinatorServer(coordinatorConf); ShuffleServerConf shuffleServerConf = getShuffleServerConf(); @@ -68,8 +70,8 @@ public void updateCommonSparkConf(SparkConf sparkConf) { @Override public void updateSparkConfCustomer(SparkConf sparkConf) { - sparkConf.set(RssSparkConfig.RSS_STORAGE_TYPE.key(), "HDFS"); - sparkConf.set(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), HDFS_URI + "rss/test"); + sparkConf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_STORAGE_TYPE.key(), "HDFS"); + sparkConf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_REMOTE_STORAGE_PATH.key(), HDFS_URI + "rss/test"); } @Test diff --git a/integration-test/spark3/src/test/java/org/apache/uniffle/test/GetReaderTest.java b/integration-test/spark3/src/test/java/org/apache/uniffle/test/GetReaderTest.java index 827e793c89..89d6d2062a 100644 --- a/integration-test/spark3/src/test/java/org/apache/uniffle/test/GetReaderTest.java +++ b/integration-test/spark3/src/test/java/org/apache/uniffle/test/GetReaderTest.java @@ -39,7 +39,7 @@ import org.apache.spark.shuffle.FetchFailedException; import org.apache.spark.shuffle.RssShuffleHandle; import org.apache.spark.shuffle.RssShuffleManager; -import org.apache.spark.shuffle.RssSparkConfig; +import org.apache.spark.shuffle.RssSparkClientConf; import org.apache.spark.shuffle.reader.RssShuffleReader; import org.apache.spark.sql.SparkSession; import org.apache.spark.util.AccumulatorV2; @@ -57,6 +57,7 @@ import org.apache.uniffle.server.ShuffleServerConf; import org.apache.uniffle.storage.util.StorageType; +import static org.apache.spark.shuffle.RssSparkClientConf.SPARK_CONFIG_KEY_PREFIX; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -67,7 +68,7 @@ public class GetReaderTest extends IntegrationTestBase { public void test() throws Exception { SparkConf sparkConf = new SparkConf(); sparkConf.set("spark.shuffle.manager", "org.apache.spark.shuffle.RssShuffleManager"); - sparkConf.set(RssSparkConfig.RSS_COORDINATOR_QUORUM.key(), COORDINATOR_QUORUM); + sparkConf.set(SPARK_CONFIG_KEY_PREFIX + RssSparkClientConf.RSS_COORDINATOR_QUORUM.key(), COORDINATOR_QUORUM); sparkConf.setMaster("local[4]"); final String remoteStorage1 = "hdfs://h1/p1"; final String remoteStorage2 = "hdfs://h2/p2";