diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index 625f8e83cc..8f07604084 100644 --- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -86,7 +86,6 @@ public class RssShuffleManager implements ShuffleManager { private final int dataCommitPoolSize; private boolean heartbeatStarted = false; private boolean dynamicConfEnabled = false; - private RemoteStorageInfo remoteStorage; private ThreadPoolExecutor threadPoolExecutor; private EventLoop eventLoop = new EventLoop("ShuffleDataQueue") { @@ -213,9 +212,10 @@ public ShuffleHandle registerShuffle(int shuffleId, int numMaps, Shuff } String storageType = sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key()); - remoteStorage = new RemoteStorageInfo(sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), "")); - remoteStorage = ClientUtils.fetchRemoteStorage( - appId, remoteStorage, dynamicConfEnabled, storageType, shuffleWriteClient); + RemoteStorageInfo defaultRemoteStorage = new RemoteStorageInfo( + sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), "")); + RemoteStorageInfo remoteStorage = ClientUtils.fetchRemoteStorage( + appId, defaultRemoteStorage, dynamicConfEnabled, storageType, shuffleWriteClient); int partitionNumPerRange = sparkConf.get(RssSparkConfig.RSS_PARTITION_NUM_PER_RANGE); @@ -233,7 +233,7 @@ public ShuffleHandle registerShuffle(int shuffleId, int numMaps, Shuff ShuffleAssignmentsInfo response = shuffleWriteClient.getShuffleAssignments( appId, shuffleId, dependency.partitioner().numPartitions(), partitionNumPerRange, assignmentTags, requiredShuffleServerNumber); - registerShuffleServers(appId, shuffleId, response.getServerToPartitionRanges()); + registerShuffleServers(appId, shuffleId, response.getServerToPartitionRanges(), remoteStorage); return response.getPartitionToServers(); }, retryInterval, retryTimes); } catch (Throwable throwable) { @@ -268,7 +268,8 @@ private void startHeartbeat() { protected void registerShuffleServers( String appId, int shuffleId, - Map> serverToPartitionRanges) { + Map> serverToPartitionRanges, + RemoteStorageInfo remoteStorage) { if (serverToPartitionRanges == null || serverToPartitionRanges.isEmpty()) { return; } @@ -480,8 +481,4 @@ public void setAppId(String appId) { this.appId = appId; } - @VisibleForTesting - public void setRemoteStorage(RemoteStorageInfo remoteStorage) { - this.remoteStorage = remoteStorage; - } } diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index 96c93ef75e..41c2f4d7dc 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -92,7 +92,6 @@ public class RssShuffleManager implements ShuffleManager { private ScheduledExecutorService heartBeatScheduledExecutorService; private boolean heartbeatStarted = false; private boolean dynamicConfEnabled = false; - private RemoteStorageInfo remoteStorage; private final EventLoop eventLoop; private final EventLoop defaultEventLoop = new EventLoop("ShuffleDataQueue") { @@ -266,10 +265,10 @@ public ShuffleHandle registerShuffle(int shuffleId, ShuffleDependency< LOG.info("Generate application id used in rss: " + id.get()); String storageType = sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key()); - remoteStorage = new RemoteStorageInfo( + RemoteStorageInfo defaultRemoteStorage = new RemoteStorageInfo( sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), "")); - remoteStorage = ClientUtils.fetchRemoteStorage( - id.get(), remoteStorage, dynamicConfEnabled, storageType, shuffleWriteClient); + RemoteStorageInfo remoteStorage = ClientUtils.fetchRemoteStorage( + id.get(), defaultRemoteStorage, dynamicConfEnabled, storageType, shuffleWriteClient); Set assignmentTags = RssSparkShuffleUtils.getAssignmentTags(sparkConf); @@ -288,7 +287,7 @@ public ShuffleHandle registerShuffle(int shuffleId, ShuffleDependency< 1, assignmentTags, requiredShuffleServerNumber); - registerShuffleServers(id.get(), shuffleId, response.getServerToPartitionRanges()); + registerShuffleServers(id.get(), shuffleId, response.getServerToPartitionRanges(), remoteStorage); return response.getPartitionToServers(); }, retryInterval, retryTimes); } catch (Throwable throwable) { @@ -606,7 +605,8 @@ protected void registerShuffleServers( String appId, int shuffleId, Map> serverToPartitionRanges) { + List> serverToPartitionRanges, + RemoteStorageInfo remoteStorage) { if (serverToPartitionRanges == null || serverToPartitionRanges.isEmpty()) { return; } @@ -728,11 +728,6 @@ public void setAppId(String appId) { this.id = new AtomicReference<>(appId); } - @VisibleForTesting - public void setRemoteStorage(RemoteStorageInfo remoteStorage) { - this.remoteStorage = remoteStorage; - } - public String getId() { return id.get(); } diff --git a/integration-test/spark2/src/test/java/org/apache/uniffle/test/GetReaderTest.java b/integration-test/spark2/src/test/java/org/apache/uniffle/test/GetReaderTest.java index b7a6bec37a..b273d47c85 100644 --- a/integration-test/spark2/src/test/java/org/apache/uniffle/test/GetReaderTest.java +++ b/integration-test/spark2/src/test/java/org/apache/uniffle/test/GetReaderTest.java @@ -141,7 +141,6 @@ public void test() throws Exception { assertNull(commonHadoopConf.get("k2")); // mock the scenario that get reader in an executor - rssShuffleManager.setRemoteStorage(null); rssShuffleReader = (RssShuffleReader) rssShuffleManager.getReader( rssShuffleHandle, 0, 0, mockTaskContextImpl); hadoopConf = rssShuffleReader.getHadoopConf(); diff --git a/integration-test/spark3/src/test/java/org/apache/uniffle/test/GetReaderTest.java b/integration-test/spark3/src/test/java/org/apache/uniffle/test/GetReaderTest.java index 01b4d35f0d..113cb4b7a8 100644 --- a/integration-test/spark3/src/test/java/org/apache/uniffle/test/GetReaderTest.java +++ b/integration-test/spark3/src/test/java/org/apache/uniffle/test/GetReaderTest.java @@ -150,8 +150,6 @@ public void test() throws Exception { assertNull(commonHadoopConf.get("k1")); assertNull(commonHadoopConf.get("k2")); - // mock the scenario that get reader in an executor - rssShuffleManager.setRemoteStorage(null); rssShuffleReader = (RssShuffleReader) rssShuffleManager.getReader( rssShuffleHandle, 0, 0, new MockTaskContext(), new TempShuffleReadMetrics()); hadoopConf = rssShuffleReader.getHadoopConf();