Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 95 additions & 30 deletions core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,90 @@ private[spark] class UninterruptibleThread(
this(null, name)
}

/** A monitor to protect "uninterruptible" and "interrupted" */
private val uninterruptibleLock = new Object
private class UninterruptibleLock {
/**
* Indicates if `this` thread are in the uninterruptible status. If so, interrupting
* "this" will be deferred until `this` enters into the interruptible status.
*/
@GuardedBy("uninterruptibleLock")
private var uninterruptible = false

/**
* Indicates if `this` thread are in the uninterruptible status. If so, interrupting
* "this" will be deferred until `this` enters into the interruptible status.
*/
@GuardedBy("uninterruptibleLock")
private var uninterruptible = false
/**
* Indicates if we should interrupt `this` when we are leaving the uninterruptible zone.
*/
@GuardedBy("uninterruptibleLock")
private var shouldInterruptThread = false

/**
* Indicates if we should interrupt `this` when we are leaving the uninterruptible zone.
*/
@GuardedBy("uninterruptibleLock")
private var shouldInterruptThread = false
/**
* Indicates that we should wait for interrupt() call before proceeding.
*/
@GuardedBy("uninterruptibleLock")
private var awaitInterruptThread = false

/**
* Set [[uninterruptible]] to given value and returns the previous value.
*/
def getAndSetUninterruptible(value: Boolean): Boolean = synchronized {
val uninterruptible = this.uninterruptible
this.uninterruptible = value
uninterruptible
}

def setShouldInterruptThread(value: Boolean): Unit = synchronized {
shouldInterruptThread = value
}

def setAwaitInterruptThread(value: Boolean): Unit = synchronized {
awaitInterruptThread = value
}

/**
* Is call to [[java.lang.Thread.interrupt()]] pending
*/
def isInterruptPending: Boolean = synchronized {
// Clear the interrupted status if it's set.
shouldInterruptThread = Thread.interrupted() || shouldInterruptThread
// wait for super.interrupt() to be called
!shouldInterruptThread && awaitInterruptThread
}

/**
* Set [[uninterruptible]] back to false and call [[java.lang.Thread.interrupt()]] to
* recover interrupt state if necessary
*/
def recoverInterrupt(): Unit = synchronized {
uninterruptible = false
if (shouldInterruptThread) {
shouldInterruptThread = false
// Recover the interrupted status
UninterruptibleThread.super.interrupt()
}
}

/**
* Is it safe to call [[java.lang.Thread.interrupt()]] and interrupt the current thread
* @return true when there is no concurrent [[runUninterruptibly()]] call ([[uninterruptible]]
* is true) and no concurrent [[interrupt()]] call, otherwise false
*/
def isInterruptible: Boolean = synchronized {
shouldInterruptThread = uninterruptible
// as we are releasing uninterruptibleLock before calling super.interrupt() there is a
// possibility that runUninterruptibly() would be called after lock is released but before
// super.interrupt() is called. In this case to prevent runUninterruptibly() from being
// interrupted, we use awaitInterruptThread flag. We need to set it only if
// runUninterruptibly() is not yet set uninterruptible to true (!shouldInterruptThread) and
// there is no other threads that called interrupt (awaitInterruptThread is already true)
if (!shouldInterruptThread && !awaitInterruptThread) {
awaitInterruptThread = true
true
} else {
false
}
}
}

/** A monitor to protect "uninterruptible" and "interrupted" */
private val uninterruptibleLock = new UninterruptibleLock

/**
* Run `f` uninterruptibly in `this` thread. The thread won't be interrupted before returning
Expand All @@ -63,27 +132,23 @@ private[spark] class UninterruptibleThread(
s"Expected: $this but was ${Thread.currentThread()}")
}

if (uninterruptibleLock.synchronized { uninterruptible }) {
if (uninterruptibleLock.getAndSetUninterruptible(true)) {
// We are already in the uninterruptible status. So just run "f" and return
return f
}

uninterruptibleLock.synchronized {
// Clear the interrupted status if it's set.
shouldInterruptThread = Thread.interrupted() || shouldInterruptThread
uninterruptible = true
while (uninterruptibleLock.isInterruptPending) {
try {
Thread.sleep(100)
} catch {
case _: InterruptedException => uninterruptibleLock.setShouldInterruptThread(true)
}
}

try {
f
} finally {
uninterruptibleLock.synchronized {
uninterruptible = false
if (shouldInterruptThread) {
// Recover the interrupted status
super.interrupt()
shouldInterruptThread = false
}
}
uninterruptibleLock.recoverInterrupt()
}
}

Expand All @@ -92,11 +157,11 @@ private[spark] class UninterruptibleThread(
* interrupted until it enters into the interruptible status.
*/
override def interrupt(): Unit = {
uninterruptibleLock.synchronized {
if (uninterruptible) {
shouldInterruptThread = true
} else {
if (uninterruptibleLock.isInterruptible) {
try {
super.interrupt()
} finally {
uninterruptibleLock.setAwaitInterruptThread(false)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.util

import java.nio.channels.spi.AbstractInterruptibleChannel
import java.util.concurrent.{CountDownLatch, TimeUnit}

import scala.util.Random
Expand Down Expand Up @@ -115,6 +116,45 @@ class UninterruptibleThreadSuite extends SparkFunSuite {
assert(interruptStatusBeforeExit)
}

test("no runUninterruptibly") {
@volatile var hasInterruptedException = false
val t = new UninterruptibleThread("test") {
override def run(): Unit = {
if (sleep(0)) {
hasInterruptedException = true
}
}
}
t.interrupt()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks that Java 8 behaves differently when interrupt() is called on not started thread.

t.start()
t.join()
assert(hasInterruptedException === true)
}

test("SPARK-51821 uninterruptibleLock deadlock") {
val latch = new CountDownLatch(1)
val task = new UninterruptibleThread("task thread") {
override def run(): Unit = {
val channel = new AbstractInterruptibleChannel() {
override def implCloseChannel(): Unit = {
begin()
latch.countDown()
try {
Thread.sleep(Long.MaxValue)
} catch {
case _: InterruptedException => Thread.currentThread().interrupt()
}
}
}
channel.close()
Comment on lines +138 to +149
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't seem to be a right use case of UninterruptibleThread. Shouldn't the whole block be executed within runUninterruptibly()? e.g.,

override def run(): Unit = {
  this.runUninterruptibly {
    ...
  }
}

The task thread can be correctly interrupted if the whole block run inside runUninterruptibly() with the limited Thread.sleep().

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see where the problem is. Spark task always uses the UninterruptibleThread but this.runUninterruptibly() is only called for the Spark task that run with KafkaConsumer (#17761 is the original PR that introduced UninterruptibleThread).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UninterruptibleThread was introduced as part of SPARK-14169 and by design can be used to run and runUninterruptibly.

Copy link
Contributor

@mridulm mridulm Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Streaming also appears to use runUninterruptibly.
Are we proposing to adapt it to not need this construct @Ngone51 ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think we should avoid using UninterruptibleThread when it is not really needed. The problem is that executor can't distinguish the tasks for different workloads. So we have to compromise with UninterruptibleThread as the default task thread and call runUninterruptibly when it is necessary.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mridulm @Ngone51 Streaming requires UninterruptibleThread, please see SPARK-21248. I also don't think that it is necessary to revisit usage of UninterruptibleThread. The run() method is not affected at all. The only affected method (overridden) is interrupt() and with the fix it also won't be impacted. The only difference with Thread.interrupt() is acquiring uninterruptibleLock that is a low cost operation when there is no contention (multiple threads calling interrupt() concurrently) and as Thread.interrupt() acquires blockerLock as well, there is pretty much no difference at all.

}
}
task.start()
assert(latch.await(10, TimeUnit.SECONDS), "await timeout")
task.interrupt()
task.join()
}

test("stress test") {
@volatile var hasInterruptedException = false
val t = new UninterruptibleThread("test") {
Expand Down Expand Up @@ -148,9 +188,20 @@ class UninterruptibleThreadSuite extends SparkFunSuite {
}
}
t.start()
for (i <- 0 until 400) {
Thread.sleep(Random.nextInt(10))
t.interrupt()
val threads = new Array[Thread](10)
for (j <- 0 until 10) {
threads(j) = new Thread() {
override def run(): Unit = {
for (i <- 0 until 400) {
Thread.sleep(Random.nextInt(10))
t.interrupt()
}
}
}
threads(j).start()
}
for (j <- 0 until 10) {
threads(j).join()
}
t.join()
assert(hasInterruptedException === false)
Expand Down