diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index d6a359db66f4..850d6845684b 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -928,6 +928,15 @@ package object config { .booleanConf .createWithDefault(true) + private[spark] val SHUFFLE_DETECT_CORRUPT_MEMORY = + ConfigBuilder("spark.shuffle.detectCorrupt.useExtraMemory") + .doc("If enabled, part of a compressed/encrypted stream will be de-compressed/de-crypted " + + "by using extra memory to detect early corruption. Any IOException thrown will cause " + + "the task to be retried once and if it fails again with same exception, then " + + "FetchFailedException will be thrown to retry previous stage") + .booleanConf + .createWithDefault(false) + private[spark] val SHUFFLE_SYNC = ConfigBuilder("spark.shuffle.sync") .doc("Whether to force outstanding writes to disk.") diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index c5eefc7c5c04..c7843710413d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -55,6 +55,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), + SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY), readMetrics).toCompletionIterator val serializerInstance = dep.serializer.newInstance() diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index c75b20906918..c89d5cc971d2 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import java.io.{InputStream, IOException} +import java.io.{InputStream, IOException, SequenceInputStream} import java.nio.ByteBuffer import java.util.concurrent.{LinkedBlockingQueue, TimeUnit} import javax.annotation.concurrent.GuardedBy @@ -25,6 +25,8 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} +import org.apache.commons.io.IOUtils + import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} @@ -32,7 +34,6 @@ import org.apache.spark.network.shuffle._ import org.apache.spark.network.util.TransportConf import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter} import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} -import org.apache.spark.util.io.ChunkedByteBufferOutputStream /** * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block @@ -73,6 +74,7 @@ final class ShuffleBlockFetcherIterator( maxBlocksInFlightPerAddress: Int, maxReqSizeShuffleToMem: Long, detectCorrupt: Boolean, + detectCorruptUseExtraMemory: Boolean, shuffleMetrics: ShuffleReadMetricsReporter) extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging { @@ -406,6 +408,7 @@ final class ShuffleBlockFetcherIterator( var result: FetchResult = null var input: InputStream = null + var streamCompressedOrEncrypted: Boolean = false // Take the next fetched result and try to decompress it to detect data corruption, // then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch // is also corrupt, so the previous stage could be retried. @@ -463,25 +466,22 @@ final class ShuffleBlockFetcherIterator( buf.release() throwFetchFailedException(blockId, address, e) } - var isStreamCopied: Boolean = false try { input = streamWrapper(blockId, in) - // Only copy the stream if it's wrapped by compression or encryption, also the size of - // block is small (the decompressed block is smaller than maxBytesInFlight) - if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight / 3) { - isStreamCopied = true - val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate) - // Decompress the whole block at once to detect any corruption, which could increase - // the memory usage tne potential increase the chance of OOM. + // If the stream is compressed or wrapped, then we optionally decompress/unwrap the + // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion + // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if + // the corruption is later, we'll still detect the corruption later in the stream. + streamCompressedOrEncrypted = !input.eq(in) + if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) { // TODO: manage the memory used here, and spill it into disk in case of OOM. - Utils.copyStream(input, out, closeStreams = true) - input = out.toChunkedByteBuffer.toInputStream(dispose = true) + input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3) } } catch { case e: IOException => buf.release() if (buf.isInstanceOf[FileSegmentManagedBuffer] - || corruptedBlocks.contains(blockId)) { + || corruptedBlocks.contains(blockId)) { throwFetchFailedException(blockId, address, e) } else { logWarning(s"got an corrupted block $blockId from $address, fetch again", e) @@ -491,7 +491,9 @@ final class ShuffleBlockFetcherIterator( } } finally { // TODO: release the buf here to free memory earlier - if (isStreamCopied) { + if (input == null) { + // Close the underlying stream if there was an issue in wrapping the stream using + // streamWrapper in.close() } } @@ -508,7 +510,13 @@ final class ShuffleBlockFetcherIterator( throw new NoSuchElementException() } currentResult = result.asInstanceOf[SuccessFetchResult] - (currentResult.blockId, new BufferReleasingInputStream(input, this)) + (currentResult.blockId, + new BufferReleasingInputStream( + input, + this, + currentResult.blockId, + currentResult.address, + detectCorrupt && streamCompressedOrEncrypted)) } def toCompletionIterator: Iterator[(BlockId, InputStream)] = { @@ -571,7 +579,10 @@ final class ShuffleBlockFetcherIterator( } } - private def throwFetchFailedException(blockId: BlockId, address: BlockManagerId, e: Throwable) = { + private[storage] def throwFetchFailedException( + blockId: BlockId, + address: BlockManagerId, + e: Throwable) = { blockId match { case ShuffleBlockId(shufId, mapId, reduceId) => throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e) @@ -583,15 +594,28 @@ final class ShuffleBlockFetcherIterator( } /** - * Helper class that ensures a ManagedBuffer is released upon InputStream.close() + * Helper class that ensures a ManagedBuffer is released upon InputStream.close() and + * also detects stream corruption if streamCompressedOrEncrypted is true */ private class BufferReleasingInputStream( - private val delegate: InputStream, - private val iterator: ShuffleBlockFetcherIterator) + // This is visible for testing + private[storage] val delegate: InputStream, + private val iterator: ShuffleBlockFetcherIterator, + private val blockId: BlockId, + private val address: BlockManagerId, + private val detectCorruption: Boolean) extends InputStream { private[this] var closed = false - override def read(): Int = delegate.read() + override def read(): Int = { + try { + delegate.read() + } catch { + case e: IOException if detectCorruption => + IOUtils.closeQuietly(this) + iterator.throwFetchFailedException(blockId, address, e) + } + } override def close(): Unit = { if (!closed) { @@ -605,13 +629,37 @@ private class BufferReleasingInputStream( override def mark(readlimit: Int): Unit = delegate.mark(readlimit) - override def skip(n: Long): Long = delegate.skip(n) + override def skip(n: Long): Long = { + try { + delegate.skip(n) + } catch { + case e: IOException if detectCorruption => + IOUtils.closeQuietly(this) + iterator.throwFetchFailedException(blockId, address, e) + } + } override def markSupported(): Boolean = delegate.markSupported() - override def read(b: Array[Byte]): Int = delegate.read(b) + override def read(b: Array[Byte]): Int = { + try { + delegate.read(b) + } catch { + case e: IOException if detectCorruption => + IOUtils.closeQuietly(this) + iterator.throwFetchFailedException(blockId, address, e) + } + } - override def read(b: Array[Byte], off: Int, len: Int): Int = delegate.read(b, off, len) + override def read(b: Array[Byte], off: Int, len: Int): Int = { + try { + delegate.read(b, off, len) + } catch { + case e: IOException if detectCorruption => + IOUtils.closeQuietly(this) + iterator.throwFetchFailedException(blockId, address, e) + } + } override def reset(): Unit = delegate.reset() } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index cade0dd88fc7..bc5731163afd 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -67,6 +67,7 @@ import org.apache.spark.launcher.SparkLauncher import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} import org.apache.spark.status.api.v1.{StackTrace, ThreadStackTrace} +import org.apache.spark.util.io.ChunkedByteBufferOutputStream /** CallSite represents a place in user code. It can have a short and a long form. */ private[spark] case class CallSite(shortForm: String, longForm: String) @@ -337,6 +338,50 @@ private[spark] object Utils extends Logging { } } + /** + * Copy the first `maxSize` bytes of data from the InputStream to an in-memory + * buffer, primarily to check for corruption. + * + * This returns a new InputStream which contains the same data as the original input stream. + * It may be entirely on in-memory buffer, or it may be a combination of in-memory data, and then + * continue to read from the original stream. The only real use of this is if the original input + * stream will potentially detect corruption while the data is being read (eg. from compression). + * This allows for an eager check of corruption in the first maxSize bytes of data. + * + * @return An InputStream which includes all data from the original stream (combining buffered + * data and remaining data in the original stream) + */ + def copyStreamUpTo(in: InputStream, maxSize: Long): InputStream = { + var count = 0L + val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate) + val fullyCopied = tryWithSafeFinally { + val bufSize = Math.min(8192L, maxSize) + val buf = new Array[Byte](bufSize.toInt) + var n = 0 + while (n != -1 && count < maxSize) { + n = in.read(buf, 0, Math.min(maxSize - count, bufSize).toInt) + if (n != -1) { + out.write(buf, 0, n) + count += n + } + } + count < maxSize + } { + try { + if (count < maxSize) { + in.close() + } + } finally { + out.close() + } + } + if (fullyCopied) { + out.toChunkedByteBuffer.toInputStream(dispose = true) + } else { + new SequenceInputStream( out.toChunkedByteBuffer.toInputStream(dispose = true), in) + } + } + def copyFileStreamNIO( input: FileChannel, output: FileChannel, diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 98fe9663b621..a1c298ae9446 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.storage -import java.io.{File, InputStream, IOException} +import java.io._ +import java.nio.ByteBuffer import java.util.UUID import java.util.concurrent.Semaphore @@ -118,6 +119,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, true, + false, metrics) // 3 local blocks fetched in initialization @@ -197,6 +199,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, true, + false, taskContext.taskMetrics.createTempShuffleReadMetrics()) verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() @@ -265,6 +268,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, true, + false, taskContext.taskMetrics.createTempShuffleReadMetrics()) @@ -325,6 +329,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, true, + false, taskContext.taskMetrics.createTempShuffleReadMetrics()) // Continue only after the mock calls onBlockFetchFailure @@ -337,15 +342,34 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT intercept[FetchFailedException] { iterator.next() } } - private def mockCorruptBuffer(size: Long = 1L): ManagedBuffer = { - val corruptStream = mock(classOf[InputStream]) - when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) + private def mockCorruptBuffer(size: Long = 1L, corruptAt: Int = 0): ManagedBuffer = { + val corruptStream = new CorruptStream(corruptAt) val corruptBuffer = mock(classOf[ManagedBuffer]) when(corruptBuffer.size()).thenReturn(size) when(corruptBuffer.createInputStream()).thenReturn(corruptStream) corruptBuffer } + private class CorruptStream(corruptAt: Long = 0L) extends InputStream { + var pos = 0 + var closed = false + + override def read(): Int = { + if (pos >= corruptAt) { + throw new IOException("corrupt") + } else { + pos += 1 + pos + } + } + + override def read(dest: Array[Byte], off: Int, len: Int): Int = { + super.read(dest, off, len) + } + + override def close(): Unit = { closed = true } + } + test("retry corrupt blocks") { val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) @@ -396,6 +420,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, true, + true, taskContext.taskMetrics.createTempShuffleReadMetrics()) // Continue only after the mock calls onBlockFetchFailure @@ -425,28 +450,98 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT intercept[FetchFailedException] { iterator.next() } } - test("big blocks are not checked for corruption") { - val corruptBuffer = mockCorruptBuffer(10000L) - + test("big blocks are also checked for corruption") { + val streamLength = 10000L val blockManager = mock(classOf[BlockManager]) - val localBmId = BlockManagerId("test-client", "test-client", 1) - doReturn(localBmId).when(blockManager).blockManagerId - doReturn(corruptBuffer).when(blockManager).getBlockData(ShuffleBlockId(0, 0, 0)) - val localBlockLengths = Seq[Tuple2[BlockId, Long]]( - ShuffleBlockId(0, 0, 0) -> corruptBuffer.size() + val localBlockManagerId = BlockManagerId("local-client", "local-client", 1) + doReturn(localBlockManagerId).when(blockManager).blockManagerId + + // This stream will throw IOException when the first byte is read + val corruptBuffer1 = mockCorruptBuffer(streamLength, 0) + val blockManagerId1 = BlockManagerId("remote-client-1", "remote-client-1", 1) + val shuffleBlockId1 = ShuffleBlockId(0, 1, 0) + val blockLengths1 = Seq[Tuple2[BlockId, Long]]( + shuffleBlockId1 -> corruptBuffer1.size() ) - val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) - val remoteBlockLengths = Seq[Tuple2[BlockId, Long]]( - ShuffleBlockId(0, 1, 0) -> corruptBuffer.size() + val streamNotCorruptTill = 8 * 1024 + // This stream will throw exception after streamNotCorruptTill bytes are read + val corruptBuffer2 = mockCorruptBuffer(streamLength, streamNotCorruptTill) + val blockManagerId2 = BlockManagerId("remote-client-2", "remote-client-2", 2) + val shuffleBlockId2 = ShuffleBlockId(0, 2, 0) + val blockLengths2 = Seq[Tuple2[BlockId, Long]]( + shuffleBlockId2 -> corruptBuffer2.size() ) val transfer = createMockTransfer( - Map(ShuffleBlockId(0, 0, 0) -> corruptBuffer, ShuffleBlockId(0, 1, 0) -> corruptBuffer)) + Map(shuffleBlockId1 -> corruptBuffer1, shuffleBlockId2 -> corruptBuffer2)) + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (blockManagerId1, blockLengths1), + (blockManagerId2, blockLengths2) + ).toIterator + val taskContext = TaskContext.empty() + val maxBytesInFlight = 3 * 1024 + val iterator = new ShuffleBlockFetcherIterator( + taskContext, + transfer, + blockManager, + blocksByAddress, + (_, in) => new LimitedInputStream(in, streamLength), + maxBytesInFlight, + Int.MaxValue, + Int.MaxValue, + Int.MaxValue, + true, + true, + taskContext.taskMetrics.createTempShuffleReadMetrics()) + + // We'll get back the block which has corruption after maxBytesInFlight/3 because the other + // block will detect corruption on first fetch, and then get added to the queue again for + // a retry + val (id, st) = iterator.next() + assert(id === shuffleBlockId2) + + // The other block will throw a FetchFailedException + intercept[FetchFailedException] { + iterator.next() + } + + // Following will succeed as it reads part of the stream which is not corrupt. This will read + // maxBytesInFlight/3 bytes from the portion copied into memory, and remaining from the + // underlying stream + new DataInputStream(st).readFully( + new Array[Byte](streamNotCorruptTill), 0, streamNotCorruptTill) + + // Following will fail as it reads the remaining part of the stream which is corrupt + intercept[FetchFailedException] { st.read() } + + // Buffers are mocked and they return the original input corrupt streams + assert(corruptBuffer1.createInputStream().asInstanceOf[CorruptStream].closed) + assert(corruptBuffer2.createInputStream().asInstanceOf[CorruptStream].closed) + } + test("ensure big blocks available as a concatenated stream can be read") { + val tmpDir = Utils.createTempDir() + val tmpFile = new File(tmpDir, "someFile.txt") + val os = new FileOutputStream(tmpFile) + val buf = ByteBuffer.allocate(10000) + for (i <- 1 to 2500) { + buf.putInt(i) + } + os.write(buf.array()) + os.close() + val managedBuffer = new FileSegmentManagedBuffer(null, tmpFile, 0, 10000) + + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + doReturn(managedBuffer).when(blockManager).getBlockData(ShuffleBlockId(0, 0, 0)) + val localBlockLengths = Seq[Tuple2[BlockId, Long]]( + ShuffleBlockId(0, 0, 0) -> 10000 + ) + val transfer = createMockTransfer(Map(ShuffleBlockId(0, 0, 0) -> managedBuffer)) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (localBmId, localBlockLengths), - (remoteBmId, remoteBlockLengths) + (localBmId, localBlockLengths) ).toIterator val taskContext = TaskContext.empty() @@ -461,10 +556,18 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, true, + true, taskContext.taskMetrics.createTempShuffleReadMetrics()) - // Blocks should be returned without exceptions. - assert(Set(iterator.next()._1, iterator.next()._1) === - Set(ShuffleBlockId(0, 0, 0), ShuffleBlockId(0, 1, 0))) + val (id, st) = iterator.next() + // Check that the test setup is correct -- make sure we have a concatenated stream. + assert (st.asInstanceOf[BufferReleasingInputStream].delegate.isInstanceOf[SequenceInputStream]) + + val dst = new DataInputStream(st) + for (i <- 1 to 2500) { + assert(i === dst.readInt()) + } + assert(dst.read() === -1) + dst.close() } test("retry corrupt blocks (disabled)") { @@ -515,6 +618,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, + true, false, taskContext.taskMetrics.createTempShuffleReadMetrics()) @@ -578,6 +682,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT maxBlocksInFlightPerAddress = Int.MaxValue, maxReqSizeShuffleToMem = 200, detectCorrupt = true, + false, taskContext.taskMetrics.createTempShuffleReadMetrics()) } @@ -625,6 +730,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, true, + false, taskContext.taskMetrics.createTempShuffleReadMetrics()) // All blocks fetched return zero length and should trigger a receive-side error: diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 188e3f6907da..d2d9eb06339c 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -18,8 +18,9 @@ package org.apache.spark.util import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataOutput, DataOutputStream, File, - FileOutputStream, PrintStream} + FileOutputStream, InputStream, PrintStream, SequenceInputStream} import java.lang.{Double => JDouble, Float => JFloat} +import java.lang.reflect.Field import java.net.{BindException, ServerSocket, URI} import java.nio.{ByteBuffer, ByteOrder} import java.nio.charset.StandardCharsets @@ -43,6 +44,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.SparkListener +import org.apache.spark.util.io.ChunkedByteBufferInputStream class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { @@ -211,6 +213,56 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assert(os.toByteArray.toList.equals(bytes.toList)) } + test("copyStreamUpTo") { + // input array initialization + val bytes = Array.ofDim[Byte](1200) + Random.nextBytes(bytes) + + val limit = 1000 + // testing for inputLength less than, equal to and greater than limit + (limit - 2 to limit + 2).foreach { inputLength => + val in = new ByteArrayInputStream(bytes.take(inputLength)) + val mergedStream = Utils.copyStreamUpTo(in, limit) + try { + // Get a handle on the buffered data, to make sure memory gets freed once we read past the + // end of it. Need to use reflection to get handle on inner structures for this check + val byteBufferInputStream = if (mergedStream.isInstanceOf[ChunkedByteBufferInputStream]) { + assert(inputLength < limit) + mergedStream.asInstanceOf[ChunkedByteBufferInputStream] + } else { + assert(inputLength >= limit) + val sequenceStream = mergedStream.asInstanceOf[SequenceInputStream] + val fieldValue = getFieldValue(sequenceStream, "in") + assert(fieldValue.isInstanceOf[ChunkedByteBufferInputStream]) + fieldValue.asInstanceOf[ChunkedByteBufferInputStream] + } + (0 until inputLength).foreach { idx => + assert(bytes(idx) === mergedStream.read().asInstanceOf[Byte]) + if (idx == limit) { + assert(byteBufferInputStream.chunkedByteBuffer === null) + } + } + assert(mergedStream.read() === -1) + assert(byteBufferInputStream.chunkedByteBuffer === null) + } finally { + IOUtils.closeQuietly(mergedStream) + IOUtils.closeQuietly(in) + } + } + } + + private def getFieldValue(obj: AnyRef, fieldName: String): Any = { + val field: Field = obj.getClass().getDeclaredField(fieldName) + if (field.isAccessible()) { + field.get(obj) + } else { + field.setAccessible(true) + val result = field.get(obj) + field.setAccessible(false) + result + } + } + test("memoryStringToMb") { assert(Utils.memoryStringToMb("1") === 0) assert(Utils.memoryStringToMb("1048575") === 0)