diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java index 3f408ebc4b..5dcf45cf4d 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java @@ -117,28 +117,19 @@ public List addRecord(int partitionId, Object key, Object valu return null; } List result = Lists.newArrayList(); - if (buffers.containsKey(partitionId)) { - WriterBuffer wb = buffers.get(partitionId); - if (wb.askForMemory(serializedDataLength)) { - requestMemory(Math.max(bufferSegmentSize, serializedDataLength)); - } - wb.addRecord(serializedData, serializedDataLength); - if (wb.getMemoryUsed() > bufferSize) { - result.add(createShuffleBlock(partitionId, wb)); - copyTime += wb.getCopyTime(); - buffers.remove(partitionId); - LOG.debug("Single buffer is full for shuffleId[" + shuffleId - + "] partition[" + partitionId + "] with memoryUsed[" + wb.getMemoryUsed() - + "], dataLength[" + wb.getDataLength() + "]"); - } - } else { - requestMemory(Math.max(bufferSegmentSize, serializedDataLength)); - WriterBuffer wb = new WriterBuffer(bufferSegmentSize); - wb.addRecord(serializedData, serializedDataLength); - buffers.put(partitionId, wb); - } + WriterBuffer wb = buffers.computeIfAbsent(partitionId, + k -> new WriterBuffer(bufferSegmentSize)); + requestMemory(wb.calculateMemoryCost(serializedDataLength)); + wb.addRecord(serializedData, serializedDataLength); shuffleWriteMetrics.incRecordsWritten(1L); - + if (wb.getMemoryUsed() > bufferSize) { + result.add(createShuffleBlock(partitionId, wb)); + copyTime += wb.getCopyTime(); + buffers.remove(partitionId); + LOG.debug("Single buffer is full for shuffleId[" + shuffleId + + "] partition[" + partitionId + "] with memoryUsed[" + wb.getMemoryUsed() + + "], dataLength[" + wb.getDataLength() + "]"); + } // check buffer size > spill threshold if (usedBytes.get() - inSendListBytes.get() > spillSize) { result.addAll(clear()); diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriterBuffer.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriterBuffer.java index 6fb41975c0..d2b0a87d2a 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriterBuffer.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriterBuffer.java @@ -39,35 +39,35 @@ public WriterBuffer(int bufferSize) { } public void addRecord(byte[] recordBuffer, int length) { - if (askForMemory(length)) { - // buffer has data already, add buffer to list - if (nextOffset > 0) { - buffers.add(new WrappedBuffer(buffer, nextOffset)); - nextOffset = 0; + int require = calculateMemoryCost(length); + int hasCopied = 0; + if (require > 0) { + if (buffer != null) { + int toCopy = buffer.length - nextOffset; + if (toCopy > 0) { + hasCopied = toCopy; + System.arraycopy(recordBuffer, 0, buffer, nextOffset, hasCopied); + } + buffers.add(new WrappedBuffer(buffer, buffer.length)); } - if (length > bufferSize) { - buffer = new byte[length]; - memoryUsed += length; - } else { - buffer = new byte[bufferSize]; - memoryUsed += bufferSize; - } - } - - try { - System.arraycopy(recordBuffer, 0, buffer, nextOffset, length); - } catch (Exception e) { - LOG.error("Unexpect exception for System.arraycopy, length[" + length + "], nextOffset[" - + nextOffset + "], bufferSize[" + bufferSize + "]"); - throw e; + buffer = new byte[require]; + nextOffset = 0; } - - nextOffset += length; + System.arraycopy(recordBuffer, hasCopied, buffer, nextOffset, length - hasCopied); + nextOffset += length - hasCopied; + memoryUsed += require; dataLength += length; } - public boolean askForMemory(long length) { - return buffer == null || nextOffset + length > bufferSize; + public int calculateMemoryCost(int length) { + if (buffer == null) { + return Math.max(length, bufferSize); + } + int require = length + nextOffset - buffer.length; + if (require <= 0) { + return 0; + } + return Math.max(require, bufferSize); } public byte[] getData() { diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java index 1a22633409..b38e5ae88d 100644 --- a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java @@ -78,6 +78,7 @@ public void addRecordTest() { wbm.addRecord(0, testKey, testValue); wbm.addRecord(0, testKey, testValue); wbm.addRecord(0, testKey, testValue); + wbm.addRecord(0, testKey, testValue); result = wbm.addRecord(0, testKey, testValue); // single buffer is full assertEquals(1, result.size()); @@ -108,7 +109,7 @@ public void addRecordTest() { assertEquals(192, wbm.getUsedBytes()); assertEquals(192, wbm.getInSendListBytes()); - assertEquals(11, wbm.getShuffleWriteMetrics().recordsWritten()); + assertEquals(12, wbm.getShuffleWriteMetrics().recordsWritten()); assertTrue(wbm.getShuffleWriteMetrics().bytesWritten() > 0); wbm.freeAllocatedMemory(192); diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferTest.java index 5d612768f7..3678bd4843 100644 --- a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferTest.java +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferTest.java @@ -58,7 +58,7 @@ public void test() { wb.addRecord(serializedData, serializedDataLength); // case: data size > output buffer size, when getData(), 2 buffer + output with 12b = 60b assertEquals(60, wb.getData().length); - assertEquals(96, wb.getMemoryUsed()); + assertEquals(64, wb.getMemoryUsed()); wb = new WriterBuffer(32);