Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 34 additions & 15 deletions core/src/main/scala/org/apache/spark/storage/BlockManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ private[spark] class BlockManager(
// We will drop it to disk later if the memory store can't hold it.
val putSucceeded = if (level.deserialized) {
val values = serializerManager.dataDeserialize(blockId, bytes)(classTag)
memoryStore.putIterator(blockId, values, level, classTag) match {
memoryStore.putIteratorAsValues(blockId, values, classTag) match {
case Right(_) => true
case Left(iter) =>
// If putting deserialized values in memory failed, we will put the bytes directly to
Expand Down Expand Up @@ -876,21 +876,40 @@ private[spark] class BlockManager(
if (level.useMemory) {
// Put it in memory first, even if it also has useDisk set to true;
// We will drop it to disk later if the memory store can't hold it.
memoryStore.putIterator(blockId, iterator(), level, classTag) match {
case Right(s) =>
size = s
case Left(iter) =>
// Not enough space to unroll this block; drop to disk if applicable
if (level.useDisk) {
logWarning(s"Persisting block $blockId to disk instead.")
diskStore.put(blockId) { fileOutputStream =>
serializerManager.dataSerializeStream(blockId, fileOutputStream, iter)(classTag)
if (level.deserialized) {
memoryStore.putIteratorAsValues(blockId, iterator(), classTag) match {
case Right(s) =>
size = s
case Left(iter) =>
// Not enough space to unroll this block; drop to disk if applicable
if (level.useDisk) {
logWarning(s"Persisting block $blockId to disk instead.")
diskStore.put(blockId) { fileOutputStream =>
serializerManager.dataSerializeStream(blockId, fileOutputStream, iter)(classTag)
}
size = diskStore.getSize(blockId)
} else {
iteratorFromFailedMemoryStorePut = Some(iter)
}
size = diskStore.getSize(blockId)
} else {
iteratorFromFailedMemoryStorePut = Some(iter)
}
}
} else { // !level.deserialized
memoryStore.putIteratorAsBytes(blockId, iterator(), classTag) match {
case Right(s) =>
size = s
case Left(partiallySerializedValues) =>
// Not enough space to unroll this block; drop to disk if applicable
if (level.useDisk) {
logWarning(s"Persisting block $blockId to disk instead.")
diskStore.put(blockId) { fileOutputStream =>
partiallySerializedValues.finishWritingToStream(fileOutputStream)
}
size = diskStore.getSize(blockId)
} else {
iteratorFromFailedMemoryStorePut = Some(partiallySerializedValues.valuesIterator)
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

about my previous comment about duplicate code, never mind. It can't actually be abstracted cleanly.

}

} else if (level.useDisk) {
diskStore.put(blockId) { fileOutputStream =>
serializerManager.dataSerializeStream(blockId, fileOutputStream, iterator())(classTag)
Expand Down Expand Up @@ -991,7 +1010,7 @@ private[spark] class BlockManager(
// Note: if we had a means to discard the disk iterator, we would do that here.
memoryStore.getValues(blockId).get
} else {
memoryStore.putIterator(blockId, diskIterator, level, classTag) match {
memoryStore.putIteratorAsValues(blockId, diskIterator, classTag) match {
case Left(iter) =>
// The memory store put() failed, so it returned the iterator back to us:
iter
Expand Down
229 changes: 204 additions & 25 deletions core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,24 @@

package org.apache.spark.storage.memory

import java.io.OutputStream
import java.nio.ByteBuffer
import java.util.LinkedHashMap

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

import com.google.common.io.ByteStreams

import org.apache.spark.{SparkConf, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.memory.MemoryManager
import org.apache.spark.serializer.SerializerManager
import org.apache.spark.serializer.{SerializationStream, SerializerManager}
import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel}
import org.apache.spark.util.{CompletionIterator, SizeEstimator, Utils}
import org.apache.spark.util.collection.SizeTrackingVector
import org.apache.spark.util.io.ChunkedByteBuffer
import org.apache.spark.util.io.{ByteArrayChunkOutputStream, ChunkedByteBuffer}

private sealed trait MemoryEntry[T] {
def size: Long
Expand All @@ -42,8 +46,9 @@ private case class DeserializedMemoryEntry[T](
classTag: ClassTag[T]) extends MemoryEntry[T]
private case class SerializedMemoryEntry[T](
buffer: ChunkedByteBuffer,
size: Long,
classTag: ClassTag[T]) extends MemoryEntry[T]
classTag: ClassTag[T]) extends MemoryEntry[T] {
def size: Long = buffer.size
}

private[storage] trait BlockEvictionHandler {
/**
Expand Down Expand Up @@ -132,7 +137,7 @@ private[spark] class MemoryStore(
// We acquired enough memory for the block, so go ahead and put it
val bytes = _bytes()
assert(bytes.size == size)
val entry = new SerializedMemoryEntry[T](bytes, size, implicitly[ClassTag[T]])
val entry = new SerializedMemoryEntry[T](bytes, implicitly[ClassTag[T]])
entries.synchronized {
entries.put(blockId, entry)
}
Expand All @@ -145,7 +150,7 @@ private[spark] class MemoryStore(
}

/**
* Attempt to put the given block in memory store.
* Attempt to put the given block in memory store as values.
*
* It's possible that the iterator is too large to materialize and store in memory. To avoid
* OOM exceptions, this method will gradually unroll the iterator while periodically checking
Expand All @@ -160,10 +165,9 @@ private[spark] class MemoryStore(
* iterator or call `close()` on it in order to free the storage memory consumed by the
* partially-unrolled block.
*/
private[storage] def putIterator[T](
private[storage] def putIteratorAsValues[T](
blockId: BlockId,
values: Iterator[T],
level: StorageLevel,
classTag: ClassTag[T]): Either[PartiallyUnrolledIterator[T], Long] = {

require(!contains(blockId), s"Block $blockId is already present in the MemoryStore")
Expand Down Expand Up @@ -218,12 +222,8 @@ private[spark] class MemoryStore(
// We successfully unrolled the entirety of this block
val arrayValues = vector.toArray
vector = null
val entry = if (level.deserialized) {
val entry =
new DeserializedMemoryEntry[T](arrayValues, SizeEstimator.estimate(arrayValues), classTag)
} else {
val bytes = serializerManager.dataSerialize(blockId, arrayValues.iterator)(classTag)
new SerializedMemoryEntry[T](bytes, bytes.size, classTag)
}
val size = entry.size
def transferUnrollToStorage(amount: Long): Unit = {
// Synchronize so that transfer is atomic
Expand Down Expand Up @@ -255,12 +255,8 @@ private[spark] class MemoryStore(
entries.synchronized {
entries.put(blockId, entry)
}
val bytesOrValues = if (level.deserialized) "values" else "bytes"
logInfo("Block %s stored as %s in memory (estimated size %s, free %s)".format(
blockId,
bytesOrValues,
Utils.bytesToString(size),
Utils.bytesToString(maxMemory - blocksMemoryUsed)))
logInfo("Block %s stored as values in memory (estimated size %s, free %s)".format(
blockId, Utils.bytesToString(size), Utils.bytesToString(maxMemory - blocksMemoryUsed)))
Right(size)
} else {
assert(currentUnrollMemoryForThisTask >= currentUnrollMemoryForThisTask,
Expand All @@ -279,13 +275,117 @@ private[spark] class MemoryStore(
}
}

/**
* Attempt to put the given block in memory store as bytes.
*
* It's possible that the iterator is too large to materialize and store in memory. To avoid
* OOM exceptions, this method will gradually unroll the iterator while periodically checking
* whether there is enough free memory. If the block is successfully materialized, then the
* temporary unroll memory used during the materialization is "transferred" to storage memory,
* so we won't acquire more memory than is actually needed to store the block.
*
* @return in case of success, the estimated the estimated size of the stored data. In case of
* failure, return a handle which allows the caller to either finish the serialization
* by spilling to disk or to deserialize the partially-serialized block and reconstruct
* the original input iterator. The caller must either fully consume this result
* iterator or call `discard()` on it in order to free the storage memory consumed by the
* partially-unrolled block.
*/
private[storage] def putIteratorAsBytes[T](
blockId: BlockId,
values: Iterator[T],
classTag: ClassTag[T]): Either[PartiallySerializedBlock[T], Long] = {

require(!contains(blockId), s"Block $blockId is already present in the MemoryStore")

// Whether there is still enough memory for us to continue unrolling this block
var keepUnrolling = true
// Initial per-task memory to request for unrolling blocks (bytes).
val initialMemoryThreshold = unrollMemoryThreshold
// Keep track of unroll memory used by this particular block / putIterator() operation
var unrollMemoryUsedByThisBlock = 0L
// Underlying buffer for unrolling the block
val redirectableStream = new RedirectableOutputStream
val byteArrayChunkOutputStream = new ByteArrayChunkOutputStream(initialMemoryThreshold.toInt)
redirectableStream.setOutputStream(byteArrayChunkOutputStream)
val serializationStream: SerializationStream = {
val ser = serializerManager.getSerializer(classTag).newInstance()
ser.serializeStream(serializerManager.wrapForCompression(blockId, redirectableStream))
}

// Request enough memory to begin unrolling
keepUnrolling = reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold)

if (!keepUnrolling) {
logWarning(s"Failed to reserve initial memory threshold of " +
s"${Utils.bytesToString(initialMemoryThreshold)} for computing block $blockId in memory.")
} else {
unrollMemoryUsedByThisBlock += initialMemoryThreshold
}

def reserveAdditionalMemoryIfNecessary(): Unit = {
if (byteArrayChunkOutputStream.size > unrollMemoryUsedByThisBlock) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One important implicit assumption which I will make explicit in a line comment: we assume that we'll always be able to get enough memory to unroll at least one element in between size calculation. This is the same assumption that we have in the deserialized case, since we only periodically measure memory usage there.

val amountToRequest = byteArrayChunkOutputStream.size - unrollMemoryUsedByThisBlock
keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest)
if (keepUnrolling) {
unrollMemoryUsedByThisBlock += amountToRequest
}
}
}

// Unroll this block safely, checking whether we have exceeded our threshold
while (values.hasNext && keepUnrolling) {
serializationStream.writeObject(values.next())(classTag)
reserveAdditionalMemoryIfNecessary()
}

// Make sure that we have enough memory to store the block. By this point, it is possible that
// the block's actual memory usage has exceeded the unroll memory by a small amount, so we
// perform one final call to attempt to allocate additional memory if necessary.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because of the call to close? That can use more memory?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

if (keepUnrolling) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this into line 317?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually here on purpose and deserves a comment. The goal here is to make sure that once we reach line 317 we are guaranteed to have enough memory to store the block. When we finish serializing the block and reach line 311, it's possible that the actual memory usage has exceeded our unroll memory slightly, so here we do one final bumping up of the unroll memory.

serializationStream.close()
reserveAdditionalMemoryIfNecessary()
}

if (keepUnrolling) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why have two of these if's? Can't we merge them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the comment above the first one.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get it. What change in functionality will there be if we moved those 2 lines into this if case? Did you separate them for readability?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After the reserveAdditionalMemoryIfNecessary() call on line 347, keepUnrolling can become false.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see...

val entry = SerializedMemoryEntry[T](
new ChunkedByteBuffer(byteArrayChunkOutputStream.toArrays.map(ByteBuffer.wrap)), classTag)
// Synchronize so that transfer is atomic
memoryManager.synchronized {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see this replicated in a few places (transferUnrollToStorage()). It might be easier to have releaseUnrollMemoryForThisTask take another argument to optionally transfer it to storage

releaseUnrollMemoryForThisTask(unrollMemoryUsedByThisBlock)
val success = memoryManager.acquireStorageMemory(blockId, entry.size)
assert(success, "transferring unroll memory to storage memory failed")
}
entries.synchronized {
entries.put(blockId, entry)
}
logInfo("Block %s stored as bytes in memory (estimated size %s, free %s)".format(
blockId, Utils.bytesToString(entry.size), Utils.bytesToString(blocksMemoryUsed)))
Right(entry.size)
} else {
// We ran out of space while unrolling the values for this block
logUnrollFailureMessage(blockId, byteArrayChunkOutputStream.size)
Left(
new PartiallySerializedBlock(
this,
serializerManager,
blockId,
serializationStream,
redirectableStream,
unrollMemoryUsedByThisBlock,
new ChunkedByteBuffer(byteArrayChunkOutputStream.toArrays.map(ByteBuffer.wrap)),
values,
classTag))
}
}

def getBytes(blockId: BlockId): Option[ChunkedByteBuffer] = {
val entry = entries.synchronized { entries.get(blockId) }
entry match {
case null => None
case e: DeserializedMemoryEntry[_] =>
throw new IllegalArgumentException("should only call getBytes on serialized blocks")
case SerializedMemoryEntry(bytes, _, _) => Some(bytes)
case SerializedMemoryEntry(bytes, _) => Some(bytes)
}
}

Expand Down Expand Up @@ -373,7 +473,7 @@ private[spark] class MemoryStore(
def dropBlock[T](blockId: BlockId, entry: MemoryEntry[T]): Unit = {
val data = entry match {
case DeserializedMemoryEntry(values, _, _) => Left(values)
case SerializedMemoryEntry(buffer, _, _) => Right(buffer)
case SerializedMemoryEntry(buffer, _) => Right(buffer)
}
val newEffectiveStorageLevel =
blockEvictionHandler.dropFromMemory(blockId, () => data)(entry.classTag)
Expand Down Expand Up @@ -507,12 +607,13 @@ private[spark] class MemoryStore(
}

/**
* The result of a failed [[MemoryStore.putIterator()]] call.
* The result of a failed [[MemoryStore.putIteratorAsValues()]] call.
*
* @param memoryStore the memoryStore, used for freeing memory.
* @param memoryStore the memoryStore, used for freeing memory.
* @param unrollMemory the amount of unroll memory used by the values in `unrolled`.
* @param unrolled an iterator for the partially-unrolled values.
* @param rest the rest of the original iterator passed to [[MemoryStore.putIterator()]].
* @param unrolled an iterator for the partially-unrolled values.
* @param rest the rest of the original iterator passed to
* [[MemoryStore.putIteratorAsValues()]].
*/
private[storage] class PartiallyUnrolledIterator[T](
memoryStore: MemoryStore,
Expand Down Expand Up @@ -544,3 +645,81 @@ private[storage] class PartiallyUnrolledIterator[T](
iter = null
}
}

/**
* A wrapper which allows an open [[OutputStream]] to be redirected to a different sink.
*/
private class RedirectableOutputStream extends OutputStream {
private[this] var os: OutputStream = _
def setOutputStream(s: OutputStream): Unit = { os = s }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a whole new class for this? Can we just write (failed stream + the rest of input stream) to the file when we drop it to disk?

override def write(b: Int): Unit = os.write(b)
override def write(b: Array[Byte]): Unit = os.write(b)
override def write(b: Array[Byte], off: Int, len: Int): Unit = os.write(b, off, len)
override def flush(): Unit = os.flush()
override def close(): Unit = os.close()
}

/**
* The result of a failed [[MemoryStore.putIteratorAsBytes()]] call.
*
* @param memoryStore the MemoryStore, used for freeing memory.
* @param serializerManager the SerializerManager, used for deserializing values.
* @param blockId the block id.
* @param serializationStream a serialization stream which writes to [[redirectableOutputStream]].
* @param redirectableOutputStream an OutputStream which can be redirected to a different sink.
* @param unrollMemory the amount of unroll memory used by the values in `unrolled`.
* @param unrolled a byte buffer containing the partially-serialized values.
* @param rest the rest of the original iterator passed to
* [[MemoryStore.putIteratorAsValues()]].
* @param classTag the [[ClassTag]] for the block.
*/
private[storage] class PartiallySerializedBlock[T](
memoryStore: MemoryStore,
serializerManager: SerializerManager,
blockId: BlockId,
serializationStream: SerializationStream,
redirectableOutputStream: RedirectableOutputStream,
unrollMemory: Long,
unrolled: ChunkedByteBuffer,
rest: Iterator[T],
classTag: ClassTag[T]) {

/**
* Called to dispose of this block and free its memory.
*/
def discard(): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

who calls this now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not actually called in the current code, since it's only currently called in one place, BlockManager.doPutIterator(), which returns a PartiallyUnrolledIterator, so that sole callsite ends up calling valuesIterator(), which takes care of the discarding.

try {
serializationStream.close()
} finally {
memoryStore.releaseUnrollMemoryForThisTask(unrollMemory)
}
}

/**
* Finish writing this block to the given output stream by first writing the serialized values
* and then serializing the values from the original input iterator.
*/
def finishWritingToStream(os: OutputStream): Unit = {
ByteStreams.copy(unrolled.toInputStream(), os)
redirectableOutputStream.setOutputStream(os)
while (rest.hasNext) {
serializationStream.writeObject(rest.next())(classTag)
}
discard()
}

/**
* Returns an iterator over the values in this block by first deserializing the serialized
* values and then consuming the rest of the original input iterator.
*
* If the caller does not plan to fully consume the resulting iterator then they must call
* `close()` on it to free its resources.
*/
def valuesIterator: PartiallyUnrolledIterator[T] = {
new PartiallyUnrolledIterator(
memoryStore,
unrollMemory,
unrolled = serializerManager.dataDeserialize(blockId, unrolled)(classTag),
rest = rest)
}
}
Loading