diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 1a2443f7ee78..ed01c2e137f6 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -236,6 +236,8 @@ class SparkContext(config: SparkConf) extends Logging { def appName: String = _conf.get("spark.app.name") private[spark] def isEventLogEnabled: Boolean = _conf.getBoolean("spark.eventLog.enabled", false) + private[spark] def isEventLogAsync: Boolean = _conf.getBoolean("spark.eventLog.async", false) + private[spark] def eventLogDir: Option[URI] = _eventLogDir private[spark] def eventLogCodec: Option[String] = _eventLogCodec @@ -525,9 +527,7 @@ class SparkContext(config: SparkConf) extends Logging { _eventLogger = if (isEventLogEnabled) { - val logger = - new EventLoggingListener(_applicationId, _applicationAttemptId, _eventLogDir.get, - _conf, _hadoopConfiguration) + val logger = getEventLogger(isEventLogAsync) logger.start() listenerBus.addListener(logger) Some(logger) @@ -593,6 +593,22 @@ class SparkContext(config: SparkConf) extends Logging { } } + private def getEventLogger(async: Boolean): EventLoggingListener = { + if (async) { + val queueSize = _conf.get(LISTENER_BUS_EVENT_QUEUE_CAPACITY) + new AsynchronousEventLoggingListener(_applicationId, + _applicationAttemptId, + _eventLogDir.get, + _conf, + _hadoopConfiguration, + queueSize) + } + else { + new EventLoggingListener(_applicationId, _applicationAttemptId, _eventLogDir.get, + _conf, _hadoopConfiguration) + } + } + /** * Called by the web UI to obtain executor thread dumps. This method may be expensive. * Logs an error and returns None if we failed to obtain a thread dump, which could occur due diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 4ad04b04c312..fcbdb8dd86fb 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -158,6 +158,11 @@ package object config { .checkValue(_ > 0, "The capacity of listener bus event queue must not be negative") .createWithDefault(10000) + private[spark] val LISTENER_BUS_EVENT_QUEUE_DROP = + ConfigBuilder("spark.scheduler.listenerbus.eventqueue.drop") + .booleanConf + .createWithDefault(true) + // This property sets the root namespace for metrics reporting private[spark] val METRICS_NAMESPACE = ConfigBuilder("spark.metrics.namespace") .stringConf diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index f48143633224..3f5553f82f6d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -21,6 +21,7 @@ import java.io._ import java.net.URI import java.nio.charset.StandardCharsets import java.util.Locale +import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -131,20 +132,24 @@ private[spark] class EventLoggingListener( } /** Log the event as JSON. */ - private def logEvent(event: SparkListenerEvent, flushLogger: Boolean = false) { + protected def logEvent(event: SparkListenerEvent, flushLogger: Boolean = false) { val eventJson = JsonProtocol.sparkEventToJson(event) // scalastyle:off println writer.foreach(_.println(compact(render(eventJson)))) // scalastyle:on println if (flushLogger) { - writer.foreach(_.flush()) - hadoopDataStream.foreach(_.hflush()) + flush() } if (testing) { loggedEvents += eventJson } } + private def flush(): Unit = { + writer.foreach(_.flush()) + hadoopDataStream.foreach(_.hflush()) + } + // Events that do not trigger a flush override def onStageSubmitted(event: SparkListenerStageSubmitted): Unit = logEvent(event) @@ -227,6 +232,7 @@ private[spark] class EventLoggingListener( * ".inprogress" suffix. */ def stop(): Unit = { + flush() writer.foreach(_.close()) val target = new Path(logPath) @@ -250,6 +256,8 @@ private[spark] class EventLoggingListener( } } + + private[spark] def redactEvent( event: SparkListenerEnvironmentUpdate): SparkListenerEnvironmentUpdate = { // environmentDetails maps a string descriptor to a set of properties @@ -267,11 +275,89 @@ private[spark] class EventLoggingListener( } +private[spark] sealed class AsynchronousEventLoggingListener( + appId: String, + appAttemptId : Option[String], + logBaseDir: URI, + sparkConf: SparkConf, + hadoopConf: Configuration, + val bufferSize: Int) + extends EventLoggingListener(appId, appAttemptId, logBaseDir, sparkConf, hadoopConf) { + import EventLoggingListener._ + + private lazy val eventBuffer = new Array[SparkListenerEvent](bufferSize) + + private val numberOfEvents = new AtomicInteger(0) + + @volatile private var writeIndex = 0 + @volatile private var readIndex = 0 + @volatile private var stopThread = false + @volatile private var lastReportTimestamp = 0L + @volatile private var numberOfDrop = 0 + @volatile private var lastFlushEvent = 0 + + private val listenerThread = new Thread(THREAD_NAME) { + setDaemon(true) + override def run(): Unit = { + while (!stopThread || numberOfEvents.get() > 0) { + if (numberOfEvents.get() > 0) { + executelogEvent(eventBuffer(readIndex), lastFlushEvent == FLUSH_FREQUENCY) + numberOfEvents.decrementAndGet() + readIndex = (readIndex + 1) % bufferSize + if (lastFlushEvent == FLUSH_FREQUENCY) { + lastFlushEvent = 0 + } else { + lastFlushEvent = lastFlushEvent + 1 + } + } else { + Thread.sleep(20) // give more chance for producer thread to be scheduled + } + } + } + } + + private def executelogEvent(event: SparkListenerEvent, flushLogger: Boolean) = + super.logEvent(event, flushLogger) + + override protected def logEvent(event: SparkListenerEvent, flushLogger: Boolean): Unit = { + if (numberOfEvents.get() < bufferSize) { + eventBuffer(writeIndex) = event + numberOfEvents.incrementAndGet() + writeIndex = (writeIndex + 1) % bufferSize + } else { + numberOfDrop = numberOfDrop + 1 + if (System.currentTimeMillis() - lastReportTimestamp >= 60 * 1000) { + val prevLastReportTimestamp = lastReportTimestamp + lastReportTimestamp = System.currentTimeMillis() + logWarning( + s"dropped $numberOfDrop SparkListenerEvents since " + + new java.util.Date(prevLastReportTimestamp)) + numberOfDrop = 0 + } + } + } + + override def start(): Unit = { + super.start() + listenerThread.start() + } + + override def stop(): Unit = { + stopThread = true + listenerThread.join() + super.stop() + } + +} + private[spark] object EventLoggingListener extends Logging { // Suffix applied to the names of files still being written by applications. val IN_PROGRESS = ".inprogress" val DEFAULT_LOG_DIR = "/tmp/spark-events" + val THREAD_NAME = "EventLoggingListener" + val FLUSH_FREQUENCY = 200 + private val LOG_FILE_PERMISSIONS = new FsPermission(Integer.parseInt("770", 8).toShort) // A cache for compression codecs to avoid creating the same codec many times diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index 801dfaa62306..14b672fe8412 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -17,12 +17,14 @@ package org.apache.spark.scheduler -import java.util.concurrent._ -import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} +import java.util.concurrent.TimeoutException +import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} +import java.util.concurrent.locks.ReentrantLock import scala.util.DynamicVariable import org.apache.spark.SparkContext +import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.util.Utils @@ -38,55 +40,45 @@ private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends Spa import LiveListenerBus._ - // Cap the capacity of the event queue so we get an explicit error (rather than - // an OOM exception) if it's perpetually being added to more quickly than it's being drained. - private lazy val eventQueue = new LinkedBlockingQueue[SparkListenerEvent]( - sparkContext.conf.get(LISTENER_BUS_EVENT_QUEUE_CAPACITY)) + private lazy val BUFFER_SIZE = sparkContext.conf.get(LISTENER_BUS_EVENT_QUEUE_CAPACITY) + private lazy val circularBuffer = new Array[SparkListenerEvent](BUFFER_SIZE) - // Indicate if `start()` is called - private val started = new AtomicBoolean(false) - // Indicate if `stop()` is called - private val stopped = new AtomicBoolean(false) + private lazy val queueStrategy = getQueueStrategy - /** A counter for dropped events. It will be reset every time we log it. */ - private val droppedEventsCounter = new AtomicLong(0L) - /** When `droppedEventsCounter` was logged last time in milliseconds. */ - @volatile private var lastReportTimestamp = 0L + private def getQueueStrategy: QueuingStrategy = { + val queueDrop = sparkContext.conf.get(LISTENER_BUS_EVENT_QUEUE_DROP) + if (queueDrop) { + new DropQueuingStrategy(BUFFER_SIZE) + } else { + new WaitQueuingStrategy(BUFFER_SIZE) + } + } + + private val numberOfEvents = new AtomicInteger(0) - // Indicate if we are processing some event - // Guarded by `self` - private var processingEvent = false + @volatile private var writeIndex = 0 + @volatile private var readIndex = 0 - private val logDroppedEvent = new AtomicBoolean(false) + // Indicate if `start()` is called + private val started = new AtomicBoolean(false) + // Indicate if `stop()` is called + private val stopped = new AtomicBoolean(false) - // A counter that represents the number of events produced and consumed in the queue - private val eventLock = new Semaphore(0) + // only post is done from multiple threads so need a lock + private val postLock = new ReentrantLock() private val listenerThread = new Thread(name) { setDaemon(true) override def run(): Unit = Utils.tryOrStopSparkContext(sparkContext) { LiveListenerBus.withinListenerThread.withValue(true) { - while (true) { - eventLock.acquire() - self.synchronized { - processingEvent = true - } - try { - val event = eventQueue.poll - if (event == null) { - // Get out of the while loop and shutdown the daemon thread - if (!stopped.get) { - throw new IllegalStateException("Polling `null` from eventQueue means" + - " the listener bus has been stopped. So `stopped` must be true") - } - return - } - postToAll(event) - } finally { - self.synchronized { - processingEvent = false - } + while (!stopped.get() || numberOfEvents.get() > 0) { + if (numberOfEvents.get() > 0) { + postToAll(circularBuffer(readIndex)) + numberOfEvents.decrementAndGet() + readIndex = (readIndex + 1) % BUFFER_SIZE + } else { + Thread.sleep(20) // give more chance for producer thread to be scheduled } } } @@ -115,30 +107,14 @@ private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends Spa logError(s"$name has already stopped! Dropping event $event") return } - val eventAdded = eventQueue.offer(event) - if (eventAdded) { - eventLock.release() - } else { - onDropEvent(event) - droppedEventsCounter.incrementAndGet() - } - - val droppedEvents = droppedEventsCounter.get - if (droppedEvents > 0) { - // Don't log too frequently - if (System.currentTimeMillis() - lastReportTimestamp >= 60 * 1000) { - // There may be multiple threads trying to decrease droppedEventsCounter. - // Use "compareAndSet" to make sure only one thread can win. - // And if another thread is increasing droppedEventsCounter, "compareAndSet" will fail and - // then that thread will update it. - if (droppedEventsCounter.compareAndSet(droppedEvents, 0)) { - val prevLastReportTimestamp = lastReportTimestamp - lastReportTimestamp = System.currentTimeMillis() - logWarning(s"Dropped $droppedEvents SparkListenerEvents since " + - new java.util.Date(prevLastReportTimestamp)) - } - } + postLock.lock() + val queueOrNot = queueStrategy.queue(numberOfEvents) + if(queueOrNot) { + circularBuffer(writeIndex) = event + numberOfEvents.incrementAndGet() + writeIndex = (writeIndex + 1) % BUFFER_SIZE } + postLock.unlock() } /** @@ -170,10 +146,15 @@ private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends Spa /** * Return whether the event queue is empty. * - * The use of synchronized here guarantees that all events that once belonged to this queue + * The use of the post lock here guarantees that all events that once belonged to this queue * have already been processed by all attached listeners, if this returns true. */ - private def queueIsEmpty: Boolean = synchronized { eventQueue.isEmpty && !processingEvent } + private def queueIsEmpty: Boolean = { + postLock.lock() + val isEmpty = numberOfEvents.get() == 0 + postLock.unlock() + isEmpty + } /** * Stop the listener bus. It will wait until the queued events have been processed, but drop the @@ -183,30 +164,10 @@ private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends Spa if (!started.get()) { throw new IllegalStateException(s"Attempted to stop $name that has not yet started!") } - if (stopped.compareAndSet(false, true)) { - // Call eventLock.release() so that listenerThread will poll `null` from `eventQueue` and know - // `stop` is called. - eventLock.release() - listenerThread.join() - } else { - // Keep quiet - } + stopped.set(true) + listenerThread.join() } - /** - * If the event queue exceeds its capacity, the new events will be dropped. The subclasses will be - * notified with the dropped events. - * - * Note: `onDropEvent` can be called in any thread. - */ - def onDropEvent(event: SparkListenerEvent): Unit = { - if (logDroppedEvent.compareAndSet(false, true)) { - // Only log the following message once to avoid duplicated annoying logs. - logError("Dropping SparkListenerEvent because no remaining room in event queue. " + - "This likely means one of the SparkListeners is too slow and cannot keep up with " + - "the rate at which tasks are being started by the scheduler.") - } - } } private[spark] object LiveListenerBus { @@ -215,5 +176,86 @@ private[spark] object LiveListenerBus { /** The thread name of Spark listener bus */ val name = "SparkListenerBus" + + private trait FirstAndRecurrentLogging extends Logging { + + @volatile private var numberOfTime = 0 + /** When `numberOfTime` was logged last time in milliseconds. */ + @volatile private var lastReportTimestamp = 0L + @volatile private var logFirstTime = false + + + def inc(): Unit = { + numberOfTime = numberOfTime + 1 + } + + def waringIfNotToClose(message: Int => String): Unit = { + if (numberOfTime > 0 && + (System.currentTimeMillis() - lastReportTimestamp >= 60 * 1000)) { + val prevLastReportTimestamp = lastReportTimestamp + lastReportTimestamp = System.currentTimeMillis() + logWarning(s"${message(numberOfTime)} SparkListenerEvents since " + + new java.util.Date(prevLastReportTimestamp)) + numberOfTime = 0 + } + } + + def errorIfFirstTime(firstTimeAction: String): Unit = { + if (!logFirstTime) { + // Only log the following message once to avoid duplicated annoying logs. + logError(s"$firstTimeAction SparkListenerEvent because no remaining room in event" + + " queue. " + + "This likely means one of the SparkListeners is too slow and cannot keep up with " + + "the rate at which tasks are being started by the scheduler.") + logFirstTime = true + lastReportTimestamp = System.currentTimeMillis() + } + } + + } + + private trait QueuingStrategy { + /** + * this method indicate if an element should be queued or discarded + * @param numberOfEvents atomic integer: the queue size + * @return true if an element should be queued, false if it should be dropped + */ + def queue(numberOfEvents: AtomicInteger): Boolean + + } + + private class DropQueuingStrategy(val bufferSize: Int) + extends QueuingStrategy with FirstAndRecurrentLogging { + + override def queue(numberOfEvents: AtomicInteger): Boolean = { + if (numberOfEvents.get() == bufferSize) { + errorIfFirstTime("Dropping") + inc() + waringIfNotToClose(count => s"Dropped $count") + false + } else { + true + } + } + + } + + private class WaitQueuingStrategy(val bufferSize: Int) + extends QueuingStrategy with FirstAndRecurrentLogging { + + override def queue(numberOfEvents: AtomicInteger): Boolean = { + if (numberOfEvents.get() == bufferSize) { + errorIfFirstTime("Waiting for posting") + waringIfNotToClose(count => s"Waiting $count period posting") + while (numberOfEvents.get() == bufferSize) { + inc() + Thread.sleep(20) // give more chance for consumer thread to be scheduled + } + } + true + } + + } + } diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index fa5ad4e8d81e..b7af077988ec 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -17,9 +17,8 @@ package org.apache.spark.util -import java.util.concurrent.CopyOnWriteArrayList +import java.util.concurrent.atomic.AtomicReference -import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.util.control.NonFatal @@ -31,13 +30,20 @@ import org.apache.spark.internal.Logging private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { // Marked `private[spark]` for access in tests. - private[spark] val listeners = new CopyOnWriteArrayList[L] + private[spark] def listeners = internalHolder.get() + private val internalHolder = new AtomicReference[Array[L]](Array.empty.asInstanceOf[Array[L]]) /** * Add a listener to listen events. This method is thread-safe and can be called in any thread. */ final def addListener(listener: L): Unit = { - listeners.add(listener) + + var oldVal, candidate: Array[L] = null + do { + oldVal = listeners + candidate = oldVal.:+(listener)(oldVal.elemTag) + // This creates a new array so we can compare reference + } while (!internalHolder.compareAndSet(oldVal, candidate)) } /** @@ -45,7 +51,11 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { * in any thread. */ final def removeListener(listener: L): Unit = { - listeners.remove(listener) + var oldVal, candidate: Array[L] = null + do { + oldVal = listeners + candidate = oldVal.filter(l => !l.equals(listener)) + } while (!internalHolder.compareAndSet(oldVal, candidate)) } /** @@ -53,18 +63,16 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { * `postToAll` in the same thread for all events. */ def postToAll(event: E): Unit = { - // JavaConverters can create a JIterableWrapper if we use asScala. - // However, this method will be called frequently. To avoid the wrapper cost, here we use - // Java Iterator directly. - val iter = listeners.iterator - while (iter.hasNext) { - val listener = iter.next() + val currentVal = listeners + var i = 0 + while(i < currentVal.length) { try { - doPostEvent(listener, event) + doPostEvent(currentVal(i), event) } catch { case NonFatal(e) => - logError(s"Listener ${Utils.getFormattedClassName(listener)} threw an exception", e) + logError(s"Listener ${Utils.getFormattedClassName(currentVal(i))} threw an exception", e) } + i = i + 1 } } @@ -76,7 +84,7 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { private[spark] def findListenersByClass[T <: L : ClassTag](): Seq[T] = { val c = implicitly[ClassTag[T]].runtimeClass - listeners.asScala.filter(_.getClass == c).map(_.asInstanceOf[T]).toSeq + listeners.filter(_.getClass == c).map(_.asInstanceOf[T]) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 4c3d0b102152..0a86dfc9bc33 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -85,13 +85,15 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit } } - test("End-to-end event logging") { - testApplicationEventLogging() - } + Seq(false, true).foreach{ async => + test((if (async) "Async " else "") + "End-to-end event logging") { + testApplicationEventLogging(None, async) + } - test("End-to-end event logging with compression") { - CompressionCodec.ALL_COMPRESSION_CODECS.foreach { codec => - testApplicationEventLogging(compressionCodec = Some(CompressionCodec.getShortName(codec))) + test((if (async) "Async " else "") + "End-to-end event logging with compression") { + CompressionCodec.ALL_COMPRESSION_CODECS.foreach { codec => + testApplicationEventLogging(Some(CompressionCodec.getShortName(codec)), async) + } } } @@ -189,11 +191,12 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit * Test end-to-end event logging functionality in an application. * This runs a simple Spark job and asserts that the expected events are logged when expected. */ - private def testApplicationEventLogging(compressionCodec: Option[String] = None) { + private def testApplicationEventLogging(compressionCodec: Option[String], asynchronous: Boolean) { // Set defaultFS to something that would cause an exception, to make sure we don't run // into SPARK-6688. val conf = getLoggingConf(testDirPath, compressionCodec) .set("spark.hadoop.fs.defaultFS", "unsupported://example.com") + .set("spark.eventLog.async", asynchronous.toString) sc = new SparkContext("local-cluster[2,2,1024]", "test", conf) assert(sc.eventLogger.isDefined) val eventLogger = sc.eventLogger.get diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 80c7e0bfee6e..a01999745a2a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -36,357 +36,378 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val jobCompletionTime = 1421191296660L - test("don't call sc.stop in listener") { - sc = new SparkContext("local", "SparkListenerSuite", new SparkConf()) - val listener = new SparkContextStoppingListener(sc) - val bus = new LiveListenerBus(sc) - bus.addListener(listener) - - // Starting listener bus should flush all buffered events - bus.start() - bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) - bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) - - bus.stop() - assert(listener.sparkExSeen) - } - - test("basic creation and shutdown of LiveListenerBus") { - sc = new SparkContext("local", "SparkListenerSuite", new SparkConf()) - val counter = new BasicJobCounter - val bus = new LiveListenerBus(sc) - bus.addListener(counter) + val queuingStrategy = List("true", "false") - // Listener bus hasn't started yet, so posting events should not increment counter - (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } - assert(counter.count === 0) + queuingStrategy.foreach(qS => { - // Starting listener bus should flush all buffered events - bus.start() - bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) - assert(counter.count === 5) + def getConf: SparkConf = { + val conf = new SparkConf() + conf.set("spark.scheduler.listenerbus.eventqueue.drop", qS) + conf + } - // After listener bus has stopped, posting events should not increment counter - bus.stop() - (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } - assert(counter.count === 5) + def name(testName: String): String = { + if (qS == "true") { + s"dropping event bus: $testName" + } else { + s"Waiting event bus: $testName" + } + } - // Listener bus must not be started twice - intercept[IllegalStateException] { + test(name("don't call sc.stop in listener")) { + sc = new SparkContext("local", "SparkListenerSuite", getConf) + val listener = new SparkContextStoppingListener(sc) val bus = new LiveListenerBus(sc) + bus.addListener(listener) + + // Starting listener bus should flush all buffered events bus.start() - bus.start() - } + bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) + bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) - // ... or stopped before starting - intercept[IllegalStateException] { - val bus = new LiveListenerBus(sc) bus.stop() + assert(listener.sparkExSeen) } - } - - test("bus.stop() waits for the event queue to completely drain") { - @volatile var drained = false - - // When Listener has started - val listenerStarted = new Semaphore(0) - // Tells the listener to stop blocking - val listenerWait = new Semaphore(0) + test(name("basic creation and shutdown of LiveListenerBus")) { + sc = new SparkContext("local", "SparkListenerSuite", getConf) + val counter = new BasicJobCounter + val bus = new LiveListenerBus(sc) + bus.addListener(counter) - // When stopper has started - val stopperStarted = new Semaphore(0) + // Listener bus hasn't started yet, so posting events should not increment counter + (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } + assert(counter.count === 0) - // When stopper has returned - val stopperReturned = new Semaphore(0) + // Starting listener bus should flush all buffered events + bus.start() + bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(counter.count === 5) - class BlockingListener extends SparkListener { - override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { - listenerStarted.release() - listenerWait.acquire() - drained = true + // After listener bus has stopped, posting events should not increment counter + bus.stop() + (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } + assert(counter.count === 5) + + // Listener bus must not be started twice + intercept[IllegalStateException] { + val bus = new LiveListenerBus(sc) + bus.start() + bus.start() } - } - sc = new SparkContext("local", "SparkListenerSuite", new SparkConf()) - val bus = new LiveListenerBus(sc) - val blockingListener = new BlockingListener - - bus.addListener(blockingListener) - bus.start() - bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) - - listenerStarted.acquire() - // Listener should be blocked after start - assert(!drained) - - new Thread("ListenerBusStopper") { - override def run() { - stopperStarted.release() - // stop() will block until notify() is called below + + // ... or stopped before starting + intercept[IllegalStateException] { + val bus = new LiveListenerBus(sc) bus.stop() - stopperReturned.release() } - }.start() + } - stopperStarted.acquire() - // Listener should remain blocked after stopper started - assert(!drained) + test(name("bus.stop() waits for the event queue to completely drain")) { + @volatile var drained = false - // unblock Listener to let queue drain - listenerWait.release() - stopperReturned.acquire() - assert(drained) - } + // When Listener has started + val listenerStarted = new Semaphore(0) - test("basic creation of StageInfo") { - sc = new SparkContext("local", "SparkListenerSuite") - val listener = new SaveStageAndTaskInfo - sc.addSparkListener(listener) - val rdd1 = sc.parallelize(1 to 100, 4) - val rdd2 = rdd1.map(_.toString) - rdd2.setName("Target RDD") - rdd2.count() - - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) - - listener.stageInfos.size should be {1} - val (stageInfo, taskInfoMetrics) = listener.stageInfos.head - stageInfo.rddInfos.size should be {2} - stageInfo.rddInfos.forall(_.numPartitions == 4) should be {true} - stageInfo.rddInfos.exists(_.name == "Target RDD") should be {true} - stageInfo.numTasks should be {4} - stageInfo.submissionTime should be ('defined) - stageInfo.completionTime should be ('defined) - taskInfoMetrics.length should be {4} - } + // Tells the listener to stop blocking + val listenerWait = new Semaphore(0) - test("basic creation of StageInfo with shuffle") { - sc = new SparkContext("local", "SparkListenerSuite") - val listener = new SaveStageAndTaskInfo - sc.addSparkListener(listener) - val rdd1 = sc.parallelize(1 to 100, 4) - val rdd2 = rdd1.filter(_ % 2 == 0).map(i => (i, i)) - val rdd3 = rdd2.reduceByKey(_ + _) - rdd1.setName("Un") - rdd2.setName("Deux") - rdd3.setName("Trois") - - rdd1.count() - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) - listener.stageInfos.size should be {1} - val stageInfo1 = listener.stageInfos.keys.find(_.stageId == 0).get - stageInfo1.rddInfos.size should be {1} // ParallelCollectionRDD - stageInfo1.rddInfos.forall(_.numPartitions == 4) should be {true} - stageInfo1.rddInfos.exists(_.name == "Un") should be {true} - listener.stageInfos.clear() - - rdd2.count() - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) - listener.stageInfos.size should be {1} - val stageInfo2 = listener.stageInfos.keys.find(_.stageId == 1).get - stageInfo2.rddInfos.size should be {3} - stageInfo2.rddInfos.forall(_.numPartitions == 4) should be {true} - stageInfo2.rddInfos.exists(_.name == "Deux") should be {true} - listener.stageInfos.clear() - - rdd3.count() - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) - listener.stageInfos.size should be {2} // Shuffle map stage + result stage - val stageInfo3 = listener.stageInfos.keys.find(_.stageId == 3).get - stageInfo3.rddInfos.size should be {1} // ShuffledRDD - stageInfo3.rddInfos.forall(_.numPartitions == 4) should be {true} - stageInfo3.rddInfos.exists(_.name == "Trois") should be {true} - } + // When stopper has started + val stopperStarted = new Semaphore(0) - test("StageInfo with fewer tasks than partitions") { - sc = new SparkContext("local", "SparkListenerSuite") - val listener = new SaveStageAndTaskInfo - sc.addSparkListener(listener) - val rdd1 = sc.parallelize(1 to 100, 4) - val rdd2 = rdd1.map(_.toString) - sc.runJob(rdd2, (items: Iterator[String]) => items.size, Seq(0, 1)) - - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) - - listener.stageInfos.size should be {1} - val (stageInfo, _) = listener.stageInfos.head - stageInfo.numTasks should be {2} - stageInfo.rddInfos.size should be {2} - stageInfo.rddInfos.forall(_.numPartitions == 4) should be {true} - } + // When stopper has returned + val stopperReturned = new Semaphore(0) - test("local metrics") { - sc = new SparkContext("local", "SparkListenerSuite") - val listener = new SaveStageAndTaskInfo - sc.addSparkListener(listener) - sc.addSparkListener(new StatsReportListener) - // just to make sure some of the tasks take a noticeable amount of time - val w = { i: Int => - if (i == 0) { - Thread.sleep(100) + class BlockingListener extends SparkListener { + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + listenerStarted.release() + listenerWait.acquire() + drained = true + } } - i + sc = new SparkContext("local", "SparkListenerSuite", getConf) + val bus = new LiveListenerBus(sc) + val blockingListener = new BlockingListener + + bus.addListener(blockingListener) + bus.start() + bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) + + listenerStarted.acquire() + // Listener should be blocked after start + assert(!drained) + + new Thread("ListenerBusStopper") { + override def run() { + stopperStarted.release() + // stop() will block until notify() is called below + bus.stop() + stopperReturned.release() + } + }.start() + + stopperStarted.acquire() + // Listener should remain blocked after stopper started + assert(!drained) + + // unblock Listener to let queue drain + listenerWait.release() + stopperReturned.acquire() + assert(drained) } - val numSlices = 16 - val d = sc.parallelize(0 to 10000, numSlices).map(w) - d.count() - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) - listener.stageInfos.size should be (1) + test(name("basic creation of StageInfo")) { + sc = new SparkContext("local", "SparkListenerSuite", getConf) + val listener = new SaveStageAndTaskInfo + sc.addSparkListener(listener) + val rdd1 = sc.parallelize(1 to 100, 4) + val rdd2 = rdd1.map(_.toString) + rdd2.setName("Target RDD") + rdd2.count() + + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + + listener.stageInfos.size should be {1} + val (stageInfo, taskInfoMetrics) = listener.stageInfos.head + stageInfo.rddInfos.size should be {2} + stageInfo.rddInfos.forall(_.numPartitions == 4) should be {true} + stageInfo.rddInfos.exists(_.name == "Target RDD") should be {true} + stageInfo.numTasks should be {4} + stageInfo.submissionTime should be ('defined) + stageInfo.completionTime should be ('defined) + taskInfoMetrics.length should be {4} + } - val d2 = d.map { i => w(i) -> i * 2 }.setName("shuffle input 1") - val d3 = d.map { i => w(i) -> (0 to (i % 5)) }.setName("shuffle input 2") - val d4 = d2.cogroup(d3, numSlices).map { case (k, (v1, v2)) => - w(k) -> (v1.size, v2.size) + test(name("basic creation of StageInfo with shuffle")) { + sc = new SparkContext("local", "SparkListenerSuite", getConf) + val listener = new SaveStageAndTaskInfo + sc.addSparkListener(listener) + val rdd1 = sc.parallelize(1 to 100, 4) + val rdd2 = rdd1.filter(_ % 2 == 0).map(i => (i, i)) + val rdd3 = rdd2.reduceByKey(_ + _) + rdd1.setName("Un") + rdd2.setName("Deux") + rdd3.setName("Trois") + + rdd1.count() + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + listener.stageInfos.size should be {1} + val stageInfo1 = listener.stageInfos.keys.find(_.stageId == 0).get + stageInfo1.rddInfos.size should be {1} // ParallelCollectionRDD + stageInfo1.rddInfos.forall(_.numPartitions == 4) should be {true} + stageInfo1.rddInfos.exists(_.name == "Un") should be {true} + listener.stageInfos.clear() + + rdd2.count() + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + listener.stageInfos.size should be {1} + val stageInfo2 = listener.stageInfos.keys.find(_.stageId == 1).get + stageInfo2.rddInfos.size should be {3} + stageInfo2.rddInfos.forall(_.numPartitions == 4) should be {true} + stageInfo2.rddInfos.exists(_.name == "Deux") should be {true} + listener.stageInfos.clear() + + rdd3.count() + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + listener.stageInfos.size should be {2} // Shuffle map stage + result stage + val stageInfo3 = listener.stageInfos.keys.find(_.stageId == 3).get + stageInfo3.rddInfos.size should be {1} // ShuffledRDD + stageInfo3.rddInfos.forall(_.numPartitions == 4) should be {true} + stageInfo3.rddInfos.exists(_.name == "Trois") should be {true} + } + + test(name("StageInfo with fewer tasks than partitions")) { + sc = new SparkContext("local", "SparkListenerSuite", getConf) + val listener = new SaveStageAndTaskInfo + sc.addSparkListener(listener) + val rdd1 = sc.parallelize(1 to 100, 4) + val rdd2 = rdd1.map(_.toString) + sc.runJob(rdd2, (items: Iterator[String]) => items.size, Seq(0, 1)) + + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + + listener.stageInfos.size should be {1} + val (stageInfo, _) = listener.stageInfos.head + stageInfo.numTasks should be {2} + stageInfo.rddInfos.size should be {2} + stageInfo.rddInfos.forall(_.numPartitions == 4) should be {true} } - d4.setName("A Cogroup") - d4.collectAsMap() - - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) - listener.stageInfos.size should be (4) - listener.stageInfos.foreach { case (stageInfo, taskInfoMetrics) => - /** - * Small test, so some tasks might take less than 1 millisecond, but average should be greater - * than 0 ms. - */ - checkNonZeroAvg( - taskInfoMetrics.map(_._2.executorRunTime), - stageInfo + " executorRunTime") - checkNonZeroAvg( - taskInfoMetrics.map(_._2.executorDeserializeTime), - stageInfo + " executorDeserializeTime") - - /* Test is disabled (SEE SPARK-2208) - if (stageInfo.rddInfos.exists(_.name == d4.name)) { - checkNonZeroAvg( - taskInfoMetrics.map(_._2.shuffleReadMetrics.get.fetchWaitTime), - stageInfo + " fetchWaitTime") - } - */ - taskInfoMetrics.foreach { case (taskInfo, taskMetrics) => - taskMetrics.resultSize should be > (0L) - if (stageInfo.rddInfos.exists(info => info.name == d2.name || info.name == d3.name)) { - assert(taskMetrics.shuffleWriteMetrics.bytesWritten > 0L) + test(name("local metrics")) { + sc = new SparkContext("local", "SparkListenerSuite", getConf) + val listener = new SaveStageAndTaskInfo + sc.addSparkListener(listener) + sc.addSparkListener(new StatsReportListener) + // just to make sure some of the tasks take a noticeable amount of time + val w = { i: Int => + if (i == 0) { + Thread.sleep(100) } + i + } + + val numSlices = 16 + val d = sc.parallelize(0 to 10000, numSlices).map(w) + d.count() + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + listener.stageInfos.size should be (1) + + val d2 = d.map { i => w(i) -> i * 2 }.setName("shuffle input 1") + val d3 = d.map { i => w(i) -> (0 to (i % 5)) }.setName("shuffle input 2") + val d4 = d2.cogroup(d3, numSlices).map { case (k, (v1, v2)) => + w(k) -> (v1.size, v2.size) + } + d4.setName("A Cogroup") + d4.collectAsMap() + + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + listener.stageInfos.size should be (4) + listener.stageInfos.foreach { case (stageInfo, taskInfoMetrics) => + /** + * Small test, so some tasks might take less than 1 millisecond, + * but average should be greater than 0 ms. + */ + checkNonZeroAvg( + taskInfoMetrics.map(_._2.executorRunTime), + stageInfo + " executorRunTime") + checkNonZeroAvg( + taskInfoMetrics.map(_._2.executorDeserializeTime), + stageInfo + " executorDeserializeTime") + + /* Test is disabled (SEE SPARK-2208) if (stageInfo.rddInfos.exists(_.name == d4.name)) { - assert(taskMetrics.shuffleReadMetrics.totalBlocksFetched == 2 * numSlices) - assert(taskMetrics.shuffleReadMetrics.localBlocksFetched == 2 * numSlices) - assert(taskMetrics.shuffleReadMetrics.remoteBlocksFetched == 0) - assert(taskMetrics.shuffleReadMetrics.remoteBytesRead == 0L) + checkNonZeroAvg( + taskInfoMetrics.map(_._2.shuffleReadMetrics.get.fetchWaitTime), + stageInfo + " fetchWaitTime") + } + */ + + taskInfoMetrics.foreach { case (taskInfo, taskMetrics) => + taskMetrics.resultSize should be > (0L) + if (stageInfo.rddInfos.exists(info => info.name == d2.name || info.name == d3.name)) { + assert(taskMetrics.shuffleWriteMetrics.bytesWritten > 0L) + } + if (stageInfo.rddInfos.exists(_.name == d4.name)) { + assert(taskMetrics.shuffleReadMetrics.totalBlocksFetched == 2 * numSlices) + assert(taskMetrics.shuffleReadMetrics.localBlocksFetched == 2 * numSlices) + assert(taskMetrics.shuffleReadMetrics.remoteBlocksFetched == 0) + assert(taskMetrics.shuffleReadMetrics.remoteBytesRead == 0L) + } } } } - } - test("onTaskGettingResult() called when result fetched remotely") { - val conf = new SparkConf().set("spark.rpc.message.maxSize", "1") - sc = new SparkContext("local", "SparkListenerSuite", conf) - val listener = new SaveTaskEvents - sc.addSparkListener(listener) - - // Make a task whose result is larger than the RPC message size - val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) - assert(maxRpcMessageSize === 1024 * 1024) - val result = sc.parallelize(Seq(1), 1) - .map { x => 1.to(maxRpcMessageSize).toArray } - .reduce { case (x, y) => x } - assert(result === 1.to(maxRpcMessageSize).toArray) - - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) - val TASK_INDEX = 0 - assert(listener.startedTasks.contains(TASK_INDEX)) - assert(listener.startedGettingResultTasks.contains(TASK_INDEX)) - assert(listener.endedTasks.contains(TASK_INDEX)) - } + test(name("onTaskGettingResult() called when result fetched remotely")) { + val conf = getConf.set("spark.rpc.message.maxSize", "1") + sc = new SparkContext("local", "SparkListenerSuite", conf) + val listener = new SaveTaskEvents + sc.addSparkListener(listener) + + // Make a task whose result is larger than the RPC message size + val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) + assert(maxRpcMessageSize === 1024 * 1024) + val result = sc.parallelize(Seq(1), 1) + .map { x => 1.to(maxRpcMessageSize).toArray } + .reduce { case (x, y) => x } + assert(result === 1.to(maxRpcMessageSize).toArray) + + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + val TASK_INDEX = 0 + assert(listener.startedTasks.contains(TASK_INDEX)) + assert(listener.startedGettingResultTasks.contains(TASK_INDEX)) + assert(listener.endedTasks.contains(TASK_INDEX)) + } - test("onTaskGettingResult() not called when result sent directly") { - sc = new SparkContext("local", "SparkListenerSuite") - val listener = new SaveTaskEvents - sc.addSparkListener(listener) + test(name("onTaskGettingResult() not called when result sent directly")) { + sc = new SparkContext("local", "SparkListenerSuite", getConf) + val listener = new SaveTaskEvents + sc.addSparkListener(listener) - // Make a task whose result is larger than the RPC message size - val result = sc.parallelize(Seq(1), 1).map(2 * _).reduce { case (x, y) => x } - assert(result === 2) + // Make a task whose result is larger than the RPC message size + val result = sc.parallelize(Seq(1), 1).map(2 * _).reduce { case (x, y) => x } + assert(result === 2) - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) - val TASK_INDEX = 0 - assert(listener.startedTasks.contains(TASK_INDEX)) - assert(listener.startedGettingResultTasks.isEmpty) - assert(listener.endedTasks.contains(TASK_INDEX)) - } + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + val TASK_INDEX = 0 + assert(listener.startedTasks.contains(TASK_INDEX)) + assert(listener.startedGettingResultTasks.isEmpty) + assert(listener.endedTasks.contains(TASK_INDEX)) + } - test("onTaskEnd() should be called for all started tasks, even after job has been killed") { - sc = new SparkContext("local", "SparkListenerSuite") - val WAIT_TIMEOUT_MILLIS = 10000 - val listener = new SaveTaskEvents - sc.addSparkListener(listener) - - val numTasks = 10 - val f = sc.parallelize(1 to 10000, numTasks).map { i => Thread.sleep(10); i }.countAsync() - // Wait until one task has started (because we want to make sure that any tasks that are started - // have corresponding end events sent to the listener). - var finishTime = System.currentTimeMillis + WAIT_TIMEOUT_MILLIS - listener.synchronized { - var remainingWait = finishTime - System.currentTimeMillis - while (listener.startedTasks.isEmpty && remainingWait > 0) { - listener.wait(remainingWait) - remainingWait = finishTime - System.currentTimeMillis + test(name("onTaskEnd() should be called for all started tasks, even after job has been killed")) + { + sc = new SparkContext("local", "SparkListenerSuite", getConf) + val WAIT_TIMEOUT_MILLIS = 10000 + val listener = new SaveTaskEvents + sc.addSparkListener(listener) + + val numTasks = 10 + val f = sc.parallelize(1 to 10000, numTasks).map { i => Thread.sleep(10); i }.countAsync() + // Wait until one task has started (because we want to make sure that any tasks that are + // started have corresponding end events sent to the listener). + var finishTime = System.currentTimeMillis + WAIT_TIMEOUT_MILLIS + listener.synchronized { + var remainingWait = finishTime - System.currentTimeMillis + while (listener.startedTasks.isEmpty && remainingWait > 0) { + listener.wait(remainingWait) + remainingWait = finishTime - System.currentTimeMillis + } + assert(listener.startedTasks.nonEmpty) } - assert(!listener.startedTasks.isEmpty) - } - f.cancel() + f.cancel() - // Ensure that onTaskEnd is called for all started tasks. - finishTime = System.currentTimeMillis + WAIT_TIMEOUT_MILLIS - listener.synchronized { - var remainingWait = finishTime - System.currentTimeMillis - while (listener.endedTasks.size < listener.startedTasks.size && remainingWait > 0) { - listener.wait(finishTime - System.currentTimeMillis) - remainingWait = finishTime - System.currentTimeMillis + // Ensure that onTaskEnd is called for all started tasks. + finishTime = System.currentTimeMillis + WAIT_TIMEOUT_MILLIS + listener.synchronized { + var remainingWait = finishTime - System.currentTimeMillis + while (listener.endedTasks.size < listener.startedTasks.size && remainingWait > 0) { + listener.wait(finishTime - System.currentTimeMillis) + remainingWait = finishTime - System.currentTimeMillis + } + assert(listener.endedTasks.size === listener.startedTasks.size) } - assert(listener.endedTasks.size === listener.startedTasks.size) } - } - test("SparkListener moves on if a listener throws an exception") { - val badListener = new BadListener - val jobCounter1 = new BasicJobCounter - val jobCounter2 = new BasicJobCounter - sc = new SparkContext("local", "SparkListenerSuite", new SparkConf()) - val bus = new LiveListenerBus(sc) - - // Propagate events to bad listener first - bus.addListener(badListener) - bus.addListener(jobCounter1) - bus.addListener(jobCounter2) - bus.start() - - // Post events to all listeners, and wait until the queue is drained - (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } - bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) - - // The exception should be caught, and the event should be propagated to other listeners - assert(bus.listenerThreadIsAlive) - assert(jobCounter1.count === 5) - assert(jobCounter2.count === 5) - } + test(name("SparkListener moves on if a listener throws an exception")) { + val badListener = new BadListener + val jobCounter1 = new BasicJobCounter + val jobCounter2 = new BasicJobCounter + sc = new SparkContext("local", "SparkListenerSuite", getConf) + val bus = new LiveListenerBus(sc) + + // Propagate events to bad listener first + bus.addListener(badListener) + bus.addListener(jobCounter1) + bus.addListener(jobCounter2) + bus.start() - test("registering listeners via spark.extraListeners") { - val listeners = Seq( - classOf[ListenerThatAcceptsSparkConf], - classOf[FirehoseListenerThatAcceptsSparkConf], - classOf[BasicJobCounter]) - val conf = new SparkConf().setMaster("local").setAppName("test") - .set("spark.extraListeners", listeners.map(_.getName).mkString(",")) - sc = new SparkContext(conf) - sc.listenerBus.listeners.asScala.count(_.isInstanceOf[BasicJobCounter]) should be (1) - sc.listenerBus.listeners.asScala - .count(_.isInstanceOf[ListenerThatAcceptsSparkConf]) should be (1) - sc.listenerBus.listeners.asScala + // Post events to all listeners, and wait until the queue is drained + (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } + bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + + // The exception should be caught, and the event should be propagated to other listeners + assert(bus.listenerThreadIsAlive) + assert(jobCounter1.count === 5) + assert(jobCounter2.count === 5) + } + + test(name("registering listeners via spark.extraListeners")) { + val listeners = Seq( + classOf[ListenerThatAcceptsSparkConf], + classOf[FirehoseListenerThatAcceptsSparkConf], + classOf[BasicJobCounter]) + val conf = getConf.setMaster("local").setAppName("test") + .set("spark.extraListeners", listeners.map(_.getName).mkString(",")) + sc = new SparkContext(conf) + sc.listenerBus.listeners.count(_.isInstanceOf[BasicJobCounter]) should be (1) + sc.listenerBus.listeners.count(_.isInstanceOf[ListenerThatAcceptsSparkConf]) should be (1) + sc.listenerBus.listeners .count(_.isInstanceOf[FirehoseListenerThatAcceptsSparkConf]) should be (1) - } + } + + }) + /** * Assert that the given list of numbers has an average that is greater than zero.