diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 94142d33369c7..77c6841a144ab 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -67,6 +67,9 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) @transient private var compressionCodec: Option[CompressionCodec] = _ /** Size of each block. Default value is 4MB. This value is only read by the broadcaster. */ @transient private var blockSize: Int = _ + /** Max size of block that can be embedded, Default value is 8KB. + * This value is only read by the broadcaster. */ + @transient private var embedSizeLimit : Int = _ private def setConf(conf: SparkConf) { compressionCodec = if (conf.getBoolean("spark.broadcast.compress", true)) { @@ -75,6 +78,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) None } blockSize = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 + embedSizeLimit = conf.getInt("spark.broadcast.embedSizeLimit", 8) * 1024 } setConf(SparkEnv.get.conf) @@ -83,6 +87,12 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) /** Total number of blocks this broadcast variable contains. */ private val numBlocks: Int = writeBlocks(obj) + /** + * Embed the serialized object into Broadcast to reduce the overhead of RPC when the object + * is small enough. + */ + private var embeddedBlock: Array[Byte] = _ + override protected def getValue() = { _value } @@ -99,12 +109,19 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) tellMaster = false) val blocks = TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec) - blocks.zipWithIndex.foreach { case (block, i) => - SparkEnv.get.blockManager.putBytes( - BroadcastBlockId(id, "piece" + i), - block, - StorageLevel.MEMORY_AND_DISK_SER, - tellMaster = true) + assert(blocks.length > 0, "should have at least one block") + + if (blocks.size == 1 && blocks(0).limit < embedSizeLimit) { + // embed small object inside Broadcast to avoid RPC + embeddedBlock = blocks(0).array() + } else { + blocks.zipWithIndex.foreach { case (block, i) => + SparkEnv.get.blockManager.putBytes( + BroadcastBlockId(id, "piece" + i), + block, + StorageLevel.MEMORY_AND_DISK_SER, + tellMaster = true) + } } blocks.length } @@ -114,6 +131,12 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) // Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported // to the driver, so other executors can pull these chunks from this executor as well. val blocks = new Array[ByteBuffer](numBlocks) + if (embeddedBlock != null) { + // get blocks from embedded one + blocks(0) = ByteBuffer.wrap(embeddedBlock) + embeddedBlock = null // release + return blocks + } val bm = SparkEnv.get.blockManager for (pid <- Random.shuffle(Seq.range(0, numBlocks))) { @@ -184,12 +207,12 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) } } } - } private object TorrentBroadcast extends Logging { + def blockifyObject[T: ClassTag]( obj: T, blockSize: Int, @@ -207,7 +230,9 @@ private object TorrentBroadcast extends Logging { blocks: Array[ByteBuffer], serializer: Serializer, compressionCodec: Option[CompressionCodec]): T = { - require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks") + if (blocks.isEmpty) { + throw new IOException("Cannot unblockify an empty array of blocks") + } val is = new SequenceInputStream( asJavaEnumeration(blocks.iterator.map(block => new ByteBufferInputStream(block)))) val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is) diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index b0a70f012f1f3..431087e753f8e 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -303,7 +303,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { new SparkContext("local", "test", broadcastConf) } val blockManagerMaster = sc.env.blockManager.master - val list = List[Int](1, 2, 3, 4) + val list = (1 to 4096).toList // Create broadcast variable val broadcast = sc.broadcast(list)