Skip to content

Commit 2e03ee6

Browse files
author
Marcelo Vanzin
committed
[SPARK-18546][core] Fix merging shuffle spills when using encryption.
The problem exists because it's not possible to just concatenate encrypted partition data from different spill files; currently each partition would have its own initial vector to set up encryption, and the final merged file should contain a single initial vector for each merged partiton, otherwise iterating over each record becomes really hard. To fix that, UnsafeShuffleWriter now decrypts the partitions when merging, so that the merged file contains a single initial vector at the start of the partition data. Because it's not possible to do that using the fast transferTo path, when encryption is enabled UnsafeShuffleWriter will revert back to using file streams when merging. It may be possible to use a hybrid approach when using encryption, using an intermediate direct buffer when reading from files and encrypting the data, but that's better left for a separate patch. As part of the change I made DiskBlockObjectWriter take a SerializerManager instead of a "wrap stream" closure, since that makes it easier to test the code without having to mock SerializerManager functionality. Tested with newly added unit tests (UnsafeShuffleWriterSuite for the write side and ExternalAppendOnlyMapSuite for integration), and by running some apps that failed without the fix.
1 parent 1a87009 commit 2e03ee6

File tree

10 files changed

+133
-118
lines changed

10 files changed

+133
-118
lines changed

core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
import org.apache.spark.executor.ShuffleWriteMetrics;
4141
import org.apache.spark.io.CompressionCodec;
4242
import org.apache.spark.io.CompressionCodec$;
43+
import org.apache.commons.io.output.CloseShieldOutputStream;
44+
import org.apache.commons.io.output.CountingOutputStream;
4345
import org.apache.spark.memory.TaskMemoryManager;
4446
import org.apache.spark.network.util.LimitedInputStream;
4547
import org.apache.spark.scheduler.MapStatus;
@@ -264,6 +266,7 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti
264266
sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true);
265267
final boolean fastMergeIsSupported = !compressionEnabled ||
266268
CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec);
269+
final boolean encryptionEnabled = blockManager.serializerManager().encryptionKey().isDefined();
267270
try {
268271
if (spills.length == 0) {
269272
new FileOutputStream(outputFile).close(); // Create an empty file
@@ -289,7 +292,7 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti
289292
// Compression is disabled or we are using an IO compression codec that supports
290293
// decompression of concatenated compressed streams, so we can perform a fast spill merge
291294
// that doesn't need to interpret the spilled bytes.
292-
if (transferToEnabled) {
295+
if (transferToEnabled && !encryptionEnabled) {
293296
logger.debug("Using transferTo-based fast merge");
294297
partitionLengths = mergeSpillsWithTransferTo(spills, outputFile);
295298
} else {
@@ -337,42 +340,38 @@ private long[] mergeSpillsWithFileStream(
337340
final int numPartitions = partitioner.numPartitions();
338341
final long[] partitionLengths = new long[numPartitions];
339342
final InputStream[] spillInputStreams = new FileInputStream[spills.length];
340-
OutputStream mergedFileOutputStream = null;
343+
final CountingOutputStream mergedFileOutputStream = new CountingOutputStream(
344+
new FileOutputStream(outputFile));
341345

342346
boolean threwException = true;
343347
try {
344348
for (int i = 0; i < spills.length; i++) {
345349
spillInputStreams[i] = new FileInputStream(spills[i].file);
346350
}
347351
for (int partition = 0; partition < numPartitions; partition++) {
348-
final long initialFileLength = outputFile.length();
349-
mergedFileOutputStream =
350-
new TimeTrackingOutputStream(writeMetrics, new FileOutputStream(outputFile, true));
352+
final long initialFileLength = mergedFileOutputStream.getByteCount();
353+
OutputStream partitionOutput = new CloseShieldOutputStream(mergedFileOutputStream);
354+
partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput);
351355
if (compressionCodec != null) {
352-
mergedFileOutputStream = compressionCodec.compressedOutputStream(mergedFileOutputStream);
356+
partitionOutput = compressionCodec.compressedOutputStream(partitionOutput);
353357
}
354-
358+
partitionOutput = new TimeTrackingOutputStream(writeMetrics, partitionOutput);
355359
for (int i = 0; i < spills.length; i++) {
356360
final long partitionLengthInSpill = spills[i].partitionLengths[partition];
357361
if (partitionLengthInSpill > 0) {
358-
InputStream partitionInputStream = null;
359-
boolean innerThrewException = true;
360-
try {
361-
partitionInputStream =
362-
new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill, false);
363-
if (compressionCodec != null) {
364-
partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream);
365-
}
366-
ByteStreams.copy(partitionInputStream, mergedFileOutputStream);
367-
innerThrewException = false;
368-
} finally {
369-
Closeables.close(partitionInputStream, innerThrewException);
362+
InputStream partitionInputStream = new LimitedInputStream(spillInputStreams[i],
363+
partitionLengthInSpill, false);
364+
partitionInputStream = blockManager.serializerManager().wrapForEncryption(
365+
partitionInputStream);
366+
if (compressionCodec != null) {
367+
partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream);
370368
}
369+
ByteStreams.copy(partitionInputStream, partitionOutput);
371370
}
372371
}
373-
mergedFileOutputStream.flush();
374-
mergedFileOutputStream.close();
375-
partitionLengths[partition] = (outputFile.length() - initialFileLength);
372+
partitionOutput.flush();
373+
partitionOutput.close();
374+
partitionLengths[partition] = (mergedFileOutputStream.getByteCount() - initialFileLength);
376375
}
377376
threwException = false;
378377
} finally {

core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStrea
3636
private[spark] class SerializerManager(
3737
defaultSerializer: Serializer,
3838
conf: SparkConf,
39-
encryptionKey: Option[Array[Byte]]) {
39+
val encryptionKey: Option[Array[Byte]]) {
4040

4141
def this(defaultSerializer: Serializer, conf: SparkConf) = this(defaultSerializer, conf, None)
4242

@@ -126,7 +126,7 @@ private[spark] class SerializerManager(
126126
/**
127127
* Wrap an input stream for encryption if shuffle encryption is enabled
128128
*/
129-
private[this] def wrapForEncryption(s: InputStream): InputStream = {
129+
def wrapForEncryption(s: InputStream): InputStream = {
130130
encryptionKey
131131
.map { key => CryptoStreamUtils.createCryptoInputStream(s, conf, key) }
132132
.getOrElse(s)
@@ -135,7 +135,7 @@ private[spark] class SerializerManager(
135135
/**
136136
* Wrap an output stream for encryption if shuffle encryption is enabled
137137
*/
138-
private[this] def wrapForEncryption(s: OutputStream): OutputStream = {
138+
def wrapForEncryption(s: OutputStream): OutputStream = {
139139
encryptionKey
140140
.map { key => CryptoStreamUtils.createCryptoOutputStream(s, conf, key) }
141141
.getOrElse(s)
@@ -144,14 +144,14 @@ private[spark] class SerializerManager(
144144
/**
145145
* Wrap an output stream for compression if block compression is enabled for its block type
146146
*/
147-
private[this] def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = {
147+
def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = {
148148
if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s
149149
}
150150

151151
/**
152152
* Wrap an input stream for compression if block compression is enabled for its block type
153153
*/
154-
private[this] def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = {
154+
def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = {
155155
if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s
156156
}
157157

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ private[spark] class BlockManager(
6262
executorId: String,
6363
rpcEnv: RpcEnv,
6464
val master: BlockManagerMaster,
65-
serializerManager: SerializerManager,
65+
val serializerManager: SerializerManager,
6666
val conf: SparkConf,
6767
memoryManager: MemoryManager,
6868
mapOutputTracker: MapOutputTracker,
@@ -745,9 +745,8 @@ private[spark] class BlockManager(
745745
serializerInstance: SerializerInstance,
746746
bufferSize: Int,
747747
writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = {
748-
val wrapStream: OutputStream => OutputStream = serializerManager.wrapStream(blockId, _)
749748
val syncWrites = conf.getBoolean("spark.shuffle.sync", false)
750-
new DiskBlockObjectWriter(file, serializerInstance, bufferSize, wrapStream,
749+
new DiskBlockObjectWriter(file, serializerManager, serializerInstance, bufferSize,
751750
syncWrites, writeMetrics, blockId)
752751
}
753752

core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import java.nio.channels.FileChannel
2222

2323
import org.apache.spark.executor.ShuffleWriteMetrics
2424
import org.apache.spark.internal.Logging
25-
import org.apache.spark.serializer.{SerializationStream, SerializerInstance}
25+
import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager}
2626
import org.apache.spark.util.Utils
2727

2828
/**
@@ -37,9 +37,9 @@ import org.apache.spark.util.Utils
3737
*/
3838
private[spark] class DiskBlockObjectWriter(
3939
val file: File,
40+
serializerManager: SerializerManager,
4041
serializerInstance: SerializerInstance,
4142
bufferSize: Int,
42-
wrapStream: OutputStream => OutputStream,
4343
syncWrites: Boolean,
4444
// These write metrics concurrently shared with other active DiskBlockObjectWriters who
4545
// are themselves performing writes. All updates must be relative.
@@ -116,7 +116,7 @@ private[spark] class DiskBlockObjectWriter(
116116
initialized = true
117117
}
118118

119-
bs = wrapStream(mcs)
119+
bs = serializerManager.wrapStream(blockId, mcs)
120120
objOut = serializerInstance.serializeStream(bs)
121121
streamOpen = true
122122
this

core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java

Lines changed: 62 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import java.io.*;
2121
import java.nio.ByteBuffer;
22+
import java.security.PrivilegedExceptionAction;
2223
import java.util.*;
2324

2425
import scala.Option;
@@ -40,9 +41,11 @@
4041
import org.mockito.stubbing.Answer;
4142

4243
import org.apache.spark.HashPartitioner;
44+
import org.apache.spark.SecurityManager;
4345
import org.apache.spark.ShuffleDependency;
4446
import org.apache.spark.SparkConf;
4547
import org.apache.spark.TaskContext;
48+
import org.apache.spark.deploy.SparkHadoopUtil;
4649
import org.apache.spark.executor.ShuffleWriteMetrics;
4750
import org.apache.spark.executor.TaskMetrics;
4851
import org.apache.spark.io.CompressionCodec$;
@@ -53,6 +56,7 @@
5356
import org.apache.spark.memory.TestMemoryManager;
5457
import org.apache.spark.network.util.LimitedInputStream;
5558
import org.apache.spark.scheduler.MapStatus;
59+
import org.apache.spark.security.CryptoStreamUtils;
5660
import org.apache.spark.serializer.*;
5761
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
5862
import org.apache.spark.storage.*;
@@ -77,7 +81,6 @@ public class UnsafeShuffleWriterSuite {
7781
final LinkedList<File> spillFilesCreated = new LinkedList<>();
7882
SparkConf conf;
7983
final Serializer serializer = new KryoSerializer(new SparkConf());
80-
final SerializerManager serializerManager = new SerializerManager(serializer, new SparkConf());
8184
TaskMetrics taskMetrics;
8285

8386
@Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
@@ -86,17 +89,6 @@ public class UnsafeShuffleWriterSuite {
8689
@Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
8790
@Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency<Object, Object, Object> shuffleDep;
8891

89-
private final class WrapStream extends AbstractFunction1<OutputStream, OutputStream> {
90-
@Override
91-
public OutputStream apply(OutputStream stream) {
92-
if (conf.getBoolean("spark.shuffle.compress", true)) {
93-
return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(stream);
94-
} else {
95-
return stream;
96-
}
97-
}
98-
}
99-
10092
@After
10193
public void tearDown() {
10294
Utils.deleteRecursively(tempDir);
@@ -121,6 +113,11 @@ public void setUp() throws IOException {
121113
memoryManager = new TestMemoryManager(conf);
122114
taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
123115

116+
// Some tests will override this manager because they change the configuration. This is a
117+
// default for tests that don't need a specific one.
118+
SerializerManager manager = new SerializerManager(serializer, conf);
119+
when(blockManager.serializerManager()).thenReturn(manager);
120+
124121
when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
125122
when(blockManager.getDiskWriter(
126123
any(BlockId.class),
@@ -131,12 +128,11 @@ public void setUp() throws IOException {
131128
@Override
132129
public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable {
133130
Object[] args = invocationOnMock.getArguments();
134-
135131
return new DiskBlockObjectWriter(
136132
(File) args[1],
133+
blockManager.serializerManager(),
137134
(SerializerInstance) args[2],
138135
(Integer) args[3],
139-
new WrapStream(),
140136
false,
141137
(ShuffleWriteMetrics) args[4],
142138
(BlockId) args[0]
@@ -201,9 +197,10 @@ private List<Tuple2<Object, Object>> readRecordsFromFile() throws IOException {
201197
for (int i = 0; i < NUM_PARTITITONS; i++) {
202198
final long partitionSize = partitionSizesInMergedFile[i];
203199
if (partitionSize > 0) {
204-
InputStream in = new FileInputStream(mergedOutputFile);
205-
ByteStreams.skipFully(in, startOffset);
206-
in = new LimitedInputStream(in, partitionSize);
200+
FileInputStream fin = new FileInputStream(mergedOutputFile);
201+
fin.getChannel().position(startOffset);
202+
InputStream in = new LimitedInputStream(fin, partitionSize);
203+
in = blockManager.serializerManager().wrapForEncryption(in);
207204
if (conf.getBoolean("spark.shuffle.compress", true)) {
208205
in = CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in);
209206
}
@@ -294,14 +291,32 @@ public void writeWithoutSpilling() throws Exception {
294291
}
295292

296293
private void testMergingSpills(
297-
boolean transferToEnabled,
298-
String compressionCodecName) throws IOException {
294+
final boolean transferToEnabled,
295+
String compressionCodecName,
296+
boolean encrypt) throws Exception {
299297
if (compressionCodecName != null) {
300298
conf.set("spark.shuffle.compress", "true");
301299
conf.set("spark.io.compression.codec", compressionCodecName);
302300
} else {
303301
conf.set("spark.shuffle.compress", "false");
304302
}
303+
conf.set(org.apache.spark.internal.config.package$.MODULE$.IO_ENCRYPTION_ENABLED(), encrypt);
304+
305+
SerializerManager manager;
306+
if (encrypt) {
307+
manager = new SerializerManager(serializer, conf,
308+
Option.apply(CryptoStreamUtils.createKey(conf)));
309+
} else {
310+
manager = new SerializerManager(serializer, conf);
311+
}
312+
313+
when(blockManager.serializerManager()).thenReturn(manager);
314+
testMergingSpills(transferToEnabled, encrypt);
315+
}
316+
317+
private void testMergingSpills(
318+
boolean transferToEnabled,
319+
boolean encrypted) throws IOException {
305320
final UnsafeShuffleWriter<Object, Object> writer = createWriter(transferToEnabled);
306321
final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>();
307322
for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) {
@@ -324,6 +339,7 @@ private void testMergingSpills(
324339
for (long size: partitionSizesInMergedFile) {
325340
sumOfPartitionSizes += size;
326341
}
342+
327343
assertEquals(sumOfPartitionSizes, mergedOutputFile.length());
328344

329345
assertEquals(HashMultiset.create(dataToWrite), HashMultiset.create(readRecordsFromFile()));
@@ -338,42 +354,60 @@ private void testMergingSpills(
338354

339355
@Test
340356
public void mergeSpillsWithTransferToAndLZF() throws Exception {
341-
testMergingSpills(true, LZFCompressionCodec.class.getName());
357+
testMergingSpills(true, LZFCompressionCodec.class.getName(), false);
342358
}
343359

344360
@Test
345361
public void mergeSpillsWithFileStreamAndLZF() throws Exception {
346-
testMergingSpills(false, LZFCompressionCodec.class.getName());
362+
testMergingSpills(false, LZFCompressionCodec.class.getName(), false);
347363
}
348364

349365
@Test
350366
public void mergeSpillsWithTransferToAndLZ4() throws Exception {
351-
testMergingSpills(true, LZ4CompressionCodec.class.getName());
367+
testMergingSpills(true, LZ4CompressionCodec.class.getName(), false);
352368
}
353369

354370
@Test
355371
public void mergeSpillsWithFileStreamAndLZ4() throws Exception {
356-
testMergingSpills(false, LZ4CompressionCodec.class.getName());
372+
testMergingSpills(false, LZ4CompressionCodec.class.getName(), false);
357373
}
358374

359375
@Test
360376
public void mergeSpillsWithTransferToAndSnappy() throws Exception {
361-
testMergingSpills(true, SnappyCompressionCodec.class.getName());
377+
testMergingSpills(true, SnappyCompressionCodec.class.getName(), false);
362378
}
363379

364380
@Test
365381
public void mergeSpillsWithFileStreamAndSnappy() throws Exception {
366-
testMergingSpills(false, SnappyCompressionCodec.class.getName());
382+
testMergingSpills(false, SnappyCompressionCodec.class.getName(), false);
367383
}
368384

369385
@Test
370386
public void mergeSpillsWithTransferToAndNoCompression() throws Exception {
371-
testMergingSpills(true, null);
387+
testMergingSpills(true, null, false);
372388
}
373389

374390
@Test
375391
public void mergeSpillsWithFileStreamAndNoCompression() throws Exception {
376-
testMergingSpills(false, null);
392+
testMergingSpills(false, null, false);
393+
}
394+
395+
@Test
396+
public void mergeSpillsWithCompressionAndEncryption() throws Exception {
397+
// This should actually be translated to a "file stream merge" internally, just have the
398+
// test to make sure that it's the case.
399+
testMergingSpills(true, LZ4CompressionCodec.class.getName(), true);
400+
}
401+
402+
@Test
403+
public void mergeSpillsWithFileStreamAndCompressionAndEncryption() throws Exception {
404+
testMergingSpills(false, LZ4CompressionCodec.class.getName(), true);
405+
}
406+
407+
@Test
408+
public void mergeSpillsWithCompressionAndEncryptionSlowPath() throws Exception {
409+
conf.set("spark.shuffle.unsafe.fastMergeEnabled", "false");
410+
testMergingSpills(false, LZ4CompressionCodec.class.getName(), true);
377411
}
378412

379413
@Test
@@ -531,4 +565,5 @@ public void testPeakMemoryUsed() throws Exception {
531565
writer.stop(false);
532566
}
533567
}
568+
534569
}

0 commit comments

Comments
 (0)