Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,6 @@ public void waitSendFinished() {
sendBuffersToServers();
}
long start = System.currentTimeMillis();
long commitDuration = 0;
if (!isMemoryShuffleEnabled) {
long s = System.currentTimeMillis();
sendCommit();
commitDuration = System.currentTimeMillis() - s;
}
while (true) {
// if failed when send data to shuffle server, mark task as failed
if (failedBlockIds.size() > 0) {
Expand All @@ -291,6 +285,12 @@ public void waitSendFinished() {
throw new RssException(errorMsg);
}
}
long commitDuration = 0;
if (!isMemoryShuffleEnabled) {
long s = System.currentTimeMillis();
sendCommit();
commitDuration = System.currentTimeMillis() - s;
}

start = System.currentTimeMillis();
shuffleWriteClient.reportShuffleResult(partitionToServers, appId, 0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@

package org.apache.hadoop.mapred;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import com.google.common.collect.Sets;
import org.apache.hadoop.conf.Configuration;
Expand Down Expand Up @@ -261,9 +263,91 @@ public void testWriteNormal() throws Exception {
assertTrue(manager.getWaitSendBuffers().isEmpty());
}

@Test
public void testCommitBlocksWhenMemoryShuffleDisabled() throws Exception {
JobConf jobConf = new JobConf(new Configuration());
SerializationFactory serializationFactory = new SerializationFactory(jobConf);
MockShuffleWriteClient client = new MockShuffleWriteClient();
client.setMode(3);
Map<Integer, List<ShuffleServerInfo>> partitionToServers = JavaUtils.newConcurrentMap();
Set<Long> successBlocks = Sets.newConcurrentHashSet();
Set<Long> failedBlocks = Sets.newConcurrentHashSet();
Counters.Counter mapOutputByteCounter = new Counters.Counter();
Counters.Counter mapOutputRecordCounter = new Counters.Counter();
SortWriteBufferManager<BytesWritable, BytesWritable> manager;
manager = new SortWriteBufferManager<BytesWritable, BytesWritable>(
10240,
1L,
10,
serializationFactory.getSerializer(BytesWritable.class),
serializationFactory.getSerializer(BytesWritable.class),
WritableComparator.get(BytesWritable.class),
0.9,
"test",
client,
500,
5 * 1000,
partitionToServers,
successBlocks,
failedBlocks,
mapOutputByteCounter,
mapOutputRecordCounter,
1,
100,
1,
false,
5,
0.2f,
1024000L,
new RssConf());
Random random = new Random();
for (int i = 0; i < 1000; i++) {
byte[] key = new byte[20];
byte[] value = new byte[1024];
random.nextBytes(key);
random.nextBytes(value);
int partitionId = random.nextInt(50);
manager.addRecord(partitionId, new BytesWritable(key), new BytesWritable(value));
}
manager.waitSendFinished();
assertTrue(manager.getWaitSendBuffers().isEmpty());
// When MEMOEY storage type is disable, all blocks should flush.
assertEquals(client.mockedShuffleServer.getFinishBlockSize(), client.mockedShuffleServer.getFlushBlockSize());
}

class MockShuffleServer {

// All methods of MockShuffle are thread safe, because send-thread may do something in concurrent way.
private List<ShuffleBlockInfo> cachedBlockInfos = new ArrayList<>();
private List<ShuffleBlockInfo> flushBlockInfos = new ArrayList<>();
private List<Long> finishedBlockInfos = new ArrayList<>();

public synchronized void finishShuffle() {
flushBlockInfos.addAll(cachedBlockInfos);
}

public synchronized void addCachedBlockInfos(List<ShuffleBlockInfo> shuffleBlockInfoList) {
cachedBlockInfos.addAll(shuffleBlockInfoList);
}

public synchronized void addFinishedBlockInfos(List<Long> shuffleBlockInfoList) {
finishedBlockInfos.addAll(shuffleBlockInfoList);
}

public synchronized int getFlushBlockSize() {
return flushBlockInfos.size();
}

public synchronized int getFinishBlockSize() {
return finishedBlockInfos.size();
}
}

class MockShuffleWriteClient implements ShuffleWriteClient {

int mode = 0;
MockShuffleServer mockedShuffleServer = new MockShuffleServer();
int committedMaps = 0;

public void setMode(int mode) {
this.mode = mode;
Expand All @@ -277,6 +361,15 @@ public SendShuffleDataResult sendShuffleData(String appId, List<ShuffleBlockInfo
} else if (mode == 1) {
return new SendShuffleDataResult(Sets.newHashSet(2L), Sets.newHashSet(1L));
} else {
if (mode == 3) {
try {
Thread.sleep(10);
mockedShuffleServer.addCachedBlockInfos(shuffleBlockInfoList);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RssException(e);
}
}
Set<Long> successBlockIds = Sets.newHashSet();
for (ShuffleBlockInfo blockInfo : shuffleBlockInfoList) {
successBlockIds.add(blockInfo.getBlockId());
Expand Down Expand Up @@ -308,6 +401,13 @@ public void registerShuffle(

@Override
public boolean sendCommit(Set<ShuffleServerInfo> shuffleServerInfoSet, String appId, int shuffleId, int numMaps) {
if (mode == 3) {
committedMaps++;
if (committedMaps >= numMaps) {
mockedShuffleServer.finishShuffle();
}
return true;
}
return false;
}

Expand All @@ -329,7 +429,10 @@ public RemoteStorageInfo fetchRemoteStorage(String appId) {
@Override
public void reportShuffleResult(Map<Integer, List<ShuffleServerInfo>> partitionToServers, String appId,
int shuffleId, long taskAttemptId, Map<Integer, List<Long>> partitionToBlockIds, int bitmapNum) {

if (mode == 3) {
mockedShuffleServer.addFinishedBlockInfos(
partitionToBlockIds.values().stream().flatMap(it -> it.stream()).collect(Collectors.toList()));
}
}

@Override
Expand Down