Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50366][SQL] Isolate user-defined tags on thread level for SparkSession in Classic #48906

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 33 additions & 14 deletions sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ package org.apache.spark.sql
import java.net.URI
import java.nio.file.Paths
import java.util.{ServiceLoader, UUID}
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicBoolean

import scala.collection.mutable
import scala.concurrent.duration.DurationInt
import scala.jdk.CollectionConverters._
import scala.reflect.runtime.universe.TypeTag
Expand Down Expand Up @@ -133,14 +133,34 @@ class SparkSession private(
/** Tag to mark all jobs owned by this session. */
private[sql] lazy val sessionJobTag = s"spark-session-$sessionUUID"

/**
* A UUID that is unique on the thread level. Used by managedJobTags to make sure that a same
* tag from two threads does not overlap in the underlying SparkContext/SQLExecution.
*/
private[sql] lazy val threadUuid = new InheritableThreadLocal[String] {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we have a simple test for this inheritance? Otherwise, looks fine from a cursory look.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

override def childValue(parent: String): String = parent

override def initialValue(): String = UUID.randomUUID().toString
}

/**
* A map to hold the mapping from user-defined tags to the real tags attached to Jobs.
* Real tag have the current session ID attached: `"tag1" -> s"spark-session-$sessionUUID-tag1"`.
* Real tag have the current session ID attached:
* tag1 -> spark-session-$sessionUUID-thread-$threadUuid-tag1
*
*/
@transient
private[sql] lazy val managedJobTags: ConcurrentHashMap[String, String] = {
new ConcurrentHashMap(parentManagedJobTags.asJava)
}
private[sql] lazy val managedJobTags = new InheritableThreadLocal[mutable.Map[String, String]] {
override def childValue(parent: mutable.Map[String, String]): mutable.Map[String, String] = {
// Note: make a clone such that changes in the parent tags aren't reflected in
// those of the children threads.
parent.clone()
}

override def initialValue(): mutable.Map[String, String] = {
mutable.Map(parentManagedJobTags.toSeq: _*)
}
}

/** @inheritdoc */
def version: String = SPARK_VERSION
Expand Down Expand Up @@ -243,10 +263,10 @@ class SparkSession private(
Some(sessionState),
extensions,
Map.empty,
managedJobTags.asScala.toMap)
managedJobTags.get().toMap)
result.sessionState // force copy of SessionState
result.sessionState.artifactManager // force copy of ArtifactManager and its resources
result.managedJobTags // force copy of userDefinedToRealTagsMap
result.managedJobTags // force copy of managedJobTags
result
}

Expand Down Expand Up @@ -550,17 +570,17 @@ class SparkSession private(
/** @inheritdoc */
override def addTag(tag: String): Unit = {
SparkContext.throwIfInvalidTag(tag)
managedJobTags.put(tag, s"spark-session-$sessionUUID-$tag")
managedJobTags.get().put(tag, s"spark-session-$sessionUUID-thread-${threadUuid.get()}-$tag")
}

/** @inheritdoc */
override def removeTag(tag: String): Unit = managedJobTags.remove(tag)
override def removeTag(tag: String): Unit = managedJobTags.get().remove(tag)

/** @inheritdoc */
override def getTags(): Set[String] = managedJobTags.keys().asScala.toSet
override def getTags(): Set[String] = managedJobTags.get().keySet.toSet

/** @inheritdoc */
override def clearTags(): Unit = managedJobTags.clear()
override def clearTags(): Unit = managedJobTags.get().clear()

/**
* Request to interrupt all currently running SQL operations of this session.
Expand Down Expand Up @@ -589,9 +609,8 @@ class SparkSession private(
* @since 4.0.0
*/
override def interruptTag(tag: String): Seq[String] = {
val realTag = managedJobTags.get(tag)
if (realTag == null) return Seq.empty
doInterruptTag(realTag, s"part of cancelled job tags $tag")
val realTag = managedJobTags.get().get(tag)
realTag.map(doInterruptTag(_, s"part of cancelled job tags $tag")).getOrElse(Seq.empty)
}

private def doInterruptTag(tag: String, reason: String): Seq[String] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ object SQLExecution extends Logging {
}

private[sql] def withSessionTagsApplied[T](sparkSession: SparkSession)(block: => T): T = {
val allTags = sparkSession.managedJobTags.values().asScala.toSet + sparkSession.sessionJobTag
val allTags = sparkSession.managedJobTags.get().values.toSet + sparkSession.sessionJobTag
sparkSession.sparkContext.addJobTags(allTags)

try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql

import java.util.concurrent.{ConcurrentHashMap, Semaphore, TimeUnit}
import java.util.concurrent.{ConcurrentHashMap, Executors, Semaphore, TimeUnit}
import java.util.concurrent.atomic.AtomicInteger

import scala.concurrent.{ExecutionContext, Future}
Expand Down Expand Up @@ -100,13 +100,14 @@ class SparkSessionJobTaggingAndCancellationSuite

assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES))
val activeJobsFuture =
session.sparkContext.cancelJobsWithTagWithFuture(session.managedJobTags.get("one"), "reason")
session.sparkContext.cancelJobsWithTagWithFuture(
session.managedJobTags.get()("one"), "reason")
val activeJob = ThreadUtils.awaitResult(activeJobsFuture, 60.seconds).head
val actualTags = activeJob.properties.getProperty(SparkContext.SPARK_JOB_TAGS)
.split(SparkContext.SPARK_JOB_TAGS_SEP)
assert(actualTags.toSet == Set(
session.sessionJobTag,
s"${session.sessionJobTag}-one",
s"${session.sessionJobTag}-thread-${session.threadUuid.get()}-one",
SQLExecution.executionIdJobTag(
session,
activeJob.properties.get(SQLExecution.EXECUTION_ROOT_ID_KEY).asInstanceOf[String].toLong)))
Expand All @@ -118,12 +119,12 @@ class SparkSessionJobTaggingAndCancellationSuite
val globalSession = SparkSession.builder().sparkContext(sc).getOrCreate()
var (sessionA, sessionB, sessionC): (SparkSession, SparkSession, SparkSession) =
(null, null, null)
var (threadUuidA, threadUuidB, threadUuidC): (String, String, String) = (null, null, null)

// global ExecutionContext has only 2 threads in Apache Spark CI
// create own thread pool for four Futures used in this test
val numThreads = 3
val fpool = ThreadUtils.newForkJoinPool("job-tags-test-thread-pool", numThreads)
val executionContext = ExecutionContext.fromExecutorService(fpool)
val threadPool = Executors.newFixedThreadPool(3)
implicit val ec: ExecutionContext = ExecutionContext.fromExecutor(threadPool)

try {
// Add a listener to release the semaphore once jobs are launched.
Expand All @@ -143,28 +144,50 @@ class SparkSessionJobTaggingAndCancellationSuite
}
})

var realTagOneForSessionA: String = null
var childThread: Thread = null
val childThreadLock = new Semaphore(0)

// Note: since tags are added in the Future threads, they don't need to be cleared in between.
val jobA = Future {
sessionA = globalSession.cloneSession()
import globalSession.implicits._

threadUuidA = sessionA.threadUuid.get()
assert(sessionA.getTags() == Set())
sessionA.addTag("two")
assert(sessionA.getTags() == Set("two"))
sessionA.clearTags() // check that clearing all tags works
assert(sessionA.getTags() == Set())
sessionA.addTag("one")
realTagOneForSessionA = sessionA.managedJobTags.get()("one")
assert(realTagOneForSessionA ==
s"${sessionA.sessionJobTag}-thread-${sessionA.threadUuid.get()}-one")
assert(sessionA.getTags() == Set("one"))

// Create a child thread which inherits thread-local variables and tries to interrupt
// the job started from the parent thread. The child thread is blocked until the main
// thread releases the lock.
childThread = new Thread {
override def run(): Unit = {
assert(childThreadLock.tryAcquire(1, 20, TimeUnit.SECONDS))
assert(sessionA.getTags() == Set("one"))
assert(sessionA.interruptTag("one").size == 1)
}
}
childThread.start()
try {
sessionA.range(1, 10000).map { i => Thread.sleep(100); i }.count()
} finally {
childThread.interrupt()
sessionA.clearTags() // clear for the case of thread reuse by another Future
}
}(executionContext)
}
val jobB = Future {
sessionB = globalSession.cloneSession()
import globalSession.implicits._

threadUuidB = sessionB.threadUuid.get()
assert(sessionB.getTags() == Set())
sessionB.addTag("one")
sessionB.addTag("two")
Expand All @@ -176,11 +199,12 @@ class SparkSessionJobTaggingAndCancellationSuite
} finally {
sessionB.clearTags() // clear for the case of thread reuse by another Future
}
}(executionContext)
}
val jobC = Future {
sessionC = globalSession.cloneSession()
import globalSession.implicits._

threadUuidC = sessionC.threadUuid.get()
sessionC.addTag("foo")
sessionC.removeTag("foo")
assert(sessionC.getTags() == Set()) // check that remove works removing the last tag
Expand All @@ -190,12 +214,13 @@ class SparkSessionJobTaggingAndCancellationSuite
} finally {
sessionC.clearTags() // clear for the case of thread reuse by another Future
}
}(executionContext)
}

// Block until four jobs have started.
assert(sem.tryAcquire(3, 1, TimeUnit.MINUTES))

// Tags are applied
def realUserTag(s: String, t: String, ta: String): String = s"spark-session-$s-thread-$t-$ta"
assert(jobProperties.size == 3)
for (ss <- Seq(sessionA, sessionB, sessionC)) {
val jobProperty = jobProperties.values().asScala.filter(_.get(SparkContext.SPARK_JOB_TAGS)
Expand All @@ -207,15 +232,17 @@ class SparkSessionJobTaggingAndCancellationSuite
val executionRootIdTag = SQLExecution.executionIdJobTag(
ss,
jobProperty.head.get(SQLExecution.EXECUTION_ROOT_ID_KEY).asInstanceOf[String].toLong)
val userTagsPrefix = s"spark-session-${ss.sessionUUID}-"

ss match {
case s if s == sessionA => assert(tags.toSet == Set(
s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}one"))
s.sessionJobTag, executionRootIdTag, realUserTag(s.sessionUUID, threadUuidA, "one")))
case s if s == sessionB => assert(tags.toSet == Set(
s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}one", s"${userTagsPrefix}two"))
s.sessionJobTag,
executionRootIdTag,
realUserTag(s.sessionUUID, threadUuidB, "one"),
realUserTag(s.sessionUUID, threadUuidB, "two")))
case s if s == sessionC => assert(tags.toSet == Set(
s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}boo"))
s.sessionJobTag, executionRootIdTag, realUserTag(s.sessionUUID, threadUuidC, "boo")))
}
}

Expand All @@ -239,8 +266,10 @@ class SparkSessionJobTaggingAndCancellationSuite
assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES))
assert(jobEnded.intValue == 1)

// Another job cancelled
assert(sessionA.interruptTag("one").size == 1)
// Another job cancelled. The next line cancels nothing because we're now in another thread.
// The real cancel is done through unblocking a child thread, which is waiting for a lock
assert(sessionA.interruptTag("one").isEmpty)
childThreadLock.release()
val eA = intercept[SparkException] {
ThreadUtils.awaitResult(jobA, 1.minute)
}.getCause
Expand All @@ -257,7 +286,48 @@ class SparkSessionJobTaggingAndCancellationSuite
assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES))
assert(jobEnded.intValue == 3)
} finally {
fpool.shutdownNow()
threadPool.shutdownNow()
}
}

test("Tags are isolated in multithreaded environment") {
// Custom thread pool for multi-threaded testing
val threadPool = Executors.newFixedThreadPool(2)
implicit val ec: ExecutionContext = ExecutionContext.fromExecutor(threadPool)

val session = SparkSession.builder().master("local").getOrCreate()
@volatile var output1: Set[String] = null
@volatile var output2: Set[String] = null

def tag1(): Unit = {
session.addTag("tag1")
output1 = session.getTags()
}

def tag2(): Unit = {
session.addTag("tag2")
output2 = session.getTags()
}

try {
// Run tasks in separate threads
val future1 = Future {
tag1()
}
val future2 = Future {
tag2()
}

// Wait for threads to complete
ThreadUtils.awaitResult(Future.sequence(Seq(future1, future2)), 1.minute)

// Assert outputs
assert(output1 != null)
assert(output1 == Set("tag1"))
assert(output2 != null)
assert(output2 == Set("tag2"))
} finally {
threadPool.shutdownNow()
}
}
}