Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ private[spark] class BroadcastManager(
extends Logging {

private var initialized = false
private var broadcastFactory: BroadcastFactory = null
var broadcastFactory: BroadcastFactory = null

initialize()

Expand Down
174 changes: 44 additions & 130 deletions core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ package org.apache.spark.broadcast

import java.io.{File, FileOutputStream, ObjectInputStream, ObjectOutputStream, OutputStream}
import java.io.{BufferedInputStream, BufferedOutputStream}
import java.io.IOException
import java.net.{URL, URLConnection, URI}
import java.util.concurrent.TimeUnit

import scala.collection.mutable.HashMap
import scala.reflect.ClassTag

import org.apache.spark.{HttpServer, Logging, SecurityManager, SparkConf, SparkEnv}
Expand All @@ -37,7 +39,7 @@ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedH
* executor to speed up future accesses.
*/
private[spark] class HttpBroadcast[T: ClassTag](
@transient var value_ : T, isLocal: Boolean, id: Long)
@transient var value_ : T, isLocal: Boolean, id: Long, key: Int = 0)
extends Broadcast[T](id) with Logging with Serializable {

override protected def getValue() = value_
Expand All @@ -54,21 +56,21 @@ private[spark] class HttpBroadcast[T: ClassTag](
}

if (!isLocal) {
HttpBroadcast.write(id, value_)
HttpBroadcast.write(id, value_, key)
}

/**
* Remove all persisted state associated with this HTTP broadcast on the executors.
*/
override protected def doUnpersist(blocking: Boolean) {
HttpBroadcast.unpersist(id, removeFromDriver = false, blocking)
HttpBroadcast.unpersist(id, removeFromDriver = false, blocking, key)
}

/**
* Remove all persisted state associated with this HTTP broadcast on the executors and driver.
*/
override protected def doDestroy(blocking: Boolean) {
HttpBroadcast.unpersist(id, removeFromDriver = true, blocking)
HttpBroadcast.unpersist(id, removeFromDriver = true, blocking, key)
}

/** Used by the JVM when serializing this object. */
Expand All @@ -86,7 +88,7 @@ private[spark] class HttpBroadcast[T: ClassTag](
case None => {
logInfo("Started reading broadcast variable " + id)
val start = System.nanoTime
value_ = HttpBroadcast.read[T](id)
value_ = HttpBroadcast.read[T](id, key)
/*
* We cache broadcast data in the BlockManager so that subsequent tasks using it
* do not need to re-fetch. This data is only used locally and no other node
Expand All @@ -103,153 +105,65 @@ private[spark] class HttpBroadcast[T: ClassTag](
}

private[broadcast] object HttpBroadcast extends Logging {
private var initialized = false
private var broadcastDir: File = null
private var compress: Boolean = false
private var bufferSize: Int = 65536
private var serverUri: String = null
private var server: HttpServer = null
private var securityManager: SecurityManager = null

// TODO: This shouldn't be a global variable so that multiple SparkContexts can coexist
private val files = new TimeStampedHashSet[File]
private val httpReadTimeout = TimeUnit.MILLISECONDS.convert(5, TimeUnit.MINUTES).toInt
private var compressionCodec: CompressionCodec = null
private var cleaner: MetadataCleaner = null

def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {

val activeHBC = new HashMap[Int, HttpBroadcastContainer]

def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager, key: Int = 0) {
synchronized {
if (!initialized) {
bufferSize = conf.getInt("spark.buffer.size", 65536)
compress = conf.getBoolean("spark.broadcast.compress", true)
securityManager = securityMgr
if (isDriver) {
createServer(conf)
conf.set("spark.httpBroadcast.uri", serverUri)
}
serverUri = conf.get("spark.httpBroadcast.uri")
cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup, conf)
compressionCodec = CompressionCodec.createCodec(conf)
initialized = true
if (!activeHBC.contains(key)) {
val hbc = new HttpBroadcastContainer()
hbc.initialize(isDriver, conf, securityMgr)
activeHBC.put(key, hbc)
}
}
}

def stop() {
def stop(key: Int = 0) {
synchronized {
if (server != null) {
server.stop()
server = null
}
if (cleaner != null) {
cleaner.cancel()
cleaner = null
if (activeHBC.contains(key) && activeHBC(key) != null) {
activeHBC(key).stop()
activeHBC(key) = null
activeHBC.remove(key)
} else {
logWarning("Not found key in HttpBroadcast")
}
compressionCodec = null
initialized = false
}
}

private def createServer(conf: SparkConf) {
broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf))
val broadcastPort = conf.getInt("spark.broadcast.port", 0)
server = new HttpServer(broadcastDir, securityManager, broadcastPort, "HTTP broadcast server")
server.start()
serverUri = server.uri
logInfo("Broadcast server started at " + serverUri)
}

def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name)

private def write(id: Long, value: Any) {
val file = getFile(id)
val out: OutputStream = {
if (compress) {
compressionCodec.compressedOutputStream(new FileOutputStream(file))
} else {
new BufferedOutputStream(new FileOutputStream(file), bufferSize)
}

def getFile(id: Long, key: Int = 0): File = {
if (activeHBC.contains(key) && activeHBC(key) != null) {
activeHBC(key).getFile(id)
} else {
throw new IOException("Not found key in HttpBroadcast")
}
val ser = SparkEnv.get.serializer.newInstance()
val serOut = ser.serializeStream(out)
serOut.writeObject(value)
serOut.close()
files += file
}

private def read[T: ClassTag](id: Long): T = {
logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id)
val url = serverUri + "/" + BroadcastBlockId(id).name

var uc: URLConnection = null
if (securityManager.isAuthenticationEnabled()) {
logDebug("broadcast security enabled")
val newuri = Utils.constructURIForAuthentication(new URI(url), securityManager)
uc = newuri.toURL.openConnection()
uc.setAllowUserInteraction(false)
private def write(id: Long, value: Any, key: Int = 0) {
if (activeHBC.contains(key) && activeHBC(key) != null) {
activeHBC(key).write(id, value)
} else {
logDebug("broadcast not using security")
uc = new URL(url).openConnection()
logWarning("Not found key in HttpBroadcast")
}
}

val in = {
uc.setReadTimeout(httpReadTimeout)
val inputStream = uc.getInputStream
if (compress) {
compressionCodec.compressedInputStream(inputStream)
} else {
new BufferedInputStream(inputStream, bufferSize)
}
private def read[T: ClassTag](id: Long, key: Int = 0): T = {
if (activeHBC.contains(key) && activeHBC(key) != null) {
activeHBC(key).read[T](id)
} else {
throw new IOException("Not found key in HttpBroadcast")
}
val ser = SparkEnv.get.serializer.newInstance()
val serIn = ser.deserializeStream(in)
val obj = serIn.readObject[T]()
serIn.close()
obj
}

/**
* Remove all persisted blocks associated with this HTTP broadcast on the executors.
* If removeFromDriver is true, also remove these persisted blocks on the driver
* and delete the associated broadcast file.
*/
def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = synchronized {
SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking)
if (removeFromDriver) {
val file = getFile(id)
files.remove(file)
deleteBroadcastFile(file)
}
}

/**
* Periodically clean up old broadcasts by removing the associated map entries and
* deleting the associated files.
*/
private def cleanup(cleanupTime: Long) {
val iterator = files.internalMap.entrySet().iterator()
while(iterator.hasNext) {
val entry = iterator.next()
val (file, time) = (entry.getKey, entry.getValue)
if (time < cleanupTime) {
iterator.remove()
deleteBroadcastFile(file)
}
}
}

private def deleteBroadcastFile(file: File) {
try {
if (file.exists) {
if (file.delete()) {
logInfo("Deleted broadcast file: %s".format(file))
} else {
logWarning("Could not delete broadcast file: %s".format(file))
}
}
} catch {
case e: Exception =>
logError("Exception while deleting broadcast file: %s".format(file), e)
def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean, key: Int = 0) = synchronized {
if (activeHBC.contains(key) && activeHBC(key) != null) {
activeHBC(key).unpersist(id, removeFromDriver, blocking)
} else {
logWarning("Not found key in HttpBroadcast")
}
}
}
Loading