Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,8 @@ import org.apache.spark.util.{AkkaUtils, Utils}
* :: DeveloperApi ::
* Holds all the runtime environment objects for a running Spark instance (either master or worker),
* including the serializer, Akka actor system, block manager, map output tracker, etc. Currently
* Spark code finds the SparkEnv through a thread-local variable, so each thread that accesses these
* objects needs to have the right SparkEnv set. You can get the current environment with
* SparkEnv.get (e.g. after creating a SparkContext) and set it with SparkEnv.set.
* Spark code finds the SparkEnv through a global variable, so all the threads can access the same
* SparkEnv. It can be accessed by SparkEnv.get (e.g. after creating a SparkContext).
*
* NOTE: This is not intended for external use. This is exposed for Shark and may be made private
* in a future release.
Expand Down Expand Up @@ -119,30 +118,28 @@ class SparkEnv (
}

object SparkEnv extends Logging {
private val env = new ThreadLocal[SparkEnv]
@volatile private var lastSetSparkEnv : SparkEnv = _
@volatile private var env: SparkEnv = _

private[spark] val driverActorSystemName = "sparkDriver"
private[spark] val executorActorSystemName = "sparkExecutor"

def set(e: SparkEnv) {
lastSetSparkEnv = e
env.set(e)
env = e
}

/**
* Returns the ThreadLocal SparkEnv, if non-null. Else returns the SparkEnv
* previously set in any thread.
* Returns the SparkEnv.
*/
def get: SparkEnv = {
Option(env.get()).getOrElse(lastSetSparkEnv)
env
}

/**
* Returns the ThreadLocal SparkEnv.
*/
@deprecated("Use SparkEnv.get instead", "1.2")
def getThreadLocal: SparkEnv = {
env.get()
env
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should leave this in instead of removing it, because some user code might be calling this public method.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Even though it's a DeveloperApi we shouldn't break it if we can help it)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Let's throw @deprecated on it when we put it back, though.

private[spark] def create(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ private[spark] class PythonRDD(

override def run(): Unit = Utils.logUncaughtExceptions {
try {
SparkEnv.set(env)
val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
val dataOut = new DataOutputStream(stream)
// Partition index
Expand Down
2 changes: 0 additions & 2 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ private[spark] class Executor(

override def run() {
val startTime = System.currentTimeMillis()
SparkEnv.set(env)
Thread.currentThread.setContextClassLoader(replClassLoader)
val ser = SparkEnv.get.closureSerializer.newInstance()
logInfo(s"Running $taskName (TID $taskId)")
Expand All @@ -157,7 +156,6 @@ private[spark] class Executor(
val startGCTime = gcTime

try {
SparkEnv.set(env)
Accumulators.clear()
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
updateDependencies(taskFiles, taskJars)
Expand Down
1 change: 0 additions & 1 deletion core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ private[spark] class PipedRDD[T: ClassTag](
// Start a thread to feed the process input from our parent's iterator
new Thread("stdin writer for " + command) {
override def run() {
SparkEnv.set(env)
val out = new PrintWriter(proc.getOutputStream)

// input the pipe context firstly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,6 @@ class DAGScheduler(
protected def runLocallyWithinThread(job: ActiveJob) {
var jobResult: JobResult = JobSucceeded
try {
SparkEnv.set(env)
val rdd = job.finalStage.rdd
val split = rdd.partitions(job.partitions(0))
val taskContext =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,6 @@ private[spark] class TaskSchedulerImpl(
* that tasks are balanced across the cluster.
*/
def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized {
SparkEnv.set(sc.env)

// Mark each slave as alive and remember its hostname
// Also track if new executor is added
var newExecAvail = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {

/** Generate jobs and perform checkpoint for the given `time`. */
private def generateJobs(time: Time) {
SparkEnv.set(ssc.env)
Try(graph.generateJobs(time)) match {
case Success(jobs) =>
val receivedBlockInfo = graph.getReceiverInputStreams.map { stream =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
}
jobSet.handleJobStart(job)
logInfo("Starting job " + job.id + " from job set of time " + jobSet.time)
SparkEnv.set(ssc.env)
}

private def handleJobCompletion(job: Job) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging {
@transient val thread = new Thread() {
override def run() {
try {
SparkEnv.set(env)
startReceivers()
} catch {
case ie: InterruptedException => logInfo("ReceiverLauncher interrupted")
Expand Down