diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecutePlanHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecutePlanHolder.scala index a3c17b9826ebc..9bf9df07e0173 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecutePlanHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecutePlanHolder.scala @@ -27,15 +27,16 @@ case class ExecutePlanHolder( sessionHolder: SessionHolder, request: proto.ExecutePlanRequest) { - val jobGroupId = - s"User_${sessionHolder.userId}_Session_${sessionHolder.sessionId}_Request_${operationId}" + val jobTag = + "SparkConnect_" + + s"User_${sessionHolder.userId}_Session_${sessionHolder.sessionId}_Request_${operationId}" def interrupt(): Unit = { // TODO/WIP: This only interrupts active Spark jobs that are actively running. // This would then throw the error from ExecutePlan and terminate it. // But if the query is not running a Spark job, but executing code on Spark driver, this // would be a noop and the execution will keep running. - sessionHolder.session.sparkContext.cancelJobGroup(jobGroupId) + sessionHolder.session.sparkContext.cancelJobsWithTag(jobTag) } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index d11f4dcc6002c..70204f2913da7 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -63,8 +63,12 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp } val executeHolder = sessionHolder.createExecutePlanHolder(v) + session.sparkContext.addJobTag(executeHolder.jobTag) + session.sparkContext.setInterruptOnCancel(true) + // Also set the tag as the JobGroup for all the jobs in the query. + // TODO: In the long term, it should be encouraged to use job tag only. session.sparkContext.setJobGroup( - executeHolder.jobGroupId, + executeHolder.jobTag, s"Spark Connect - ${StringUtils.abbreviate(debugString, 128)}", interruptOnCancel = true) @@ -89,6 +93,8 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp throw new UnsupportedOperationException(s"${v.getPlan.getOpTypeCase} not supported.") } } finally { + session.sparkContext.removeJobTag(executeHolder.jobTag) + session.sparkContext.clearJobGroup() sessionHolder.removeExecutePlanHolder(executeHolder.operationId) } } diff --git a/core/src/main/protobuf/org/apache/spark/status/protobuf/store_types.proto b/core/src/main/protobuf/org/apache/spark/status/protobuf/store_types.proto index 94ce1b8b58a34..93365add3a642 100644 --- a/core/src/main/protobuf/org/apache/spark/status/protobuf/store_types.proto +++ b/core/src/main/protobuf/org/apache/spark/status/protobuf/store_types.proto @@ -47,6 +47,7 @@ message JobData { optional int64 completion_time = 5; repeated int64 stage_ids = 6; optional string job_group = 7; + repeated string job_tags = 21; JobExecutionStatus status = 8; int32 num_tasks = 9; int32 num_active_tasks = 10; diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index cf7a405f1babc..c32c674d64e0f 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -829,6 +829,55 @@ class SparkContext(config: SparkConf) extends Logging { setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, null) } + /** + * Set the behavior of job cancellation from jobs started in this thread. + * + * @param interruptOnCancel If true, then job cancellation will result in `Thread.interrupt()` + * being called on the job's executor threads. This is useful to help ensure that the tasks + * are actually stopped in a timely manner, but is off by default due to HDFS-1208, where HDFS + * may respond to Thread.interrupt() by marking nodes as dead. + */ + def setInterruptOnCancel(interruptOnCancel: Boolean): Unit = { + setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, interruptOnCancel.toString) + } + + /** + * Add a tag to be assigned to all the jobs started by this thread. + * + * @param tag The tag to be added. Cannot contain ',' (comma) character. + */ + def addJobTag(tag: String): Unit = { + SparkContext.throwIfInvalidTag(tag) + val existingTags = getJobTags() + val newTags = (existingTags + tag).mkString(SparkContext.SPARK_JOB_TAGS_SEP) + setLocalProperty(SparkContext.SPARK_JOB_TAGS, newTags) + } + + /** + * Remove a tag previously added to be assigned to all the jobs started by this thread. + * Noop if such a tag was not added earlier. + * + * @param tag The tag to be removed. Cannot contain ',' (comma) character. + */ + def removeJobTag(tag: String): Unit = { + SparkContext.throwIfInvalidTag(tag) + val existingTags = getJobTags() + val newTags = (existingTags - tag).mkString(SparkContext.SPARK_JOB_TAGS_SEP) + setLocalProperty(SparkContext.SPARK_JOB_TAGS, newTags) + } + + /** Get the tags that are currently set to be assigned to all the jobs started by this thread. */ + def getJobTags(): Set[String] = { + Option(getLocalProperty(SparkContext.SPARK_JOB_TAGS)) + .map(_.split(SparkContext.SPARK_JOB_TAGS_SEP).toSet) + .getOrElse(Set()) + } + + /** Clear the current thread's job tags. */ + def clearJobTags(): Unit = { + setLocalProperty(SparkContext.SPARK_JOB_TAGS, null) + } + /** * Execute a block of code in a scope such that all new RDDs created in this body will * be part of the same scope. For more detail, see {{org.apache.spark.rdd.RDDOperationScope}}. @@ -2471,6 +2520,17 @@ class SparkContext(config: SparkConf) extends Logging { dagScheduler.cancelJobGroup(groupId) } + /** + * Cancel active jobs that have the specified tag. See `org.apache.spark.SparkContext.addJobTag`. + * + * @param tag The tag to be added. Cannot contain ',' (comma) character. + */ + def cancelJobsWithTag(tag: String): Unit = { + SparkContext.throwIfInvalidTag(tag) + assertNotStopped() + dagScheduler.cancelJobsWithTag(tag) + } + /** Cancel all jobs that have been scheduled or are running. */ def cancelAllJobs(): Unit = { assertNotStopped() @@ -2840,6 +2900,7 @@ object SparkContext extends Logging { private[spark] val SPARK_JOB_DESCRIPTION = "spark.job.description" private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id" private[spark] val SPARK_JOB_INTERRUPT_ON_CANCEL = "spark.job.interruptOnCancel" + private[spark] val SPARK_JOB_TAGS = "spark.job.tags" private[spark] val SPARK_SCHEDULER_POOL = "spark.scheduler.pool" private[spark] val RDD_SCOPE_KEY = "spark.rdd.scope" private[spark] val RDD_SCOPE_NO_OVERRIDE_KEY = "spark.rdd.scope.noOverride" @@ -2851,6 +2912,22 @@ object SparkContext extends Logging { */ private[spark] val DRIVER_IDENTIFIER = "driver" + /** Separator of tags in SPARK_JOB_TAGS property */ + private[spark] val SPARK_JOB_TAGS_SEP = "," + + private[spark] def throwIfInvalidTag(tag: String) = { + if (tag == null) { + throw new IllegalArgumentException("Spark job tag cannot be null.") + } + if (tag.contains(SPARK_JOB_TAGS_SEP)) { + throw new IllegalArgumentException( + s"Spark job tag cannot contain '$SPARK_JOB_TAGS_SEP'.") + } + if (tag.isEmpty) { + throw new IllegalArgumentException( + "Spark job tag cannot be an empty string.") + } + } private implicit def arrayToArrayWritable[T <: Writable : ClassTag](arr: Iterable[T]) : ArrayWritable = { diff --git a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala index 22dc1d056ec0c..a55a6a8b8eb62 100644 --- a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala +++ b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala @@ -52,6 +52,17 @@ class SparkStatusTracker private[spark] (sc: SparkContext, store: AppStatusStore store.jobsList(null).filter(_.jobGroup == expected).map(_.jobId).toArray } + /** + * Return a list of all known jobs with a particular tag. + * + * The returned list may contain running, failed, and completed jobs, and may vary across + * invocations of this method. This method does not guarantee the order of the elements in + * its result. + */ + def getJobIdsForTag(jobTag: String): Array[Int] = { + store.jobsList(null).filter(_.jobTags.contains(jobTag)).map(_.jobId).toArray + } + /** * Returns an array containing the ids of all active stages. * diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index c78a26d91ebfd..64a8192f8e18f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1085,6 +1085,15 @@ private[spark] class DAGScheduler( eventProcessLoop.post(JobGroupCancelled(groupId)) } + /** + * Cancel all jobs with a given tag. + */ + def cancelJobsWithTag(tag: String): Unit = { + SparkContext.throwIfInvalidTag(tag) + logInfo(s"Asked to cancel jobs with tag $tag") + eventProcessLoop.post(JobTagCancelled(tag)) + } + /** * Cancel all jobs that are running or waiting in the queue. */ @@ -1182,6 +1191,19 @@ private[spark] class DAGScheduler( Option("part of cancelled job group %s".format(groupId)))) } + private[scheduler] def handleJobTagCancelled(tag: String): Unit = { + // Cancel all jobs belonging that have this tag. + // First finds all active jobs with this group id, and then kill stages for them. + val jobIds = activeJobs.filter { activeJob => + Option(activeJob.properties).exists { properties => + Option(properties.getProperty(SparkContext.SPARK_JOB_TAGS)).getOrElse("") + .split(SparkContext.SPARK_JOB_TAGS_SEP).toSet.contains(tag) + } + }.map(_.jobId) + jobIds.foreach(handleJobCancellation(_, + Option(s"part of cancelled job tag $tag"))) + } + private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo): Unit = { listenerBus.post(SparkListenerTaskStart(task.stageId, task.stageAttemptId, taskInfo)) } @@ -2972,6 +2994,9 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case JobGroupCancelled(groupId) => dagScheduler.handleJobGroupCancelled(groupId) + case JobTagCancelled(groupId) => + dagScheduler.handleJobTagCancelled(groupId) + case AllJobsCancelled => dagScheduler.doCancelAllJobs() diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index c16e5ea03d7c9..6f2b778ca82d7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -63,6 +63,8 @@ private[scheduler] case class JobCancelled( private[scheduler] case class JobGroupCancelled(groupId: String) extends DAGSchedulerEvent +private[scheduler] case class JobTagCancelled(tagName: String) extends DAGSchedulerEvent + private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent private[scheduler] diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 5dee3cb6719fd..c1f52e86dd058 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -438,6 +438,12 @@ private[spark] class AppStatusListener( .flatMap { p => Option(p.getProperty(SparkContext.SPARK_JOB_DESCRIPTION)) } val jobGroup = Option(event.properties) .flatMap { p => Option(p.getProperty(SparkContext.SPARK_JOB_GROUP_ID)) } + val jobTags = Option(event.properties) + .flatMap { p => Option(p.getProperty(SparkContext.SPARK_JOB_TAGS)) } + .map(_.split(SparkContext.SPARK_JOB_TAGS_SEP).toSet) + .getOrElse(Set()) + .toSeq + .sorted val sqlExecutionId = Option(event.properties) .flatMap(p => Option(p.getProperty(SQL_EXECUTION_ID_KEY)).map(_.toLong)) @@ -448,6 +454,7 @@ private[spark] class AppStatusListener( if (event.time > 0) Some(new Date(event.time)) else None, event.stageIds, jobGroup, + jobTags, numTasks, sqlExecutionId) liveJobs.put(event.jobId, job) diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index 9910a0f07fcf4..ebea11fdca07b 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -66,6 +66,7 @@ private class LiveJob( val submissionTime: Option[Date], val stageIds: Seq[Int], jobGroup: Option[String], + jobTags: Seq[String], numTasks: Int, sqlExecutionId: Option[Long]) extends LiveEntity { @@ -98,6 +99,7 @@ private class LiveJob( completionTime, stageIds, jobGroup, + jobTags, status, numTasks, activeTasks, diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index f436d16ca4775..8d648b9df38fa 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -199,6 +199,7 @@ class JobData private[spark]( val completionTime: Option[Date], val stageIds: collection.Seq[Int], val jobGroup: Option[String], + val jobTags: collection.Seq[String], val status: JobExecutionStatus, val numTasks: Int, val numActiveTasks: Int, diff --git a/core/src/main/scala/org/apache/spark/status/protobuf/JobDataWrapperSerializer.scala b/core/src/main/scala/org/apache/spark/status/protobuf/JobDataWrapperSerializer.scala index 97189e372f975..11f1b7070cc3c 100644 --- a/core/src/main/scala/org/apache/spark/status/protobuf/JobDataWrapperSerializer.scala +++ b/core/src/main/scala/org/apache/spark/status/protobuf/JobDataWrapperSerializer.scala @@ -71,6 +71,7 @@ private[protobuf] class JobDataWrapperSerializer extends ProtobufSerDe[JobDataWr } jobData.stageIds.foreach(id => jobDataBuilder.addStageIds(id.toLong)) jobData.jobGroup.foreach(jobDataBuilder.setJobGroup) + jobData.jobTags.foreach(jobDataBuilder.addJobTags) jobData.killedTasksSummary.foreach { entry => jobDataBuilder.putKillTasksSummary(entry._1, entry._2) } @@ -93,6 +94,7 @@ private[protobuf] class JobDataWrapperSerializer extends ProtobufSerDe[JobDataWr completionTime = completionTime, stageIds = info.getStageIdsList.asScala.map(_.toInt), jobGroup = jobGroup, + jobTags = info.getJobTagsList.asScala, status = status, numTasks = info.getNumTasks, numActiveTasks = info.getNumActiveTasks, diff --git a/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_1__expectation.json b/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_1__expectation.json index 2f275c7bfe2f4..b7271d89e0271 100644 --- a/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_1__expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_1__expectation.json @@ -2,6 +2,7 @@ "jobId" : 0, "name" : "foreach at :15", "stageIds" : [ 0 ], + "jobTags" : [ ], "status" : "SUCCEEDED", "numTasks" : 8, "numActiveTasks" : 0, diff --git a/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_2__expectation.json b/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_2__expectation.json index 2f275c7bfe2f4..b7271d89e0271 100644 --- a/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_2__expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_2__expectation.json @@ -2,6 +2,7 @@ "jobId" : 0, "name" : "foreach at :15", "stageIds" : [ 0 ], + "jobTags" : [ ], "status" : "SUCCEEDED", "numTasks" : 8, "numActiveTasks" : 0, diff --git a/core/src/test/resources/HistoryServerExpectations/job_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/job_list_json_expectation.json index 71bf8706307c8..bb26bc47eac49 100644 --- a/core/src/test/resources/HistoryServerExpectations/job_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/job_list_json_expectation.json @@ -2,6 +2,7 @@ "jobId" : 2, "name" : "count at :17", "stageIds" : [ 3 ], + "jobTags" : [ ], "status" : "SUCCEEDED", "numTasks" : 8, "numActiveTasks" : 0, @@ -19,6 +20,7 @@ "jobId" : 1, "name" : "count at :20", "stageIds" : [ 1, 2 ], + "jobTags" : [ ], "status" : "FAILED", "numTasks" : 16, "numActiveTasks" : 0, @@ -36,6 +38,7 @@ "jobId" : 0, "name" : "count at :15", "stageIds" : [ 0 ], + "jobTags" : [ ], "status" : "SUCCEEDED", "numTasks" : 8, "numActiveTasks" : 0, diff --git a/core/src/test/resources/HistoryServerExpectations/one_job_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_job_json_expectation.json index 1eae5f3d5beb3..3bf4845ed1e0e 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_job_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_job_json_expectation.json @@ -2,6 +2,7 @@ "jobId" : 0, "name" : "count at :15", "stageIds" : [ 0 ], + "jobTags" : [ ], "status" : "SUCCEEDED", "numTasks" : 8, "numActiveTasks" : 0, diff --git a/core/src/test/resources/HistoryServerExpectations/succeeded_failed_job_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/succeeded_failed_job_list_json_expectation.json index 71bf8706307c8..bb26bc47eac49 100644 --- a/core/src/test/resources/HistoryServerExpectations/succeeded_failed_job_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/succeeded_failed_job_list_json_expectation.json @@ -2,6 +2,7 @@ "jobId" : 2, "name" : "count at :17", "stageIds" : [ 3 ], + "jobTags" : [ ], "status" : "SUCCEEDED", "numTasks" : 8, "numActiveTasks" : 0, @@ -19,6 +20,7 @@ "jobId" : 1, "name" : "count at :20", "stageIds" : [ 1, 2 ], + "jobTags" : [ ], "status" : "FAILED", "numTasks" : 16, "numActiveTasks" : 0, @@ -36,6 +38,7 @@ "jobId" : 0, "name" : "count at :15", "stageIds" : [ 0 ], + "jobTags" : [ ], "status" : "SUCCEEDED", "numTasks" : 8, "numActiveTasks" : 0, diff --git a/core/src/test/resources/HistoryServerExpectations/succeeded_job_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/succeeded_job_list_json_expectation.json index b1ddd760c9714..2b2c2fbe1f25b 100644 --- a/core/src/test/resources/HistoryServerExpectations/succeeded_job_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/succeeded_job_list_json_expectation.json @@ -2,6 +2,7 @@ "jobId" : 2, "name" : "count at :17", "stageIds" : [ 3 ], + "jobTags" : [ ], "status" : "SUCCEEDED", "numTasks" : 8, "numActiveTasks" : 0, @@ -19,6 +20,7 @@ "jobId" : 0, "name" : "count at :15", "stageIds" : [ 0 ], + "jobTags" : [ ], "status" : "SUCCEEDED", "numTasks" : 8, "numActiveTasks" : 0, diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 77bdb882c507d..f2ad33b0be710 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -20,10 +20,10 @@ package org.apache.spark import java.util.concurrent.{Semaphore, TimeUnit} import java.util.concurrent.atomic.AtomicInteger +import scala.concurrent.{ExecutionContext, Future} // scalastyle:off executioncontextglobal import scala.concurrent.ExecutionContext.Implicits.global // scalastyle:on executioncontextglobal -import scala.concurrent.Future import scala.concurrent.duration._ import org.scalatest.BeforeAndAfter @@ -31,7 +31,7 @@ import org.scalatest.matchers.must.Matchers import org.apache.spark.internal.config._ import org.apache.spark.internal.config.Deploy._ -import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted, SparkListenerTaskEnd, SparkListenerTaskStart} +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd, SparkListenerJobStart, SparkListenerStageCompleted, SparkListenerTaskEnd, SparkListenerTaskStart} import org.apache.spark.util.ThreadUtils /** @@ -153,6 +153,131 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft assert(jobB.get() === 100) } + test("job tags") { + sc = new SparkContext("local[2]", "test") + + // global ExecutionContext has only 2 threads in Apache Spark CI + // create own thread pool for four Futures used in this test + val numThreads = 4 + val fpool = ThreadUtils.newForkJoinPool("job-tags-test-thread-pool", numThreads) + val executionContext = ExecutionContext.fromExecutorService(fpool) + + try { + // Add a listener to release the semaphore once jobs are launched. + val sem = new Semaphore(0) + val jobEnded = new AtomicInteger(0) + + sc.addSparkListener(new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + sem.release() + } + + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + sem.release() + jobEnded.incrementAndGet() + } + }) + + val eSep = intercept[IllegalArgumentException](sc.addJobTag("foo,bar")) + assert(eSep.getMessage.contains( + s"Spark job tag cannot contain '${SparkContext.SPARK_JOB_TAGS_SEP}'.")) + val eEmpty = intercept[IllegalArgumentException](sc.addJobTag("")) + assert(eEmpty.getMessage.contains("Spark job tag cannot be an empty string.")) + val eNull = intercept[IllegalArgumentException](sc.addJobTag(null)) + assert(eNull.getMessage.contains("Spark job tag cannot be null.")) + + // Note: since tags are added in the Future threads, they don't need to be cleared in between. + val jobA = Future { + assert(sc.getJobTags() == Set()) + sc.addJobTag("two") + assert(sc.getJobTags() == Set("two")) + sc.clearJobTags() // check that clearing all tags works + assert(sc.getJobTags() == Set()) + sc.addJobTag("one") + assert(sc.getJobTags() == Set("one")) + try { + sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(100); i }.count() + } finally { + sc.clearJobTags() // clear for the case of thread reuse by another Future + } + }(executionContext) + val jobB = Future { + assert(sc.getJobTags() == Set()) + sc.addJobTag("one") + sc.addJobTag("two") + sc.addJobTag("one") + sc.addJobTag("two") // duplicates shouldn't matter + assert(sc.getJobTags() == Set("one", "two")) + try { + sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(100); i }.count() + } finally { + sc.clearJobTags() // clear for the case of thread reuse by another Future + } + }(executionContext) + val jobC = Future { + assert(sc.getJobTags() == Set()) + sc.addJobTag("two") + assert(sc.getJobTags() == Set("two")) + try { + sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(100); i }.count() + } finally { + sc.clearJobTags() // clear for the case of thread reuse by another Future + } + }(executionContext) + val jobD = Future { + assert(sc.getJobTags() == Set()) + sc.addJobTag("one") + sc.addJobTag("two") + sc.addJobTag("two") + assert(sc.getJobTags() == Set("one", "two")) + sc.removeJobTag("two") // check that remove works, despite duplicate add + assert(sc.getJobTags() == Set("one")) + try { + sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(100); i }.count() + } finally { + sc.clearJobTags() // clear for the case of thread reuse by another Future + } + }(executionContext) + + // Block until four jobs have started. + val acquired1 = sem.tryAcquire(4, 1, TimeUnit.MINUTES) + assert(acquired1 == true) + + sc.cancelJobsWithTag("two") + val eB = intercept[SparkException] { + ThreadUtils.awaitResult(jobB, 1.minute) + }.getCause + assert(eB.getMessage contains "cancel") + val eC = intercept[SparkException] { + ThreadUtils.awaitResult(jobC, 1.minute) + }.getCause + assert(eC.getMessage contains "cancel") + + // two jobs cancelled + val acquired2 = sem.tryAcquire(2, 1, TimeUnit.MINUTES) + assert(acquired2 == true) + assert(jobEnded.intValue == 2) + + // this cancels the remaining two jobs + sc.cancelJobsWithTag("one") + val eA = intercept[SparkException] { + ThreadUtils.awaitResult(jobA, 1.minute) + }.getCause + assert(eA.getMessage contains "cancel") + val eD = intercept[SparkException] { + ThreadUtils.awaitResult(jobD, 1.minute) + }.getCause + assert(eD.getMessage contains "cancel") + + // another two jobs cancelled + val acquired3 = sem.tryAcquire(2, 1, TimeUnit.MINUTES) + assert(acquired3 == true) + assert(jobEnded.intValue == 4) + } finally { + fpool.shutdownNow() + } + } + test("inherited job group (SPARK-6629)") { sc = new SparkContext("local[2]", "test") diff --git a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala index e6d3377120e56..0817abbc6a328 100644 --- a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala @@ -111,4 +111,45 @@ class StatusTrackerSuite extends SparkFunSuite with Matchers with LocalSparkCont sc.statusTracker.getJobIdsForGroup("my-job-group2") should have size 2 } } + + test("getJobIdsForTag()") { + sc = new SparkContext("local", "test", new SparkConf(false)) + + sc.addJobTag("tag1") + sc.statusTracker.getJobIdsForTag("tag1") should be (Seq.empty) + + // countAsync() + val firstJobFuture = sc.parallelize(1 to 1000).countAsync() + val firstJobId = eventually(timeout(10.seconds)) { + firstJobFuture.jobIds.head + } + eventually(timeout(10.seconds)) { + sc.statusTracker.getJobIdsForTag("tag1") should be (Seq(firstJobId)) + } + + sc.addJobTag("tag2") + // takeAsync() + val secondJobFuture = sc.parallelize(1 to 1000).takeAsync(1) + val secondJobId = eventually(timeout(10.seconds)) { + secondJobFuture.jobIds.head + } + eventually(timeout(10.seconds)) { + sc.statusTracker.getJobIdsForTag("tag1").toSet should be ( + Set(firstJobId, secondJobId)) + sc.statusTracker.getJobIdsForTag("tag2") should be (Seq(secondJobId)) + } + + sc.removeJobTag("tag1") + // takeAsync() across multiple partitions + val thirdJobFuture = sc.parallelize(1 to 1000, 2).takeAsync(999) + val thirdJobId = eventually(timeout(10.seconds)) { + thirdJobFuture.jobIds.head + } + eventually(timeout(10.seconds)) { + sc.statusTracker.getJobIdsForTag("tag1").toSet should be ( + Set(firstJobId, secondJobId)) + sc.statusTracker.getJobIdsForTag("tag2").toSet should be ( + Set(secondJobId, thirdJobId)) + } + } } diff --git a/core/src/test/scala/org/apache/spark/status/protobuf/KVStoreProtobufSerializerSuite.scala b/core/src/test/scala/org/apache/spark/status/protobuf/KVStoreProtobufSerializerSuite.scala index 0849d63b03ec7..ac568fee1ad45 100644 --- a/core/src/test/scala/org/apache/spark/status/protobuf/KVStoreProtobufSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/protobuf/KVStoreProtobufSerializerSuite.scala @@ -65,9 +65,9 @@ class KVStoreProtobufSerializerSuite extends SparkFunSuite { test("Job data") { Seq( - ("test", Some("test description"), Some("group")), - (null, None, None) - ).foreach { case (name, description, jobGroup) => + ("test", Some("test description"), Some("group"), Seq("tag1", "tag2")), + (null, None, None, Seq()) + ).foreach { case (name, description, jobGroup, jobTags) => val input = new JobDataWrapper( new JobData( jobId = 1, @@ -77,6 +77,7 @@ class KVStoreProtobufSerializerSuite extends SparkFunSuite { completionTime = Some(new Date(654321L)), stageIds = Seq(1, 2, 3, 4), jobGroup = jobGroup, + jobTags = jobTags, status = JobExecutionStatus.UNKNOWN, numTasks = 2, numActiveTasks = 3, @@ -102,6 +103,7 @@ class KVStoreProtobufSerializerSuite extends SparkFunSuite { assert(result.info.completionTime == input.info.completionTime) assert(result.info.stageIds == input.info.stageIds) assert(result.info.jobGroup == input.info.jobGroup) + assert(result.info.jobTags == input.info.jobTags) assert(result.info.status == input.info.status) assert(result.info.numTasks == input.info.numTasks) assert(result.info.numActiveTasks == input.info.numActiveTasks) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index bba20534f44a9..31c8b68162d8e 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -56,7 +56,9 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.Row.prettyJson"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.MutableAggregationBuffer.json"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.MutableAggregationBuffer.prettyJson"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.MutableAggregationBuffer.jsonValue") + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.MutableAggregationBuffer.jsonValue"), + // [SPARK-43952][CORE][CONNECT][SQL] Add SparkContext APIs for query cancellation by tag + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.JobData.this") ) // Defulat exclude rules diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index 548a8628ba44d..15141b09b6c07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -124,14 +124,17 @@ case class BroadcastExchangeExec( case _ => 512000000 } + @transient + private lazy val jobTag = s"broadcast exchange (runId ${runId.toString})" + @transient override lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]]( session, BroadcastExchangeExec.executionContext) { try { - // Setup a job group here so later it may get cancelled by groupId if necessary. - sparkContext.setJobGroup(runId.toString, s"broadcast exchange (runId $runId)", - interruptOnCancel = true) + // Setup a job tag here so later it may get cancelled by tag if necessary. + sparkContext.addJobTag(jobTag) + sparkContext.setInterruptOnCancel(true) val beforeCollect = System.nanoTime() // Use executeCollect/executeCollectIterator to avoid conversion to Scala types val (numRows, input) = child.executeCollectIterator() @@ -211,7 +214,7 @@ case class BroadcastExchangeExec( case ex: TimeoutException => logError(s"Could not execute broadcast in $timeout secs.", ex) if (!relationFuture.isDone) { - sparkContext.cancelJobGroup(runId.toString) + sparkContext.cancelJobsWithTag(jobTag) relationFuture.cancel(true) } throw QueryExecutionErrors.executeBroadcastTimeoutError(timeout, Some(ex)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BroadcastExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BroadcastExchangeSuite.scala index 31a8507cba0c6..0efb4180dbdb6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BroadcastExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BroadcastExchangeSuite.scala @@ -38,7 +38,7 @@ class BroadcastExchangeSuite extends SparkPlanTest import testImplicits._ - test("BroadcastExchange should cancel the job group if timeout") { + test("BroadcastExchange should cancel the job tag if timeout") { val startLatch = new CountDownLatch(1) val endLatch = new CountDownLatch(1) var jobEvents: Seq[SparkListenerEvent] = Seq.empty[SparkListenerEvent] @@ -82,7 +82,7 @@ class BroadcastExchangeSuite extends SparkPlanTest val events = jobEvents.toArray val hasStart = events(0).isInstanceOf[SparkListenerJobStart] val hasCancelled = events(1).asInstanceOf[SparkListenerJobEnd].jobResult - .asInstanceOf[JobFailed].exception.getMessage.contains("cancelled job group") + .asInstanceOf[JobFailed].exception.getMessage.contains("cancelled job tag") events.length == 2 && hasStart && hasCancelled } }