diff --git a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBuffer.java b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBuffer.java index fdec31f64e..ad0b31af54 100644 --- a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBuffer.java +++ b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBuffer.java @@ -82,7 +82,14 @@ public synchronized ShuffleDataFlushEvent toFlushEvent( } // buffer will be cleared, and new list must be created for async flush List spBlocks = new LinkedList<>(blocks); + List inFlushedQueueBlocks = spBlocks; if (dataDistributionType == ShuffleDataDistributionType.LOCAL_ORDER) { + /** + * When reordering the blocks, it will break down the original reads sequence to cause + * the data lost in some cases. + * So we should create a reference copy to avoid this. + */ + inFlushedQueueBlocks = new LinkedList<>(spBlocks); spBlocks.sort((o1, o2) -> new Long(o1.getTaskAttemptId()).compareTo(o2.getTaskAttemptId())); } long eventId = ShuffleFlushManager.ATOMIC_EVENT_ID.getAndIncrement(); @@ -96,7 +103,7 @@ public synchronized ShuffleDataFlushEvent toFlushEvent( spBlocks, isValid, this); - inFlushBlockMap.put(eventId, spBlocks); + inFlushBlockMap.put(eventId, inFlushedQueueBlocks); blocks.clear(); size = 0; return event; diff --git a/server/src/test/java/org/apache/uniffle/server/buffer/BufferTestBase.java b/server/src/test/java/org/apache/uniffle/server/buffer/BufferTestBase.java index bfdd4895e5..837448a580 100644 --- a/server/src/test/java/org/apache/uniffle/server/buffer/BufferTestBase.java +++ b/server/src/test/java/org/apache/uniffle/server/buffer/BufferTestBase.java @@ -47,13 +47,16 @@ protected ShufflePartitionedData createData(int len) { } protected ShufflePartitionedData createData(int partitionId, int len) { + return createData(partitionId, 0, len); + } + + protected ShufflePartitionedData createData(int partitionId, int taskAttemptId, int len) { byte[] buf = new byte[len]; new Random().nextBytes(buf); ShufflePartitionedBlock block = new ShufflePartitionedBlock( - len, len, ChecksumUtils.getCrc32(buf), atomBlockId.incrementAndGet(), 0, buf); + len, len, ChecksumUtils.getCrc32(buf), atomBlockId.incrementAndGet(), taskAttemptId, buf); ShufflePartitionedData data = new ShufflePartitionedData( partitionId, new ShufflePartitionedBlock[]{block}); return data; } - } diff --git a/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferTest.java b/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferTest.java index bb59f23c88..d275449a11 100644 --- a/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferTest.java +++ b/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferTest.java @@ -24,6 +24,7 @@ import org.junit.jupiter.api.Test; import org.apache.uniffle.common.BufferSegment; +import org.apache.uniffle.common.ShuffleDataDistributionType; import org.apache.uniffle.common.ShuffleDataResult; import org.apache.uniffle.common.ShufflePartitionedBlock; import org.apache.uniffle.common.ShufflePartitionedData; @@ -81,6 +82,46 @@ public void toFlushEventTest() { assertEquals(0, shuffleBuffer.getBlocks().size()); } + @Test + public void getShuffleDataWithLocalOrderTest() { + ShuffleBuffer shuffleBuffer = new ShuffleBuffer(200); + ShufflePartitionedData spd1 = createData(1, 1, 15); + ShufflePartitionedData spd2 = createData(1, 0, 15); + ShufflePartitionedData spd3 = createData(1, 2, 15); + shuffleBuffer.append(spd1); + shuffleBuffer.append(spd2); + shuffleBuffer.append(spd3); + + // First read from the cached data + ShuffleDataResult sdr = shuffleBuffer.getShuffleData(Constants.INVALID_BLOCK_ID, 16); + byte[] expectedData = getExpectedData(spd1, spd2); + compareBufferSegment(shuffleBuffer.getBlocks(), sdr.getBufferSegments(), 0, 2); + assertArrayEquals(expectedData, sdr.getData()); + + // Second read after flushed + ShuffleDataFlushEvent event1 = shuffleBuffer.toFlushEvent( + "appId", + 0, + 0, + 1, + null, + ShuffleDataDistributionType.LOCAL_ORDER + ); + long lastBlockId = sdr.getBufferSegments().get(1).getBlockId(); + sdr = shuffleBuffer.getShuffleData(lastBlockId, 16); + expectedData = getExpectedData(spd3); + compareBufferSegment(shuffleBuffer.getInFlushBlockMap().get(event1.getEventId()), sdr.getBufferSegments(), 2, 1); + assertArrayEquals(expectedData, sdr.getData()); + + assertEquals(0, event1.getShuffleBlocks().get(0).getTaskAttemptId()); + assertEquals(1, event1.getShuffleBlocks().get(1).getTaskAttemptId()); + assertEquals(2, event1.getShuffleBlocks().get(2).getTaskAttemptId()); + + assertEquals(1, shuffleBuffer.getInFlushBlockMap().get(event1.getEventId()).get(0).getTaskAttemptId()); + assertEquals(0, shuffleBuffer.getInFlushBlockMap().get(event1.getEventId()).get(1).getTaskAttemptId()); + assertEquals(2, shuffleBuffer.getInFlushBlockMap().get(event1.getEventId()).get(2).getTaskAttemptId()); + } + @Test public void getShuffleDataTest() { ShuffleBuffer shuffleBuffer = new ShuffleBuffer(200);