Skip to content

Commit 4b4e329

Browse files
Ferdinand XuMarcelo Vanzin
authored andcommitted
[SPARK-5682][CORE] Add encrypted shuffle in spark
This patch is using Apache Commons Crypto library to enable shuffle encryption support. Author: Ferdinand Xu <cheng.a.xu@intel.com> Author: kellyzly <kellyzly@126.com> Closes #8880 from winningsix/SPARK-10771.
1 parent 2720925 commit 4b4e329

File tree

27 files changed

+478
-28
lines changed

27 files changed

+478
-28
lines changed

core/pom.xml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,10 @@
327327
<groupId>org.apache.spark</groupId>
328328
<artifactId>spark-tags_${scala.binary.version}</artifactId>
329329
</dependency>
330+
<dependency>
331+
<groupId>org.apache.commons</groupId>
332+
<artifactId>commons-crypto</artifactId>
333+
</dependency>
330334
</dependencies>
331335
<build>
332336
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ public UnsafeSorterSpillReader(
7272
final BufferedInputStream bs =
7373
new BufferedInputStream(new FileInputStream(file), (int) bufferSizeBytes);
7474
try {
75-
this.in = serializerManager.wrapForCompression(blockId, bs);
75+
this.in = serializerManager.wrapStream(blockId, bs);
7676
this.din = new DataInputStream(this.in);
7777
numRecords = numRecordsRemaining = din.readInt();
7878
} catch (IOException e) {

core/src/main/scala/org/apache/spark/SecurityManager.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,19 @@ import java.lang.{Byte => JByte}
2121
import java.net.{Authenticator, PasswordAuthentication}
2222
import java.security.{KeyStore, SecureRandom}
2323
import java.security.cert.X509Certificate
24+
import javax.crypto.KeyGenerator
2425
import javax.net.ssl._
2526

2627
import com.google.common.hash.HashCodes
2728
import com.google.common.io.Files
2829
import org.apache.hadoop.io.Text
30+
import org.apache.hadoop.security.Credentials
2931

3032
import org.apache.spark.deploy.SparkHadoopUtil
3133
import org.apache.spark.internal.Logging
34+
import org.apache.spark.internal.config._
3235
import org.apache.spark.network.sasl.SecretKeyHolder
36+
import org.apache.spark.security.CryptoStreamUtils._
3337
import org.apache.spark.util.Utils
3438

3539
/**
@@ -554,4 +558,20 @@ private[spark] object SecurityManager {
554558

555559
// key used to store the spark secret in the Hadoop UGI
556560
val SECRET_LOOKUP_KEY = "sparkCookie"
561+
562+
/**
563+
* Setup the cryptographic key used by IO encryption in credentials. The key is generated using
564+
* [[KeyGenerator]]. The algorithm and key length is specified by the [[SparkConf]].
565+
*/
566+
def initIOEncryptionKey(conf: SparkConf, credentials: Credentials): Unit = {
567+
if (credentials.getSecretKey(SPARK_IO_TOKEN) == null) {
568+
val keyLen = conf.get(IO_ENCRYPTION_KEY_SIZE_BITS)
569+
val ioKeyGenAlgorithm = conf.get(IO_ENCRYPTION_KEYGEN_ALGORITHM)
570+
val keyGen = KeyGenerator.getInstance(ioKeyGenAlgorithm)
571+
keyGen.init(keyLen)
572+
573+
val ioKey = keyGen.generateKey()
574+
credentials.addSecretKey(SPARK_IO_TOKEN, ioKey.getEncoded)
575+
}
576+
}
557577
}

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
4949
import org.apache.spark.input.{FixedLengthBinaryInputFormat, PortableDataStream, StreamInputFormat,
5050
WholeTextFileInputFormat}
5151
import org.apache.spark.internal.Logging
52+
import org.apache.spark.internal.config._
5253
import org.apache.spark.io.CompressionCodec
5354
import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
5455
import org.apache.spark.rdd._
@@ -411,6 +412,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
411412
}
412413

413414
if (master == "yarn" && deployMode == "client") System.setProperty("SPARK_YARN_MODE", "true")
415+
if (_conf.get(IO_ENCRYPTION_ENABLED) && !SparkHadoopUtil.get.isYarnMode()) {
416+
throw new SparkException("IO encryption is only supported in YARN mode, please disable it " +
417+
s"by setting ${IO_ENCRYPTION_ENABLED.key} to false")
418+
}
414419

415420
// "_jobProgressListener" should be set up before creating SparkEnv because when creating
416421
// "SparkEnv", some messages will be posted to "listenerBus" and we should not miss them.

core/src/main/scala/org/apache/spark/internal/config/package.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,4 +119,24 @@ package object config {
119119
private[spark] val UI_RETAINED_TASKS = ConfigBuilder("spark.ui.retainedTasks")
120120
.intConf
121121
.createWithDefault(100000)
122+
123+
private[spark] val IO_ENCRYPTION_ENABLED = ConfigBuilder("spark.io.encryption.enabled")
124+
.booleanConf
125+
.createWithDefault(false)
126+
127+
private[spark] val IO_ENCRYPTION_KEYGEN_ALGORITHM =
128+
ConfigBuilder("spark.io.encryption.keygen.algorithm")
129+
.stringConf
130+
.createWithDefault("HmacSHA1")
131+
132+
private[spark] val IO_ENCRYPTION_KEY_SIZE_BITS = ConfigBuilder("spark.io.encryption.keySizeBits")
133+
.intConf
134+
.checkValues(Set(128, 192, 256))
135+
.createWithDefault(128)
136+
137+
private[spark] val IO_CRYPTO_CIPHER_TRANSFORMATION =
138+
ConfigBuilder("spark.io.crypto.cipher.transformation")
139+
.internal()
140+
.stringConf
141+
.createWithDefaultString("AES/CTR/NoPadding")
122142
}
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.security
18+
19+
import java.io.{InputStream, OutputStream}
20+
import java.util.Properties
21+
import javax.crypto.spec.{IvParameterSpec, SecretKeySpec}
22+
23+
import org.apache.commons.crypto.random._
24+
import org.apache.commons.crypto.stream._
25+
import org.apache.hadoop.io.Text
26+
27+
import org.apache.spark.SparkConf
28+
import org.apache.spark.deploy.SparkHadoopUtil
29+
import org.apache.spark.internal.Logging
30+
import org.apache.spark.internal.config._
31+
32+
/**
33+
* A util class for manipulating IO encryption and decryption streams.
34+
*/
35+
private[spark] object CryptoStreamUtils extends Logging {
36+
/**
37+
* Constants and variables for spark IO encryption
38+
*/
39+
val SPARK_IO_TOKEN = new Text("SPARK_IO_TOKEN")
40+
41+
// The initialization vector length in bytes.
42+
val IV_LENGTH_IN_BYTES = 16
43+
// The prefix of IO encryption related configurations in Spark configuration.
44+
val SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX = "spark.io.encryption.commons.config."
45+
// The prefix for the configurations passing to Apache Commons Crypto library.
46+
val COMMONS_CRYPTO_CONF_PREFIX = "commons.crypto."
47+
48+
/**
49+
* Helper method to wrap [[OutputStream]] with [[CryptoOutputStream]] for encryption.
50+
*/
51+
def createCryptoOutputStream(
52+
os: OutputStream,
53+
sparkConf: SparkConf): OutputStream = {
54+
val properties = toCryptoConf(sparkConf)
55+
val iv = createInitializationVector(properties)
56+
os.write(iv)
57+
val credentials = SparkHadoopUtil.get.getCurrentUserCredentials()
58+
val key = credentials.getSecretKey(SPARK_IO_TOKEN)
59+
val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION)
60+
new CryptoOutputStream(transformationStr, properties, os,
61+
new SecretKeySpec(key, "AES"), new IvParameterSpec(iv))
62+
}
63+
64+
/**
65+
* Helper method to wrap [[InputStream]] with [[CryptoInputStream]] for decryption.
66+
*/
67+
def createCryptoInputStream(
68+
is: InputStream,
69+
sparkConf: SparkConf): InputStream = {
70+
val properties = toCryptoConf(sparkConf)
71+
val iv = new Array[Byte](IV_LENGTH_IN_BYTES)
72+
is.read(iv, 0, iv.length)
73+
val credentials = SparkHadoopUtil.get.getCurrentUserCredentials()
74+
val key = credentials.getSecretKey(SPARK_IO_TOKEN)
75+
val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION)
76+
new CryptoInputStream(transformationStr, properties, is,
77+
new SecretKeySpec(key, "AES"), new IvParameterSpec(iv))
78+
}
79+
80+
/**
81+
* Get Commons-crypto configurations from Spark configurations identified by prefix.
82+
*/
83+
def toCryptoConf(conf: SparkConf): Properties = {
84+
val props = new Properties()
85+
conf.getAll.foreach { case (k, v) =>
86+
if (k.startsWith(SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX)) {
87+
props.put(COMMONS_CRYPTO_CONF_PREFIX + k.substring(
88+
SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX.length()), v)
89+
}
90+
}
91+
props
92+
}
93+
94+
/**
95+
* This method to generate an IV (Initialization Vector) using secure random.
96+
*/
97+
private[this] def createInitializationVector(properties: Properties): Array[Byte] = {
98+
val iv = new Array[Byte](IV_LENGTH_IN_BYTES)
99+
val initialIVStart = System.currentTimeMillis()
100+
CryptoRandomFactory.getCryptoRandom(properties).nextBytes(iv)
101+
val initialIVFinish = System.currentTimeMillis()
102+
val initialIVTime = initialIVFinish - initialIVStart
103+
if (initialIVTime > 2000) {
104+
logWarning(s"It costs ${initialIVTime} milliseconds to create the Initialization Vector " +
105+
s"used by CryptoStream")
106+
}
107+
iv
108+
}
109+
}

core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,15 @@ import java.nio.ByteBuffer
2323
import scala.reflect.ClassTag
2424

2525
import org.apache.spark.SparkConf
26+
import org.apache.spark.internal.config._
2627
import org.apache.spark.io.CompressionCodec
28+
import org.apache.spark.security.CryptoStreamUtils
2729
import org.apache.spark.storage._
2830
import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}
2931

3032
/**
31-
* Component which configures serialization and compression for various Spark components, including
32-
* automatic selection of which [[Serializer]] to use for shuffles.
33+
* Component which configures serialization, compression and encryption for various Spark
34+
* components, including automatic selection of which [[Serializer]] to use for shuffles.
3335
*/
3436
private[spark] class SerializerManager(defaultSerializer: Serializer, conf: SparkConf) {
3537

@@ -61,6 +63,9 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
6163
// Whether to compress shuffle output temporarily spilled to disk
6264
private[this] val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true)
6365

66+
// Whether to enable IO encryption
67+
private[this] val enableIOEncryption = conf.get(IO_ENCRYPTION_ENABLED)
68+
6469
/* The compression codec to use. Note that the "lazy" val is necessary because we want to delay
6570
* the initialization of the compression codec until it is first used. The reason is that a Spark
6671
* program could be using a user-defined codec in a third party jar, which is loaded in
@@ -102,17 +107,45 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
102107
}
103108
}
104109

110+
/**
111+
* Wrap an input stream for encryption and compression
112+
*/
113+
def wrapStream(blockId: BlockId, s: InputStream): InputStream = {
114+
wrapForCompression(blockId, wrapForEncryption(s))
115+
}
116+
117+
/**
118+
* Wrap an output stream for encryption and compression
119+
*/
120+
def wrapStream(blockId: BlockId, s: OutputStream): OutputStream = {
121+
wrapForCompression(blockId, wrapForEncryption(s))
122+
}
123+
124+
/**
125+
* Wrap an input stream for encryption if shuffle encryption is enabled
126+
*/
127+
private[this] def wrapForEncryption(s: InputStream): InputStream = {
128+
if (enableIOEncryption) CryptoStreamUtils.createCryptoInputStream(s, conf) else s
129+
}
130+
131+
/**
132+
* Wrap an output stream for encryption if shuffle encryption is enabled
133+
*/
134+
private[this] def wrapForEncryption(s: OutputStream): OutputStream = {
135+
if (enableIOEncryption) CryptoStreamUtils.createCryptoOutputStream(s, conf) else s
136+
}
137+
105138
/**
106139
* Wrap an output stream for compression if block compression is enabled for its block type
107140
*/
108-
def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = {
141+
private[this] def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = {
109142
if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s
110143
}
111144

112145
/**
113146
* Wrap an input stream for compression if block compression is enabled for its block type
114147
*/
115-
def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = {
148+
private[this] def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = {
116149
if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s
117150
}
118151

@@ -123,7 +156,7 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
123156
values: Iterator[T]): Unit = {
124157
val byteStream = new BufferedOutputStream(outputStream)
125158
val ser = getSerializer(implicitly[ClassTag[T]]).newInstance()
126-
ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
159+
ser.serializeStream(wrapStream(blockId, byteStream)).writeAll(values).close()
127160
}
128161

129162
/** Serializes into a chunked byte buffer. */
@@ -139,7 +172,7 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
139172
val bbos = new ChunkedByteBufferOutputStream(1024 * 1024 * 4, ByteBuffer.allocate)
140173
val byteStream = new BufferedOutputStream(bbos)
141174
val ser = getSerializer(classTag).newInstance()
142-
ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
175+
ser.serializeStream(wrapStream(blockId, byteStream)).writeAll(values).close()
143176
bbos.toChunkedByteBuffer
144177
}
145178

@@ -153,7 +186,7 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
153186
val stream = new BufferedInputStream(inputStream)
154187
getSerializer(implicitly[ClassTag[T]])
155188
.newInstance()
156-
.deserializeStream(wrapForCompression(blockId, stream))
189+
.deserializeStream(wrapStream(blockId, stream))
157190
.asIterator.asInstanceOf[Iterator[T]]
158191
}
159192
}

core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ private[spark] class BlockStoreShuffleReader[K, C](
5151
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
5252
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue))
5353

54-
// Wrap the streams for compression based on configuration
54+
// Wrap the streams for compression and encryption based on configuration
5555
val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
56-
serializerManager.wrapForCompression(blockId, inputStream)
56+
serializerManager.wrapStream(blockId, inputStream)
5757
}
5858

5959
val serializerInstance = dep.serializer.newInstance()

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -721,10 +721,9 @@ private[spark] class BlockManager(
721721
serializerInstance: SerializerInstance,
722722
bufferSize: Int,
723723
writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = {
724-
val compressStream: OutputStream => OutputStream =
725-
serializerManager.wrapForCompression(blockId, _)
724+
val wrapStream: OutputStream => OutputStream = serializerManager.wrapStream(blockId, _)
726725
val syncWrites = conf.getBoolean("spark.shuffle.sync", false)
727-
new DiskBlockObjectWriter(file, serializerInstance, bufferSize, compressStream,
726+
new DiskBlockObjectWriter(file, serializerInstance, bufferSize, wrapStream,
728727
syncWrites, writeMetrics, blockId)
729728
}
730729

core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ private[spark] class DiskBlockObjectWriter(
3939
val file: File,
4040
serializerInstance: SerializerInstance,
4141
bufferSize: Int,
42-
compressStream: OutputStream => OutputStream,
42+
wrapStream: OutputStream => OutputStream,
4343
syncWrites: Boolean,
4444
// These write metrics concurrently shared with other active DiskBlockObjectWriters who
4545
// are themselves performing writes. All updates must be relative.
@@ -115,7 +115,8 @@ private[spark] class DiskBlockObjectWriter(
115115
initialize()
116116
initialized = true
117117
}
118-
bs = compressStream(mcs)
118+
119+
bs = wrapStream(mcs)
119120
objOut = serializerInstance.serializeStream(bs)
120121
streamOpen = true
121122
this

0 commit comments

Comments
 (0)