Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.serializer.SerializerManager;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
Expand Down Expand Up @@ -163,19 +164,22 @@ public final class BytesToBytesMap extends MemoryConsumer {
private long peakMemoryUsedBytes = 0L;

private final BlockManager blockManager;
private final SerializerManager serializerManager;
private volatile MapIterator destructiveIterator = null;
private LinkedList<UnsafeSorterSpillWriter> spillWriters = new LinkedList<>();

public BytesToBytesMap(
TaskMemoryManager taskMemoryManager,
BlockManager blockManager,
SerializerManager serializerManager,
int initialCapacity,
double loadFactor,
long pageSizeBytes,
boolean enablePerfMetrics) {
super(taskMemoryManager, pageSizeBytes);
this.taskMemoryManager = taskMemoryManager;
this.blockManager = blockManager;
this.serializerManager = serializerManager;
this.loadFactor = loadFactor;
this.loc = new Location();
this.pageSizeBytes = pageSizeBytes;
Expand Down Expand Up @@ -209,6 +213,7 @@ public BytesToBytesMap(
this(
taskMemoryManager,
SparkEnv.get() != null ? SparkEnv.get().blockManager() : null,
SparkEnv.get() != null ? SparkEnv.get().serializerManager() : null,
initialCapacity,
0.70,
pageSizeBytes,
Expand Down Expand Up @@ -271,7 +276,7 @@ private void advanceToNextPage() {
}
try {
Closeables.close(reader, /* swallowIOException = */ false);
reader = spillWriters.getFirst().getReader(blockManager);
reader = spillWriters.getFirst().getReader(serializerManager);
recordsInPage = -1;
} catch (IOException e) {
// Scala iterator does not handle exception
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.serializer.SerializerManager;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.LongArray;
Expand All @@ -51,6 +52,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
private final RecordComparator recordComparator;
private final TaskMemoryManager taskMemoryManager;
private final BlockManager blockManager;
private final SerializerManager serializerManager;
private final TaskContext taskContext;
private ShuffleWriteMetrics writeMetrics;

Expand Down Expand Up @@ -78,14 +80,16 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
public static UnsafeExternalSorter createWithExistingInMemorySorter(
TaskMemoryManager taskMemoryManager,
BlockManager blockManager,
SerializerManager serializerManager,
TaskContext taskContext,
RecordComparator recordComparator,
PrefixComparator prefixComparator,
int initialSize,
long pageSizeBytes,
UnsafeInMemorySorter inMemorySorter) throws IOException {
UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager,
taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, inMemorySorter);
serializerManager, taskContext, recordComparator, prefixComparator, initialSize,
pageSizeBytes, inMemorySorter);
sorter.spill(Long.MAX_VALUE, sorter);
// The external sorter will be used to insert records, in-memory sorter is not needed.
sorter.inMemSorter = null;
Expand All @@ -95,18 +99,20 @@ public static UnsafeExternalSorter createWithExistingInMemorySorter(
public static UnsafeExternalSorter create(
TaskMemoryManager taskMemoryManager,
BlockManager blockManager,
SerializerManager serializerManager,
TaskContext taskContext,
RecordComparator recordComparator,
PrefixComparator prefixComparator,
int initialSize,
long pageSizeBytes) {
return new UnsafeExternalSorter(taskMemoryManager, blockManager,
return new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager,
taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null);
}

private UnsafeExternalSorter(
TaskMemoryManager taskMemoryManager,
BlockManager blockManager,
SerializerManager serializerManager,
TaskContext taskContext,
RecordComparator recordComparator,
PrefixComparator prefixComparator,
Expand All @@ -116,6 +122,7 @@ private UnsafeExternalSorter(
super(taskMemoryManager, pageSizeBytes);
this.taskMemoryManager = taskMemoryManager;
this.blockManager = blockManager;
this.serializerManager = serializerManager;
this.taskContext = taskContext;
this.recordComparator = recordComparator;
this.prefixComparator = prefixComparator;
Expand Down Expand Up @@ -412,7 +419,7 @@ public UnsafeSorterIterator getSortedIterator() throws IOException {
final UnsafeSorterSpillMerger spillMerger =
new UnsafeSorterSpillMerger(recordComparator, prefixComparator, spillWriters.size());
for (UnsafeSorterSpillWriter spillWriter : spillWriters) {
spillMerger.addSpillIfNotEmpty(spillWriter.getReader(blockManager));
spillMerger.addSpillIfNotEmpty(spillWriter.getReader(serializerManager));
}
if (inMemSorter != null) {
readingIterator = new SpillableIterator(inMemSorter.getSortedIterator());
Expand Down Expand Up @@ -463,7 +470,7 @@ public long spill() throws IOException {
}
spillWriter.close();
spillWriters.add(spillWriter);
nextUpstream = spillWriter.getReader(blockManager);
nextUpstream = spillWriter.getReader(serializerManager);

long released = 0L;
synchronized (UnsafeExternalSorter.this) {
Expand Down Expand Up @@ -549,7 +556,7 @@ public UnsafeSorterIterator getIterator() throws IOException {
} else {
LinkedList<UnsafeSorterIterator> queue = new LinkedList<>();
for (UnsafeSorterSpillWriter spillWriter : spillWriters) {
queue.add(spillWriter.getReader(blockManager));
queue.add(spillWriter.getReader(serializerManager));
}
if (inMemSorter != null) {
queue.add(inMemSorter.getSortedIterator());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import com.google.common.io.ByteStreams;
import com.google.common.io.Closeables;

import org.apache.spark.serializer.SerializerManager;
import org.apache.spark.storage.BlockId;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.unsafe.Platform;

/**
Expand All @@ -46,13 +46,13 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen
private final long baseOffset = Platform.BYTE_ARRAY_OFFSET;

public UnsafeSorterSpillReader(
BlockManager blockManager,
SerializerManager serializerManager,
File file,
BlockId blockId) throws IOException {
assert (file.length() > 0);
final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file));
try {
this.in = blockManager.wrapForCompression(blockId, bs);
this.in = serializerManager.wrapForCompression(blockId, bs);
this.din = new DataInputStream(this.in);
numRecords = numRecordsRemaining = din.readInt();
} catch (IOException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.io.File;
import java.io.IOException;

import org.apache.spark.serializer.SerializerManager;
import scala.Tuple2;

import org.apache.spark.executor.ShuffleWriteMetrics;
Expand Down Expand Up @@ -144,7 +145,7 @@ public File getFile() {
return file;
}

public UnsafeSorterSpillReader getReader(BlockManager blockManager) throws IOException {
return new UnsafeSorterSpillReader(blockManager, file, blockId);
public UnsafeSorterSpillReader getReader(SerializerManager serializerManager) throws IOException {
return new UnsafeSorterSpillReader(serializerManager, file, blockId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,25 @@

package org.apache.spark.serializer

import java.io.{BufferedInputStream, BufferedOutputStream, InputStream, OutputStream}
import java.nio.ByteBuffer

import scala.reflect.ClassTag

import org.apache.spark.SparkConf
import org.apache.spark.io.CompressionCodec
import org.apache.spark.storage._
import org.apache.spark.util.io.{ByteArrayChunkOutputStream, ChunkedByteBuffer}

/**
* Component that selects which [[Serializer]] to use for shuffles.
* Component which configures serialization and compression for various Spark components, including
* automatic selection of which [[Serializer]] to use for shuffles.
*/
private[spark] class SerializerManager(defaultSerializer: Serializer, conf: SparkConf) {

private[this] val kryoSerializer = new KryoSerializer(conf)

private[this] val stringClassTag: ClassTag[String] = implicitly[ClassTag[String]]
private[this] val primitiveAndPrimitiveArrayClassTags: Set[ClassTag[_]] = {
val primitiveClassTags = Set[ClassTag[_]](
ClassTag.Boolean,
Expand All @@ -44,7 +52,21 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
primitiveClassTags ++ arrayClassTags
}

private[this] val stringClassTag: ClassTag[String] = implicitly[ClassTag[String]]
// Whether to compress broadcast variables that are stored
private[this] val compressBroadcast = conf.getBoolean("spark.broadcast.compress", true)
// Whether to compress shuffle output that are stored
private[this] val compressShuffle = conf.getBoolean("spark.shuffle.compress", true)
// Whether to compress RDD partitions that are stored serialized
private[this] val compressRdds = conf.getBoolean("spark.rdd.compress", false)
// Whether to compress shuffle output temporarily spilled to disk
private[this] val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true)

/* The compression codec to use. Note that the "lazy" val is necessary because we want to delay
* the initialization of the compression codec until it is first used. The reason is that a Spark
* program could be using a user-defined codec in a third party jar, which is loaded in
* Executor.updateDependencies. When the BlockManager is initialized, user level jars hasn't been
* loaded yet. */
private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf)

private def canUseKryo(ct: ClassTag[_]): Boolean = {
primitiveAndPrimitiveArrayClassTags.contains(ct) || ct == stringClassTag
Expand All @@ -68,4 +90,68 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
defaultSerializer
}
}

private def shouldCompress(blockId: BlockId): Boolean = {
blockId match {
case _: ShuffleBlockId => compressShuffle
case _: BroadcastBlockId => compressBroadcast
case _: RDDBlockId => compressRdds
case _: TempLocalBlockId => compressShuffleSpill
case _: TempShuffleBlockId => compressShuffle
case _ => false
}
}

/**
* Wrap an output stream for compression if block compression is enabled for its block type
*/
def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = {
if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s
}

/**
* Wrap an input stream for compression if block compression is enabled for its block type
*/
def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = {
if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s
}

/** Serializes into a stream. */
def dataSerializeStream[T: ClassTag](
blockId: BlockId,
outputStream: OutputStream,
values: Iterator[T]): Unit = {
val byteStream = new BufferedOutputStream(outputStream)
val ser = getSerializer(implicitly[ClassTag[T]]).newInstance()
ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
}

/** Serializes into a chunked byte buffer. */
def dataSerialize[T: ClassTag](blockId: BlockId, values: Iterator[T]): ChunkedByteBuffer = {
val byteArrayChunkOutputStream = new ByteArrayChunkOutputStream(1024 * 1024 * 4)
dataSerializeStream(blockId, byteArrayChunkOutputStream, values)
new ChunkedByteBuffer(byteArrayChunkOutputStream.toArrays.map(ByteBuffer.wrap))
}

/**
* Deserializes a ByteBuffer into an iterator of values and disposes of it when the end of
* the iterator is reached.
*/
def dataDeserialize[T: ClassTag](blockId: BlockId, bytes: ChunkedByteBuffer): Iterator[T] = {
dataDeserializeStream[T](blockId, bytes.toInputStream(dispose = true))
}

/**
* Deserializes a InputStream into an iterator of values and disposes of it when the end of
* the iterator is reached.
*/
def dataDeserializeStream[T: ClassTag](
blockId: BlockId,
inputStream: InputStream): Iterator[T] = {
val stream = new BufferedInputStream(inputStream)
getSerializer(implicitly[ClassTag[T]])
.newInstance()
.deserializeStream(wrapForCompression(blockId, stream))
.asIterator.asInstanceOf[Iterator[T]]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.shuffle

import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.serializer.Serializer
import org.apache.spark.serializer.SerializerManager
import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator}
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter
Expand All @@ -33,6 +33,7 @@ private[spark] class BlockStoreShuffleReader[K, C](
startPartition: Int,
endPartition: Int,
context: TaskContext,
serializerManager: SerializerManager = SparkEnv.get.serializerManager,
blockManager: BlockManager = SparkEnv.get.blockManager,
mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
extends ShuffleReader[K, C] with Logging {
Expand All @@ -52,7 +53,7 @@ private[spark] class BlockStoreShuffleReader[K, C](

// Wrap the streams for compression based on configuration
val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
blockManager.wrapForCompression(blockId, inputStream)
serializerManager.wrapForCompression(blockId, inputStream)
}

val serializerInstance = dep.serializer.newInstance()
Expand Down
Loading