diff --git a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala index 24788d69121b..8fba5ed944c6 100644 --- a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala +++ b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala @@ -35,21 +35,90 @@ private[spark] class UninterruptibleThread( this(null, name) } - /** A monitor to protect "uninterruptible" and "interrupted" */ - private val uninterruptibleLock = new Object + private class UninterruptibleLock { + /** + * Indicates if `this` thread are in the uninterruptible status. If so, interrupting + * "this" will be deferred until `this` enters into the interruptible status. + */ + @GuardedBy("uninterruptibleLock") + private var uninterruptible = false - /** - * Indicates if `this` thread are in the uninterruptible status. If so, interrupting - * "this" will be deferred until `this` enters into the interruptible status. - */ - @GuardedBy("uninterruptibleLock") - private var uninterruptible = false + /** + * Indicates if we should interrupt `this` when we are leaving the uninterruptible zone. + */ + @GuardedBy("uninterruptibleLock") + private var shouldInterruptThread = false - /** - * Indicates if we should interrupt `this` when we are leaving the uninterruptible zone. - */ - @GuardedBy("uninterruptibleLock") - private var shouldInterruptThread = false + /** + * Indicates that we should wait for interrupt() call before proceeding. + */ + @GuardedBy("uninterruptibleLock") + private var awaitInterruptThread = false + + /** + * Set [[uninterruptible]] to given value and returns the previous value. + */ + def getAndSetUninterruptible(value: Boolean): Boolean = synchronized { + val uninterruptible = this.uninterruptible + this.uninterruptible = value + uninterruptible + } + + def setShouldInterruptThread(value: Boolean): Unit = synchronized { + shouldInterruptThread = value + } + + def setAwaitInterruptThread(value: Boolean): Unit = synchronized { + awaitInterruptThread = value + } + + /** + * Is call to [[java.lang.Thread.interrupt()]] pending + */ + def isInterruptPending: Boolean = synchronized { + // Clear the interrupted status if it's set. + shouldInterruptThread = Thread.interrupted() || shouldInterruptThread + // wait for super.interrupt() to be called + !shouldInterruptThread && awaitInterruptThread + } + + /** + * Set [[uninterruptible]] back to false and call [[java.lang.Thread.interrupt()]] to + * recover interrupt state if necessary + */ + def recoverInterrupt(): Unit = synchronized { + uninterruptible = false + if (shouldInterruptThread) { + shouldInterruptThread = false + // Recover the interrupted status + UninterruptibleThread.super.interrupt() + } + } + + /** + * Is it safe to call [[java.lang.Thread.interrupt()]] and interrupt the current thread + * @return true when there is no concurrent [[runUninterruptibly()]] call ([[uninterruptible]] + * is true) and no concurrent [[interrupt()]] call, otherwise false + */ + def isInterruptible: Boolean = synchronized { + shouldInterruptThread = uninterruptible + // as we are releasing uninterruptibleLock before calling super.interrupt() there is a + // possibility that runUninterruptibly() would be called after lock is released but before + // super.interrupt() is called. In this case to prevent runUninterruptibly() from being + // interrupted, we use awaitInterruptThread flag. We need to set it only if + // runUninterruptibly() is not yet set uninterruptible to true (!shouldInterruptThread) and + // there is no other threads that called interrupt (awaitInterruptThread is already true) + if (!shouldInterruptThread && !awaitInterruptThread) { + awaitInterruptThread = true + true + } else { + false + } + } + } + + /** A monitor to protect "uninterruptible" and "interrupted" */ + private val uninterruptibleLock = new UninterruptibleLock /** * Run `f` uninterruptibly in `this` thread. The thread won't be interrupted before returning @@ -63,27 +132,23 @@ private[spark] class UninterruptibleThread( s"Expected: $this but was ${Thread.currentThread()}") } - if (uninterruptibleLock.synchronized { uninterruptible }) { + if (uninterruptibleLock.getAndSetUninterruptible(true)) { // We are already in the uninterruptible status. So just run "f" and return return f } - uninterruptibleLock.synchronized { - // Clear the interrupted status if it's set. - shouldInterruptThread = Thread.interrupted() || shouldInterruptThread - uninterruptible = true + while (uninterruptibleLock.isInterruptPending) { + try { + Thread.sleep(100) + } catch { + case _: InterruptedException => uninterruptibleLock.setShouldInterruptThread(true) + } } + try { f } finally { - uninterruptibleLock.synchronized { - uninterruptible = false - if (shouldInterruptThread) { - // Recover the interrupted status - super.interrupt() - shouldInterruptThread = false - } - } + uninterruptibleLock.recoverInterrupt() } } @@ -92,11 +157,11 @@ private[spark] class UninterruptibleThread( * interrupted until it enters into the interruptible status. */ override def interrupt(): Unit = { - uninterruptibleLock.synchronized { - if (uninterruptible) { - shouldInterruptThread = true - } else { + if (uninterruptibleLock.isInterruptible) { + try { super.interrupt() + } finally { + uninterruptibleLock.setAwaitInterruptThread(false) } } } diff --git a/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala b/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala index 9c0ee1e1303e..fbc954d05af8 100644 --- a/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.util +import java.nio.channels.spi.AbstractInterruptibleChannel import java.util.concurrent.{CountDownLatch, TimeUnit} import scala.util.Random @@ -115,6 +116,45 @@ class UninterruptibleThreadSuite extends SparkFunSuite { assert(interruptStatusBeforeExit) } + test("no runUninterruptibly") { + @volatile var hasInterruptedException = false + val t = new UninterruptibleThread("test") { + override def run(): Unit = { + if (sleep(0)) { + hasInterruptedException = true + } + } + } + t.interrupt() + t.start() + t.join() + assert(hasInterruptedException === true) + } + + test("SPARK-51821 uninterruptibleLock deadlock") { + val latch = new CountDownLatch(1) + val task = new UninterruptibleThread("task thread") { + override def run(): Unit = { + val channel = new AbstractInterruptibleChannel() { + override def implCloseChannel(): Unit = { + begin() + latch.countDown() + try { + Thread.sleep(Long.MaxValue) + } catch { + case _: InterruptedException => Thread.currentThread().interrupt() + } + } + } + channel.close() + } + } + task.start() + assert(latch.await(10, TimeUnit.SECONDS), "await timeout") + task.interrupt() + task.join() + } + test("stress test") { @volatile var hasInterruptedException = false val t = new UninterruptibleThread("test") { @@ -148,9 +188,20 @@ class UninterruptibleThreadSuite extends SparkFunSuite { } } t.start() - for (i <- 0 until 400) { - Thread.sleep(Random.nextInt(10)) - t.interrupt() + val threads = new Array[Thread](10) + for (j <- 0 until 10) { + threads(j) = new Thread() { + override def run(): Unit = { + for (i <- 0 until 400) { + Thread.sleep(Random.nextInt(10)) + t.interrupt() + } + } + } + threads(j).start() + } + for (j <- 0 until 10) { + threads(j).join() } t.join() assert(hasInterruptedException === false)