Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,26 @@
import scala.runtime.AbstractFunction1;

import org.apache.uniffle.client.util.RssClientConfig;
import org.apache.uniffle.common.config.ConfigOption;
import org.apache.uniffle.common.config.ConfigOptions;
import org.apache.uniffle.common.config.ConfigUtils;
import org.apache.uniffle.common.config.RssConf;

public class RssSparkConfig {

public static final ConfigOption<Long> RSS_CLIENT_SEND_SIZE_LIMITATION = ConfigOptions
.key("rss.client.send.size.limit")
.longType()
.defaultValue(1024 * 1024 * 16L)
.withDescription("The max data size sent to shuffle server");

public static final ConfigOption<Integer> RSS_MEMORY_SPILL_TIMEOUT = ConfigOptions
.key("rss.client.memory.spill.timeout.sec")
.intType()
.defaultValue(1)
.withDescription("The timeout of spilling data to remote shuffle server, "
+ "which will be triggered by Spark TaskMemoryManager. Unit is sec, default value is 1");

public static final String SPARK_RSS_CONFIG_PREFIX = "spark.";

public static final ConfigEntry<Integer> RSS_PARTITION_NUM_PER_RANGE = createIntegerBuilder(
Expand Down Expand Up @@ -115,11 +130,6 @@ public class RssSparkConfig {
new ConfigBuilder("spark.rss.client.heartBeat.threadNum"))
.createWithDefault(4);

public static final ConfigEntry<String> RSS_CLIENT_SEND_SIZE_LIMIT = createStringBuilder(
new ConfigBuilder("spark.rss.client.send.size.limit")
.doc("The max data size sent to shuffle server"))
.createWithDefault("16m");

public static final ConfigEntry<Integer> RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE = createIntegerBuilder(
new ConfigBuilder("spark.rss.client.unregister.thread.pool.size"))
.createWithDefault(10);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.shuffle.writer;

import java.util.ArrayList;
import java.util.List;

import org.apache.uniffle.common.ShuffleBlockInfo;
Expand All @@ -25,10 +26,26 @@ public class AddBlockEvent {

private String taskId;
private List<ShuffleBlockInfo> shuffleDataInfoList;
private List<Runnable> processedCallbackChain;

public AddBlockEvent(String taskId, List<ShuffleBlockInfo> shuffleDataInfoList) {
this.taskId = taskId;
this.shuffleDataInfoList = shuffleDataInfoList;
this.processedCallbackChain = new ArrayList<>();
}

public AddBlockEvent(String taskId, List<ShuffleBlockInfo> shuffleBlockInfoList, Runnable callback) {
this.taskId = taskId;
this.shuffleDataInfoList = shuffleBlockInfoList;
this.processedCallbackChain = new ArrayList<>();
addCallback(callback);
}

/**
* @param callback, should not throw any exception and execute fast.
*/
public void addCallback(Runnable callback) {
processedCallbackChain.add(callback);
}

public String getTaskId() {
Expand All @@ -39,6 +56,10 @@ public List<ShuffleBlockInfo> getShuffleDataInfoList() {
return shuffleDataInfoList;
}

public List<Runnable> getProcessedCallbackChain() {
return processedCallbackChain;
}

@Override
public String toString() {
return "AddBlockEvent: TaskId[" + taskId + "], " + shuffleDataInfoList;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle.writer;

import java.io.Closeable;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

import com.google.common.collect.Queues;
import com.google.common.collect.Sets;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.response.SendShuffleDataResult;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.ThreadUtils;

/**
* A {@link DataPusher} that is responsible for sending data to remote
* shuffle servers asynchronously.
*/
public class DataPusher implements Closeable {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could DataPusher be used for MapReduce?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't seen this part code.

private static final Logger LOGGER = LoggerFactory.getLogger(DataPusher.class);

private final ExecutorService executorService;

private final ShuffleWriteClient shuffleWriteClient;
// Must be thread safe
private final Map<String, Set<Long>> taskToSuccessBlockIds;
// Must be thread safe
private final Map<String, Set<Long>> taskToFailedBlockIds;
private String rssAppId;
// Must be thread safe
private final Set<String> failedTaskIds;

public DataPusher(ShuffleWriteClient shuffleWriteClient,
Map<String, Set<Long>> taskToSuccessBlockIds,
Map<String, Set<Long>> taskToFailedBlockIds,
Set<String> failedTaskIds,
int threadPoolSize,
int threadKeepAliveTime) {
this.shuffleWriteClient = shuffleWriteClient;
this.taskToSuccessBlockIds = taskToSuccessBlockIds;
this.taskToFailedBlockIds = taskToFailedBlockIds;
this.failedTaskIds = failedTaskIds;
this.executorService = new ThreadPoolExecutor(
threadPoolSize,
threadPoolSize * 2,
threadKeepAliveTime,
TimeUnit.SECONDS,
Queues.newLinkedBlockingQueue(Integer.MAX_VALUE),
ThreadUtils.getThreadFactory(this.getClass().getName())
);
}

public CompletableFuture<Long> send(AddBlockEvent event) {
if (rssAppId == null) {
throw new RssException("RssAppId should be set.");
}
return CompletableFuture.supplyAsync(() -> {
String taskId = event.getTaskId();
List<ShuffleBlockInfo> shuffleBlockInfoList = event.getShuffleDataInfoList();
try {
SendShuffleDataResult result = shuffleWriteClient.sendShuffleData(
rssAppId,
shuffleBlockInfoList,
() -> !isValidTask(taskId)
);
putBlockId(taskToSuccessBlockIds, taskId, result.getSuccessBlockIds());
putBlockId(taskToFailedBlockIds, taskId, result.getFailedBlockIds());
} finally {
List<Runnable> callbackChain = Optional.of(event.getProcessedCallbackChain()).orElse(Collections.EMPTY_LIST);
for (Runnable runnable : callbackChain) {
runnable.run();
}
}
return shuffleBlockInfoList.stream()
.map(x -> x.getFreeMemory())
.reduce((a, b) -> a + b)
.get();
}, executorService);
}

private synchronized void putBlockId(
Map<String, Set<Long>> taskToBlockIds,
String taskAttemptId,
Set<Long> blockIds) {
if (blockIds == null || blockIds.isEmpty()) {
return;
}
taskToBlockIds.computeIfAbsent(taskAttemptId, x -> Sets.newConcurrentHashSet()).addAll(blockIds);
}

public boolean isValidTask(String taskId) {
return !failedTaskIds.contains(taskId);
}

public void setRssAppId(String rssAppId) {
this.rssAppId = rssAppId;
}

@Override
public void close() throws IOException {
if (executorService != null) {
try {
ThreadUtils.shutdownThreadPool(executorService, 5);
} catch (InterruptedException interruptedException) {
LOGGER.error("Errors on shutdown thread pool of [{}].", this.getClass().getSimpleName());
}
}
}
}
Loading