diff --git a/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala index c1fedd63f6a9..e2b6df460059 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala @@ -34,7 +34,11 @@ import org.apache.spark.util.Utils * Delivery will only begin when the `start()` method is called. The `stop()` method should be * called when no more events need to be delivered. */ -private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveListenerBusMetrics) +private class AsyncEventQueue( + val name: String, + conf: SparkConf, + metrics: LiveListenerBusMetrics, + bus: LiveListenerBus) extends SparkListenerBus with Logging { @@ -81,23 +85,18 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi } private def dispatch(): Unit = LiveListenerBus.withinListenerThread.withValue(true) { - try { - var next: SparkListenerEvent = eventQueue.take() - while (next != POISON_PILL) { - val ctx = processingTime.time() - try { - super.postToAll(next) - } finally { - ctx.stop() - } - eventCount.decrementAndGet() - next = eventQueue.take() + var next: SparkListenerEvent = eventQueue.take() + while (next != POISON_PILL) { + val ctx = processingTime.time() + try { + super.postToAll(next) + } finally { + ctx.stop() } eventCount.decrementAndGet() - } catch { - case ie: InterruptedException => - logInfo(s"Stopping listener queue $name.", ie) + next = eventQueue.take() } + eventCount.decrementAndGet() } override protected def getTimer(listener: SparkListenerInterface): Option[Timer] = { @@ -130,7 +129,11 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi eventCount.incrementAndGet() eventQueue.put(POISON_PILL) } - dispatchThread.join() + // this thread might be trying to stop itself as part of error handling -- we can't join + // in that case. + if (Thread.currentThread() != dispatchThread) { + dispatchThread.join() + } } def post(event: SparkListenerEvent): Unit = { @@ -187,6 +190,12 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi true } + override def removeListenerOnError(listener: SparkListenerInterface): Unit = { + // the listener failed in an unrecoverably way, we want to remove it from the entire + // LiveListenerBus (potentially stopping a queue if it is empty) + bus.removeListener(listener) + } + } private object AsyncEventQueue { diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index ba6387a8f08a..d135190d1e91 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -102,7 +102,7 @@ private[spark] class LiveListenerBus(conf: SparkConf) { queue.addListener(listener) case None => - val newQueue = new AsyncEventQueue(queue, conf, metrics) + val newQueue = new AsyncEventQueue(queue, conf, metrics, this) newQueue.addListener(listener) if (started.get()) { newQueue.start(sparkContext) diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index b25a731401f2..d4474a90b26f 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -60,6 +60,15 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { } } + /** + * This can be overriden by subclasses if there is any extra cleanup to do when removing a + * listener. In particular AsyncEventQueues can clean up queues in the LiveListenerBus. + */ + def removeListenerOnError(listener: L): Unit = { + removeListener(listener) + } + + /** * Post the event to all registered listeners. The `postToAll` caller should guarantee calling * `postToAll` in the same thread for all events. @@ -80,7 +89,16 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { } try { doPostEvent(listener, event) + if (Thread.interrupted()) { + // We want to throw the InterruptedException right away so we can associate the interrupt + // with this listener, as opposed to waiting for a queue.take() etc. to detect it. + throw new InterruptedException() + } } catch { + case ie: InterruptedException => + logError(s"Interrupted while posting to ${Utils.getFormattedClassName(listener)}. " + + s"Removing that listener.", ie) + removeListenerOnError(listener) case NonFatal(e) if !isIgnorableException(e) => logError(s"Listener ${Utils.getFormattedClassName(listener)} threw an exception", e) } finally { diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index fa47a52bbbc4..6ffd1e84f7ad 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -489,6 +489,48 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match assert(bus.findListenersByClass[BasicJobCounter]().isEmpty) } + Seq(true, false).foreach { throwInterruptedException => + val suffix = if (throwInterruptedException) "throw interrupt" else "set Thread interrupted" + test(s"interrupt within listener is handled correctly: $suffix") { + val conf = new SparkConf(false) + .set(LISTENER_BUS_EVENT_QUEUE_CAPACITY, 5) + val bus = new LiveListenerBus(conf) + val counter1 = new BasicJobCounter() + val counter2 = new BasicJobCounter() + val interruptingListener1 = new InterruptingListener(throwInterruptedException) + val interruptingListener2 = new InterruptingListener(throwInterruptedException) + bus.addToSharedQueue(counter1) + bus.addToSharedQueue(interruptingListener1) + bus.addToStatusQueue(counter2) + bus.addToEventLogQueue(interruptingListener2) + assert(bus.activeQueues() === Set(SHARED_QUEUE, APP_STATUS_QUEUE, EVENT_LOG_QUEUE)) + assert(bus.findListenersByClass[BasicJobCounter]().size === 2) + assert(bus.findListenersByClass[InterruptingListener]().size === 2) + + bus.start(mockSparkContext, mockMetricsSystem) + + // after we post one event, both interrupting listeners should get removed, and the + // event log queue should be removed + bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) + bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(bus.activeQueues() === Set(SHARED_QUEUE, APP_STATUS_QUEUE)) + assert(bus.findListenersByClass[BasicJobCounter]().size === 2) + assert(bus.findListenersByClass[InterruptingListener]().size === 0) + assert(counter1.count === 1) + assert(counter2.count === 1) + + // posting more events should be fine, they'll just get processed from the OK queue. + (0 until 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } + bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(counter1.count === 6) + assert(counter2.count === 6) + + // Make sure stopping works -- this requires putting a poison pill in all active queues, which + // would fail if our interrupted queue was still active, as its queue would be full. + bus.stop() + } + } + /** * Assert that the given list of numbers has an average that is greater than zero. */ @@ -547,6 +589,18 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { throw new Exception } } + /** + * A simple listener that interrupts on job end. + */ + private class InterruptingListener(val throwInterruptedException: Boolean) extends SparkListener { + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + if (throwInterruptedException) { + throw new InterruptedException("got interrupted") + } else { + Thread.currentThread().interrupt() + } + } + } } // These classes can't be declared inside of the SparkListenerSuite class because we don't want