From dba03005f7df674dde8e3cd419e7faa545aa26c1 Mon Sep 17 00:00:00 2001 From: Kaijie Chen Date: Fri, 29 Jul 2022 13:03:40 +0800 Subject: [PATCH] [Improvement] Add RssUtils#cloneBitMap() --- .../uniffle/client/impl/ShuffleReadClientImpl.java | 13 ++----------- .../org/apache/uniffle/common/util/RssUtils.java | 6 ++++++ .../apache/uniffle/common/util/RssUtilsTest.java | 9 +++++++++ .../uniffle/test/SparkClientWithLocalTest.java | 2 +- .../apache/uniffle/server/ShuffleTaskManager.java | 7 ++----- 5 files changed, 20 insertions(+), 17 deletions(-) diff --git a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java index fc8b4d3d62..5059f8793c 100644 --- a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java +++ b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java @@ -17,7 +17,6 @@ package org.apache.uniffle.client.impl; -import java.io.IOException; import java.nio.ByteBuffer; import java.util.List; import java.util.Queue; @@ -110,11 +109,7 @@ public ShuffleReadClientImpl( } // copy blockIdBitmap to track all pending blocks - try { - pendingBlockIds = RssUtils.deserializeBitMap(RssUtils.serializeBitMap(blockIdBitmap)); - } catch (IOException ioe) { - throw new RuntimeException("Can't create pending blockIds.", ioe); - } + pendingBlockIds = RssUtils.cloneBitMap(blockIdBitmap); clientReadHandler = ShuffleHandlerFactory.getInstance().createShuffleReadHandler(request); } @@ -213,11 +208,7 @@ private int read() { @Override public void checkProcessedBlockIds() { Roaring64NavigableMap cloneBitmap; - try { - cloneBitmap = RssUtils.deserializeBitMap(RssUtils.serializeBitMap(blockIdBitmap)); - } catch (IOException ioe) { - throw new RuntimeException("Can't validate processed blockIds.", ioe); - } + cloneBitmap = RssUtils.cloneBitMap(blockIdBitmap); cloneBitmap.and(processedBlockIds); if (!blockIdBitmap.equals(cloneBitmap)) { throw new RssException("Blocks read inconsistent: expected " + blockIdBitmap.getLongCardinality() diff --git a/common/src/main/java/org/apache/uniffle/common/util/RssUtils.java b/common/src/main/java/org/apache/uniffle/common/util/RssUtils.java index 2219abae43..8b7f50081e 100644 --- a/common/src/main/java/org/apache/uniffle/common/util/RssUtils.java +++ b/common/src/main/java/org/apache/uniffle/common/util/RssUtils.java @@ -172,6 +172,12 @@ public static Roaring64NavigableMap deserializeBitMap(byte[] bytes) throws IOExc return bitmap; } + public static Roaring64NavigableMap cloneBitMap(Roaring64NavigableMap bitmap) { + Roaring64NavigableMap clone = Roaring64NavigableMap.bitmapOf(); + clone.or(bitmap); + return clone; + } + public static List transIndexDataToSegments( ShuffleIndexResult shuffleIndexResult, int readBufferSize) { if (shuffleIndexResult == null || shuffleIndexResult.isEmpty()) { diff --git a/common/src/test/java/org/apache/uniffle/common/util/RssUtilsTest.java b/common/src/test/java/org/apache/uniffle/common/util/RssUtilsTest.java index c1c95e234a..90639f023b 100644 --- a/common/src/test/java/org/apache/uniffle/common/util/RssUtilsTest.java +++ b/common/src/test/java/org/apache/uniffle/common/util/RssUtilsTest.java @@ -36,6 +36,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotSame; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; @@ -88,6 +89,14 @@ public void testSerializeBitmap() throws Exception { assertEquals(Roaring64NavigableMap.bitmapOf(), RssUtils.deserializeBitMap(new byte[]{})); } + @Test + public void testCloneBitmap() { + Roaring64NavigableMap bitmap1 = Roaring64NavigableMap.bitmapOf(1, 2, 100, 10000); + Roaring64NavigableMap bitmap2 = RssUtils.cloneBitMap(bitmap1); + assertNotSame(bitmap1, bitmap2); + assertEquals(bitmap1, bitmap2); + } + @Test public void testShuffleIndexSegment() { ShuffleIndexResult shuffleIndexResult = new ShuffleIndexResult(); diff --git a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkClientWithLocalTest.java b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkClientWithLocalTest.java index b0e68d2848..4c4e1905d1 100644 --- a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkClientWithLocalTest.java +++ b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkClientWithLocalTest.java @@ -318,7 +318,7 @@ public void readTest9() throws Exception { ShuffleReadClientImpl readClient; createTestData(testAppId, expectedData, blockIdBitmap, taskIdBitmap); - Roaring64NavigableMap beforeAdded = RssUtils.deserializeBitMap(RssUtils.serializeBitMap(blockIdBitmap)); + Roaring64NavigableMap beforeAdded = RssUtils.cloneBitMap(blockIdBitmap); // write data by another task, read data again, the cache for index file should be updated blocks = createShuffleBlockList( 0, 0, 1, 3, 25, blockIdBitmap, Maps.newHashMap(), mockSSI); diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java index 0876c5140c..3243c2d713 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java @@ -170,11 +170,9 @@ public StatusCode commitShuffle(String appId, int shuffleId) throws Exception { if (System.currentTimeMillis() - start > commitTimeout) { throw new RuntimeException("Shuffle data commit timeout for " + commitTimeout + " ms"); } - byte[] bitmapBytes; synchronized (cachedBlockIds) { - bitmapBytes = RssUtils.serializeBitMap(cachedBlockIds); + cloneBlockIds = RssUtils.cloneBitMap(cachedBlockIds); } - cloneBlockIds = RssUtils.deserializeBitMap(bitmapBytes); long expectedCommitted = cloneBlockIds.getLongCardinality(); shuffleBufferManager.commitShuffleTask(appId, shuffleId); Roaring64NavigableMap committedBlockIds; @@ -183,9 +181,8 @@ public StatusCode commitShuffle(String appId, int shuffleId) throws Exception { while (true) { committedBlockIds = shuffleFlushManager.getCommittedBlockIds(appId, shuffleId); synchronized (committedBlockIds) { - bitmapBytes = RssUtils.serializeBitMap(committedBlockIds); + cloneCommittedBlockIds = RssUtils.cloneBitMap(committedBlockIds); } - cloneCommittedBlockIds = RssUtils.deserializeBitMap(bitmapBytes); cloneBlockIds.andNot(cloneCommittedBlockIds); if (cloneBlockIds.isEmpty()) { break;