From 79f48dcff976de99f06057da48951fa16e6d68d8 Mon Sep 17 00:00:00 2001 From: witgo Date: Fri, 6 Jun 2014 13:51:14 +0800 Subject: [PATCH] Add the lifecycle interface --- .../org/apache/spark/ContextCleaner.scala | 12 ++-- .../scala/org/apache/spark/HttpServer.scala | 8 ++- .../scala/org/apache/spark/Lifecycle.scala | 72 +++++++++++++++++++ .../org/apache/spark/SecurityManager.scala | 2 +- .../main/scala/org/apache/spark/Service.scala | 60 ++++++++++++++++ .../scala/org/apache/spark/SparkContext.scala | 22 ++++-- .../spark/deploy/client/AppClient.scala | 10 +-- .../spark/deploy/history/HistoryServer.scala | 9 ++- .../spark/deploy/master/ui/MasterWebUI.scala | 4 +- .../spark/deploy/worker/ui/WorkerWebUI.scala | 4 +- .../spark/metrics/sink/ConsoleSink.scala | 6 +- .../apache/spark/metrics/sink/CsvSink.scala | 6 +- .../spark/metrics/sink/GraphiteSink.scala | 6 +- .../apache/spark/metrics/sink/JmxSink.scala | 6 +- .../spark/metrics/sink/MetricsServlet.scala | 8 ++- .../org/apache/spark/metrics/sink/Sink.scala | 8 +-- .../apache/spark/scheduler/DAGScheduler.scala | 12 +++- .../spark/scheduler/SchedulerBackend.scala | 7 +- .../spark/scheduler/TaskScheduler.scala | 8 +-- .../spark/scheduler/TaskSchedulerImpl.scala | 4 +- .../CoarseGrainedSchedulerBackend.scala | 8 +-- .../cluster/SimrSchedulerBackend.scala | 10 +-- .../cluster/SparkDeploySchedulerBackend.scala | 1 + .../mesos/CoarseMesosSchedulerBackend.scala | 8 +-- .../cluster/mesos/MesosSchedulerBackend.scala | 12 ++-- .../spark/scheduler/local/LocalBackend.scala | 10 +-- .../scala/org/apache/spark/ui/SparkUI.scala | 7 +- .../scala/org/apache/spark/ui/WebUI.scala | 11 ++- .../spark/scheduler/DAGSchedulerSuite.scala | 10 +-- .../scheduler/TaskSchedulerImplSuite.scala | 7 +- .../spark/metrics/sink/GangliaSink.scala | 4 +- .../spark/streaming/StreamingContext.scala | 37 +++++----- .../api/java/JavaStreamingContext.scala | 23 ++++-- .../streaming/scheduler/JobGenerator.scala | 16 +++-- .../streaming/scheduler/JobScheduler.scala | 16 +++-- .../streaming/scheduler/ReceiverTracker.scala | 18 +++-- .../scheduler/StreamingListenerBus.scala | 8 +-- .../streaming/StreamingContextSuite.scala | 8 +-- .../cluster/YarnClientSchedulerBackend.scala | 8 +-- 39 files changed, 341 insertions(+), 155 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/Lifecycle.scala create mode 100644 core/src/main/scala/org/apache/spark/Service.scala diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index ede1e23f4fcc5..e008ae680104c 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -52,7 +52,7 @@ private class CleanupTaskWeakReference( * to be processed when the associated object goes out of scope of the application. Actual * cleanup is performed in a separate daemon thread. */ -private[spark] class ContextCleaner(sc: SparkContext) extends Logging { +private[spark] class ContextCleaner(sc: SparkContext) extends Logging with Lifecycle { private val referenceBuffer = new ArrayBuffer[CleanupTaskWeakReference] with SynchronizedBuffer[CleanupTaskWeakReference] @@ -90,24 +90,22 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { private val blockOnShuffleCleanupTasks = sc.conf.getBoolean( "spark.cleaner.referenceTracking.blocking.shuffle", false) - @volatile private var stopped = false - /** Attach a listener object to get information of when objects are cleaned. */ def attachListener(listener: CleanerListener) { listeners += listener } + def conf = sc.conf + /** Start the cleaner. */ - def start() { + override protected def doStart() { cleaningThread.setDaemon(true) cleaningThread.setName("Spark Context Cleaner") cleaningThread.start() } /** Stop the cleaner. */ - def stop() { - stopped = true - } + override protected def doStop() { } /** Register a RDD for cleanup when it is garbage collected. */ def registerRDDForCleanup(rdd: RDD[_]) { diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala index 912558d0cab7d..157ccbbc555b9 100644 --- a/core/src/main/scala/org/apache/spark/HttpServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpServer.scala @@ -46,12 +46,14 @@ private[spark] class HttpServer( securityManager: SecurityManager, requestedPort: Int = 0, serverName: String = "HTTP server") - extends Logging { + extends Logging with Lifecycle { private var server: Server = null private var port: Int = requestedPort - def start() { + def conf = securityManager.sparkConf + + override protected def doStart() { if (server != null) { throw new ServerStateException("Server is already started") } else { @@ -137,7 +139,7 @@ private[spark] class HttpServer( sh } - def stop() { + override protected def doStop() { if (server == null) { throw new ServerStateException("Server is already stopped") } else { diff --git a/core/src/main/scala/org/apache/spark/Lifecycle.scala b/core/src/main/scala/org/apache/spark/Lifecycle.scala new file mode 100644 index 0000000000000..50baef65c53f0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/Lifecycle.scala @@ -0,0 +1,72 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark + +trait Lifecycle extends Service { + + import State._ + + protected var state_ = Uninitialized + + def conf: SparkConf + + def uninitialized = state_ == Uninitialized + + def initialized = state_ == Initialized + + def started = state_ == Started + + def stopped = state_ == Stopped + + def state: State.State = state_ + + def initialize(): Unit = synchronized { + if (!uninitialized) { + throw new SparkException(s"Can't move to initialized state when $state_") + } + doInitialize + state_ = Initialized + } + + override def start(): Unit = synchronized { + if (uninitialized) initialize() + if (started) { + throw new SparkException(s"Can't move to started state when $state_") + } + doStart() + state_ = Started + } + + override def stop(): Unit = synchronized { + if (!started) { + throw new SparkException(s"Can't move to stopped state when $state_") + } + doStop + state_ = Stopped + } + + override def close(): Unit = synchronized { + stop() + } + + protected def doInitialize(): Unit = {} + + protected def doStart(): Unit + + protected def doStop(): Unit +} diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 12b15fe0815be..ac013af845b53 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -140,7 +140,7 @@ import org.apache.spark.deploy.SparkHadoopUtil * can take place. */ -private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { +private[spark] class SecurityManager(val sparkConf: SparkConf) extends Logging { // key used to store the spark secret in the Hadoop UGI private val sparkSecretLookupKey = "sparkCookie" diff --git a/core/src/main/scala/org/apache/spark/Service.scala b/core/src/main/scala/org/apache/spark/Service.scala new file mode 100644 index 0000000000000..07b99057dc608 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/Service.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +trait Service extends java.io.Closeable { + + /** + * Service states + */ + object State extends Enumeration { + + /** + * Constructed but not initialized + */ + val Uninitialized = Value(0, "Uninitialized") + + /** + * Initialized but not started or stopped + */ + val Initialized = Value(1, "Initialized") + + /** + * started and not stopped + */ + val Started = Value(2, "Started") + + /** + * stopped. No further state transitions are permitted + */ + val Stopped = Value(3, "Stopped") + + type State = Value + } + + def conf: SparkConf + + def initialize(): Unit + + def start(): Unit + + def stop(): Unit + + def state: State.State + +} diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index cb4fb7cfbd32f..d555766fa86f7 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -60,7 +60,7 @@ import org.apache.spark.util.{CallSite, ClosureCleaner, MetadataCleaner, Metadat * this config overrides the default configs as well as system properties. */ -class SparkContext(config: SparkConf) extends Logging { +class SparkContext(config: SparkConf) extends Logging with Lifecycle { // This is used only by YARN for now, but should be relevant to other cluster types (Mesos, // etc) too. This is typically generated from InputFormatInfo.computePreferredLocations. It @@ -154,9 +154,8 @@ class SparkContext(config: SparkConf) extends Logging { private[spark] def this(master: String, appName: String, sparkHome: String, jars: Seq[String]) = this(master, appName, sparkHome, jars, Map(), Map()) - private[spark] val conf = config.clone() + val conf = config.clone() conf.validateSettings() - /** * Return a copy of this SparkContext's configuration. The configuration ''cannot'' be * changed at runtime. @@ -987,8 +986,23 @@ class SparkContext(config: SparkConf) extends Logging { addedJars.clear() } + override def start() { + if (stopped) { + throw new SparkException("SparkContext has already been stopped") + } + super.start() + } + + override protected def doStart() {} + + start() + /** Shut down the SparkContext. */ - def stop() { + override def stop() { + if (started) super.stop() + } + + override protected def doStop() { postApplicationEnd() ui.stop() // Do this only if not stopped already - best case effort. diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 32790053a6be8..09bcc800cde85 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -26,7 +26,7 @@ import akka.actor._ import akka.pattern.ask import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent} -import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.{Lifecycle, Logging, SparkConf, SparkException} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.Master @@ -44,8 +44,8 @@ private[spark] class AppClient( masterUrls: Array[String], appDescription: ApplicationDescription, listener: AppClientListener, - conf: SparkConf) - extends Logging { + val conf: SparkConf) + extends Logging with Lifecycle { val REGISTRATION_TIMEOUT = 20.seconds val REGISTRATION_RETRIES = 3 @@ -186,12 +186,12 @@ private[spark] class AppClient( } - def start() { + override protected def doStart() { // Just launch an actor; it will call back into the listener. actor = actorSystem.actorOf(Props(new ClientActor)) } - def stop() { + override protected def doStop() { if (actor != null) { try { val timeout = AkkaUtils.askTimeout(conf) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index d1a64c1912cb8..f6156ef9507ba 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -41,7 +41,7 @@ import org.apache.spark.util.SignalLogger * EventLoggingListener. */ class HistoryServer( - conf: SparkConf, + val conf: SparkConf, provider: ApplicationHistoryProvider, securityManager: SecurityManager, port: Int) @@ -101,7 +101,6 @@ class HistoryServer( } } - initialize() /** * Initialize the history server. @@ -109,7 +108,7 @@ class HistoryServer( * This starts a background thread that periodically synchronizes information displayed on * this UI with the event logs in the provided base directory. */ - def initialize() { + override def doInitialize() { attachPage(new HistoryPage(this)) attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) @@ -125,8 +124,8 @@ class HistoryServer( } /** Stop the server and close the file system. */ - override def stop() { - super.stop() + override protected def doStop() { + super.doStop() provider.stop() } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index d86ec1e03e45c..54651ef11e05b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -33,10 +33,10 @@ class MasterWebUI(val master: Master, requestedPort: Int) val masterActorRef = master.self val timeout = AkkaUtils.askTimeout(master.conf) - initialize() + def conf = master.conf /** Initialize all components of the server. */ - def initialize() { + override def doInitialize() { attachPage(new ApplicationPage(this)) attachPage(new HistoryNotFoundPage(this)) attachPage(new MasterPage(this)) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index b07942a9ca729..b9ef8c91d2394 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -40,10 +40,10 @@ class WorkerWebUI( val timeout = AkkaUtils.askTimeout(worker.conf) - initialize() + def conf = worker.conf /** Initialize all components of the server. */ - def initialize() { + override def doInitialize() { val logPage = new LogPage(this) attachPage(logPage) attachPage(new WorkerPage(this)) diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala index 81b9056b40fb8..e60d2eb536083 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala @@ -50,11 +50,13 @@ private[spark] class ConsoleSink(val property: Properties, val registry: MetricR .convertRatesTo(TimeUnit.SECONDS) .build() - override def start() { + def conf = securityMgr.sparkConf + + override protected def doStart() { reporter.start(pollPeriod, pollUnit) } - override def stop() { + override protected def doStop() { reporter.stop() } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala index 9d5f2ae9328ad..5a81bbe8ee232 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala @@ -59,11 +59,13 @@ private[spark] class CsvSink(val property: Properties, val registry: MetricRegis .convertRatesTo(TimeUnit.SECONDS) .build(new File(pollDir)) - override def start() { + def conf = securityMgr.sparkConf + + override protected def doStart() { reporter.start(pollPeriod, pollUnit) } - override def stop() { + override protected def doStop() { reporter.stop() } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala index d7b5f5c40efae..1930606771949 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala @@ -74,11 +74,13 @@ private[spark] class GraphiteSink(val property: Properties, val registry: Metric .prefixedWith(prefix) .build(graphite) - override def start() { + def conf = securityMgr.sparkConf + + override protected def doStart() { reporter.start(pollPeriod, pollUnit) } - override def stop() { + override protected def doStop() { reporter.stop() } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala index 2588fe2c9edb8..7cee5c340154e 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala @@ -27,11 +27,13 @@ private[spark] class JmxSink(val property: Properties, val registry: MetricRegis val reporter: JmxReporter = JmxReporter.forRegistry(registry).build() - override def start() { + def conf = securityMgr.sparkConf + + override protected def doStart() { reporter.start() } - override def stop() { + override protected def doStop() { reporter.stop() } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala index 2f65bc8b46609..c5feb8bf23a4b 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala @@ -54,9 +54,11 @@ private[spark] class MetricsServlet(val property: Properties, val registry: Metr mapper.writeValueAsString(registry) } - override def start() { } - - override def stop() { } + def conf = securityMgr.sparkConf override def report() { } + + override protected def doStart() { } + + override protected def doStop() { } } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala index 0d83d8c425ca4..5adeb8cc07378 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala @@ -17,8 +17,8 @@ package org.apache.spark.metrics.sink -private[spark] trait Sink { - def start: Unit - def stop: Unit - def report(): Unit +import org.apache.spark.Lifecycle + +private[spark] trait Sink extends Lifecycle { +def report(): Unit } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 2ccc27324ac8c..3ee9cb4fff131 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -65,7 +65,7 @@ class DAGScheduler( blockManagerMaster: BlockManagerMaster, env: SparkEnv, clock: Clock = SystemClock) - extends Logging { + extends Logging with Lifecycle{ import DAGScheduler._ @@ -112,6 +112,8 @@ class DAGScheduler( // stray messages to detect. private val failedEpoch = new HashMap[String, Long] + def conf = env.conf + private val dagSchedulerActorSupervisor = env.actorSystem.actorOf(Props(new DAGSchedulerActorSupervisor(this))) @@ -134,7 +136,11 @@ class DAGScheduler( asInstanceOf[ActorRef] } - initializeEventProcessActor() + override protected def doStart() { + initializeEventProcessActor() + } + + start() // Called by TaskScheduler to report task's starting. def taskStarted(task: Task[_], taskInfo: TaskInfo) { @@ -1314,7 +1320,7 @@ class DAGScheduler( Nil } - def stop() { + override protected def doStop() { logInfo("Stopping DAGScheduler") dagSchedulerActorSupervisor ! PoisonPill taskScheduler.stop() diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala index e41e0a9841691..b4cc3b3deac4c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala @@ -17,15 +17,16 @@ package org.apache.spark.scheduler +import org.apache.spark.Service + /** * A backend interface for scheduling systems that allows plugging in different ones under * TaskSchedulerImpl. We assume a Mesos-like model where the application gets resource offers as * machines become available and can launch tasks on them. */ -private[spark] trait SchedulerBackend { - def start(): Unit - def stop(): Unit +private[spark] trait SchedulerBackend extends Service { def reviveOffers(): Unit + def defaultParallelism(): Int def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 1a0b877c8a5e1..bb0b8da990cdf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -17,6 +17,7 @@ package org.apache.spark.scheduler +import org.apache.spark.Lifecycle import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId @@ -29,22 +30,17 @@ import org.apache.spark.storage.BlockManagerId * them, retrying if there are failures, and mitigating stragglers. They return events to the * DAGScheduler. */ -private[spark] trait TaskScheduler { +private[spark] trait TaskScheduler extends Lifecycle { def rootPool: Pool def schedulingMode: SchedulingMode - def start(): Unit - // Invoked after system has successfully initialized (typically in spark context). // Yarn uses this to bootstrap allocation of resources based on preferred locations, // wait for slave registerations, etc. def postStartHook() { } - // Disconnect from the cluster. - def stop(): Unit - // Submit a sequence of tasks to run. def submitTasks(taskSet: TaskSet): Unit diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index ad051e59af86d..1792d970982ac 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -137,7 +137,7 @@ private[spark] class TaskSchedulerImpl( def newTaskId(): Long = nextTaskId.getAndIncrement() - override def start() { + override protected def doStart() { backend.start() if (!isLocal && conf.getBoolean("spark.speculation", false)) { @@ -389,7 +389,7 @@ private[spark] class TaskSchedulerImpl( } } - override def stop() { + override protected def doStop() { if (backend != null) { backend.stop() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 2a3711ae2a78c..4e12f378eb2d1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -27,7 +27,7 @@ import akka.actor._ import akka.pattern.ask import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} -import org.apache.spark.{SparkEnv, Logging, SparkException, TaskState} +import org.apache.spark.{Lifecycle, Logging, SparkEnv, SparkException, TaskState} import org.apache.spark.scheduler.{SchedulerBackend, SlaveLost, TaskDescription, TaskSchedulerImpl, WorkerOffer} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Utils} @@ -43,7 +43,7 @@ import org.apache.spark.ui.JettyUtils */ private[spark] class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: ActorSystem) - extends SchedulerBackend with Logging + extends SchedulerBackend with Logging with Lifecycle { // Use an atomic variable to track total number of cores in the cluster for simplicity and speed var totalCoreCount = new AtomicInteger(0) @@ -205,7 +205,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A var driverActor: ActorRef = null val taskIdsOnSlave = new HashMap[String, HashSet[String]] - override def start() { + override protected def doStart() { val properties = new ArrayBuffer[(String, String)] for ((key, value) <- scheduler.sc.conf.getAll) { if (key.startsWith("spark.")) { @@ -230,7 +230,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A } } - override def stop() { + override protected def doStop() { stopExecutors() try { if (driverActor != null) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index bc7670f4a804d..ffac4a853c07a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -36,8 +36,8 @@ private[spark] class SimrSchedulerBackend( val maxCores = conf.getInt("spark.simr.executor.cores", 1) - override def start() { - super.start() + override protected def doStart() { + super.doStart() val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format( SparkEnv.driverActorSystemName, @@ -63,10 +63,10 @@ private[spark] class SimrSchedulerBackend( fs.rename(tmpPath, filePath) } - override def stop() { - val conf = SparkHadoopUtil.get.newConfiguration(sc.conf) + override protected def doStop() { + val conf = SparkHadoopUtil.get.newConfiguration(sc.conf) val fs = FileSystem.get(conf) fs.delete(new Path(driverFilePath), false) - super.stop() + super.doStop() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 32138e5246700..871cb92496bdf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -40,6 +40,7 @@ private[spark] class SparkDeploySchedulerBackend( override def start() { super.start() + stopping = false // The endpoint for executors to talk to us val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format( diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 87e181e773fdf..f3775c86dfc3c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -87,8 +87,8 @@ private[spark] class CoarseMesosSchedulerBackend( id } - override def start() { - super.start() + override protected def doStart() { + super.doStart() synchronized { new Thread("CoarseMesosSchedulerBackend driver") { @@ -285,8 +285,8 @@ private[spark] class CoarseMesosSchedulerBackend( scheduler.error(message) } - override def stop() { - super.stop() + override protected def doStop() { + super.doStop() if (driver != null) { driver.stop() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 67ee4d66f151b..b7e669a73021a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -29,7 +29,7 @@ import org.apache.mesos.{Scheduler => MScheduler} import org.apache.mesos._ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} -import org.apache.spark.{Logging, SparkContext, SparkException, TaskState} +import org.apache.spark._ import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SchedulerBackend, SlaveLost, TaskDescription, TaskSchedulerImpl, WorkerOffer} import org.apache.spark.util.Utils @@ -42,7 +42,7 @@ private[spark] class MesosSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext, master: String) - extends SchedulerBackend + extends SchedulerBackend with Lifecycle with MScheduler with Logging { @@ -62,8 +62,9 @@ private[spark] class MesosSchedulerBackend( var classLoader: ClassLoader = null - override def start() { - synchronized { + override def conf = scheduler.conf + + override protected def doStart() { classLoader = Thread.currentThread.getContextClassLoader new Thread("MesosSchedulerBackend driver") { @@ -82,7 +83,6 @@ private[spark] class MesosSchedulerBackend( }.start() waitForRegister() - } } def createExecutorInfo(execId: String): ExecutorInfo = { @@ -311,7 +311,7 @@ private[spark] class MesosSchedulerBackend( } } - override def stop() { + override protected def doStop() { if (driver != null) { driver.stop() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index bec9502f20466..8ee681a124306 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -21,7 +21,7 @@ import java.nio.ByteBuffer import akka.actor.{Actor, ActorRef, Props} -import org.apache.spark.{Logging, SparkEnv, TaskState} +import org.apache.spark.{Lifecycle, Logging, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} @@ -86,17 +86,19 @@ private[spark] class LocalActor( * on a single Executor (created by the LocalBackend) running locally. */ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: Int) - extends SchedulerBackend with ExecutorBackend { + extends SchedulerBackend with ExecutorBackend with Lifecycle { var localActor: ActorRef = null - override def start() { + override def conf = scheduler.conf + + override protected def doStart() { localActor = SparkEnv.get.actorSystem.actorOf( Props(new LocalActor(scheduler, this, totalCores)), "LocalBackendActor") } - override def stop() { + override protected def doStop() { localActor ! StopExecutor } diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index cccd59d122a92..09f1e72780311 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -57,10 +57,9 @@ private[spark] class SparkUI( // Maintain executor storage status through Spark events val storageStatusListener = new StorageStatusListener - initialize() /** Initialize all components of the server. */ - def initialize() { + override def doInitialize() { listenerBus.addListener(storageStatusListener) val jobProgressTab = new JobProgressTab(this) attachTab(jobProgressTab) @@ -89,8 +88,8 @@ private[spark] class SparkUI( } /** Stop the server behind this web interface. Only valid after bind(). */ - override def stop() { - super.stop() + override protected def doStop() { + super.doStop() logInfo("Stopped Spark web UI at %s".format(appUIAddress)) } diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 5d88ca403a674..30fa8c55febb7 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -25,7 +25,7 @@ import scala.xml.Node import org.eclipse.jetty.servlet.ServletContextHandler import org.json4s.JsonAST.{JNothing, JValue} -import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.{Lifecycle, Logging, SecurityManager, SparkConf} import org.apache.spark.ui.JettyUtils._ import org.apache.spark.util.Utils @@ -41,7 +41,7 @@ private[spark] abstract class WebUI( conf: SparkConf, basePath: String = "", name: String = "") - extends Logging { + extends Logging with Lifecycle { protected val tabs = ArrayBuffer[WebUITab]() protected val handlers = ArrayBuffer[ServletContextHandler]() @@ -92,12 +92,10 @@ private[spark] abstract class WebUI( } } - /** Initialize all components of the server. */ - def initialize() - /** Bind to the HTTP server behind this web interface. */ def bind() { assert(!serverInfo.isDefined, "Attempted to bind %s more than once!".format(className)) + super.start() try { serverInfo = Some(startJettyServer("0.0.0.0", port, handlers, conf, name)) logInfo("Started %s at http://%s:%d".format(className, publicHostName, boundPort)) @@ -112,11 +110,12 @@ private[spark] abstract class WebUI( def boundPort: Int = serverInfo.map(_.boundPort).getOrElse(-1) /** Stop the server behind this web interface. Only valid after bind(). */ - def stop() { + override protected def doStop() { assert(serverInfo.isDefined, "Attempted to stop %s before binding to a server!".format(className)) serverInfo.get.server.stop() } + override protected def doStart() { } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 1a42fc1b233ba..dee692db66b57 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -78,8 +78,9 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F val taskScheduler = new TaskScheduler() { override def rootPool: Pool = null override def schedulingMode: SchedulingMode = SchedulingMode.NONE - override def start() = {} - override def stop() = {} + override protected def doStart() = { } + override protected def doStop() = { } + override def conf: SparkConf = null override def executorHeartbeatReceived(execId: String, taskMetrics: Array[(Long, TaskMetrics)], blockManagerId: BlockManagerId): Boolean = true override def submitTasks(taskSet: TaskSet) = { @@ -349,10 +350,11 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F // make sure that the DAGScheduler doesn't crash when the TaskScheduler // doesn't implement killTask() val noKillTaskScheduler = new TaskScheduler() { + def conf: SparkConf = null override def rootPool: Pool = null override def schedulingMode: SchedulingMode = SchedulingMode.NONE - override def start() = {} - override def stop() = {} + override def doStart() = {} + override def doStop() = {} override def submitTasks(taskSet: TaskSet) = { taskSets += taskSet } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 7532da88c6065..92c8ced7259ef 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -23,9 +23,10 @@ import org.scalatest.FunSuite import org.apache.spark._ -class FakeSchedulerBackend extends SchedulerBackend { - def start() {} - def stop() {} +class FakeSchedulerBackend extends SchedulerBackend with Lifecycle { + def conf: SparkConf = null + def doStart() {} + def doStop() {} def reviveOffers() {} def defaultParallelism() = 1 } diff --git a/extras/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala b/extras/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala index 3b1880e143513..a069c0d79607a 100644 --- a/extras/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala +++ b/extras/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala @@ -75,11 +75,11 @@ class GangliaSink(val property: Properties, val registry: MetricRegistry, .convertRatesTo(TimeUnit.SECONDS) .build(ganglia) - override def start() { + override protected def doStart() { reporter.start(pollPeriod, pollUnit) } - override def stop() { + override protected def doStop() { reporter.stop() } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 101cec1c7a7c2..bacc34d974eb4 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -55,7 +55,7 @@ class StreamingContext private[streaming] ( sc_ : SparkContext, cp_ : Checkpoint, batchDur_ : Duration - ) extends Logging { + ) extends Logging with Lifecycle { /** * Create a StreamingContext using an existing SparkContext. @@ -122,7 +122,7 @@ class StreamingContext private[streaming] ( } } - private[streaming] val conf = sc.conf + def conf = sc.conf private[streaming] val env = SparkEnv.get @@ -164,15 +164,6 @@ class StreamingContext private[streaming] ( private val streamingSource = new StreamingSource(this) SparkEnv.get.metricsSystem.registerSource(streamingSource) - /** Enumeration to identify current state of the StreamingContext */ - private[streaming] object StreamingContextState extends Enumeration { - type CheckpointState = Value - val Initialized, Started, Stopped = Value - } - - import StreamingContextState._ - private[streaming] var state = Initialized - /** * Return the associated Spark context */ @@ -431,18 +422,18 @@ class StreamingContext private[streaming] ( /** * Start the execution of the streams. */ - def start(): Unit = synchronized { + override def start() { // Throw exception if the context has already been started once // or if a stopped context is being started again - if (state == Started) { - throw new SparkException("StreamingContext has already been started") - } - if (state == Stopped) { + if (stopped) { throw new SparkException("StreamingContext has already been stopped") } + super.start() + } + + override protected def doStart() { validate() scheduler.start() - state = Started } /** @@ -462,13 +453,16 @@ class StreamingContext private[streaming] ( waiter.waitForStopOrError(timeout) } + override def stop() { + stop(true) + } /** * Stop the execution of the streams immediately (does not wait for all received data * to be processed). * @param stopSparkContext Stop the associated SparkContext or not * */ - def stop(stopSparkContext: Boolean = true): Unit = synchronized { + def stop(stopSparkContext: Boolean): Unit = synchronized { stop(stopSparkContext, false) } @@ -482,11 +476,11 @@ class StreamingContext private[streaming] ( def stop(stopSparkContext: Boolean, stopGracefully: Boolean): Unit = synchronized { // Warn (but not fail) if context is stopped twice, // or context is stopped before starting - if (state == Initialized) { + if (uninitialized || initialized) { logWarning("StreamingContext has not been started yet") return } - if (state == Stopped) { + if (stopped) { logWarning("StreamingContext has already been stopped") return } // no need to throw an exception as its okay to stop twice @@ -494,8 +488,9 @@ class StreamingContext private[streaming] ( logInfo("StreamingContext stopped successfully") waiter.notifyStop() if (stopSparkContext) sc.stop() - state = Stopped + super.stop() } + override protected def doStop() { } } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 18605cac7006c..7e4e4d1d87bee 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -27,7 +27,7 @@ import java.util.{List => JList, Map => JMap} import akka.actor.{Props, SupervisorStrategy} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.{Lifecycle, SparkConf, SparkContext} import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2} import org.apache.spark.rdd.RDD @@ -49,7 +49,7 @@ import org.apache.spark.streaming.receiver.Receiver * respectively. `context.awaitTransformation()` allows the current thread to wait for the * termination of a context by `stop()` or by an exception. */ -class JavaStreamingContext(val ssc: StreamingContext) { +class JavaStreamingContext(val ssc: StreamingContext) extends Lifecycle { /** * Create a StreamingContext. @@ -148,6 +148,8 @@ class JavaStreamingContext(val ssc: StreamingContext) { @deprecated("use sparkContext", "0.9.0") val sc: JavaSparkContext = sparkContext + override def conf = sparkContext.conf + /** * Create an input stream from network source hostname:port. Data is received using * a TCP socket and the receive bytes is interpreted as UTF8 encoded \n delimited @@ -497,7 +499,7 @@ class JavaStreamingContext(val ssc: StreamingContext) { /** * Start the execution of the streams. */ - def start(): Unit = { + override protected def doStart(): Unit = { ssc.start() } @@ -521,15 +523,17 @@ class JavaStreamingContext(val ssc: StreamingContext) { /** * Stop the execution of the streams. Will stop the associated JavaSparkContext as well. */ - def stop(): Unit = { - ssc.stop() + override def stop(): Unit = { + stop(true) } /** * Stop the execution of the streams. * @param stopSparkContext Stop the associated SparkContext or not */ - def stop(stopSparkContext: Boolean) = ssc.stop(stopSparkContext) + def stop(stopSparkContext: Boolean): Unit = { + stop(stopSparkContext,false) + } /** * Stop the execution of the streams. @@ -538,8 +542,13 @@ class JavaStreamingContext(val ssc: StreamingContext) { * received data to be completed */ def stop(stopSparkContext: Boolean, stopGracefully: Boolean) = { - ssc.stop(stopSparkContext, stopGracefully) + if (ssc.started) { + ssc.stop(stopSparkContext, stopGracefully) + super.stop() + } } + + override protected def doStop() { } } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 374848358e700..f63479b89473e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -18,7 +18,7 @@ package org.apache.spark.streaming.scheduler import akka.actor.{ActorRef, ActorSystem, Props, Actor} -import org.apache.spark.{SparkException, SparkEnv, Logging} +import org.apache.spark.{Lifecycle, SparkException, SparkEnv, Logging} import org.apache.spark.streaming.{Checkpoint, Time, CheckpointWriter} import org.apache.spark.streaming.util.{ManualClock, RecurringTimer, Clock} import scala.util.{Failure, Success, Try} @@ -35,10 +35,9 @@ private[scheduler] case class ClearCheckpointData(time: Time) extends JobGenerat * up DStream metadata. */ private[streaming] -class JobGenerator(jobScheduler: JobScheduler) extends Logging { +class JobGenerator(jobScheduler: JobScheduler) extends Logging with Lifecycle { private val ssc = jobScheduler.ssc - private val conf = ssc.conf private val graph = ssc.graph val clock = { @@ -67,8 +66,10 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { // last batch whose completion,checkpointing and metadata cleanup has been completed private var lastProcessedBatch: Time = null + def conf = ssc.conf + /** Start generation of jobs */ - def start(): Unit = synchronized { + override protected def doStart(): Unit = { if (eventActor != null) return // generator has already been started eventActor = ssc.env.actorSystem.actorOf(Props(new Actor { @@ -88,6 +89,12 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { * of current ongoing time interval has been generated, processed and corresponding * checkpoints written. */ + override def stop() { + stop(true) + } + + override def doStop() { } + def stop(processReceivedData: Boolean): Unit = synchronized { if (eventActor == null) return // generator has already been stopped @@ -141,6 +148,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { // Stop the actor and checkpoint writer if (shouldCheckpoint) checkpointWriter.stop() ssc.env.actorSystem.stop(eventActor) + super.stop() logInfo("Stopped JobGenerator") } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 1b034b9fb187c..7463d7191bf5a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -21,7 +21,7 @@ import scala.util.{Failure, Success, Try} import scala.collection.JavaConversions._ import java.util.concurrent.{TimeUnit, ConcurrentHashMap, Executors} import akka.actor.{ActorRef, Actor, Props} -import org.apache.spark.{SparkException, Logging, SparkEnv} +import org.apache.spark.{Lifecycle, SparkException, Logging, SparkEnv} import org.apache.spark.streaming._ @@ -35,22 +35,23 @@ private[scheduler] case class ErrorReported(msg: String, e: Throwable) extends J * the jobs and runs them using a thread pool. */ private[streaming] -class JobScheduler(val ssc: StreamingContext) extends Logging { +class JobScheduler(val ssc: StreamingContext) extends Logging with Lifecycle { private val jobSets = new ConcurrentHashMap[Time, JobSet] private val numConcurrentJobs = ssc.conf.getInt("spark.streaming.concurrentJobs", 1) private val jobExecutor = Executors.newFixedThreadPool(numConcurrentJobs) private val jobGenerator = new JobGenerator(this) val clock = jobGenerator.clock - val listenerBus = new StreamingListenerBus() + val listenerBus = new StreamingListenerBus(ssc.conf) // These two are created only when scheduler starts. // eventActor not being null means the scheduler has been started and not stopped var receiverTracker: ReceiverTracker = null private var eventActor: ActorRef = null + def conf = ssc.conf - def start(): Unit = synchronized { + override protected def doStart(): Unit = { if (eventActor != null) return // scheduler has already been started logDebug("Starting JobScheduler") @@ -67,6 +68,12 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { logInfo("Started JobScheduler") } + override protected def doStop() { } + + override def stop() { + stop(true) + } + def stop(processAllReceivedData: Boolean): Unit = synchronized { if (eventActor == null) return // scheduler has already been stopped logDebug("Stopping JobScheduler") @@ -97,6 +104,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { listenerBus.stop() ssc.env.actorSystem.stop(eventActor) eventActor = null + super.stop() logInfo("Stopped JobScheduler") } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 5307fe189d717..fbfb5801e3331 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.{HashMap, SynchronizedMap, SynchronizedQueue} import scala.language.existentials import akka.actor._ -import org.apache.spark.{Logging, SparkEnv, SparkException} +import org.apache.spark.{Lifecycle, Logging, SparkEnv, SparkException} import org.apache.spark.SparkContext._ import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.{StreamingContext, Time} @@ -59,7 +59,7 @@ private[streaming] case class DeregisterReceiver(streamId: Int, msg: String, err * has been called because it needs the final set of input streams at the time of instantiation. */ private[streaming] -class ReceiverTracker(ssc: StreamingContext) extends Logging { +class ReceiverTracker(ssc: StreamingContext) extends Logging with Lifecycle { val receiverInputStreams = ssc.graph.getReceiverInputStreams() val receiverInputStreamMap = Map(receiverInputStreams.map(x => (x.id, x)): _*) @@ -75,8 +75,10 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { var actor: ActorRef = null var currentTime: Time = null + override def conf = ssc.conf + /** Start the actor and receiver execution thread. */ - def start() = synchronized { + override protected def doStart() = { if (actor != null) { throw new SparkException("ReceiverTracker already started") } @@ -90,7 +92,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { } /** Stop the receiver execution thread. */ - def stop() = synchronized { + override protected def doStop() = { if (!receiverInputStreams.isEmpty && actor != null) { // First, stop the receivers receiverExecutor.stop() @@ -197,7 +199,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { } /** This thread class runs all the receivers on the cluster. */ - class ReceiverLauncher { + class ReceiverLauncher extends Lifecycle { @transient val env = ssc.env @transient val thread = new Thread() { override def run() { @@ -210,11 +212,13 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { } } - def start() { + override def conf = ssc.conf + + override protected def doStart() { thread.start() } - def stop() { + override protected def doStop() { // Send the stop signal to all the receivers stopReceivers() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala index 398724d9e8130..1934ef645a422 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala @@ -17,12 +17,12 @@ package org.apache.spark.streaming.scheduler -import org.apache.spark.Logging +import org.apache.spark.{Lifecycle, SparkConf, Logging} import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} import java.util.concurrent.LinkedBlockingQueue /** Asynchronously passes StreamingListenerEvents to registered StreamingListeners. */ -private[spark] class StreamingListenerBus() extends Logging { +private[spark] class StreamingListenerBus(val conf: SparkConf) extends Logging with Lifecycle { private val listeners = new ArrayBuffer[StreamingListener]() with SynchronizedBuffer[StreamingListener] @@ -59,7 +59,7 @@ private[spark] class StreamingListenerBus() extends Logging { } } - def start() { + override def doStart() { listenerThread.start() } @@ -95,5 +95,5 @@ private[spark] class StreamingListenerBus() extends Logging { true } - def stop(): Unit = post(StreamingListenerShutdown) + override protected def doStop(): Unit = post(StreamingListenerShutdown) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 7b33d3b235466..0b79e96f38d0a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.streaming import java.util.concurrent.atomic.AtomicInteger -import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException} +import org.apache.spark._ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver @@ -108,11 +108,11 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w ssc = new StreamingContext(master, appName, batchDuration) addInputStream(ssc).register - assert(ssc.state === ssc.StreamingContextState.Initialized) + assert(ssc.uninitialized) ssc.start() - assert(ssc.state === ssc.StreamingContextState.Started) + assert(ssc.started) ssc.stop() - assert(ssc.state === ssc.StreamingContextState.Stopped) + assert(ssc.stopped) } test("start multiple times") { diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index a5f537dd9de30..09e5e107cf1b7 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -49,8 +49,8 @@ private[spark] class YarnClientSchedulerBackend( } } - override def start() { - super.start() + override protected def doStart() { + super.doStart() val driverHost = conf.get("spark.driver.host") val driverPort = conf.get("spark.driver.port") @@ -140,9 +140,9 @@ private[spark] class YarnClientSchedulerBackend( t } - override def stop() { + override protected def doStop() { stopping = true - super.stop() + super.doStop() client.stop logInfo("Stopped") }