diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index 8f8a0b11f9f2e..d0f6209db3257 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -30,7 +30,7 @@ private[spark] class BroadcastManager( extends Logging { private var initialized = false - private var broadcastFactory: BroadcastFactory = null + var broadcastFactory: BroadcastFactory = null initialize() diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 942dc7d7eac87..c97de0a0e2179 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -19,9 +19,11 @@ package org.apache.spark.broadcast import java.io.{File, FileOutputStream, ObjectInputStream, ObjectOutputStream, OutputStream} import java.io.{BufferedInputStream, BufferedOutputStream} +import java.io.IOException import java.net.{URL, URLConnection, URI} import java.util.concurrent.TimeUnit +import scala.collection.mutable.HashMap import scala.reflect.ClassTag import org.apache.spark.{HttpServer, Logging, SecurityManager, SparkConf, SparkEnv} @@ -37,7 +39,7 @@ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedH * executor to speed up future accesses. */ private[spark] class HttpBroadcast[T: ClassTag]( - @transient var value_ : T, isLocal: Boolean, id: Long) + @transient var value_ : T, isLocal: Boolean, id: Long, key: Int = 0) extends Broadcast[T](id) with Logging with Serializable { override protected def getValue() = value_ @@ -54,21 +56,21 @@ private[spark] class HttpBroadcast[T: ClassTag]( } if (!isLocal) { - HttpBroadcast.write(id, value_) + HttpBroadcast.write(id, value_, key) } /** * Remove all persisted state associated with this HTTP broadcast on the executors. */ override protected def doUnpersist(blocking: Boolean) { - HttpBroadcast.unpersist(id, removeFromDriver = false, blocking) + HttpBroadcast.unpersist(id, removeFromDriver = false, blocking, key) } /** * Remove all persisted state associated with this HTTP broadcast on the executors and driver. */ override protected def doDestroy(blocking: Boolean) { - HttpBroadcast.unpersist(id, removeFromDriver = true, blocking) + HttpBroadcast.unpersist(id, removeFromDriver = true, blocking, key) } /** Used by the JVM when serializing this object. */ @@ -86,7 +88,7 @@ private[spark] class HttpBroadcast[T: ClassTag]( case None => { logInfo("Started reading broadcast variable " + id) val start = System.nanoTime - value_ = HttpBroadcast.read[T](id) + value_ = HttpBroadcast.read[T](id, key) /* * We cache broadcast data in the BlockManager so that subsequent tasks using it * do not need to re-fetch. This data is only used locally and no other node @@ -103,109 +105,53 @@ private[spark] class HttpBroadcast[T: ClassTag]( } private[broadcast] object HttpBroadcast extends Logging { - private var initialized = false - private var broadcastDir: File = null - private var compress: Boolean = false - private var bufferSize: Int = 65536 - private var serverUri: String = null - private var server: HttpServer = null - private var securityManager: SecurityManager = null - - // TODO: This shouldn't be a global variable so that multiple SparkContexts can coexist - private val files = new TimeStampedHashSet[File] - private val httpReadTimeout = TimeUnit.MILLISECONDS.convert(5, TimeUnit.MINUTES).toInt - private var compressionCodec: CompressionCodec = null - private var cleaner: MetadataCleaner = null - - def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { + + val activeHBC = new HashMap[Int, HttpBroadcastContainer] + + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager, key: Int = 0) { synchronized { - if (!initialized) { - bufferSize = conf.getInt("spark.buffer.size", 65536) - compress = conf.getBoolean("spark.broadcast.compress", true) - securityManager = securityMgr - if (isDriver) { - createServer(conf) - conf.set("spark.httpBroadcast.uri", serverUri) - } - serverUri = conf.get("spark.httpBroadcast.uri") - cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup, conf) - compressionCodec = CompressionCodec.createCodec(conf) - initialized = true + if (!activeHBC.contains(key)) { + val hbc = new HttpBroadcastContainer() + hbc.initialize(isDriver, conf, securityMgr) + activeHBC.put(key, hbc) } } } - - def stop() { + + def stop(key: Int = 0) { synchronized { - if (server != null) { - server.stop() - server = null - } - if (cleaner != null) { - cleaner.cancel() - cleaner = null + if (activeHBC.contains(key) && activeHBC(key) != null) { + activeHBC(key).stop() + activeHBC(key) = null + activeHBC.remove(key) + } else { + logWarning("Not found key in HttpBroadcast") } - compressionCodec = null - initialized = false } } - - private def createServer(conf: SparkConf) { - broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf)) - val broadcastPort = conf.getInt("spark.broadcast.port", 0) - server = new HttpServer(broadcastDir, securityManager, broadcastPort, "HTTP broadcast server") - server.start() - serverUri = server.uri - logInfo("Broadcast server started at " + serverUri) - } - - def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name) - - private def write(id: Long, value: Any) { - val file = getFile(id) - val out: OutputStream = { - if (compress) { - compressionCodec.compressedOutputStream(new FileOutputStream(file)) - } else { - new BufferedOutputStream(new FileOutputStream(file), bufferSize) - } + + def getFile(id: Long, key: Int = 0): File = { + if (activeHBC.contains(key) && activeHBC(key) != null) { + activeHBC(key).getFile(id) + } else { + throw new IOException("Not found key in HttpBroadcast") } - val ser = SparkEnv.get.serializer.newInstance() - val serOut = ser.serializeStream(out) - serOut.writeObject(value) - serOut.close() - files += file } - private def read[T: ClassTag](id: Long): T = { - logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id) - val url = serverUri + "/" + BroadcastBlockId(id).name - - var uc: URLConnection = null - if (securityManager.isAuthenticationEnabled()) { - logDebug("broadcast security enabled") - val newuri = Utils.constructURIForAuthentication(new URI(url), securityManager) - uc = newuri.toURL.openConnection() - uc.setAllowUserInteraction(false) + private def write(id: Long, value: Any, key: Int = 0) { + if (activeHBC.contains(key) && activeHBC(key) != null) { + activeHBC(key).write(id, value) } else { - logDebug("broadcast not using security") - uc = new URL(url).openConnection() + logWarning("Not found key in HttpBroadcast") } + } - val in = { - uc.setReadTimeout(httpReadTimeout) - val inputStream = uc.getInputStream - if (compress) { - compressionCodec.compressedInputStream(inputStream) - } else { - new BufferedInputStream(inputStream, bufferSize) - } + private def read[T: ClassTag](id: Long, key: Int = 0): T = { + if (activeHBC.contains(key) && activeHBC(key) != null) { + activeHBC(key).read[T](id) + } else { + throw new IOException("Not found key in HttpBroadcast") } - val ser = SparkEnv.get.serializer.newInstance() - val serIn = ser.deserializeStream(in) - val obj = serIn.readObject[T]() - serIn.close() - obj } /** @@ -213,43 +159,11 @@ private[broadcast] object HttpBroadcast extends Logging { * If removeFromDriver is true, also remove these persisted blocks on the driver * and delete the associated broadcast file. */ - def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = synchronized { - SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) - if (removeFromDriver) { - val file = getFile(id) - files.remove(file) - deleteBroadcastFile(file) - } - } - - /** - * Periodically clean up old broadcasts by removing the associated map entries and - * deleting the associated files. - */ - private def cleanup(cleanupTime: Long) { - val iterator = files.internalMap.entrySet().iterator() - while(iterator.hasNext) { - val entry = iterator.next() - val (file, time) = (entry.getKey, entry.getValue) - if (time < cleanupTime) { - iterator.remove() - deleteBroadcastFile(file) - } - } - } - - private def deleteBroadcastFile(file: File) { - try { - if (file.exists) { - if (file.delete()) { - logInfo("Deleted broadcast file: %s".format(file)) - } else { - logWarning("Could not delete broadcast file: %s".format(file)) - } - } - } catch { - case e: Exception => - logError("Exception while deleting broadcast file: %s".format(file), e) + def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean, key: Int = 0) = synchronized { + if (activeHBC.contains(key) && activeHBC(key) != null) { + activeHBC(key).unpersist(id, removeFromDriver, blocking) + } else { + logWarning("Not found key in HttpBroadcast") } } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastContainer.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastContainer.scala new file mode 100644 index 0000000000000..e69e587eabd33 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastContainer.scala @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +import java.io.{File, FileOutputStream, ObjectInputStream, ObjectOutputStream, OutputStream} +import java.io.{BufferedInputStream, BufferedOutputStream} +import java.net.{URL, URLConnection, URI} +import java.util.concurrent.TimeUnit + +import scala.reflect.ClassTag + +import org.apache.spark.{HttpServer, Logging, SecurityManager, SparkConf, SparkEnv} +import org.apache.spark.io.CompressionCodec +import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} +import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils} + + + +import scala.reflect.ClassTag + +import org.apache.spark.{SecurityManager, SparkConf} + +class HttpBroadcastContainer extends Serializable with Logging { + + private var initialized = false + private var broadcastDir: File = null + private var compress: Boolean = false + private var bufferSize: Int = 65536 + private var serverUri: String = null + private var server: HttpServer = null + private var securityManager: SecurityManager = null + + private val files = new TimeStampedHashSet[File] + private val httpReadTimeout = TimeUnit.MILLISECONDS.convert(5, TimeUnit.MINUTES).toInt + private var compressionCodec: CompressionCodec = null + private var cleaner: MetadataCleaner = null + + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { + synchronized { + if (!initialized) { + bufferSize = conf.getInt("spark.buffer.size", 65536) + compress = conf.getBoolean("spark.broadcast.compress", true) + securityManager = securityMgr + if (isDriver) { + createServer(conf) + conf.set("spark.httpBroadcast.uri", serverUri) + } + serverUri = conf.get("spark.httpBroadcast.uri") + cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup, conf) + compressionCodec = CompressionCodec.createCodec(conf) + initialized = true + } + } + } + + def stop() { + synchronized { + if (server != null) { + server.stop() + server = null + } + if (cleaner != null) { + cleaner.cancel() + cleaner = null + } + compressionCodec = null + initialized = false + } + } + + private def createServer(conf: SparkConf) { + broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf)) + val broadcastPort = conf.getInt("spark.broadcast.port", 0) + server = new HttpServer(broadcastDir, securityManager, broadcastPort, "HTTP broadcast server") + server.start() + serverUri = server.uri + logInfo("Broadcast server started at " + serverUri) + } + + def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name) + + def write(id: Long, value: Any) { + val file = getFile(id) + val out: OutputStream = { + if (compress) { + compressionCodec.compressedOutputStream(new FileOutputStream(file)) + } else { + new BufferedOutputStream(new FileOutputStream(file), bufferSize) + } + } + val ser = SparkEnv.get.serializer.newInstance() + val serOut = ser.serializeStream(out) + serOut.writeObject(value) + serOut.close() + files += file + } + + def read[T: ClassTag](id: Long): T = { + logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id) + val url = serverUri + "/" + BroadcastBlockId(id).name + + var uc: URLConnection = null + if (securityManager.isAuthenticationEnabled()) { + logDebug("broadcast security enabled") + val newuri = Utils.constructURIForAuthentication(new URI(url), securityManager) + uc = newuri.toURL.openConnection() + uc.setAllowUserInteraction(false) + } else { + logDebug("broadcast not using security") + uc = new URL(url).openConnection() + } + + val in = { + uc.setReadTimeout(httpReadTimeout) + val inputStream = uc.getInputStream + if (compress) { + compressionCodec.compressedInputStream(inputStream) + } else { + new BufferedInputStream(inputStream, bufferSize) + } + } + val ser = SparkEnv.get.serializer.newInstance() + val serIn = ser.deserializeStream(in) + val obj = serIn.readObject[T]() + serIn.close() + obj + } + + /** + * Remove all persisted blocks associated with this HTTP broadcast on the executors. + * If removeFromDriver is true, also remove these persisted blocks on the driver + * and delete the associated broadcast file. + */ + def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = synchronized { + SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) + if (removeFromDriver) { + val file = getFile(id) + files.remove(file) + deleteBroadcastFile(file) + } + } + + /** + * Periodically clean up old broadcasts by removing the associated map entries and + * deleting the associated files. + */ + private def cleanup(cleanupTime: Long) { + val iterator = files.internalMap.entrySet().iterator() + while(iterator.hasNext) { + val entry = iterator.next() + val (file, time) = (entry.getKey, entry.getValue) + if (time < cleanupTime) { + iterator.remove() + deleteBroadcastFile(file) + } + } + } + + private def deleteBroadcastFile(file: File) { + try { + if (file.exists) { + if (file.delete()) { + logInfo("Deleted broadcast file: %s".format(file)) + } else { + logWarning("Could not delete broadcast file: %s".format(file)) + } + } + } catch { + case e: Exception => + logError("Exception while deleting broadcast file: %s".format(file), e) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala index c7ef02d572a19..c2d08322ea676 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala @@ -27,14 +27,17 @@ import org.apache.spark.{SecurityManager, SparkConf} * [[org.apache.spark.broadcast.HttpBroadcast]] for more details about this mechanism. */ class HttpBroadcastFactory extends BroadcastFactory { + + private val thisKey = this.hashCode() + override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { - HttpBroadcast.initialize(isDriver, conf, securityMgr) + HttpBroadcast.initialize(isDriver, conf, securityMgr, thisKey) } override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) = - new HttpBroadcast[T](value_, isLocal, id) + new HttpBroadcast[T](value_, isLocal, id, thisKey) - override def stop() { HttpBroadcast.stop() } + override def stop() { HttpBroadcast.stop(thisKey) } /** * Remove all persisted state associated with the HTTP broadcast with the given ID. @@ -42,6 +45,6 @@ class HttpBroadcastFactory extends BroadcastFactory { * @param blocking Whether to block until unbroadcasted */ override def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) { - HttpBroadcast.unpersist(id, removeFromDriver, blocking) + HttpBroadcast.unpersist(id, removeFromDriver, blocking, thisKey) } } diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 978a6ded80829..7c3d0208b195a 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -17,11 +17,9 @@ package org.apache.spark.broadcast -import org.scalatest.FunSuite - +import org.apache.spark.storage.{BroadcastBlockId, _} import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException} -import org.apache.spark.storage._ - +import org.scalatest.FunSuite class BroadcastSuite extends FunSuite with LocalSparkContext { @@ -46,10 +44,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { test("Accessing HttpBroadcast variables in a local cluster") { val numSlaves = 4 - val conf = httpConf.clone - conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - conf.set("spark.broadcast.compress", "true") - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", conf) + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", httpConf) val list = List[Int](1, 2, 3, 4) val broadcast = sc.broadcast(list) val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) @@ -74,10 +69,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { test("Accessing TorrentBroadcast variables in a local cluster") { val numSlaves = 4 - val conf = torrentConf.clone - conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - conf.set("spark.broadcast.compress", "true") - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", conf) + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", torrentConf) val list = List[Int](1, 2, 3, 4) val broadcast = sc.broadcast(list) val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) @@ -126,10 +118,12 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { private def testUnpersistHttpBroadcast(distributed: Boolean, removeFromDriver: Boolean) { val numSlaves = if (distributed) 2 else 0 + def getBlockIds(id: Long) = Seq[BroadcastBlockId](BroadcastBlockId(id)) + // Verify that the broadcast file is created, and blocks are persisted only on the driver - def afterCreation(broadcastId: Long, bmm: BlockManagerMaster) { - val blockId = BroadcastBlockId(broadcastId) - val statuses = bmm.getBlockStatus(blockId, askSlaves = true) + def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { + assert(blockIds.size === 1) + val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true) assert(statuses.size === 1) statuses.head match { case (bm, status) => assert(bm.executorId === "", "Block should only be on the driver") @@ -139,14 +133,14 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { } if (distributed) { // this file is only generated in distributed mode - assert(HttpBroadcast.getFile(blockId.broadcastId).exists, "Broadcast file not found!") + assert(HttpBroadcast.getFile(blockIds.head.broadcastId).exists, "Broadcast file not found!") } } // Verify that blocks are persisted in both the executors and the driver - def afterUsingBroadcast(broadcastId: Long, bmm: BlockManagerMaster) { - val blockId = BroadcastBlockId(broadcastId) - val statuses = bmm.getBlockStatus(blockId, askSlaves = true) + def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { + assert(blockIds.size === 1) + val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true) assert(statuses.size === numSlaves + 1) statuses.foreach { case (_, status) => assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) @@ -157,21 +151,21 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver // is true. In the latter case, also verify that the broadcast file is deleted on the driver. - def afterUnpersist(broadcastId: Long, bmm: BlockManagerMaster) { - val blockId = BroadcastBlockId(broadcastId) - val statuses = bmm.getBlockStatus(blockId, askSlaves = true) + def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { + assert(blockIds.size === 1) + val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true) val expectedNumBlocks = if (removeFromDriver) 0 else 1 val possiblyNot = if (removeFromDriver) "" else " not" assert(statuses.size === expectedNumBlocks, "Block should%s be unpersisted on the driver".format(possiblyNot)) if (distributed && removeFromDriver) { // this file is only generated in distributed mode - assert(!HttpBroadcast.getFile(blockId.broadcastId).exists, + assert(!HttpBroadcast.getFile(blockIds.head.broadcastId).exists, "Broadcast file should%s be deleted".format(possiblyNot)) } } - testUnpersistBroadcast(distributed, numSlaves, httpConf, afterCreation, + testUnpersistBroadcast(distributed, numSlaves, httpConf, getBlockIds, afterCreation, afterUsingBroadcast, afterUnpersist, removeFromDriver) } @@ -185,51 +179,67 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { private def testUnpersistTorrentBroadcast(distributed: Boolean, removeFromDriver: Boolean) { val numSlaves = if (distributed) 2 else 0 - // Verify that blocks are persisted only on the driver - def afterCreation(broadcastId: Long, bmm: BlockManagerMaster) { - var blockId = BroadcastBlockId(broadcastId) - var statuses = bmm.getBlockStatus(blockId, askSlaves = true) - assert(statuses.size === 1) - - blockId = BroadcastBlockId(broadcastId, "piece0") - statuses = bmm.getBlockStatus(blockId, askSlaves = true) - assert(statuses.size === (if (distributed) 1 else 0)) - } - - // Verify that blocks are persisted in both the executors and the driver - def afterUsingBroadcast(broadcastId: Long, bmm: BlockManagerMaster) { - var blockId = BroadcastBlockId(broadcastId) - var statuses = bmm.getBlockStatus(blockId, askSlaves = true) + def getBlockIds(id: Long) = { + val broadcastBlockId = BroadcastBlockId(id) + val metaBlockId = BroadcastBlockId(id, "meta") + // Assume broadcast value is small enough to fit into 1 piece + val pieceBlockId = BroadcastBlockId(id, "piece0") if (distributed) { - assert(statuses.size === numSlaves + 1) + // the metadata and piece blocks are generated only in distributed mode + Seq[BroadcastBlockId](broadcastBlockId, metaBlockId, pieceBlockId) } else { + Seq[BroadcastBlockId](broadcastBlockId) + } + } + + // Verify that blocks are persisted only on the driver + def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { + blockIds.foreach { blockId => + val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true) assert(statuses.size === 1) + statuses.head match { case (bm, status) => + assert(bm.executorId === "", "Block should only be on the driver") + assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) + assert(status.memSize > 0, "Block should be in memory store on the driver") + assert(status.diskSize === 0, "Block should not be in disk store on the driver") + } } + } - blockId = BroadcastBlockId(broadcastId, "piece0") - statuses = bmm.getBlockStatus(blockId, askSlaves = true) - if (distributed) { - assert(statuses.size === numSlaves + 1) - } else { - assert(statuses.size === 0) + // Verify that blocks are persisted in both the executors and the driver + def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { + blockIds.foreach { blockId => + val statuses = bmm.getBlockStatus(blockId, askSlaves = true) + if (blockId.field == "meta") { + // Meta data is only on the driver + assert(statuses.size === 1) + statuses.head match { case (bm, _) => assert(bm.executorId === "") } + } else { + // Other blocks are on both the executors and the driver + assert(statuses.size === numSlaves + 1, + blockId + " has " + statuses.size + " statuses: " + statuses.mkString(",")) + statuses.foreach { case (_, status) => + assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) + assert(status.memSize > 0, "Block should be in memory store") + assert(status.diskSize === 0, "Block should not be in disk store") + } + } } } // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver // is true. - def afterUnpersist(broadcastId: Long, bmm: BlockManagerMaster) { - var blockId = BroadcastBlockId(broadcastId) - var expectedNumBlocks = if (removeFromDriver) 0 else 1 - var statuses = bmm.getBlockStatus(blockId, askSlaves = true) - assert(statuses.size === expectedNumBlocks) - - blockId = BroadcastBlockId(broadcastId, "piece0") - expectedNumBlocks = if (removeFromDriver || !distributed) 0 else 1 - statuses = bmm.getBlockStatus(blockId, askSlaves = true) - assert(statuses.size === expectedNumBlocks) + def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { + val expectedNumBlocks = if (removeFromDriver) 0 else 1 + val possiblyNot = if (removeFromDriver) "" else " not" + blockIds.foreach { blockId => + val statuses = bmm.getBlockStatus(blockId, askSlaves = true) + assert(statuses.size === expectedNumBlocks, + "Block should%s be unpersisted on the driver".format(possiblyNot)) + } } - testUnpersistBroadcast(distributed, numSlaves, torrentConf, afterCreation, + testUnpersistBroadcast(distributed, numSlaves, torrentConf, getBlockIds, afterCreation, afterUsingBroadcast, afterUnpersist, removeFromDriver) } @@ -246,9 +256,10 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { distributed: Boolean, numSlaves: Int, // used only when distributed = true broadcastConf: SparkConf, - afterCreation: (Long, BlockManagerMaster) => Unit, - afterUsingBroadcast: (Long, BlockManagerMaster) => Unit, - afterUnpersist: (Long, BlockManagerMaster) => Unit, + getBlockIds: Long => Seq[BroadcastBlockId], + afterCreation: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit, + afterUsingBroadcast: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit, + afterUnpersist: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit, removeFromDriver: Boolean) { sc = if (distributed) { @@ -261,14 +272,15 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { // Create broadcast variable val broadcast = sc.broadcast(list) - afterCreation(broadcast.id, blockManagerMaster) + val blocks = getBlockIds(broadcast.id) + afterCreation(blocks, blockManagerMaster) // Use broadcast variable on all executors val partitions = 10 assert(partitions > numSlaves) val results = sc.parallelize(1 to partitions, partitions).map(x => (x, broadcast.value.sum)) assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet) - afterUsingBroadcast(broadcast.id, blockManagerMaster) + afterUsingBroadcast(blocks, blockManagerMaster) // Unpersist broadcast if (removeFromDriver) { @@ -276,7 +288,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { } else { broadcast.unpersist(blocking = true) } - afterUnpersist(broadcast.id, blockManagerMaster) + afterUnpersist(blocks, blockManagerMaster) // If the broadcast is removed from driver, all subsequent uses of the broadcast variable // should throw SparkExceptions. Otherwise, the result should be the same as before.