Skip to content

Commit bb0b2d2

Browse files
vrozovcloud-fan
authored andcommitted
[SPARK-51821][CORE] Call interrupt() without holding uninterruptibleLock to avoid possible deadlock
### What changes were proposed in this pull request? Do not hold `uninterruptibleLock` monitor while calling `super.interrupt()` in `UninterruptibleThread`, instead use newly introduced `awaitInterruptThread` flag and wait for `super.interrupt()` to be called. ### Why are the changes needed? There is potential deadlock as `UninterruptibleThread` may be blocked on NIO operation and interrupting channel while holding `uninterruptibleLock` monitor may cause deadlock like in ``` Found one Java-level deadlock: ============================= "pool-1-thread-1-ScalaTest-running-UninterruptibleThreadSuite": waiting to lock monitor 0x00006000036ee3c0 (object 0x000000070f3019d0, a java.lang.Object), which is held by "task thread" "task thread": waiting to lock monitor 0x00006000036e75a0 (object 0x000000070f70fe80, a java.lang.Object), which is held by "pool-1-thread-1-ScalaTest-running-UninterruptibleThreadSuite" Java stack information for the threads listed above: =================================================== "pool-1-thread-1-ScalaTest-running-UninterruptibleThreadSuite": at java.nio.channels.spi.AbstractInterruptibleChannel$1.interrupt(java.base17.0.14/AbstractInterruptibleChannel.java:157) - waiting to lock <0x000000070f3019d0> (a java.lang.Object) at java.lang.Thread.interrupt(java.base17.0.14/Thread.java:1004) - locked <0x000000070f70fc90> (a java.lang.Object) at org.apache.spark.util.UninterruptibleThread.interrupt(UninterruptibleThread.scala:99) - locked <0x000000070f70fe80> (a java.lang.Object) at org.apache.spark.util.UninterruptibleThreadSuite.$anonfun$new$5(UninterruptibleThreadSuite.scala:159) - locked <0x000000070f70f9f8> (a java.lang.Object) at org.apache.spark.util.UninterruptibleThreadSuite$$Lambda$216/0x000000700120d6c8.apply$mcV$sp(Unknown Source) at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.scala:18) at org.scalatest.enablers.Timed$$anon$1.timeoutAfter(Timed.scala:127) at org.scalatest.concurrent.TimeLimits$.failAfterImpl(TimeLimits.scala:282) at org.scalatest.concurrent.TimeLimits.failAfter(TimeLimits.scala:231) at org.scalatest.concurrent.TimeLimits.failAfter$(TimeLimits.scala:230) at org.apache.spark.SparkFunSuite.failAfter(SparkFunSuite.scala:69) at org.apache.spark.SparkFunSuite.$anonfun$test$2(SparkFunSuite.scala:155) at org.apache.spark.SparkFunSuite$$Lambda$205/0x0000007001207700.apply(Unknown Source) at org.scalatest.OutcomeOf.outcomeOf(OutcomeOf.scala:85) at org.scalatest.OutcomeOf.outcomeOf$(OutcomeOf.scala:83) at org.scalatest.OutcomeOf$.outcomeOf(OutcomeOf.scala:104) at org.scalatest.Transformer.apply(Transformer.scala:22) at org.scalatest.Transformer.apply(Transformer.scala:20) at org.scalatest.funsuite.AnyFunSuiteLike$$anon$1.apply(AnyFunSuiteLike.scala:226) at org.apache.spark.SparkFunSuite.withFixture(SparkFunSuite.scala:227) at org.scalatest.funsuite.AnyFunSuiteLike.invokeWithFixture$1(AnyFunSuiteLike.scala:224) at org.scalatest.funsuite.AnyFunSuiteLike.$anonfun$runTest$1(AnyFunSuiteLike.scala:236) at org.scalatest.funsuite.AnyFunSuiteLike$$Lambda$343/0x00000070012867b0.apply(Unknown Source) at org.scalatest.SuperEngine.runTestImpl(Engine.scala:306) at org.scalatest.funsuite.AnyFunSuiteLike.runTest(AnyFunSuiteLike.scala:236) at org.scalatest.funsuite.AnyFunSuiteLike.runTest$(AnyFunSuiteLike.scala:218) at org.apache.spark.SparkFunSuite.org$scalatest$BeforeAndAfterEach$$super$runTest(SparkFunSuite.scala:69) at org.scalatest.BeforeAndAfterEach.runTest(BeforeAndAfterEach.scala:234) at org.scalatest.BeforeAndAfterEach.runTest$(BeforeAndAfterEach.scala:227) at org.apache.spark.SparkFunSuite.runTest(SparkFunSuite.scala:69) at org.scalatest.funsuite.AnyFunSuiteLike.$anonfun$runTests$1(AnyFunSuiteLike.scala:269) at org.scalatest.funsuite.AnyFunSuiteLike$$Lambda$339/0x00000070012833e0.apply(Unknown Source) at org.scalatest.SuperEngine.$anonfun$runTestsInBranch$1(Engine.scala:413) at org.scalatest.SuperEngine$$Lambda$340/0x0000007001283998.apply(Unknown Source) at scala.collection.immutable.List.foreach(List.scala:334) at org.scalatest.SuperEngine.traverseSubNodes$1(Engine.scala:401) at org.scalatest.SuperEngine.runTestsInBranch(Engine.scala:396) at org.scalatest.SuperEngine.runTestsImpl(Engine.scala:475) at org.scalatest.funsuite.AnyFunSuiteLike.runTests(AnyFunSuiteLike.scala:269) at org.scalatest.funsuite.AnyFunSuiteLike.runTests$(AnyFunSuiteLike.scala:268) at org.scalatest.funsuite.AnyFunSuite.runTests(AnyFunSuite.scala:1564) at org.scalatest.Suite.run(Suite.scala:1114) at org.scalatest.Suite.run$(Suite.scala:1096) at org.scalatest.funsuite.AnyFunSuite.org$scalatest$funsuite$AnyFunSuiteLike$$super$run(AnyFunSuite.scala:1564) at org.scalatest.funsuite.AnyFunSuiteLike.$anonfun$run$1(AnyFunSuiteLike.scala:273) at org.scalatest.funsuite.AnyFunSuiteLike$$Lambda$332/0x000000700127b000.apply(Unknown Source) at org.scalatest.SuperEngine.runImpl(Engine.scala:535) at org.scalatest.funsuite.AnyFunSuiteLike.run(AnyFunSuiteLike.scala:273) at org.scalatest.funsuite.AnyFunSuiteLike.run$(AnyFunSuiteLike.scala:272) at org.apache.spark.SparkFunSuite.org$scalatest$BeforeAndAfterAll$$super$run(SparkFunSuite.scala:69) at org.scalatest.BeforeAndAfterAll.liftedTree1$1(BeforeAndAfterAll.scala:213) at org.scalatest.BeforeAndAfterAll.run(BeforeAndAfterAll.scala:210) at org.scalatest.BeforeAndAfterAll.run$(BeforeAndAfterAll.scala:208) at org.apache.spark.SparkFunSuite.run(SparkFunSuite.scala:69) at org.scalatest.tools.Framework.org$scalatest$tools$Framework$$runSuite(Framework.scala:321) at org.scalatest.tools.Framework$ScalaTestTask.execute(Framework.scala:517) at sbt.ForkMain$Run.lambda$runTest$1(ForkMain.java:414) at sbt.ForkMain$Run$$Lambda$107/0x0000007001110000.call(Unknown Source) at java.util.concurrent.FutureTask.run(java.base17.0.14/FutureTask.java:264) at java.util.concurrent.ThreadPoolExecutor.runWorker(java.base17.0.14/ThreadPoolExecutor.java:1136) at java.util.concurrent.ThreadPoolExecutor$Worker.run(java.base17.0.14/ThreadPoolExecutor.java:635) at java.lang.Thread.run(java.base17.0.14/Thread.java:840) "task thread": at org.apache.spark.util.UninterruptibleThread.interrupt(UninterruptibleThread.scala:96) - waiting to lock <0x000000070f70fe80> (a java.lang.Object) at org.apache.spark.util.UninterruptibleThreadSuite$InterruptibleChannel.implCloseChannel(UninterruptibleThreadSuite.scala:143) at java.nio.channels.spi.AbstractInterruptibleChannel.close(java.base17.0.14/AbstractInterruptibleChannel.java:112) - locked <0x000000070f3019d0> (a java.lang.Object) at org.apache.spark.util.UninterruptibleThreadSuite$InterruptibleChannel.<init>(UninterruptibleThreadSuite.scala:138) at org.apache.spark.util.UninterruptibleThreadSuite$$anon$5.run(UninterruptibleThreadSuite.scala:153) Found 1 deadlock. ``` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added 2 new test cases to the `UninterruptibleThreadSuite` ### Was this patch authored or co-authored using generative AI tooling? No Closes #50594 from vrozov/uninterruptible. Authored-by: Vlad Rozov <vrozov@amazon.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 9178c2c commit bb0b2d2

File tree

2 files changed

+149
-33
lines changed

2 files changed

+149
-33
lines changed

core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala

Lines changed: 95 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,90 @@ private[spark] class UninterruptibleThread(
3535
this(null, name)
3636
}
3737

38-
/** A monitor to protect "uninterruptible" and "interrupted" */
39-
private val uninterruptibleLock = new Object
38+
private class UninterruptibleLock {
39+
/**
40+
* Indicates if `this` thread are in the uninterruptible status. If so, interrupting
41+
* "this" will be deferred until `this` enters into the interruptible status.
42+
*/
43+
@GuardedBy("uninterruptibleLock")
44+
private var uninterruptible = false
4045

41-
/**
42-
* Indicates if `this` thread are in the uninterruptible status. If so, interrupting
43-
* "this" will be deferred until `this` enters into the interruptible status.
44-
*/
45-
@GuardedBy("uninterruptibleLock")
46-
private var uninterruptible = false
46+
/**
47+
* Indicates if we should interrupt `this` when we are leaving the uninterruptible zone.
48+
*/
49+
@GuardedBy("uninterruptibleLock")
50+
private var shouldInterruptThread = false
4751

48-
/**
49-
* Indicates if we should interrupt `this` when we are leaving the uninterruptible zone.
50-
*/
51-
@GuardedBy("uninterruptibleLock")
52-
private var shouldInterruptThread = false
52+
/**
53+
* Indicates that we should wait for interrupt() call before proceeding.
54+
*/
55+
@GuardedBy("uninterruptibleLock")
56+
private var awaitInterruptThread = false
57+
58+
/**
59+
* Set [[uninterruptible]] to given value and returns the previous value.
60+
*/
61+
def getAndSetUninterruptible(value: Boolean): Boolean = synchronized {
62+
val uninterruptible = this.uninterruptible
63+
this.uninterruptible = value
64+
uninterruptible
65+
}
66+
67+
def setShouldInterruptThread(value: Boolean): Unit = synchronized {
68+
shouldInterruptThread = value
69+
}
70+
71+
def setAwaitInterruptThread(value: Boolean): Unit = synchronized {
72+
awaitInterruptThread = value
73+
}
74+
75+
/**
76+
* Is call to [[java.lang.Thread.interrupt()]] pending
77+
*/
78+
def isInterruptPending: Boolean = synchronized {
79+
// Clear the interrupted status if it's set.
80+
shouldInterruptThread = Thread.interrupted() || shouldInterruptThread
81+
// wait for super.interrupt() to be called
82+
!shouldInterruptThread && awaitInterruptThread
83+
}
84+
85+
/**
86+
* Set [[uninterruptible]] back to false and call [[java.lang.Thread.interrupt()]] to
87+
* recover interrupt state if necessary
88+
*/
89+
def recoverInterrupt(): Unit = synchronized {
90+
uninterruptible = false
91+
if (shouldInterruptThread) {
92+
shouldInterruptThread = false
93+
// Recover the interrupted status
94+
UninterruptibleThread.super.interrupt()
95+
}
96+
}
97+
98+
/**
99+
* Is it safe to call [[java.lang.Thread.interrupt()]] and interrupt the current thread
100+
* @return true when there is no concurrent [[runUninterruptibly()]] call ([[uninterruptible]]
101+
* is true) and no concurrent [[interrupt()]] call, otherwise false
102+
*/
103+
def isInterruptible: Boolean = synchronized {
104+
shouldInterruptThread = uninterruptible
105+
// as we are releasing uninterruptibleLock before calling super.interrupt() there is a
106+
// possibility that runUninterruptibly() would be called after lock is released but before
107+
// super.interrupt() is called. In this case to prevent runUninterruptibly() from being
108+
// interrupted, we use awaitInterruptThread flag. We need to set it only if
109+
// runUninterruptibly() is not yet set uninterruptible to true (!shouldInterruptThread) and
110+
// there is no other threads that called interrupt (awaitInterruptThread is already true)
111+
if (!shouldInterruptThread && !awaitInterruptThread) {
112+
awaitInterruptThread = true
113+
true
114+
} else {
115+
false
116+
}
117+
}
118+
}
119+
120+
/** A monitor to protect "uninterruptible" and "interrupted" */
121+
private val uninterruptibleLock = new UninterruptibleLock
53122

54123
/**
55124
* Run `f` uninterruptibly in `this` thread. The thread won't be interrupted before returning
@@ -63,27 +132,23 @@ private[spark] class UninterruptibleThread(
63132
s"Expected: $this but was ${Thread.currentThread()}")
64133
}
65134

66-
if (uninterruptibleLock.synchronized { uninterruptible }) {
135+
if (uninterruptibleLock.getAndSetUninterruptible(true)) {
67136
// We are already in the uninterruptible status. So just run "f" and return
68137
return f
69138
}
70139

71-
uninterruptibleLock.synchronized {
72-
// Clear the interrupted status if it's set.
73-
shouldInterruptThread = Thread.interrupted() || shouldInterruptThread
74-
uninterruptible = true
140+
while (uninterruptibleLock.isInterruptPending) {
141+
try {
142+
Thread.sleep(100)
143+
} catch {
144+
case _: InterruptedException => uninterruptibleLock.setShouldInterruptThread(true)
145+
}
75146
}
147+
76148
try {
77149
f
78150
} finally {
79-
uninterruptibleLock.synchronized {
80-
uninterruptible = false
81-
if (shouldInterruptThread) {
82-
// Recover the interrupted status
83-
super.interrupt()
84-
shouldInterruptThread = false
85-
}
86-
}
151+
uninterruptibleLock.recoverInterrupt()
87152
}
88153
}
89154

@@ -92,11 +157,11 @@ private[spark] class UninterruptibleThread(
92157
* interrupted until it enters into the interruptible status.
93158
*/
94159
override def interrupt(): Unit = {
95-
uninterruptibleLock.synchronized {
96-
if (uninterruptible) {
97-
shouldInterruptThread = true
98-
} else {
160+
if (uninterruptibleLock.isInterruptible) {
161+
try {
99162
super.interrupt()
163+
} finally {
164+
uninterruptibleLock.setAwaitInterruptThread(false)
100165
}
101166
}
102167
}

core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.util
1919

20+
import java.nio.channels.spi.AbstractInterruptibleChannel
2021
import java.util.concurrent.{CountDownLatch, TimeUnit}
2122

2223
import scala.util.Random
@@ -115,6 +116,45 @@ class UninterruptibleThreadSuite extends SparkFunSuite {
115116
assert(interruptStatusBeforeExit)
116117
}
117118

119+
test("no runUninterruptibly") {
120+
@volatile var hasInterruptedException = false
121+
val t = new UninterruptibleThread("test") {
122+
override def run(): Unit = {
123+
if (sleep(0)) {
124+
hasInterruptedException = true
125+
}
126+
}
127+
}
128+
t.interrupt()
129+
t.start()
130+
t.join()
131+
assert(hasInterruptedException === true)
132+
}
133+
134+
test("SPARK-51821 uninterruptibleLock deadlock") {
135+
val latch = new CountDownLatch(1)
136+
val task = new UninterruptibleThread("task thread") {
137+
override def run(): Unit = {
138+
val channel = new AbstractInterruptibleChannel() {
139+
override def implCloseChannel(): Unit = {
140+
begin()
141+
latch.countDown()
142+
try {
143+
Thread.sleep(Long.MaxValue)
144+
} catch {
145+
case _: InterruptedException => Thread.currentThread().interrupt()
146+
}
147+
}
148+
}
149+
channel.close()
150+
}
151+
}
152+
task.start()
153+
assert(latch.await(10, TimeUnit.SECONDS), "await timeout")
154+
task.interrupt()
155+
task.join()
156+
}
157+
118158
test("stress test") {
119159
@volatile var hasInterruptedException = false
120160
val t = new UninterruptibleThread("test") {
@@ -148,9 +188,20 @@ class UninterruptibleThreadSuite extends SparkFunSuite {
148188
}
149189
}
150190
t.start()
151-
for (i <- 0 until 400) {
152-
Thread.sleep(Random.nextInt(10))
153-
t.interrupt()
191+
val threads = new Array[Thread](10)
192+
for (j <- 0 until 10) {
193+
threads(j) = new Thread() {
194+
override def run(): Unit = {
195+
for (i <- 0 until 400) {
196+
Thread.sleep(Random.nextInt(10))
197+
t.interrupt()
198+
}
199+
}
200+
}
201+
threads(j).start()
202+
}
203+
for (j <- 0 until 10) {
204+
threads(j).join()
154205
}
155206
t.join()
156207
assert(hasInterruptedException === false)

0 commit comments

Comments
 (0)