diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index d01de3b9ed08..138e7da9569d 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -25,6 +25,7 @@ import java.nio.ByteBuffer import java.util.{Locale, Properties} import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.locks.ReentrantLock import javax.annotation.concurrent.GuardedBy import javax.ws.rs.core.UriBuilder @@ -85,6 +86,11 @@ private[spark] class Executor( private[executor] val conf = env.conf + // SPARK-40235: updateDependencies() uses a ReentrantLock instead of the `synchronized` keyword + // so that tasks can exit quickly if they are interrupted while waiting on another task to + // finish downloading dependencies. + private val updateDependenciesLock = new ReentrantLock() + // No ip or host:port - just hostname Utils.checkHost(executorHostname) // must not have port specified. @@ -969,13 +975,19 @@ private[spark] class Executor( /** * Download any missing dependencies if we receive a new set of files and JARs from the * SparkContext. Also adds any new JARs we fetched to the class loader. + * Visible for testing. */ - private def updateDependencies( + private[executor] def updateDependencies( newFiles: Map[String, Long], newJars: Map[String, Long], - newArchives: Map[String, Long]): Unit = { + newArchives: Map[String, Long], + testStartLatch: Option[CountDownLatch] = None, + testEndLatch: Option[CountDownLatch] = None): Unit = { lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) - synchronized { + updateDependenciesLock.lockInterruptibly() + try { + // For testing, so we can simulate a slow file download: + testStartLatch.foreach(_.countDown()) // Fetch missing dependencies for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { logInfo(s"Fetching $name with timestamp $timestamp") @@ -1018,6 +1030,10 @@ private[spark] class Executor( } } } + // For testing, so we can simulate a slow file download: + testEndLatch.foreach(_.await()) + } finally { + updateDependenciesLock.unlock() } } diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 14871efac5bc..bef36d08e8ae 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -514,6 +514,59 @@ class ExecutorSuite extends SparkFunSuite } } + test("SPARK-40235: updateDependencies is interruptible when waiting on lock") { + val conf = new SparkConf + val serializer = new JavaSerializer(conf) + val env = createMockEnv(conf, serializer) + withExecutor("id", "localhost", env) { executor => + val startLatch = new CountDownLatch(1) + val endLatch = new CountDownLatch(1) + + // Start a thread to simulate a task that begins executing updateDependencies() + // and takes a long time to finish because file download is slow: + val slowLibraryDownloadThread = new Thread(() => { + executor.updateDependencies( + Map.empty, + Map.empty, + Map.empty, + Some(startLatch), + Some(endLatch)) + }) + slowLibraryDownloadThread.start() + + // Wait for that thread to acquire the lock: + startLatch.await() + + // Start a second thread to simulate a task that blocks on the other task's + // dependency update: + val blockedLibraryDownloadThread = new Thread(() => { + executor.updateDependencies( + Map.empty, + Map.empty, + Map.empty) + }) + blockedLibraryDownloadThread.start() + eventually(timeout(10.seconds), interval(100.millis)) { + val threadState = blockedLibraryDownloadThread.getState + assert(Set(Thread.State.BLOCKED, Thread.State.WAITING).contains(threadState)) + } + + // Interrupt the blocked thread: + blockedLibraryDownloadThread.interrupt() + + // The thread should exit: + eventually(timeout(10.seconds), interval(100.millis)) { + assert(blockedLibraryDownloadThread.getState == Thread.State.TERMINATED) + } + + // Allow the first thread to finish and exit: + endLatch.countDown() + eventually(timeout(10.seconds), interval(100.millis)) { + assert(slowLibraryDownloadThread.getState == Thread.State.TERMINATED) + } + } + } + private def createMockEnv(conf: SparkConf, serializer: JavaSerializer): SparkEnv = { val mockEnv = mock[SparkEnv] val mockRpcEnv = mock[RpcEnv]