diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java index 54c3fd88bd..00ffc31ae0 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java @@ -117,29 +117,16 @@ public List addRecord(int partitionId, Object key, Object valu return null; } List result = Lists.newArrayList(); - if (buffers.containsKey(partitionId)) { - WriterBuffer wb = buffers.get(partitionId); - if (wb.askForMemory(serializedDataLength)) { - if (serializedDataLength > bufferSegmentSize) { - requestMemory(serializedDataLength); - } else { - requestMemory(bufferSegmentSize); - } - } - wb.addRecord(serializedData, serializedDataLength); - if (wb.getMemoryUsed() > bufferSize) { - result.add(createShuffleBlock(partitionId, wb)); - copyTime += wb.getCopyTime(); - buffers.remove(partitionId); - LOG.debug("Single buffer is full for shuffleId[" + shuffleId - + "] partition[" + partitionId + "] with memoryUsed[" + wb.getMemoryUsed() - + "], dataLength[" + wb.getDataLength() + "]"); - } - } else { - requestMemory(bufferSegmentSize); - WriterBuffer wb = new WriterBuffer(bufferSegmentSize); - wb.addRecord(serializedData, serializedDataLength); - buffers.put(partitionId, wb); + WriterBuffer wb = buffers.computeIfAbsent(partitionId, k -> new WriterBuffer(bufferSegmentSize)); + requestMemory(wb.calculateMemoryCost(serializedDataLength)); + wb.addRecord(serializedData, serializedDataLength); + if (wb.getMemoryUsed() > bufferSize) { + result.add(createShuffleBlock(partitionId, wb)); + copyTime += wb.getCopyTime(); + buffers.remove(partitionId); + LOG.info("Single buffer is full for shuffleId[" + shuffleId + + "] partition[" + partitionId + "] with memoryUsed[" + wb.getMemoryUsed() + + "], dataLength[" + wb.getDataLength() + "]"); } shuffleWriteMetrics.incRecordsWritten(1L); @@ -175,15 +162,17 @@ protected ShuffleBlockInfo createShuffleBlock(int partitionId, WriterBuffer wb) final int uncompressLength = data.length; long start = System.currentTimeMillis(); final byte[] compressed = RssShuffleUtils.compressData(data); + final int compressLength = compressed.length; final long crc32 = ChecksumUtils.getCrc32(compressed); compressTime += System.currentTimeMillis() - start; + freeInnerUsed(wb.getMemoryUsed() - compressLength); final long blockId = ClientUtils.getBlockId(partitionId, taskAttemptId, getNextSeqNo(partitionId)); uncompressedDataLen += data.length; - shuffleWriteMetrics.incBytesWritten(compressed.length); + shuffleWriteMetrics.incBytesWritten(compressLength); // add memory to indicate bytes which will be sent to shuffle server - inSendListBytes.addAndGet(wb.getMemoryUsed()); + inSendListBytes.addAndGet(compressLength); return new ShuffleBlockInfo(shuffleId, partitionId, blockId, compressed.length, crc32, - compressed, partitionToServers.get(partitionId), uncompressLength, wb.getMemoryUsed(), taskAttemptId); + compressed, partitionToServers.get(partitionId), uncompressLength, compressLength, taskAttemptId); } // it's run in single thread, and is not thread safe @@ -195,6 +184,9 @@ private int getNextSeqNo(int partitionId) { } private void requestMemory(long requiredMem) { + if (requiredMem == 0) { + return; + } final long start = System.currentTimeMillis(); if (allocatedBytes.get() - usedBytes.get() < requiredMem) { requestExecutorMemory(requiredMem); @@ -204,10 +196,10 @@ private void requestMemory(long requiredMem) { } private void requestExecutorMemory(long leastMem) { - long gotMem = acquireMemory(askExecutorMemory); - allocatedBytes.addAndGet(gotMem); + long gotMem = 0; + gotMem += acquireMemory(askExecutorMemory); int retry = 0; - while (allocatedBytes.get() - usedBytes.get() < leastMem) { + while (gotMem < leastMem) { LOG.info("Can't get memory for now, sleep and try[" + retry + "] again, request[" + askExecutorMemory + "], got[" + gotMem + "] less than " + leastMem); @@ -216,8 +208,7 @@ private void requestExecutorMemory(long leastMem) { } catch (InterruptedException ie) { LOG.warn("Exception happened when waiting for memory.", ie); } - gotMem = acquireMemory(askExecutorMemory); - allocatedBytes.addAndGet(gotMem); + gotMem += acquireMemory(askExecutorMemory); retry++; if (retry > requireMemoryRetryMax) { String message = "Can't get memory to cache shuffle data, request[" + askExecutorMemory @@ -226,9 +217,11 @@ private void requestExecutorMemory(long leastMem) { + " or consider to optimize 'spark.executor.memory'," + " 'spark.rss.writer.buffer.spill.size'."; LOG.error(message); + allocatedBytes.addAndGet(gotMem); throw new RssException(message); } } + allocatedBytes.addAndGet(gotMem); } @Override @@ -259,6 +252,15 @@ public void freeAllocatedMemory(long freeMemory) { inSendListBytes.addAndGet(-freeMemory); } + public void freeInnerUsed(long freeMemmory) { + if (freeMemmory < 0) { + // this will not happen in common case + requestMemory(-freeMemmory); + } else { + usedBytes.addAndGet(-freeMemmory); + } + } + public void freeAllMemory() { long memory = allocatedBytes.get(); if (memory > 0) { diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriterBuffer.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriterBuffer.java index 6fb41975c0..d5bd658e53 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriterBuffer.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriterBuffer.java @@ -39,37 +39,43 @@ public WriterBuffer(int bufferSize) { } public void addRecord(byte[] recordBuffer, int length) { - if (askForMemory(length)) { - // buffer has data already, add buffer to list - if (nextOffset > 0) { - buffers.add(new WrappedBuffer(buffer, nextOffset)); - nextOffset = 0; - } - if (length > bufferSize) { - buffer = new byte[length]; - memoryUsed += length; - } else { - buffer = new byte[bufferSize]; - memoryUsed += bufferSize; - } + int require = calculateMemoryCost(length); + int hasCopied = 0; + if (require > 0 && buffer != null && buffer.length - nextOffset > 0) { + hasCopied = buffer.length - nextOffset; + System.arraycopy(recordBuffer, 0, buffer, nextOffset, hasCopied); } - - try { - System.arraycopy(recordBuffer, 0, buffer, nextOffset, length); - } catch (Exception e) { - LOG.error("Unexpect exception for System.arraycopy, length[" + length + "], nextOffset[" - + nextOffset + "], bufferSize[" + bufferSize + "]"); - throw e; + if (require > 0 && buffer != null) { + buffers.add(new WrappedBuffer(buffer, buffer.length)); } - - nextOffset += length; + if (require > 0) { + buffer = new byte[require]; + nextOffset = 0; + } + System.arraycopy(recordBuffer, 0, buffer, nextOffset, length - hasCopied); + nextOffset += length - hasCopied; + memoryUsed += require; dataLength += length; } - public boolean askForMemory(long length) { - return buffer == null || nextOffset + length > bufferSize; + public int calculateMemoryCost(int length) { + if (buffer == null) { + return Math.max(bufferSize, length); + } + int require = length + nextOffset - buffer.length; + if (require <= 0) { + return 0; + } + if (require < bufferSize) { + return bufferSize; + } + return require; } + /** + * add record can no longer be called after this method + * get data and release buffers immediately + */ public byte[] getData() { byte[] data = new byte[dataLength]; int offset = 0; @@ -81,6 +87,8 @@ public byte[] getData() { // nextOffset is the length of current buffer used System.arraycopy(buffer, 0, data, offset, nextOffset); copyTime += System.currentTimeMillis() - start; + buffers = null; + buffer = null; return data; } diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java index 3e62f94926..af919d2174 100644 --- a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java @@ -78,20 +78,21 @@ public void addRecordTest() { wbm.addRecord(0, testKey, testValue); wbm.addRecord(0, testKey, testValue); wbm.addRecord(0, testKey, testValue); + wbm.addRecord(0, testKey, testValue); result = wbm.addRecord(0, testKey, testValue); // single buffer is full assertEquals(1, result.size()); assertEquals(512, wbm.getAllocatedBytes()); - assertEquals(96, wbm.getUsedBytes()); - assertEquals(96, wbm.getInSendListBytes()); + assertEquals(35, wbm.getUsedBytes()); + assertEquals(35, wbm.getInSendListBytes()); assertEquals(0, wbm.getBuffers().size()); wbm.addRecord(0, testKey, testValue); wbm.addRecord(1, testKey, testValue); wbm.addRecord(2, testKey, testValue); // single buffer is not full, and less than spill size assertEquals(512, wbm.getAllocatedBytes()); - assertEquals(192, wbm.getUsedBytes()); - assertEquals(96, wbm.getInSendListBytes()); + assertEquals(131, wbm.getUsedBytes()); + assertEquals(35, wbm.getInSendListBytes()); assertEquals(3, wbm.getBuffers().size()); // all buffer size > spill size wbm.addRecord(3, testKey, testValue); @@ -99,27 +100,26 @@ public void addRecordTest() { result = wbm.addRecord(5, testKey, testValue); assertEquals(6, result.size()); assertEquals(512, wbm.getAllocatedBytes()); - assertEquals(288, wbm.getUsedBytes()); - assertEquals(288, wbm.getInSendListBytes()); + assertEquals(113, wbm.getUsedBytes()); + assertEquals(113, wbm.getInSendListBytes()); assertEquals(0, wbm.getBuffers().size()); // free memory - wbm.freeAllocatedMemory(96); - assertEquals(416, wbm.getAllocatedBytes()); - assertEquals(192, wbm.getUsedBytes()); - assertEquals(192, wbm.getInSendListBytes()); + wbm.freeAllocatedMemory(113); + assertEquals(399, wbm.getAllocatedBytes()); + assertEquals(0, wbm.getUsedBytes()); + assertEquals(0, wbm.getInSendListBytes()); - assertEquals(11, wbm.getShuffleWriteMetrics().recordsWritten()); + assertEquals(12, wbm.getShuffleWriteMetrics().recordsWritten()); assertTrue(wbm.getShuffleWriteMetrics().bytesWritten() > 0); - wbm.freeAllocatedMemory(192); wbm.addRecord(0, testKey, testValue); wbm.addRecord(1, testKey, testValue); wbm.addRecord(2, testKey, testValue); result = wbm.clear(); assertEquals(3, result.size()); - assertEquals(224, wbm.getAllocatedBytes()); - assertEquals(96, wbm.getUsedBytes()); - assertEquals(96, wbm.getInSendListBytes()); + assertEquals(399, wbm.getAllocatedBytes()); + assertEquals(39, wbm.getUsedBytes()); + assertEquals(39, wbm.getInSendListBytes()); } @Test diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferTest.java index 5d612768f7..0408a1dd01 100644 --- a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferTest.java +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferTest.java @@ -49,16 +49,16 @@ public void test() { wb.addRecord(serializedData, serializedDataLength); assertEquals(32, wb.getMemoryUsed()); // case: data size < output buffer size, when getData(), [] + buffer with 24b = 24b - assertEquals(24, wb.getData().length); + assertEquals(24, wb.getDataLength()); wb.addRecord(serializedData, serializedDataLength); // case: data size > output buffer size, when getData(), [1 buffer] + buffer with 12 = 36b - assertEquals(36, wb.getData().length); + assertEquals(36, wb.getDataLength()); assertEquals(64, wb.getMemoryUsed()); wb.addRecord(serializedData, serializedDataLength); wb.addRecord(serializedData, serializedDataLength); // case: data size > output buffer size, when getData(), 2 buffer + output with 12b = 60b assertEquals(60, wb.getData().length); - assertEquals(96, wb.getMemoryUsed()); + assertEquals(64, wb.getMemoryUsed()); wb = new WriterBuffer(32); @@ -73,7 +73,6 @@ public void test() { assertEquals(99, wb.getMemoryUsed()); // 67 + 12 assertEquals(79, wb.getDataLength()); - assertEquals(79, wb.getData().length); wb.addRecord(serializedData, serializedDataLength); assertEquals(99, wb.getMemoryUsed());