diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 7c157a749965..66de9d37db0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -119,7 +119,9 @@ object SQLExecution { // will be caught and reported in the `SparkListenerSQLExecutionEnd` sparkPlanInfo = SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), time = System.currentTimeMillis(), - redactedConfigs)) + modifiedConfigs = redactedConfigs, + jobTags = sc.getJobTags() + )) body } catch { case e: Throwable => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index 7b9f877bdef5..3fafc399dd82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -342,7 +342,7 @@ class SQLAppStatusListener( private def onExecutionStart(event: SparkListenerSQLExecutionStart): Unit = { val SparkListenerSQLExecutionStart(executionId, rootExecutionId, description, details, - physicalPlanDescription, sparkPlanInfo, time, modifiedConfigs) = event + physicalPlanDescription, sparkPlanInfo, time, modifiedConfigs, _) = event val planGraph = SparkPlanGraph(sparkPlanInfo) val sqlPlanMetrics = planGraph.allNodes.flatMap { node => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index d4c8f600a4e2..3a22dd23548f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -53,7 +53,8 @@ case class SparkListenerSQLExecutionStart( physicalPlanDescription: String, sparkPlanInfo: SparkPlanInfo, time: Long, - modifiedConfigs: Map[String, String] = Map.empty) + modifiedConfigs: Map[String, String] = Map.empty, + jobTags: Set[String] = Set.empty) extends SparkListenerEvent @DeveloperApi diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala index 1a062d8d4e25..766f4959caa1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala @@ -203,6 +203,36 @@ class SQLExecutionSuite extends SparkFunSuite { spark.stop() } } + + test("SPARK-44591: jobTags property") { + val spark = SparkSession.builder.master("local[*]").appName("test").getOrCreate() + val jobTag = "jobTag" + try { + spark.sparkContext.addJobTag(jobTag) + + var jobTags: Option[String] = None + var sqlJobTags: Set[String] = Set.empty + spark.sparkContext.addSparkListener(new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + jobTags = Some(jobStart.properties.getProperty(SparkContext.SPARK_JOB_TAGS)) + } + override def onOtherEvent(event: SparkListenerEvent): Unit = { + event match { + case e: SparkListenerSQLExecutionStart => + sqlJobTags = e.jobTags + } + } + }) + + spark.range(1).collect() + + assert(jobTags.contains(jobTag)) + assert(sqlJobTags.contains(jobTag)) + } finally { + spark.sparkContext.removeJobTag(jobTag) + spark.stop() + } + } } object SQLExecutionSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index 0e3ba6b79eb5..72f8de3e9ccd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -345,7 +345,7 @@ abstract class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTes val listener = new SparkListener { override def onOtherEvent(event: SparkListenerEvent): Unit = { event match { - case SparkListenerSQLExecutionStart(_, _, _, _, planDescription, _, _, _) => + case SparkListenerSQLExecutionStart(_, _, _, _, planDescription, _, _, _, _) => assert(expected.forall(planDescription.contains)) checkDone = true case _ => // ignore other events