Skip to content

Commit a456226

Browse files
committed
SQL execution listener shouldn't happen on execution thread
1 parent ebd899b commit a456226

File tree

10 files changed

+134
-105
lines changed

10 files changed

+134
-105
lines changed

project/MimaExcludes.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ object MimaExcludes {
3636

3737
// Exclude rules for 3.0.x
3838
lazy val v30excludes = v24excludes ++ Seq(
39+
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.util.ExecutionListenerManager.clone"),
40+
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.util.ExecutionListenerManager.this")
3941
)
4042

4143
// Exclude rules for 2.4.x

sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -667,17 +667,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
667667
*/
668668
private def runCommand(session: SparkSession, name: String)(command: LogicalPlan): Unit = {
669669
val qe = session.sessionState.executePlan(command)
670-
try {
671-
val start = System.nanoTime()
672-
// call `QueryExecution.toRDD` to trigger the execution of commands.
673-
SQLExecution.withNewExecutionId(session, qe)(qe.toRdd)
674-
val end = System.nanoTime()
675-
session.listenerManager.onSuccess(name, qe, end - start)
676-
} catch {
677-
case e: Exception =>
678-
session.listenerManager.onFailure(name, qe, e)
679-
throw e
680-
}
670+
// call `QueryExecution.toRDD` to trigger the execution of commands.
671+
SQLExecution.withNewExecutionId(session, qe, Some(name))(qe.toRdd)
681672
}
682673

683674
///////////////////////////////////////////////////////////////////////////////////////

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3356,21 +3356,11 @@ class Dataset[T] private[sql](
33563356
* user-registered callback functions.
33573357
*/
33583358
private def withAction[U](name: String, qe: QueryExecution)(action: SparkPlan => U) = {
3359-
try {
3360-
qe.executedPlan.foreach { plan =>
3361-
plan.resetMetrics()
3362-
}
3363-
val start = System.nanoTime()
3364-
val result = SQLExecution.withNewExecutionId(sparkSession, qe) {
3365-
action(qe.executedPlan)
3366-
}
3367-
val end = System.nanoTime()
3368-
sparkSession.listenerManager.onSuccess(name, qe, end - start)
3369-
result
3370-
} catch {
3371-
case e: Exception =>
3372-
sparkSession.listenerManager.onFailure(name, qe, e)
3373-
throw e
3359+
qe.executedPlan.foreach { plan =>
3360+
plan.resetMetrics()
3361+
}
3362+
SQLExecution.withNewExecutionId(sparkSession, qe, Some(name)) {
3363+
action(qe.executedPlan)
33743364
}
33753365
}
33763366

sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ object SQLExecution {
5858
*/
5959
def withNewExecutionId[T](
6060
sparkSession: SparkSession,
61-
queryExecution: QueryExecution)(body: => T): T = {
61+
queryExecution: QueryExecution,
62+
name: Option[String] = None)(body: => T): T = {
6263
val sc = sparkSession.sparkContext
6364
val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY)
6465
val executionId = SQLExecution.nextExecutionId
@@ -71,14 +72,35 @@ object SQLExecution {
7172
val callSite = sc.getCallSite()
7273

7374
withSQLConfPropagated(sparkSession) {
74-
sc.listenerBus.post(SparkListenerSQLExecutionStart(
75-
executionId, callSite.shortForm, callSite.longForm, queryExecution.toString,
76-
SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis()))
75+
var ex: Option[Exception] = None
76+
val startTime = System.currentTimeMillis()
7777
try {
78+
sc.listenerBus.post(SparkListenerSQLExecutionStart(
79+
executionId = executionId,
80+
description = callSite.shortForm,
81+
details = callSite.longForm,
82+
physicalPlanDescription = queryExecution.toString,
83+
// `queryExecution.executedPlan` triggers query planning. If it fails, the exception
84+
// will be caught and reported in the `SparkListenerSQLExecutionEnd`
85+
sparkPlanInfo = SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan),
86+
time = startTime))
7887
body
88+
} catch {
89+
case e: Exception =>
90+
ex = Some(e)
91+
throw e
7992
} finally {
80-
sc.listenerBus.post(SparkListenerSQLExecutionEnd(
81-
executionId, System.currentTimeMillis()))
93+
val endTime = System.currentTimeMillis()
94+
val event = SparkListenerSQLExecutionEnd(executionId, endTime)
95+
// Currently only `Dataset.withAction` and `DataFrameWriter.runCommand` specify the `name`
96+
// parameter. The `ExecutionListenerManager` only watches SQL executions with name. We
97+
// can specify the execution name in more places in the future, so that
98+
// `QueryExecutionListener` can track more cases.
99+
event.executionName = name
100+
event.duration = endTime - startTime
101+
event.qe = queryExecution
102+
event.executionFailure = ex
103+
sc.listenerBus.post(event)
82104
}
83105
}
84106
} finally {

sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@
1717

1818
package org.apache.spark.sql.execution.ui
1919

20+
import com.fasterxml.jackson.annotation.JsonIgnore
2021
import com.fasterxml.jackson.databind.JavaType
2122
import com.fasterxml.jackson.databind.`type`.TypeFactory
2223
import com.fasterxml.jackson.databind.annotation.JsonDeserialize
2324
import com.fasterxml.jackson.databind.util.Converter
2425

2526
import org.apache.spark.annotation.DeveloperApi
2627
import org.apache.spark.scheduler._
27-
import org.apache.spark.sql.execution.SparkPlanInfo
28-
import org.apache.spark.sql.execution.metric._
28+
import org.apache.spark.sql.execution.{QueryExecution, SparkPlanInfo}
2929

3030
@DeveloperApi
3131
case class SparkListenerSQLExecutionStart(
@@ -39,7 +39,14 @@ case class SparkListenerSQLExecutionStart(
3939

4040
@DeveloperApi
4141
case class SparkListenerSQLExecutionEnd(executionId: Long, time: Long)
42-
extends SparkListenerEvent
42+
extends SparkListenerEvent {
43+
44+
@JsonIgnore private[sql] var executionName: Option[String] = None
45+
// These 3 fields are only accessed when `executionName` is defined.
46+
@JsonIgnore private[sql] var duration: Long = 0L
47+
@JsonIgnore private[sql] var qe: QueryExecution = null
48+
@JsonIgnore private[sql] var executionFailure: Option[Exception] = None
49+
}
4350

4451
/**
4552
* A message used to update SQL metric value for driver-side updates (which doesn't get reflected

sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,8 @@ abstract class BaseSessionStateBuilder(
266266
* This gets cloned from parent if available, otherwise a new instance is created.
267267
*/
268268
protected def listenerManager: ExecutionListenerManager = {
269-
parentState.map(_.listenerManager.clone()).getOrElse(
270-
new ExecutionListenerManager(session.sparkContext.conf))
269+
parentState.map(_.listenerManager.clone(session)).getOrElse(
270+
new ExecutionListenerManager(session, loadExtensions = true))
271271
}
272272

273273
/**

sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala

Lines changed: 34 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,16 @@
1717

1818
package org.apache.spark.sql.util
1919

20-
import java.util.concurrent.locks.ReentrantReadWriteLock
20+
import java.util.concurrent.CopyOnWriteArrayList
2121

22-
import scala.collection.mutable.ListBuffer
23-
import scala.util.control.NonFatal
22+
import scala.collection.JavaConverters._
2423

25-
import org.apache.spark.SparkConf
2624
import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
2725
import org.apache.spark.internal.Logging
26+
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent}
27+
import org.apache.spark.sql.SparkSession
2828
import org.apache.spark.sql.execution.QueryExecution
29+
import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd
2930
import org.apache.spark.sql.internal.StaticSQLConf._
3031
import org.apache.spark.util.Utils
3132

@@ -75,95 +76,69 @@ trait QueryExecutionListener {
7576
*/
7677
@Experimental
7778
@InterfaceStability.Evolving
78-
class ExecutionListenerManager private extends Logging {
79+
class ExecutionListenerManager private[sql](session: SparkSession, loadExtensions: Boolean)
80+
extends SparkListener with Logging {
7981

80-
private[sql] def this(conf: SparkConf) = {
81-
this()
82+
private[this] val listeners = new CopyOnWriteArrayList[QueryExecutionListener]
83+
84+
if (loadExtensions) {
85+
val conf = session.sparkContext.conf
8286
conf.get(QUERY_EXECUTION_LISTENERS).foreach { classNames =>
8387
Utils.loadExtensions(classOf[QueryExecutionListener], classNames, conf).foreach(register)
8488
}
8589
}
8690

91+
session.sparkContext.listenerBus.addToSharedQueue(this)
92+
8793
/**
8894
* Registers the specified [[QueryExecutionListener]].
8995
*/
9096
@DeveloperApi
91-
def register(listener: QueryExecutionListener): Unit = writeLock {
92-
listeners += listener
97+
def register(listener: QueryExecutionListener): Unit = {
98+
listeners.add(listener)
9399
}
94100

95101
/**
96102
* Unregisters the specified [[QueryExecutionListener]].
97103
*/
98104
@DeveloperApi
99-
def unregister(listener: QueryExecutionListener): Unit = writeLock {
100-
listeners -= listener
105+
def unregister(listener: QueryExecutionListener): Unit = {
106+
listeners.remove(listener)
101107
}
102108

103109
/**
104110
* Removes all the registered [[QueryExecutionListener]].
105111
*/
106112
@DeveloperApi
107-
def clear(): Unit = writeLock {
113+
def clear(): Unit = {
108114
listeners.clear()
109115
}
110116

111117
/**
112118
* Get an identical copy of this listener manager.
113119
*/
114-
@DeveloperApi
115-
override def clone(): ExecutionListenerManager = writeLock {
116-
val newListenerManager = new ExecutionListenerManager
117-
listeners.foreach(newListenerManager.register)
120+
private[sql] def clone(session: SparkSession): ExecutionListenerManager = {
121+
val newListenerManager = new ExecutionListenerManager(session, loadExtensions = false)
122+
listeners.iterator().asScala.foreach(newListenerManager.register)
118123
newListenerManager
119124
}
120125

121-
private[sql] def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
122-
readLock {
123-
withErrorHandling { listener =>
124-
listener.onSuccess(funcName, qe, duration)
126+
override def onOtherEvent(event: SparkListenerEvent): Unit = event match {
127+
case e: SparkListenerSQLExecutionEnd if shouldCatchEvent(e) =>
128+
val funcName = e.executionName.get
129+
e.executionFailure match {
130+
case Some(ex) =>
131+
listeners.iterator().asScala.foreach(_.onFailure(funcName, e.qe, ex))
132+
case _ =>
133+
listeners.iterator().asScala.foreach(_.onSuccess(funcName, e.qe, e.duration))
125134
}
126-
}
127-
}
128135

129-
private[sql] def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
130-
readLock {
131-
withErrorHandling { listener =>
132-
listener.onFailure(funcName, qe, exception)
133-
}
134-
}
136+
case _ => // Ignore
135137
}
136138

137-
private[this] val listeners = ListBuffer.empty[QueryExecutionListener]
138-
139-
/** A lock to prevent updating the list of listeners while we are traversing through them. */
140-
private[this] val lock = new ReentrantReadWriteLock()
141-
142-
private def withErrorHandling(f: QueryExecutionListener => Unit): Unit = {
143-
for (listener <- listeners) {
144-
try {
145-
f(listener)
146-
} catch {
147-
case NonFatal(e) => logWarning("Error executing query execution listener", e)
148-
}
149-
}
150-
}
151-
152-
/** Acquires a read lock on the cache for the duration of `f`. */
153-
private def readLock[A](f: => A): A = {
154-
val rl = lock.readLock()
155-
rl.lock()
156-
try f finally {
157-
rl.unlock()
158-
}
159-
}
160-
161-
/** Acquires a write lock on the cache for the duration of `f`. */
162-
private def writeLock[A](f: => A): A = {
163-
val wl = lock.writeLock()
164-
wl.lock()
165-
try f finally {
166-
wl.unlock()
167-
}
139+
private def shouldCatchEvent(e: SparkListenerSQLExecutionEnd): Boolean = {
140+
// Only catch SQL execution with a name, and triggered by the same spark session that this
141+
// listener manager belongs.
142+
e.executionName.isDefined && e.qe.sparkSession.eq(this.session)
168143
}
169144
}

sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@
1717

1818
package org.apache.spark.sql.execution
1919

20-
import org.json4s.jackson.JsonMethods.parse
20+
import org.json4s.jackson.JsonMethods._
2121

2222
import org.apache.spark.SparkFunSuite
23-
import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart
23+
import org.apache.spark.sql.LocalSparkSession
24+
import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, SparkListenerSQLExecutionStart}
25+
import org.apache.spark.sql.test.TestSparkSession
2426
import org.apache.spark.util.JsonProtocol
2527

26-
class SQLJsonProtocolSuite extends SparkFunSuite {
28+
class SQLJsonProtocolSuite extends SparkFunSuite with LocalSparkSession {
2729

2830
test("SparkPlanGraph backward compatibility: metadata") {
2931
val SQLExecutionStartJsonString =
@@ -49,4 +51,29 @@ class SQLJsonProtocolSuite extends SparkFunSuite {
4951
new SparkPlanInfo("TestNode", "test string", Nil, Map(), Nil), 0)
5052
assert(reconstructedEvent == expectedEvent)
5153
}
54+
55+
test("SparkListenerSQLExecutionEnd backward compatibility") {
56+
spark = new TestSparkSession()
57+
val qe = spark.sql("select 1").queryExecution
58+
val event = SparkListenerSQLExecutionEnd(1, 10)
59+
event.duration = 1000
60+
event.executionName = Some("test")
61+
event.qe = qe
62+
event.executionFailure = Some(new RuntimeException("test"))
63+
val json = JsonProtocol.sparkEventToJson(event)
64+
assert(json == parse(
65+
"""
66+
|{
67+
| "Event" : "org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd",
68+
| "executionId" : 1,
69+
| "time" : 10
70+
|}
71+
""".stripMargin))
72+
val readBack = JsonProtocol.sparkEventFromJson(json)
73+
event.duration = 0
74+
event.executionName = None
75+
event.qe = null
76+
event.executionFailure = None
77+
assert(readBack == event)
78+
}
5279
}

0 commit comments

Comments
 (0)