Skip to content
4 changes: 4 additions & 0 deletions client-mr/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.github.luben</groupId>
<artifactId>zstd-jni</artifactId>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ public void init(Context context) throws IOException, ClassNotFoundException {
isMemoryShuffleEnabled(storageType),
sendThreadNum,
sendThreshold,
maxBufferSize);
maxBufferSize,
RssMRConfig.toRssConf(rssJobConf));
}

private Map<Integer, List<ShuffleServerInfo>> createAssignmentMap(JobConf jobConf) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,10 @@

import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.response.SendShuffleDataResult;
import org.apache.uniffle.common.RssShuffleUtils;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.compression.Codec;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.ChecksumUtils;
import org.apache.uniffle.common.util.ThreadUtils;
Expand Down Expand Up @@ -90,6 +91,8 @@ public class SortWriteBufferManager<K, V> {
private long sortTime = 0;
private final long maxBufferSize;
private final ExecutorService sendExecutorService;
private final RssConf rssConf;
private final Codec codec;

public SortWriteBufferManager(
long maxMemSize,
Expand All @@ -114,7 +117,8 @@ public SortWriteBufferManager(
boolean isMemoryShuffleEnabled,
int sendThreadNum,
double sendThreshold,
long maxBufferSize) {
long maxBufferSize,
RssConf rssConf) {
this.maxMemSize = maxMemSize;
this.taskAttemptId = taskAttemptId;
this.batch = batch;
Expand All @@ -140,6 +144,8 @@ public SortWriteBufferManager(
this.sendExecutorService = Executors.newFixedThreadPool(
sendThreadNum,
ThreadUtils.getThreadFactory("send-thread-%d"));
this.rssConf = rssConf;
this.codec = Codec.newInstance(rssConf);
}

// todo: Single Buffer should also have its size limit
Expand Down Expand Up @@ -309,7 +315,7 @@ ShuffleBlockInfo createShuffleBlock(SortWriteBuffer wb) {
int partitionId = wb.getPartitionId();
final int uncompressLength = data.length;
long start = System.currentTimeMillis();
final byte[] compressed = RssShuffleUtils.compressData(data);
final byte[] compressed = codec.compress(data);
final long crc32 = ChecksumUtils.getCrc32(compressed);
compressTime += System.currentTimeMillis() - start;
final long blockId = RssMRUtils.getBlockId((long)partitionId, taskAttemptId, getNextSeqNo(partitionId));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@

package org.apache.hadoop.mapreduce;

import java.util.Map;
import java.util.Set;

import com.google.common.collect.ImmutableSet;
import org.apache.hadoop.mapred.JobConf;

import org.apache.uniffle.client.util.RssClientConfig;
import org.apache.uniffle.common.config.RssConf;

public class RssMRConfig {

Expand Down Expand Up @@ -164,4 +167,17 @@ public class RssMRConfig {

public static final Set<String> RSS_MANDATORY_CLUSTER_CONF =
ImmutableSet.of(RSS_STORAGE_TYPE, RSS_REMOTE_STORAGE_PATH);

public static RssConf toRssConf(JobConf jobConf) {
RssConf rssConf = new RssConf();
for (Map.Entry<String, String> entry : jobConf) {
String key = entry.getKey();
if (!key.startsWith(MR_RSS_CONFIG_PREFIX)) {
continue;
}
key = key.substring(MR_RSS_CONFIG_PREFIX.length());
rssConf.setString(key, entry.getValue());
}
return rssConf;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@

import org.apache.uniffle.client.api.ShuffleReadClient;
import org.apache.uniffle.client.response.CompressedShuffleBlock;
import org.apache.uniffle.common.RssShuffleUtils;
import org.apache.uniffle.common.compression.Codec;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.ByteUnit;

Expand Down Expand Up @@ -84,14 +85,17 @@ private enum ShuffleErrors {
private long startWait;
private int waitCount = 0;
private byte[] uncompressedData = null;
private RssConf rssConf;
private Codec codec;

RssFetcher(JobConf job, TaskAttemptID reduceId,
TaskStatus status,
MergeManager<K,V> merger,
Progress progress,
Reporter reporter, ShuffleClientMetrics metrics,
ShuffleReadClient shuffleReadClient,
long totalBlockCount) {
TaskStatus status,
MergeManager<K, V> merger,
Progress progress,
Reporter reporter, ShuffleClientMetrics metrics,
ShuffleReadClient shuffleReadClient,
long totalBlockCount,
RssConf rssConf) {
this.jobConf = job;
this.reporter = reporter;
this.status = status;
Expand All @@ -114,6 +118,9 @@ private enum ShuffleErrors {

this.shuffleReadClient = shuffleReadClient;
this.totalBlockCount = totalBlockCount;

this.rssConf = rssConf;
this.codec = Codec.newInstance(rssConf);
}

public void fetchAllRssBlocks() throws IOException, InterruptedException {
Expand Down Expand Up @@ -150,8 +157,10 @@ public void copyFromRssServer() throws IOException {
// uncompress the block
if (!hasPendingData && compressedData != null) {
final long startDecompress = System.currentTimeMillis();
uncompressedData = RssShuffleUtils.decompressData(
compressedData, compressedBlock.getUncompressLength(), false).array();
int uncompressedLen = compressedBlock.getUncompressLength();
ByteBuffer decompressedBuffer = ByteBuffer.allocate(uncompressedLen);
codec.decompress(compressedData, uncompressedLen, decompressedBuffer, 0);
uncompressedData = decompressedBuffer.array();
unCompressionLength += compressedBlock.getUncompressLength();
long decompressDuration = System.currentTimeMillis() - startDecompress;
decompressTime += decompressDuration;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ public RawKeyValueIterator run() throws IOException, InterruptedException {
readerJobConf, new MRIdHelper());
ShuffleReadClient shuffleReadClient = ShuffleClientFactory.getInstance().createShuffleReadClient(request);
RssFetcher fetcher = new RssFetcher(mrJobConf, reduceId, taskStatus, merger, copyPhase, reporter, metrics,
shuffleReadClient, blockIdBitmap.getLongCardinality());
shuffleReadClient, blockIdBitmap.getLongCardinality(), RssMRConfig.toRssConf(rssJobConf));
fetcher.fetchAllRssBlocks();
LOG.info("In reduce: " + reduceId
+ ", Rss MR client fetches blocks from RSS server successfully");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.apache.uniffle.common.ShuffleAssignmentsInfo;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;

import static org.junit.jupiter.api.Assertions.assertEquals;
Expand Down Expand Up @@ -79,7 +80,8 @@ public void testWriteException() throws Exception {
true,
5,
0.2f,
1024000L);
1024000L,
new RssConf());
Random random = new Random();
for (int i = 0; i < 1000; i++) {
byte[] key = new byte[20];
Expand Down Expand Up @@ -128,7 +130,8 @@ public void testWriteException() throws Exception {
true,
5,
0.2f,
1024000L);
1024000L,
new RssConf());
byte[] key = new byte[20];
byte[] value = new byte[1024];
random.nextBytes(key);
Expand Down Expand Up @@ -176,7 +179,8 @@ public void testOnePartition() throws Exception {
true,
5,
0.2f,
100L);
100L,
new RssConf());
Random random = new Random();
for (int i = 0; i < 1000; i++) {
byte[] key = new byte[20];
Expand Down Expand Up @@ -223,7 +227,8 @@ public void testWriteNormal() throws Exception {
true,
5,
0.2f,
1024000L);
1024000L,
new RssConf());
Random random = new Random();
for (int i = 0; i < 1000; i++) {
byte[] key = new byte[20];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,12 @@
import org.apache.uniffle.client.response.SendShuffleDataResult;
import org.apache.uniffle.common.PartitionRange;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.RssShuffleUtils;
import org.apache.uniffle.common.ShuffleAssignmentsInfo;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.compression.Codec;
import org.apache.uniffle.common.compression.Lz4Codec;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;

import static org.junit.jupiter.api.Assertions.assertEquals;
Expand All @@ -88,6 +90,8 @@ public class FetcherTest {
static List<byte[]> data;
static MergeManagerImpl<Text, Text> merger;

static Codec codec = new Lz4Codec();

@Test
public void writeAndReadDataTestWithRss() throws Throwable {
fs = FileSystem.getLocal(conf);
Expand All @@ -97,7 +101,7 @@ public void writeAndReadDataTestWithRss() throws Throwable {
null, null, new Progress(), new MROutputFiles());
ShuffleReadClient shuffleReadClient = new MockedShuffleReadClient(data);
RssFetcher fetcher = new RssFetcher(jobConf, reduceId1, taskStatus, merger, new Progress(),
reporter, metrics, shuffleReadClient, 3);
reporter, metrics, shuffleReadClient, 3, new RssConf());
fetcher.fetchAllRssBlocks();


Expand Down Expand Up @@ -128,7 +132,7 @@ public void writeAndReadDataTestWithoutRss() throws Throwable {
null, null, new Progress(), new MROutputFiles());
ShuffleReadClient shuffleReadClient = new MockedShuffleReadClient(data);
RssFetcher fetcher = new RssFetcher(jobConf, reduceId1, taskStatus, merger, new Progress(),
reporter, metrics, shuffleReadClient, 3);
reporter, metrics, shuffleReadClient, 3, new RssConf());
fetcher.fetchAllRssBlocks();


Expand Down Expand Up @@ -161,7 +165,7 @@ public void writeAndReadDataMergeFailsTestWithRss() throws Throwable {
null, null, new Progress(), new MROutputFiles(), expectedFails);
ShuffleReadClient shuffleReadClient = new MockedShuffleReadClient(data);
RssFetcher fetcher = new RssFetcher(jobConf, reduceId1, taskStatus, merger, new Progress(),
reporter, metrics, shuffleReadClient, 3);
reporter, metrics, shuffleReadClient, 3, new RssConf());
fetcher.fetchAllRssBlocks();

RawKeyValueIterator iterator = merger.close();
Expand Down Expand Up @@ -276,7 +280,8 @@ private static byte[] writeMapOutputRss(Configuration conf, Map<String, String>
true,
5,
0.2f,
1024000L);
1024000L,
new RssConf());

for (String key : keysToValues.keySet()) {
String value = keysToValues.get(key);
Expand Down Expand Up @@ -357,7 +362,14 @@ public SendShuffleDataResult sendShuffleData(String appId, List<ShuffleBlockInfo
successBlockIds.add(blockInfo.getBlockId());
}
shuffleBlockInfoList.forEach(block -> {
data.add(RssShuffleUtils.decompressData(block.getData(), block.getUncompressLength()));
ByteBuffer uncompressedBuffer = ByteBuffer.allocate(block.getUncompressLength());
codec.decompress(
ByteBuffer.wrap(block.getData()),
block.getUncompressLength(),
uncompressedBuffer,
0
);
data.add(uncompressedBuffer.array());
});
return new SendShuffleDataResult(successBlockIds, Sets.newHashSet());
}
Expand Down Expand Up @@ -440,7 +452,7 @@ static class MockedShuffleReadClient implements ShuffleReadClient {
MockedShuffleReadClient(List<byte[]> data) {
this.blocks = new LinkedList<>();
data.forEach(bytes -> {
byte[] compressed = RssShuffleUtils.compressData(bytes);
byte[] compressed = codec.compress(bytes);
blocks.add(new CompressedShuffleBlock(ByteBuffer.wrap(compressed), bytes.length));
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@
import java.util.Set;

import com.google.common.collect.ImmutableSet;
import org.apache.spark.SparkConf;
import org.apache.spark.internal.config.ConfigBuilder;
import org.apache.spark.internal.config.ConfigEntry;
import org.apache.spark.internal.config.TypedConfigBuilder;
import scala.Tuple2;
import scala.runtime.AbstractFunction1;

import org.apache.uniffle.client.util.RssClientConfig;
import org.apache.uniffle.common.config.ConfigUtils;
import org.apache.uniffle.common.config.RssConf;

public class RssSparkConfig {

Expand Down Expand Up @@ -286,4 +289,17 @@ public Double apply(String in) {
public static TypedConfigBuilder<String> createStringBuilder(ConfigBuilder builder) {
return builder.stringConf();
}

public static RssConf toRssConf(SparkConf sparkConf) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change conf design in this ZSTD PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to make compressorFactory accessed by MR and Spark to create concrete codec which will be initialized by specified conf, so it will have two choice.

  1. Use the shareable RssConf like this PR
  2. Introduce the extra config bean of compression (I think there is no need to do so)

Besides, I want to refactor the code of MR/Spark client conf entry, this PR is to do some partial work. Please refer to #200

RssConf rssConf = new RssConf();
for (Tuple2<String, String> tuple : sparkConf.getAll()) {
String key = tuple._1;
if (!key.startsWith(SPARK_RSS_CONFIG_PREFIX)) {
continue;
}
key = key.substring(SPARK_RSS_CONFIG_PREFIX.length());
rssConf.setString(key, tuple._2);
}
return rssConf;
}
}
Loading