From f0531002477d2919abdad06c7b69d54a7761dd86 Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Wed, 20 Nov 2024 12:24:32 +0100 Subject: [PATCH 1/5] isolate tags per thread --- .../org/apache/spark/sql/SparkSession.scala | 41 ++++++--- .../spark/sql/execution/SQLExecution.scala | 2 +- ...essionJobTaggingAndCancellationSuite.scala | 89 +++++++++++++++---- 3 files changed, 101 insertions(+), 31 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index afc0a2d7df604..4b56a055641b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql import java.net.URI import java.nio.file.Paths import java.util.{ServiceLoader, UUID} -import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicBoolean +import scala.collection.mutable import scala.concurrent.duration.DurationInt import scala.jdk.CollectionConverters._ import scala.reflect.runtime.universe.TypeTag @@ -133,14 +133,30 @@ class SparkSession private( /** Tag to mark all jobs owned by this session. */ private[sql] lazy val sessionJobTag = s"spark-session-$sessionUUID" + /** + * A UUID that is unique on the thread level. Used by managedJobTags to make sure that the same + * use tag do not overlap in the underlying SparkContext/SQLExecution. + */ + private[sql] lazy val threadUuid = new ThreadLocal[String] { + override def initialValue(): String = UUID.randomUUID().toString + } + /** * A map to hold the mapping from user-defined tags to the real tags attached to Jobs. * Real tag have the current session ID attached: `"tag1" -> s"spark-session-$sessionUUID-tag1"`. */ @transient - private[sql] lazy val managedJobTags: ConcurrentHashMap[String, String] = { - new ConcurrentHashMap(parentManagedJobTags.asJava) - } + private[sql] lazy val managedJobTags = new InheritableThreadLocal[mutable.Map[String, String]] { + override def childValue(parent: mutable.Map[String, String]): mutable.Map[String, String] = { + // Note: make a clone such that changes in the parent tags aren't reflected in + // those of the children threads. + parent.clone() + } + + override def initialValue(): mutable.Map[String, String] = { + mutable.Map(parentManagedJobTags.toSeq: _*) + } + } /** @inheritdoc */ def version: String = SPARK_VERSION @@ -243,10 +259,10 @@ class SparkSession private( Some(sessionState), extensions, Map.empty, - managedJobTags.asScala.toMap) + managedJobTags.get().toMap) result.sessionState // force copy of SessionState result.sessionState.artifactManager // force copy of ArtifactManager and its resources - result.managedJobTags // force copy of userDefinedToRealTagsMap + result.managedJobTags // force copy of managedJobTags result } @@ -550,17 +566,17 @@ class SparkSession private( /** @inheritdoc */ override def addTag(tag: String): Unit = { SparkContext.throwIfInvalidTag(tag) - managedJobTags.put(tag, s"spark-session-$sessionUUID-$tag") + managedJobTags.get().put(tag, s"spark-session-$sessionUUID-thread-${threadUuid.get()}-$tag") } /** @inheritdoc */ - override def removeTag(tag: String): Unit = managedJobTags.remove(tag) + override def removeTag(tag: String): Unit = managedJobTags.get().remove(tag) /** @inheritdoc */ - override def getTags(): Set[String] = managedJobTags.keys().asScala.toSet + override def getTags(): Set[String] = managedJobTags.get().keySet.toSet /** @inheritdoc */ - override def clearTags(): Unit = managedJobTags.clear() + override def clearTags(): Unit = managedJobTags.get().clear() /** * Request to interrupt all currently running SQL operations of this session. @@ -589,9 +605,8 @@ class SparkSession private( * @since 4.0.0 */ override def interruptTag(tag: String): Seq[String] = { - val realTag = managedJobTags.get(tag) - if (realTag == null) return Seq.empty - doInterruptTag(realTag, s"part of cancelled job tags $tag") + val realTag = managedJobTags.get().get(tag) + realTag.map(doInterruptTag(_, s"part of cancelled job tags $tag")).getOrElse(Seq.empty) } private def doInterruptTag(tag: String, reason: String): Seq[String] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index e805aabe013cf..242149010ceef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -261,7 +261,7 @@ object SQLExecution extends Logging { } private[sql] def withSessionTagsApplied[T](sparkSession: SparkSession)(block: => T): T = { - val allTags = sparkSession.managedJobTags.values().asScala.toSet + sparkSession.sessionJobTag + val allTags = sparkSession.managedJobTags.get().values.toSet + sparkSession.sessionJobTag sparkSession.sparkContext.addJobTags(allTags) try { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala index 1ac51b408301a..99203d45eb073 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.util.concurrent.{ConcurrentHashMap, Semaphore, TimeUnit} +import java.util.concurrent.{ConcurrentHashMap, Executors, Semaphore, TimeUnit} import java.util.concurrent.atomic.AtomicInteger import scala.concurrent.{ExecutionContext, Future} @@ -100,13 +100,14 @@ class SparkSessionJobTaggingAndCancellationSuite assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES)) val activeJobsFuture = - session.sparkContext.cancelJobsWithTagWithFuture(session.managedJobTags.get("one"), "reason") + session.sparkContext.cancelJobsWithTagWithFuture( + session.managedJobTags.get()("one"), "reason") val activeJob = ThreadUtils.awaitResult(activeJobsFuture, 60.seconds).head val actualTags = activeJob.properties.getProperty(SparkContext.SPARK_JOB_TAGS) .split(SparkContext.SPARK_JOB_TAGS_SEP) assert(actualTags.toSet == Set( session.sessionJobTag, - s"${session.sessionJobTag}-one", + s"${session.sessionJobTag}-thread-${session.threadUuid.get()}-one", SQLExecution.executionIdJobTag( session, activeJob.properties.get(SQLExecution.EXECUTION_ROOT_ID_KEY).asInstanceOf[String].toLong))) @@ -118,12 +119,12 @@ class SparkSessionJobTaggingAndCancellationSuite val globalSession = SparkSession.builder().sparkContext(sc).getOrCreate() var (sessionA, sessionB, sessionC): (SparkSession, SparkSession, SparkSession) = (null, null, null) + var (threadUuidA, threadUuidB, threadUuidC): (String, String, String) = (null, null, null) // global ExecutionContext has only 2 threads in Apache Spark CI // create own thread pool for four Futures used in this test - val numThreads = 3 - val fpool = ThreadUtils.newForkJoinPool("job-tags-test-thread-pool", numThreads) - val executionContext = ExecutionContext.fromExecutorService(fpool) + val threadPool = Executors.newFixedThreadPool(3) + implicit val ec: ExecutionContext = ExecutionContext.fromExecutor(threadPool) try { // Add a listener to release the semaphore once jobs are launched. @@ -143,28 +144,35 @@ class SparkSessionJobTaggingAndCancellationSuite } }) + var realTagOneForSessionA: String = null + // Note: since tags are added in the Future threads, they don't need to be cleared in between. val jobA = Future { sessionA = globalSession.cloneSession() import globalSession.implicits._ + threadUuidA = sessionA.threadUuid.get() assert(sessionA.getTags() == Set()) sessionA.addTag("two") assert(sessionA.getTags() == Set("two")) sessionA.clearTags() // check that clearing all tags works assert(sessionA.getTags() == Set()) sessionA.addTag("one") + realTagOneForSessionA = sessionA.managedJobTags.get()("one") + assert(realTagOneForSessionA == + s"${sessionA.sessionJobTag}-thread-${sessionA.threadUuid.get()}-one") assert(sessionA.getTags() == Set("one")) try { sessionA.range(1, 10000).map { i => Thread.sleep(100); i }.count() } finally { sessionA.clearTags() // clear for the case of thread reuse by another Future } - }(executionContext) + } val jobB = Future { sessionB = globalSession.cloneSession() import globalSession.implicits._ + threadUuidB = sessionB.threadUuid.get() assert(sessionB.getTags() == Set()) sessionB.addTag("one") sessionB.addTag("two") @@ -176,11 +184,12 @@ class SparkSessionJobTaggingAndCancellationSuite } finally { sessionB.clearTags() // clear for the case of thread reuse by another Future } - }(executionContext) + } val jobC = Future { sessionC = globalSession.cloneSession() import globalSession.implicits._ + threadUuidC = sessionC.threadUuid.get() sessionC.addTag("foo") sessionC.removeTag("foo") assert(sessionC.getTags() == Set()) // check that remove works removing the last tag @@ -190,12 +199,13 @@ class SparkSessionJobTaggingAndCancellationSuite } finally { sessionC.clearTags() // clear for the case of thread reuse by another Future } - }(executionContext) + } // Block until four jobs have started. assert(sem.tryAcquire(3, 1, TimeUnit.MINUTES)) // Tags are applied + def realUserTag(s: String, t: String, ta: String): String = s"spark-session-$s-thread-$t-$ta" assert(jobProperties.size == 3) for (ss <- Seq(sessionA, sessionB, sessionC)) { val jobProperty = jobProperties.values().asScala.filter(_.get(SparkContext.SPARK_JOB_TAGS) @@ -207,15 +217,17 @@ class SparkSessionJobTaggingAndCancellationSuite val executionRootIdTag = SQLExecution.executionIdJobTag( ss, jobProperty.head.get(SQLExecution.EXECUTION_ROOT_ID_KEY).asInstanceOf[String].toLong) - val userTagsPrefix = s"spark-session-${ss.sessionUUID}-" ss match { case s if s == sessionA => assert(tags.toSet == Set( - s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}one")) + s.sessionJobTag, executionRootIdTag, realUserTag(s.sessionUUID, threadUuidA, "one"))) case s if s == sessionB => assert(tags.toSet == Set( - s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}one", s"${userTagsPrefix}two")) + s.sessionJobTag, + executionRootIdTag, + realUserTag(s.sessionUUID, threadUuidB, "one"), + realUserTag(s.sessionUUID, threadUuidB, "two"))) case s if s == sessionC => assert(tags.toSet == Set( - s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}boo")) + s.sessionJobTag, executionRootIdTag, realUserTag(s.sessionUUID, threadUuidC, "boo"))) } } @@ -239,12 +251,14 @@ class SparkSessionJobTaggingAndCancellationSuite assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES)) assert(jobEnded.intValue == 1) - // Another job cancelled - assert(sessionA.interruptTag("one").size == 1) + // Another job cancelled. The next line cancels nothing because we're now in another thread + assert(sessionA.interruptTag("one").isEmpty) + // Have to cancel it via SparkContext using the real tag + sessionA.sparkContext.cancelJobsWithTagWithFuture(realTagOneForSessionA, "abc") val eA = intercept[SparkException] { ThreadUtils.awaitResult(jobA, 1.minute) }.getCause - assert(eA.getMessage contains "cancelled job tags one") + assert(eA.getMessage contains "abc") assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES)) assert(jobEnded.intValue == 2) @@ -257,7 +271,48 @@ class SparkSessionJobTaggingAndCancellationSuite assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES)) assert(jobEnded.intValue == 3) } finally { - fpool.shutdownNow() + threadPool.shutdownNow() + } + } + + test("Tags are isolated in multithreaded environment") { + // Custom thread pool for multi-threaded testing + val threadPool = Executors.newFixedThreadPool(2) + implicit val ec: ExecutionContext = ExecutionContext.fromExecutor(threadPool) + + val session = SparkSession.builder().master("local").getOrCreate() + @volatile var output1: Set[String] = null + @volatile var output2: Set[String] = null + + def tag1(): Unit = { + session.addTag("tag1") + output1 = session.getTags() + } + + def tag2(): Unit = { + session.addTag("tag2") + output2 = session.getTags() + } + + try { + // Run tasks in separate threads + val future1 = Future { + tag1() + } + val future2 = Future { + tag2() + } + + // Wait for threads to complete + ThreadUtils.awaitResult(Future.sequence(Seq(future1, future2)), 1.minute) + + // Assert outputs + assert(output1 != null) + assert(output1 == Set("tag1")) + assert(output2 != null) + assert(output2 == Set("tag2")) + } finally { + threadPool.shutdownNow() } } } From 1e320204b262025ead631cd9150ad4cbe2812e1a Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Wed, 20 Nov 2024 12:31:23 +0100 Subject: [PATCH 2/5] doc --- .../src/main/scala/org/apache/spark/sql/SparkSession.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 4b56a055641b8..81c8795ba0e8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -143,7 +143,9 @@ class SparkSession private( /** * A map to hold the mapping from user-defined tags to the real tags attached to Jobs. - * Real tag have the current session ID attached: `"tag1" -> s"spark-session-$sessionUUID-tag1"`. + * Real tag have the current session ID attached: + * tag1" -> s"spark-session-$sessionUUID-thread-$threadUuid-tag1 + * */ @transient private[sql] lazy val managedJobTags = new InheritableThreadLocal[mutable.Map[String, String]] { From 14268ee9175c79f0592189e95ecbdae5eb890cbc Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Wed, 20 Nov 2024 13:37:35 +0100 Subject: [PATCH 3/5] address comments --- .../main/scala/org/apache/spark/sql/SparkSession.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 81c8795ba0e8a..5afb22ffc9f0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -134,10 +134,12 @@ class SparkSession private( private[sql] lazy val sessionJobTag = s"spark-session-$sessionUUID" /** - * A UUID that is unique on the thread level. Used by managedJobTags to make sure that the same - * use tag do not overlap in the underlying SparkContext/SQLExecution. + * A UUID that is unique on the thread level. Used by managedJobTags to make sure that a same + * tag from two threads does not overlap in the underlying SparkContext/SQLExecution. */ - private[sql] lazy val threadUuid = new ThreadLocal[String] { + private[sql] lazy val threadUuid = new InheritableThreadLocal[String] { + override def childValue(parent: String): String = parent + override def initialValue(): String = UUID.randomUUID().toString } From 086082750ba9c29deffc45b583e15899d7d629cb Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Wed, 20 Nov 2024 14:56:31 +0100 Subject: [PATCH 4/5] Add interruptTag test --- ...essionJobTaggingAndCancellationSuite.scala | 23 +++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala index 99203d45eb073..89500fe51f3ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala @@ -145,6 +145,8 @@ class SparkSessionJobTaggingAndCancellationSuite }) var realTagOneForSessionA: String = null + var childThread: Thread = null + val childThreadLock = new Semaphore(0) // Note: since tags are added in the Future threads, they don't need to be cleared in between. val jobA = Future { @@ -162,9 +164,22 @@ class SparkSessionJobTaggingAndCancellationSuite assert(realTagOneForSessionA == s"${sessionA.sessionJobTag}-thread-${sessionA.threadUuid.get()}-one") assert(sessionA.getTags() == Set("one")) + + // Create a child thread which inherits thread-local variables and tries to interrupt + // the job started from the parent thread. The child thread is blocked until the main + // thread releases the lock. + childThread = new Thread { + override def run(): Unit = { + assert(childThreadLock.tryAcquire(1, 20, TimeUnit.SECONDS)) + assert(sessionA.getTags() == Set("one")) + assert(sessionA.interruptTag("one").size == 1) + } + } + childThread.start() try { sessionA.range(1, 10000).map { i => Thread.sleep(100); i }.count() } finally { + childThread.interrupt() sessionA.clearTags() // clear for the case of thread reuse by another Future } } @@ -251,14 +266,14 @@ class SparkSessionJobTaggingAndCancellationSuite assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES)) assert(jobEnded.intValue == 1) - // Another job cancelled. The next line cancels nothing because we're now in another thread + // Another job cancelled. The next line cancels nothing because we're now in another thread. + // The real cancel is done through unblocking a child thread, which is waiting for a lock assert(sessionA.interruptTag("one").isEmpty) - // Have to cancel it via SparkContext using the real tag - sessionA.sparkContext.cancelJobsWithTagWithFuture(realTagOneForSessionA, "abc") + childThreadLock.release() val eA = intercept[SparkException] { ThreadUtils.awaitResult(jobA, 1.minute) }.getCause - assert(eA.getMessage contains "abc") + assert(eA.getMessage contains "cancelled job tags one") assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES)) assert(jobEnded.intValue == 2) From 2c8419b537a0229c38e4c56d1fc5c243d017b737 Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Thu, 21 Nov 2024 11:17:39 +0100 Subject: [PATCH 5/5] fix comment --- sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 5afb22ffc9f0a..a7f85db12b214 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -146,7 +146,7 @@ class SparkSession private( /** * A map to hold the mapping from user-defined tags to the real tags attached to Jobs. * Real tag have the current session ID attached: - * tag1" -> s"spark-session-$sessionUUID-thread-$threadUuid-tag1 + * tag1 -> spark-session-$sessionUUID-thread-$threadUuid-tag1 * */ @transient