From d05e086ad212bb37f921a39cc88670ac55795cbc Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Sun, 6 Nov 2022 09:07:25 +0800 Subject: [PATCH] [ISSUE-301][Subtask][Improvement][AQE] Merge continuous ShuffleDataSegment into single one --- .../segment/LocalOrderSegmentSplitter.java | 12 ++- .../LocalOrderSegmentSplitterTest.java | 88 +++++++++++++++++++ 2 files changed, 98 insertions(+), 2 deletions(-) diff --git a/common/src/main/java/org/apache/uniffle/common/segment/LocalOrderSegmentSplitter.java b/common/src/main/java/org/apache/uniffle/common/segment/LocalOrderSegmentSplitter.java index 77e02e06d6..0a4a669de7 100644 --- a/common/src/main/java/org/apache/uniffle/common/segment/LocalOrderSegmentSplitter.java +++ b/common/src/main/java/org/apache/uniffle/common/segment/LocalOrderSegmentSplitter.java @@ -70,6 +70,7 @@ public List split(ShuffleIndexResult shuffleIndexResult) { long totalLen = 0; long lastTaskAttemptId = -1; + long lastExpectedBlockIndex = -1; /** * One ShuffleDataSegment should meet following requirements: @@ -79,6 +80,7 @@ public List split(ShuffleIndexResult shuffleIndexResult) { * 3. ShuffleDataSegment's blocks should be continuous * */ + int index = 0; while (byteBuffer.hasRemaining()) { try { long offset = byteBuffer.getLong(); @@ -98,7 +100,8 @@ public List split(ShuffleIndexResult shuffleIndexResult) { break; } - if ((taskAttemptId < lastTaskAttemptId && bufferSegments.size() > 0) || bufferOffset >= readBufferSize) { + if ((taskAttemptId < lastTaskAttemptId && bufferSegments.size() > 0 && index - lastExpectedBlockIndex != 1) + || bufferOffset >= readBufferSize) { ShuffleDataSegment sds = new ShuffleDataSegment(fileOffset, bufferOffset, bufferSegments); dataFileSegments.add(sds); bufferSegments = Lists.newArrayList(); @@ -107,14 +110,20 @@ public List split(ShuffleIndexResult shuffleIndexResult) { } if (expectTaskIds.contains(taskAttemptId)) { + if (bufferOffset != 0 && index - lastExpectedBlockIndex > 1) { + throw new RssException("There are discontinuous blocks which should not happen when using LOCAL_ORDER."); + } + if (fileOffset == -1) { fileOffset = offset; } bufferSegments.add(new BufferSegment(blockId, bufferOffset, length, uncompressLength, crc, taskAttemptId)); bufferOffset += length; + lastExpectedBlockIndex = index; } lastTaskAttemptId = taskAttemptId; + index++; } catch (BufferUnderflowException ue) { throw new RssException("Read index data under flow", ue); } @@ -124,7 +133,6 @@ public List split(ShuffleIndexResult shuffleIndexResult) { ShuffleDataSegment sds = new ShuffleDataSegment(fileOffset, bufferOffset, bufferSegments); dataFileSegments.add(sds); } - return dataFileSegments; } } diff --git a/common/src/test/java/org/apache/uniffle/common/segment/LocalOrderSegmentSplitterTest.java b/common/src/test/java/org/apache/uniffle/common/segment/LocalOrderSegmentSplitterTest.java index 27a2981448..4df956d6ee 100644 --- a/common/src/test/java/org/apache/uniffle/common/segment/LocalOrderSegmentSplitterTest.java +++ b/common/src/test/java/org/apache/uniffle/common/segment/LocalOrderSegmentSplitterTest.java @@ -24,14 +24,102 @@ import org.junit.jupiter.api.Test; import org.roaringbitmap.longlong.Roaring64NavigableMap; +import org.apache.uniffle.common.BufferSegment; import org.apache.uniffle.common.ShuffleDataSegment; import org.apache.uniffle.common.ShuffleIndexResult; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; public class LocalOrderSegmentSplitterTest { + @Test + public void testSplitWithDiscontinuousBlocksShouldThrowException() { + Roaring64NavigableMap taskIds = Roaring64NavigableMap.bitmapOf(1, 2, 4); + LocalOrderSegmentSplitter splitter = new LocalOrderSegmentSplitter(taskIds, 32); + byte[] data = generateData( + Pair.of(1, 1), + Pair.of(1, 2), + Pair.of(1, 3), + Pair.of(1, 4) + ); + try { + splitter.split(new ShuffleIndexResult(data, -1)); + fail(); + } catch (Exception e) { + // ignore + } + } + + @Test + public void testSplitForMergeContinuousSegments() { + /** + * case1: (32, 5) (16, 1) (10, 1) (16, 2) (6, 1) (8, 1) (10, 3) (9, 1) + * + * It will skip the (32, 5) and merge others into one dataSegment when no exceeding the + * read buffer size. + */ + Roaring64NavigableMap taskIds = Roaring64NavigableMap.bitmapOf(1, 2); + LocalOrderSegmentSplitter splitter = new LocalOrderSegmentSplitter(taskIds, 1000); + byte[] data = generateData( + Pair.of(32, 5), + Pair.of(16, 1), + Pair.of(10, 1), + Pair.of(16, 2), + Pair.of(6, 1), + Pair.of(8, 1), + Pair.of(10, 3), + Pair.of(9, 1) + ); + List dataSegments = splitter.split(new ShuffleIndexResult(data, -1)); + assertEquals(2, dataSegments.size()); + assertEquals(32, dataSegments.get(0).getOffset()); + assertEquals(56, dataSegments.get(0).getLength()); + + List bufferSegments = dataSegments.get(0).getBufferSegments(); + assertEquals(0, bufferSegments.get(0).getOffset()); + assertEquals(16, bufferSegments.get(0).getLength()); + + assertEquals(16, bufferSegments.get(1).getOffset()); + assertEquals(10, bufferSegments.get(1).getLength()); + + assertEquals(26, bufferSegments.get(2).getOffset()); + assertEquals(16, bufferSegments.get(2).getLength()); + + assertEquals(42, bufferSegments.get(3).getOffset()); + assertEquals(6, bufferSegments.get(3).getLength()); + + assertEquals(48, bufferSegments.get(4).getOffset()); + assertEquals(8, bufferSegments.get(4).getLength()); + + assertEquals(98, dataSegments.get(1).getOffset()); + assertEquals(9, dataSegments.get(1).getLength()); + bufferSegments = dataSegments.get(1).getBufferSegments(); + assertEquals(1, bufferSegments.size()); + assertEquals(0, bufferSegments.get(0).getOffset()); + assertEquals(9, bufferSegments.get(0).getLength()); + + /** + * case2: (16, 1) (16, 2) (6, 1) + * + * It will skip merging into one dataSegment when exceeding the + * read buffer size. + */ + data = generateData( + Pair.of(16, 1), + Pair.of(15, 2), + Pair.of(1, 1), + Pair.of(6, 1) + ); + dataSegments = new LocalOrderSegmentSplitter(taskIds, 32).split(new ShuffleIndexResult(data, -1)); + assertEquals(2, dataSegments.size()); + assertEquals(0, dataSegments.get(0).getOffset()); + assertEquals(32, dataSegments.get(0).getLength()); + assertEquals(32, dataSegments.get(1).getOffset()); + assertEquals(6, dataSegments.get(1).getLength()); + } + @Test public void testSplit() { Roaring64NavigableMap taskIds = Roaring64NavigableMap.bitmapOf(1);